cocos-engine-external/sources/enoki/autodiff.h

1570 lines
54 KiB
C++

/*
enoki/autodiff.h -- Reverse mode automatic differentiation
Enoki is a C++ template library that enables transparent vectorization
of numerical kernels using SIMD instruction sets available on current
processor architectures.
Copyrighe (c) 2019 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a BSD-style
license that can be found in the LICENSE file.
*/
#pragma once
#include <enoki/array.h>
#include <vector>
#define ENOKI_AUTODIFF_H 1
NAMESPACE_BEGIN(enoki)
template <typename Type> struct Tape {
private:
template <typename T> friend struct DiffArray;
struct Detail;
struct Node;
struct Edge;
struct Special;
struct SimplificationLock;
using Index = uint32_t;
using Mask = mask_t<Type>;
using Int64 = int64_array_t<Type>;
Tape();
// -----------------------------------------------------------------------
//! @{ \name Append unary/binary/ternary operations to the tape
// -----------------------------------------------------------------------
Index append(const char *label, size_t size, Index i1, const Type &w1);
Index append(const char *label, size_t size, Index i1, Index i2,
const Type &w1, const Type &w2);
Index append(const char *label, size_t size, Index i1, Index i2, Index i3,
const Type &w1, const Type &w2, const Type &w3);
Index append_psum(Index i);
Index append_reverse(Index i);
Index append_gather(const Int64 &offset, const Mask &mask);
void append_scatter(Index index, const Int64 &offset, const Mask &mask,
bool scatter_add);
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Append nodes and edges to the tape
// -----------------------------------------------------------------------
Index append_node(size_t size, const char *label);
Index append_leaf(size_t size);
void append_edge(Index src, Index dst, const Type &weight);
void append_edge_prod(Index src, Index dst, const Type &weight1,
const Type &weight2);
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Reference counting
// -----------------------------------------------------------------------
void dec_ref_ext(Index index);
void inc_ref_ext(Index index);
void dec_ref_int(Index index, Index from);
void inc_ref_int(Index index, Index from);
void free_node(Index index);
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Other operations
// -----------------------------------------------------------------------
void set_scatter_gather_operand(Index *index, size_t size, bool permute);
void push_prefix(const char *);
void pop_prefix();
void backward(bool free_graph);
void forward(bool free_graph);
void backward(Index index, bool free_graph);
void forward(Index index, bool free_graph);
void set_gradient(Index index, const Type &value,
bool backward = true);
void set_label(Index index, const char *name);
const Type &gradient(Index index);
std::string graphviz(const std::vector<Index> &indices);
/// Current log level (0 == none, 1 == minimal, 2 == moderate, 3 == high, 4 == everything)
void set_log_level(uint32_t);
uint32_t log_level() const;
void set_graph_simplification(bool);
void simplify_graph();
std::string whos() const;
static void cuda_callback(void*);
//! @}
// -----------------------------------------------------------------------
static Tape* get() ENOKI_PURE;
public:
~Tape();
private:
static std::unique_ptr<Tape> s_tape;
Detail *d;
};
template <typename Type>
struct DiffArray : ArrayBase<value_t<Type>, DiffArray<Type>> {
public:
using Base = enoki::ArrayBase<value_t<Type>, DiffArray<Type>>;
using typename Base::Scalar;
using Tape = enoki::Tape<Type>;
using Index = uint32_t;
using UnderlyingType = Type;
using ArrayType = DiffArray;
using MaskType = DiffArray<mask_t<Type>>;
static constexpr size_t Size = is_scalar_v<Type> ? 1 : array_size_v<Type>;
static constexpr size_t Depth = is_scalar_v<Type> ? 1 : array_depth_v<Type>;
static constexpr bool IsMask = is_mask_v<Type>;
static constexpr bool IsCUDA = is_cuda_array_v<Type>;
static constexpr bool IsDiff = true;
static constexpr bool Enabled =
std::is_floating_point_v<scalar_t<Type>> && !is_mask_v<Type>;
template <typename T>
using ReplaceValue = DiffArray<replace_scalar_t<Type, T, false>>;
static_assert(array_depth_v<Type> <= 1,
"DiffArray requires a scalar or (non-nested) static or "
"dynamic Enoki array as template parameter.");
// -----------------------------------------------------------------------
//! @{ \name Constructors / destructors
// -----------------------------------------------------------------------
DiffArray() = default;
~DiffArray() {
if constexpr (Enabled)
tape()->dec_ref_ext(m_index);
}
DiffArray(const DiffArray &a) : m_value(a.m_value), m_index(a.m_index) {
if constexpr (Enabled)
tape()->inc_ref_ext(m_index);
}
DiffArray(DiffArray &&a) : m_value(std::move(a.m_value)) {
if constexpr (Enabled) {
m_index = a.m_index;
a.m_index = 0;
}
}
template <typename T>
DiffArray(const DiffArray<T> &v, detail::reinterpret_flag) :
m_value(v.value_(), detail::reinterpret_flag()) { /* no derivatives */ }
template <typename Type2, enable_if_t<!std::is_same_v<Type, Type2>> = 0>
DiffArray(const DiffArray<Type2> &a) : m_value(a.value_()) { }
template <typename Type2, enable_if_t<!std::is_same_v<Type, Type2>> = 0>
DiffArray(DiffArray<Type2> &&a) : m_value(std::move(a.value_())) { }
DiffArray(Type &&value) : m_value(std::move(value)) { }
template <typename... Args,
enable_if_t<sizeof...(Args) != 0 && std::conjunction_v<
std::negation<is_diff_array<Args>>...>> = 0>
DiffArray(Args&&... args) : m_value(std::forward<Args>(args)...) { }
DiffArray &operator=(const DiffArray &a) {
m_value = a.m_value;
if constexpr (Enabled) {
auto t = tape();
t->inc_ref_ext(a.m_index);
t->dec_ref_ext(m_index);
m_index = a.m_index;
}
return *this;
}
DiffArray &operator=(DiffArray &&a) {
m_value = std::move(a.m_value);
if constexpr (Enabled)
std::swap(m_index, a.m_index);
return *this;
}
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Vertical operations
// -----------------------------------------------------------------------
DiffArray add_(const DiffArray &a) const {
if constexpr (is_mask_v<Type> || std::is_pointer_v<Scalar>) {
fail_unsupported("add_");
} else {
Index index_new = 0;
Type result = m_value + a.m_value;
if constexpr (Enabled)
index_new = tape()->append("add", slices(result), m_index,
a.m_index, 1.f, 1.f);
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray sub_(const DiffArray &a) const {
if constexpr (is_mask_v<Type> || std::is_pointer_v<Scalar>) {
fail_unsupported("sub_");
} else {
Index index_new = 0;
Type result = m_value - a.m_value;
if constexpr (Enabled)
index_new = tape()->append("sub", slices(result), m_index,
a.m_index, 1.f, -1.f);
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray mul_(const DiffArray &a) const {
if constexpr (is_mask_v<Type> || std::is_pointer_v<Scalar>) {
fail_unsupported("mul_");
} else {
Index index_new = 0;
Type result = m_value * a.m_value;
if constexpr (Enabled)
index_new = tape()->append("mul", slices(result), m_index,
a.m_index, a.m_value, m_value);
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray div_(const DiffArray &a) const {
if constexpr (is_mask_v<Type> || std::is_pointer_v<Scalar>) {
fail_unsupported("div_");
} else {
Index index_new = 0;
Type result = m_value / a.m_value;
if constexpr (Enabled) {
Type rcp_a = rcp(a.m_value);
index_new = tape()->append("div", slices(result),
m_index, a.m_index, rcp_a,
-m_value * sqr(rcp_a));
}
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray fmadd_(const DiffArray &a, const DiffArray &b) const {
if constexpr (is_mask_v<Type>) {
fail_unsupported("fmadd_");
} else {
Index index_new = 0;
Type result = fmadd(m_value, a.m_value, b.m_value);
if constexpr (Enabled)
index_new = tape()->append("fmadd", slices(result),
m_index, a.m_index, b.m_index,
a.m_value, m_value, 1);
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray fmsub_(const DiffArray &a, const DiffArray &b) const {
if constexpr (is_mask_v<Type>) {
fail_unsupported("fmsub_");
} else {
Type result = fmsub(m_value, a.m_value, b.m_value);
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append("fmsub", slices(result),
m_index, a.m_index, b.m_index,
a.m_value, m_value, -1);
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray fnmadd_(const DiffArray &a, const DiffArray &b) const {
if constexpr (is_mask_v<Type>) {
fail_unsupported("fnmadd_");
} else {
Type result = fnmadd(m_value, a.m_value, b.m_value);
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append("fnmadd", slices(result),
m_index, a.m_index, b.m_index,
-a.m_value, -m_value, 1);
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray fnmsub_(const DiffArray &a, const DiffArray &b) const {
if constexpr (is_mask_v<Type>) {
fail_unsupported("fnmsub_");
} else {
Index index_new = 0;
Type result = fnmsub(m_value, a.m_value, b.m_value);
if constexpr (Enabled)
index_new = tape()->append("fnmsub", slices(result),
m_index, a.m_index, b.m_index,
-a.m_value, -m_value, -1);
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray neg_() const {
if constexpr (is_mask_v<Type> || std::is_pointer_v<Scalar>) {
fail_unsupported("neg_");
} else {
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append("neg", slices(m_value), m_index, -1.f);
return DiffArray::create(index_new, -m_value);
}
}
DiffArray abs_() const {
if constexpr (is_mask_v<Type> || std::is_pointer_v<Scalar>) {
fail_unsupported("abs_");
} else {
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append("abs", slices(m_value), m_index,
sign(m_value));
return DiffArray::create(index_new, abs(m_value));
}
}
DiffArray sqrt_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("sqrt_");
} else {
Index index_new = 0;
Type result = sqrt(m_value);
if constexpr (Enabled)
index_new = tape()->append("sqrt", slices(result), m_index,
.5f / result);
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray cbrt_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("cbrt_");
} else {
Index index_new = 0;
Type result = cbrt(m_value);
if constexpr (Enabled)
index_new = tape()->append("cbrt", slices(result), m_index,
1.f / (3 * sqr(result)));
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray rcp_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("rcp_");
} else {
Index index_new = 0;
Type result = rcp(m_value);
if constexpr (Enabled)
index_new = tape()->append("rcp", slices(result), m_index,
-sqr(result));
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray rsqrt_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("rsqrt_");
} else {
Index index_new = 0;
Type result = rsqrt(m_value);
if constexpr (Enabled) {
Type rsqrt_2 = sqr(result), rsqrt_3 = result * rsqrt_2;
index_new = tape()->append("rsqrt", slices(result), m_index,
-.5f * rsqrt_3);
}
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray min_(const DiffArray &a) const {
if constexpr (is_mask_v<Type>) {
fail_unsupported("min_");
} else {
Index index_new = 0;
Type result = min(m_value, a.m_value);
if constexpr (Enabled) {
mask_t<Type> m = m_value < a.m_value;
index_new = tape()->append("min", slices(result),
m_index, a.m_index,
select(m, Type(1), Type(0)),
select(m, Type(0), Type(1)));
}
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray max_(const DiffArray &a) const {
if constexpr (is_mask_v<Type>) {
fail_unsupported("max_");
} else {
Index index_new = 0;
Type result = max(m_value, a.m_value);
if constexpr (Enabled) {
mask_t<Type> m = m_value > a.m_value;
index_new = tape()->append("max", slices(result),
m_index, a.m_index,
select(m, Type(1), Type(0)),
select(m, Type(0), Type(1)));
}
return DiffArray::create(index_new, std::move(result));
}
}
static DiffArray select_(const DiffArray<mask_t<Type>> &m,
const DiffArray &t,
const DiffArray &f) {
Index index_new = 0;
Type result = select(m.value_(), t.m_value, f.m_value);
if constexpr (Enabled) {
index_new =
tape()->append("select", slices(result), t.m_index, f.m_index,
select(m.value_(), Type(1), Type(0)),
select(m.value_(), Type(0), Type(1)));
}
return DiffArray::create(index_new, std::move(result));
}
DiffArray floor_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>)
fail_unsupported("floor_");
else
return DiffArray::create(0, floor(m_value));
}
DiffArray ceil_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>)
fail_unsupported("ceil_");
else
return DiffArray::create(0, ceil(m_value));
}
DiffArray trunc_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>)
fail_unsupported("trunc_");
else
return DiffArray::create(0, trunc(m_value));
}
DiffArray round_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>)
fail_unsupported("round_");
else
return DiffArray::create(0, round(m_value));
}
template <typename T> T ceil2int_() const {
return T(ceil2int<typename T::UnderlyingType>(m_value));
}
template <typename T> T floor2int_() const {
return T(floor2int<typename T::UnderlyingType>(m_value));
}
DiffArray sin_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("sin_");
} else {
Index index_new = 0;
auto [s, c] = sincos(m_value);
if constexpr (Enabled)
index_new = tape()->append("sin", slices(m_value), m_index, c);
return DiffArray::create(index_new, std::move(s));
}
}
DiffArray cos_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("cos_");
} else {
Index index_new = 0;
auto [s, c] = sincos(m_value);
if constexpr (Enabled)
index_new = tape()->append("cos", slices(m_value), m_index, -s);
return DiffArray::create(index_new, std::move(c));
}
}
std::pair<DiffArray, DiffArray> sincos_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("sincos_");
} else {
Index index_new_s = 0, index_new_c = 0;
auto [s, c] = sincos(m_value);
if constexpr (Enabled) {
index_new_s = tape()->append("sin", slices(m_value), m_index, c);
index_new_c = tape()->append("cos", slices(m_value), m_index, -s);
}
return {
DiffArray::create(index_new_s, std::move(s)),
DiffArray::create(index_new_c, std::move(c))
};
}
}
DiffArray tan_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("tan_");
} else {
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append("tan", slices(m_value), m_index,
sqr(sec(m_value)));
return DiffArray::create(index_new, tan(m_value));
}
}
DiffArray csc_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("csc_");
} else {
Index index_new = 0;
Type csc_value = csc(m_value);
if constexpr (Enabled)
index_new = tape()->append("csc", slices(m_value), m_index,
-csc_value * cot(m_value));
return DiffArray::create(index_new, std::move(csc_value));
}
}
DiffArray sec_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("sec_");
} else {
Index index_new = 0;
Type sec_value = sec(m_value);
if constexpr (Enabled)
index_new = tape()->append("sec", slices(m_value), m_index,
sec_value * tan(m_value));
return DiffArray::create(index_new, std::move(sec_value));
}
}
DiffArray cot_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("cot_");
} else {
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append("cot", slices(m_value), m_index,
-sqr(csc(m_value)));
return DiffArray::create(index_new, cot(m_value));
}
}
DiffArray asin_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("asin_");
} else {
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append("asin", slices(m_value), m_index,
rsqrt(1 - sqr(m_value)));
return DiffArray::create(index_new, asin(m_value));
}
}
DiffArray acos_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("acos_");
} else {
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append("acos", slices(m_value), m_index,
-rsqrt(1 - sqr(m_value)));
return DiffArray::create(index_new, acos(m_value));
}
}
DiffArray atan_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("atan_");
} else {
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append("atan", slices(m_value), m_index,
rcp(1 + sqr(m_value)));
return DiffArray::create(index_new, atan(m_value));
}
}
DiffArray atan2_(const DiffArray &x) const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("atan2_");
} else {
Index index_new = 0;
if constexpr (Enabled) {
Type il2 = rcp(sqr(m_value) + sqr(x.m_value));
index_new = tape()->append("atan2", slices(il2),
m_index, x.m_index,
il2 * x.m_value, -il2 * m_value);
}
return DiffArray::create(index_new, atan2(m_value, x.m_value));
}
}
DiffArray sinh_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("sinh_");
} else {
Index index_new = 0;
auto [s, c] = sincosh(m_value);
if constexpr (Enabled)
index_new = tape()->append("sinh", slices(m_value), m_index, c);
return DiffArray::create(index_new, std::move(s));
}
}
DiffArray cosh_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("cosh_");
} else {
Index index_new = 0;
auto [s, c] = sincosh(m_value);
if constexpr (Enabled)
index_new = tape()->append("cosh", slices(m_value), m_index, s);
return DiffArray::create(index_new, std::move(c));
}
}
DiffArray csch_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("csch_");
} else {
Index index_new = 0;
Type result = csch(m_value);
if constexpr (Enabled)
index_new = tape()->append("csch", slices(m_value), m_index,
-result * coth(m_value));
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray sech_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("sech_");
} else {
Index index_new = 0;
Type result = sech(m_value);
if constexpr (Enabled)
index_new = tape()->append("sech", slices(m_value), m_index,
-result * tanh(m_value));
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray tanh_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("tanh_");
} else {
Index index_new = 0;
Type result = tanh(m_value);
if constexpr (Enabled)
index_new = tape()->append("index", slices(m_value), m_index,
sqr(sech(m_value)));
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray asinh_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("asinh_");
} else {
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append("asinh", slices(m_value), m_index,
rsqrt((Scalar) 1 + sqr(m_value)));
return DiffArray::create(index_new, asinh(m_value));
}
}
DiffArray acosh_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("acosh_");
} else {
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append("acosh", slices(m_value), m_index,
rsqrt(sqr(m_value) - (Scalar) 1));
return DiffArray::create(index_new, acosh(m_value));
}
}
DiffArray atanh_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("atanh_");
} else {
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append("atanh", slices(m_value), m_index,
rcp((Scalar) 1 - sqr(m_value)));
return DiffArray::create(index_new, atanh(m_value));
}
}
DiffArray exp_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("exp_");
} else {
Index index_new = 0;
Type result = exp(m_value);
if constexpr (Enabled)
index_new = tape()->append("exp", slices(m_value),
m_index, result);
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray log_() const {
if constexpr (is_mask_v<Type> || !std::is_floating_point_v<Scalar>) {
fail_unsupported("log_");
} else {
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append("log", slices(m_value), m_index,
rcp(m_value));
return DiffArray::create(index_new, log(m_value));
}
}
DiffArray or_(const DiffArray &m) const {
if constexpr (!is_mask_v<Type> && !std::is_integral_v<Scalar>)
fail_unsupported("or_");
else
return DiffArray::create(0, m_value | m.value_());
}
template <typename Mask> DiffArray or_(const Mask &m) const {
Index index_new = 0;
if constexpr (Enabled && is_mask_v<Mask>)
index_new = tape()->append("or", slices(m_value), m_index, 1);
return DiffArray::create(index_new, m_value | m.value_());
}
DiffArray and_(const DiffArray &m) const {
if constexpr (!is_mask_v<Type> && !std::is_integral_v<Scalar>)
fail_unsupported("and_");
else
return DiffArray::create(0, m_value & m.value_());
}
template <typename Mask>
DiffArray and_(const Mask &m) const {
Index index_new = 0;
if constexpr (Enabled && is_mask_v<Mask>)
index_new = tape()->append("and", slices(m_value), m_index,
select(m.value_(), Type(1), Type(0)));
return DiffArray::create(index_new, m_value & m.value_());
}
DiffArray xor_(const DiffArray &m) const {
if constexpr (!is_mask_v<Type> && !std::is_integral_v<Scalar>)
fail_unsupported("xor_");
else
return DiffArray::create(0, m_value ^ m.value_());
}
template <typename Mask>
DiffArray xor_(const Mask &m) const {
if (Enabled && m_index != 0)
fail_unsupported("xor_ -- gradients are not implemented.");
return DiffArray(m_value ^ m.value_());
}
DiffArray andnot_(const DiffArray &m) const {
if constexpr (!is_mask_v<Type> && !std::is_integral_v<Scalar>)
fail_unsupported("andnot_");
else
return DiffArray::create(0, andnot(m_value, m.value_()));
}
template <typename Mask>
DiffArray andnot_(const Mask &m) const {
if (Enabled && m_index != 0)
fail_unsupported("andnot_ -- gradients are not implemented.");
return DiffArray(andnot(m_value, m.value_()));
}
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Operations that don't require derivatives
// -----------------------------------------------------------------------
DiffArray mod_(const DiffArray &a) const {
if constexpr (!std::is_integral_v<Scalar>)
fail_unsupported("mod_");
else
return m_value % a.m_value;
}
DiffArray mulhi_(const DiffArray &a) const {
if constexpr (!std::is_integral_v<Scalar>)
fail_unsupported("mulhi_");
else
return mulhi(m_value, a.m_value);
}
DiffArray not_() const {
if constexpr ((!is_mask_v<Type> && !std::is_integral_v<Scalar>) ||
std::is_pointer_v<Scalar>)
fail_unsupported("not_");
else
return DiffArray::create(0, ~m_value);
}
template <typename Mask>
ENOKI_INLINE value_t<Type> extract_(const Mask &mask) const {
if constexpr (is_mask_v<Type> || Enabled)
fail_unsupported("extract_");
else
return extract(m_value, mask.value_());
}
DiffArray lzcnt_() const {
if constexpr ((!is_mask_v<Type> && !std::is_integral_v<Scalar>) ||
std::is_pointer_v<Scalar>)
fail_unsupported("lzcnt_");
else
return DiffArray::create(0, lzcnt(m_value));
}
DiffArray tzcnt_() const {
if constexpr ((!is_mask_v<Type> && !std::is_integral_v<Scalar>) ||
std::is_pointer_v<Scalar>)
fail_unsupported("tzcnt_");
else
return DiffArray::create(0, tzcnt(m_value));
}
DiffArray popcnt_() const {
if constexpr ((!is_mask_v<Type> && !std::is_integral_v<Scalar>) ||
std::is_pointer_v<Scalar>)
fail_unsupported("popcnt_");
else
return DiffArray::create(0, popcnt(m_value));
}
template <size_t Imm> DiffArray sl_() const {
if constexpr (is_mask_v<Type> || !std::is_integral_v<Scalar>)
fail_unsupported("sl_");
else
return DiffArray::create(0, sl<Imm>(m_value));
}
template <size_t Imm> DiffArray sr_() const {
if constexpr (is_mask_v<Type> || !std::is_integral_v<Scalar>)
fail_unsupported("sr_");
else
return DiffArray::create(0, sr<Imm>(m_value));
}
DiffArray sl_(const DiffArray &a) const {
if constexpr (is_mask_v<Type> || !std::is_integral_v<Scalar>)
fail_unsupported("sl_");
else
return DiffArray::create(0, m_value << a.m_value);
}
DiffArray sr_(const DiffArray &a) const {
if constexpr (is_mask_v<Type> || !std::is_integral_v<Scalar>)
fail_unsupported("sr_");
else
return DiffArray::create(0, m_value >> a.m_value);
}
DiffArray sl_(size_t size) const {
if constexpr (is_mask_v<Type> || !std::is_integral_v<Scalar>)
fail_unsupported("sl_");
else
return DiffArray::create(0, m_value << size);
}
DiffArray sr_(size_t size) const {
if constexpr (is_mask_v<Type> || !std::is_integral_v<Scalar>)
fail_unsupported("sr_");
else
return DiffArray::create(0, m_value >> size);
}
template <size_t Imm> DiffArray rol_() const {
if constexpr (is_mask_v<Type> || !std::is_integral_v<Scalar>)
fail_unsupported("rol_");
else
return DiffArray::create(0, rol<Imm>(m_value));
}
template <size_t Imm> DiffArray ror_() const {
if constexpr (is_mask_v<Type> || !std::is_integral_v<Scalar>)
fail_unsupported("ror_");
else
return DiffArray::create(0, ror<Imm>(m_value));
}
DiffArray rol_(const DiffArray &a) const {
if constexpr (is_mask_v<Type> || !std::is_integral_v<Scalar>)
fail_unsupported("rol_");
else
return DiffArray::create(0, rol(m_value, a.m_value));
}
DiffArray ror_(const DiffArray &a) const {
if constexpr (is_mask_v<Type> || !std::is_integral_v<Scalar>)
fail_unsupported("ror_");
else
return DiffArray::create(0, ror(m_value, a.m_value));
}
auto eq_ (const DiffArray &d) const { return MaskType(eq(m_value, d.m_value)); }
auto neq_(const DiffArray &d) const { return MaskType(neq(m_value, d.m_value)); }
auto lt_ (const DiffArray &d) const { return MaskType(m_value < d.m_value); }
auto le_ (const DiffArray &d) const { return MaskType(m_value <= d.m_value); }
auto gt_ (const DiffArray &d) const { return MaskType(m_value > d.m_value); }
auto ge_ (const DiffArray &d) const { return MaskType(m_value >= d.m_value); }
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Scatter/gather operations
// -----------------------------------------------------------------------
template <size_t Stride, typename Offset, typename Mask>
static DiffArray gather_(const void *ptr, const Offset &offset,
const Mask &mask) {
static_assert(!Enabled || Stride == sizeof(Scalar),
"Differentiable gather: unsupported stride!");
Type result = gather<Type, Stride>(ptr, offset.value_(), mask.value_());
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append_gather(offset.value_(), mask.value_());
return DiffArray::create(index_new, std::move(result));
}
template <size_t Stride, typename Offset, typename Mask>
void scatter_(void *ptr, const Offset &offset, const Mask &mask) const {
static_assert(!Enabled || Stride == sizeof(Scalar),
"Differentiable scatter: unsupported stride!");
scatter<Stride>(ptr, m_value, offset.value_(), mask.value_());
if constexpr (Enabled)
tape()->append_scatter(m_index, offset.value_(), mask.value_(), false);
}
template <size_t Stride, typename Offset, typename Mask>
void scatter_add_(void *ptr, const Offset &offset, const Mask &mask) const {
static_assert(!Enabled || Stride == sizeof(Scalar),
"Differentiable scatter_add: unsupported stride!");
scatter_add<Stride>(ptr, m_value, offset.value_(), mask.value_());
if constexpr (Enabled)
tape()->append_scatter(m_index, offset.value_(), mask.value_(), true);
}
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Horizontal operations
// -----------------------------------------------------------------------
auto all_() const {
if constexpr (!is_mask_v<Type>)
fail_unsupported("all_");
else
return all(m_value);
}
auto any_() const {
if constexpr (!is_mask_v<Type>)
fail_unsupported("any_");
else
return any(m_value);
}
auto count_() const {
if constexpr (!is_mask_v<Type>)
fail_unsupported("count_");
else
return count(m_value);
}
DiffArray reverse_() const {
if constexpr (is_mask_v<Type> || std::is_pointer_v<Scalar>) {
fail_unsupported("reverse_");
} else {
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append_reverse(m_index);
return DiffArray::create(index_new, reverse(m_value));
}
}
DiffArray psum_() const {
if constexpr (is_mask_v<Type> || std::is_pointer_v<Scalar>) {
fail_unsupported("psum_");
} else {
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append_psum(m_index);
return DiffArray::create(index_new, psum(m_value));
}
}
DiffArray hsum_() const {
if constexpr (is_mask_v<Type> || std::is_pointer_v<Scalar>) {
fail_unsupported("hsum_");
} else {
Index index_new = 0;
if constexpr (Enabled)
index_new = tape()->append("hsum", 1, m_index, 1.f);
return DiffArray::create(index_new, hsum(m_value));
}
}
DiffArray hprod_() const {
if constexpr (is_mask_v<Type> || std::is_pointer_v<Scalar>) {
fail_unsupported("hprod_");
} else {
Index index_new = 0;
Type result = hprod(m_value);
if constexpr (Enabled)
index_new = tape()->append(
"hprod", 1, m_index,
select(eq(m_value, (Scalar) 0), (Scalar) 0, result / m_value));
return DiffArray::create(index_new, std::move(result));
}
}
DiffArray hmax_() const {
if constexpr (is_mask_v<Type> || std::is_pointer_v<Scalar>) {
fail_unsupported("hmax_");
} else {
if (Enabled && m_index != 0)
fail_unsupported("hmax_: gradients not yet implemented!");
return DiffArray::create(0, hmax(m_value));
}
}
DiffArray hmin_() const {
if constexpr (is_mask_v<Type> || std::is_pointer_v<Scalar>) {
fail_unsupported("hmin_");
} else {
if (Enabled && m_index != 0)
fail_unsupported("hmin_: gradients not yet implemented!");
return DiffArray::create(0, hmin(m_value));
}
}
template <typename T = Scalar, enable_if_t<std::is_pointer_v<T>> = 0>
auto partition_() const {
std::vector<std::pair<T, uint32_array_t<DiffArray, false>>> result;
auto p = partition(m_value);
result.reserve(p.size());
for (auto &kv : p)
result.emplace_back(kv.first, std::move(kv.second));
return result;
}
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Access to internals
// -----------------------------------------------------------------------
void set_index_(Index index) {
if constexpr (Enabled) {
auto t = tape();
t->inc_ref_ext(index);
t->dec_ref_ext(m_index);
}
m_index = index;
}
Index index_() const { return m_index; }
Type &value_() { return m_value; }
const Type &value_() const { return m_value; }
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Coefficient access
// -----------------------------------------------------------------------
ENOKI_INLINE size_t size() const {
if constexpr (is_scalar_v<Type>)
return 1;
else
return slices(m_value);
}
ENOKI_INLINE bool empty() const {
if constexpr (is_scalar_v<Type>)
return false;
else
return slices(m_value) == 0;
}
ENOKI_NOINLINE void resize(size_t size) {
ENOKI_MARK_USED(size);
if constexpr (!is_scalar_v<Type>)
m_value.resize(size);
}
ENOKI_INLINE Scalar *data() {
if constexpr (is_scalar_v<Type>)
return &m_value;
else
return m_value.data();
}
ENOKI_INLINE const Scalar *data() const {
if constexpr (is_scalar_v<Type>)
return &m_value;
else
return m_value.data();
}
template <typename... Args>
ENOKI_INLINE decltype(auto) coeff(Args... args) {
static_assert(sizeof...(Args) == Depth, "coeff(): Invalid number of arguments!");
if constexpr (is_scalar_v<Type>)
return m_value;
else
return m_value.coeff((size_t) args...);
}
template <typename... Args>
ENOKI_INLINE decltype(auto) coeff(Args... args) const {
static_assert(sizeof...(Args) == Depth, "coeff(): Invalid number of arguments!");
if constexpr (is_scalar_v<Type>)
return m_value;
else
return m_value.coeff((size_t) args...);
}
const Type &gradient_() const {
if constexpr (!Enabled)
fail_unsupported("gradient_");
else
return tape()->gradient(m_index);
}
static const Type &gradient_static_(Index index) {
if constexpr (!Enabled)
fail_unsupported("gradient_static_");
else
return tape()->gradient(index);
}
void set_gradient_(const Type &value, bool backward = true) {
if constexpr (!Enabled)
fail_unsupported("set_gradient_");
else
return tape()->set_gradient(m_index, value, backward);
}
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Standard initializers
// -----------------------------------------------------------------------
template <typename... Args>
static DiffArray empty_(Args... args) { return enoki::empty<Type>(args...); }
template <typename... Args>
static DiffArray zero_(Args... args) { return zero<Type>(args...); }
template <typename... Args>
static DiffArray arange_(Args... args) { return arange<Type>(args...); }
template <typename... Args>
static DiffArray linspace_(Args... args) { return linspace<Type>(args...); }
template <typename... Args>
static DiffArray full_(Args... args) { return full<Type>(args...); }
//! @}
// -----------------------------------------------------------------------
void set_requires_gradient_(bool value) {
if constexpr (!Enabled) {
fail_unsupported("set_requires_gradient_");
} else {
if (value && m_index == 0) {
m_index = tape()->append_leaf(slices(m_value));
} else if (!value && m_index != 0) {
tape()->dec_ref_ext(m_index);
m_index = 0;
}
}
}
bool requires_gradient_() const {
return Enabled && m_index != 0;
}
void set_label_(const char *label) const {
ENOKI_MARK_USED(label);
if constexpr (Enabled)
tape()->set_label(m_index, label);
set_label(m_value, label);
}
void backward_(bool free_graph) const {
if constexpr (!Enabled) {
fail_unsupported("backward_");
} else {
tape()->backward(m_index, free_graph);
}
}
void forward_(bool free_graph) const {
if constexpr (!Enabled) {
fail_unsupported("forward_");
} else {
tape()->forward(m_index, free_graph);
}
}
static void backward_static_(bool free_graph) {
tape()->backward(free_graph);
}
static void forward_static_(bool free_graph) {
tape()->forward(free_graph);
}
static std::string graphviz_(const std::vector<Index> &indices) {
if constexpr (!Enabled)
fail_unsupported("graphviz_");
else
return tape()->graphviz(indices);
}
static void push_prefix_(const char *label) {
if constexpr (Enabled)
tape()->push_prefix(label);
}
static void pop_prefix_() {
if constexpr (Enabled)
tape()->pop_prefix();
}
static void inc_ref_ext_(Index index) {
if constexpr (Enabled)
tape()->inc_ref_ext(index);
}
static void dec_ref_ext_(Index index) {
if constexpr (Enabled)
tape()->dec_ref_ext(index);
}
static void set_scatter_gather_operand_(const DiffArray &v, bool permute) {
ENOKI_MARK_USED(v);
ENOKI_MARK_USED(permute);
if constexpr (Enabled)
tape()->set_scatter_gather_operand(const_cast<Index *>(&v.m_index),
v.size(), permute);
}
static void clear_scatter_gather_operand_() {
if constexpr (Enabled)
tape()->set_scatter_gather_operand(nullptr, 0, false);
}
static void set_log_level_(uint32_t level) {
if constexpr (Enabled)
tape()->set_log_level(level);
}
static uint32_t log_level_() {
if constexpr (Enabled)
return tape()->log_level();
else
return 0;
}
static void set_graph_simplification_(uint32_t level) {
if constexpr (Enabled)
tape()->set_graph_simplification(level);
}
static void simplify_graph_() {
if constexpr (Enabled)
tape()->simplify_graph();
}
static std::string whos_() {
if constexpr (!Enabled)
fail_unsupported("whos");
else
return tape()->whos();
}
static DiffArray map(void *ptr, size_t size, bool dealloc = false) {
if constexpr (!is_dynamic_array_v<Type>)
fail_unsupported("map");
else
return DiffArray::create(0, Type::map(ptr, size, dealloc));
}
static DiffArray copy(const void *ptr, size_t size) {
if constexpr (!is_dynamic_array_v<Type>)
fail_unsupported("copy");
else
return DiffArray::create(0, Type::copy(ptr, size));
}
DiffArray &managed() {
if constexpr (is_cuda_array_v<Type>)
m_value.managed();
return *this;
}
const DiffArray &managed() const {
if constexpr (is_cuda_array_v<Type>)
m_value.managed();
return *this;
}
DiffArray &eval() {
if constexpr (is_cuda_array_v<Type>)
m_value.eval();
return *this;
}
const DiffArray &eval() const {
if constexpr (is_cuda_array_v<Type>)
m_value.eval();
return *this;
}
auto operator->() const {
using BaseType = std::decay_t<std::remove_pointer_t<Scalar>>;
return call_support<BaseType, DiffArray>(*this);
}
private:
ENOKI_INLINE static Tape* tape() { return Tape::get(); }
using Arg = std::conditional_t<std::is_scalar_v<Type>, Type, Type&&>;
ENOKI_INLINE static DiffArray create(Index index, Arg value) {
DiffArray result(std::move(value));
result.m_index = index;
return result;
}
[[noreturn]]
ENOKI_NOINLINE static void fail_unsupported(const char *msg) {
fprintf(stderr, "DiffArray: unsupported operation for type %s", msg);
exit(EXIT_FAILURE);
}
Type m_value;
Index m_index = 0;
};
template <typename T, enable_if_t<is_diff_array_v<T>> = 0>
ENOKI_INLINE void set_label(const T& a, const char *label) {
if constexpr (array_depth_v<T> >= 2) {
for (size_t i = 0; i < T::Size; ++i)
set_label(a.coeff(i), (std::string(label) + "." + std::to_string(i)).c_str());
} else {
a.set_label_(label);
}
}
template <typename T> ENOKI_INLINE bool requires_gradient(T& a) {
if constexpr (is_diff_array_v<T>) {
if constexpr (array_depth_v<T> >= 2) {
for (size_t i = 0; i < a.size(); ++i) {
if (requires_gradient(a.coeff(i)))
return true;
}
return false;
} else {
return a.requires_gradient_();
}
}
return false;
}
template <typename T> ENOKI_INLINE void set_requires_gradient(T& a, bool value = true) {
if constexpr (is_diff_array_v<T>) {
if constexpr (array_depth_v<T> >= 2) {
for (size_t i = 0; i < a.size(); ++i)
set_requires_gradient(a.coeff(i), value);
} else {
a.set_requires_gradient_(value);
}
}
}
template <typename T> auto gradient_index(const T &a) {
if constexpr (array_depth_v<T> >= 2) {
using Result = std::array<decltype(gradient_index(a.coeff(0))), T::Size>;
Result result;
for (size_t i = 0; i < T::Size; ++i)
result[i] = gradient_index(a.coeff(i));
return result;
} else if constexpr (is_diff_array_v<T>) {
return a.index_();
} else {
static_assert(detail::false_v<T>, "The given array does not support derivatives.");
}
}
template <typename T1, typename T2> void set_gradient(T1 &a, const T2 &b, bool backward = true) {
if constexpr (array_depth_v<T1> >= 2) {
for (size_t i = 0; i < array_size_v<T1>; ++i)
set_gradient(a[i], b[i], backward);
} else if constexpr (is_diff_array_v<T1>) {
a.set_gradient_(b, backward);
} else {
static_assert(detail::false_v<T1, T2>, "The given array does not support derivatives.");
}
}
template <typename T1> void reattach(T1 &a, const T1 &b) {
if constexpr (array_depth_v<T1> >= 2) {
for (size_t i = 0; i < array_size_v<T1>; ++i)
reattach(a[i], b[i]);
} else if constexpr (is_diff_array_v<T1>) {
a.set_index_(b.index_());
} else {
static_assert(detail::false_v<T1>, "The given array does not support derivatives.");
}
}
template <typename T> void forward(const T& a, bool free_graph = true) {
a.forward_(free_graph);
}
template <typename T> void backward(const T& a, bool free_graph = true) {
a.backward_(free_graph);
}
template <typename T> void backward(bool free_graph = true) {
T::backward_static_(free_graph);
}
template <typename T> void forward(bool free_graph = true) {
T::forward_static_(free_graph);
}
namespace detail {
template <typename T>
void collect_indices(const T &value, std::vector<uint32_t> &indices) {
if constexpr (is_diff_array_v<T>) {
if constexpr (array_depth_v<T> == 1) {
if (value.index_() != 0)
indices.push_back(value.index_());
} else {
for (size_t i = 0; i < T::Size; ++i)
collect_indices(value.coeff(i), indices);
}
}
}
};
namespace detail {
template <typename T, typename = int> struct diff_type {
using type = T;
};
template <typename T> using diff_type_t = typename diff_type<T>::type;
template <typename T> struct diff_type<T, enable_if_t<is_diff_array_v<value_t<T>>>> {
using type = diff_type_t<value_t<T>>;
};
}
template <typename T> std::string graphviz(const T &value) {
std::vector<uint32_t> indices;
detail::collect_indices(value, indices);
return detail::diff_type_t<T>::graphviz_(indices);
}
#if defined(ENOKI_AUTODIFF_BUILD)
# define ENOKI_AUTODIFF_EXTERN extern
# define ENOKI_AUTODIFF_EXPORT ENOKI_EXPORT
#else
# define ENOKI_AUTODIFF_EXPORT ENOKI_IMPORT
# if defined(_MSC_VER)
# define ENOKI_AUTODIFF_EXTERN
#else
# define ENOKI_AUTODIFF_EXTERN extern
# endif
#endif
#if !defined(ENOKI_BUILD)
ENOKI_AUTODIFF_EXTERN template struct ENOKI_AUTODIFF_EXPORT Tape<float>;
ENOKI_AUTODIFF_EXTERN template struct ENOKI_AUTODIFF_EXPORT DiffArray<float>;
ENOKI_AUTODIFF_EXTERN template struct ENOKI_AUTODIFF_EXPORT Tape<double>;
ENOKI_AUTODIFF_EXTERN template struct ENOKI_AUTODIFF_EXPORT DiffArray<double>;
# if defined(ENOKI_DYNAMIC_H)
ENOKI_AUTODIFF_EXTERN template struct ENOKI_AUTODIFF_EXPORT Tape<DynamicArray<Packet<float>>>;
ENOKI_AUTODIFF_EXTERN template struct ENOKI_AUTODIFF_EXPORT DiffArray<DynamicArray<Packet<float>>>;
ENOKI_AUTODIFF_EXTERN template struct ENOKI_AUTODIFF_EXPORT Tape<DynamicArray<Packet<double>>>;
ENOKI_AUTODIFF_EXTERN template struct ENOKI_AUTODIFF_EXPORT DiffArray<DynamicArray<Packet<double>>>;
# endif
# if defined(ENOKI_CUDA_H)
ENOKI_AUTODIFF_EXTERN template struct ENOKI_AUTODIFF_EXPORT Tape<CUDAArray<float>>;
ENOKI_AUTODIFF_EXTERN template struct ENOKI_AUTODIFF_EXPORT DiffArray<CUDAArray<float>>;
ENOKI_AUTODIFF_EXTERN template struct ENOKI_AUTODIFF_EXPORT Tape<CUDAArray<double>>;
ENOKI_AUTODIFF_EXTERN template struct ENOKI_AUTODIFF_EXPORT DiffArray<CUDAArray<double>>;
# endif
#endif
NAMESPACE_END(enoki)