/* enoki/array_router.h -- Helper functions which route function calls in the enoki namespace to the intended recipients Enoki is a C++ template library that enables transparent vectorization of numerical kernels using SIMD instruction sets available on current processor architectures. Copyright (c) 2019 Wenzel Jakob All rights reserved. Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. */ #pragma once #include "array_traits.h" #include "array_fallbacks.h" NAMESPACE_BEGIN(enoki) /// Define an unary operation #define ENOKI_ROUTE_UNARY(name, func) \ template = 0> \ ENOKI_INLINE auto name(const T &a) { \ return eval(a).func##_(); \ } /// Define an unary operation with an immediate argument (e.g. sr<5>(...)) #define ENOKI_ROUTE_UNARY_IMM(name, func) \ template = 0> \ ENOKI_INLINE auto name(const T &a) { \ return eval(a).template func##_(); /* Forward to array */ \ } /// Define an unary operation with a fallback expression for scalar arguments #define ENOKI_ROUTE_UNARY_SCALAR(name, func, expr) \ template ENOKI_INLINE auto name(const T &a) { \ if constexpr (!is_array_v) \ return expr; /* Scalar fallback implementation */ \ else \ return eval(a).func##_(); /* Forward to array */ \ } /// Define an unary operation with an immediate argument and a scalar fallback #define ENOKI_ROUTE_UNARY_SCALAR_IMM(name, func, expr) \ template ENOKI_INLINE auto name(const T &a) { \ if constexpr (!is_array_v) \ return expr; /* Scalar fallback implementation */ \ else \ return eval(a).template func##_(); /* Forward to array */ \ } /// Define a binary operation #define ENOKI_ROUTE_BINARY(name, func) \ template = 0> \ ENOKI_INLINE auto name(const T1 &a1, const T2 &a2) { \ using E = expr_t; \ if constexpr (std::is_same_v && std::is_same_v) \ return a1.derived().func##_(a2.derived()); \ else \ return name(static_cast(a1), \ static_cast(a2)); \ } /// Define a binary operation for bit operations #define ENOKI_ROUTE_BINARY_BITOP(name, func) \ template = 0> \ ENOKI_INLINE auto name(const T1 &a1, const T2 &a2) { \ using E = expr_t; \ if constexpr (std::is_same_v && std::is_same_v) \ return a1.derived().func##_(a2.derived()); \ else if constexpr (is_mask_v && !is_array_v) \ return a1.derived().func##_((const mask_t &) a2); \ else if constexpr (is_array_v) { \ if constexpr (std::decay_t::IsMask) \ return a1.derived().func##_((const mask_t &) a2.derived());\ else \ return name(static_cast(a1), \ static_cast(a2)); \ } else { \ return name(static_cast(a1), \ static_cast(a2)); \ } \ } /// Define a binary operation (but only restrict to cases where 'cond' is true) #define ENOKI_ROUTE_BINARY_COND(name, func, cond) \ template = 0, \ enable_if_array_any_t = 0> \ ENOKI_INLINE auto name(const T1 &a1, const T2 &a2) { \ using E = expr_t; \ if constexpr (std::is_same_v && std::is_same_v) \ return a1.derived().func##_(a2.derived()); \ else \ return name(static_cast(a1), \ static_cast(a2)); \ } #define ENOKI_ROUTE_BINARY_SHIFT(name, func) \ template >> = 0, \ enable_if_array_any_t = 0> \ ENOKI_INLINE auto name(const T1 &a1, const T2 &a2) { \ using E = expr_t; \ if constexpr (std::is_integral_v) \ return eval(a1).func##_((size_t) a2); \ else if constexpr (std::is_same_v && std::is_same_v) \ return a1.derived().func##_(a2.derived()); \ else \ return name(static_cast(a1), \ static_cast(a2)); \ } /// Define a binary operation with a fallback expression for scalar arguments #define ENOKI_ROUTE_BINARY_SCALAR(name, func, expr) \ template \ ENOKI_INLINE auto name(const T1 &a1, const T2 &a2) { \ using E = expr_t; \ if constexpr (is_array_any_v) { \ if constexpr (std::is_same_v && std::is_same_v) \ return a1.derived().func##_(a2.derived()); \ else \ return name(static_cast(a1), \ static_cast(a2)); \ } else { \ return expr; \ } \ } /// Define a ternary operation #define ENOKI_ROUTE_TERNARY_SCALAR(name, func, expr) \ template \ ENOKI_INLINE auto name(const T1 &a1, const T2 &a2, const T3 &a3) { \ using E = expr_t; \ if constexpr (is_array_any_v) { \ if constexpr (std::is_same_v && \ std::is_same_v && \ std::is_same_v) \ return a1.derived().func##_(a2.derived(), a3.derived()); \ else \ return name(static_cast(a1), \ static_cast(a2), \ static_cast(a3)); \ } else { \ return expr; \ } \ } /// Macro for compound assignment operators (operator+=, etc.) #define ENOKI_ROUTE_COMPOUND_OPERATOR(op) \ template && \ !std::is_const_v> = 0, typename T2> \ ENOKI_INLINE T1 &operator op##=(T1 &a1, const T2 &a2) { \ a1 = a1 op a2; \ return a1; \ } template = 0> ENOKI_INLINE decltype(auto) eval(const T& x) { if constexpr (std::is_same_v, expr_t>) return x.derived(); else return expr_t(x); } ENOKI_ROUTE_UNARY(operator-, neg) ENOKI_ROUTE_UNARY(operator~, not) ENOKI_ROUTE_UNARY(operator!, not) ENOKI_ROUTE_BINARY_COND(operator+, add, !std::is_pointer_v> && !std::is_pointer_v>) ENOKI_ROUTE_BINARY_COND(operator-, sub, !std::is_pointer_v> && !std::is_pointer_v>) ENOKI_ROUTE_BINARY(operator*, mul) ENOKI_ROUTE_BINARY_SHIFT(operator<<, sl) ENOKI_ROUTE_BINARY_SHIFT(operator>>, sr) ENOKI_ROUTE_UNARY_SCALAR_IMM(sl, sl, a << Imm) ENOKI_ROUTE_UNARY_SCALAR_IMM(sr, sr, a >> Imm) ENOKI_ROUTE_BINARY_BITOP(operator&, and) ENOKI_ROUTE_BINARY_BITOP(operator&&, and) ENOKI_ROUTE_BINARY_BITOP(operator|, or) ENOKI_ROUTE_BINARY_BITOP(operator||, or) ENOKI_ROUTE_BINARY_BITOP(operator^, xor) ENOKI_ROUTE_BINARY_SCALAR(andnot, andnot, a1 & !a2) ENOKI_ROUTE_BINARY(operator<, lt) ENOKI_ROUTE_BINARY(operator<=, le) ENOKI_ROUTE_BINARY(operator>, gt) ENOKI_ROUTE_BINARY(operator>=, ge) ENOKI_ROUTE_BINARY_SCALAR(eq, eq, a1 == a2) ENOKI_ROUTE_BINARY_SCALAR(neq, neq, a1 != a2) ENOKI_ROUTE_COMPOUND_OPERATOR(+) ENOKI_ROUTE_COMPOUND_OPERATOR(-) ENOKI_ROUTE_COMPOUND_OPERATOR(*) ENOKI_ROUTE_COMPOUND_OPERATOR(/) ENOKI_ROUTE_COMPOUND_OPERATOR(^) ENOKI_ROUTE_COMPOUND_OPERATOR(|) ENOKI_ROUTE_COMPOUND_OPERATOR(&) ENOKI_ROUTE_COMPOUND_OPERATOR(<<) ENOKI_ROUTE_COMPOUND_OPERATOR(>>) ENOKI_ROUTE_BINARY_SCALAR(max, max, (std::decay_t) std::max((E) a1, (E) a2)) ENOKI_ROUTE_BINARY_SCALAR(min, min, (std::decay_t) std::min((E) a1, (E) a2)) ENOKI_ROUTE_BINARY_SCALAR(dot, dot, (E) a1 * (E) a2) ENOKI_ROUTE_BINARY_SCALAR(mulhi, mulhi, detail::mulhi_scalar(a1, a2)) ENOKI_ROUTE_UNARY_SCALAR(abs, abs, detail::abs_scalar(a)) ENOKI_ROUTE_TERNARY_SCALAR(fmadd, fmadd, detail::fmadd_scalar((E) a1, (E) a2, (E) a3)) ENOKI_ROUTE_TERNARY_SCALAR(fmsub, fmsub, detail::fmadd_scalar((E) a1, (E) a2, (E) -a3)) ENOKI_ROUTE_TERNARY_SCALAR(fnmadd, fnmadd, detail::fmadd_scalar((E) -a1, (E) a2, (E) a3)) ENOKI_ROUTE_TERNARY_SCALAR(fnmsub, fnmsub, detail::fmadd_scalar((E) -a1, (E) a2, (E) -a3)) ENOKI_ROUTE_TERNARY_SCALAR(fmaddsub, fmaddsub, fmsub(a1, a2, a3)) ENOKI_ROUTE_TERNARY_SCALAR(fmsubadd, fmsubadd, fmadd(a1, a2, a3)) ENOKI_ROUTE_UNARY_SCALAR(rcp, rcp, 1 / a) ENOKI_ROUTE_UNARY_SCALAR(rsqrt, rsqrt, detail::rsqrt_scalar(a)) ENOKI_ROUTE_UNARY_SCALAR(popcnt, popcnt, detail::popcnt_scalar(a)) ENOKI_ROUTE_UNARY_SCALAR(lzcnt, lzcnt, detail::lzcnt_scalar(a)) ENOKI_ROUTE_UNARY_SCALAR(tzcnt, tzcnt, detail::tzcnt_scalar(a)) ENOKI_ROUTE_UNARY_SCALAR(all, all, (bool) a) ENOKI_ROUTE_UNARY_SCALAR(any, any, (bool) a) ENOKI_ROUTE_UNARY_SCALAR(count, count, (size_t) ((bool) a ? 1 : 0)) ENOKI_ROUTE_UNARY_SCALAR(reverse, reverse, a) ENOKI_ROUTE_UNARY_SCALAR(psum, psum, a) ENOKI_ROUTE_UNARY_SCALAR(hsum, hsum, a) ENOKI_ROUTE_UNARY_SCALAR(hprod, hprod, a) ENOKI_ROUTE_UNARY_SCALAR(hmin, hmin, a) ENOKI_ROUTE_UNARY_SCALAR(hmax, hmax, a) ENOKI_ROUTE_UNARY_SCALAR(hmean, hmean, a) ENOKI_ROUTE_UNARY_SCALAR(all_inner, all_inner, (bool) a) ENOKI_ROUTE_UNARY_SCALAR(any_inner, any_inner, (bool) a) ENOKI_ROUTE_UNARY_SCALAR(count_inner, count_inner, (size_t) ((bool) a ? 1 : 0)) ENOKI_ROUTE_UNARY_SCALAR(psum_inner, psum_inner, a) ENOKI_ROUTE_UNARY_SCALAR(hsum_inner, hsum_inner, a) ENOKI_ROUTE_UNARY_SCALAR(hprod_inner, hprod_inner, a) ENOKI_ROUTE_UNARY_SCALAR(hmin_inner, hmin_inner, a) ENOKI_ROUTE_UNARY_SCALAR(hmax_inner, hmax_inner, a) ENOKI_ROUTE_UNARY_SCALAR(hmean_inner, hmean_inner, a) ENOKI_ROUTE_UNARY_SCALAR(sqrt, sqrt, std::sqrt(a)) ENOKI_ROUTE_UNARY_SCALAR(floor, floor, std::floor(a)) ENOKI_ROUTE_UNARY_SCALAR(ceil, ceil, std::ceil(a)) ENOKI_ROUTE_UNARY_SCALAR(round, round, std::rint(a)) ENOKI_ROUTE_UNARY_SCALAR(trunc, trunc, std::trunc(a)) ENOKI_ROUTE_UNARY_IMM(rol_array, rol_array) ENOKI_ROUTE_UNARY_IMM(ror_array, ror_array) template auto none(const T &value) { return !any(value); } template auto none_inner(const T &value) { return !any_inner(value); } /// Floating point division template = 0, enable_if_t>>> = 0> ENOKI_INLINE auto operator/(const T1 &a1, const T2 &a2) { using E = expr_t; using T = expr_t, T2>; if constexpr (std::is_same_v && std::is_same_v) return a1.derived().div_(a2.derived()); else if constexpr (array_depth_v > array_depth_v) return static_cast(a1) * // reciprocal approximation rcp((const T &) a2); else return operator/(static_cast(a1), static_cast(a2)); } template = 0, enable_if_t>> && is_array_v> = 0> ENOKI_INLINE auto operator/(const T1 &a1, const T2 &a2) { using E = expr_t; if constexpr (std::is_same_v && std::is_same_v) return a1.derived().div_(a2.derived()); else return operator/(static_cast(a1), static_cast(a2)); } template = 0, enable_if_t>> && is_array_v> = 0> ENOKI_INLINE auto operator%(const T1 &a1, const T2 &a2) { using E = expr_t; if constexpr (std::is_same_v && std::is_same_v) return a1.derived().mod_(a2.derived()); else return operator%(static_cast(a1), static_cast(a2)); } /// Shuffle the entries of an array template ENOKI_INLINE auto shuffle(const T &a) { if constexpr (is_array_v) { return eval(a).template shuffle_(); } else { static_assert(sizeof...(Is) == 1 && (... && (Is == 0)), "Shuffle argument out of bounds!"); return a; } } template && is_array_v && std::is_integral_v>> = 0> ENOKI_INLINE Array shuffle(const Array &a, const Index &idx) { if constexpr (Index::Depth > Array::Depth) { Array result; for (size_t i = 0; i < Array::Size; ++i) result.coeff(i) = shuffle(a.derived().coeff(i), idx); return result; } else { return a.derived().shuffle_((int_array_t &) idx); } } //// Compute the square of the given value template ENOKI_INLINE auto sqr(const T &value) { return value * value; } //// Convert radians to degrees template ENOKI_INLINE auto rad_to_deg(const T &a) { return a * scalar_t(180 / M_PI); } /// Convert degrees to radians template ENOKI_INLINE auto deg_to_rad(const T &a) { return a * scalar_t(M_PI / 180); } template ENOKI_INLINE auto sign_mask() { using Scalar = scalar_t; using UInt = uint_array_t; return memcpy_cast(UInt(1) << (sizeof(UInt) * 8 - 1)); } template > ENOKI_INLINE Expr sign(const T &a) { using Scalar = scalar_t; if constexpr (array_depth_v >= 2) { Expr result; for (size_t i = 0; i < Expr::Size; ++i) result.coeff(i) = sign(a.coeff(i)); return result; } else if constexpr (!std::is_signed_v) { return Expr(Scalar(1)); } else if constexpr (!std::is_floating_point_v || is_diff_array_v) { return select(a < Scalar(0), Expr(Scalar(-1)), Expr(Scalar(1))); } else if constexpr (is_scalar_v) { return std::copysign(Scalar(1), a); } else { return (sign_mask() & a) | Expr(Scalar(1)); } } template > ENOKI_INLINE Expr copysign(const T1 &a1, const T2 &a2) { using Scalar1 = scalar_t; using Scalar2 = scalar_t; static_assert(std::is_same_v || !std::is_signed_v, "copysign(): Incompatible input arguments!"); if constexpr (!std::is_same_v || !std::is_same_v) { return copysign((const Expr &) a1, (const Expr &) a2); } else if constexpr (array_depth_v >= 2) { Expr result; for (size_t i = 0; i < Expr::Size; ++i) result.coeff(i) = copysign(a1.coeff(i), a2.coeff(i)); return result; } else if constexpr (!std::is_floating_point_v) { return select((a1 ^ a2) < Scalar1(0), a1, -a1); } else if constexpr (is_scalar_v) { return std::copysign(a1, a2); } else if constexpr (is_diff_array_v) { return abs(a1) * sign(a2); } else { return abs(a1) | (sign_mask() & a2); } } template > ENOKI_INLINE Expr copysign_neg(const T1 &a1, const T2 &a2) { using Scalar1 = scalar_t; using Scalar2 = scalar_t; static_assert(std::is_same_v || !std::is_signed_v, "copysign_neg(): Incompatible input arguments!"); if constexpr (!std::is_same_v || !std::is_same_v) { return copysign_neg((const Expr &) a1, (const Expr &) a2); } else if constexpr (array_depth_v >= 2) { Expr result; for (size_t i = 0; i < Expr::Size; ++i) result.coeff(i) = copysign_neg(a1.coeff(i), a2.coeff(i)); return result; } else if constexpr (!std::is_floating_point_v) { return select((a1 ^ a2) < Scalar1(0), -a1, a1); } else if constexpr (is_scalar_v) { return std::copysign(a1, -a2); } else if constexpr (is_diff_array_v) { return abs(a1) * -sign(a2); } else { return abs(a1) | andnot(sign_mask(), a2); } } template > ENOKI_INLINE Expr mulsign(const T1 &a1, const T2 &a2) { using Scalar1 = scalar_t; using Scalar2 = scalar_t; static_assert(std::is_same_v || !std::is_signed_v, "mulsign(): Incompatible input arguments!"); if constexpr (!std::is_same_v || !std::is_same_v) { return mulsign((const Expr &) a1, (const Expr &) a2); } else if constexpr (array_depth_v >= 2) { Expr result; for (size_t i = 0; i < Expr::Size; ++i) result.coeff(i) = mulsign(a1.coeff(i), a2.coeff(i)); return result; } else if constexpr (!std::is_floating_point_v) { return select(a2 < Scalar1(0), -a1, a1); } else if constexpr (is_scalar_v) { return a1 * std::copysign(Scalar1(1), a2); } else if constexpr (is_diff_array_v) { return a1 * sign(a2); } else { return a1 ^ (sign_mask() & a2); } } template > ENOKI_INLINE Expr mulsign_neg(const T1 &a1, const T2 &a2) { using Scalar1 = scalar_t; using Scalar2 = scalar_t; static_assert(std::is_same_v || !std::is_signed_v, "mulsign_neg(): Incompatible input arguments!"); if constexpr (!std::is_same_v || !std::is_same_v) { return mulsign_neg((const Expr &) a1, (const Expr &) a2); } else if constexpr (array_depth_v >= 2) { Expr result; for (size_t i = 0; i < Expr::Size; ++i) result.coeff(i) = mulsign_neg(a1.coeff(i), a2.coeff(i)); return result; } else if constexpr (!std::is_floating_point_v) { return select(a2 < Scalar1(0), a1, -a1); } else if constexpr (is_scalar_v) { return a1 * std::copysign(Scalar1(1), -a2); } else if constexpr (is_diff_array_v) { return a1 * -sign(a2); } else { return a1 ^ andnot(sign_mask(), a2); } } template ENOKI_INLINE auto select(const M &m, const T &t, const F &f) { using E = expr_t; if constexpr (!is_array_v) return (bool) m ? (E) t : (E) f; else if constexpr (std::is_same_v> && std::is_same_v && std::is_same_v) return E::select_(m.derived(), t.derived(), f.derived()); else return select((const mask_t &) m, (const E &) t, (const E &) f); } template = 0> ENOKI_INLINE bool operator==(const T1 &a1, const T2 &a2) { return all_nested(eq(a1, a2)); } template = 0> ENOKI_INLINE bool operator!=(const T1 &a1, const T2 &a2) { return any_nested(neq(a1, a2)); } namespace detail { template using has_ror = decltype(std::declval().template ror_<0>()); template constexpr bool has_ror_v = is_detected_v; } /// Bit-level rotate left (with immediate offset value) template ENOKI_INLINE auto rol(const T &a) { constexpr size_t Mask = 8 * sizeof(scalar_t) - 1u; using UInt = uint_array_t; if constexpr (detail::has_ror_v) return a.template rol_(); else return sl(a) | T(sr<((~Imm + 1u) & Mask)>(UInt(a))); } /// Bit-level rotate right (with immediate offset value) template ENOKI_INLINE auto rol(const T1 &a1, const T2 &a2) { if constexpr (detail::has_ror_v) { return a1.rol_(a2); } else { using U1 = uint_array_t; using U2 = uint_array_t; using Expr = expr_t; constexpr scalar_t Mask = 8 * sizeof(scalar_t) - 1u; U1 u1 = (U1) a1; U2 u2 = (U2) a2; return Expr((u1 << u2) | (u1 >> ((~u2 + 1u) & Mask))); } } /// Bit-level rotate right (with scalar or array offset value) template ENOKI_INLINE T ror(const T &a) { constexpr size_t Mask = 8 * sizeof(scalar_t) - 1u; using UInt = uint_array_t; if constexpr (detail::has_ror_v) return a.template ror_(); else return T(sr(UInt(a))) | sl<((~Imm + 1u) & Mask)>(a); } /// Bit-level rotate right (with scalar or array offset value) template ENOKI_INLINE auto ror(const T1 &a1, const T2 &a2) { if constexpr (detail::has_ror_v) { return a1.ror_(a2); } else { using U1 = uint_array_t; using U2 = uint_array_t; using Expr = expr_t; constexpr scalar_t Mask = 8 * sizeof(scalar_t) - 1u; U1 u1 = (U1) a1; U2 u2 = (U2) a2; return Expr((u1 >> u2) | (u1 << ((~u2 + 1u) & Mask))); } } /// Fast implementation for computing the base 2 log of an integer. template ENOKI_INLINE auto log2i(T value) { return scalar_t(sizeof(scalar_t) * 8 - 1) - lzcnt(value); } template struct MaskBit { MaskBit(T &mask, size_t index) : mask(mask), index(index) { } operator bool() const { return mask.bit_(index); } MaskBit &operator=(bool b) { mask.set_bit_(index, b); return *this; } private: T mask; size_t index; }; template ENOKI_INLINE Target reinterpret_array(const Source &src) { if constexpr (std::is_same_v) { return src; } else if constexpr (std::is_constructible_v) { return Target(src, detail::reinterpret_flag()); } else if constexpr (is_scalar_v && is_scalar_v) { if constexpr (sizeof(Source) == sizeof(Target)) { return memcpy_cast(src); } else { using SrcInt = int_array_t; using TrgInt = int_array_t; if constexpr (std::is_same_v) return memcpy_cast(src) != 0 ? true : false; else return memcpy_cast(memcpy_cast(src) != 0 ? TrgInt(-1) : TrgInt(0)); } } else { static_assert(detail::false_v, "reinterpret_array(): don't know what to do!"); } } template ENOKI_INLINE Target reinterpret_array(const MaskBit &src) { return reinterpret_array((bool) src); } /// Element-wise test for NaN values template ENOKI_INLINE auto isnan(const T &a) { return !eq(a, a); } /// Element-wise test for +/- infinity template ENOKI_INLINE auto isinf(const T &a) { return eq(abs(a), std::numeric_limits>::infinity()); } /// Element-wise test for finiteness template ENOKI_INLINE auto isfinite(const T &a) { return abs(a) < std::numeric_limits>::infinity(); } /// Extract the low elements from an array of even size template 1 && Array::Size != -1)> = 0> auto low(const Array &a) { return a.derived().low_(); } /// Extract the high elements from an array of even size template 1 && Array::Size != -1)> = 0> auto high(const Array &a) { return a.derived().high_(); } template T floor2int(const Arg &a) { if constexpr (is_array_v) return a.template floor2int_(); else return detail::floor2int_scalar(a); } template T ceil2int(const Arg &a) { if constexpr (is_array_v) return a.template ceil2int_(); else return detail::ceil2int_scalar(a); } // ----------------------------------------------------------------------- //! @{ \name Miscellaneous routines for vector spaces // ----------------------------------------------------------------------- template ENOKI_INLINE auto abs_dot(const T1 &a1, const T2 &a2) { return abs(dot(a1, a2)); } template ENOKI_INLINE auto norm(const T &v) { return sqrt(dot(v, v)); } template ENOKI_INLINE auto squared_norm(const T &v) { return dot(v, v); } template ENOKI_INLINE auto normalize(const T &v) { return v * rsqrt(squared_norm(v)); } template > = 0> ENOKI_INLINE auto partition(const T &v) { return v.partition_(); } template == 3 && array_size_v == 3> = 0> ENOKI_INLINE auto cross(const T1 &v1, const T2 &v2) { #if defined(ENOKI_ARM_32) || defined(ENOKI_ARM_64) return fnmadd( shuffle<2, 0, 1>(v1), shuffle<1, 2, 0>(v2), shuffle<1, 2, 0>(v1) * shuffle<2, 0, 1>(v2) ); #else return fmsub(shuffle<1, 2, 0>(v1), shuffle<2, 0, 1>(v2), shuffle<2, 0, 1>(v1) * shuffle<1, 2, 0>(v2)); #endif } template decltype(auto) detach(T &value) { if constexpr (is_array_v) { if constexpr (!is_diff_array_v) return value; else if constexpr (array_depth_v == 1) return value.value_(); else return struct_support_t::detach(value); } else { return struct_support_t::detach(value); } } template decltype(auto) gradient(T &&value) { if constexpr (is_array_v) { if constexpr (!is_diff_array_v) return value; else if constexpr (array_depth_v == 1) return value.gradient_(); else return struct_support_t::gradient(value); } else { return struct_support_t::gradient(value); } } //! @} // ----------------------------------------------------------------------- // ----------------------------------------------------------------------- //! @{ \name Initialization, loading/writing data // ----------------------------------------------------------------------- template ENOKI_INLINE T zero(size_t size = 1); template ENOKI_INLINE T empty(size_t size = 1); /// Construct an index sequence, i.e. 0, 1, 2, .. template = 0> ENOKI_INLINE Array arange(size_t end = 1) { return Array::arange_(0, (ssize_t) end, 1); } template = 0> ENOKI_INLINE Array arange(size_t end = Array::Size) { assert(end == Array::Size); (void) end; return Array::arange_(0, (ssize_t) Array::Size, 1); } template = 0> ENOKI_INLINE Arg arange(size_t end = 1) { assert(end == 1); (void) end; return Arg(0); } template ENOKI_INLINE T arange(ssize_t start, ssize_t end, ssize_t step = 1) { if constexpr (is_static_array_v) { assert(end - start == (ssize_t) T::Size * step); return T::arange_(start, end, step); } else if constexpr (is_dynamic_array_v) { return T::arange_(start, end, step); } else { assert(end - start == step); (void) end; (void) step; return T(start); } } /// Construct an array that linearly interpolates from min..max template = 0> ENOKI_INLINE Array linspace(scalar_t min, scalar_t max, size_t size = 1) { return Array::linspace_(min, max, size); } template = 0> ENOKI_INLINE Array linspace(scalar_t min, scalar_t max, size_t size = Array::Size) { assert(size == Array::Size); (void) size; return Array::linspace_(min, max); } /// Construct an array that linearly interpolates from min..max (scalar fallback) template = 0> ENOKI_INLINE Arg linspace(scalar_t min, scalar_t, size_t size = 1) { assert(size == 1); (void) size; return min; } template > ENOKI_INLINE Return full(const Inner &inner, size_t size = 1) { ENOKI_MARK_USED(size); if constexpr (std::is_scalar_v) return inner; else return Return::full_(inner, size); } /// Load an array from aligned memory template ENOKI_INLINE T load(const void *mem) { if constexpr (is_array_v) { return T::load_(mem); } else { assert((uintptr_t) mem % alignof(T) == 0); return *static_cast(mem); } } /// Load an array from aligned memory (masked) template ENOKI_INLINE T load(const void *mem, const mask_t &mask) { if constexpr (is_array_v) { return T::load_(mem, mask); } else { if (mask) { assert((uintptr_t) mem % alignof(T) == 0); return *static_cast(mem); } else { return T(0); } } } /// Load an array from unaligned memory template ENOKI_INLINE T load_unaligned(const void *mem) { if constexpr (is_array_v) return T::load_unaligned_(mem); else return *static_cast(mem); } /// Load an array from unaligned memory (masked) template ENOKI_INLINE T load_unaligned(const void *mem, const mask_t &mask) { if constexpr (is_array_v) return T::load_unaligned_(mem, mask); else return mask ? *static_cast(mem) : T(0); } /// Store an array to aligned memory template ENOKI_INLINE void store(void *mem, const T &value) { if constexpr (is_array_v) { value.store_(mem); } else { assert((uintptr_t) mem % alignof(T) == 0); *static_cast(mem) = value; } } /// Store an array to aligned memory (masked) template ENOKI_INLINE void store(void *mem, const T &value, const mask_t &mask) { if constexpr (is_array_v) { value.store_(mem, mask); } else { if (mask) { assert((uintptr_t) mem % alignof(T) == 0); *static_cast(mem) = value; } } } /// Store an array to unaligned memory template ENOKI_INLINE void store_unaligned(void *mem, const T &value) { if constexpr (is_array_v) value.store_unaligned_(mem); else *static_cast(mem) = value; } /// Store an array to unaligned memory (masked) template ENOKI_INLINE void store_unaligned(void *mem, const T &value, const mask_t &mask) { if constexpr (is_array_v) value.store_unaligned_(mem, mask); else if (mask) *static_cast(mem) = value; } template = 0> auto concat(const T1 &a1, const T2 &a2) { static_assert(std::is_same_v, scalar_t>, "concat(): Scalar types must be identical"); constexpr size_t Depth1 = array_depth_v, Depth2 = array_depth_v, Depth = std::max(Depth1, Depth2), Size1 = array_size_v, Size2 = array_size_v, Size = Size1 + Size2; using Value = expr_t, value_t>; using Result = Array; if constexpr (Result::Size1 == Size1 && Result::Size2 == Size2 && Depth1 == 1 && Depth2 == 1) { return Result(a1, a2); } else if constexpr (Depth1 == 1 && Depth2 == 0 && T1::ActualSize == Size) { Result result(a1); #if defined(ENOKI_X86_SSE42) if constexpr (std::is_same_v, float>) result.m = _mm_insert_ps(result.m, _mm_set_ss(a2), 0b00110000); else #endif result.coeff(Size1) = a2; return result; } else { Result result; if constexpr (Depth1 == Depth) { for (size_t i = 0; i < Size1; ++i) result.coeff(i) = a1.derived().coeff(i); } else { result.coeff(0) = a1; } if constexpr (Depth2 == Depth) { for (size_t i = 0; i < Size2; ++i) result.coeff(i + Size1) = a2.derived().coeff(i); } else { result.coeff(Size1) = a2; } return result; } } namespace detail { template static ENOKI_INLINE Return extract(const T &a, std::index_sequence) { return Return(a.coeff(Index + Offset)...); } } template , Size>> ENOKI_INLINE Return head(const T &a) { if constexpr (T::ActualSize == Return::ActualSize) { return a; } else if constexpr (T::Size1 == Size) { return low(a); } else { static_assert(Size <= array_size_v, "Array size mismatch"); return detail::extract(a, std::make_index_sequence()); } } template , Size>> ENOKI_INLINE Return tail(const T &a) { if constexpr (T::Size == Return::Size) { return a; } else if constexpr (T::Size2 == Size) { return high(a); } else { static_assert(Size <= array_size_v, "Array size mismatch"); return detail::extract(a, std::make_index_sequence()); } } /// Masked extraction operation template ENOKI_INLINE auto extract(const Array &value, const Mask &mask) { if constexpr (is_array_v) return (value_t) value.extract_(mask); else return value; } //! @} // ----------------------------------------------------------------------- // ----------------------------------------------------------------------- //! @{ \name CUDA-specific forward declarations // ----------------------------------------------------------------------- /* Documentation in 'cuda.h' */ extern ENOKI_IMPORT void cuda_trace_printf(const char *, uint32_t, uint32_t*); extern ENOKI_IMPORT void cuda_var_mark_dirty(uint32_t); extern ENOKI_IMPORT void cuda_eval(bool log_assembly = false); extern ENOKI_IMPORT void cuda_sync(); extern ENOKI_IMPORT void cuda_set_scatter_gather_operand(uint32_t index, bool gather = false); extern ENOKI_IMPORT void cuda_set_log_level(uint32_t); extern ENOKI_IMPORT uint32_t cuda_log_level(); /// Fancy templated 'printf', which extracts the indices of Enoki arrays template void cuda_printf(const char *fmt, const Args&... args) { uint32_t indices[] = { args.index()..., 0 }; cuda_trace_printf(fmt, (uint32_t) sizeof...(Args), indices); } template && !is_cuda_array_v> = 0> ENOKI_INLINE void set_label(T&, const char *) { } //! @} // ----------------------------------------------------------------------- // ----------------------------------------------------------------------- //! @{ \name Scatter/gather/prefetch operations // ----------------------------------------------------------------------- NAMESPACE_BEGIN(detail) template ENOKI_INLINE decltype(auto) do_recursive(const Func &func, const Index1 &offset1, const Index2 &offset2, const Mask &mask) { if constexpr (array_depth_v + array_depth_v != array_depth_v) { using NewIndex = enoki::Array, Guide::Size>; using CombinedIndex = replace_scalar_t; constexpr size_t Size = (Packed || (array_depth_v + array_depth_v + 1 != array_depth_v)) ? Guide::Size : enoki::Array, Guide::Size>::ActualSize; /* Deal with n=3 special case */ CombinedIndex combined_offset = CombinedIndex(offset2 * scalar_t(Size)) + full(arange()); return do_recursive>( func, offset1, combined_offset, mask); } else { using CombinedIndex = replace_scalar_t; CombinedIndex combined_offset = CombinedIndex(offset2) + enoki::full(offset1) * scalar_t(Mult1 == 0 ? Mult2 : Mult1); return func(combined_offset, full(mask)); } } template constexpr size_t fix_stride(size_t Stride) { if (has_avx2) { if (Stride % 8 == 0) return 8; else if (Stride % 4 == 0) return 4; else return 1; } return Stride; } NAMESPACE_END(detail) /// Masked prefetch operation template >>> ENOKI_INLINE void prefetch(const void *mem, const Index &index, const identity_t &mask = true) { static_assert(is_std_int_v>, "prefetch(): expected a signed 32/64-bit integer as 'index' argument!"); constexpr size_t ScalarSize = sizeof(scalar_t); if constexpr (!is_array_v && !is_array_v) { /* Scalar case */ #if defined(ENOKI_X86_SSE42) if (mask) { constexpr size_t Stride = (Stride_ != 0) ? Stride_ : ScalarSize; const uint8_t *ptr = (const uint8_t *) mem + index * Index(Stride); constexpr auto Hint = Level == 1 ? _MM_HINT_T0 : _MM_HINT_T1; _mm_prefetch((char *) ptr, Hint); } #else (void) mem; (void) index; (void) mask; #endif } else if constexpr (std::is_same_v, array_shape_t>) { /* Forward to the array-specific implementation */ constexpr size_t Stride = (Stride_ != 0) ? Stride_ : ScalarSize, Stride2 = detail::fix_stride(Stride); Index index2 = Stride != Stride2 ? index * scalar_t(Stride / Stride2) : index; Array::template prefetch_(mem, index2, mask); } else if constexpr (array_depth_v > array_depth_v) { /* Dimension mismatch, reduce to a sequence of gather operations */ static_assert((Stride_ / ScalarSize) * ScalarSize == Stride_, "Stride must be divisible by sizeof(Scalar)"); return detail::do_recursive( [mem](const auto &index2, const auto &mask2) ENOKI_INLINE_LAMBDA { constexpr size_t ScalarSize2 = sizeof(scalar_t); // needed for MSVC prefetch(mem, index2, mask2); }, index, scalar_t(0), mask); } else { static_assert(detail::false_v, "prefetch(): don't know what to do with the input arguments!"); } } /// Masked gather operation template >>> ENOKI_INLINE Array gather(const void *mem, const Index &index, const identity_t &mask) { static_assert(is_std_int_v>, "gather(): expected a signed 32/64-bit integer as 'index' argument!"); constexpr size_t ScalarSize = sizeof(scalar_t); if constexpr (!is_array_v && !is_array_v) { /* Scalar case */ constexpr size_t Stride = (Stride_ != 0) ? Stride_ : ScalarSize; const Array *ptr = (const Array *) ((const uint8_t *) mem + index * Index(Stride)); return mask ? *ptr : Array(0); } else if constexpr (std::is_same_v, array_shape_t>) { /* Forward to the array-specific implementation */ constexpr size_t Stride = (Stride_ != 0) ? Stride_ : ScalarSize, Stride2 = detail::fix_stride(Stride); Index index2 = Stride != Stride2 ? index * scalar_t(Stride / Stride2) : index; return Array::template gather_(mem, index2, mask); } else if constexpr (array_depth_v == 1 && array_depth_v == 0) { /* Turn into a load */ ENOKI_MARK_USED(mask); constexpr size_t Stride = (Stride_ != 0) ? Stride_ : (Packed ? (sizeof(value_t) * array_size_v) : (sizeof(Array))); if constexpr (Masked) return load_unaligned((uint8_t *) mem + Stride * (size_t) index, mask); else return load_unaligned((uint8_t *) mem + Stride * (size_t) index); } else if constexpr (array_depth_v > array_depth_v) { /* Dimension mismatch, reduce to a sequence of gather operations */ static_assert((Stride_ / ScalarSize) * ScalarSize == Stride_, "Stride must be divisible by sizeof(Scalar)"); return detail::do_recursive( [mem](const auto &index2, const auto &mask2) ENOKI_INLINE_LAMBDA { constexpr size_t ScalarSize2 = sizeof(scalar_t); // needed for MSVC return gather(mem, index2, mask2); }, index, scalar_t(0), mask); } else { static_assert(detail::false_v, "gather(): don't know what to do with the input arguments!"); } } /// Masked scatter operation template >>> ENOKI_INLINE void scatter(void *mem, const Array &value, const Index &index, const identity_t &mask) { static_assert(is_std_int_v>, "scatter(): expected a signed 32/64-bit integer as 'index' argument!"); constexpr size_t ScalarSize = sizeof(scalar_t); if constexpr (!is_array_v && !is_array_v) { /* Scalar case */ constexpr size_t Stride = (Stride_ != 0) ? Stride_ : ScalarSize; Array *ptr = (Array *) ((uint8_t *) mem + index * Index(Stride)); if (mask) *ptr = value; } else if constexpr (std::is_same_v, array_shape_t>) { /* Forward to the array-specific implementation */ constexpr size_t Stride = (Stride_ != 0) ? Stride_ : ScalarSize, Stride2 = detail::fix_stride(Stride); Index index2 = Stride != Stride2 ? index * scalar_t(Stride / Stride2) : index; value.template scatter_(mem, index2, mask); } else if constexpr (array_depth_v == 1 && array_depth_v == 0) { /* Turn into a store */ ENOKI_MARK_USED(mask); constexpr size_t Stride = (Stride_ != 0) ? Stride_ : (Packed ? (sizeof(value_t) * array_size_v) : (sizeof(Array))); if constexpr (Masked) return store_unaligned((uint8_t *) mem + Stride * (size_t) index, value, mask); else return store_unaligned((uint8_t *) mem + Stride * (size_t) index, value); } else if constexpr (array_depth_v > array_depth_v) { /* Dimension mismatch, reduce to a sequence of gather operations */ static_assert((Stride_ / ScalarSize) * ScalarSize == Stride_, "Stride must be divisible by sizeof(Scalar)"); detail::do_recursive( [mem, &value](const auto &index2, const auto &mask2) ENOKI_INLINE_LAMBDA { constexpr size_t ScalarSize2 = sizeof(scalar_t); // needed for MSVC scatter(mem, value, index2, mask2); }, index, scalar_t(0), mask); } else { static_assert(detail::false_v, "scatter(): don't know what to do with the input arguments!"); } } template ENOKI_INLINE Array gather(const void *mem, const Index &index) { return gather(mem, index, true); } template ENOKI_INLINE void scatter(void *mem, const Array &value, const Index &index) { scatter(mem, value, index, true); } #if defined(__GNUC__) # pragma GCC diagnostic push # pragma GCC diagnostic ignored "-Wunused-value" #endif /// Conflict-free modification operation template ), typename Func, typename Index, typename... Args> void transform(void *mem, const Index &index, Func &&func, Args&&... args) { static_assert(is_std_int_v>, "transform(): index argument must be a 32/64-bit integer array!"); if constexpr (is_array_v) { using Int = int_array_t; if constexpr ((false, ..., is_mask_v)) Arg::template transform_(mem, (const Int &) index, (..., args), func, args...); else Arg::template transform_(mem, (const Int &) index, mask_t(true), func, args..., mask_t(true)); } else { Arg& ref = *(Arg *) ((uint8_t *) mem + index * Index(Stride)); if constexpr ((false, ..., is_mask_v)) { if ((..., args)) func(ref, args...); } else { func(ref, args..., true); } } } #if defined(__GNUC__) # pragma GCC diagnostic pop #endif /// Conflict-free scatter-add update template ENOKI_INLINE void scatter_add(void *mem, const Arg &value, const Index &index, mask_t mask = true) { static_assert(is_std_int_v>, "scatter_add(): index argument must be a 32/64-bit integer array!"); constexpr size_t Stride = Stride_ == 0 ? sizeof(scalar_t) : Stride_; if constexpr (is_array_v) { value.template scatter_add_(mem, index, mask); } else { Arg& ref = *(Arg *) ((uint8_t *) mem + index * Index(Stride)); if (mask) ref += value; } } /// Prefetch operations with an array source template == 1> = 0> ENOKI_INLINE void prefetch(const Source &source, const Args &... args) { prefetch(source.data(), args...); } //! @} // ----------------------------------------------------------------------- // ----------------------------------------------------------------------- //! @{ \name Nested horizontal reduction operators // ----------------------------------------------------------------------- template auto hsum_nested(const T &a) { if constexpr (array_depth_v == 1) return hsum(a); else if constexpr (is_array_v) return hsum_nested(hsum(a)); else return a; } template auto hprod_nested(const T &a) { if constexpr (array_depth_v == 1) return hprod(a); else if constexpr (is_array_v) return hprod_nested(hprod(a)); else return a; } template auto hmin_nested(const T &a) { if constexpr (array_depth_v == 1) return hmin(a); else if constexpr (is_array_v) return hmin_nested(hmin(a)); else return a; } template auto hmax_nested(const T &a) { if constexpr (array_depth_v == 1) return hmax(a); else if constexpr (is_array_v) return hmax_nested(hmax(a)); else return a; } template auto hmean_nested(const T &a) { if constexpr (array_depth_v == 1) return hmean(a); else if constexpr (is_array_v) return hmean_nested(hmean(a)); else return a; } template auto count_nested(const T &a) { if constexpr (is_array_v) return hsum_nested(count(a)); else return count(a); } template auto any_nested(const T &a) { if constexpr (is_array_v) return any_nested(any(a)); else return any(a); } template auto all_nested(const T &a) { if constexpr (is_array_v) return all_nested(all(a)); else return all(a); } template auto none_nested(const T &a) { return !any_nested(a); } /// Convert an array with 1 entry into a scalar or throw an error template scalar_t scalar_cast(const T &v) { static_assert(array_depth_v <= 1); if constexpr (is_array_v) { if (v.size() != 1) throw std::runtime_error("scalar_cast(): array should be of size 1!"); return v.coeff(0); } else { return v; } } template bool allclose(const T1 &a, const T2 &b, float rtol = 1e-5f, float atol = 1e-8f, bool equal_nan = false) { auto cond = abs(a - b) <= abs(b) * rtol + atol; if constexpr (std::is_floating_point_v> && std::is_floating_point_v>) { if (equal_nan) cond |= isnan(a) & isnan(b); } return all_nested(cond); } //! @} // ----------------------------------------------------------------------- // ----------------------------------------------------------------------- //! @{ \name Reduction operators that return a default argument when // invoked using CUDA arrays // ----------------------------------------------------------------------- template auto any_or(const T &value) { if constexpr (is_cuda_array_v) { ENOKI_MARK_USED(value); return Default; } else { return any(value); } } template auto any_nested_or(const T &value) { if constexpr (is_cuda_array_v) { ENOKI_MARK_USED(value); return Default; } else { return any_nested(value); } } template auto none_or(const T &value) { if constexpr (is_cuda_array_v) { ENOKI_MARK_USED(value); return Default; } else { return none(value); } } template auto none_nested_or(const T &value) { if constexpr (is_cuda_array_v) { ENOKI_MARK_USED(value); return Default; } else { return none_nested(value); } } template auto all_or(const T &value) { if constexpr (is_cuda_array_v) { ENOKI_MARK_USED(value); return Default; } else { return all(value); } } template auto all_nested_or(const T &value) { if constexpr (is_cuda_array_v) { ENOKI_MARK_USED(value); return Default; } else { return all_nested(value); } } //! @} // ----------------------------------------------------------------------- #undef ENOKI_ROUTE_UNARY #undef ENOKI_ROUTE_UNARY_IMM #undef ENOKI_ROUTE_UNARY_SCALAR #undef ENOKI_ROUTE_UNARY_SCALAR_IMM #undef ENOKI_ROUTE_BINARY #undef ENOKI_ROUTE_BINARY_BITOP #undef ENOKI_ROUTE_BINARY_COND #undef ENOKI_ROUTE_BINARY_SHIFT #undef ENOKI_ROUTE_BINARY_SCALAR #undef ENOKI_ROUTE_TERNARY #undef ENOKI_ROUTE_COMPOUND_OPERATOR NAMESPACE_END(enoki)