/* enoki/array_recursive.h -- Template specialization that recursively instantiates Array instances with smaller sizes when the requested packet float array size is not directly supported by the processor's SIMD instructions Enoki is a C++ template library that enables transparent vectorization of numerical kernels using ENOKI instruction sets available on current processor architectures. Copyright (c) 2019 Wenzel Jakob All rights reserved. Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. */ #pragma once #include NAMESPACE_BEGIN(enoki) template struct StaticArrayImpl::use_recursive_impl>> : StaticArrayBase { using Base = StaticArrayBase; ENOKI_ARRAY_IMPORT_BASIC(Base, StaticArrayImpl) using typename Base::Array1; using typename Base::Array2; using Base::Size1; using Base::Size2; using Ref = const Derived &; static constexpr bool IsRecursive = true; StaticArrayImpl() = default; /// Initialize all entries with a constant ENOKI_INLINE StaticArrayImpl(const Value &value) : a1(value), a2(value) { } /// Initialize from a list of component values template ...>> = 0> ENOKI_INLINE StaticArrayImpl(Ts... args) { alignas(alignof(Array1)) Value storage[Size] = { (Value) args... }; a1 = load(storage); a2 = load(storage + Size1); } /// Construct from the two sub-components template = 0> ENOKI_INLINE StaticArrayImpl(const T1 &a1, const T2 &a2) : a1(a1), a2(a2) { } /// Cast another array template = 0> ENOKI_INLINE StaticArrayImpl( const StaticArrayBase &a) : a1(low(a)), a2(high(a)) { } /// Reinterpret another array template = 0> ENOKI_INLINE StaticArrayImpl( const StaticArrayBase &a, detail::reinterpret_flag) : a1(low (a), detail::reinterpret_flag()), a2(high(a), detail::reinterpret_flag()) { } /// Reinterpret another array (masks) template = 0> ENOKI_INLINE StaticArrayImpl(bool b, detail::reinterpret_flag) : a1(b, detail::reinterpret_flag()), a2(b, detail::reinterpret_flag()) { } template = 0> ENOKI_INLINE StaticArrayImpl &operator=(Value_ v) { *this = StaticArrayImpl(v); return *this; } template = 0> ENOKI_INLINE StaticArrayImpl &operator=(bool v) { *this = StaticArrayImpl(v, detail::reinterpret_flag()); return *this; } // ----------------------------------------------------------------------- //! @{ \name Vertical operations // ----------------------------------------------------------------------- ENOKI_INLINE Derived add_(Ref a) const { return Derived(a1 + a.a1, a2 + a.a2); } ENOKI_INLINE Derived sub_(Ref a) const { return Derived(a1 - a.a1, a2 - a.a2); } ENOKI_INLINE Derived mul_(Ref a) const { return Derived(a1 * a.a1, a2 * a.a2); } ENOKI_INLINE Derived div_(Ref a) const { return Derived(a1 / a.a1, a2 / a.a2); } ENOKI_INLINE Derived mod_(Ref a) const { return Derived(a1 % a.a1, a2 % a.a2); } ENOKI_INLINE Derived mulhi_(Ref a) const { return Derived(mulhi(a1, a.a1), mulhi(a2, a.a2)); } ENOKI_INLINE Derived fmod_(Ref a) const { return Derived(fmod(a1, a.a1), fmod(a2, a.a2)); } ENOKI_INLINE auto lt_ (Ref a) const { return mask_t(a1 < a.a1, a2 < a.a2); } ENOKI_INLINE auto gt_ (Ref a) const { return mask_t(a1 > a.a1, a2 > a.a2); } ENOKI_INLINE auto le_ (Ref a) const { return mask_t(a1 <= a.a1, a2 <= a.a2); } ENOKI_INLINE auto ge_ (Ref a) const { return mask_t(a1 >= a.a1, a2 >= a.a2); } ENOKI_INLINE auto eq_ (Ref a) const { return mask_t(eq(a1, a.a1), eq(a2, a.a2)); } ENOKI_INLINE auto neq_(Ref a) const { return mask_t(neq(a1, a.a1), neq(a2, a.a2)); } ENOKI_INLINE Derived min_(Ref a) const { return Derived(min(a1, a.a1), min(a2, a.a2)); } ENOKI_INLINE Derived max_(Ref a) const { return Derived(max(a1, a.a1), max(a2, a.a2)); } ENOKI_INLINE Derived abs_() const { return Derived(abs(a1), abs(a2)); } ENOKI_INLINE Derived ceil_() const { return Derived(ceil(a1), ceil(a2)); } ENOKI_INLINE Derived floor_() const { return Derived(floor(a1), floor(a2)); } ENOKI_INLINE Derived sqrt_() const { return Derived(sqrt(a1), sqrt(a2)); } ENOKI_INLINE Derived round_() const { return Derived(round(a1), round(a2)); } ENOKI_INLINE Derived trunc_() const { return Derived(trunc(a1), trunc(a2)); } ENOKI_INLINE Derived rcp_() const { return Derived(rcp(a1), rcp(a2)); } ENOKI_INLINE Derived rsqrt_() const { return Derived(rsqrt(a1), rsqrt(a2)); } ENOKI_INLINE Derived not_() const { return Derived(~a1, ~a2); } ENOKI_INLINE Derived neg_() const { return Derived(-a1, -a2); } ENOKI_INLINE Derived fmadd_(Ref b, Ref c) const { return Derived(fmadd(a1, b.a1, c.a1), fmadd(a2, b.a2, c.a2)); } ENOKI_INLINE Derived fnmadd_(Ref b, Ref c) const { return Derived(fnmadd(a1, b.a1, c.a1), fnmadd(a2, b.a2, c.a2)); } ENOKI_INLINE Derived fmsub_(Ref b, Ref c) const { return Derived(fmsub(a1, b.a1, c.a1), fmsub(a2, b.a2, c.a2)); } ENOKI_INLINE Derived fnmsub_(Ref b, Ref c) const { return Derived(fnmsub(a1, b.a1, c.a1), fnmsub(a2, b.a2, c.a2)); } ENOKI_INLINE Derived fmaddsub_(Ref b, Ref c) const { return Derived(fmaddsub(a1, b.a1, c.a1), fmaddsub(a2, b.a2, c.a2)); } ENOKI_INLINE Derived fmsubadd_(Ref b, Ref c) const { return Derived(fmsubadd(a1, b.a1, c.a1), fmsubadd(a2, b.a2, c.a2)); } template ENOKI_INLINE Derived or_(const T &a) const { return Derived(a1 | low(a), a2 | high(a)); } template ENOKI_INLINE Derived andnot_(const T &a) const { return Derived(andnot(a1, low(a)), andnot(a2, high(a))); } template ENOKI_INLINE Derived and_(const T &a) const { return Derived(a1 & low(a), a2 & high(a)); } template ENOKI_INLINE Derived xor_(const T &a) const { return Derived(a1 ^ low(a), a2 ^ high(a)); } template ENOKI_INLINE Derived sl_() const { return Derived(sl(a1), sl(a2)); } ENOKI_INLINE Derived sl_(size_t k) const { return Derived(a1 << k, a2 << k); } ENOKI_INLINE Derived sl_(Ref a) const { return Derived(a1 << a.a1, a2 << a.a2); } template ENOKI_INLINE Derived sr_() const { return Derived(sr(a1), sr(a2)); } ENOKI_INLINE Derived sr_(size_t k) const { return Derived(a1 >> k, a2 >> k); } ENOKI_INLINE Derived sr_(Ref a) const { return Derived(a1 >> a.a1, a2 >> a.a2); } template ENOKI_INLINE Derived rol_() const { return Derived(rol(a1), rol(a2)); } template ENOKI_INLINE Derived ror_() const { return Derived(ror(a1), ror(a2)); } ENOKI_INLINE Derived rol_(Ref arg) const { return Derived(rol(a1, arg.a1), rol(a2, arg.a2)); } ENOKI_INLINE Derived ror_(Ref arg) const { return Derived(ror(a1, arg.a1), ror(a2, arg.a2)); } template static ENOKI_INLINE Derived select_(const Mask &m, Ref t, Ref f) { return Derived(select(m.a1, t.a1, f.a1), select(m.a2, t.a2, f.a2)); } template ENOKI_INLINE Derived ror_array_() const { if constexpr (Size1 == Size2) { static_assert( Imm <= Size1 && Imm <= Size2, "ror_array(): Refusing to rotate a recursively defined array by an " "amount that is larger than the recursive array sizes."); const mask_t mask = arange() >= Scalar(Imm); Array1 a1_r = ror_array(a1); Array2 a2_r = ror_array(a2); return Derived( select(mask, a1_r, a2_r), select(mask, a2_r, a1_r) ); } else { return Base::template ror_array_(); } } template ENOKI_INLINE Derived rol_array_() const { if constexpr (Size1 == Size2) { static_assert( Imm <= Size1 && Imm <= Size2, "rol_array(): Refusing to rotate a recursively defined array " "by an amount that is larger than the recursive array sizes."); const mask_t mask = arange() < Scalar(Size1 - Imm); Array1 a1_r = rol_array(a1); Array2 a2_r = rol_array(a2); return Derived( select(mask, a1_r, a2_r), select(mask, a2_r, a1_r) ); } else { return Base::template rol_array_(); } } Derived ldexp_(Ref a) const { return Derived(ldexp(a1, a.a1), ldexp(a2, a.a2)); } std::pair frexp_() const { auto r1 = frexp(a1); auto r2 = frexp(a2); return std::make_pair( Derived(r1.first, r2.first), Derived(r1.second, r2.second) ); } template ENOKI_INLINE auto ceil2int_() const { return T(ceil2int(a1), ceil2int(a2)); } template ENOKI_INLINE auto floor2int_() const { return T(floor2int(a1), floor2int(a2)); } Derived lzcnt_() const { return Derived(lzcnt(a1), lzcnt(a2)); } Derived tzcnt_() const { return Derived(tzcnt(a1), tzcnt(a2)); } Derived popcnt_() const { return Derived(popcnt(a1), popcnt(a2)); } template static constexpr auto split_(std::index_sequence, std::index_sequence) { constexpr std::array out { Is ... }; return std::make_pair(std::index_sequence(), std::index_sequence()); } template ENOKI_INLINE Derived shuffle_() const { if constexpr (Size1 != Size2) { return Base::template shuffle_(); } else { constexpr auto indices = split_(std::index_sequence(), std::make_index_sequence()); return shuffle_impl_(indices.first, indices.second); } } template ENOKI_INLINE Derived shuffle_impl_(std::index_sequence, std::index_sequence) const { using Int = int_array_t; Array1 a1l = a1.template shuffle_<(size_t) std::min(Size1 - 1, Indices1)...>(), a1h = a2.template shuffle_<(size_t) std::max((ssize_t) 0, (ssize_t) Indices1 - (ssize_t) Size1)...>(), a1f = select(Int(Indices1...) < Int(Size1), a1l, a1h); Array2 a2l = a1.template shuffle_(), a2h = a2.template shuffle_<(size_t) std::max((ssize_t) 0, (ssize_t) Indices2 - (ssize_t) Size1)...>(), a2f = select(Int(Indices2...) < Int(Size1), a2l, a2h); return Derived(a1f, a2f); } template ENOKI_INLINE Derived shuffle_(const Index &index) const { if constexpr (Size1 != Size2) { return Base::shuffle_(index); } else { auto il = low(index), ih = high(index); decltype(il) size = scalar_t(Size1); Array1 a1l = a1.shuffle_(il), a1h = a2.shuffle_(il - size), a1f = select(il < size, a1l, a1h); Array2 a2l = a1.shuffle_(ih), a2h = a2.shuffle_(ih - size), a2f = select(ih < size, a2l, a2h); return Derived(a1f, a2f); } } #define ENOKI_MASKED_OPERATOR(name) \ template \ ENOKI_INLINE void m##name##_(Ref value, const Mask &mask) { \ a1.m##name##_(low(value), low(mask)); \ a2.m##name##_(high(value), high(mask)); \ } ENOKI_MASKED_OPERATOR(assign) ENOKI_MASKED_OPERATOR(add) ENOKI_MASKED_OPERATOR(sub) ENOKI_MASKED_OPERATOR(mul) ENOKI_MASKED_OPERATOR(div) ENOKI_MASKED_OPERATOR(and) ENOKI_MASKED_OPERATOR(or) ENOKI_MASKED_OPERATOR(xor) #undef ENOKI_MASKED_OPERATOR //! @} // ----------------------------------------------------------------------- // ----------------------------------------------------------------------- //! @{ \name Horizontal operations // ----------------------------------------------------------------------- ENOKI_INLINE Value hsum_() const { if constexpr (Size1 == Size2) return hsum(a1 + a2); else return hsum(a1) + hsum(a2); } ENOKI_INLINE Value hprod_() const { if constexpr (Size1 == Size2) return hprod(a1 * a2); else return hprod(a1) * hprod(a2); } ENOKI_INLINE Value hmin_() const { if constexpr (Size1 == Size2) return hmin(min(a1, a2)); else return min(hmin(a1), hmin(a2)); } ENOKI_INLINE Value hmax_() const { if constexpr (Size1 == Size2) return hmax(max(a1, a2)); else return max(hmax(a1), hmax(a2)); } ENOKI_INLINE Value dot_(Ref a) const { if constexpr (Size1 == Size2) return hsum(fmadd(a1, a.a1, a2 * a.a2)); else return dot(a1, a.a1) + dot(a2, a.a2); } ENOKI_INLINE bool all_() const { if constexpr (Size1 == Size2) return all(a1 & a2); else return all(a1) && all(a2); } ENOKI_INLINE bool any_() const { if constexpr (Size1 == Size2) return any(a1 | a2); else return any(a1) || any(a2); } ENOKI_INLINE size_t count_() const { return count(a1) + count(a2); } //! @} // ----------------------------------------------------------------------- // ----------------------------------------------------------------------- //! @{ \name Initialization, loading/writing data // ----------------------------------------------------------------------- ENOKI_INLINE void store_(void *mem) const { store((uint8_t *) mem, a1); store((uint8_t *) mem + sizeof(Array1), a2); } template ENOKI_INLINE void store_(void *mem, const Mask &mask) const { store((uint8_t *) mem, a1, low(mask)); store((uint8_t *) mem + sizeof(Array1), a2, high(mask)); } ENOKI_INLINE void store_unaligned_(void *mem) const { store_unaligned((uint8_t *) mem, a1); store_unaligned((uint8_t *) mem + sizeof(Array1), a2); } template ENOKI_INLINE void store_unaligned_(void *mem, const Mask &mask) const { store_unaligned((uint8_t *) mem, a1, low(mask)); store_unaligned((uint8_t *) mem + sizeof(Array1), a2, high(mask)); } static ENOKI_INLINE Derived load_(const void *mem) { return Derived( load((uint8_t *) mem), load((uint8_t *) mem + sizeof(Array1)) ); } template static ENOKI_INLINE Derived load_(const void *mem, const Mask &mask) { return Derived( load((uint8_t *) mem, low(mask)), load((uint8_t *) mem + sizeof(Array1), high(mask)) ); } static ENOKI_INLINE Derived load_unaligned_(const void *a) { return Derived( load_unaligned((uint8_t *) a), load_unaligned((uint8_t *) a + sizeof(Array1)) ); } template static ENOKI_INLINE Derived load_unaligned_(const void *a, const Mask &mask) { return Derived( load_unaligned((uint8_t *) a, low(mask)), load_unaligned((uint8_t *) a + sizeof(Array1), high(mask)) ); } static ENOKI_INLINE Derived zero_() { return Derived(zero(), zero()); } template static ENOKI_INLINE void prefetch_(const void *ptr, const Index &index, const Mask &mask) { prefetch(ptr, low(index), low(mask)); prefetch(ptr, high(index), high(mask)); } template static ENOKI_INLINE Derived gather_(const void *ptr, const Index &index, const Mask &mask) { return Derived( gather(ptr, low(index), low(mask)), gather(ptr, high(index), high(mask)) ); } template ENOKI_INLINE void scatter_(void *ptr, const Index &index, const Mask &mask) const { scatter(ptr, a1, low(index), low(mask)); scatter(ptr, a2, high(index), high(mask)); } template static ENOKI_INLINE void transform_(void *ptr, const Index &index, const Mask &, const Func &func, const Args &... args) { transform(ptr, low(index), func, low(args)...); transform(ptr, high(index), func, high(args)...); } template ENOKI_INLINE Value extract_(const Mask &mask) const { if constexpr (Size1 == Size2) { return extract(select(low(mask), a1, a2), low(mask) | high(mask)); } else { if (ENOKI_LIKELY(any(low(mask)))) return extract(a1, low(mask)); else return extract(a2, high(mask)); } } template ENOKI_INLINE size_t compress_(T *&ptr, const Mask &mask) const { size_t r0 = compress(ptr, a1, low(mask)); size_t r1 = compress(ptr, a2, high(mask)); return r0 + r1; } //! @} // ----------------------------------------------------------------------- // ----------------------------------------------------------------------- //! @{ \name Component access // ----------------------------------------------------------------------- ENOKI_INLINE const Array1& low_() const { return a1; } ENOKI_INLINE const Array2& high_() const { return a2; } ENOKI_INLINE decltype(auto) coeff(size_t i) const { if constexpr (Size1 == Size2) return ((i < Size1) ? a1 : a2).coeff(i % Size1); else return (i < Size1) ? a1.coeff(i) : a2.coeff(i - Size1); } ENOKI_INLINE decltype(auto) coeff(size_t i) { if constexpr (Size1 == Size2) return ((i < Size1) ? a1 : a2).coeff(i % Size1); else return (i < Size1) ? a1.coeff(i) : a2.coeff(i - Size1); } //! @} // ----------------------------------------------------------------------- Array1 a1; Array2 a2; }; NAMESPACE_END(enoki)