From 8b90d7f9eb9d8763440d6d7c2a95f78fa4819a34 Mon Sep 17 00:00:00 2001 From: Yun Hsiao Wu Date: Mon, 29 Mar 2021 18:18:43 +0800 Subject: [PATCH] enoki headers (#140) --- sources/SocketRocket/CMakeLists.txt | 6 +- sources/enoki/array.h | 182 ++ sources/enoki/array_avx.h | 1173 +++++++++++++ sources/enoki/array_avx2.h | 1257 ++++++++++++++ sources/enoki/array_avx512.h | 1928 +++++++++++++++++++++ sources/enoki/array_base.h | 240 +++ sources/enoki/array_call.h | 291 ++++ sources/enoki/array_enum.h | 82 + sources/enoki/array_fallbacks.h | 546 ++++++ sources/enoki/array_generic.h | 626 +++++++ sources/enoki/array_idiv.h | 327 ++++ sources/enoki/array_intrin.h | 326 ++++ sources/enoki/array_kmask.h | 296 ++++ sources/enoki/array_macro.h | 419 +++++ sources/enoki/array_masked.h | 92 + sources/enoki/array_math.h | 1505 +++++++++++++++++ sources/enoki/array_neon.h | 1305 +++++++++++++++ sources/enoki/array_recursive.h | 556 ++++++ sources/enoki/array_round.h | 156 ++ sources/enoki/array_router.h | 1400 ++++++++++++++++ sources/enoki/array_sse42.h | 2410 +++++++++++++++++++++++++++ sources/enoki/array_static.h | 1231 ++++++++++++++ sources/enoki/array_struct.h | 544 ++++++ sources/enoki/array_traits.h | 615 +++++++ sources/enoki/array_utils.h | 200 +++ sources/enoki/autodiff.h | 1569 +++++++++++++++++ sources/enoki/color.h | 95 ++ sources/enoki/complex.h | 289 ++++ sources/enoki/cuda.h | 1026 ++++++++++++ sources/enoki/dynamic.h | 1145 +++++++++++++ sources/enoki/fwd.h | 330 ++++ sources/enoki/half.h | 193 +++ sources/enoki/matrix.h | 658 ++++++++ sources/enoki/morton.h | 161 ++ sources/enoki/python.h | 229 +++ sources/enoki/quaternion.h | 361 ++++ sources/enoki/random.h | 333 ++++ sources/enoki/sh.h | 843 ++++++++++ sources/enoki/special.h | 675 ++++++++ sources/enoki/stl.h | 323 ++++ sources/enoki/transform.h | 202 +++ 41 files changed, 26142 insertions(+), 3 deletions(-) create mode 100644 sources/enoki/array.h create mode 100644 sources/enoki/array_avx.h create mode 100644 sources/enoki/array_avx2.h create mode 100644 sources/enoki/array_avx512.h create mode 100644 sources/enoki/array_base.h create mode 100644 sources/enoki/array_call.h create mode 100644 sources/enoki/array_enum.h create mode 100644 sources/enoki/array_fallbacks.h create mode 100644 sources/enoki/array_generic.h create mode 100644 sources/enoki/array_idiv.h create mode 100644 sources/enoki/array_intrin.h create mode 100644 sources/enoki/array_kmask.h create mode 100644 sources/enoki/array_macro.h create mode 100644 sources/enoki/array_masked.h create mode 100644 sources/enoki/array_math.h create mode 100644 sources/enoki/array_neon.h create mode 100644 sources/enoki/array_recursive.h create mode 100644 sources/enoki/array_round.h create mode 100644 sources/enoki/array_router.h create mode 100644 sources/enoki/array_sse42.h create mode 100644 sources/enoki/array_static.h create mode 100644 sources/enoki/array_struct.h create mode 100644 sources/enoki/array_traits.h create mode 100644 sources/enoki/array_utils.h create mode 100644 sources/enoki/autodiff.h create mode 100644 sources/enoki/color.h create mode 100644 sources/enoki/complex.h create mode 100644 sources/enoki/cuda.h create mode 100644 sources/enoki/dynamic.h create mode 100644 sources/enoki/fwd.h create mode 100644 sources/enoki/half.h create mode 100644 sources/enoki/matrix.h create mode 100644 sources/enoki/morton.h create mode 100644 sources/enoki/python.h create mode 100644 sources/enoki/quaternion.h create mode 100644 sources/enoki/random.h create mode 100644 sources/enoki/sh.h create mode 100644 sources/enoki/special.h create mode 100644 sources/enoki/stl.h create mode 100644 sources/enoki/transform.h diff --git a/sources/SocketRocket/CMakeLists.txt b/sources/SocketRocket/CMakeLists.txt index ae2ed906..928f7c6d 100644 --- a/sources/SocketRocket/CMakeLists.txt +++ b/sources/SocketRocket/CMakeLists.txt @@ -47,7 +47,7 @@ set(SOCKET_ROCKET_SOURCES set(SOCKET_ROCKET_SOURCES_M ${SOCKET_ROCKET_SOURCES}) list(FILTER SOCKET_ROCKET_SOURCES_M INCLUDE REGEX ".*m$") -set_source_files_properties(${SOCKET_ROCKET_SOURCES_M} PROPERTIES COMPILE_FLAGS +set_source_files_properties(${SOCKET_ROCKET_SOURCES_M} PROPERTIES COMPILE_FLAGS -fobjc-arc ) @@ -62,5 +62,5 @@ list(APPEND CC_EXTERNAL_PRIVATE_INCLUDES ${CMAKE_CURRENT_LIST_DIR}/Internal/Proxy ) -list(APPEND CC_EXTERNAL_SROUCES ${SOCKET_ROCKET_SOURCES}) -list(APPEND CC_EXTERNAL_INCLUDES ${CMAKE_CURRENT_LIST_DIR}) \ No newline at end of file +list(APPEND CC_EXTERNAL_SOURCES ${SOCKET_ROCKET_SOURCES}) +list(APPEND CC_EXTERNAL_INCLUDES ${CMAKE_CURRENT_LIST_DIR}) diff --git a/sources/enoki/array.h b/sources/enoki/array.h new file mode 100644 index 00000000..c4b12360 --- /dev/null +++ b/sources/enoki/array.h @@ -0,0 +1,182 @@ +/* + enoki/array.h -- Main header file for the Enoki array class and + various template specializations + + Enoki is a C++ template library that enables transparent vectorization + of numerical kernels using SIMD instruction sets available on current + processor architectures. + + Copyright (c) 2019 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#pragma once + +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable: 4146) // warning C4146: unary minus operator applied to unsigned type, result still unsigned +# pragma warning(disable: 4554) // warning C4554: '>>': check operator precedence for possible error; use parentheses to clarify precedence +# pragma warning(disable: 4702) // warning C4702: unreachable code +# pragma warning(disable: 4522) // warning C4522: multiple assignment operators specified +# pragma warning(disable: 4310) // warning C4310: cast truncates constant value +# pragma warning(disable: 4127) // warning C4127: conditional expression is constant +#elif defined(__GNUC__) && !defined(__clang__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wclass-memaccess" +#endif + +#include + +#include + +#if defined(ENOKI_ARM_NEON) || defined(ENOKI_X86_SSE42) +# include +#endif + +#if defined(ENOKI_X86_AVX512F) +# include +#endif + +#if defined(ENOKI_X86_SSE42) +# include +#endif + +#if defined(ENOKI_X86_AVX) +# include +#endif + +#if defined(ENOKI_X86_AVX2) +# include +#endif + +#if defined(ENOKI_X86_AVX512F) +# include +#endif + +#if defined(ENOKI_ARM_NEON) +# include +#endif + +#include +#include +#include +#include +#include + +#include + +NAMESPACE_BEGIN(enoki) + +template +struct Array : StaticArrayImpl> { + + using Base = StaticArrayImpl>; + + using ArrayType = Array; + using MaskType = Mask; + + /// Type alias for creating a similar-shaped array over a different type + template using ReplaceValue = Array; + + ENOKI_ARRAY_IMPORT(Base, Array) +}; + +template +struct Mask : StaticArrayImpl> { + + using Base = StaticArrayImpl>; + + using ArrayType = Array; + using MaskType = Mask; + + /// Type alias for creating a similar-shaped array over a different type + template using ReplaceValue = Mask; + + Mask() = default; + + template Mask(T &&value) + : Base(std::forward(value), detail::reinterpret_flag()) { } + + template Mask(T &&value, detail::reinterpret_flag) + : Base(std::forward(value), detail::reinterpret_flag()) { } + + /// Construct from sub-arrays + template == array_depth_v && array_size_v == Base::Size1 && + array_depth_v == array_depth_v && array_size_v == Base::Size2 && + Base::Size2 != 0> = 0> + Mask(const T1 &a1, const T2 &a2) + : Base(a1, a2) { } + + template ...>> = 0> + Mask(Ts&&... ts) : Base(std::forward(ts)...) { } + + ENOKI_ARRAY_IMPORT_BASIC(Base, Mask) + using Base::operator=; +}; + +template +struct Packet : StaticArrayImpl> { + + using Base = StaticArrayImpl>; + + using ArrayType = Packet; + using MaskType = PacketMask; + + static constexpr bool BroadcastPreferOuter = false; + + /// Type alias for creating a similar-shaped array over a different type + template using ReplaceValue = Packet; + + ENOKI_ARRAY_IMPORT(Base, Packet) +}; + +template +struct PacketMask : StaticArrayImpl> { + + using Base = StaticArrayImpl>; + + static constexpr bool BroadcastPreferOuter = false; + + using ArrayType = Packet; + using MaskType = PacketMask; + + /// Type alias for creating a similar-shaped array over a different type + template using ReplaceValue = PacketMask; + + PacketMask() = default; + + template PacketMask(T &&value) + : Base(std::forward(value), detail::reinterpret_flag()) { } + + template PacketMask(T &&value, detail::reinterpret_flag) + : Base(std::forward(value), detail::reinterpret_flag()) { } + + /// Construct from sub-arrays + template == array_depth_v && array_size_v == Base::Size1 && + array_depth_v == array_depth_v && array_size_v == Base::Size2 && + Base::Size2 != 0> = 0> + PacketMask(const T1 &a1, const T2 &a2) + : Base(a1, a2) { } + + template ...>> = 0> + PacketMask(Ts&&... ts) : Base(std::forward(ts)...) { } + + ENOKI_ARRAY_IMPORT_BASIC(Base, PacketMask) + using Base::operator=; +}; + +NAMESPACE_END(enoki) + +#if defined(_MSC_VER) +# pragma warning(pop) +#elif defined(__GNUC__) && !defined(__clang__) +# pragma GCC diagnostic pop +#endif diff --git a/sources/enoki/array_avx.h b/sources/enoki/array_avx.h new file mode 100644 index 00000000..a1bb60c9 --- /dev/null +++ b/sources/enoki/array_avx.h @@ -0,0 +1,1173 @@ +/* + enoki/array_avx.h -- Packed SIMD array (AVX specialization) + + Enoki is a C++ template library that enables transparent vectorization + of numerical kernels using SIMD instruction sets available on current + processor architectures. + + Copyrighe (c) 2019 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#pragma once + +NAMESPACE_BEGIN(enoki) +NAMESPACE_BEGIN(detail) +template <> struct is_native : std::true_type { } ; +template <> struct is_native : std::true_type { }; +template <> struct is_native : std::true_type { }; +NAMESPACE_END(detail) + +/// Partial overload of StaticArrayImpl using AVX intrinsics (single precision) +template struct alignas(32) + StaticArrayImpl + : StaticArrayBase { + ENOKI_NATIVE_ARRAY(float, 8, __m256) + + // ----------------------------------------------------------------------- + //! @{ \name Value constructors + // ----------------------------------------------------------------------- + + ENOKI_INLINE StaticArrayImpl(Value value) : m(_mm256_set1_ps(value)) { } + ENOKI_INLINE StaticArrayImpl(Value v0, Value v1, Value v2, Value v3, + Value v4, Value v5, Value v6, Value v7) + : m(_mm256_setr_ps(v0, v1, v2, v3, v4, v5, v6, v7)) { } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Type converting constructors + // ----------------------------------------------------------------------- + +#if defined(ENOKI_X86_F16C) + ENOKI_CONVERT(half) + : m(_mm256_cvtph_ps(_mm_loadu_si128((const __m128i *) a.derived().data()))) { } +#endif + + ENOKI_CONVERT(float) : m(a.derived().m) { } + +#if defined(ENOKI_X86_AVX2) + ENOKI_CONVERT(int32_t) : m(_mm256_cvtepi32_ps(a.derived().m)) { } +#endif + + ENOKI_CONVERT(uint32_t) { + #if defined(ENOKI_X86_AVX512DQ) && defined(ENOKI_X86_AVX512VL) + m = _mm256_cvtepu32_ps(a.derived().m); + #else + int32_array_t ai(a); + Derived result = + Derived(ai & 0x7fffffff) + + (Derived(float(1u << 31)) & mask_t(sr<31>(ai))); + m = result.m; + #endif + } + +#if defined(ENOKI_X86_AVX512F) + ENOKI_CONVERT(double) + :m(_mm512_cvtpd_ps(a.derived().m)) { } +#else + ENOKI_CONVERT(double) + : m(detail::concat(_mm256_cvtpd_ps(low(a).m), + _mm256_cvtpd_ps(high(a).m))) { } +#endif + +#if defined(ENOKI_X86_AVX512DQ) + ENOKI_CONVERT(int64_t) : m(_mm512_cvtepi64_ps(a.derived().m)) { } + ENOKI_CONVERT(uint64_t) : m(_mm512_cvtepu64_ps(a.derived().m)) { } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Reinterpreting constructors, mask converters + // ----------------------------------------------------------------------- + + ENOKI_REINTERPRET(bool) { + uint64_t ival; + memcpy(&ival, a.derived().data(), 8); + __m128i value = _mm_cmpgt_epi8( + detail::mm_cvtsi64_si128((long long) ival), _mm_setzero_si128()); + #if defined(ENOKI_X86_AVX2) + m = _mm256_castsi256_ps(_mm256_cvtepi8_epi32(value)); + #else + m = _mm256_castsi256_ps(_mm256_insertf128_si256( + _mm256_castsi128_si256(_mm_cvtepi8_epi32(value)), + _mm_cvtepi8_epi32(_mm_srli_si128(value, 4)), 1)); + #endif + } + + ENOKI_REINTERPRET(float) : m(a.derived().m) { } + +#if defined(ENOKI_X86_AVX2) + ENOKI_REINTERPRET(int32_t) : m(_mm256_castsi256_ps(a.derived().m)) { } + ENOKI_REINTERPRET(uint32_t) : m(_mm256_castsi256_ps(a.derived().m)) { } +#else + ENOKI_REINTERPRET(int32_t) + : m(detail::concat(_mm_castsi128_ps(low(a).m), + _mm_castsi128_ps(high(a).m))) { } + + ENOKI_REINTERPRET(uint32_t) + : m(detail::concat(_mm_castsi128_ps(low(a).m), + _mm_castsi128_ps(high(a).m))) { } +#endif + +#if defined(ENOKI_X86_AVX512DQ) && defined(ENOKI_X86_AVX512VL) + /// Handled by KMask +#elif defined(ENOKI_X86_AVX512F) + ENOKI_REINTERPRET(double) + : m(_mm512_castps512_ps256(_mm512_maskz_mov_ps( + (__mmask16) a.derived().k, _mm512_set1_ps(memcpy_cast(int32_t(-1)))))) { } + ENOKI_REINTERPRET(int64_t) + : m(_mm512_castps512_ps256(_mm512_maskz_mov_ps( + (__mmask16) a.derived().k, _mm512_set1_ps(memcpy_cast(int32_t(-1)))))) { } + ENOKI_REINTERPRET(uint64_t) + : m(_mm512_castps512_ps256(_mm512_maskz_mov_ps( + (__mmask16) a.derived().k, _mm512_set1_ps(memcpy_cast(int32_t(-1)))))) { } +#else + ENOKI_REINTERPRET(double) + : m(_mm256_castsi256_ps(detail::mm512_cvtepi64_epi32( + _mm256_castpd_si256(low(a).m), _mm256_castpd_si256(high(a).m)))) { } +# if defined(ENOKI_X86_AVX2) + ENOKI_REINTERPRET(int64_t) + : m(_mm256_castsi256_ps( + detail::mm512_cvtepi64_epi32(low(a).m, high(a).m))) { } + ENOKI_REINTERPRET(uint64_t) + : m(_mm256_castsi256_ps( + detail::mm512_cvtepi64_epi32(low(a).m, high(a).m))) { } +# else + ENOKI_REINTERPRET(int64_t) + : m(_mm256_castsi256_ps(detail::mm512_cvtepi64_epi32( + low(low(a)).m, high(low(a)).m, + low(high(a)).m, high(high(a)).m))) { } + ENOKI_REINTERPRET(uint64_t) + : m(_mm256_castsi256_ps(detail::mm512_cvtepi64_epi32( + low(low(a)).m, high(low(a)).m, + low(high(a)).m, high(high(a)).m))) { } +# endif +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Converting from/to half size vectors + // ----------------------------------------------------------------------- + + StaticArrayImpl(const Array1 &a1, const Array2 &a2) + : m(detail::concat(a1.m, a2.m)) { } + + ENOKI_INLINE Array1 low_() const { return _mm256_castps256_ps128(m); } + ENOKI_INLINE Array2 high_() const { return _mm256_extractf128_ps(m, 1); } + + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Vertical operations + // ----------------------------------------------------------------------- + + ENOKI_INLINE Derived add_(Ref a) const { return _mm256_add_ps(m, a.m); } + ENOKI_INLINE Derived sub_(Ref a) const { return _mm256_sub_ps(m, a.m); } + ENOKI_INLINE Derived mul_(Ref a) const { return _mm256_mul_ps(m, a.m); } + ENOKI_INLINE Derived div_(Ref a) const { return _mm256_div_ps(m, a.m); } + + template ENOKI_INLINE Derived or_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_mask_mov_ps(m, a.k, _mm256_set1_ps(memcpy_cast(int32_t(-1)))); + else + #endif + return _mm256_or_ps(m, a.m); + } + + template ENOKI_INLINE Derived and_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_maskz_mov_ps(a.k, m); + else + #endif + return _mm256_and_ps(m, a.m); + } + + template ENOKI_INLINE Derived xor_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_mask_xor_ps(m, a.k, m, _mm256_set1_ps(memcpy_cast(int32_t(-1)))); + else + #endif + return _mm256_xor_ps(m, a.m); + } + + template ENOKI_INLINE Derived andnot_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_mask_mov_ps(m, a.k, _mm256_setzero_ps()); + else + #endif + return _mm256_andnot_ps(a.m, m); + } + + #if defined(ENOKI_X86_AVX512VL) + #define ENOKI_COMP(name, NAME) mask_t::from_k(_mm256_cmp_ps_mask(m, a.m, _CMP_##NAME)) + #else + #define ENOKI_COMP(name, NAME) mask_t(_mm256_cmp_ps(m, a.m, _CMP_##NAME)) + #endif + + ENOKI_INLINE auto lt_ (Ref a) const { return ENOKI_COMP(lt, LT_OQ); } + ENOKI_INLINE auto gt_ (Ref a) const { return ENOKI_COMP(gt, GT_OQ); } + ENOKI_INLINE auto le_ (Ref a) const { return ENOKI_COMP(le, LE_OQ); } + ENOKI_INLINE auto ge_ (Ref a) const { return ENOKI_COMP(ge, GE_OQ); } + ENOKI_INLINE auto eq_ (Ref a) const { + using Int = int_array_t; + if constexpr (IsMask_) + return mask_t(eq(Int(derived()), Int(a))); + else + return ENOKI_COMP(eq, EQ_OQ); + } + + ENOKI_INLINE auto neq_(Ref a) const { + using Int = int_array_t; + if constexpr (IsMask_) + return mask_t(neq(Int(derived()), Int(a))); + else + return ENOKI_COMP(neq, NEQ_UQ); + } + + #undef ENOKI_COMP + + ENOKI_INLINE Derived abs_() const { return _mm256_andnot_ps(_mm256_set1_ps(-0.f), m); } + ENOKI_INLINE Derived min_(Ref b) const { return _mm256_min_ps(b.m, m); } + ENOKI_INLINE Derived max_(Ref b) const { return _mm256_max_ps(b.m, m); } + ENOKI_INLINE Derived ceil_() const { return _mm256_ceil_ps(m); } + ENOKI_INLINE Derived floor_() const { return _mm256_floor_ps(m); } + ENOKI_INLINE Derived sqrt_() const { return _mm256_sqrt_ps(m); } + + ENOKI_INLINE Derived round_() const { + return _mm256_round_ps(m, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + } + + ENOKI_INLINE Derived trunc_() const { + return _mm256_round_ps(m, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); + } + + template + static ENOKI_INLINE Derived select_(const Mask &m, Ref t, Ref f) { + #if !defined(ENOKI_X86_AVX512VL) + return _mm256_blendv_ps(f.m, t.m, m.m); + #else + return _mm256_mask_blend_ps(m.k, f.m, t.m); + #endif + } + +#if defined(ENOKI_X86_FMA) + ENOKI_INLINE Derived fmadd_ (Ref b, Ref c) const { return _mm256_fmadd_ps (m, b.m, c.m); } + ENOKI_INLINE Derived fmsub_ (Ref b, Ref c) const { return _mm256_fmsub_ps (m, b.m, c.m); } + ENOKI_INLINE Derived fnmadd_ (Ref b, Ref c) const { return _mm256_fnmadd_ps (m, b.m, c.m); } + ENOKI_INLINE Derived fnmsub_ (Ref b, Ref c) const { return _mm256_fnmsub_ps (m, b.m, c.m); } + ENOKI_INLINE Derived fmsubadd_(Ref b, Ref c) const { return _mm256_fmsubadd_ps(m, b.m, c.m); } + ENOKI_INLINE Derived fmaddsub_(Ref b, Ref c) const { return _mm256_fmaddsub_ps(m, b.m, c.m); } +#endif + + template + ENOKI_INLINE Derived shuffle_() const { + #if defined(ENOKI_X86_AVX2) + return _mm256_permutevar8x32_ps(m, + _mm256_setr_epi32(I0, I1, I2, I3, I4, I5, I6, I7)); + #else + return Base::template shuffle_(); + #endif + } + + + template + ENOKI_INLINE Derived shuffle_(const Index &index_) const { + #if defined(ENOKI_X86_AVX2) + return _mm256_permutevar8x32_ps(m, index_.m); + #else + __m128i i0 = low(index_).m, + i1 = high(index_).m; + + // swap low and high part of table + __m256 m2 = _mm256_permute2f128_ps(m, m, 1); + + __m256i index = _mm256_insertf128_si256(_mm256_castsi128_si256(i0), i1, 1); + + __m256 r0 = _mm256_permutevar_ps(m, index), + r1 = _mm256_permutevar_ps(m2, index); + + __m128i k0 = _mm_slli_epi32(i0, 29), + k1 = _mm_slli_epi32(_mm_xor_si128(i1, _mm_set1_epi32(4)), 29); + + __m256 k = _mm256_insertf128_ps( + _mm256_castps128_ps256(_mm_castsi128_ps(k0)), + _mm_castsi128_ps(k1), 1); + + return _mm256_blendv_ps(r0, r1, k); + #endif + } + +#if defined(ENOKI_X86_AVX512VL) + ENOKI_INLINE Derived ldexp_(Ref arg) const { return _mm256_scalef_ps(m, arg.m); } + + ENOKI_INLINE std::pair frexp_() const { + return std::make_pair( + _mm256_getmant_ps(m, _MM_MANT_NORM_p5_1, _MM_MANT_SIGN_src), + _mm256_getexp_ps(m)); + } +#endif + + ENOKI_INLINE Derived rcp_() const { + #if defined(ENOKI_X86_AVX512ER) + /* rel err < 2^28, use as is (even in non-approximate mode) */ + return _mm512_castps512_ps256( + _mm512_rcp28_ps(_mm512_castps256_ps512(m))); + #else + /* Use best reciprocal approximation available on the current + hardware and refine */ + __m256 r; + #if defined(ENOKI_X86_AVX512VL) + r = _mm256_rcp14_ps(m); /* rel error < 2^-14 */ + #else + r = _mm256_rcp_ps(m); /* rel error < 1.5*2^-12 */ + #endif + + /* Refine using one Newton-Raphson iteration */ + __m256 t0 = _mm256_add_ps(r, r), + t1 = _mm256_mul_ps(r, m), + ro = r; + (void) ro; + + #if defined(ENOKI_X86_FMA) + r = _mm256_fnmadd_ps(t1, r, t0); + #else + r = _mm256_sub_ps(t0, _mm256_mul_ps(r, t1)); + #endif + + #if defined(ENOKI_X86_AVX512VL) + return _mm256_fixupimm_ps(r, m, _mm256_set1_epi32(0x0087A622), 0); + #else + return _mm256_blendv_ps(r, ro, t1); /* mask bit is '1' iff t1 == nan */ + #endif + #endif + } + + ENOKI_INLINE Derived rsqrt_() const { + #if defined(ENOKI_X86_AVX512ER) + /* rel err < 2^28, use as is (even in non-approximate mode) */ + return _mm512_castps512_ps256( + _mm512_rsqrt28_ps(_mm512_castps256_ps512(m))); + #else + /* Use best reciprocal square root approximation available + on the current hardware and refine */ + __m256 r; + #if defined(ENOKI_X86_AVX512VL) + r = _mm256_rsqrt14_ps(m); /* rel error < 2^-14 */ + #else + r = _mm256_rsqrt_ps(m); /* rel error < 1.5*2^-12 */ + #endif + + /* Refine using one Newton-Raphson iteration */ + const __m256 c0 = _mm256_set1_ps(.5f), + c1 = _mm256_set1_ps(3.f); + + __m256 t0 = _mm256_mul_ps(r, c0), + t1 = _mm256_mul_ps(r, m), + ro = r; + (void) ro; + + #if defined(ENOKI_X86_FMA) + r = _mm256_mul_ps(_mm256_fnmadd_ps(t1, r, c1), t0); + #else + r = _mm256_mul_ps(_mm256_sub_ps(c1, _mm256_mul_ps(t1, r)), t0); + #endif + + #if defined(ENOKI_X86_AVX512VL) + return _mm256_fixupimm_ps(r, m, _mm256_set1_epi32(0x0383A622), 0); + #else + return _mm256_blendv_ps(r, ro, t1); /* mask bit is '1' iff t1 == nan */ + #endif + #endif + } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Horizontal operations + // ----------------------------------------------------------------------- + + ENOKI_INLINE Value hsum_() const { return hsum(low_() + high_()); } + ENOKI_INLINE Value hprod_() const { return hprod(low_() * high_()); } + ENOKI_INLINE Value hmin_() const { return hmin(min(low_(), high_())); } + ENOKI_INLINE Value hmax_() const { return hmax(max(low_(), high_())); } + + ENOKI_INLINE bool all_() const { return _mm256_movemask_ps(m) == 0xFF;} + ENOKI_INLINE bool any_() const { return _mm256_movemask_ps(m) != 0x0; } + + ENOKI_INLINE uint32_t bitmask_() const { return (uint32_t) _mm256_movemask_ps(m); } + ENOKI_INLINE size_t count_() const { return (size_t) _mm_popcnt_u32(bitmask_()); } + + ENOKI_INLINE Value dot_(Ref a) const { + __m256 dp = _mm256_dp_ps(m, a.m, 0b11110001); + __m128 m0 = _mm256_castps256_ps128(dp); + __m128 m1 = _mm256_extractf128_ps(dp, 1); + __m128 m = _mm_add_ss(m0, m1); + return _mm_cvtss_f32(m); + } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Masked versions of key operations + // ----------------------------------------------------------------------- + +#if defined(ENOKI_X86_AVX512VL) + template + ENOKI_INLINE void massign_(const Derived &a, const Mask &mask) { m = _mm256_mask_mov_ps(m, mask.k, a.m); } + template + ENOKI_INLINE void madd_ (const Derived &a, const Mask &mask) { m = _mm256_mask_add_ps(m, mask.k, m, a.m); } + template + ENOKI_INLINE void msub_ (const Derived &a, const Mask &mask) { m = _mm256_mask_sub_ps(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mmul_ (const Derived &a, const Mask &mask) { m = _mm256_mask_mul_ps(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mdiv_ (const Derived &a, const Mask &mask) { m = _mm256_mask_div_ps(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mor_ (const Derived &a, const Mask &mask) { m = _mm256_mask_or_ps(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mand_ (const Derived &a, const Mask &mask) { m = _mm256_mask_and_ps(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mxor_ (const Derived &a, const Mask &mask) { m = _mm256_mask_xor_ps(m, mask.k, m, a.m); } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Initialization, loading/writing data + // ----------------------------------------------------------------------- + + ENOKI_INLINE void store_(void *ptr) const { + assert((uintptr_t) ptr % 32 == 0); + _mm256_store_ps((Value *) ENOKI_ASSUME_ALIGNED(ptr, 32), m); + } + + template + ENOKI_INLINE void store_(void *ptr, const Mask &mask) const { + #if defined(ENOKI_X86_AVX512VL) + _mm256_mask_store_ps((Value *) ptr, mask.k, m); + #else + _mm256_maskstore_ps((Value *) ptr, _mm256_castps_si256(mask.m), m); + #endif + } + + ENOKI_INLINE void store_unaligned_(void *ptr) const { + _mm256_storeu_ps((Value *) ptr, m); + } + + template + ENOKI_INLINE void store_unaligned_(void *ptr, const Mask &mask) const { + #if defined(ENOKI_X86_AVX512VL) + _mm256_mask_storeu_ps((Value *) ptr, mask.k, m); + #else + _mm256_maskstore_ps((Value *) ptr, _mm256_castps_si256(mask.m), m); + #endif + } + + static ENOKI_INLINE Derived load_(const void *ptr) { + assert((uintptr_t) ptr % 32 == 0); + return _mm256_load_ps((const Value *) ENOKI_ASSUME_ALIGNED(ptr, 32)); + } + + template + static ENOKI_INLINE Derived load_(const void *ptr, const Mask &mask) { + #if defined(ENOKI_X86_AVX512VL) + return _mm256_maskz_load_ps(mask.k, ptr); + #else + return _mm256_maskload_ps((const Value *) ptr, _mm256_castps_si256(mask.m)); + #endif + } + + static ENOKI_INLINE Derived load_unaligned_(const void *ptr) { + return _mm256_loadu_ps((const Value *) ptr); + } + + template + static ENOKI_INLINE Derived load_unaligned_(const void *ptr, const Mask &mask) { + #if defined(ENOKI_X86_AVX512VL) + return _mm256_maskz_loadu_ps(mask.k, ptr); + #else + return _mm256_maskload_ps((const Value *) ptr, _mm256_castps_si256(mask.m)); + #endif + } + + static ENOKI_INLINE Derived zero_() { return _mm256_setzero_ps(); } + +#if defined(ENOKI_X86_AVX2) + template + static ENOKI_INLINE Derived gather_(const void *ptr, const Index &index, const Mask &mask) { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (sizeof(scalar_t) == 4) + return _mm256_mmask_i32gather_ps(_mm256_setzero_ps(), mask.k, index.m, (const float *) ptr, Stride); + else + return _mm512_mask_i64gather_ps(_mm256_setzero_ps(), mask.k, index.m, (const float *) ptr, Stride); + #else + if constexpr (sizeof(scalar_t) == 4) + return _mm256_mask_i32gather_ps(_mm256_setzero_ps(), (const float *) ptr, index.m, mask.m, Stride); + else + return Derived( + _mm256_mask_i64gather_ps(_mm_setzero_ps(), (const float *) ptr, low(index).m, low(mask).m, Stride), + _mm256_mask_i64gather_ps(_mm_setzero_ps(), (const float *) ptr, high(index).m, high(mask).m, Stride) + ); + #endif + } +#endif + +#if defined(ENOKI_X86_AVX512VL) + template + ENOKI_INLINE void scatter_(void *ptr, const Index &index, const Mask &mask) const { + if constexpr (sizeof(scalar_t) == 4) + _mm256_mask_i32scatter_ps(ptr, mask.k, index.m, m, Stride); + else + _mm512_mask_i64scatter_ps(ptr, mask.k, index.m, m, Stride); + } +#endif + + template + ENOKI_INLINE Value extract_(const Mask &mask) const { + #if !defined(ENOKI_X86_AVX512VL) + unsigned int k = (unsigned int) _mm256_movemask_ps(mask.m); + return coeff((size_t) (tzcnt(k) & 7)); + #else + return _mm256_cvtss_f32(_mm256_mask_compress_ps(_mm256_setzero_ps(), mask.k, m)); + #endif + } + + template + ENOKI_INLINE size_t compress_(float *&ptr, const Mask &mask) const { + #if defined(ENOKI_X86_AVX512VL) + _mm256_storeu_ps(ptr, _mm256_maskz_compress_ps(mask.k, m)); + size_t kn = (size_t) _mm_popcnt_u32(mask.k); + ptr += kn; + return kn; + #elif defined(ENOKI_X86_AVX2) && defined(ENOKI_X86_64) + /** Clever BMI2-based partitioning algorithm by Christoph Diegelmann + see https://goo.gl/o3ysMN for context */ + + unsigned int k = (unsigned int) _mm256_movemask_epi8(_mm256_castps_si256(mask.m)); + uint32_t wanted_indices = _pext_u32(0x76543210, k); + uint64_t expanded_indices = _pdep_u64((uint64_t) wanted_indices, + 0x0F0F0F0F0F0F0F0Full); + size_t kn = (size_t) (_mm_popcnt_u32(k) >> 2); + + __m128i bytevec = detail::mm_cvtsi64_si128((long long) expanded_indices); + __m256i shufmask = _mm256_cvtepu8_epi32(bytevec); + __m256 perm = _mm256_permutevar8x32_ps(m, shufmask); + + _mm256_storeu_ps(ptr, perm); + ptr += kn; + return kn; + #else + size_t r0 = compress(ptr, low(derived()), low(mask)); + size_t r1 = compress(ptr, high(derived()), high(mask)); + return r0 + r1; + #endif + } + + //! @} + // ----------------------------------------------------------------------- +} ENOKI_MAY_ALIAS; + +/// Partial overload of StaticArrayImpl using AVX intrinsics (double precision) +template struct alignas(32) + StaticArrayImpl + : StaticArrayBase { + ENOKI_NATIVE_ARRAY(double, 4, __m256d) + + // ----------------------------------------------------------------------- + //! @{ \name Value constructors + // ----------------------------------------------------------------------- + + ENOKI_INLINE StaticArrayImpl(Value value) : m(_mm256_set1_pd(value)) { } + ENOKI_INLINE StaticArrayImpl(Value v0, Value v1, Value v2, Value v3) + : m(_mm256_setr_pd(v0, v1, v2, v3)) { } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Type converting constructors + // ----------------------------------------------------------------------- + +#if defined(ENOKI_X86_F16C) + ENOKI_CONVERT(half) { + m = _mm256_cvtps_pd( + _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *) a.derived().data()))); + } +#endif + + ENOKI_CONVERT(float) : m(_mm256_cvtps_pd(a.derived().m)) { } + ENOKI_CONVERT(int32_t) : m(_mm256_cvtepi32_pd(a.derived().m)) { } + +#if defined(ENOKI_X86_AVX512DQ) && defined(ENOKI_X86_AVX512VL) + ENOKI_CONVERT(uint32_t) : m(_mm256_cvtepu32_pd(a.derived().m)) { } +#endif + + ENOKI_CONVERT(double) : m(a.derived().m) { } + +#if defined(ENOKI_X86_AVX512DQ) && defined(ENOKI_X86_AVX512VL) + ENOKI_CONVERT(int64_t) : m(_mm256_cvtepi64_pd(a.derived().m)) { } + ENOKI_CONVERT(uint64_t) : m(_mm256_cvtepu64_pd(a.derived().m)) { } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Reinterpreting constructors, mask converters + // ----------------------------------------------------------------------- + + ENOKI_REINTERPRET(bool) { + int ival; + memcpy(&ival, a.derived().data(), 4); + __m128i value = _mm_cmpgt_epi8( + _mm_cvtsi32_si128(ival), _mm_setzero_si128()); + #if defined(ENOKI_X86_AVX2) + m = _mm256_castsi256_pd(_mm256_cvtepi8_epi64(value)); + #else + m = _mm256_castsi256_pd(_mm256_insertf128_si256( + _mm256_castsi128_si256(_mm_cvtepi8_epi64(value)), + _mm_cvtepi8_epi64(_mm_srli_si128(value, 2)), 1)); + #endif + } + + ENOKI_REINTERPRET(float) + : m(_mm256_castsi256_pd( + detail::mm256_cvtepi32_epi64(_mm_castps_si128(a.derived().m)))) { } + + ENOKI_REINTERPRET(int32_t) + : m(_mm256_castsi256_pd(detail::mm256_cvtepi32_epi64(a.derived().m))) { } + + ENOKI_REINTERPRET(uint32_t) + : m(_mm256_castsi256_pd(detail::mm256_cvtepi32_epi64(a.derived().m))) { } + + ENOKI_REINTERPRET(double) : m(a.derived().m) { } + +#if defined(ENOKI_X86_AVX2) + ENOKI_REINTERPRET(int64_t) : m(_mm256_castsi256_pd(a.derived().m)) { } + ENOKI_REINTERPRET(uint64_t) : m(_mm256_castsi256_pd(a.derived().m)) { } +#else + ENOKI_REINTERPRET(int64_t) + : m(detail::concat(_mm_castsi128_pd(low(a).m), + _mm_castsi128_pd(high(a).m))) { } + ENOKI_REINTERPRET(uint64_t) + : m(detail::concat(_mm_castsi128_pd(low(a).m), + _mm_castsi128_pd(high(a).m))) { } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Converting from/to half size vectors + // ----------------------------------------------------------------------- + + StaticArrayImpl(const Array1 &a1, const Array2 &a2) + : m(detail::concat(a1.m, a2.m)) { } + + ENOKI_INLINE Array1 low_() const { return _mm256_castpd256_pd128(m); } + ENOKI_INLINE Array2 high_() const { return _mm256_extractf128_pd(m, 1); } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Vertical operations + // ----------------------------------------------------------------------- + + ENOKI_INLINE Derived add_(Ref a) const { return _mm256_add_pd(m, a.m); } + ENOKI_INLINE Derived sub_(Ref a) const { return _mm256_sub_pd(m, a.m); } + ENOKI_INLINE Derived mul_(Ref a) const { return _mm256_mul_pd(m, a.m); } + ENOKI_INLINE Derived div_(Ref a) const { return _mm256_div_pd(m, a.m); } + + template ENOKI_INLINE Derived or_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_mask_mov_pd(m, a.k, _mm256_set1_pd(memcpy_cast(int64_t(-1)))); + else + #endif + return _mm256_or_pd(m, a.m); + } + + template ENOKI_INLINE Derived and_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_maskz_mov_pd(a.k, m); + else + #endif + return _mm256_and_pd(m, a.m); + } + + template ENOKI_INLINE Derived xor_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_mask_xor_pd(m, a.k, m, _mm256_set1_pd(memcpy_cast(int64_t(-1)))); + else + #endif + return _mm256_xor_pd(m, a.m); + } + + template ENOKI_INLINE Derived andnot_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_mask_mov_pd(m, a.k, _mm256_setzero_pd()); + else + #endif + return _mm256_andnot_pd(a.m, m); + } + + #if defined(ENOKI_X86_AVX512VL) + #define ENOKI_COMP(name, NAME) mask_t::from_k(_mm256_cmp_pd_mask(m, a.m, _CMP_##NAME)) + #else + #define ENOKI_COMP(name, NAME) mask_t(_mm256_cmp_pd(m, a.m, _CMP_##NAME)) + #endif + + ENOKI_INLINE auto lt_ (Ref a) const { return ENOKI_COMP(lt, LT_OQ); } + ENOKI_INLINE auto gt_ (Ref a) const { return ENOKI_COMP(gt, GT_OQ); } + ENOKI_INLINE auto le_ (Ref a) const { return ENOKI_COMP(le, LE_OQ); } + ENOKI_INLINE auto ge_ (Ref a) const { return ENOKI_COMP(ge, GE_OQ); } + + ENOKI_INLINE auto eq_ (Ref a) const { + using Int = int_array_t; + if constexpr (IsMask_) + return mask_t(eq(Int(derived()), Int(a))); + else + return ENOKI_COMP(eq, EQ_OQ); + } + + ENOKI_INLINE auto neq_(Ref a) const { + using Int = int_array_t; + if constexpr (IsMask_) + return mask_t(neq(Int(derived()), Int(a))); + else + return ENOKI_COMP(neq, NEQ_UQ); + } + + #undef ENOKI_COMP + + ENOKI_INLINE Derived abs_() const { return _mm256_andnot_pd(_mm256_set1_pd(-0.), m); } + ENOKI_INLINE Derived min_(Ref b) const { return _mm256_min_pd(b.m, m); } + ENOKI_INLINE Derived max_(Ref b) const { return _mm256_max_pd(b.m, m); } + ENOKI_INLINE Derived ceil_() const { return _mm256_ceil_pd(m); } + ENOKI_INLINE Derived floor_() const { return _mm256_floor_pd(m); } + ENOKI_INLINE Derived sqrt_() const { return _mm256_sqrt_pd(m); } + + ENOKI_INLINE Derived round_() const { + return _mm256_round_pd(m, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + } + + ENOKI_INLINE Derived trunc_() const { + return _mm256_round_pd(m, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); + } + + template + static ENOKI_INLINE Derived select_(const Mask &m, Ref t, Ref f) { + #if !defined(ENOKI_X86_AVX512VL) + return _mm256_blendv_pd(f.m, t.m, m.m); + #else + return _mm256_mask_blend_pd(m.k, f.m, t.m); + #endif + } + +#if defined(ENOKI_X86_FMA) + ENOKI_INLINE Derived fmadd_ (Ref b, Ref c) const { return _mm256_fmadd_pd (m, b.m, c.m); } + ENOKI_INLINE Derived fmsub_ (Ref b, Ref c) const { return _mm256_fmsub_pd (m, b.m, c.m); } + ENOKI_INLINE Derived fnmadd_ (Ref b, Ref c) const { return _mm256_fnmadd_pd (m, b.m, c.m); } + ENOKI_INLINE Derived fnmsub_ (Ref b, Ref c) const { return _mm256_fnmsub_pd (m, b.m, c.m); } + ENOKI_INLINE Derived fmsubadd_(Ref b, Ref c) const { return _mm256_fmsubadd_pd(m, b.m, c.m); } + ENOKI_INLINE Derived fmaddsub_(Ref b, Ref c) const { return _mm256_fmaddsub_pd(m, b.m, c.m); } +#endif + +#if defined(ENOKI_X86_AVX2) + template + ENOKI_INLINE Derived shuffle_() const { + return _mm256_permute4x64_pd(m, _MM_SHUFFLE(I3, I2, I1, I0)); + } + + template + ENOKI_INLINE Derived shuffle_(const Index &index) const { + return Base::shuffle_(index); + } +#endif + + +#if defined(ENOKI_X86_AVX512VL) + ENOKI_INLINE Derived ldexp_(Ref arg) const { return _mm256_scalef_pd(m, arg.m); } + + ENOKI_INLINE std::pair frexp_() const { + return std::make_pair( + _mm256_getmant_pd(m, _MM_MANT_NORM_p5_1, _MM_MANT_SIGN_src), + _mm256_getexp_pd(m)); + } +#endif + +#if defined(ENOKI_X86_AVX512VL) || defined(ENOKI_X86_AVX512ER) + ENOKI_INLINE Derived rcp_() const { + /* Use best reciprocal approximation available on the current + hardware and refine */ + __m256d r; + #if defined(ENOKI_X86_AVX512ER) + /* rel err < 2^28 */ + r = _mm512_castpd512_pd256( + _mm512_rcp28_pd(_mm512_castpd256_pd512(m))); + #elif defined(ENOKI_X86_AVX512VL) + r = _mm256_rcp14_pd(m); /* rel error < 2^-14 */ + #endif + + __m256d ro = r, t0, t1; + (void) ro; + + /* Refine using 1-2 Newton-Raphson iterations */ + ENOKI_UNROLL for (int i = 0; i < (has_avx512er ? 1 : 2); ++i) { + t0 = _mm256_add_pd(r, r); + t1 = _mm256_mul_pd(r, m); + r = _mm256_fnmadd_pd(t1, r, t0); + } + + #if defined(ENOKI_X86_AVX512VL) + return _mm256_fixupimm_pd(r, m, _mm256_set1_epi32(0x0087A622), 0); + #else + return _mm256_blendv_pd(r, ro, t1); /* mask bit is '1' iff t1 == nan */ + #endif + } + + ENOKI_INLINE Derived rsqrt_() const { + /* Use best reciprocal square root approximation available + on the current hardware and refine */ + __m256d r; + #if defined(ENOKI_X86_AVX512ER) + /* rel err < 2^28 */ + r = _mm512_castpd512_pd256( + _mm512_rsqrt28_pd(_mm512_castpd256_pd512(m))); + #elif defined(ENOKI_X86_AVX512VL) + r = _mm256_rsqrt14_pd(m); /* rel error < 2^-14 */ + #endif + + const __m256d c0 = _mm256_set1_pd(0.5), + c1 = _mm256_set1_pd(3.0); + + __m256d ro = r, t0, t1; + (void) ro; + + /* Refine using 1-2 Newton-Raphson iterations */ + ENOKI_UNROLL for (int i = 0; i < (has_avx512er ? 1 : 2); ++i) { + t0 = _mm256_mul_pd(r, c0); + t1 = _mm256_mul_pd(r, m); + r = _mm256_mul_pd(_mm256_fnmadd_pd(t1, r, c1), t0); + } + + #if defined(ENOKI_X86_AVX512VL) + return _mm256_fixupimm_pd(r, m, _mm256_set1_epi32(0x0383A622), 0); + #else + return _mm256_blendv_pd(r, ro, t1); /* mask bit is '1' iff t1 == nan */ + #endif + } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Horizontal operations + // ----------------------------------------------------------------------- + + ENOKI_INLINE Value hsum_() const { return hsum(low_() + high_()); } + ENOKI_INLINE Value hprod_() const { return hprod(low_() * high_()); } + ENOKI_INLINE Value hmin_() const { return hmin(min(low_(), high_())); } + ENOKI_INLINE Value hmax_() const { return hmax(max(low_(), high_())); } + + ENOKI_INLINE bool all_() const { return _mm256_movemask_pd(m) == 0xF;} + ENOKI_INLINE bool any_() const { return _mm256_movemask_pd(m) != 0x0; } + + ENOKI_INLINE uint32_t bitmask_() const { return (uint32_t) _mm256_movemask_pd(m); } + ENOKI_INLINE size_t count_() const { return (size_t) _mm_popcnt_u32(bitmask_()); } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Masked versions of key operations + // ----------------------------------------------------------------------- + +#if defined(ENOKI_X86_AVX512VL) + template + ENOKI_INLINE void massign_(const Derived &a, const Mask &mask) { m = _mm256_mask_mov_pd(m, mask.k, a.m); } + template + ENOKI_INLINE void madd_ (const Derived &a, const Mask &mask) { m = _mm256_mask_add_pd(m, mask.k, m, a.m); } + template + ENOKI_INLINE void msub_ (const Derived &a, const Mask &mask) { m = _mm256_mask_sub_pd(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mmul_ (const Derived &a, const Mask &mask) { m = _mm256_mask_mul_pd(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mdiv_ (const Derived &a, const Mask &mask) { m = _mm256_mask_div_pd(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mor_ (const Derived &a, const Mask &mask) { m = _mm256_mask_or_pd(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mand_ (const Derived &a, const Mask &mask) { m = _mm256_mask_and_pd(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mxor_ (const Derived &a, const Mask &mask) { m = _mm256_mask_xor_pd(m, mask.k, m, a.m); } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Initialization, loading/writing data + // ----------------------------------------------------------------------- + + ENOKI_INLINE void store_(void *ptr) const { + assert((uintptr_t) ptr % 32 == 0); + _mm256_store_pd((Value *) ENOKI_ASSUME_ALIGNED(ptr, 32), m); + } + + template + ENOKI_INLINE void store_(void *ptr, const Mask &mask) const { + #if defined(ENOKI_X86_AVX512VL) + _mm256_mask_store_pd((Value *) ptr, mask.k, m); + #else + _mm256_maskstore_pd((Value *) ptr, _mm256_castpd_si256(mask.m), m); + #endif + } + + ENOKI_INLINE void store_unaligned_(void *ptr) const { + _mm256_storeu_pd((Value *) ptr, m); + } + + template + ENOKI_INLINE void store_unaligned_(void *ptr, const Mask &mask) const { + #if defined(ENOKI_X86_AVX512VL) + _mm256_mask_storeu_pd((Value *) ptr, mask.k, m); + #else + _mm256_maskstore_pd((Value *) ptr, _mm256_castpd_si256(mask.m), m); + #endif + } + + static ENOKI_INLINE Derived load_(const void *ptr) { + assert((uintptr_t) ptr % 32 == 0); + return _mm256_load_pd((const Value *) ENOKI_ASSUME_ALIGNED(ptr, 32)); + } + + template + static ENOKI_INLINE Derived load_(const void *ptr, const Mask &mask) { + #if defined(ENOKI_X86_AVX512VL) + return _mm256_maskz_load_pd(mask.k, ptr); + #else + return _mm256_maskload_pd((const Value *) ptr, _mm256_castpd_si256(mask.m)); + #endif + } + + static ENOKI_INLINE Derived load_unaligned_(const void *ptr) { + return _mm256_loadu_pd((const Value *) ptr); + } + + template + static ENOKI_INLINE Derived load_unaligned_(const void *ptr, const Mask &mask) { + #if defined(ENOKI_X86_AVX512VL) + return _mm256_maskz_loadu_pd(mask.k, ptr); + #else + return _mm256_maskload_pd((const Value *) ptr, _mm256_castpd_si256(mask.m)); + #endif + } + + static ENOKI_INLINE Derived zero_() { return _mm256_setzero_pd(); } + +#if defined(ENOKI_X86_AVX2) + template + static ENOKI_INLINE Derived gather_(const void *ptr, const Index &index, const Mask &mask) { + #if !defined(ENOKI_X86_AVX512VL) + if constexpr (sizeof(scalar_t) == 4) + return _mm256_mask_i32gather_pd(_mm256_setzero_pd(), (const double *) ptr, index.m, mask.m, Stride); + else + return _mm256_mask_i64gather_pd(_mm256_setzero_pd(), (const double *) ptr, index.m, mask.m, Stride); + #else + if constexpr (sizeof(scalar_t) == 4) + return _mm256_mmask_i32gather_pd(_mm256_setzero_pd(), mask.k, index.m, (const double *) ptr, Stride); + else + return _mm256_mmask_i64gather_pd(_mm256_setzero_pd(), mask.k, index.m, (const double *) ptr, Stride); + #endif + } +#endif + +#if defined(ENOKI_X86_AVX512VL) + template + ENOKI_INLINE void scatter_(void *ptr, const Index &index, const Mask &mask) const { + if constexpr (sizeof(scalar_t) == 4) + _mm256_mask_i32scatter_pd(ptr, mask.k, index.m, m, Stride); + else + _mm256_mask_i64scatter_pd(ptr, mask.k, index.m, m, Stride); + } + + template + ENOKI_INLINE Value extract_(const Mask &mask) const { + return _mm256_cvtsd_f64(_mm256_mask_compress_pd(_mm256_setzero_pd(), mask.k, m)); + } + + template + ENOKI_INLINE size_t compress_(double *&ptr, const Mask &mask) const { + _mm256_storeu_pd(ptr, _mm256_mask_compress_pd(_mm256_setzero_pd(), mask.k, m)); + size_t kn = (size_t) _mm_popcnt_u32(mask.k); + ptr += kn; + return kn; + } +#endif + + //! @} + // ----------------------------------------------------------------------- +} ENOKI_MAY_ALIAS; + +/// Partial overload of StaticArrayImpl for the n=3 case (double precision) +template struct alignas(32) + StaticArrayImpl + : StaticArrayImpl { + using Base = StaticArrayImpl; + + ENOKI_DECLARE_3D_ARRAY(StaticArrayImpl) + +#if defined(ENOKI_X86_F16C) + template + ENOKI_INLINE StaticArrayImpl(const StaticArrayBase &a) { + uint16_t temp[4]; + memcpy(temp, a.derived().data(), sizeof(uint16_t) * 3); + temp[3] = 0; + m = _mm256_cvtps_pd(_mm_cvtph_ps(_mm_loadl_epi64((const __m128i *) temp))); + } +#endif + + template + ENOKI_INLINE Derived shuffle_() const { + return Base::template shuffle_(); + } + + template + ENOKI_INLINE Derived shuffle_(const Index &index) const { + return Base::shuffle_(index); + } + + // ----------------------------------------------------------------------- + //! @{ \name Horizontal operations (adapted for the n=3 case) + // ----------------------------------------------------------------------- + + #define ENOKI_HORIZONTAL_OP(name, op) \ + ENOKI_INLINE Value name##_() const { \ + __m128d t1 = _mm256_extractf128_pd(m, 1); \ + __m128d t2 = _mm256_castpd256_pd128(m); \ + t1 = _mm_##op##_sd(t1, t2); \ + t2 = _mm_permute_pd(t2, 1); \ + t2 = _mm_##op##_sd(t2, t1); \ + return _mm_cvtsd_f64(t2); \ + } + + ENOKI_HORIZONTAL_OP(hsum, add) + ENOKI_HORIZONTAL_OP(hprod, mul) + ENOKI_HORIZONTAL_OP(hmin, min) + ENOKI_HORIZONTAL_OP(hmax, max) + + #undef ENOKI_HORIZONTAL_OP + + ENOKI_INLINE bool all_() const { return (_mm256_movemask_pd(m) & 7) == 7; } + ENOKI_INLINE bool any_() const { return (_mm256_movemask_pd(m) & 7) != 0; } + + ENOKI_INLINE uint32_t bitmask_() const { return (uint32_t) (_mm256_movemask_pd(m) & 7); } + ENOKI_INLINE size_t count_() const { return (size_t) _mm_popcnt_u32(bitmask_()); } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Loading/writing data (adapted for the n=3 case) + // ----------------------------------------------------------------------- + + static ENOKI_INLINE auto mask_() { + #if defined(ENOKI_X86_AVX512VL) + return mask_t::from_k((__mmask8) 7); + #else + return mask_t(_mm256_castsi256_pd(_mm256_setr_epi64x(-1, -1, -1, 0))); + #endif + } + + using Base::load_; + using Base::load_unaligned_; + using Base::store_; + using Base::store_unaligned_; + + ENOKI_INLINE void store_(void *ptr) const { + memcpy(ptr, &m, sizeof(Value) * 3); + } + ENOKI_INLINE void store_unaligned_(void *ptr) const { + store_(ptr); + } + static ENOKI_INLINE Derived load_(const void *ptr) { + return Base::load_unaligned_(ptr); + } + static ENOKI_INLINE Derived load_unaligned_(const void *ptr) { + Derived result; + memcpy(&result.m, ptr, sizeof(Value) * 3); + return result; + } + + template + ENOKI_INLINE void store_(void *ptr, const Mask &mask) const { + Base::store_(ptr, mask & mask_()); + } + template + ENOKI_INLINE void store_unaligned_(void *ptr, const Mask &mask) const { + Base::store_unaligned_(ptr, mask & mask_()); + } + template + static ENOKI_INLINE Derived load_(const void *ptr, const Mask &mask) { + return Base::load_(ptr, mask & mask_()); + } + template + static ENOKI_INLINE Derived load_unaligned_(const void *ptr, const Mask &mask) { + return Base::load_unaligned_(ptr, mask & mask_()); + } + +#if defined(ENOKI_X86_AVX2) + template + static ENOKI_INLINE Derived gather_(const void *ptr, const Index &index, const Mask &mask) { + return Base::template gather_(ptr, index, mask & mask_()); + } +#endif + +#if defined(ENOKI_X86_AVX512VL) + template + ENOKI_INLINE void scatter_(void *ptr, const Index &index, const Mask &mask) const { + Base::template scatter_(ptr, index, mask & mask_()); + } +#endif + + template + ENOKI_INLINE size_t compress_(double *&ptr, const Mask &mask) const { + return Base::compress_(ptr, mask & mask_()); + } + + //! @} + // ----------------------------------------------------------------------- +} ENOKI_MAY_ALIAS; + +#if defined(ENOKI_X86_AVX512VL) +template +ENOKI_DECLARE_KMASK(float, 8, Derived_, int) +template +ENOKI_DECLARE_KMASK(double, 4, Derived_, int) +template +ENOKI_DECLARE_KMASK(double, 3, Derived_, int) +#endif + +NAMESPACE_END(enoki) diff --git a/sources/enoki/array_avx2.h b/sources/enoki/array_avx2.h new file mode 100644 index 00000000..2ab6a273 --- /dev/null +++ b/sources/enoki/array_avx2.h @@ -0,0 +1,1257 @@ +/* + enoki/array_avx2.h -- Packed SIMD array (AVX2 specialization) + + Enoki is a C++ template library that enables transparent vectorization + of numerical kernels using SIMD instruction sets available on current + processor architectures. + + Copyrighe (c) 2019 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#pragma once + +NAMESPACE_BEGIN(enoki) +NAMESPACE_BEGIN(detail) +template struct is_native> : std::true_type { }; +template struct is_native> : std::true_type { }; +template struct is_native> : std::true_type { }; +NAMESPACE_END(detail) + +/// Partial overload of StaticArrayImpl using AVX intrinsics (32 bit integers) +template struct alignas(32) + StaticArrayImpl> + : StaticArrayBase { + ENOKI_NATIVE_ARRAY(Value_, 8, __m256i) + + // ----------------------------------------------------------------------- + //! @{ \name Value constructors + // ----------------------------------------------------------------------- + + ENOKI_INLINE StaticArrayImpl(Value value) : m(_mm256_set1_epi32((int32_t) value)) { } + ENOKI_INLINE StaticArrayImpl(Value v0, Value v1, Value v2, Value v3, + Value v4, Value v5, Value v6, Value v7) + : m(_mm256_setr_epi32((int32_t) v0, (int32_t) v1, (int32_t) v2, (int32_t) v3, + (int32_t) v4, (int32_t) v5, (int32_t) v6, (int32_t) v7)) { } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Type converting constructors + // ----------------------------------------------------------------------- + + ENOKI_CONVERT(float) { + if constexpr (std::is_signed_v) { + m = _mm256_cvttps_epi32(a.derived().m); + } else { + #if defined(ENOKI_X86_AVX512VL) + m = _mm256_cvttps_epu32(a.derived().m); + #else + constexpr uint32_t limit = 1u << 31; + const __m256 limit_f = _mm256_set1_ps((float) limit); + const __m256i limit_i = _mm256_set1_epi32((int) limit); + + __m256 v = a.derived().m; + + __m256i mask = + _mm256_castps_si256(_mm256_cmp_ps(v, limit_f, _CMP_GE_OQ)); + + __m256i b2 = _mm256_add_epi32( + _mm256_cvttps_epi32(_mm256_sub_ps(v, limit_f)), limit_i); + + __m256i b1 = _mm256_cvttps_epi32(v); + + m = _mm256_blendv_epi8(b1, b2, mask); + #endif + } + } + + ENOKI_CONVERT(int32_t) : m(a.derived().m) { } + ENOKI_CONVERT(uint32_t) : m(a.derived().m) { } + + ENOKI_CONVERT(double) { + if constexpr (std::is_signed_v) { + #if defined(ENOKI_X86_AVX512F) + m = _mm512_cvttpd_epi32(a.derived().m); + #else + m = detail::concat(_mm256_cvttpd_epi32(low(a).m), + _mm256_cvttpd_epi32(high(a).m)); + #endif + } else { + #if defined(ENOKI_X86_AVX512F) + m = _mm512_cvttpd_epu32(a.derived().m); + #else + ENOKI_TRACK_SCALAR("Constructor (converting, double[8] -> [u]int32[8])"); + for (size_t i = 0; i < Size; ++i) + coeff(i) = Value(a.derived().coeff(i)); + #endif + } + } + + ENOKI_CONVERT(int64_t) { + #if defined(ENOKI_X86_AVX512F) + m = _mm512_cvtepi64_epi32(a.derived().m); + #else + m = detail::mm512_cvtepi64_epi32(low(a).m, high(a).m); + #endif + } + + ENOKI_CONVERT(uint64_t) { + #if defined(ENOKI_X86_AVX512F) + m = _mm512_cvtepi64_epi32(a.derived().m); + #else + m = detail::mm512_cvtepi64_epi32(low(a).m, high(a).m); + #endif + } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Reinterpreting constructors, mask converters + // ----------------------------------------------------------------------- + + ENOKI_REINTERPRET(bool) { + uint64_t ival; + memcpy(&ival, a.derived().data(), 8); + __m128i value = _mm_cmpgt_epi8(detail::mm_cvtsi64_si128((long long) ival), + _mm_setzero_si128()); + m = _mm256_cvtepi8_epi32(value); + } + + ENOKI_REINTERPRET(float) : m(_mm256_castps_si256(a.derived().m)) { } + ENOKI_REINTERPRET(int32_t) : m(a.derived().m) { } + ENOKI_REINTERPRET(uint32_t) : m(a.derived().m) { } + +#if defined(ENOKI_X86_AVX512DQ) && defined(ENOKI_X86_AVX512VL) + ENOKI_REINTERPRET(double) : m(_mm256_movm_epi32(a.derived().k)) { } + ENOKI_REINTERPRET(int64_t) : m(_mm256_movm_epi32(a.derived().k)) { } + ENOKI_REINTERPRET(uint64_t) : m(_mm256_movm_epi32(a.derived().k)) { } +#elif defined(ENOKI_X86_AVX512F) + ENOKI_REINTERPRET(double) + : m(_mm512_castsi512_si256(_mm512_maskz_mov_epi32( + (__mmask16) a.derived().k, _mm512_set1_epi32(int32_t(-1))))) { } + ENOKI_REINTERPRET(int64_t) + : m(_mm512_castsi512_si256(_mm512_maskz_mov_epi32( + (__mmask16) a.derived().k, _mm512_set1_epi32(int32_t(-1))))) { } + ENOKI_REINTERPRET(uint64_t) + : m(_mm512_castsi512_si256(_mm512_maskz_mov_epi32( + (__mmask16) a.derived().k, _mm512_set1_epi32(int32_t(-1))))) { } +#else + ENOKI_REINTERPRET(double) + : m(detail::mm512_cvtepi64_epi32(_mm256_castpd_si256(low(a).m), + _mm256_castpd_si256(high(a).m))) { } + ENOKI_REINTERPRET(int64_t) + : m(detail::mm512_cvtepi64_epi32(low(a).m, high(a).m)) { } + ENOKI_REINTERPRET(uint64_t) + : m(detail::mm512_cvtepi64_epi32(low(a).m, high(a).m)) { } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Converting from/to half size vectors + // ----------------------------------------------------------------------- + + StaticArrayImpl(const Array1 &a1, const Array2 &a2) + : m(detail::concat(a1.m, a2.m)) { } + + ENOKI_INLINE Array1 low_() const { return _mm256_castsi256_si128(m); } + ENOKI_INLINE Array2 high_() const { return _mm256_extractf128_si256(m, 1); } + + //! @} + // ----------------------------------------------------------------------- + + ENOKI_INLINE Derived add_(Ref a) const { return _mm256_add_epi32(m, a.m); } + ENOKI_INLINE Derived sub_(Ref a) const { return _mm256_sub_epi32(m, a.m); } + ENOKI_INLINE Derived mul_(Ref a) const { return _mm256_mullo_epi32(m, a.m); } + + template ENOKI_INLINE Derived or_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_mask_mov_epi32(m, a.k, _mm256_set1_epi32(-1)); + else + #endif + return _mm256_or_si256(m, a.m); + } + + template ENOKI_INLINE Derived and_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_maskz_mov_epi32(a.k, m); + else + #endif + return _mm256_and_si256(m, a.m); + } + + template ENOKI_INLINE Derived xor_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_mask_xor_epi32(m, a.k, m, _mm256_set1_epi32(-1)); + else + #endif + return _mm256_xor_si256(m, a.m); + } + + template ENOKI_INLINE Derived andnot_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_mask_mov_epi32(m, a.k, _mm256_setzero_si256()); + else + #endif + return _mm256_andnot_si256(a.m, m); + } + + template ENOKI_INLINE Derived sl_() const { + return _mm256_slli_epi32(m, (int) Imm); + } + + template ENOKI_INLINE Derived sr_() const { + return std::is_signed_v ? _mm256_srai_epi32(m, (int) Imm) + : _mm256_srli_epi32(m, (int) Imm); + } + + ENOKI_INLINE Derived sl_(size_t k) const { + return _mm256_sll_epi32(m, _mm_set1_epi64x((long long) k)); + } + + ENOKI_INLINE Derived sr_(size_t k) const { + return std::is_signed_v + ? _mm256_sra_epi32(m, _mm_set1_epi64x((long long) k)) + : _mm256_srl_epi32(m, _mm_set1_epi64x((long long) k)); + } + + ENOKI_INLINE Derived sl_(Ref k) const { + return _mm256_sllv_epi32(m, k.m); + } + + ENOKI_INLINE Derived sr_(Ref k) const { + return std::is_signed_v ? _mm256_srav_epi32(m, k.m) + : _mm256_srlv_epi32(m, k.m); + } + +#if defined(ENOKI_X86_AVX512VL) + template ENOKI_INLINE Derived rol_() const { return _mm256_rol_epi32(m, (int) Imm); } + template ENOKI_INLINE Derived ror_() const { return _mm256_ror_epi32(m, (int) Imm); } + ENOKI_INLINE Derived rol_(Ref k) const { return _mm256_rolv_epi32(m, k.m); } + ENOKI_INLINE Derived ror_(Ref k) const { return _mm256_rorv_epi32(m, k.m); } +#endif + + ENOKI_INLINE auto eq_(Ref a) const { + using Return = mask_t; + + #if defined(ENOKI_X86_AVX512VL) + return Return::from_k(_mm256_cmpeq_epi32_mask(m, a.m)); + #else + return Return(_mm256_cmpeq_epi32(m, a.m)); + #endif + } + + ENOKI_INLINE auto neq_(Ref a) const { + #if defined(ENOKI_X86_AVX512VL) + return mask_t::from_k(_mm256_cmpneq_epi32_mask(m, a.m)); + #else + return ~eq_(a); + #endif + } + + ENOKI_INLINE auto lt_(Ref a) const { + using Return = mask_t; + + #if !defined(ENOKI_X86_AVX512VL) + if constexpr (std::is_signed_v) { + return Return(_mm256_cmpgt_epi32(a.m, m)); + } else { + const __m256i offset = _mm256_set1_epi32((int32_t) 0x80000000ul); + return Return(_mm256_cmpgt_epi32(_mm256_sub_epi32(a.m, offset), + _mm256_sub_epi32(m, offset))); + } + #else + return Return::from_k(std::is_signed_v + ? _mm256_cmplt_epi32_mask(m, a.m) + : _mm256_cmplt_epu32_mask(m, a.m)); + #endif + } + + ENOKI_INLINE auto gt_(Ref a) const { + using Return = mask_t; + + #if !defined(ENOKI_X86_AVX512VL) + if constexpr (std::is_signed_v) { + return Return(_mm256_cmpgt_epi32(m, a.m)); + } else { + const __m256i offset = _mm256_set1_epi32((int32_t) 0x80000000ul); + return Return(_mm256_cmpgt_epi32(_mm256_sub_epi32(m, offset), + _mm256_sub_epi32(a.m, offset))); + } + #else + return Return::from_k(std::is_signed_v + ? _mm256_cmpgt_epi32_mask(m, a.m) + : _mm256_cmpgt_epu32_mask(m, a.m)); + #endif + } + + ENOKI_INLINE auto le_(Ref a) const { + #if defined(ENOKI_X86_AVX512VL) + return mask_t::from_k(std::is_signed_v + ? _mm256_cmple_epi32_mask(m, a.m) + : _mm256_cmple_epu32_mask(m, a.m)); + #else + return ~gt_(a); + #endif + } + + ENOKI_INLINE auto ge_(Ref a) const { + #if defined(ENOKI_X86_AVX512VL) + return mask_t::from_k(std::is_signed_v + ? _mm256_cmpge_epi32_mask(m, a.m) + : _mm256_cmpge_epu32_mask(m, a.m)); + #else + return ~lt_(a); + #endif + } + + ENOKI_INLINE Derived min_(Ref a) const { + return std::is_signed_v ? _mm256_min_epi32(a.m, m) + : _mm256_min_epu32(a.m, m); + } + + ENOKI_INLINE Derived max_(Ref a) const { + return std::is_signed_v ? _mm256_max_epi32(a.m, m) + : _mm256_max_epu32(a.m, m); + } + + ENOKI_INLINE Derived abs_() const { + return std::is_signed_v ? _mm256_abs_epi32(m) : m; + } + + template + static ENOKI_INLINE Derived select_(const Mask &m, Ref t, Ref f) { + #if !defined(ENOKI_X86_AVX512VL) + return _mm256_blendv_epi8(f.m, t.m, m.m); + #else + return _mm256_mask_blend_epi32(m.k, f.m, t.m); + #endif + } + + template + ENOKI_INLINE Derived shuffle_() const { + return _mm256_permutevar8x32_epi32(m, + _mm256_setr_epi32(I0, I1, I2, I3, I4, I5, I6, I7)); + } + + template + ENOKI_INLINE Derived shuffle_(const Index &index) const { + return _mm256_permutevar8x32_epi32(m, index.m); + } + + ENOKI_INLINE Derived mulhi_(Ref a) const { + Derived even, odd; + + if constexpr (std::is_signed_v) { + even.m = _mm256_srli_epi64(_mm256_mul_epi32(m, a.m), 32); + odd.m = _mm256_mul_epi32(_mm256_srli_epi64(m, 32), _mm256_srli_epi64(a.m, 32)); + } else { + even.m = _mm256_srli_epi64(_mm256_mul_epu32(m, a.m), 32); + odd.m = _mm256_mul_epu32(_mm256_srli_epi64(m, 32), _mm256_srli_epi64(a.m, 32)); + } + + #if defined(ENOKI_X86_AVX512VL) + const mask_t blend = mask_t::from_k(0b01010101); + #else + const mask_t blend(Value(-1), Value(0), Value(-1), Value(0), + Value(-1), Value(0), Value(-1), Value(0)); + #endif + + return select(blend, even, odd); + } + +#if defined(ENOKI_X86_AVX512CD) && defined(ENOKI_X86_AVX512VL) + ENOKI_INLINE Derived lzcnt_() const { return _mm256_lzcnt_epi32(m); } + ENOKI_INLINE Derived tzcnt_() const { return Value(32) - lzcnt(~derived() & (derived() - Value(1))); } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Horizontal operations + // ----------------------------------------------------------------------- + + ENOKI_INLINE Value hsum_() const { return hsum(low_() + high_()); } + ENOKI_INLINE Value hprod_() const { return hprod(low_() * high_()); } + ENOKI_INLINE Value hmin_() const { return hmin(min(low_(), high_())); } + ENOKI_INLINE Value hmax_() const { return hmax(max(low_(), high_())); } + + ENOKI_INLINE bool all_() const { return _mm256_movemask_ps(_mm256_castsi256_ps(m)) == 0xFF; } + ENOKI_INLINE bool any_() const { return _mm256_movemask_ps(_mm256_castsi256_ps(m)) != 0; } + + ENOKI_INLINE uint32_t bitmask_() const { return (uint32_t) _mm256_movemask_ps(_mm256_castsi256_ps(m)); } + ENOKI_INLINE size_t count_() const { return (size_t) _mm_popcnt_u32(bitmask_()); } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Masked versions of key operations + // ----------------------------------------------------------------------- + +#if defined(ENOKI_X86_AVX512VL) + template + ENOKI_INLINE void massign_(const Derived &a, const Mask &mask) { m = _mm256_mask_mov_epi32(m, mask.k, a.m); } + template + ENOKI_INLINE void madd_ (const Derived &a, const Mask &mask) { m = _mm256_mask_add_epi32(m, mask.k, m, a.m); } + template + ENOKI_INLINE void msub_ (const Derived &a, const Mask &mask) { m = _mm256_mask_sub_epi32(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mmul_ (const Derived &a, const Mask &mask) { m = _mm256_mask_mullo_epi32(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mor_ (const Derived &a, const Mask &mask) { m = _mm256_mask_or_epi32(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mand_ (const Derived &a, const Mask &mask) { m = _mm256_mask_and_epi32(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mxor_ (const Derived &a, const Mask &mask) { m = _mm256_mask_xor_epi32(m, mask.k, m, a.m); } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Initialization, loading/writing data + // ----------------------------------------------------------------------- + + ENOKI_INLINE void store_(void *ptr) const { + assert((uintptr_t) ptr % 32 == 0); + _mm256_store_si256((__m256i *) ENOKI_ASSUME_ALIGNED(ptr, 32), m); + } + + template + ENOKI_INLINE void store_(void *ptr, const Mask &mask) const { + #if defined(ENOKI_X86_AVX512VL) + _mm256_mask_store_epi32((Value *) ptr, mask.k, m); + #else + _mm256_maskstore_epi32((int *) ptr, mask.m, m); + #endif + } + + ENOKI_INLINE void store_unaligned_(void *ptr) const { + _mm256_storeu_si256((__m256i *) ptr, m); + } + + template + ENOKI_INLINE void store_unaligned_(void *ptr, const Mask &mask) const { + #if defined(ENOKI_X86_AVX512VL) + _mm256_mask_storeu_epi32((Value *) ptr, mask.k, m); + #else + _mm256_maskstore_epi32((int *) ptr, mask.m, m); + #endif + } + + static ENOKI_INLINE Derived load_(const void *ptr) { + assert((uintptr_t) ptr % 32 == 0); + return _mm256_load_si256((const __m256i *) ENOKI_ASSUME_ALIGNED(ptr, 32)); + } + + template + static ENOKI_INLINE Derived load_(const void *ptr, const Mask &mask) { + #if defined(ENOKI_X86_AVX512VL) + return _mm256_maskz_load_epi32(mask.k, ptr); + #else + return _mm256_maskload_epi32((const int *) ptr, mask.m); + #endif + } + + static ENOKI_INLINE Derived load_unaligned_(const void *ptr) { + return _mm256_loadu_si256((const __m256i *) ptr); + } + + template + static ENOKI_INLINE Derived load_unaligned_(const void *ptr, const Mask &mask) { + #if defined(ENOKI_X86_AVX512VL) + return _mm256_maskz_loadu_epi32(mask.k, ptr); + #else + return _mm256_maskload_epi32((const int *) ptr, mask.m); + #endif + } + + static ENOKI_INLINE Derived zero_() { return _mm256_setzero_si256(); } + + template + static ENOKI_INLINE Derived gather_(const void *ptr, const Index &index, const Mask &mask) { + #if defined(ENOKI_X86_AVX512VL) && defined(ENOKI_X86_AVX512DQ) + if constexpr (sizeof(scalar_t) == 4) + return _mm256_mmask_i32gather_epi32(_mm256_setzero_si256(), mask.k, index.m, (const int *) ptr, Stride); + else + return _mm512_mask_i64gather_epi32(_mm256_setzero_si256(), mask.k, index.m, (const int *) ptr, Stride); + #else + if constexpr (sizeof(scalar_t) == 4) { + return _mm256_mask_i32gather_epi32( + _mm256_setzero_si256(), (const int *) ptr, index.m, mask.m, Stride); + } else { + return Derived( + _mm256_mask_i64gather_epi32(_mm_setzero_si128(), (const int *) ptr, low(index).m, low(mask).m, Stride), + _mm256_mask_i64gather_epi32(_mm_setzero_si128(), (const int *) ptr, high(index).m, high(mask).m, Stride) + ); + } + #endif + } + +#if defined(ENOKI_X86_AVX512VL) + template + ENOKI_INLINE void scatter_(void *ptr, const Index &index, const Mask &mask) const { + if constexpr (sizeof(scalar_t) == 4) + _mm256_mask_i32scatter_epi32(ptr, mask.k, index.m, m, Stride); + else + _mm512_mask_i64scatter_epi32(ptr, mask.k, index.m, m, Stride); + } +#endif + + template + ENOKI_INLINE Value extract_(const Mask &mask) const { + #if !defined(ENOKI_X86_AVX512VL) + unsigned int k = (unsigned int) _mm256_movemask_ps(_mm256_castsi256_ps(mask.m)); + return coeff((size_t) (detail::tzcnt_scalar(k) & 7)); + #else + return (Value) _mm_cvtsi128_si32(_mm256_castsi256_si128( + _mm256_mask_compress_epi32(_mm256_setzero_si256(), mask.k, m))); + #endif + } + + template + ENOKI_INLINE size_t compress_(Value_ *&ptr, const Mask &mask) const { + #if defined(ENOKI_X86_AVX512VL) + size_t kn = (size_t) _mm_popcnt_u32(mask.k); + _mm256_storeu_si256((__m256i *) ptr, _mm256_maskz_compress_epi32(mask.k, m)); + ptr += kn; + return kn; + #elif defined(ENOKI_X86_64) // requires _pdep_u64 + /** Clever BMI2-based partitioning algorithm by Christoph Diegelmann + see https://goo.gl/o3ysMN for context */ + + unsigned int k = (unsigned int) _mm256_movemask_epi8(mask.m); + uint32_t wanted_indices = _pext_u32(0x76543210, k); + uint64_t expanded_indices = _pdep_u64((uint64_t) wanted_indices, + 0x0F0F0F0F0F0F0F0Full); + size_t kn = (size_t) (_mm_popcnt_u32(k) >> 2); + + __m128i bytevec = detail::mm_cvtsi64_si128((long long) expanded_indices); + __m256i shufmask = _mm256_cvtepu8_epi32(bytevec); + __m256i perm = _mm256_permutevar8x32_epi32(m, shufmask); + + _mm256_storeu_si256((__m256i *) ptr, perm); + ptr += kn; + return kn; + #else + return Base::compress_(ptr, mask); + #endif + } + + //! @} + // ----------------------------------------------------------------------- +} ENOKI_MAY_ALIAS; + +/// Partial overload of StaticArrayImpl using AVX intrinsics (64 bit integers) +template struct alignas(32) + StaticArrayImpl> + : StaticArrayBase { + ENOKI_NATIVE_ARRAY(Value_, 4, __m256i) + + // ----------------------------------------------------------------------- + //! @{ \name Value constructors + // ----------------------------------------------------------------------- + + ENOKI_INLINE StaticArrayImpl(const Value &value) + : m(_mm256_set1_epi64x((long long) value)) { } + + ENOKI_INLINE StaticArrayImpl(Value v0, Value v1, Value v2, Value v3) + : m(_mm256_setr_epi64x((long long) v0, (long long) v1, + (long long) v2, (long long) v3)) { } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Type converting constructors + // ----------------------------------------------------------------------- + +#if defined(ENOKI_X86_AVX512DQ) && defined(ENOKI_X86_AVX512VL) + ENOKI_CONVERT(float) { + m = std::is_signed_v ? _mm256_cvttps_epi64(a.derived().m) + : _mm256_cvttps_epu64(a.derived().m); + } + + ENOKI_CONVERT(double) { + m = std::is_signed_v ? _mm256_cvttpd_epi64(a.derived().m) + : _mm256_cvttpd_epu64(a.derived().m); + } +#endif + ENOKI_CONVERT(int32_t) : m(_mm256_cvtepi32_epi64(a.derived().m)) { } + ENOKI_CONVERT(uint32_t) : m(_mm256_cvtepu32_epi64(a.derived().m)) { } + + ENOKI_CONVERT(int64_t) : m(a.derived().m) { } + ENOKI_CONVERT(uint64_t) : m(a.derived().m) { } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Reinterpreting constructors, mask converters + // ----------------------------------------------------------------------- + + ENOKI_REINTERPRET(bool) { + int ival; + memcpy(&ival, a.derived().data(), 4); + m = _mm256_cvtepi8_epi64( + _mm_cmpgt_epi8(_mm_cvtsi32_si128(ival), _mm_setzero_si128())); + } + + ENOKI_REINTERPRET(float) + : m(_mm256_cvtepi32_epi64(_mm_castps_si128(a.derived().m))) { } + ENOKI_REINTERPRET(int32_t) : m(_mm256_cvtepi32_epi64(a.derived().m)) { } + ENOKI_REINTERPRET(uint32_t) : m(_mm256_cvtepi32_epi64(a.derived().m)) { } + + ENOKI_REINTERPRET(double) : m(_mm256_castpd_si256(a.derived().m)) { } + ENOKI_REINTERPRET(int64_t) : m(a.derived().m) { } + ENOKI_REINTERPRET(uint64_t) : m(a.derived().m) { } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Converting from/to half size vectors + // ----------------------------------------------------------------------- + + StaticArrayImpl(const Array1 &a1, const Array2 &a2) + : m(detail::concat(a1.m, a2.m)) { } + + ENOKI_INLINE Array1 low_() const { return _mm256_castsi256_si128(m); } + ENOKI_INLINE Array2 high_() const { return _mm256_extractf128_si256(m, 1); } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Vertical operations + // ----------------------------------------------------------------------- + + ENOKI_INLINE Derived add_(Ref a) const { return _mm256_add_epi64(m, a.m); } + ENOKI_INLINE Derived sub_(Ref a) const { return _mm256_sub_epi64(m, a.m); } + + ENOKI_INLINE Derived mul_(Ref a) const { + #if defined(ENOKI_X86_AVX512DQ) && defined(ENOKI_X86_AVX512VL) + return _mm256_mullo_epi64(m, a.m); + #else + __m256i h0 = _mm256_srli_epi64(m, 32); + __m256i h1 = _mm256_srli_epi64(a.m, 32); + __m256i low = _mm256_mul_epu32(m, a.m); + __m256i mix0 = _mm256_mul_epu32(m, h1); + __m256i mix1 = _mm256_mul_epu32(h0, a.m); + __m256i mix = _mm256_add_epi64(mix0, mix1); + __m256i mix_s = _mm256_slli_epi64(mix, 32); + return _mm256_add_epi64(mix_s, low); + #endif + } + + ENOKI_INLINE Derived mulhi_(Ref b) const { + if constexpr (std::is_unsigned_v) { + const __m256i low_bits = _mm256_set1_epi64x(0xffffffffu); + __m256i al = m, bl = b.m; + + __m256i ah = _mm256_srli_epi64(al, 32); + __m256i bh = _mm256_srli_epi64(bl, 32); + + // 4x unsigned 32x32->64 bit multiplication + __m256i albl = _mm256_mul_epu32(al, bl); + __m256i albh = _mm256_mul_epu32(al, bh); + __m256i ahbl = _mm256_mul_epu32(ah, bl); + __m256i ahbh = _mm256_mul_epu32(ah, bh); + + // Calculate a possible carry from the low bits of the multiplication. + __m256i carry = _mm256_add_epi64( + _mm256_srli_epi64(albl, 32), + _mm256_add_epi64(_mm256_and_si256(albh, low_bits), + _mm256_and_si256(ahbl, low_bits))); + + __m256i s0 = _mm256_add_epi64(ahbh, _mm256_srli_epi64(carry, 32)); + __m256i s1 = _mm256_add_epi64(_mm256_srli_epi64(albh, 32), + _mm256_srli_epi64(ahbl, 32)); + + return _mm256_add_epi64(s0, s1); + + } else { + const Derived mask(0xffffffff); + const Derived a = derived(); + Derived ah = sr<32>(a), bh = sr<32>(b), + al = a & mask, bl = b & mask; + + Derived albl_hi = _mm256_srli_epi64(_mm256_mul_epu32(m, b.m), 32); + + Derived t = ah * bl + albl_hi; + Derived w1 = al * bh + (t & mask); + + return ah * bh + sr<32>(t) + sr<32>(w1); + } + } + + template ENOKI_INLINE Derived or_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_mask_mov_epi64(m, a.k, _mm256_set1_epi64x(-1)); + else + #endif + return _mm256_or_si256(m, a.m); + } + + template ENOKI_INLINE Derived and_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_maskz_mov_epi64(a.k, m); + else + #endif + return _mm256_and_si256(m, a.m); + } + + template ENOKI_INLINE Derived xor_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_mask_xor_epi64(m, a.k, m, _mm256_set1_epi64x(-1)); + else + #endif + return _mm256_xor_si256(m, a.m); + } + + template ENOKI_INLINE Derived andnot_(const T &a) const { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (is_mask_v) + return _mm256_mask_mov_epi64(m, a.k, _mm256_setzero_si256()); + else + #endif + return _mm256_andnot_si256(a.m, m); + } + + template ENOKI_INLINE Derived sl_() const { + return _mm256_slli_epi64(m, (int) k); + } + + template ENOKI_INLINE Derived sr_() const { + if constexpr (std::is_signed_v) { + #if defined(ENOKI_X86_AVX512VL) + return _mm256_srai_epi64(m, (int) k); + #else + const __m256i offset = _mm256_set1_epi64x((long long) 0x8000000000000000ull); + __m256i s1 = _mm256_srli_epi64(_mm256_add_epi64(m, offset), (int) k); + __m256i s2 = _mm256_srli_epi64(offset, (int) k); + return _mm256_sub_epi64(s1, s2); + #endif + } else { + return _mm256_srli_epi64(m, (int) k); + } + } + + ENOKI_INLINE Derived sl_(size_t k) const { + return _mm256_sll_epi64(m, _mm_set1_epi64x((long long) k)); + } + + ENOKI_INLINE Derived sr_(size_t k) const { + if constexpr (std::is_signed_v) { + #if defined(ENOKI_X86_AVX512VL) + return _mm256_sra_epi64(m, _mm_set1_epi64x((long long) k)); + #else + const __m256i offset = _mm256_set1_epi64x((long long) 0x8000000000000000ull); + __m128i s0 = _mm_set1_epi64x((long long) k); + __m256i s1 = _mm256_srl_epi64(_mm256_add_epi64(m, offset), s0); + __m256i s2 = _mm256_srl_epi64(offset, s0); + return _mm256_sub_epi64(s1, s2); + #endif + } else { + return _mm256_srl_epi64(m, _mm_set1_epi64x((long long) k)); + } + } + + ENOKI_INLINE Derived sl_(Ref k) const { + return _mm256_sllv_epi64(m, k.m); + } + + ENOKI_INLINE Derived sr_(Ref k) const { + if constexpr (std::is_signed_v) { + #if defined(ENOKI_X86_AVX512VL) + return _mm256_srav_epi64(m, k.m); + #else + const __m256i offset = _mm256_set1_epi64x((long long) 0x8000000000000000ull); + __m256i s1 = _mm256_srlv_epi64(_mm256_add_epi64(m, offset), k.m); + __m256i s2 = _mm256_srlv_epi64(offset, k.m); + return _mm256_sub_epi64(s1, s2); + #endif + } else { + return _mm256_srlv_epi64(m, k.m); + } + } + +#if defined(ENOKI_X86_AVX512VL) + template ENOKI_INLINE Derived rol_() const { return _mm256_rol_epi64(m, (int) Imm); } + template ENOKI_INLINE Derived ror_() const { return _mm256_ror_epi64(m, (int) Imm); } + ENOKI_INLINE Derived rol_(Ref k) const { return _mm256_rolv_epi64(m, k.m); } + ENOKI_INLINE Derived ror_(Ref k) const { return _mm256_rorv_epi64(m, k.m); } +#endif + + ENOKI_INLINE auto eq_(Ref a) const { + using Return = mask_t; + + #if defined(ENOKI_X86_AVX512VL) + return Return::from_k(_mm256_cmpeq_epi64_mask(m, a.m)); + #else + return Return(_mm256_cmpeq_epi64(m, a.m)); + #endif + } + + ENOKI_INLINE auto neq_(Ref a) const { + #if defined(ENOKI_X86_AVX512VL) + return mask_t::from_k(_mm256_cmpneq_epi64_mask(m, a.m)); + #else + return ~eq_(a); + #endif + } + + ENOKI_INLINE auto lt_(Ref a) const { + using Return = mask_t; + + #if !defined(ENOKI_X86_AVX512VL) + if constexpr (std::is_signed_v) { + return Return(_mm256_cmpgt_epi64(a.m, m)); + } else { + const __m256i offset = + _mm256_set1_epi64x((long long) 0x8000000000000000ull); + return Return(_mm256_cmpgt_epi64( + _mm256_sub_epi64(a.m, offset), + _mm256_sub_epi64(m, offset) + )); + } + #else + return Return::from_k(std::is_signed_v + ? _mm256_cmplt_epi64_mask(m, a.m) + : _mm256_cmplt_epu64_mask(m, a.m)); + #endif + } + + ENOKI_INLINE auto gt_(Ref a) const { + using Return = mask_t; + + #if !defined(ENOKI_X86_AVX512VL) + if constexpr (std::is_signed_v) { + return Return(_mm256_cmpgt_epi64(m, a.m)); + } else { + const __m256i offset = + _mm256_set1_epi64x((long long) 0x8000000000000000ull); + return Return(_mm256_cmpgt_epi64( + _mm256_sub_epi64(m, offset), + _mm256_sub_epi64(a.m, offset) + )); + } + #else + return Return::from_k(std::is_signed_v + ? _mm256_cmpgt_epi64_mask(m, a.m) + : _mm256_cmpgt_epu64_mask(m, a.m)); + #endif + } + + ENOKI_INLINE auto le_(Ref a) const { + #if defined(ENOKI_X86_AVX512VL) + return mask_t::from_k(std::is_signed_v + ? _mm256_cmple_epi64_mask(m, a.m) + : _mm256_cmple_epu64_mask(m, a.m)); + #else + return ~gt_(a); + #endif + } + + ENOKI_INLINE auto ge_(Ref a) const { + #if defined(ENOKI_X86_AVX512VL) + return mask_t::from_k(std::is_signed_v + ? _mm256_cmpge_epi64_mask(m, a.m) + : _mm256_cmpge_epu64_mask(m, a.m)); + #else + return ~lt_(a); + #endif + } + + ENOKI_INLINE Derived min_(Ref a) const { + #if defined(ENOKI_X86_AVX512VL) + return std::is_signed_v ? _mm256_min_epi64(a.m, m) + : _mm256_min_epu64(a.m, m); + #else + return select(derived() < a, derived(), a); + #endif + } + + ENOKI_INLINE Derived max_(Ref a) const { + #if defined(ENOKI_X86_AVX512VL) + return std::is_signed_v ? _mm256_max_epi64(a.m, m) + : _mm256_max_epu64(a.m, m); + #else + return select(derived() > a, derived(), a); + #endif + } + + ENOKI_INLINE Derived abs_() const { + if constexpr (std::is_signed_v) { + #if defined(ENOKI_X86_AVX512VL) + return _mm256_abs_epi64(m); + #else + return select(derived() < zero(), + ~derived() + Derived(Value(1)), derived()); + #endif + } else { + return m; + } + } + + template + static ENOKI_INLINE Derived select_(const Mask &m, Ref t, Ref f) { + #if !defined(ENOKI_X86_AVX512VL) + return _mm256_blendv_epi8(f.m, t.m, m.m); + #else + return _mm256_mask_blend_epi64(m.k, f.m, t.m); + #endif + } + + template + ENOKI_INLINE Derived shuffle_() const { + return _mm256_permute4x64_epi64(m, _MM_SHUFFLE(I3, I2, I1, I0)); + } + + template + ENOKI_INLINE Derived shuffle_(const Index &index) const { + return Base::shuffle_(index); + } + +#if defined(ENOKI_X86_AVX512CD) && defined(ENOKI_X86_AVX512VL) + ENOKI_INLINE Derived lzcnt_() const { return _mm256_lzcnt_epi64(m); } + ENOKI_INLINE Derived tzcnt_() const { return Value(64) - lzcnt(~derived() & (derived() - Value(1))); } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Masked versions of key operations + // ----------------------------------------------------------------------- + +#if defined(ENOKI_X86_AVX512VL) + template + ENOKI_INLINE void massign_(const Derived &a, const Mask &mask) { m = _mm256_mask_mov_epi64(m, mask.k, a.m); } + template + ENOKI_INLINE void madd_ (const Derived &a, const Mask &mask) { m = _mm256_mask_add_epi64(m, mask.k, m, a.m); } + template + ENOKI_INLINE void msub_ (const Derived &a, const Mask &mask) { m = _mm256_mask_sub_epi64(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mmul_ (const Derived &a, const Mask &mask) { m = _mm256_mask_mullo_epi64(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mor_ (const Derived &a, const Mask &mask) { m = _mm256_mask_or_epi64(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mand_ (const Derived &a, const Mask &mask) { m = _mm256_mask_and_epi64(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mxor_ (const Derived &a, const Mask &mask) { m = _mm256_mask_xor_epi64(m, mask.k, m, a.m); } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Horizontal operations + // ----------------------------------------------------------------------- + // + ENOKI_INLINE Value hsum_() const { return hsum(low_() + high_()); } + ENOKI_INLINE Value hprod_() const { return hprod(low_() * high_()); } + ENOKI_INLINE Value hmin_() const { return hmin(min(low_(), high_())); } + ENOKI_INLINE Value hmax_() const { return hmax(max(low_(), high_())); } + + ENOKI_INLINE bool all_() const { return _mm256_movemask_pd(_mm256_castsi256_pd(m)) == 0xF; } + ENOKI_INLINE bool any_() const { return _mm256_movemask_pd(_mm256_castsi256_pd(m)) != 0; } + + ENOKI_INLINE uint32_t bitmask_() const { return (uint32_t) _mm256_movemask_pd(_mm256_castsi256_pd(m)); } + ENOKI_INLINE size_t count_() const { return (size_t) _mm_popcnt_u32(bitmask_()); } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Initialization, loading/writing data + // ----------------------------------------------------------------------- + + ENOKI_INLINE void store_(void *ptr) const { + assert((uintptr_t) ptr % 32 == 0); + _mm256_store_si256((__m256i *) ENOKI_ASSUME_ALIGNED(ptr, 32), m); + } + + template + ENOKI_INLINE void store_(void *ptr, const Mask &mask) const { + #if defined(ENOKI_X86_AVX512VL) + _mm256_mask_store_epi64(ptr, mask.k, m); + #else + _mm256_maskstore_epi64((long long *) ptr, mask.m, m); + #endif + } + + ENOKI_INLINE void store_unaligned_(void *ptr) const { + _mm256_storeu_si256((__m256i *) ptr, m); + } + + template + ENOKI_INLINE void store_unaligned_(void *ptr, const Mask &mask) const { + #if defined(ENOKI_X86_AVX512VL) + _mm256_mask_storeu_epi64(ptr, mask.k, m); + #else + _mm256_maskstore_epi64((long long *) ptr, mask.m, m); + #endif + } + + static ENOKI_INLINE Derived load_(const void *ptr) { + assert((uintptr_t) ptr % 32 == 0); + return _mm256_load_si256((const __m256i *) ENOKI_ASSUME_ALIGNED(ptr, 32)); + } + + template + static ENOKI_INLINE Derived load_(const void *ptr, const Mask &mask) { + #if defined(ENOKI_X86_AVX512VL) + return _mm256_maskz_load_epi64(mask.k, ptr); + #else + return _mm256_maskload_epi64((const long long *) ptr, mask.m); + #endif + } + + static ENOKI_INLINE Derived load_unaligned_(const void *ptr) { + return _mm256_loadu_si256((const __m256i *) ptr); + } + + template + static ENOKI_INLINE Derived load_unaligned_(const void *ptr, const Mask &mask) { + #if defined(ENOKI_X86_AVX512VL) + return _mm256_maskz_loadu_epi64(mask.k, ptr); + #else + return _mm256_maskload_epi64((const long long *) ptr, mask.m); + #endif + } + + static ENOKI_INLINE Derived zero_() { return _mm256_setzero_si256(); } + + template + static ENOKI_INLINE Derived gather_(const void *ptr, const Index &index, const Mask &mask) { + #if defined(ENOKI_X86_AVX512VL) + if constexpr (sizeof(scalar_t) == 4) + return _mm256_mmask_i32gather_epi64(_mm256_setzero_si256(), mask.k, index.m, (const long long *) ptr, Stride); + else + return _mm256_mmask_i64gather_epi64(_mm256_setzero_si256(), mask.k, index.m, (const long long *) ptr, Stride); + #else + if constexpr (sizeof(scalar_t) == 4) + return _mm256_mask_i32gather_epi64(_mm256_setzero_si256(), (const long long *) ptr, index.m, mask.m, Stride); + else + return _mm256_mask_i64gather_epi64(_mm256_setzero_si256(), (const long long *) ptr, index.m, mask.m, Stride); + #endif + } + +#if defined(ENOKI_X86_AVX512VL) + template + ENOKI_INLINE void scatter_(void *ptr, const Index &index, const Mask &mask) const { + if constexpr (sizeof(scalar_t) == 4) + _mm256_mask_i32scatter_epi64(ptr, mask.k, index.m, m, Stride); + else + _mm256_mask_i64scatter_epi64(ptr, mask.k, index.m, m, Stride); + } +#endif + + template + ENOKI_INLINE Value extract_(const Mask &mask) const { + #if defined(ENOKI_X86_AVX512VL) + return (Value) detail::mm_cvtsi128_si64(_mm256_castsi256_si128( + _mm256_mask_compress_epi64(_mm256_setzero_si256(), mask.k, m))); + #else + unsigned int k = + (unsigned int) _mm256_movemask_pd(_mm256_castsi256_pd(mask.m)); + return coeff((size_t) (tzcnt(k) & 3)); + #endif + } + + template + ENOKI_INLINE size_t compress_(Value_ *&ptr, const Mask &mask) const { + #if defined(ENOKI_X86_AVX512VL) + size_t kn = (size_t) _mm_popcnt_u32(mask.k); + _mm256_storeu_si256((__m256i *) ptr, _mm256_maskz_compress_epi64(mask.k, m)); + ptr += kn; + return kn; + #elif defined(ENOKI_X86_64) // requires _pdep_u64 + /** Clever BMI2-based partitioning algorithm by Christoph Diegelmann + see https://goo.gl/o3ysMN for context */ + + unsigned int k = (unsigned int) _mm256_movemask_epi8(mask.m); + uint32_t wanted_indices = _pext_u32(0x76543210, k); + uint64_t expanded_indices = _pdep_u64((uint64_t) wanted_indices, + 0x0F0F0F0F0F0F0F0Full); + size_t kn = (size_t) (_mm_popcnt_u32(k) >> 3); + + __m128i bytevec = detail::mm_cvtsi64_si128((long long) expanded_indices); + __m256i shufmask = _mm256_cvtepu8_epi32(bytevec); + + __m256i perm = _mm256_permutevar8x32_epi32(m, shufmask); + + _mm256_storeu_si256((__m256i *) ptr, perm); + ptr += kn; + return kn; + #else + return Base::compress_(ptr, mask); + #endif + } + + //! @} + // ----------------------------------------------------------------------- +} ENOKI_MAY_ALIAS; + +/// Partial overload of StaticArrayImpl for the n=3 case (64 bit integers) +template struct alignas(32) + StaticArrayImpl> + : StaticArrayImpl { + using Base = StaticArrayImpl; + + ENOKI_DECLARE_3D_ARRAY(StaticArrayImpl) + + template + ENOKI_INLINE Derived shuffle_() const { + return Base::template shuffle_(); + } + + template + ENOKI_INLINE Derived shuffle_(const Index &idx) const { + return Base::shuffle_(idx); + } + + // ----------------------------------------------------------------------- + //! @{ \name Horizontal operations (adapted for the n=3 case) + // ----------------------------------------------------------------------- + + ENOKI_INLINE Value hsum_() const { + Value result = coeff(0); + for (size_t i = 1; i < 3; ++i) + result += coeff(i); + return result; + } + + ENOKI_INLINE Value hprod_() const { + Value result = coeff(0); + for (size_t i = 1; i < 3; ++i) + result *= coeff(i); + return result; + } + + ENOKI_INLINE Value hmin_() const { + Value result = coeff(0); + for (size_t i = 1; i < 3; ++i) + result = std::min(result, coeff(i)); + return result; + } + + ENOKI_INLINE Value hmax_() const { + Value result = coeff(0); + for (size_t i = 1; i < 3; ++i) + result = std::max(result, coeff(i)); + return result; + } + + ENOKI_INLINE bool all_() const { return (_mm256_movemask_pd(_mm256_castsi256_pd(m)) & 7) == 7;} + ENOKI_INLINE bool any_() const { return (_mm256_movemask_pd(_mm256_castsi256_pd(m)) & 7) != 0; } + + ENOKI_INLINE uint32_t bitmask_() const { return (uint32_t) _mm256_movemask_pd(_mm256_castsi256_pd(m)) & 7; } + ENOKI_INLINE size_t count_() const { return (size_t) _mm_popcnt_u32(bitmask_()); } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Loading/writing data (adapted for the n=3 case) + // ----------------------------------------------------------------------- + + static ENOKI_INLINE auto mask_() { + #if defined(ENOKI_X86_AVX512VL) + return mask_t::from_k((__mmask8) 7); + #else + return mask_t(_mm256_setr_epi64x( + (int64_t) -1, (int64_t) -1, (int64_t) -1, (int64_t) 0)); + #endif + } + + using Base::load_; + using Base::load_unaligned_; + using Base::store_; + using Base::store_unaligned_; + + ENOKI_INLINE void store_(void *ptr) const { + memcpy(ptr, &m, sizeof(Value) * 3); + } + + ENOKI_INLINE void store_unaligned_(void *ptr) const { + store_(ptr); + } + + static ENOKI_INLINE Derived load_(const void *ptr) { + return Base::load_unaligned_(ptr); + } + + static ENOKI_INLINE Derived load_unaligned_(const void *ptr) { + Derived result; + memcpy(&result.m, ptr, sizeof(Value) * 3); + return result; + } + + template + ENOKI_INLINE void store_unaligned_(void *ptr, const Mask &mask) const { + return Base::store_unaligned_(ptr, mask & mask_()); + } + + template + ENOKI_INLINE void store_(void *ptr, const Mask &mask) const { + return Base::store_(ptr, mask & mask_()); + } + + template + static ENOKI_INLINE Derived load_(const void *ptr, const Mask &mask) { + return Base::load_(ptr, mask & mask_()); + } + + template + static ENOKI_INLINE Derived load_unaligned_(const void *ptr, const Mask &mask) { + return Base::load_unaligned_(ptr, mask & mask_()); + } + + template + static ENOKI_INLINE Derived gather_(const void *ptr, const Index &index, const Mask &mask) { + return Base::template gather_(ptr, index, mask & mask_()); + } + +#if defined(ENOKI_X86_AVX512VL) + template + ENOKI_INLINE void scatter_(void *ptr, const Index &index, const Mask &mask) const { + Base::template scatter_(ptr, index, mask & mask_()); + } +#endif + + template + ENOKI_INLINE size_t compress_(Value_ *&ptr, const Mask &mask) const { + return Base::compress_(ptr, mask & mask_()); + } + + //! @} + // ----------------------------------------------------------------------- +} ENOKI_MAY_ALIAS; + +#if defined(ENOKI_X86_AVX512VL) +template +ENOKI_DECLARE_KMASK(Value_, 8, Derived_, enable_if_int32_t) +template +ENOKI_DECLARE_KMASK(Value_, 4, Derived_, enable_if_int64_t) +template +ENOKI_DECLARE_KMASK(Value_, 3, Derived_, enable_if_int64_t) +#endif + +NAMESPACE_END(enoki) diff --git a/sources/enoki/array_avx512.h b/sources/enoki/array_avx512.h new file mode 100644 index 00000000..e8e14b20 --- /dev/null +++ b/sources/enoki/array_avx512.h @@ -0,0 +1,1928 @@ +/* + enoki/array_avx512.h -- Packed SIMD array (AVX512 specialization) + + Enoki is a C++ template library that enables transparent vectorization + of numerical kernels using SIMD instruction sets available on current + processor architectures. + + Copyrighe (c) 2019 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#pragma once + +NAMESPACE_BEGIN(enoki) +NAMESPACE_BEGIN(detail) +template <> struct is_native : std::true_type { } ; +template <> struct is_native : std::true_type { }; +template struct is_native> : std::true_type { }; +template struct is_native> : std::true_type { }; +NAMESPACE_END(detail) + +/// Partial overload of StaticArrayImpl using AVX512 intrinsics (single precision) +template struct alignas(64) + StaticArrayImpl + : StaticArrayBase { + ENOKI_NATIVE_ARRAY(float, 16, __m512) + + // ----------------------------------------------------------------------- + //! @{ \name Value constructors + // ----------------------------------------------------------------------- + + ENOKI_INLINE StaticArrayImpl(Value value) : m(_mm512_set1_ps(value)) { } + + ENOKI_INLINE StaticArrayImpl(Value f0, Value f1, Value f2, Value f3, + Value f4, Value f5, Value f6, Value f7, + Value f8, Value f9, Value f10, Value f11, + Value f12, Value f13, Value f14, Value f15) + : m(_mm512_setr_ps(f0, f1, f2, f3, f4, f5, f6, f7, f8, + f9, f10, f11, f12, f13, f14, f15)) { } + + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Type converting constructors + // ----------------------------------------------------------------------- + + ENOKI_CONVERT(half) + : m(_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *) a.derived().data()))) { } + + ENOKI_CONVERT(float) : m(a.derived().m) { } + + ENOKI_CONVERT(int32_t) : m(_mm512_cvtepi32_ps(a.derived().m)) { } + + ENOKI_CONVERT(uint32_t) : m(_mm512_cvtepu32_ps(a.derived().m)) { } + + ENOKI_CONVERT(double) + : m(detail::concat(_mm512_cvtpd_ps(low(a).m), + _mm512_cvtpd_ps(high(a).m))) { } + +#if defined(ENOKI_X86_AVX512DQ) + ENOKI_CONVERT(int64_t) + : m(detail::concat(_mm512_cvtepi64_ps(low(a).m), + _mm512_cvtepi64_ps(high(a).m))) { } + + ENOKI_CONVERT(uint64_t) + : m(detail::concat(_mm512_cvtepu64_ps(low(a).m), + _mm512_cvtepu64_ps(high(a).m))) { } +#elif defined(ENOKI_X86_AVX512CD) + /* Emulate uint64_t -> float conversion instead of falling + back to scalar operations. This is quite a bit faster + (>6x for unsigned, >5x for signed). */ + + ENOKI_CONVERT(uint64_t) { + using Int64 = int64_array_t; + using Int32 = uint32_array_t; + + auto lz = lzcnt(a); + auto shift = (63 - 23) - Int64(lz); + auto abs_shift = abs(shift); + auto nzero_mask = neq(a, 0ull); + auto mant = select(shift > 0, a >> abs_shift, a << abs_shift); + auto exp = sl<23>(uint64_t(127 + 63) - lz) & nzero_mask; + auto comb = exp | (mant & 0x7fffffull); + + m = reinterpret_array(Int32(comb)).m; + } + + ENOKI_CONVERT(int64_t) { + using Int32 = uint32_array_t; + + auto b = abs(a), lz = lzcnt(b); + auto shift = (63 - 23) - lz; + auto abs_shift = abs(shift); + auto nzero_mask = neq(a, 0ll); + auto mant = select(shift > 0, b >> abs_shift, b << abs_shift); + auto sign = sr<32>(a) & 0x80000000ll; + auto exp = sl<23>(int64_t(127 + 63) - lz) & nzero_mask; + auto comb = exp | (mant & 0x7fffffll) | sign; + + m = reinterpret_array(Int32(comb)).m; + } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Reinterpreting constructors, mask converters + // ----------------------------------------------------------------------- + + ENOKI_REINTERPRET(float) : m(a.derived().m) { } + + ENOKI_REINTERPRET(int32_t) : m(_mm512_castsi512_ps(a.derived().m)) { } + ENOKI_REINTERPRET(uint32_t) : m(_mm512_castsi512_ps(a.derived().m)) { } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Converting from/to half size vectors + // ----------------------------------------------------------------------- + + StaticArrayImpl(const Array1 &a1, const Array2 &a2) + : m(detail::concat(a1.m, a2.m)) { } + + ENOKI_INLINE Array1 low_() const { return _mm512_castps512_ps256(m); } + ENOKI_INLINE Array2 high_() const { + #if defined(ENOKI_X86_AVX512DQ) + return _mm512_extractf32x8_ps(m, 1); + #else + return _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(m), 1)); + #endif + } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Vertical operations + // ----------------------------------------------------------------------- + + ENOKI_INLINE Derived add_(Ref a) const { return _mm512_add_ps(m, a.m); } + ENOKI_INLINE Derived sub_(Ref a) const { return _mm512_sub_ps(m, a.m); } + ENOKI_INLINE Derived mul_(Ref a) const { return _mm512_mul_ps(m, a.m); } + ENOKI_INLINE Derived div_(Ref a) const { return _mm512_div_ps(m, a.m); } + + template ENOKI_INLINE Derived or_(const T &a) const { + if constexpr (is_mask_v) { + return _mm512_mask_mov_ps(m, a.k, _mm512_set1_ps(memcpy_cast(int32_t(-1)))); + } else { + #if defined(ENOKI_X86_AVX512DQ) + return _mm512_or_ps(m, a.m); + #else + return _mm512_castsi512_ps( + _mm512_or_si512(_mm512_castps_si512(m), _mm512_castps_si512(a.m))); + #endif + } + } + + template ENOKI_INLINE Derived and_(const T &a) const { + if constexpr (is_mask_v) { + return _mm512_maskz_mov_ps(a.k, m); + } else { + #if defined(ENOKI_X86_AVX512DQ) + return _mm512_and_ps(m, a.m); + #else + return _mm512_castsi512_ps( + _mm512_and_si512(_mm512_castps_si512(m), _mm512_castps_si512(a.m))); + #endif + } + } + + template ENOKI_INLINE Derived andnot_(const T &a) const { + if constexpr (is_mask_v) { + return _mm512_mask_mov_ps(m, a.k, _mm512_setzero_ps()); + } else { + #if defined(ENOKI_X86_AVX512DQ) + return _mm512_andnot_ps(a.m, m); + #else + return _mm512_castsi512_ps( + _mm512_andnot_si512(_mm512_castps_si512(a.m), _mm512_castps_si512(m))); + #endif + } + } + + template ENOKI_INLINE Derived xor_(const T &a) const { + if constexpr (is_mask_v) { + const __m512 c = _mm512_set1_ps(memcpy_cast(int32_t(-1))); + #if defined(ENOKI_X86_AVX512DQ) + return _mm512_mask_xor_ps(m, a.k, m, c); + #else + const __m512i v0 = _mm512_castps_si512(m); + return _mm512_castsi512_ps(_mm512_mask_xor_epi32(v0, a.k, v0, c)); + #endif + } else { + #if defined(ENOKI_X86_AVX512DQ) + return _mm512_xor_ps(m, a.m); + #else + return _mm512_castsi512_ps( + _mm512_xor_si512(_mm512_castps_si512(m), _mm512_castps_si512(a.m))); + #endif + } + } + + ENOKI_INLINE auto lt_ (Ref a) const { return mask_t::from_k(_mm512_cmp_ps_mask(m, a.m, _CMP_LT_OQ)); } + ENOKI_INLINE auto gt_ (Ref a) const { return mask_t::from_k(_mm512_cmp_ps_mask(m, a.m, _CMP_GT_OQ)); } + ENOKI_INLINE auto le_ (Ref a) const { return mask_t::from_k(_mm512_cmp_ps_mask(m, a.m, _CMP_LE_OQ)); } + ENOKI_INLINE auto ge_ (Ref a) const { return mask_t::from_k(_mm512_cmp_ps_mask(m, a.m, _CMP_GE_OQ)); } + ENOKI_INLINE auto eq_ (Ref a) const { return mask_t::from_k(_mm512_cmp_ps_mask(m, a.m, _CMP_EQ_OQ)); } + ENOKI_INLINE auto neq_(Ref a) const { return mask_t::from_k(_mm512_cmp_ps_mask(m, a.m, _CMP_NEQ_UQ)); } + + ENOKI_INLINE Derived abs_() const { return andnot_(Derived(_mm512_set1_ps(-0.f))); } + + ENOKI_INLINE Derived min_(Ref b) const { return _mm512_min_ps(b.m, m); } + ENOKI_INLINE Derived max_(Ref b) const { return _mm512_max_ps(b.m, m); } + ENOKI_INLINE Derived ceil_() const { return _mm512_ceil_ps(m); } + ENOKI_INLINE Derived floor_() const { return _mm512_floor_ps(m); } + ENOKI_INLINE Derived sqrt_() const { return _mm512_sqrt_ps(m); } + + ENOKI_INLINE Derived round_() const { + return _mm512_roundscale_ps(m, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + } + + ENOKI_INLINE Derived trunc_() const { + return _mm512_roundscale_ps(m, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); + } + + template + ENOKI_INLINE auto ceil2int_() const { + if constexpr (sizeof(scalar_t) == 4) { + if constexpr (std::is_signed_v>) + return T(_mm512_cvt_roundps_epi32(m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)); + else + return T(_mm512_cvt_roundps_epu32(m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)); + } else { + #if defined(ENOKI_X86_AVX512DQ) + using A = typename T::Array1; + if constexpr (std::is_signed_v>) + return T( + A(_mm512_cvt_roundps_epi64(low(derived()).m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)), + A(_mm512_cvt_roundps_epi64(high(derived()).m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)) + ); + else + return T( + A(_mm512_cvt_roundps_epu64(low(derived()).m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)), + A(_mm512_cvt_roundps_epu64(high(derived()).m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)) + ); + #else + return Base::template ceil2int_(); + #endif + } + } + + template + ENOKI_INLINE auto floor2int_() const { + if constexpr (sizeof(scalar_t) == 4) { + if constexpr (std::is_signed_v>) + return T(_mm512_cvt_roundps_epi32(m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)); + else + return T(_mm512_cvt_roundps_epu32(m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)); + } else { + #if defined(ENOKI_X86_AVX512DQ) + using A = typename T::Array1; + + if constexpr (std::is_signed_v>) + return T( + A(_mm512_cvt_roundps_epi64(low(derived()).m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)), + A(_mm512_cvt_roundps_epi64(high(derived()).m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)) + ); + else + return T( + A(_mm512_cvt_roundps_epu64(low(derived()).m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)), + A(_mm512_cvt_roundps_epu64(high(derived()).m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)) + ); + #else + return Base::template floor2int_(); + #endif + } + } + + ENOKI_INLINE Derived fmadd_ (Ref b, Ref c) const { return _mm512_fmadd_ps (m, b.m, c.m); } + ENOKI_INLINE Derived fmsub_ (Ref b, Ref c) const { return _mm512_fmsub_ps (m, b.m, c.m); } + ENOKI_INLINE Derived fnmadd_ (Ref b, Ref c) const { return _mm512_fnmadd_ps (m, b.m, c.m); } + ENOKI_INLINE Derived fnmsub_ (Ref b, Ref c) const { return _mm512_fnmsub_ps (m, b.m, c.m); } + ENOKI_INLINE Derived fmsubadd_(Ref b, Ref c) const { return _mm512_fmsubadd_ps(m, b.m, c.m); } + ENOKI_INLINE Derived fmaddsub_(Ref b, Ref c) const { return _mm512_fmaddsub_ps(m, b.m, c.m); } + + template + static ENOKI_INLINE Derived select_(const Mask &m, Ref t, Ref f) { + return _mm512_mask_blend_ps(m.k, f.m, t.m); + } + + template + ENOKI_INLINE Derived shuffle_() const { + const __m512i idx = + _mm512_setr_epi32(I0, I1, I2, I3, I4, I5, I6, I7, I8, + I9, I10, I11, I12, I13, I14, I15); + return _mm512_permutexvar_ps(idx, m); + } + + template ENOKI_INLINE Derived shuffle_(const Index &index) const { + return _mm512_permutexvar_ps(index.m, m); + } + + ENOKI_INLINE Derived rcp_() const { + #if defined(ENOKI_X86_AVX512ER) + /* rel err < 2^28, use as is */ + return _mm512_rcp28_ps(m); + #else + /* Use best reciprocal approximation available on the current + hardware and refine */ + __m512 r = _mm512_rcp14_ps(m); /* rel error < 2^-14 */ + + /* Refine using one Newton-Raphson iteration */ + __m512 t0 = _mm512_add_ps(r, r), + t1 = _mm512_mul_ps(r, m); + + r = _mm512_fnmadd_ps(t1, r, t0); + + return _mm512_fixupimm_ps(r, m, + _mm512_set1_epi32(0x0087A622), 0); + #endif + } + + ENOKI_INLINE Derived rsqrt_() const { + #if defined(ENOKI_X86_AVX512ER) + /* rel err < 2^28, use as is */ + return _mm512_rsqrt28_ps(m); + #else + __m512 r = _mm512_rsqrt14_ps(m); /* rel error < 2^-14 */ + + /* Refine using one Newton-Raphson iteration */ + const __m512 c0 = _mm512_set1_ps(0.5f), + c1 = _mm512_set1_ps(3.0f); + + __m512 t0 = _mm512_mul_ps(r, c0), + t1 = _mm512_mul_ps(r, m); + + r = _mm512_mul_ps(_mm512_fnmadd_ps(t1, r, c1), t0); + + return _mm512_fixupimm_ps(r, m, + _mm512_set1_epi32(0x0383A622), 0); + #endif + } + + ENOKI_INLINE Derived ldexp_(Ref arg) const { return _mm512_scalef_ps(m, arg.m); } + + ENOKI_INLINE std::pair frexp_() const { + return std::make_pair( + _mm512_getmant_ps(m, _MM_MANT_NORM_p5_1, _MM_MANT_SIGN_src), + _mm512_getexp_ps(m)); + } + + // ----------------------------------------------------------------------- + //! @{ \name Horizontal operations + // ----------------------------------------------------------------------- + + ENOKI_INLINE Value hsum_() const { return hsum(low_() + high_()); } + ENOKI_INLINE Value hprod_() const { return hprod(low_() * high_()); } + ENOKI_INLINE Value hmin_() const { return hmin(min(low_(), high_())); } + ENOKI_INLINE Value hmax_() const { return hmax(max(low_(), high_())); } + + //! @} + // ----------------------------------------------------------------------- + + ENOKI_INLINE void store_(void *ptr) const { + assert((uintptr_t) ptr % 64 == 0); + _mm512_store_ps((Value *) ENOKI_ASSUME_ALIGNED(ptr, 64), m); + } + + template + ENOKI_INLINE void store_(void *ptr, const Mask &mask) const { + assert((uintptr_t) ptr % 64 == 0); + _mm512_mask_store_ps((Value *) ENOKI_ASSUME_ALIGNED(ptr, 64), mask.k, m); + } + + ENOKI_INLINE void store_unaligned_(void *ptr) const { + _mm512_storeu_ps((Value *) ptr, m); + } + + template + ENOKI_INLINE void store_unaligned_(void *ptr, const Mask &mask) const { + _mm512_mask_storeu_ps((Value *) ptr, mask.k, m); + } + + static ENOKI_INLINE Derived load_(const void *ptr) { + assert((uintptr_t) ptr % 64 == 0); + return _mm512_load_ps((const Value *) ENOKI_ASSUME_ALIGNED(ptr, 64)); + } + + template + static ENOKI_INLINE Derived load_(const void *ptr, const Mask &mask) { + assert((uintptr_t) ptr % 64 == 0); + return _mm512_maskz_load_ps(mask.k, (const Value *) ENOKI_ASSUME_ALIGNED(ptr, 64)); + } + + static ENOKI_INLINE Derived load_unaligned_(const void *ptr) { + return _mm512_loadu_ps((const Value *) ptr); + } + + template + static ENOKI_INLINE Derived load_unaligned_(const void *ptr, const Mask &mask) { + return _mm512_maskz_loadu_ps(mask.k, (const Value *) ptr); + } + + static ENOKI_INLINE Derived zero_() { return _mm512_setzero_ps(); } + +#if defined(ENOKI_X86_AVX512PF) + template + static ENOKI_INLINE void prefetch_(const void *ptr, const Index &index, const Mask &mask) { + constexpr auto Hint = Level == 1 ? _MM_HINT_T0 : _MM_HINT_T1; + if constexpr (sizeof(scalar_t) == 4) { + if constexpr (Write) + _mm512_mask_prefetch_i32scatter_ps((void *) ptr, mask.k, index.m, Stride, Hint); + else + _mm512_mask_prefetch_i32gather_ps(index.m, mask.k, ptr, Stride, Hint); + } else { + if constexpr (Write) { + _mm512_mask_prefetch_i64scatter_ps((void *) ptr, low(mask).k, low(index).m, Stride, Hint); + _mm512_mask_prefetch_i64scatter_ps((void *) ptr, high(mask).k, high(index).m, Stride, Hint); + } else { + _mm512_mask_prefetch_i64gather_ps(low(index).m, low(mask).k, ptr, Stride, Hint); + _mm512_mask_prefetch_i64gather_ps(high(index).m, high(mask).k, ptr, Stride, Hint); + } + } + } +#endif + + template + static ENOKI_INLINE Derived gather_(const void *ptr, const Index &index, const Mask &mask) { + if constexpr (sizeof(scalar_t) == 4) { + return _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask.k, index.m, (const float *) ptr, Stride); + } else { + return detail::concat( + _mm512_mask_i64gather_ps(_mm256_setzero_ps(), low(mask).k, low(index).m, (const float *) ptr, Stride), + _mm512_mask_i64gather_ps(_mm256_setzero_ps(), high(mask).k, high(index).m, (const float *) ptr, Stride)); + } + } + + template + ENOKI_INLINE void scatter_(void *ptr, const Index &index, const Mask &mask) const { + if constexpr (sizeof(scalar_t) == 4) { + _mm512_mask_i32scatter_ps(ptr, mask.k, index.m, m, Stride); + } else { + _mm512_mask_i64scatter_ps(ptr, low(mask).k, low(index).m, low(derived()).m, Stride); + _mm512_mask_i64scatter_ps(ptr, high(mask).k, high(index).m, high(derived()).m, Stride); + } + } + + template + ENOKI_INLINE Value extract_(const Mask &mask) const { + return _mm_cvtss_f32(_mm512_castps512_ps128(_mm512_maskz_compress_ps(mask.k, m))); + } + + template + ENOKI_INLINE size_t compress_(float *&ptr, const Mask &mask) const { + _mm512_storeu_ps(ptr, _mm512_maskz_compress_ps(mask.k, m)); + size_t kn = (size_t) _mm_popcnt_u32(mask.k); + ptr += kn; + return kn; + } + +#if defined(ENOKI_X86_AVX512CD) + template + static ENOKI_INLINE void transform_(void *mem, + Index index, + const Mask &mask, + const Func &func, + const Args &... args) { + Derived values = _mm512_mask_i32gather_ps( + _mm512_undefined_ps(), mask.k, index.m, mem, (int) Stride); + + index.m = _mm512_mask_mov_epi32(_mm512_set1_epi32(-1), mask.k, index.m); + + __m512i conflicts = _mm512_conflict_epi32(index.m); + __m512i perm_idx = _mm512_sub_epi32(_mm512_set1_epi32(31), _mm512_lzcnt_epi32(conflicts)); + __mmask16 todo = _mm512_mask_test_epi32_mask(mask.k, conflicts, _mm512_set1_epi32(-1)); + + func(values, args...); + + ENOKI_NOUNROLL while (ENOKI_UNLIKELY(!_mm512_kortestz(todo, todo))) { + __mmask16 cur = _mm512_mask_testn_epi32_mask( + todo, conflicts, _mm512_broadcastmw_epi32(todo)); + values.m = _mm512_mask_permutexvar_ps(values.m, cur, perm_idx, values.m); + + __m512 backup(values.m); + func(values, args...); + + values.m = _mm512_mask_mov_ps(backup, cur, values.m); + todo = _mm512_kxor(todo, cur); + } + + _mm512_mask_i32scatter_ps(mem, mask.k, index.m, values.m, (int) Stride); + } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Masked versions of key operations + // ----------------------------------------------------------------------- + + template + ENOKI_INLINE void massign_(const Derived &a, const Mask &mask) { m = _mm512_mask_mov_ps(m, mask.k, a.m); } + template + ENOKI_INLINE void madd_ (const Derived &a, const Mask &mask) { m = _mm512_mask_add_ps(m, mask.k, m, a.m); } + template + ENOKI_INLINE void msub_ (const Derived &a, const Mask &mask) { m = _mm512_mask_sub_ps(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mmul_ (const Derived &a, const Mask &mask) { m = _mm512_mask_mul_ps(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mdiv_ (const Derived &a, const Mask &mask) { m = _mm512_mask_div_ps(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mor_ (const Derived &a, const Mask &mask) { + #if defined(ENOKI_X86_AVX512DQ) + m = _mm512_mask_or_ps(m, mask.k, m, a.m); + #else + m = _mm512_castsi512_ps( + _mm512_or_si512(_mm512_castps_si512(m), mask.k, + _mm512_castps_si512(m), _mm512_castps_si512(a.m))); + #endif + } + + template + ENOKI_INLINE void mand_ (const Derived &a, const Mask &mask) { + #if defined(ENOKI_X86_AVX512DQ) + m = _mm512_mask_and_ps(m, mask.k, m, a.m); + #else + m = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(m), mask.k, + _mm512_castps_si512(m), + _mm512_castps_si512(a.m))); + #endif + } + + template + ENOKI_INLINE void mxor_ (const Derived &a, const Mask &mask) { + #if defined(ENOKI_X86_AVX512DQ) + m = _mm512_mask_xor_ps(m, mask.k, m, a.m); + #else + m = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(m), mask.k, + _mm512_castps_si512(m), + _mm512_castps_si512(a.m))); + #endif + } + + //! @} + // ----------------------------------------------------------------------- +} ENOKI_MAY_ALIAS; + +/// Partial overload of StaticArrayImpl using AVX512 intrinsics (double precision) +template struct alignas(64) + StaticArrayImpl + : StaticArrayBase { + ENOKI_NATIVE_ARRAY(double, 8, __m512d) + + // ----------------------------------------------------------------------- + //! @{ \name Value constructors + // ----------------------------------------------------------------------- + + ENOKI_INLINE StaticArrayImpl(const Value &value) : m(_mm512_set1_pd(value)) { } + ENOKI_INLINE StaticArrayImpl(Value f0, Value f1, Value f2, Value f3, + Value f4, Value f5, Value f6, Value f7) + : m(_mm512_setr_pd(f0, f1, f2, f3, f4, f5, f6, f7)) { } + + + // ----------------------------------------------------------------------- + //! @{ \name Type converting constructors + // ----------------------------------------------------------------------- + + ENOKI_CONVERT(half) + : m(_mm512_cvtps_pd( + _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *) a.derived().data())))) { } + + ENOKI_CONVERT(float) : m(_mm512_cvtps_pd(a.derived().m)) { } + + ENOKI_CONVERT(double) : m(a.derived().m) { } + + ENOKI_CONVERT(int32_t) : m(_mm512_cvtepi32_pd(a.derived().m)) { } + + ENOKI_CONVERT(uint32_t) : m(_mm512_cvtepu32_pd(a.derived().m)) { } + +#if defined(ENOKI_X86_AVX512DQ) + ENOKI_CONVERT(int64_t) + : m(_mm512_cvtepi64_pd(a.derived().m)) { } + + ENOKI_CONVERT(uint64_t) + : m(_mm512_cvtepu64_pd(a.derived().m)) { } +#elif defined(ENOKI_X86_AVX512CD) + /* Emulate uint64_t -> double conversion instead of falling + back to scalar operations. This is quite a bit faster + (>5.5x for unsigned, > for signed). */ + + ENOKI_CONVERT(uint64_t) { + using Int64 = int64_array_t; + + auto lz = lzcnt(a); + auto shift = (63 - 52) - Int64(lz); + auto abs_shift = abs(shift); + auto nzero_mask = neq(a, 0ull); + auto mant = select(shift > 0, a >> abs_shift, a << abs_shift); + auto exp = sl<52>(uint64_t(1023 + 63) - lz) & nzero_mask; + auto comb = exp | (mant & 0xfffffffffffffull); + + m = reinterpret_array(comb).m; + } + + ENOKI_CONVERT(int64_t) { + auto b = abs(a), lz = lzcnt(b); + auto shift = (63 - 52) - lz; + auto abs_shift = abs(shift); + auto nzero_mask = neq(a, 0ll); + auto mant = select(shift > 0, b >> abs_shift, b << abs_shift); + auto sign = a & 0x8000000000000000ull; + auto exp = sl<52>(int64_t(1023 + 63) - lz) & nzero_mask; + auto comb = exp | (mant & 0xfffffffffffffull) | sign; + + m = reinterpret_array(comb).m; + } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Reinterpreting constructors, mask converters + // ----------------------------------------------------------------------- + + ENOKI_REINTERPRET(double) : m(a.derived().m) { } + + ENOKI_REINTERPRET(int64_t) : m(_mm512_castsi512_pd(a.derived().m)) { } + ENOKI_REINTERPRET(uint64_t) : m(_mm512_castsi512_pd(a.derived().m)) { } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Converting from/to half size vectors + // ----------------------------------------------------------------------- + + StaticArrayImpl(const Array1 &a1, const Array2 &a2) + : m(detail::concat(a1.m, a2.m)) { } + + ENOKI_INLINE Array1 low_() const { return _mm512_castpd512_pd256(m); } + ENOKI_INLINE Array2 high_() const { return _mm512_extractf64x4_pd(m, 1); } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Vertical operations + // ----------------------------------------------------------------------- + + ENOKI_INLINE Derived add_(Ref a) const { return _mm512_add_pd(m, a.m); } + ENOKI_INLINE Derived sub_(Ref a) const { return _mm512_sub_pd(m, a.m); } + ENOKI_INLINE Derived mul_(Ref a) const { return _mm512_mul_pd(m, a.m); } + ENOKI_INLINE Derived div_(Ref a) const { return _mm512_div_pd(m, a.m); } + + template ENOKI_INLINE Derived or_(const T &a) const { + if constexpr (is_mask_v) { + return _mm512_mask_mov_pd(m, a.k, _mm512_set1_pd(memcpy_cast(int64_t(-1)))); + } else { + #if defined(ENOKI_X86_AVX512DQ) + return _mm512_or_pd(m, a.m); + #else + return _mm512_castsi512_pd( + _mm512_or_si512(_mm512_castpd_si512(m), _mm512_castpd_si512(a.m))); + #endif + } + } + + template ENOKI_INLINE Derived and_(const T &a) const { + if constexpr (is_mask_v) { + return _mm512_maskz_mov_pd(a.k, m); + } else { + #if defined(ENOKI_X86_AVX512DQ) + return _mm512_and_pd(m, a.m); + #else + return _mm512_castsi512_pd( + _mm512_and_si512(_mm512_castpd_si512(m), _mm512_castpd_si512(a.m))); + #endif + } + } + + template ENOKI_INLINE Derived andnot_(const T &a) const { + if constexpr (is_mask_v) { + return _mm512_mask_mov_pd(m, a.k, _mm512_setzero_pd()); + } else { + #if defined(ENOKI_X86_AVX512DQ) + return _mm512_andnot_pd(a.m, m); + #else + return _mm512_castsi512_pd( + _mm512_andnot_si512(_mm512_castpd_si512(a.m), _mm512_castpd_si512(m))); + #endif + } + } + + template ENOKI_INLINE Derived xor_(const T &a) const { + if constexpr (is_mask_v) { + const __m512 c = _mm512_set1_pd(memcpy_cast(int64_t(-1))); + #if defined(ENOKI_X86_AVX512DQ) + return _mm512_mask_xor_pd(m, a.k, m, c); + #else + const __m512i v0 = _mm512_castpd_si512(m); + return _mm512_castsi512_pd(_mm512_mask_xor_epi64(v0, a.k, v0, c)); + #endif + } else { + #if defined(ENOKI_X86_AVX512DQ) + return _mm512_xor_pd(m, a.m); + #else + return _mm512_castsi512_pd( + _mm512_xor_si512(_mm512_castpd_si512(m), _mm512_castpd_si512(a.m))); + #endif + } + } + + ENOKI_INLINE auto lt_ (Ref a) const { return mask_t::from_k(_mm512_cmp_pd_mask(m, a.m, _CMP_LT_OQ)); } + ENOKI_INLINE auto gt_ (Ref a) const { return mask_t::from_k(_mm512_cmp_pd_mask(m, a.m, _CMP_GT_OQ)); } + ENOKI_INLINE auto le_ (Ref a) const { return mask_t::from_k(_mm512_cmp_pd_mask(m, a.m, _CMP_LE_OQ)); } + ENOKI_INLINE auto ge_ (Ref a) const { return mask_t::from_k(_mm512_cmp_pd_mask(m, a.m, _CMP_GE_OQ)); } + ENOKI_INLINE auto eq_ (Ref a) const { return mask_t::from_k(_mm512_cmp_pd_mask(m, a.m, _CMP_EQ_OQ)); } + ENOKI_INLINE auto neq_(Ref a) const { return mask_t::from_k(_mm512_cmp_pd_mask(m, a.m, _CMP_NEQ_UQ)); } + + ENOKI_INLINE Derived abs_() const { return andnot_(Derived(_mm512_set1_pd(-0.0))); } + + ENOKI_INLINE Derived min_(Ref b) const { return _mm512_min_pd(b.m, m); } + ENOKI_INLINE Derived max_(Ref b) const { return _mm512_max_pd(b.m, m); } + ENOKI_INLINE Derived ceil_() const { return _mm512_ceil_pd(m); } + ENOKI_INLINE Derived floor_() const { return _mm512_floor_pd(m); } + ENOKI_INLINE Derived sqrt_() const { return _mm512_sqrt_pd(m); } + + template + ENOKI_INLINE auto ceil2int_() const { + if constexpr (sizeof(scalar_t) == 4) { + if constexpr (std::is_signed_v>) + return T(_mm512_cvt_roundpd_epi32(m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)); + else + return T(_mm512_cvt_roundpd_epu32(m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)); + } else { + #if defined(ENOKI_X86_AVX512DQ) + if constexpr (std::is_signed_v>) + return T(_mm512_cvt_roundpd_epi64(m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)); + else + return T(_mm512_cvt_roundpd_epu64(m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)); + #else + return Base::template ceil2int_(); + #endif + } + } + + template + ENOKI_INLINE auto floor2int_() const { + if constexpr (sizeof(scalar_t) == 4) { + if constexpr (std::is_signed_v>) + return T(_mm512_cvt_roundpd_epi32(m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)); + else + return T(_mm512_cvt_roundpd_epu32(m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)); + } else { + #if defined(ENOKI_X86_AVX512DQ) + if constexpr (std::is_signed_v>) + return T(_mm512_cvt_roundpd_epi64(m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)); + else + return T(_mm512_cvt_roundpd_epu64(m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)); + #else + return Base::template floor2int_(); + #endif + } + } + + ENOKI_INLINE Derived round_() const { + return _mm512_roundscale_pd(m, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + } + + ENOKI_INLINE Derived trunc_() const { + return _mm512_roundscale_pd(m, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); + } + + ENOKI_INLINE Derived fmadd_ (Ref b, Ref c) const { return _mm512_fmadd_pd (m, b.m, c.m); } + ENOKI_INLINE Derived fmsub_ (Ref b, Ref c) const { return _mm512_fmsub_pd (m, b.m, c.m); } + ENOKI_INLINE Derived fnmadd_ (Ref b, Ref c) const { return _mm512_fnmadd_pd (m, b.m, c.m); } + ENOKI_INLINE Derived fnmsub_ (Ref b, Ref c) const { return _mm512_fnmsub_pd (m, b.m, c.m); } + ENOKI_INLINE Derived fmsubadd_(Ref b, Ref c) const { return _mm512_fmsubadd_pd(m, b.m, c.m); } + ENOKI_INLINE Derived fmaddsub_(Ref b, Ref c) const { return _mm512_fmaddsub_pd(m, b.m, c.m); } + + template + static ENOKI_INLINE Derived select_(const Mask &m, Ref t, Ref f) { + return _mm512_mask_blend_pd(m.k, f.m, t.m); + } + + template + ENOKI_INLINE Derived shuffle_() const { + const __m512i idx = + _mm512_setr_epi64(I0, I1, I2, I3, I4, I5, I6, I7); + return _mm512_permutexvar_pd(idx, m); + } + + template ENOKI_INLINE Derived shuffle_(const Index &index) const { + return _mm512_permutexvar_pd(index.m, m); + } + + ENOKI_INLINE Derived rcp_() const { + /* Use best reciprocal approximation available on the current + hardware and refine */ + __m512d r; + + #if defined(ENOKI_X86_AVX512ER) + r = _mm512_rcp28_pd(m); /* rel err < 2^28 */ + #else + r = _mm512_rcp14_pd(m); /* rel error < 2^-14 */ + #endif + + /* Refine using 1-2 Newton-Raphson iterations */ + ENOKI_UNROLL for (int i = 0; i < (has_avx512er ? 1 : 2); ++i) { + __m512d t0 = _mm512_add_pd(r, r); + __m512d t1 = _mm512_mul_pd(r, m); + + r = _mm512_fnmadd_pd(t1, r, t0); + } + + return _mm512_fixupimm_pd(r, m, + _mm512_set1_epi32(0x0087A622), 0); + } + + ENOKI_INLINE Derived rsqrt_() const { + /* Use best reciprocal square root approximation available + on the current hardware and refine */ + __m512d r; + #if defined(ENOKI_X86_AVX512ER) + r = _mm512_rsqrt28_pd(m); /* rel err < 2^28 */ + #else + r = _mm512_rsqrt14_pd(m); /* rel error < 2^-14 */ + #endif + + const __m512d c0 = _mm512_set1_pd(0.5), + c1 = _mm512_set1_pd(3.0); + + /* Refine using 1-2 Newton-Raphson iterations */ + ENOKI_UNROLL for (int i = 0; i < (has_avx512er ? 1 : 2); ++i) { + __m512d t0 = _mm512_mul_pd(r, c0); + __m512d t1 = _mm512_mul_pd(r, m); + + r = _mm512_mul_pd(_mm512_fnmadd_pd(t1, r, c1), t0); + } + + return _mm512_fixupimm_pd(r, m, + _mm512_set1_epi32(0x0383A622), 0); + } + + + ENOKI_INLINE Derived ldexp_(Ref arg) const { return _mm512_scalef_pd(m, arg.m); } + + ENOKI_INLINE std::pair frexp_() const { + return std::make_pair( + _mm512_getmant_pd(m, _MM_MANT_NORM_p5_1, _MM_MANT_SIGN_src), + _mm512_getexp_pd(m)); + } + + // ----------------------------------------------------------------------- + //! @{ \name Horizontal operations + // ----------------------------------------------------------------------- + + ENOKI_INLINE Value hsum_() const { return hsum(low_() + high_()); } + ENOKI_INLINE Value hprod_() const { return hprod(low_() * high_()); } + ENOKI_INLINE Value hmin_() const { return hmin(min(low_(), high_())); } + ENOKI_INLINE Value hmax_() const { return hmax(max(low_(), high_())); } + + //! @} + // ----------------------------------------------------------------------- + + ENOKI_INLINE void store_(void *ptr) const { + assert((uintptr_t) ptr % 64 == 0); + _mm512_store_pd((Value *) ENOKI_ASSUME_ALIGNED(ptr, 64), m); + } + + template + ENOKI_INLINE void store_(void *ptr, const Mask &mask) const { + assert((uintptr_t) ptr % 64 == 0); + _mm512_mask_store_pd((Value *) ENOKI_ASSUME_ALIGNED(ptr, 64), mask.k, m); + } + + ENOKI_INLINE void store_unaligned_(void *ptr) const { + _mm512_storeu_pd((Value *) ptr, m); + } + + template + ENOKI_INLINE void store_unaligned_(void *ptr, const Mask &mask) const { + _mm512_mask_storeu_pd((Value *) ptr, mask.k, m); + } + + static ENOKI_INLINE Derived load_(const void *ptr) { + assert((uintptr_t) ptr % 64 == 0); + return _mm512_load_pd((const Value *) ENOKI_ASSUME_ALIGNED(ptr, 64)); + } + + template + static ENOKI_INLINE Derived load_(const void *ptr, const Mask &mask) { + assert((uintptr_t) ptr % 64 == 0); + return _mm512_maskz_load_pd(mask.k, (const Value *) ENOKI_ASSUME_ALIGNED(ptr, 64)); + } + + static ENOKI_INLINE Derived load_unaligned_(const void *ptr) { + return _mm512_loadu_pd((const Value *) ptr); + } + + template + static ENOKI_INLINE Derived load_unaligned_(const void *ptr, const Mask &mask) { + return _mm512_maskz_loadu_pd(mask.k, (const Value *) ptr); + } + + static ENOKI_INLINE Derived zero_() { return _mm512_setzero_pd(); } + +#if defined(ENOKI_X86_AVX512PF) + template + static ENOKI_INLINE void prefetch_(const void *ptr, const Index &index, const Mask &mask) { + constexpr auto Hint = Level == 1 ? _MM_HINT_T0 : _MM_HINT_T1; + if constexpr (sizeof(scalar_t) == 4) { + if (Write) + _mm512_mask_prefetch_i32scatter_pd((void *) ptr, mask.k, index.m, Stride, Hint); + else + _mm512_mask_prefetch_i32gather_pd(index.m, mask.k, ptr, Stride, Hint); + } else { + if (Write) + _mm512_mask_prefetch_i64scatter_pd((void *) ptr, mask.k, index.m, Stride, Hint); + else + _mm512_mask_prefetch_i64gather_pd(index.m, mask.k, ptr, Stride, Hint); + } + } +#endif + + template + static ENOKI_INLINE Derived gather_(const void *ptr, const Index &index, const Mask &mask) { + if constexpr (sizeof(scalar_t) == 4) + return _mm512_mask_i32gather_pd(_mm512_setzero_pd(), mask.k, index.m, (const double *) ptr, Stride); + else + return _mm512_mask_i64gather_pd(_mm512_setzero_pd(), mask.k, index.m, (const double *) ptr, Stride); + } + + template + ENOKI_INLINE void scatter_(void *ptr, const Index &index, const Mask &mask) const { + if constexpr (sizeof(scalar_t) == 4) + _mm512_mask_i32scatter_pd(ptr, mask.k, index.m, m, Stride); + else + _mm512_mask_i64scatter_pd(ptr, mask.k, index.m, m, Stride); + } + + template + ENOKI_INLINE Value extract_(const Mask &mask) const { + return _mm_cvtsd_f64(_mm512_castpd512_pd128(_mm512_maskz_compress_pd(mask.k, m))); + } + + template + ENOKI_INLINE size_t compress_(double *&ptr, const Mask &mask) const { + _mm512_storeu_pd(ptr, _mm512_maskz_compress_pd(mask.k, m)); + size_t kn = (size_t) _mm_popcnt_u32(mask.k); + ptr += kn; + return kn; + } + +#if defined(ENOKI_X86_AVX512CD) + template + static ENOKI_INLINE void transform_(void *mem, + Index index, + const Mask &mask, + const Func &func, + const Args &... args) { + Derived values = _mm512_mask_i64gather_pd( + _mm512_undefined_pd(), mask.k, index.m, mem, (int) Stride); + + index.m = _mm512_mask_mov_epi64(_mm512_set1_epi64(-1), mask.k, index.m); + + __m512i conflicts = _mm512_conflict_epi64(index.m); + __m512i perm_idx = _mm512_sub_epi64(_mm512_set1_epi64(63), _mm512_lzcnt_epi64(conflicts)); + __mmask8 todo = _mm512_mask_test_epi64_mask(mask.k, conflicts, _mm512_set1_epi64(-1)); + + func(values, args...); + + ENOKI_NOUNROLL while (ENOKI_UNLIKELY(todo)) { + __mmask8 cur = _mm512_mask_testn_epi64_mask( + todo, conflicts, _mm512_broadcastmb_epi64(todo)); + values.m = _mm512_mask_permutexvar_pd(values.m, cur, perm_idx, values.m); + + __m512d backup(values.m); + func(values, args...); + + values.m = _mm512_mask_mov_pd(backup, cur, values.m); + todo ^= cur; + } + + _mm512_mask_i64scatter_pd(mem, mask.k, index.m, values.m, (int) Stride); + } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Masked versions of key operations + // ----------------------------------------------------------------------- + + template + ENOKI_INLINE void massign_(const Derived &a, const Mask &mask) { m = _mm512_mask_mov_pd(m, mask.k, a.m); } + template + ENOKI_INLINE void madd_ (const Derived &a, const Mask &mask) { m = _mm512_mask_add_pd(m, mask.k, m, a.m); } + template + ENOKI_INLINE void msub_ (const Derived &a, const Mask &mask) { m = _mm512_mask_sub_pd(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mmul_ (const Derived &a, const Mask &mask) { m = _mm512_mask_mul_pd(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mdiv_ (const Derived &a, const Mask &mask) { m = _mm512_mask_div_pd(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mor_ (const Derived &a, const Mask &mask) { + #if defined(ENOKI_X86_AVX512DQ) + m = _mm512_mask_or_pd(m, mask.k, m, a.m); + #else + m = _mm512_castsi512_pd(_mm512_or_si512(_mm512_castpd_si512(m), mask.k, + _mm512_castpd_si512(m), + _mm512_castpd_si512(a.m))); + #endif + } + + template + ENOKI_INLINE void mand_ (const Derived &a, const Mask &mask) { + #if defined(ENOKI_X86_AVX512DQ) + m = _mm512_mask_and_pd(m, mask.k, m, a.m); + #else + m = _mm512_castsi512_pd(_mm512_and_si512(_mm512_castpd_si512(m), mask.k, + _mm512_castpd_si512(m), + _mm512_castpd_si512(a.m))); + #endif + } + + template + ENOKI_INLINE void mxor_ (const Derived &a, const Mask &mask) { + #if defined(ENOKI_X86_AVX512DQ) + m = _mm512_mask_xor_pd(m, mask.k, m, a.m); + #else + m = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(m), mask.k, + _mm512_castpd_si512(m), + _mm512_castpd_si512(a.m))); + #endif + } + + //! @} + // ----------------------------------------------------------------------- +} ENOKI_MAY_ALIAS; + +/// Partial overload of StaticArrayImpl using AVX512 intrinsics (32 bit integers) +template struct alignas(64) + StaticArrayImpl> + : StaticArrayBase { + ENOKI_NATIVE_ARRAY(Value_, 16, __m512i) + + + // ----------------------------------------------------------------------- + //! @{ \name Value constructors + // ----------------------------------------------------------------------- + + ENOKI_INLINE StaticArrayImpl(Value value) : m(_mm512_set1_epi32((int32_t) value)) { } + + ENOKI_INLINE StaticArrayImpl(Value f0, Value f1, Value f2, Value f3, + Value f4, Value f5, Value f6, Value f7, + Value f8, Value f9, Value f10, Value f11, + Value f12, Value f13, Value f14, Value f15) + : m(_mm512_setr_epi32( + (int32_t) f0, (int32_t) f1, (int32_t) f2, (int32_t) f3, + (int32_t) f4, (int32_t) f5, (int32_t) f6, (int32_t) f7, + (int32_t) f8, (int32_t) f9, (int32_t) f10, (int32_t) f11, + (int32_t) f12, (int32_t) f13, (int32_t) f14, (int32_t) f15)) { } + + // ----------------------------------------------------------------------- + //! @{ \name Type converting constructors + // ----------------------------------------------------------------------- + + ENOKI_CONVERT(int32_t) : m(a.derived().m) { } + ENOKI_CONVERT(uint32_t) : m(a.derived().m) { } + + ENOKI_CONVERT(float) { + m = std::is_signed_v ? _mm512_cvttps_epi32(a.derived().m) + : _mm512_cvttps_epu32(a.derived().m); + } + + ENOKI_CONVERT(double) { + m = std::is_signed_v + ? detail::concat(_mm512_cvttpd_epi32(low(a).m), + _mm512_cvttpd_epi32(high(a).m)) + : detail::concat(_mm512_cvttpd_epu32(low(a).m), + _mm512_cvttpd_epu32(high(a).m)); + } + + ENOKI_CONVERT(int64_t) + : m(detail::concat(_mm512_cvtepi64_epi32(low(a).m), + _mm512_cvtepi64_epi32(high(a).m))) { } + + ENOKI_CONVERT(uint64_t) + : m(detail::concat(_mm512_cvtepi64_epi32(low(a).m), + _mm512_cvtepi64_epi32(high(a).m))) { } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Reinterpreting constructors, mask converters + // ----------------------------------------------------------------------- + + ENOKI_REINTERPRET(float) : m(_mm512_castps_si512(a.derived().m)) { } + ENOKI_REINTERPRET(int32_t) : m(a.derived().m) { } + ENOKI_REINTERPRET(uint32_t) : m(a.derived().m) { } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Converting from/to half size vectors + // ----------------------------------------------------------------------- + + StaticArrayImpl(const Array1 &a1, const Array2 &a2) + : m(detail::concat(a1.m, a2.m)) { } + + ENOKI_INLINE Array1 low_() const { return _mm512_castsi512_si256(m); } + ENOKI_INLINE Array2 high_() const { + #if defined(ENOKI_X86_AVX512DQ) + return _mm512_extracti32x8_epi32(m, 1); + #else + return _mm512_extracti64x4_epi64(m, 1); + #endif + } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Vertical operations + // ----------------------------------------------------------------------- + + ENOKI_INLINE Derived add_(Ref a) const { return _mm512_add_epi32(m, a.m); } + ENOKI_INLINE Derived sub_(Ref a) const { return _mm512_sub_epi32(m, a.m); } + ENOKI_INLINE Derived mul_(Ref a) const { return _mm512_mullo_epi32(m, a.m); } + + template + ENOKI_INLINE Derived or_ (const T &a) const { + if constexpr (is_mask_v) + return _mm512_mask_mov_epi32(m, a.k, _mm512_set1_epi32(int32_t(-1))); + else + return _mm512_or_epi32(m, a.m); + } + + template + ENOKI_INLINE Derived and_ (const T &a) const { + if constexpr (is_mask_v) + return _mm512_maskz_mov_epi32(a.k, m); + else + return _mm512_and_epi32(m, a.m); + } + + template + ENOKI_INLINE Derived andnot_ (const T &a) const { + if constexpr (is_mask_v) + return _mm512_mask_mov_epi32(m, a.k, _mm512_setzero_si512()); + else + return _mm512_andnot_epi32(m, a.m); + } + + template + ENOKI_INLINE Derived xor_ (const T &a) const { + if constexpr (is_mask_v) + return _mm512_mask_xor_epi32(m, a.k, m, _mm512_set1_epi32(int32_t(-1))); + else + return _mm512_xor_epi32(m, a.m); + } + + template ENOKI_INLINE Derived sl_() const { + return _mm512_slli_epi32(m, (int) k); + } + + template ENOKI_INLINE Derived sr_() const { + return std::is_signed_v ? _mm512_srai_epi32(m, (int) k) + : _mm512_srli_epi32(m, (int) k); + } + + ENOKI_INLINE Derived sl_(size_t k) const { + return _mm512_sll_epi32(m, _mm_set1_epi64x((long long) k)); + } + + ENOKI_INLINE Derived sr_(size_t k) const { + return std::is_signed_v + ? _mm512_sra_epi32(m, _mm_set1_epi64x((long long) k)) + : _mm512_srl_epi32(m, _mm_set1_epi64x((long long) k)); + } + + ENOKI_INLINE Derived sl_(Ref k) const { + return _mm512_sllv_epi32(m, k.m); + } + + ENOKI_INLINE Derived sr_(Ref k) const { + return std::is_signed_v ? _mm512_srav_epi32(m, k.m) + : _mm512_srlv_epi32(m, k.m); + } + + ENOKI_INLINE Derived rol_(Ref k) const { return _mm512_rolv_epi32(m, k.m); } + ENOKI_INLINE Derived ror_(Ref k) const { return _mm512_rorv_epi32(m, k.m); } + + template + ENOKI_INLINE Derived rol_() const { return _mm512_rol_epi32(m, (int) Imm); } + + template + ENOKI_INLINE Derived ror_() const { return _mm512_ror_epi32(m, (int) Imm); } + + ENOKI_INLINE auto lt_ (Ref a) const { return mask_t::from_k(_mm512_cmp_epi32_mask(m, a.m, _MM_CMPINT_LT)); } + ENOKI_INLINE auto gt_ (Ref a) const { return mask_t::from_k(_mm512_cmp_epi32_mask(m, a.m, _MM_CMPINT_GT)); } + ENOKI_INLINE auto le_ (Ref a) const { return mask_t::from_k(_mm512_cmp_epi32_mask(m, a.m, _MM_CMPINT_LE)); } + ENOKI_INLINE auto ge_ (Ref a) const { return mask_t::from_k(_mm512_cmp_epi32_mask(m, a.m, _MM_CMPINT_GE)); } + ENOKI_INLINE auto eq_ (Ref a) const { return mask_t::from_k(_mm512_cmp_epi32_mask(m, a.m, _MM_CMPINT_EQ)); } + ENOKI_INLINE auto neq_(Ref a) const { return mask_t::from_k(_mm512_cmp_epi32_mask(m, a.m, _MM_CMPINT_NE)); } + + ENOKI_INLINE Derived min_(Ref a) const { + return std::is_signed_v ? _mm512_min_epi32(a.m, m) + : _mm512_min_epu32(a.m, m); + } + + ENOKI_INLINE Derived max_(Ref a) const { + return std::is_signed_v ? _mm512_max_epi32(a.m, m) + : _mm512_max_epu32(a.m, m); + } + + ENOKI_INLINE Derived abs_() const { + return std::is_signed_v ? _mm512_abs_epi32(m) : m; + } + + template + static ENOKI_INLINE Derived select_(const Mask &m, Ref t, Ref f) { + return _mm512_mask_blend_epi32(m.k, f.m, t.m); + } + + template + ENOKI_INLINE Derived shuffle_() const { + const __m512i idx = + _mm512_setr_epi32(I0, I1, I2, I3, I4, I5, I6, I7, I8, + I9, I10, I11, I12, I13, I14, I15); + return _mm512_permutexvar_epi32(idx, m); + } + + template ENOKI_INLINE Derived shuffle_(const Index &index) const { + return _mm512_permutexvar_epi32(index.m, m); + } + + ENOKI_INLINE Derived mulhi_(Ref a) const { + auto blend = mask_t::from_k(0b0101010101010101); + Derived even, odd; + + if constexpr (std::is_signed_v) { + even.m = _mm512_srli_epi64(_mm512_mul_epi32(m, a.m), 32); + odd.m = _mm512_mul_epi32(_mm512_srli_epi64(m, 32), + _mm512_srli_epi64(a.m, 32)); + } else { + even.m = _mm512_srli_epi64(_mm512_mul_epu32(m, a.m), 32); + odd.m = _mm512_mul_epu32(_mm512_srli_epi64(m, 32), + _mm512_srli_epi64(a.m, 32)); + } + + return select(blend, even, odd); + } + +#if defined(ENOKI_X86_AVX512CD) + ENOKI_INLINE Derived lzcnt_() const { return _mm512_lzcnt_epi32(m); } + ENOKI_INLINE Derived tzcnt_() const { return Value(32) - lzcnt(~derived() & (derived() - Value(1))); } +#endif + +#if defined(ENOKI_X86_AVX512VPOPCNTDQ) + ENOKI_INLINE Derived popcnt_() const { return _mm512_popcnt_epi32(m); } +#endif + + // ----------------------------------------------------------------------- + //! @{ \name Horizontal operations + // ----------------------------------------------------------------------- + + ENOKI_INLINE Value hsum_() const { return hsum(low_() + high_()); } + ENOKI_INLINE Value hprod_() const { return hprod(low_() * high_()); } + ENOKI_INLINE Value hmin_() const { return hmin(min(low_(), high_())); } + ENOKI_INLINE Value hmax_() const { return hmax(max(low_(), high_())); } + + //! @} + // ----------------------------------------------------------------------- + // + + ENOKI_INLINE void store_(void *ptr) const { + assert((uintptr_t) ptr % 64 == 0); + _mm512_store_si512((__m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64), m); + } + + template + ENOKI_INLINE void store_(void *ptr, const Mask &mask) const { + assert((uintptr_t) ptr % 64 == 0); + _mm512_mask_store_epi32((__m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64), mask.k, m); + } + + ENOKI_INLINE void store_unaligned_(void *ptr) const { + _mm512_storeu_si512((__m512i *) ptr, m); + } + + template + ENOKI_INLINE void store_unaligned_(void *ptr, const Mask &mask) const { + _mm512_mask_storeu_epi32((__m512i *) ptr, mask.k, m); + } + + static ENOKI_INLINE Derived load_(const void *ptr) { + assert((uintptr_t) ptr % 64 == 0); + return _mm512_load_si512((const __m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64)); + } + + template + static ENOKI_INLINE Derived load_(const void *ptr, const Mask &mask) { + assert((uintptr_t) ptr % 64 == 0); + return _mm512_maskz_load_epi32(mask.k, (const __m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64)); + } + + static ENOKI_INLINE Derived load_unaligned_(const void *ptr) { + return _mm512_loadu_si512((const __m512i *) ptr); + } + + template + static ENOKI_INLINE Derived load_unaligned_(const void *ptr, const Mask &mask) { + return _mm512_maskz_loadu_epi32(mask.k, (const __m512i *) ptr); + } + + static ENOKI_INLINE Derived zero_() { return _mm512_setzero_si512(); } + +#if defined(ENOKI_X86_AVX512PF) + template + static ENOKI_INLINE void prefetch_(const void *ptr, const Index &index, const Mask &mask) { + constexpr auto Hint = Level == 1 ? _MM_HINT_T0 : _MM_HINT_T1; + + if constexpr (sizeof(scalar_t) == 4) { + if (Write) + _mm512_mask_prefetch_i32scatter_ps((void *) ptr, mask.k, index.m, Stride, Hint); + else + _mm512_mask_prefetch_i32gather_ps(index.m, mask.k, ptr, Stride, Hint); + } else { + if (Write) { + _mm512_mask_prefetch_i64scatter_ps((void *) ptr, low(mask).k, low(index).m, Stride, Hint); + _mm512_mask_prefetch_i64scatter_ps((void *) ptr, high(mask).k, high(index).m, Stride, Hint); + } else { + _mm512_mask_prefetch_i64gather_ps(low(index).m, low(mask).k, ptr, Stride, Hint); + _mm512_mask_prefetch_i64gather_ps(high(index).m, high(mask).k, ptr, Stride, Hint); + } + } + } +#endif + + template + static ENOKI_INLINE Derived gather_(const void *ptr, const Index &index, const Mask &mask) { + if constexpr (sizeof(scalar_t) == 4) { + return _mm512_mask_i32gather_epi32(_mm512_setzero_si512(), mask.k, index.m, (const float *) ptr, Stride); + } else { + return detail::concat( + _mm512_mask_i64gather_epi32(_mm256_setzero_si256(), low(mask).k, low(index).m, (const float *) ptr, Stride), + _mm512_mask_i64gather_epi32(_mm256_setzero_si256(), high(mask).k, high(index).m, (const float *) ptr, Stride)); + } + } + + template + ENOKI_INLINE void scatter_(void *ptr, const Index &index, const Mask &mask) const { + if constexpr (sizeof(scalar_t) == 4) { + _mm512_mask_i32scatter_epi32(ptr, mask.k, index.m, m, Stride); + } else { + _mm512_mask_i64scatter_epi32(ptr, low(mask).k, low(index).m, low(derived()).m, Stride); + _mm512_mask_i64scatter_epi32(ptr, high(mask).k, high(index).m, high(derived()).m, Stride); + } + } + + template + ENOKI_INLINE Value extract_(const Mask &mask) const { + return (Value) _mm_cvtsi128_si32(_mm512_castsi512_si128(_mm512_maskz_compress_epi32(mask.k, m))); + } + + template + ENOKI_INLINE size_t compress_(Value_ *&ptr, const Mask &mask) const { + _mm512_storeu_si512((__m512i *) ptr, _mm512_maskz_compress_epi32(mask.k, m)); + size_t kn = (size_t) _mm_popcnt_u32(mask.k); + ptr += kn; + return kn; + } + +#if defined(ENOKI_X86_AVX512CD) + template + static ENOKI_INLINE void transform_(void *mem, + Index index, + const Mask &mask, + const Func &func, + const Args &... args) { + Derived values = _mm512_mask_i32gather_epi32( + _mm512_undefined_epi32(), mask.k, index.m, mem, (int) Stride); + + index.m = _mm512_mask_mov_epi32(_mm512_set1_epi32(-1), mask.k, index.m); + + __m512i conflicts = _mm512_conflict_epi32(index.m); + __m512i perm_idx = _mm512_sub_epi32(_mm512_set1_epi32(31), _mm512_lzcnt_epi32(conflicts)); + __mmask16 todo = _mm512_mask_test_epi32_mask(mask.k, conflicts, _mm512_set1_epi32(-1)); + + func(values, args...); + + ENOKI_NOUNROLL while (ENOKI_UNLIKELY(!_mm512_kortestz(todo, todo))) { + __mmask16 cur = _mm512_mask_testn_epi32_mask( + todo, conflicts, _mm512_broadcastmw_epi32(todo)); + values.m = _mm512_mask_permutexvar_epi32(values.m, cur, perm_idx, values.m); + + __m512i backup(values.m); + func(values, args...); + + values.m = _mm512_mask_mov_epi32(backup, cur, values.m); + todo = _mm512_kxor(todo, cur); + } + + _mm512_mask_i32scatter_epi32(mem, mask.k, index.m, values.m, (int) Stride); + } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Masked versions of key operations + // ----------------------------------------------------------------------- + + template + ENOKI_INLINE void massign_(const Derived &a, const Mask &mask) { m = _mm512_mask_mov_epi32(m, mask.k, a.m); } + template + ENOKI_INLINE void madd_ (const Derived &a, const Mask &mask) { m = _mm512_mask_add_epi32(m, mask.k, m, a.m); } + template + ENOKI_INLINE void msub_ (const Derived &a, const Mask &mask) { m = _mm512_mask_sub_epi32(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mmul_ (const Derived &a, const Mask &mask) { m = _mm512_mask_mullo_epi32(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mor_ (const Derived &a, const Mask &mask) { m = _mm512_mask_or_epi32(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mand_ (const Derived &a, const Mask &mask) { m = _mm512_mask_and_epi32(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mxor_ (const Derived &a, const Mask &mask) { m = _mm512_mask_xor_epi32(m, mask.k, m, a.m); } + + //! @} + // ----------------------------------------------------------------------- +} ENOKI_MAY_ALIAS; + +/// Partial overload of StaticArrayImpl using AVX512 intrinsics (64 bit integers) +template struct alignas(64) + StaticArrayImpl> + : StaticArrayBase { + ENOKI_NATIVE_ARRAY(Value_, 8, __m512i) + + + // ----------------------------------------------------------------------- + //! @{ \name Value constructors + // ----------------------------------------------------------------------- + + ENOKI_INLINE StaticArrayImpl(const Value &value) : m(_mm512_set1_epi64((long long) value)) { } + ENOKI_INLINE StaticArrayImpl(Value f0, Value f1, Value f2, Value f3, + Value f4, Value f5, Value f6, Value f7) + : m(_mm512_setr_epi64((long long) f0, (long long) f1, (long long) f2, + (long long) f3, (long long) f4, (long long) f5, + (long long) f6, (long long) f7)) { } + + // ----------------------------------------------------------------------- + //! @{ \name Type converting constructors + // ----------------------------------------------------------------------- + +#if defined(ENOKI_X86_AVX512DQ) + ENOKI_CONVERT(float) { + m = std::is_signed_v ? _mm512_cvttps_epi64(a.derived().m) + : _mm512_cvttps_epu64(a.derived().m); + } +#else + /* Emulate float -> uint64 conversion instead of falling + back to scalar operations. This is quite a bit faster (~4x!) */ + + ENOKI_CONVERT(float) { + using Int32 = int_array_t; + using UInt32 = uint_array_t; + using UInt64 = uint64_array_t; + + /* Shift out sign bit */ + auto b = reinterpret_array(a); + b += b; + + auto mant = UInt64((b & 0xffffffu) | 0x1000000u); + auto shift = (24 + 127) - Int32(sr<24>(b)); + auto abs_shift = UInt64(abs(shift)); + + auto result = select(shift > 0, mant >> abs_shift, mant << abs_shift); + + if constexpr (std::is_signed_v) + result[a < 0] = -result; + + m = result.m; + } +#endif + + ENOKI_CONVERT(int32_t) + : m(_mm512_cvtepi32_epi64(a.derived().m)) { } + + ENOKI_CONVERT(uint32_t) + : m(_mm512_cvtepu32_epi64(a.derived().m)) { } + +#if defined(ENOKI_X86_AVX512DQ) + ENOKI_CONVERT(double) { + m = std::is_signed_v ? _mm512_cvttpd_epi64(a.derived().m) + : _mm512_cvttpd_epu64(a.derived().m); + } +#else + /* Emulate double -> uint64 conversion instead of falling + back to scalar operations. This is quite a bit faster (>~11x!) */ + + ENOKI_CONVERT(double) { + using Int64 = int_array_t; + using UInt64 = uint_array_t; + + /* Shift out sign bit */ + auto b = reinterpret_array(a); + b += b; + + auto mant = (b & 0x1fffffffffffffull) | 0x20000000000000ull; + auto shift = (53 + 1023) - Int64(sr<53>(b)); + auto abs_shift = UInt64(abs(shift)); + + auto result = select(shift > 0, mant >> abs_shift, mant << abs_shift); + + if constexpr (std::is_signed_v) + result[a < 0] = -result; + + m = result.m; + } +#endif + + ENOKI_CONVERT(int64_t) : m(a.derived().m) { } + ENOKI_CONVERT(uint64_t) : m(a.derived().m) { } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Reinterpreting constructors, mask converters + // ----------------------------------------------------------------------- + + ENOKI_REINTERPRET(double) : m(_mm512_castpd_si512(a.derived().m)) { } + ENOKI_REINTERPRET(int64_t) : m(a.derived().m) { } + ENOKI_REINTERPRET(uint64_t) : m(a.derived().m) { } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Converting from/to half size vectors + // ----------------------------------------------------------------------- + + StaticArrayImpl(const Array1 &a1, const Array2 &a2) + : m(detail::concat(a1.m, a2.m)) { } + + ENOKI_INLINE Array1 low_() const { return _mm512_castsi512_si256(m); } + ENOKI_INLINE Array2 high_() const { return _mm512_extracti64x4_epi64(m, 1); } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Vertical operations + // ----------------------------------------------------------------------- + + ENOKI_INLINE Derived add_(Ref a) const { return _mm512_add_epi64(m, a.m); } + ENOKI_INLINE Derived sub_(Ref a) const { return _mm512_sub_epi64(m, a.m); } + + ENOKI_INLINE Derived mul_(Ref a) const { + #if defined(ENOKI_X86_AVX512DQ) && defined(ENOKI_X86_AVX512VL) + return _mm512_mullo_epi64(m, a.m); + #else + __m512i h0 = _mm512_srli_epi64(m, 32); + __m512i h1 = _mm512_srli_epi64(a.m, 32); + __m512i low = _mm512_mul_epu32(m, a.m); + __m512i mix0 = _mm512_mul_epu32(m, h1); + __m512i mix1 = _mm512_mul_epu32(h0, a.m); + __m512i mix = _mm512_add_epi64(mix0, mix1); + __m512i mix_s = _mm512_slli_epi64(mix, 32); + return _mm512_add_epi64(mix_s, low); + #endif + } + + template + ENOKI_INLINE Derived or_ (const T &a) const { + if constexpr (is_mask_v) + return _mm512_mask_mov_epi64(m, a.k, _mm512_set1_epi64(int64_t(-1))); + else + return _mm512_or_epi64(m, a.m); + } + + template + ENOKI_INLINE Derived and_ (const T &a) const { + if constexpr (is_mask_v) + return _mm512_maskz_mov_epi64(a.k, m); + else + return _mm512_and_epi64(m, a.m); + } + + template + ENOKI_INLINE Derived andnot_ (const T &a) const { + if constexpr (is_mask_v) + return _mm512_mask_mov_epi64(m, a.k, _mm512_setzero_si512()); + else + return _mm512_andnot_epi64(m, a.m); + } + + template + ENOKI_INLINE Derived xor_ (const T &a) const { + if constexpr (is_mask_v) + return _mm512_mask_xor_epi64(m, a.k, m, _mm512_set1_epi64(int64_t(-1))); + else + return _mm512_xor_epi64(m, a.m); + } + + template ENOKI_INLINE Derived sl_() const { + return _mm512_slli_epi64(m, (int) k); + } + + template ENOKI_INLINE Derived sr_() const { + return std::is_signed_v ? _mm512_srai_epi64(m, (int) k) + : _mm512_srli_epi64(m, (int) k); + } + + ENOKI_INLINE Derived sl_(size_t k) const { + return _mm512_sll_epi64(m, _mm_set1_epi64x((long long) k)); + } + + ENOKI_INLINE Derived sr_(size_t k) const { + return std::is_signed_v + ? _mm512_sra_epi64(m, _mm_set1_epi64x((long long) k)) + : _mm512_srl_epi64(m, _mm_set1_epi64x((long long) k)); + } + + ENOKI_INLINE Derived sl_(Ref k) const { + return _mm512_sllv_epi64(m, k.m); + } + + ENOKI_INLINE Derived sr_(Ref k) const { + return std::is_signed_v ? _mm512_srav_epi64(m, k.m) + : _mm512_srlv_epi64(m, k.m); + } + + ENOKI_INLINE Derived rol_(Ref k) const { return _mm512_rolv_epi64(m, k.m); } + ENOKI_INLINE Derived ror_(Ref k) const { return _mm512_rorv_epi64(m, k.m); } + + template + ENOKI_INLINE Derived rol_() const { return _mm512_rol_epi64(m, (int) Imm); } + + template + ENOKI_INLINE Derived ror_() const { return _mm512_ror_epi64(m, (int) Imm); } + + ENOKI_INLINE auto lt_ (Ref a) const { return mask_t::from_k(_mm512_cmp_epi64_mask(m, a.m, _MM_CMPINT_LT)); } + ENOKI_INLINE auto gt_ (Ref a) const { return mask_t::from_k(_mm512_cmp_epi64_mask(m, a.m, _MM_CMPINT_GT)); } + ENOKI_INLINE auto le_ (Ref a) const { return mask_t::from_k(_mm512_cmp_epi64_mask(m, a.m, _MM_CMPINT_LE)); } + ENOKI_INLINE auto ge_ (Ref a) const { return mask_t::from_k(_mm512_cmp_epi64_mask(m, a.m, _MM_CMPINT_GE)); } + ENOKI_INLINE auto eq_ (Ref a) const { return mask_t::from_k(_mm512_cmp_epi64_mask(m, a.m, _MM_CMPINT_EQ)); } + ENOKI_INLINE auto neq_(Ref a) const { return mask_t::from_k(_mm512_cmp_epi64_mask(m, a.m, _MM_CMPINT_NE)); } + + ENOKI_INLINE Derived min_(Ref a) const { + return std::is_signed_v ? _mm512_min_epi64(a.m, m) + : _mm512_min_epu64(a.m, m); + } + + ENOKI_INLINE Derived max_(Ref a) const { + return std::is_signed_v ? _mm512_max_epi64(a.m, m) + : _mm512_max_epu64(a.m, m); + } + + ENOKI_INLINE Derived abs_() const { + return std::is_signed_v ? _mm512_abs_epi64(m) : m; + } + + template + static ENOKI_INLINE Derived select_(const Mask &m, Ref t, Ref f) { + return _mm512_mask_blend_epi64(m.k, f.m, t.m); + } + + template + ENOKI_INLINE Derived shuffle_() const { + const __m512i idx = + _mm512_setr_epi64(I0, I1, I2, I3, I4, I5, I6, I7); + return _mm512_permutexvar_epi64(idx, m); + } + + template ENOKI_INLINE Derived shuffle_(const Index &index) const { + return _mm512_permutexvar_epi64(index.m, m); + } + + ENOKI_INLINE Derived mulhi_(Ref b) const { + if (std::is_unsigned_v) { + const __m512i low_bits = _mm512_set1_epi64(0xffffffffu); + __m512i al = m, bl = b.m; + __m512i ah = _mm512_srli_epi64(al, 32); + __m512i bh = _mm512_srli_epi64(bl, 32); + + // 4x unsigned 32x32->64 bit multiplication + __m512i albl = _mm512_mul_epu32(al, bl); + __m512i albh = _mm512_mul_epu32(al, bh); + __m512i ahbl = _mm512_mul_epu32(ah, bl); + __m512i ahbh = _mm512_mul_epu32(ah, bh); + + // Calculate a possible carry from the low bits of the multiplication. + __m512i carry = _mm512_add_epi64( + _mm512_srli_epi64(albl, 32), + _mm512_add_epi64(_mm512_and_epi64(albh, low_bits), + _mm512_and_epi64(ahbl, low_bits))); + + __m512i s0 = _mm512_add_epi64(ahbh, _mm512_srli_epi64(carry, 32)); + __m512i s1 = _mm512_add_epi64(_mm512_srli_epi64(albh, 32), + _mm512_srli_epi64(ahbl, 32)); + + return _mm512_add_epi64(s0, s1); + } else { + const Derived mask(0xffffffff); + const Derived a = derived(); + Derived ah = sr<32>(a), bh = sr<32>(b), + al = a & mask, bl = b & mask; + + Derived albl_hi = _mm512_srli_epi64(_mm512_mul_epu32(m, b.m), 32); + + Derived t = ah * bl + albl_hi; + Derived w1 = al * bh + (t & mask); + + return ah * bh + sr<32>(t) + sr<32>(w1); + } + } + +#if defined(ENOKI_X86_AVX512CD) + ENOKI_INLINE Derived lzcnt_() const { return _mm512_lzcnt_epi64(m); } + ENOKI_INLINE Derived tzcnt_() const { return Value(64) - lzcnt(~derived() & (derived() - Value(1))); } +#endif + +#if defined(ENOKI_X86_AVX512VPOPCNTDQ) + ENOKI_INLINE Derived popcnt_() const { return _mm512_popcnt_epi64(m); } +#endif + + // ----------------------------------------------------------------------- + //! @{ \name Horizontal operations + // ----------------------------------------------------------------------- + + ENOKI_INLINE Value hsum_() const { return hsum(low_() + high_()); } + ENOKI_INLINE Value hprod_() const { return hprod(low_() * high_()); } + ENOKI_INLINE Value hmin_() const { return hmin(min(low_(), high_())); } + ENOKI_INLINE Value hmax_() const { return hmax(max(low_(), high_())); } + + //! @} + // ----------------------------------------------------------------------- + + ENOKI_INLINE void store_(void *ptr) const { + assert((uint64_t) ptr % 64 == 0); + _mm512_store_si512((__m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64), m); + } + + template + ENOKI_INLINE void store_(void *ptr, const Mask &mask) const { + assert((uint64_t) ptr % 64 == 0); + _mm512_mask_store_epi64((__m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64), mask.k, m); + } + + ENOKI_INLINE void store_unaligned_(void *ptr) const { + _mm512_storeu_si512((__m512i *) ptr, m); + } + + template + ENOKI_INLINE void store_unaligned_(void *ptr, const Mask &mask) const { + _mm512_mask_storeu_epi64((__m512i *) ptr, mask.k, m); + } + + static ENOKI_INLINE Derived load_(const void *ptr) { + assert((uint64_t) ptr % 64 == 0); + return _mm512_load_si512((const __m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64)); + } + + template + static ENOKI_INLINE Derived load_(const void *ptr, const Mask &mask) { + assert((uint64_t) ptr % 64 == 0); + return _mm512_maskz_load_epi64(mask.k, (const __m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64)); + } + + static ENOKI_INLINE Derived load_unaligned_(const void *ptr) { + return _mm512_loadu_si512((const __m512i *) ptr); + } + + template + static ENOKI_INLINE Derived load_unaligned_(const void *ptr, const Mask &mask) { + return _mm512_maskz_loadu_epi64(mask.k, (const __m512i *) ptr); + } + + static ENOKI_INLINE Derived zero_() { return _mm512_setzero_si512(); } + +#if defined(ENOKI_X86_AVX512PF) + template + static ENOKI_INLINE void prefetch_(const void *ptr, const Index &index, const Mask &mask) { + constexpr auto Hint = Level == 1 ? _MM_HINT_T0 : _MM_HINT_T1; + + if constexpr (sizeof(scalar_t) == 4) { + if constexpr (Write) + _mm512_mask_prefetch_i32scatter_pd((void *) ptr, mask.k, index.m, Stride, Hint); + else + _mm512_mask_prefetch_i32gather_pd(index.m, mask.k, ptr, Stride, Hint); + } else { + if constexpr (Write) + _mm512_mask_prefetch_i64scatter_pd((void *) ptr, mask.k, index.m, Stride, Hint); + else + _mm512_mask_prefetch_i64gather_pd(index.m, mask.k, ptr, Stride, Hint); + } + } +#endif + + template + static ENOKI_INLINE Derived gather_(const void *ptr, const Index &index, const Mask &mask) { + if constexpr (sizeof(scalar_t) == 4) + return _mm512_mask_i32gather_epi64(_mm512_setzero_si512(), mask.k, index.m, (const float *) ptr, Stride); + else + return _mm512_mask_i64gather_epi64(_mm512_setzero_si512(), mask.k, index.m, (const float *) ptr, Stride); + } + + + template + ENOKI_INLINE void scatter_(void *ptr, const Index &index, const Mask &mask) const { + if constexpr (sizeof(scalar_t) == 4) + _mm512_mask_i32scatter_epi64(ptr, mask.k, index.m, m, Stride); + else + _mm512_mask_i64scatter_epi64(ptr, mask.k, index.m, m, Stride); + } + + template + ENOKI_INLINE Value extract_(const Mask &mask) const { + return (Value) _mm_cvtsi128_si64(_mm512_castsi512_si128(_mm512_maskz_compress_epi64(mask.k, m))); + } + + template + ENOKI_INLINE size_t compress_(Value_ *&ptr, const Mask &mask) const { + _mm512_storeu_si512((__m512i *) ptr, _mm512_maskz_compress_epi64(mask.k, m)); + size_t kn = (size_t) _mm_popcnt_u32(mask.k); + ptr += kn; + return kn; + } + +#if defined(ENOKI_X86_AVX512CD) + template + static ENOKI_INLINE void transform_(void *mem, + Index index, + const Mask &mask, + const Func &func, + const Args &... args) { + Derived values = _mm512_mask_i64gather_epi64( + _mm512_undefined_epi32(), mask.k, index.m, mem, (int) Stride); + + index.m = _mm512_mask_mov_epi64(_mm512_set1_epi64(-1), mask.k, index.m); + + __m512i conflicts = _mm512_conflict_epi64(index.m); + __m512i perm_idx = _mm512_sub_epi64(_mm512_set1_epi64(63), _mm512_lzcnt_epi64(conflicts)); + __mmask8 todo = _mm512_mask_test_epi64_mask(mask.k, conflicts, _mm512_set1_epi64(-1)); + + func(values, args...); + + ENOKI_NOUNROLL while (ENOKI_UNLIKELY(todo)) { + __mmask8 cur = _mm512_mask_testn_epi64_mask( + todo, conflicts, _mm512_broadcastmb_epi64(todo)); + values.m = _mm512_mask_permutexvar_epi64(values.m, cur, perm_idx, values.m); + + __m512i backup(values.m); + func(values, args...); + + values.m = _mm512_mask_mov_epi64(backup, cur, values.m); + todo ^= cur; + } + + _mm512_mask_i64scatter_epi64(mem, mask.k, index.m, values.m, (int) Stride); + } +#endif + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Masked versions of key operations + // ----------------------------------------------------------------------- + + template + ENOKI_INLINE void massign_(const Derived &a, const Mask &mask) { m = _mm512_mask_mov_epi64(m, mask.k, a.m); } + template + ENOKI_INLINE void madd_ (const Derived &a, const Mask &mask) { m = _mm512_mask_add_epi64(m, mask.k, m, a.m); } + template + ENOKI_INLINE void msub_ (const Derived &a, const Mask &mask) { m = _mm512_mask_sub_epi64(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mmul_ (const Derived &a, const Mask &mask) { + #if defined(ENOKI_X86_AVX512DQ) && defined(ENOKI_X86_AVX512VL) + m = _mm512_mask_mullo_epi64(m, mask.k, m, a.m); + #else + m = select(mask, a * derived(), derived()).m; + #endif + } + template + ENOKI_INLINE void mor_ (const Derived &a, const Mask &mask) { m = _mm512_mask_or_epi64(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mand_ (const Derived &a, const Mask &mask) { m = _mm512_mask_and_epi64(m, mask.k, m, a.m); } + template + ENOKI_INLINE void mxor_ (const Derived &a, const Mask &mask) { m = _mm512_mask_xor_epi64(m, mask.k, m, a.m); } + + //! @} + // ----------------------------------------------------------------------- +} ENOKI_MAY_ALIAS; + +template +ENOKI_DECLARE_KMASK(float, 16, Derived_, int) +template +ENOKI_DECLARE_KMASK(double, 8, Derived_, int) +template +ENOKI_DECLARE_KMASK(Value_, 16, Derived_, enable_if_int32_t) +template +ENOKI_DECLARE_KMASK(Value_, 8, Derived_, enable_if_int64_t) + +NAMESPACE_END(enoki) diff --git a/sources/enoki/array_base.h b/sources/enoki/array_base.h new file mode 100644 index 00000000..d51804cb --- /dev/null +++ b/sources/enoki/array_base.h @@ -0,0 +1,240 @@ +/* + enoki/array_base.h -- Base class of all Enoki arrays + + Enoki is a C++ template library that enables transparent vectorization + of numerical kernels using SIMD instruction sets available on current + processor architectures. + + Copyright (c) 2019 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#include +#include +#include + +NAMESPACE_BEGIN(enoki) + +template struct ArrayBase { + // ----------------------------------------------------------------------- + //! @{ \name Curiously Recurring Template design pattern + // ----------------------------------------------------------------------- + + /// Alias to the derived type + using Derived = Derived_; + + /// Cast to derived type + ENOKI_INLINE Derived &derived() { return (Derived &) *this; } + + /// Cast to derived type (const version) + ENOKI_INLINE const Derived &derived() const { return (Derived &) *this; } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Basic declarations + // ----------------------------------------------------------------------- + + /// Actual type underlying the derived array + using Value = Value_; + + /// Scalar data type all the way at the lowest level + using Scalar = scalar_t; + + /// Specifies how deeply nested this array is + static constexpr size_t Depth = 1 + array_depth_v; + + /// Is this a mask type? + static constexpr bool IsMask = is_mask_v; + + /// Is this a dynamically allocated array (no by default) + static constexpr bool IsDynamic = is_dynamic_v; + + /// Does this array compute derivatives using automatic differentation? + static constexpr bool IsDiff = is_diff_array_v; + + /// Does this array reside on the GPU? (via CUDA) + static constexpr bool IsCUDA = is_cuda_array_v; + + /// Does this array map operations onto native vector instructions? + static constexpr bool IsNative = false; + + /// Is this an AVX512-style 'k' mask register? + static constexpr bool IsKMask = false; + + /// Is the storage representation of this array implemented recursively? + static constexpr bool IsRecursive = false; + + /// Always prefer broadcasting to the outer dimensions of a N-D array + static constexpr bool BroadcastPreferOuter = true; + + /// Does this array represent a fixed size vector? + static constexpr bool IsVector = false; + + /// Does this array represent a complex number? + static constexpr bool IsComplex = false; + + /// Does this array represent a quaternion? + static constexpr bool IsQuaternion = false; + + /// Does this array represent a matrix? + static constexpr bool IsMatrix = false; + + /// Does this array represent the result of a 'masked(...)' epxpression? + static constexpr bool IsMaskedArray = false; + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Iterators + // ----------------------------------------------------------------------- + + ENOKI_INLINE auto begin() const { return derived().data(); } + ENOKI_INLINE auto begin() { return derived().data(); } + ENOKI_INLINE auto end() const { return derived().data() + derived().size(); } + ENOKI_INLINE auto end() { return derived().data() + derived().size(); } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Element access + // ----------------------------------------------------------------------- + + /// Array indexing operator with bounds checks in debug mode + ENOKI_INLINE decltype(auto) operator[](size_t i) { + #if !defined(NDEBUG) && !defined(ENOKI_DISABLE_RANGE_CHECK) + if (i >= derived().size()) + throw std::out_of_range( + "ArrayBase: out of range access (tried to access index " + + std::to_string(i) + " in an array of size " + + std::to_string(derived().size()) + ")"); + #endif + return derived().coeff(i); + } + + /// Array indexing operator with bounds checks in debug mode, const version + ENOKI_INLINE decltype(auto) operator[](size_t i) const { + #if !defined(NDEBUG) && !defined(ENOKI_DISABLE_RANGE_CHECK) + if (i >= derived().size()) + throw std::out_of_range( + "ArrayBase: out of range access (tried to access index " + + std::to_string(i) + " in an array of size " + + std::to_string(derived().size()) + ")"); + #endif + return derived().coeff(i); + } + + template = 0> + ENOKI_INLINE auto operator[](const Mask &m) { + return detail::MaskedArray{ derived(), (const mask_t &) m }; + } + + //! @} + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + //! @{ \name Fallback implementations for masked operations + // ----------------------------------------------------------------------- + + #define ENOKI_MASKED_OPERATOR_FALLBACK(name, expr) \ + template \ + ENOKI_INLINE void m##name##_(const T &e, const Mask &m) { \ + derived() = select(m, expr, derived()); \ + } + + ENOKI_MASKED_OPERATOR_FALLBACK(assign, e) + ENOKI_MASKED_OPERATOR_FALLBACK(add, derived() + e) + ENOKI_MASKED_OPERATOR_FALLBACK(sub, derived() - e) + ENOKI_MASKED_OPERATOR_FALLBACK(mul, derived() * e) + ENOKI_MASKED_OPERATOR_FALLBACK(div, derived() / e) + ENOKI_MASKED_OPERATOR_FALLBACK(or, derived() | e) + ENOKI_MASKED_OPERATOR_FALLBACK(and, derived() & e) + ENOKI_MASKED_OPERATOR_FALLBACK(xor, derived() ^ e) + + #undef ENOKI_MASKED_OPERATOR_FALLBACK + + //! @} + // ----------------------------------------------------------------------- + + /// Dot product fallback implementation + ENOKI_INLINE auto dot_(const Derived &a) const { return hsum(derived() * a); } + + /// Horizontal mean fallback implementation + ENOKI_INLINE auto hmean_() const { + return hsum(derived()) * (1.f / derived().size()); + } + + template + ENOKI_INLINE void scatter_add_(void *mem, const Index &index, + const Mask &mask) const { + transform( + mem, index, [](auto &a, auto &b, auto &) { a += b; }, + derived(), mask); + } +}; + +namespace detail { + template + ENOKI_INLINE bool convert_mask(T value) { + if constexpr (std::is_same_v) + return value; + else + return memcpy_cast::UInt>(value) != 0; + } + + template + void print(Stream &os, const Array &a, bool abbrev, + const std::array &size, Indices... indices) { + ENOKI_MARK_USED(size); + ENOKI_MARK_USED(abbrev); + if constexpr (sizeof...(Indices) == N) { + os << a.derived().coeff(indices...); + } else { + constexpr size_t k = N - sizeof...(Indices) - 1; + os << "["; + for (size_t i = 0; i < size[k]; ++i) { + if constexpr (is_dynamic_array_v) { + if (size[k] > 20 && i == 5 && abbrev) { + if (k > 0) { + os << ".. " << size[k] - 10 << " skipped ..,\n"; + for (size_t j = 0; j <= sizeof...(Indices); ++j) + os << " "; + } else { + os << ".. " << size[k] - 10 << " skipped .., "; + } + i = size[k] - 6; + continue; + } + } + print(os, a, abbrev, size, i, indices...); + if (i + 1 < size[k]) { + if constexpr (k == 0) { + os << ", "; + } else { + os << ",\n"; + for (size_t j = 0; j <= sizeof...(Indices); ++j) + os << " "; + } + } + } + os << "]"; + } + } +} + +template +ENOKI_NOINLINE std::ostream &operator<<(std::ostream &os, const ArrayBase &a) { + if (ragged(a)) + os << "[ragged array]"; + else + detail::print(os, a, true, shape(a)); + return os; +} + + +NAMESPACE_END(enoki) diff --git a/sources/enoki/array_call.h b/sources/enoki/array_call.h new file mode 100644 index 00000000..d020e370 --- /dev/null +++ b/sources/enoki/array_call.h @@ -0,0 +1,291 @@ +/* + enoki/array_call.h -- Enoki arrays of pointers, support for + array (virtual) method calls + + Copyright (c) 2019 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#pragma once + +#include + +NAMESPACE_BEGIN(enoki) + +template struct call_support { + call_support(const Storage &) { } +}; + +template +struct StaticArrayImpl::use_pointer_impl>> + : StaticArrayImpl { + + using UnderlyingType = std::uintptr_t; + + using Base = StaticArrayImpl; + + ENOKI_ARRAY_DEFAULTS(StaticArrayImpl) + using Base::derived; + + using Value = std::conditional_t; + using Scalar = std::conditional_t; + + StaticArrayImpl() = default; + StaticArrayImpl(Value value) : Base(UnderlyingType(value)) { } + StaticArrayImpl(std::nullptr_t) : Base(UnderlyingType(0)) { } + + template > = 0> + StaticArrayImpl(const T &b) : Base(b) { } + + template + StaticArrayImpl(const T &b, detail::reinterpret_flag) + : Base(b, detail::reinterpret_flag()) { } + + template == array_depth_v && array_size_v == Base::Size1 && + array_depth_v == array_depth_v && array_size_v == Base::Size2 && + Base::Size2 != 0> = 0> + StaticArrayImpl(const T1 &a1, const T2 &a2) + : Base(a1, a2) { } + + ENOKI_INLINE decltype(auto) coeff(size_t i) const { + using Coeff = decltype(Base::coeff(i)); + if constexpr (std::is_same_v) + return (const Value &) Base::coeff(i); + else + return Base::coeff(i); + } + + ENOKI_INLINE decltype(auto) coeff(size_t i) { + using Coeff = decltype(Base::coeff(i)); + if constexpr (std::is_same_v) + return (Value &) Base::coeff(i); + else + return Base::coeff(i); + } + + template + ENOKI_INLINE size_t compress_(T *&ptr, const Mask &mask) const { + return Base::compress_((UnderlyingType *&) ptr, mask); + } + + auto operator->() const { + using BaseType = std::decay_t>>; + return call_support(derived()); + } + + template Derived_& operator=(T&& t) { + ENOKI_MARK_USED(t); + if constexpr (std::is_same_v) + return (Derived_ &) Base::operator=(UnderlyingType(0)); + else if constexpr (std::is_convertible_v) + return (Derived_ &) Base::operator=(UnderlyingType(t)); + else + return (Derived_ &) Base::operator=(std::forward(t)); + } +}; + +NAMESPACE_BEGIN(detail) +template typename T, typename... Args> +struct is_callable : std::false_type {}; +template