cocos-engine-external/sources/enoki/array_avx512.h

1929 lines
80 KiB
C++

/*
enoki/array_avx512.h -- Packed SIMD array (AVX512 specialization)
Enoki is a C++ template library that enables transparent vectorization
of numerical kernels using SIMD instruction sets available on current
processor architectures.
Copyrighe (c) 2019 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a BSD-style
license that can be found in the LICENSE file.
*/
#pragma once
NAMESPACE_BEGIN(enoki)
NAMESPACE_BEGIN(detail)
template <> struct is_native<float, 16> : std::true_type { } ;
template <> struct is_native<double, 8> : std::true_type { };
template <typename Value> struct is_native<Value, 16, enable_if_int32_t<Value>> : std::true_type { };
template <typename Value> struct is_native<Value, 8, enable_if_int64_t<Value>> : std::true_type { };
NAMESPACE_END(detail)
/// Partial overload of StaticArrayImpl using AVX512 intrinsics (single precision)
template <bool IsMask_, typename Derived_> struct alignas(64)
StaticArrayImpl<float, 16, IsMask_, Derived_>
: StaticArrayBase<float, 16, IsMask_, Derived_> {
ENOKI_NATIVE_ARRAY(float, 16, __m512)
// -----------------------------------------------------------------------
//! @{ \name Value constructors
// -----------------------------------------------------------------------
ENOKI_INLINE StaticArrayImpl(Value value) : m(_mm512_set1_ps(value)) { }
ENOKI_INLINE StaticArrayImpl(Value f0, Value f1, Value f2, Value f3,
Value f4, Value f5, Value f6, Value f7,
Value f8, Value f9, Value f10, Value f11,
Value f12, Value f13, Value f14, Value f15)
: m(_mm512_setr_ps(f0, f1, f2, f3, f4, f5, f6, f7, f8,
f9, f10, f11, f12, f13, f14, f15)) { }
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Type converting constructors
// -----------------------------------------------------------------------
ENOKI_CONVERT(half)
: m(_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *) a.derived().data()))) { }
ENOKI_CONVERT(float) : m(a.derived().m) { }
ENOKI_CONVERT(int32_t) : m(_mm512_cvtepi32_ps(a.derived().m)) { }
ENOKI_CONVERT(uint32_t) : m(_mm512_cvtepu32_ps(a.derived().m)) { }
ENOKI_CONVERT(double)
: m(detail::concat(_mm512_cvtpd_ps(low(a).m),
_mm512_cvtpd_ps(high(a).m))) { }
#if defined(ENOKI_X86_AVX512DQ)
ENOKI_CONVERT(int64_t)
: m(detail::concat(_mm512_cvtepi64_ps(low(a).m),
_mm512_cvtepi64_ps(high(a).m))) { }
ENOKI_CONVERT(uint64_t)
: m(detail::concat(_mm512_cvtepu64_ps(low(a).m),
_mm512_cvtepu64_ps(high(a).m))) { }
#elif defined(ENOKI_X86_AVX512CD)
/* Emulate uint64_t -> float conversion instead of falling
back to scalar operations. This is quite a bit faster
(>6x for unsigned, >5x for signed). */
ENOKI_CONVERT(uint64_t) {
using Int64 = int64_array_t<Derived2>;
using Int32 = uint32_array_t<Derived2>;
auto lz = lzcnt(a);
auto shift = (63 - 23) - Int64(lz);
auto abs_shift = abs(shift);
auto nzero_mask = neq(a, 0ull);
auto mant = select(shift > 0, a >> abs_shift, a << abs_shift);
auto exp = sl<23>(uint64_t(127 + 63) - lz) & nzero_mask;
auto comb = exp | (mant & 0x7fffffull);
m = reinterpret_array<Derived>(Int32(comb)).m;
}
ENOKI_CONVERT(int64_t) {
using Int32 = uint32_array_t<Derived2>;
auto b = abs(a), lz = lzcnt(b);
auto shift = (63 - 23) - lz;
auto abs_shift = abs(shift);
auto nzero_mask = neq(a, 0ll);
auto mant = select(shift > 0, b >> abs_shift, b << abs_shift);
auto sign = sr<32>(a) & 0x80000000ll;
auto exp = sl<23>(int64_t(127 + 63) - lz) & nzero_mask;
auto comb = exp | (mant & 0x7fffffll) | sign;
m = reinterpret_array<Derived>(Int32(comb)).m;
}
#endif
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Reinterpreting constructors, mask converters
// -----------------------------------------------------------------------
ENOKI_REINTERPRET(float) : m(a.derived().m) { }
ENOKI_REINTERPRET(int32_t) : m(_mm512_castsi512_ps(a.derived().m)) { }
ENOKI_REINTERPRET(uint32_t) : m(_mm512_castsi512_ps(a.derived().m)) { }
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Converting from/to half size vectors
// -----------------------------------------------------------------------
StaticArrayImpl(const Array1 &a1, const Array2 &a2)
: m(detail::concat(a1.m, a2.m)) { }
ENOKI_INLINE Array1 low_() const { return _mm512_castps512_ps256(m); }
ENOKI_INLINE Array2 high_() const {
#if defined(ENOKI_X86_AVX512DQ)
return _mm512_extractf32x8_ps(m, 1);
#else
return _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(m), 1));
#endif
}
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Vertical operations
// -----------------------------------------------------------------------
ENOKI_INLINE Derived add_(Ref a) const { return _mm512_add_ps(m, a.m); }
ENOKI_INLINE Derived sub_(Ref a) const { return _mm512_sub_ps(m, a.m); }
ENOKI_INLINE Derived mul_(Ref a) const { return _mm512_mul_ps(m, a.m); }
ENOKI_INLINE Derived div_(Ref a) const { return _mm512_div_ps(m, a.m); }
template <typename T> ENOKI_INLINE Derived or_(const T &a) const {
if constexpr (is_mask_v<T>) {
return _mm512_mask_mov_ps(m, a.k, _mm512_set1_ps(memcpy_cast<Value>(int32_t(-1))));
} else {
#if defined(ENOKI_X86_AVX512DQ)
return _mm512_or_ps(m, a.m);
#else
return _mm512_castsi512_ps(
_mm512_or_si512(_mm512_castps_si512(m), _mm512_castps_si512(a.m)));
#endif
}
}
template <typename T> ENOKI_INLINE Derived and_(const T &a) const {
if constexpr (is_mask_v<T>) {
return _mm512_maskz_mov_ps(a.k, m);
} else {
#if defined(ENOKI_X86_AVX512DQ)
return _mm512_and_ps(m, a.m);
#else
return _mm512_castsi512_ps(
_mm512_and_si512(_mm512_castps_si512(m), _mm512_castps_si512(a.m)));
#endif
}
}
template <typename T> ENOKI_INLINE Derived andnot_(const T &a) const {
if constexpr (is_mask_v<T>) {
return _mm512_mask_mov_ps(m, a.k, _mm512_setzero_ps());
} else {
#if defined(ENOKI_X86_AVX512DQ)
return _mm512_andnot_ps(a.m, m);
#else
return _mm512_castsi512_ps(
_mm512_andnot_si512(_mm512_castps_si512(a.m), _mm512_castps_si512(m)));
#endif
}
}
template <typename T> ENOKI_INLINE Derived xor_(const T &a) const {
if constexpr (is_mask_v<T>) {
const __m512 c = _mm512_set1_ps(memcpy_cast<Value>(int32_t(-1)));
#if defined(ENOKI_X86_AVX512DQ)
return _mm512_mask_xor_ps(m, a.k, m, c);
#else
const __m512i v0 = _mm512_castps_si512(m);
return _mm512_castsi512_ps(_mm512_mask_xor_epi32(v0, a.k, v0, c));
#endif
} else {
#if defined(ENOKI_X86_AVX512DQ)
return _mm512_xor_ps(m, a.m);
#else
return _mm512_castsi512_ps(
_mm512_xor_si512(_mm512_castps_si512(m), _mm512_castps_si512(a.m)));
#endif
}
}
ENOKI_INLINE auto lt_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_ps_mask(m, a.m, _CMP_LT_OQ)); }
ENOKI_INLINE auto gt_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_ps_mask(m, a.m, _CMP_GT_OQ)); }
ENOKI_INLINE auto le_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_ps_mask(m, a.m, _CMP_LE_OQ)); }
ENOKI_INLINE auto ge_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_ps_mask(m, a.m, _CMP_GE_OQ)); }
ENOKI_INLINE auto eq_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_ps_mask(m, a.m, _CMP_EQ_OQ)); }
ENOKI_INLINE auto neq_(Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_ps_mask(m, a.m, _CMP_NEQ_UQ)); }
ENOKI_INLINE Derived abs_() const { return andnot_(Derived(_mm512_set1_ps(-0.f))); }
ENOKI_INLINE Derived min_(Ref b) const { return _mm512_min_ps(b.m, m); }
ENOKI_INLINE Derived max_(Ref b) const { return _mm512_max_ps(b.m, m); }
ENOKI_INLINE Derived ceil_() const { return _mm512_ceil_ps(m); }
ENOKI_INLINE Derived floor_() const { return _mm512_floor_ps(m); }
ENOKI_INLINE Derived sqrt_() const { return _mm512_sqrt_ps(m); }
ENOKI_INLINE Derived round_() const {
return _mm512_roundscale_ps(m, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
}
ENOKI_INLINE Derived trunc_() const {
return _mm512_roundscale_ps(m, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC);
}
template <typename T>
ENOKI_INLINE auto ceil2int_() const {
if constexpr (sizeof(scalar_t<T>) == 4) {
if constexpr (std::is_signed_v<scalar_t<T>>)
return T(_mm512_cvt_roundps_epi32(m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC));
else
return T(_mm512_cvt_roundps_epu32(m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC));
} else {
#if defined(ENOKI_X86_AVX512DQ)
using A = typename T::Array1;
if constexpr (std::is_signed_v<scalar_t<T>>)
return T(
A(_mm512_cvt_roundps_epi64(low(derived()).m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)),
A(_mm512_cvt_roundps_epi64(high(derived()).m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC))
);
else
return T(
A(_mm512_cvt_roundps_epu64(low(derived()).m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)),
A(_mm512_cvt_roundps_epu64(high(derived()).m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC))
);
#else
return Base::template ceil2int_<T>();
#endif
}
}
template <typename T>
ENOKI_INLINE auto floor2int_() const {
if constexpr (sizeof(scalar_t<T>) == 4) {
if constexpr (std::is_signed_v<scalar_t<T>>)
return T(_mm512_cvt_roundps_epi32(m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC));
else
return T(_mm512_cvt_roundps_epu32(m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC));
} else {
#if defined(ENOKI_X86_AVX512DQ)
using A = typename T::Array1;
if constexpr (std::is_signed_v<scalar_t<T>>)
return T(
A(_mm512_cvt_roundps_epi64(low(derived()).m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)),
A(_mm512_cvt_roundps_epi64(high(derived()).m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC))
);
else
return T(
A(_mm512_cvt_roundps_epu64(low(derived()).m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)),
A(_mm512_cvt_roundps_epu64(high(derived()).m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC))
);
#else
return Base::template floor2int_<T>();
#endif
}
}
ENOKI_INLINE Derived fmadd_ (Ref b, Ref c) const { return _mm512_fmadd_ps (m, b.m, c.m); }
ENOKI_INLINE Derived fmsub_ (Ref b, Ref c) const { return _mm512_fmsub_ps (m, b.m, c.m); }
ENOKI_INLINE Derived fnmadd_ (Ref b, Ref c) const { return _mm512_fnmadd_ps (m, b.m, c.m); }
ENOKI_INLINE Derived fnmsub_ (Ref b, Ref c) const { return _mm512_fnmsub_ps (m, b.m, c.m); }
ENOKI_INLINE Derived fmsubadd_(Ref b, Ref c) const { return _mm512_fmsubadd_ps(m, b.m, c.m); }
ENOKI_INLINE Derived fmaddsub_(Ref b, Ref c) const { return _mm512_fmaddsub_ps(m, b.m, c.m); }
template <typename Mask>
static ENOKI_INLINE Derived select_(const Mask &m, Ref t, Ref f) {
return _mm512_mask_blend_ps(m.k, f.m, t.m);
}
template <size_t I0, size_t I1, size_t I2, size_t I3, size_t I4,
size_t I5, size_t I6, size_t I7, size_t I8, size_t I9,
size_t I10, size_t I11, size_t I12, size_t I13, size_t I14,
size_t I15>
ENOKI_INLINE Derived shuffle_() const {
const __m512i idx =
_mm512_setr_epi32(I0, I1, I2, I3, I4, I5, I6, I7, I8,
I9, I10, I11, I12, I13, I14, I15);
return _mm512_permutexvar_ps(idx, m);
}
template <typename Index> ENOKI_INLINE Derived shuffle_(const Index &index) const {
return _mm512_permutexvar_ps(index.m, m);
}
ENOKI_INLINE Derived rcp_() const {
#if defined(ENOKI_X86_AVX512ER)
/* rel err < 2^28, use as is */
return _mm512_rcp28_ps(m);
#else
/* Use best reciprocal approximation available on the current
hardware and refine */
__m512 r = _mm512_rcp14_ps(m); /* rel error < 2^-14 */
/* Refine using one Newton-Raphson iteration */
__m512 t0 = _mm512_add_ps(r, r),
t1 = _mm512_mul_ps(r, m);
r = _mm512_fnmadd_ps(t1, r, t0);
return _mm512_fixupimm_ps(r, m,
_mm512_set1_epi32(0x0087A622), 0);
#endif
}
ENOKI_INLINE Derived rsqrt_() const {
#if defined(ENOKI_X86_AVX512ER)
/* rel err < 2^28, use as is */
return _mm512_rsqrt28_ps(m);
#else
__m512 r = _mm512_rsqrt14_ps(m); /* rel error < 2^-14 */
/* Refine using one Newton-Raphson iteration */
const __m512 c0 = _mm512_set1_ps(0.5f),
c1 = _mm512_set1_ps(3.0f);
__m512 t0 = _mm512_mul_ps(r, c0),
t1 = _mm512_mul_ps(r, m);
r = _mm512_mul_ps(_mm512_fnmadd_ps(t1, r, c1), t0);
return _mm512_fixupimm_ps(r, m,
_mm512_set1_epi32(0x0383A622), 0);
#endif
}
ENOKI_INLINE Derived ldexp_(Ref arg) const { return _mm512_scalef_ps(m, arg.m); }
ENOKI_INLINE std::pair<Derived, Derived> frexp_() const {
return std::make_pair<Derived, Derived>(
_mm512_getmant_ps(m, _MM_MANT_NORM_p5_1, _MM_MANT_SIGN_src),
_mm512_getexp_ps(m));
}
// -----------------------------------------------------------------------
//! @{ \name Horizontal operations
// -----------------------------------------------------------------------
ENOKI_INLINE Value hsum_() const { return hsum(low_() + high_()); }
ENOKI_INLINE Value hprod_() const { return hprod(low_() * high_()); }
ENOKI_INLINE Value hmin_() const { return hmin(min(low_(), high_())); }
ENOKI_INLINE Value hmax_() const { return hmax(max(low_(), high_())); }
//! @}
// -----------------------------------------------------------------------
ENOKI_INLINE void store_(void *ptr) const {
assert((uintptr_t) ptr % 64 == 0);
_mm512_store_ps((Value *) ENOKI_ASSUME_ALIGNED(ptr, 64), m);
}
template <typename Mask>
ENOKI_INLINE void store_(void *ptr, const Mask &mask) const {
assert((uintptr_t) ptr % 64 == 0);
_mm512_mask_store_ps((Value *) ENOKI_ASSUME_ALIGNED(ptr, 64), mask.k, m);
}
ENOKI_INLINE void store_unaligned_(void *ptr) const {
_mm512_storeu_ps((Value *) ptr, m);
}
template <typename Mask>
ENOKI_INLINE void store_unaligned_(void *ptr, const Mask &mask) const {
_mm512_mask_storeu_ps((Value *) ptr, mask.k, m);
}
static ENOKI_INLINE Derived load_(const void *ptr) {
assert((uintptr_t) ptr % 64 == 0);
return _mm512_load_ps((const Value *) ENOKI_ASSUME_ALIGNED(ptr, 64));
}
template <typename Mask>
static ENOKI_INLINE Derived load_(const void *ptr, const Mask &mask) {
assert((uintptr_t) ptr % 64 == 0);
return _mm512_maskz_load_ps(mask.k, (const Value *) ENOKI_ASSUME_ALIGNED(ptr, 64));
}
static ENOKI_INLINE Derived load_unaligned_(const void *ptr) {
return _mm512_loadu_ps((const Value *) ptr);
}
template <typename Mask>
static ENOKI_INLINE Derived load_unaligned_(const void *ptr, const Mask &mask) {
return _mm512_maskz_loadu_ps(mask.k, (const Value *) ptr);
}
static ENOKI_INLINE Derived zero_() { return _mm512_setzero_ps(); }
#if defined(ENOKI_X86_AVX512PF)
template <bool Write, size_t Level, size_t Stride, typename Index, typename Mask>
static ENOKI_INLINE void prefetch_(const void *ptr, const Index &index, const Mask &mask) {
constexpr auto Hint = Level == 1 ? _MM_HINT_T0 : _MM_HINT_T1;
if constexpr (sizeof(scalar_t<Index>) == 4) {
if constexpr (Write)
_mm512_mask_prefetch_i32scatter_ps((void *) ptr, mask.k, index.m, Stride, Hint);
else
_mm512_mask_prefetch_i32gather_ps(index.m, mask.k, ptr, Stride, Hint);
} else {
if constexpr (Write) {
_mm512_mask_prefetch_i64scatter_ps((void *) ptr, low(mask).k, low(index).m, Stride, Hint);
_mm512_mask_prefetch_i64scatter_ps((void *) ptr, high(mask).k, high(index).m, Stride, Hint);
} else {
_mm512_mask_prefetch_i64gather_ps(low(index).m, low(mask).k, ptr, Stride, Hint);
_mm512_mask_prefetch_i64gather_ps(high(index).m, high(mask).k, ptr, Stride, Hint);
}
}
}
#endif
template <size_t Stride, typename Index, typename Mask>
static ENOKI_INLINE Derived gather_(const void *ptr, const Index &index, const Mask &mask) {
if constexpr (sizeof(scalar_t<Index>) == 4) {
return _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask.k, index.m, (const float *) ptr, Stride);
} else {
return detail::concat(
_mm512_mask_i64gather_ps(_mm256_setzero_ps(), low(mask).k, low(index).m, (const float *) ptr, Stride),
_mm512_mask_i64gather_ps(_mm256_setzero_ps(), high(mask).k, high(index).m, (const float *) ptr, Stride));
}
}
template <size_t Stride, typename Index, typename Mask>
ENOKI_INLINE void scatter_(void *ptr, const Index &index, const Mask &mask) const {
if constexpr (sizeof(scalar_t<Index>) == 4) {
_mm512_mask_i32scatter_ps(ptr, mask.k, index.m, m, Stride);
} else {
_mm512_mask_i64scatter_ps(ptr, low(mask).k, low(index).m, low(derived()).m, Stride);
_mm512_mask_i64scatter_ps(ptr, high(mask).k, high(index).m, high(derived()).m, Stride);
}
}
template <typename Mask>
ENOKI_INLINE Value extract_(const Mask &mask) const {
return _mm_cvtss_f32(_mm512_castps512_ps128(_mm512_maskz_compress_ps(mask.k, m)));
}
template <typename Mask>
ENOKI_INLINE size_t compress_(float *&ptr, const Mask &mask) const {
_mm512_storeu_ps(ptr, _mm512_maskz_compress_ps(mask.k, m));
size_t kn = (size_t) _mm_popcnt_u32(mask.k);
ptr += kn;
return kn;
}
#if defined(ENOKI_X86_AVX512CD)
template <size_t Stride, typename Index, typename Mask, typename Func, typename... Args>
static ENOKI_INLINE void transform_(void *mem,
Index index,
const Mask &mask,
const Func &func,
const Args &... args) {
Derived values = _mm512_mask_i32gather_ps(
_mm512_undefined_ps(), mask.k, index.m, mem, (int) Stride);
index.m = _mm512_mask_mov_epi32(_mm512_set1_epi32(-1), mask.k, index.m);
__m512i conflicts = _mm512_conflict_epi32(index.m);
__m512i perm_idx = _mm512_sub_epi32(_mm512_set1_epi32(31), _mm512_lzcnt_epi32(conflicts));
__mmask16 todo = _mm512_mask_test_epi32_mask(mask.k, conflicts, _mm512_set1_epi32(-1));
func(values, args...);
ENOKI_NOUNROLL while (ENOKI_UNLIKELY(!_mm512_kortestz(todo, todo))) {
__mmask16 cur = _mm512_mask_testn_epi32_mask(
todo, conflicts, _mm512_broadcastmw_epi32(todo));
values.m = _mm512_mask_permutexvar_ps(values.m, cur, perm_idx, values.m);
__m512 backup(values.m);
func(values, args...);
values.m = _mm512_mask_mov_ps(backup, cur, values.m);
todo = _mm512_kxor(todo, cur);
}
_mm512_mask_i32scatter_ps(mem, mask.k, index.m, values.m, (int) Stride);
}
#endif
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Masked versions of key operations
// -----------------------------------------------------------------------
template <typename Mask>
ENOKI_INLINE void massign_(const Derived &a, const Mask &mask) { m = _mm512_mask_mov_ps(m, mask.k, a.m); }
template <typename Mask>
ENOKI_INLINE void madd_ (const Derived &a, const Mask &mask) { m = _mm512_mask_add_ps(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void msub_ (const Derived &a, const Mask &mask) { m = _mm512_mask_sub_ps(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void mmul_ (const Derived &a, const Mask &mask) { m = _mm512_mask_mul_ps(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void mdiv_ (const Derived &a, const Mask &mask) { m = _mm512_mask_div_ps(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void mor_ (const Derived &a, const Mask &mask) {
#if defined(ENOKI_X86_AVX512DQ)
m = _mm512_mask_or_ps(m, mask.k, m, a.m);
#else
m = _mm512_castsi512_ps(
_mm512_or_si512(_mm512_castps_si512(m), mask.k,
_mm512_castps_si512(m), _mm512_castps_si512(a.m)));
#endif
}
template <typename Mask>
ENOKI_INLINE void mand_ (const Derived &a, const Mask &mask) {
#if defined(ENOKI_X86_AVX512DQ)
m = _mm512_mask_and_ps(m, mask.k, m, a.m);
#else
m = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(m), mask.k,
_mm512_castps_si512(m),
_mm512_castps_si512(a.m)));
#endif
}
template <typename Mask>
ENOKI_INLINE void mxor_ (const Derived &a, const Mask &mask) {
#if defined(ENOKI_X86_AVX512DQ)
m = _mm512_mask_xor_ps(m, mask.k, m, a.m);
#else
m = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(m), mask.k,
_mm512_castps_si512(m),
_mm512_castps_si512(a.m)));
#endif
}
//! @}
// -----------------------------------------------------------------------
} ENOKI_MAY_ALIAS;
/// Partial overload of StaticArrayImpl using AVX512 intrinsics (double precision)
template <bool IsMask_, typename Derived_> struct alignas(64)
StaticArrayImpl<double, 8, IsMask_, Derived_>
: StaticArrayBase<double, 8, IsMask_, Derived_> {
ENOKI_NATIVE_ARRAY(double, 8, __m512d)
// -----------------------------------------------------------------------
//! @{ \name Value constructors
// -----------------------------------------------------------------------
ENOKI_INLINE StaticArrayImpl(const Value &value) : m(_mm512_set1_pd(value)) { }
ENOKI_INLINE StaticArrayImpl(Value f0, Value f1, Value f2, Value f3,
Value f4, Value f5, Value f6, Value f7)
: m(_mm512_setr_pd(f0, f1, f2, f3, f4, f5, f6, f7)) { }
// -----------------------------------------------------------------------
//! @{ \name Type converting constructors
// -----------------------------------------------------------------------
ENOKI_CONVERT(half)
: m(_mm512_cvtps_pd(
_mm256_cvtph_ps(_mm_loadu_si128((const __m128i *) a.derived().data())))) { }
ENOKI_CONVERT(float) : m(_mm512_cvtps_pd(a.derived().m)) { }
ENOKI_CONVERT(double) : m(a.derived().m) { }
ENOKI_CONVERT(int32_t) : m(_mm512_cvtepi32_pd(a.derived().m)) { }
ENOKI_CONVERT(uint32_t) : m(_mm512_cvtepu32_pd(a.derived().m)) { }
#if defined(ENOKI_X86_AVX512DQ)
ENOKI_CONVERT(int64_t)
: m(_mm512_cvtepi64_pd(a.derived().m)) { }
ENOKI_CONVERT(uint64_t)
: m(_mm512_cvtepu64_pd(a.derived().m)) { }
#elif defined(ENOKI_X86_AVX512CD)
/* Emulate uint64_t -> double conversion instead of falling
back to scalar operations. This is quite a bit faster
(>5.5x for unsigned, > for signed). */
ENOKI_CONVERT(uint64_t) {
using Int64 = int64_array_t<Derived2>;
auto lz = lzcnt(a);
auto shift = (63 - 52) - Int64(lz);
auto abs_shift = abs(shift);
auto nzero_mask = neq(a, 0ull);
auto mant = select(shift > 0, a >> abs_shift, a << abs_shift);
auto exp = sl<52>(uint64_t(1023 + 63) - lz) & nzero_mask;
auto comb = exp | (mant & 0xfffffffffffffull);
m = reinterpret_array<Derived>(comb).m;
}
ENOKI_CONVERT(int64_t) {
auto b = abs(a), lz = lzcnt(b);
auto shift = (63 - 52) - lz;
auto abs_shift = abs(shift);
auto nzero_mask = neq(a, 0ll);
auto mant = select(shift > 0, b >> abs_shift, b << abs_shift);
auto sign = a & 0x8000000000000000ull;
auto exp = sl<52>(int64_t(1023 + 63) - lz) & nzero_mask;
auto comb = exp | (mant & 0xfffffffffffffull) | sign;
m = reinterpret_array<Derived>(comb).m;
}
#endif
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Reinterpreting constructors, mask converters
// -----------------------------------------------------------------------
ENOKI_REINTERPRET(double) : m(a.derived().m) { }
ENOKI_REINTERPRET(int64_t) : m(_mm512_castsi512_pd(a.derived().m)) { }
ENOKI_REINTERPRET(uint64_t) : m(_mm512_castsi512_pd(a.derived().m)) { }
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Converting from/to half size vectors
// -----------------------------------------------------------------------
StaticArrayImpl(const Array1 &a1, const Array2 &a2)
: m(detail::concat(a1.m, a2.m)) { }
ENOKI_INLINE Array1 low_() const { return _mm512_castpd512_pd256(m); }
ENOKI_INLINE Array2 high_() const { return _mm512_extractf64x4_pd(m, 1); }
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Vertical operations
// -----------------------------------------------------------------------
ENOKI_INLINE Derived add_(Ref a) const { return _mm512_add_pd(m, a.m); }
ENOKI_INLINE Derived sub_(Ref a) const { return _mm512_sub_pd(m, a.m); }
ENOKI_INLINE Derived mul_(Ref a) const { return _mm512_mul_pd(m, a.m); }
ENOKI_INLINE Derived div_(Ref a) const { return _mm512_div_pd(m, a.m); }
template <typename T> ENOKI_INLINE Derived or_(const T &a) const {
if constexpr (is_mask_v<T>) {
return _mm512_mask_mov_pd(m, a.k, _mm512_set1_pd(memcpy_cast<Value>(int64_t(-1))));
} else {
#if defined(ENOKI_X86_AVX512DQ)
return _mm512_or_pd(m, a.m);
#else
return _mm512_castsi512_pd(
_mm512_or_si512(_mm512_castpd_si512(m), _mm512_castpd_si512(a.m)));
#endif
}
}
template <typename T> ENOKI_INLINE Derived and_(const T &a) const {
if constexpr (is_mask_v<T>) {
return _mm512_maskz_mov_pd(a.k, m);
} else {
#if defined(ENOKI_X86_AVX512DQ)
return _mm512_and_pd(m, a.m);
#else
return _mm512_castsi512_pd(
_mm512_and_si512(_mm512_castpd_si512(m), _mm512_castpd_si512(a.m)));
#endif
}
}
template <typename T> ENOKI_INLINE Derived andnot_(const T &a) const {
if constexpr (is_mask_v<T>) {
return _mm512_mask_mov_pd(m, a.k, _mm512_setzero_pd());
} else {
#if defined(ENOKI_X86_AVX512DQ)
return _mm512_andnot_pd(a.m, m);
#else
return _mm512_castsi512_pd(
_mm512_andnot_si512(_mm512_castpd_si512(a.m), _mm512_castpd_si512(m)));
#endif
}
}
template <typename T> ENOKI_INLINE Derived xor_(const T &a) const {
if constexpr (is_mask_v<T>) {
const __m512 c = _mm512_set1_pd(memcpy_cast<Value>(int64_t(-1)));
#if defined(ENOKI_X86_AVX512DQ)
return _mm512_mask_xor_pd(m, a.k, m, c);
#else
const __m512i v0 = _mm512_castpd_si512(m);
return _mm512_castsi512_pd(_mm512_mask_xor_epi64(v0, a.k, v0, c));
#endif
} else {
#if defined(ENOKI_X86_AVX512DQ)
return _mm512_xor_pd(m, a.m);
#else
return _mm512_castsi512_pd(
_mm512_xor_si512(_mm512_castpd_si512(m), _mm512_castpd_si512(a.m)));
#endif
}
}
ENOKI_INLINE auto lt_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_pd_mask(m, a.m, _CMP_LT_OQ)); }
ENOKI_INLINE auto gt_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_pd_mask(m, a.m, _CMP_GT_OQ)); }
ENOKI_INLINE auto le_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_pd_mask(m, a.m, _CMP_LE_OQ)); }
ENOKI_INLINE auto ge_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_pd_mask(m, a.m, _CMP_GE_OQ)); }
ENOKI_INLINE auto eq_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_pd_mask(m, a.m, _CMP_EQ_OQ)); }
ENOKI_INLINE auto neq_(Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_pd_mask(m, a.m, _CMP_NEQ_UQ)); }
ENOKI_INLINE Derived abs_() const { return andnot_(Derived(_mm512_set1_pd(-0.0))); }
ENOKI_INLINE Derived min_(Ref b) const { return _mm512_min_pd(b.m, m); }
ENOKI_INLINE Derived max_(Ref b) const { return _mm512_max_pd(b.m, m); }
ENOKI_INLINE Derived ceil_() const { return _mm512_ceil_pd(m); }
ENOKI_INLINE Derived floor_() const { return _mm512_floor_pd(m); }
ENOKI_INLINE Derived sqrt_() const { return _mm512_sqrt_pd(m); }
template <typename T>
ENOKI_INLINE auto ceil2int_() const {
if constexpr (sizeof(scalar_t<T>) == 4) {
if constexpr (std::is_signed_v<scalar_t<T>>)
return T(_mm512_cvt_roundpd_epi32(m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC));
else
return T(_mm512_cvt_roundpd_epu32(m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC));
} else {
#if defined(ENOKI_X86_AVX512DQ)
if constexpr (std::is_signed_v<scalar_t<T>>)
return T(_mm512_cvt_roundpd_epi64(m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC));
else
return T(_mm512_cvt_roundpd_epu64(m, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC));
#else
return Base::template ceil2int_<T>();
#endif
}
}
template <typename T>
ENOKI_INLINE auto floor2int_() const {
if constexpr (sizeof(scalar_t<T>) == 4) {
if constexpr (std::is_signed_v<scalar_t<T>>)
return T(_mm512_cvt_roundpd_epi32(m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC));
else
return T(_mm512_cvt_roundpd_epu32(m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC));
} else {
#if defined(ENOKI_X86_AVX512DQ)
if constexpr (std::is_signed_v<scalar_t<T>>)
return T(_mm512_cvt_roundpd_epi64(m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC));
else
return T(_mm512_cvt_roundpd_epu64(m, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC));
#else
return Base::template floor2int_<T>();
#endif
}
}
ENOKI_INLINE Derived round_() const {
return _mm512_roundscale_pd(m, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
}
ENOKI_INLINE Derived trunc_() const {
return _mm512_roundscale_pd(m, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC);
}
ENOKI_INLINE Derived fmadd_ (Ref b, Ref c) const { return _mm512_fmadd_pd (m, b.m, c.m); }
ENOKI_INLINE Derived fmsub_ (Ref b, Ref c) const { return _mm512_fmsub_pd (m, b.m, c.m); }
ENOKI_INLINE Derived fnmadd_ (Ref b, Ref c) const { return _mm512_fnmadd_pd (m, b.m, c.m); }
ENOKI_INLINE Derived fnmsub_ (Ref b, Ref c) const { return _mm512_fnmsub_pd (m, b.m, c.m); }
ENOKI_INLINE Derived fmsubadd_(Ref b, Ref c) const { return _mm512_fmsubadd_pd(m, b.m, c.m); }
ENOKI_INLINE Derived fmaddsub_(Ref b, Ref c) const { return _mm512_fmaddsub_pd(m, b.m, c.m); }
template <typename Mask>
static ENOKI_INLINE Derived select_(const Mask &m, Ref t, Ref f) {
return _mm512_mask_blend_pd(m.k, f.m, t.m);
}
template <size_t I0, size_t I1, size_t I2, size_t I3, size_t I4, size_t I5,
size_t I6, size_t I7>
ENOKI_INLINE Derived shuffle_() const {
const __m512i idx =
_mm512_setr_epi64(I0, I1, I2, I3, I4, I5, I6, I7);
return _mm512_permutexvar_pd(idx, m);
}
template <typename Index> ENOKI_INLINE Derived shuffle_(const Index &index) const {
return _mm512_permutexvar_pd(index.m, m);
}
ENOKI_INLINE Derived rcp_() const {
/* Use best reciprocal approximation available on the current
hardware and refine */
__m512d r;
#if defined(ENOKI_X86_AVX512ER)
r = _mm512_rcp28_pd(m); /* rel err < 2^28 */
#else
r = _mm512_rcp14_pd(m); /* rel error < 2^-14 */
#endif
/* Refine using 1-2 Newton-Raphson iterations */
ENOKI_UNROLL for (int i = 0; i < (has_avx512er ? 1 : 2); ++i) {
__m512d t0 = _mm512_add_pd(r, r);
__m512d t1 = _mm512_mul_pd(r, m);
r = _mm512_fnmadd_pd(t1, r, t0);
}
return _mm512_fixupimm_pd(r, m,
_mm512_set1_epi32(0x0087A622), 0);
}
ENOKI_INLINE Derived rsqrt_() const {
/* Use best reciprocal square root approximation available
on the current hardware and refine */
__m512d r;
#if defined(ENOKI_X86_AVX512ER)
r = _mm512_rsqrt28_pd(m); /* rel err < 2^28 */
#else
r = _mm512_rsqrt14_pd(m); /* rel error < 2^-14 */
#endif
const __m512d c0 = _mm512_set1_pd(0.5),
c1 = _mm512_set1_pd(3.0);
/* Refine using 1-2 Newton-Raphson iterations */
ENOKI_UNROLL for (int i = 0; i < (has_avx512er ? 1 : 2); ++i) {
__m512d t0 = _mm512_mul_pd(r, c0);
__m512d t1 = _mm512_mul_pd(r, m);
r = _mm512_mul_pd(_mm512_fnmadd_pd(t1, r, c1), t0);
}
return _mm512_fixupimm_pd(r, m,
_mm512_set1_epi32(0x0383A622), 0);
}
ENOKI_INLINE Derived ldexp_(Ref arg) const { return _mm512_scalef_pd(m, arg.m); }
ENOKI_INLINE std::pair<Derived, Derived> frexp_() const {
return std::make_pair<Derived, Derived>(
_mm512_getmant_pd(m, _MM_MANT_NORM_p5_1, _MM_MANT_SIGN_src),
_mm512_getexp_pd(m));
}
// -----------------------------------------------------------------------
//! @{ \name Horizontal operations
// -----------------------------------------------------------------------
ENOKI_INLINE Value hsum_() const { return hsum(low_() + high_()); }
ENOKI_INLINE Value hprod_() const { return hprod(low_() * high_()); }
ENOKI_INLINE Value hmin_() const { return hmin(min(low_(), high_())); }
ENOKI_INLINE Value hmax_() const { return hmax(max(low_(), high_())); }
//! @}
// -----------------------------------------------------------------------
ENOKI_INLINE void store_(void *ptr) const {
assert((uintptr_t) ptr % 64 == 0);
_mm512_store_pd((Value *) ENOKI_ASSUME_ALIGNED(ptr, 64), m);
}
template <typename Mask>
ENOKI_INLINE void store_(void *ptr, const Mask &mask) const {
assert((uintptr_t) ptr % 64 == 0);
_mm512_mask_store_pd((Value *) ENOKI_ASSUME_ALIGNED(ptr, 64), mask.k, m);
}
ENOKI_INLINE void store_unaligned_(void *ptr) const {
_mm512_storeu_pd((Value *) ptr, m);
}
template <typename Mask>
ENOKI_INLINE void store_unaligned_(void *ptr, const Mask &mask) const {
_mm512_mask_storeu_pd((Value *) ptr, mask.k, m);
}
static ENOKI_INLINE Derived load_(const void *ptr) {
assert((uintptr_t) ptr % 64 == 0);
return _mm512_load_pd((const Value *) ENOKI_ASSUME_ALIGNED(ptr, 64));
}
template <typename Mask>
static ENOKI_INLINE Derived load_(const void *ptr, const Mask &mask) {
assert((uintptr_t) ptr % 64 == 0);
return _mm512_maskz_load_pd(mask.k, (const Value *) ENOKI_ASSUME_ALIGNED(ptr, 64));
}
static ENOKI_INLINE Derived load_unaligned_(const void *ptr) {
return _mm512_loadu_pd((const Value *) ptr);
}
template <typename Mask>
static ENOKI_INLINE Derived load_unaligned_(const void *ptr, const Mask &mask) {
return _mm512_maskz_loadu_pd(mask.k, (const Value *) ptr);
}
static ENOKI_INLINE Derived zero_() { return _mm512_setzero_pd(); }
#if defined(ENOKI_X86_AVX512PF)
template <bool Write, size_t Level, size_t Stride, typename Index, typename Mask>
static ENOKI_INLINE void prefetch_(const void *ptr, const Index &index, const Mask &mask) {
constexpr auto Hint = Level == 1 ? _MM_HINT_T0 : _MM_HINT_T1;
if constexpr (sizeof(scalar_t<Index>) == 4) {
if (Write)
_mm512_mask_prefetch_i32scatter_pd((void *) ptr, mask.k, index.m, Stride, Hint);
else
_mm512_mask_prefetch_i32gather_pd(index.m, mask.k, ptr, Stride, Hint);
} else {
if (Write)
_mm512_mask_prefetch_i64scatter_pd((void *) ptr, mask.k, index.m, Stride, Hint);
else
_mm512_mask_prefetch_i64gather_pd(index.m, mask.k, ptr, Stride, Hint);
}
}
#endif
template <size_t Stride, typename Index, typename Mask>
static ENOKI_INLINE Derived gather_(const void *ptr, const Index &index, const Mask &mask) {
if constexpr (sizeof(scalar_t<Index>) == 4)
return _mm512_mask_i32gather_pd(_mm512_setzero_pd(), mask.k, index.m, (const double *) ptr, Stride);
else
return _mm512_mask_i64gather_pd(_mm512_setzero_pd(), mask.k, index.m, (const double *) ptr, Stride);
}
template <size_t Stride, typename Index, typename Mask>
ENOKI_INLINE void scatter_(void *ptr, const Index &index, const Mask &mask) const {
if constexpr (sizeof(scalar_t<Index>) == 4)
_mm512_mask_i32scatter_pd(ptr, mask.k, index.m, m, Stride);
else
_mm512_mask_i64scatter_pd(ptr, mask.k, index.m, m, Stride);
}
template <typename Mask>
ENOKI_INLINE Value extract_(const Mask &mask) const {
return _mm_cvtsd_f64(_mm512_castpd512_pd128(_mm512_maskz_compress_pd(mask.k, m)));
}
template <typename Mask>
ENOKI_INLINE size_t compress_(double *&ptr, const Mask &mask) const {
_mm512_storeu_pd(ptr, _mm512_maskz_compress_pd(mask.k, m));
size_t kn = (size_t) _mm_popcnt_u32(mask.k);
ptr += kn;
return kn;
}
#if defined(ENOKI_X86_AVX512CD)
template <size_t Stride, typename Index, typename Mask, typename Func, typename... Args>
static ENOKI_INLINE void transform_(void *mem,
Index index,
const Mask &mask,
const Func &func,
const Args &... args) {
Derived values = _mm512_mask_i64gather_pd(
_mm512_undefined_pd(), mask.k, index.m, mem, (int) Stride);
index.m = _mm512_mask_mov_epi64(_mm512_set1_epi64(-1), mask.k, index.m);
__m512i conflicts = _mm512_conflict_epi64(index.m);
__m512i perm_idx = _mm512_sub_epi64(_mm512_set1_epi64(63), _mm512_lzcnt_epi64(conflicts));
__mmask8 todo = _mm512_mask_test_epi64_mask(mask.k, conflicts, _mm512_set1_epi64(-1));
func(values, args...);
ENOKI_NOUNROLL while (ENOKI_UNLIKELY(todo)) {
__mmask8 cur = _mm512_mask_testn_epi64_mask(
todo, conflicts, _mm512_broadcastmb_epi64(todo));
values.m = _mm512_mask_permutexvar_pd(values.m, cur, perm_idx, values.m);
__m512d backup(values.m);
func(values, args...);
values.m = _mm512_mask_mov_pd(backup, cur, values.m);
todo ^= cur;
}
_mm512_mask_i64scatter_pd(mem, mask.k, index.m, values.m, (int) Stride);
}
#endif
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Masked versions of key operations
// -----------------------------------------------------------------------
template <typename Mask>
ENOKI_INLINE void massign_(const Derived &a, const Mask &mask) { m = _mm512_mask_mov_pd(m, mask.k, a.m); }
template <typename Mask>
ENOKI_INLINE void madd_ (const Derived &a, const Mask &mask) { m = _mm512_mask_add_pd(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void msub_ (const Derived &a, const Mask &mask) { m = _mm512_mask_sub_pd(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void mmul_ (const Derived &a, const Mask &mask) { m = _mm512_mask_mul_pd(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void mdiv_ (const Derived &a, const Mask &mask) { m = _mm512_mask_div_pd(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void mor_ (const Derived &a, const Mask &mask) {
#if defined(ENOKI_X86_AVX512DQ)
m = _mm512_mask_or_pd(m, mask.k, m, a.m);
#else
m = _mm512_castsi512_pd(_mm512_or_si512(_mm512_castpd_si512(m), mask.k,
_mm512_castpd_si512(m),
_mm512_castpd_si512(a.m)));
#endif
}
template <typename Mask>
ENOKI_INLINE void mand_ (const Derived &a, const Mask &mask) {
#if defined(ENOKI_X86_AVX512DQ)
m = _mm512_mask_and_pd(m, mask.k, m, a.m);
#else
m = _mm512_castsi512_pd(_mm512_and_si512(_mm512_castpd_si512(m), mask.k,
_mm512_castpd_si512(m),
_mm512_castpd_si512(a.m)));
#endif
}
template <typename Mask>
ENOKI_INLINE void mxor_ (const Derived &a, const Mask &mask) {
#if defined(ENOKI_X86_AVX512DQ)
m = _mm512_mask_xor_pd(m, mask.k, m, a.m);
#else
m = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(m), mask.k,
_mm512_castpd_si512(m),
_mm512_castpd_si512(a.m)));
#endif
}
//! @}
// -----------------------------------------------------------------------
} ENOKI_MAY_ALIAS;
/// Partial overload of StaticArrayImpl using AVX512 intrinsics (32 bit integers)
template <typename Value_, bool IsMask_, typename Derived_> struct alignas(64)
StaticArrayImpl<Value_, 16, IsMask_, Derived_, enable_if_int32_t<Value_>>
: StaticArrayBase<Value_, 16, IsMask_, Derived_> {
ENOKI_NATIVE_ARRAY(Value_, 16, __m512i)
// -----------------------------------------------------------------------
//! @{ \name Value constructors
// -----------------------------------------------------------------------
ENOKI_INLINE StaticArrayImpl(Value value) : m(_mm512_set1_epi32((int32_t) value)) { }
ENOKI_INLINE StaticArrayImpl(Value f0, Value f1, Value f2, Value f3,
Value f4, Value f5, Value f6, Value f7,
Value f8, Value f9, Value f10, Value f11,
Value f12, Value f13, Value f14, Value f15)
: m(_mm512_setr_epi32(
(int32_t) f0, (int32_t) f1, (int32_t) f2, (int32_t) f3,
(int32_t) f4, (int32_t) f5, (int32_t) f6, (int32_t) f7,
(int32_t) f8, (int32_t) f9, (int32_t) f10, (int32_t) f11,
(int32_t) f12, (int32_t) f13, (int32_t) f14, (int32_t) f15)) { }
// -----------------------------------------------------------------------
//! @{ \name Type converting constructors
// -----------------------------------------------------------------------
ENOKI_CONVERT(int32_t) : m(a.derived().m) { }
ENOKI_CONVERT(uint32_t) : m(a.derived().m) { }
ENOKI_CONVERT(float) {
m = std::is_signed_v<Value> ? _mm512_cvttps_epi32(a.derived().m)
: _mm512_cvttps_epu32(a.derived().m);
}
ENOKI_CONVERT(double) {
m = std::is_signed_v<Value>
? detail::concat(_mm512_cvttpd_epi32(low(a).m),
_mm512_cvttpd_epi32(high(a).m))
: detail::concat(_mm512_cvttpd_epu32(low(a).m),
_mm512_cvttpd_epu32(high(a).m));
}
ENOKI_CONVERT(int64_t)
: m(detail::concat(_mm512_cvtepi64_epi32(low(a).m),
_mm512_cvtepi64_epi32(high(a).m))) { }
ENOKI_CONVERT(uint64_t)
: m(detail::concat(_mm512_cvtepi64_epi32(low(a).m),
_mm512_cvtepi64_epi32(high(a).m))) { }
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Reinterpreting constructors, mask converters
// -----------------------------------------------------------------------
ENOKI_REINTERPRET(float) : m(_mm512_castps_si512(a.derived().m)) { }
ENOKI_REINTERPRET(int32_t) : m(a.derived().m) { }
ENOKI_REINTERPRET(uint32_t) : m(a.derived().m) { }
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Converting from/to half size vectors
// -----------------------------------------------------------------------
StaticArrayImpl(const Array1 &a1, const Array2 &a2)
: m(detail::concat(a1.m, a2.m)) { }
ENOKI_INLINE Array1 low_() const { return _mm512_castsi512_si256(m); }
ENOKI_INLINE Array2 high_() const {
#if defined(ENOKI_X86_AVX512DQ)
return _mm512_extracti32x8_epi32(m, 1);
#else
return _mm512_extracti64x4_epi64(m, 1);
#endif
}
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Vertical operations
// -----------------------------------------------------------------------
ENOKI_INLINE Derived add_(Ref a) const { return _mm512_add_epi32(m, a.m); }
ENOKI_INLINE Derived sub_(Ref a) const { return _mm512_sub_epi32(m, a.m); }
ENOKI_INLINE Derived mul_(Ref a) const { return _mm512_mullo_epi32(m, a.m); }
template <typename T>
ENOKI_INLINE Derived or_ (const T &a) const {
if constexpr (is_mask_v<T>)
return _mm512_mask_mov_epi32(m, a.k, _mm512_set1_epi32(int32_t(-1)));
else
return _mm512_or_epi32(m, a.m);
}
template <typename T>
ENOKI_INLINE Derived and_ (const T &a) const {
if constexpr (is_mask_v<T>)
return _mm512_maskz_mov_epi32(a.k, m);
else
return _mm512_and_epi32(m, a.m);
}
template <typename T>
ENOKI_INLINE Derived andnot_ (const T &a) const {
if constexpr (is_mask_v<T>)
return _mm512_mask_mov_epi32(m, a.k, _mm512_setzero_si512());
else
return _mm512_andnot_epi32(m, a.m);
}
template <typename T>
ENOKI_INLINE Derived xor_ (const T &a) const {
if constexpr (is_mask_v<T>)
return _mm512_mask_xor_epi32(m, a.k, m, _mm512_set1_epi32(int32_t(-1)));
else
return _mm512_xor_epi32(m, a.m);
}
template <size_t k> ENOKI_INLINE Derived sl_() const {
return _mm512_slli_epi32(m, (int) k);
}
template <size_t k> ENOKI_INLINE Derived sr_() const {
return std::is_signed_v<Value> ? _mm512_srai_epi32(m, (int) k)
: _mm512_srli_epi32(m, (int) k);
}
ENOKI_INLINE Derived sl_(size_t k) const {
return _mm512_sll_epi32(m, _mm_set1_epi64x((long long) k));
}
ENOKI_INLINE Derived sr_(size_t k) const {
return std::is_signed_v<Value>
? _mm512_sra_epi32(m, _mm_set1_epi64x((long long) k))
: _mm512_srl_epi32(m, _mm_set1_epi64x((long long) k));
}
ENOKI_INLINE Derived sl_(Ref k) const {
return _mm512_sllv_epi32(m, k.m);
}
ENOKI_INLINE Derived sr_(Ref k) const {
return std::is_signed_v<Value> ? _mm512_srav_epi32(m, k.m)
: _mm512_srlv_epi32(m, k.m);
}
ENOKI_INLINE Derived rol_(Ref k) const { return _mm512_rolv_epi32(m, k.m); }
ENOKI_INLINE Derived ror_(Ref k) const { return _mm512_rorv_epi32(m, k.m); }
template <size_t Imm>
ENOKI_INLINE Derived rol_() const { return _mm512_rol_epi32(m, (int) Imm); }
template <size_t Imm>
ENOKI_INLINE Derived ror_() const { return _mm512_ror_epi32(m, (int) Imm); }
ENOKI_INLINE auto lt_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_epi32_mask(m, a.m, _MM_CMPINT_LT)); }
ENOKI_INLINE auto gt_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_epi32_mask(m, a.m, _MM_CMPINT_GT)); }
ENOKI_INLINE auto le_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_epi32_mask(m, a.m, _MM_CMPINT_LE)); }
ENOKI_INLINE auto ge_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_epi32_mask(m, a.m, _MM_CMPINT_GE)); }
ENOKI_INLINE auto eq_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_epi32_mask(m, a.m, _MM_CMPINT_EQ)); }
ENOKI_INLINE auto neq_(Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_epi32_mask(m, a.m, _MM_CMPINT_NE)); }
ENOKI_INLINE Derived min_(Ref a) const {
return std::is_signed_v<Value> ? _mm512_min_epi32(a.m, m)
: _mm512_min_epu32(a.m, m);
}
ENOKI_INLINE Derived max_(Ref a) const {
return std::is_signed_v<Value> ? _mm512_max_epi32(a.m, m)
: _mm512_max_epu32(a.m, m);
}
ENOKI_INLINE Derived abs_() const {
return std::is_signed_v<Value> ? _mm512_abs_epi32(m) : m;
}
template <typename Mask>
static ENOKI_INLINE Derived select_(const Mask &m, Ref t, Ref f) {
return _mm512_mask_blend_epi32(m.k, f.m, t.m);
}
template <size_t I0, size_t I1, size_t I2, size_t I3, size_t I4,
size_t I5, size_t I6, size_t I7, size_t I8, size_t I9,
size_t I10, size_t I11, size_t I12, size_t I13, size_t I14,
size_t I15>
ENOKI_INLINE Derived shuffle_() const {
const __m512i idx =
_mm512_setr_epi32(I0, I1, I2, I3, I4, I5, I6, I7, I8,
I9, I10, I11, I12, I13, I14, I15);
return _mm512_permutexvar_epi32(idx, m);
}
template <typename Index> ENOKI_INLINE Derived shuffle_(const Index &index) const {
return _mm512_permutexvar_epi32(index.m, m);
}
ENOKI_INLINE Derived mulhi_(Ref a) const {
auto blend = mask_t<Derived>::from_k(0b0101010101010101);
Derived even, odd;
if constexpr (std::is_signed_v<Value>) {
even.m = _mm512_srli_epi64(_mm512_mul_epi32(m, a.m), 32);
odd.m = _mm512_mul_epi32(_mm512_srli_epi64(m, 32),
_mm512_srli_epi64(a.m, 32));
} else {
even.m = _mm512_srli_epi64(_mm512_mul_epu32(m, a.m), 32);
odd.m = _mm512_mul_epu32(_mm512_srli_epi64(m, 32),
_mm512_srli_epi64(a.m, 32));
}
return select(blend, even, odd);
}
#if defined(ENOKI_X86_AVX512CD)
ENOKI_INLINE Derived lzcnt_() const { return _mm512_lzcnt_epi32(m); }
ENOKI_INLINE Derived tzcnt_() const { return Value(32) - lzcnt(~derived() & (derived() - Value(1))); }
#endif
#if defined(ENOKI_X86_AVX512VPOPCNTDQ)
ENOKI_INLINE Derived popcnt_() const { return _mm512_popcnt_epi32(m); }
#endif
// -----------------------------------------------------------------------
//! @{ \name Horizontal operations
// -----------------------------------------------------------------------
ENOKI_INLINE Value hsum_() const { return hsum(low_() + high_()); }
ENOKI_INLINE Value hprod_() const { return hprod(low_() * high_()); }
ENOKI_INLINE Value hmin_() const { return hmin(min(low_(), high_())); }
ENOKI_INLINE Value hmax_() const { return hmax(max(low_(), high_())); }
//! @}
// -----------------------------------------------------------------------
//
ENOKI_INLINE void store_(void *ptr) const {
assert((uintptr_t) ptr % 64 == 0);
_mm512_store_si512((__m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64), m);
}
template <typename Mask>
ENOKI_INLINE void store_(void *ptr, const Mask &mask) const {
assert((uintptr_t) ptr % 64 == 0);
_mm512_mask_store_epi32((__m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64), mask.k, m);
}
ENOKI_INLINE void store_unaligned_(void *ptr) const {
_mm512_storeu_si512((__m512i *) ptr, m);
}
template <typename Mask>
ENOKI_INLINE void store_unaligned_(void *ptr, const Mask &mask) const {
_mm512_mask_storeu_epi32((__m512i *) ptr, mask.k, m);
}
static ENOKI_INLINE Derived load_(const void *ptr) {
assert((uintptr_t) ptr % 64 == 0);
return _mm512_load_si512((const __m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64));
}
template <typename Mask>
static ENOKI_INLINE Derived load_(const void *ptr, const Mask &mask) {
assert((uintptr_t) ptr % 64 == 0);
return _mm512_maskz_load_epi32(mask.k, (const __m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64));
}
static ENOKI_INLINE Derived load_unaligned_(const void *ptr) {
return _mm512_loadu_si512((const __m512i *) ptr);
}
template <typename Mask>
static ENOKI_INLINE Derived load_unaligned_(const void *ptr, const Mask &mask) {
return _mm512_maskz_loadu_epi32(mask.k, (const __m512i *) ptr);
}
static ENOKI_INLINE Derived zero_() { return _mm512_setzero_si512(); }
#if defined(ENOKI_X86_AVX512PF)
template <bool Write, size_t Level, size_t Stride, typename Index, typename Mask>
static ENOKI_INLINE void prefetch_(const void *ptr, const Index &index, const Mask &mask) {
constexpr auto Hint = Level == 1 ? _MM_HINT_T0 : _MM_HINT_T1;
if constexpr (sizeof(scalar_t<Index>) == 4) {
if (Write)
_mm512_mask_prefetch_i32scatter_ps((void *) ptr, mask.k, index.m, Stride, Hint);
else
_mm512_mask_prefetch_i32gather_ps(index.m, mask.k, ptr, Stride, Hint);
} else {
if (Write) {
_mm512_mask_prefetch_i64scatter_ps((void *) ptr, low(mask).k, low(index).m, Stride, Hint);
_mm512_mask_prefetch_i64scatter_ps((void *) ptr, high(mask).k, high(index).m, Stride, Hint);
} else {
_mm512_mask_prefetch_i64gather_ps(low(index).m, low(mask).k, ptr, Stride, Hint);
_mm512_mask_prefetch_i64gather_ps(high(index).m, high(mask).k, ptr, Stride, Hint);
}
}
}
#endif
template <size_t Stride, typename Index, typename Mask>
static ENOKI_INLINE Derived gather_(const void *ptr, const Index &index, const Mask &mask) {
if constexpr (sizeof(scalar_t<Index>) == 4) {
return _mm512_mask_i32gather_epi32(_mm512_setzero_si512(), mask.k, index.m, (const float *) ptr, Stride);
} else {
return detail::concat(
_mm512_mask_i64gather_epi32(_mm256_setzero_si256(), low(mask).k, low(index).m, (const float *) ptr, Stride),
_mm512_mask_i64gather_epi32(_mm256_setzero_si256(), high(mask).k, high(index).m, (const float *) ptr, Stride));
}
}
template <size_t Stride, typename Index, typename Mask>
ENOKI_INLINE void scatter_(void *ptr, const Index &index, const Mask &mask) const {
if constexpr (sizeof(scalar_t<Index>) == 4) {
_mm512_mask_i32scatter_epi32(ptr, mask.k, index.m, m, Stride);
} else {
_mm512_mask_i64scatter_epi32(ptr, low(mask).k, low(index).m, low(derived()).m, Stride);
_mm512_mask_i64scatter_epi32(ptr, high(mask).k, high(index).m, high(derived()).m, Stride);
}
}
template <typename Mask>
ENOKI_INLINE Value extract_(const Mask &mask) const {
return (Value) _mm_cvtsi128_si32(_mm512_castsi512_si128(_mm512_maskz_compress_epi32(mask.k, m)));
}
template <typename Mask>
ENOKI_INLINE size_t compress_(Value_ *&ptr, const Mask &mask) const {
_mm512_storeu_si512((__m512i *) ptr, _mm512_maskz_compress_epi32(mask.k, m));
size_t kn = (size_t) _mm_popcnt_u32(mask.k);
ptr += kn;
return kn;
}
#if defined(ENOKI_X86_AVX512CD)
template <size_t Stride, typename Index, typename Mask, typename Func, typename... Args>
static ENOKI_INLINE void transform_(void *mem,
Index index,
const Mask &mask,
const Func &func,
const Args &... args) {
Derived values = _mm512_mask_i32gather_epi32(
_mm512_undefined_epi32(), mask.k, index.m, mem, (int) Stride);
index.m = _mm512_mask_mov_epi32(_mm512_set1_epi32(-1), mask.k, index.m);
__m512i conflicts = _mm512_conflict_epi32(index.m);
__m512i perm_idx = _mm512_sub_epi32(_mm512_set1_epi32(31), _mm512_lzcnt_epi32(conflicts));
__mmask16 todo = _mm512_mask_test_epi32_mask(mask.k, conflicts, _mm512_set1_epi32(-1));
func(values, args...);
ENOKI_NOUNROLL while (ENOKI_UNLIKELY(!_mm512_kortestz(todo, todo))) {
__mmask16 cur = _mm512_mask_testn_epi32_mask(
todo, conflicts, _mm512_broadcastmw_epi32(todo));
values.m = _mm512_mask_permutexvar_epi32(values.m, cur, perm_idx, values.m);
__m512i backup(values.m);
func(values, args...);
values.m = _mm512_mask_mov_epi32(backup, cur, values.m);
todo = _mm512_kxor(todo, cur);
}
_mm512_mask_i32scatter_epi32(mem, mask.k, index.m, values.m, (int) Stride);
}
#endif
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Masked versions of key operations
// -----------------------------------------------------------------------
template <typename Mask>
ENOKI_INLINE void massign_(const Derived &a, const Mask &mask) { m = _mm512_mask_mov_epi32(m, mask.k, a.m); }
template <typename Mask>
ENOKI_INLINE void madd_ (const Derived &a, const Mask &mask) { m = _mm512_mask_add_epi32(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void msub_ (const Derived &a, const Mask &mask) { m = _mm512_mask_sub_epi32(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void mmul_ (const Derived &a, const Mask &mask) { m = _mm512_mask_mullo_epi32(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void mor_ (const Derived &a, const Mask &mask) { m = _mm512_mask_or_epi32(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void mand_ (const Derived &a, const Mask &mask) { m = _mm512_mask_and_epi32(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void mxor_ (const Derived &a, const Mask &mask) { m = _mm512_mask_xor_epi32(m, mask.k, m, a.m); }
//! @}
// -----------------------------------------------------------------------
} ENOKI_MAY_ALIAS;
/// Partial overload of StaticArrayImpl using AVX512 intrinsics (64 bit integers)
template <typename Value_, bool IsMask_, typename Derived_> struct alignas(64)
StaticArrayImpl<Value_, 8, IsMask_, Derived_, enable_if_int64_t<Value_>>
: StaticArrayBase<Value_, 8, IsMask_, Derived_> {
ENOKI_NATIVE_ARRAY(Value_, 8, __m512i)
// -----------------------------------------------------------------------
//! @{ \name Value constructors
// -----------------------------------------------------------------------
ENOKI_INLINE StaticArrayImpl(const Value &value) : m(_mm512_set1_epi64((long long) value)) { }
ENOKI_INLINE StaticArrayImpl(Value f0, Value f1, Value f2, Value f3,
Value f4, Value f5, Value f6, Value f7)
: m(_mm512_setr_epi64((long long) f0, (long long) f1, (long long) f2,
(long long) f3, (long long) f4, (long long) f5,
(long long) f6, (long long) f7)) { }
// -----------------------------------------------------------------------
//! @{ \name Type converting constructors
// -----------------------------------------------------------------------
#if defined(ENOKI_X86_AVX512DQ)
ENOKI_CONVERT(float) {
m = std::is_signed_v<Value> ? _mm512_cvttps_epi64(a.derived().m)
: _mm512_cvttps_epu64(a.derived().m);
}
#else
/* Emulate float -> uint64 conversion instead of falling
back to scalar operations. This is quite a bit faster (~4x!) */
ENOKI_CONVERT(float) {
using Int32 = int_array_t<Derived2>;
using UInt32 = uint_array_t<Derived2>;
using UInt64 = uint64_array_t<Derived2>;
/* Shift out sign bit */
auto b = reinterpret_array<UInt32>(a);
b += b;
auto mant = UInt64((b & 0xffffffu) | 0x1000000u);
auto shift = (24 + 127) - Int32(sr<24>(b));
auto abs_shift = UInt64(abs(shift));
auto result = select(shift > 0, mant >> abs_shift, mant << abs_shift);
if constexpr (std::is_signed_v<Value>)
result[a < 0] = -result;
m = result.m;
}
#endif
ENOKI_CONVERT(int32_t)
: m(_mm512_cvtepi32_epi64(a.derived().m)) { }
ENOKI_CONVERT(uint32_t)
: m(_mm512_cvtepu32_epi64(a.derived().m)) { }
#if defined(ENOKI_X86_AVX512DQ)
ENOKI_CONVERT(double) {
m = std::is_signed_v<Value> ? _mm512_cvttpd_epi64(a.derived().m)
: _mm512_cvttpd_epu64(a.derived().m);
}
#else
/* Emulate double -> uint64 conversion instead of falling
back to scalar operations. This is quite a bit faster (>~11x!) */
ENOKI_CONVERT(double) {
using Int64 = int_array_t<Derived2>;
using UInt64 = uint_array_t<Derived2>;
/* Shift out sign bit */
auto b = reinterpret_array<UInt64>(a);
b += b;
auto mant = (b & 0x1fffffffffffffull) | 0x20000000000000ull;
auto shift = (53 + 1023) - Int64(sr<53>(b));
auto abs_shift = UInt64(abs(shift));
auto result = select(shift > 0, mant >> abs_shift, mant << abs_shift);
if constexpr (std::is_signed_v<Value>)
result[a < 0] = -result;
m = result.m;
}
#endif
ENOKI_CONVERT(int64_t) : m(a.derived().m) { }
ENOKI_CONVERT(uint64_t) : m(a.derived().m) { }
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Reinterpreting constructors, mask converters
// -----------------------------------------------------------------------
ENOKI_REINTERPRET(double) : m(_mm512_castpd_si512(a.derived().m)) { }
ENOKI_REINTERPRET(int64_t) : m(a.derived().m) { }
ENOKI_REINTERPRET(uint64_t) : m(a.derived().m) { }
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Converting from/to half size vectors
// -----------------------------------------------------------------------
StaticArrayImpl(const Array1 &a1, const Array2 &a2)
: m(detail::concat(a1.m, a2.m)) { }
ENOKI_INLINE Array1 low_() const { return _mm512_castsi512_si256(m); }
ENOKI_INLINE Array2 high_() const { return _mm512_extracti64x4_epi64(m, 1); }
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Vertical operations
// -----------------------------------------------------------------------
ENOKI_INLINE Derived add_(Ref a) const { return _mm512_add_epi64(m, a.m); }
ENOKI_INLINE Derived sub_(Ref a) const { return _mm512_sub_epi64(m, a.m); }
ENOKI_INLINE Derived mul_(Ref a) const {
#if defined(ENOKI_X86_AVX512DQ) && defined(ENOKI_X86_AVX512VL)
return _mm512_mullo_epi64(m, a.m);
#else
__m512i h0 = _mm512_srli_epi64(m, 32);
__m512i h1 = _mm512_srli_epi64(a.m, 32);
__m512i low = _mm512_mul_epu32(m, a.m);
__m512i mix0 = _mm512_mul_epu32(m, h1);
__m512i mix1 = _mm512_mul_epu32(h0, a.m);
__m512i mix = _mm512_add_epi64(mix0, mix1);
__m512i mix_s = _mm512_slli_epi64(mix, 32);
return _mm512_add_epi64(mix_s, low);
#endif
}
template <typename T>
ENOKI_INLINE Derived or_ (const T &a) const {
if constexpr (is_mask_v<T>)
return _mm512_mask_mov_epi64(m, a.k, _mm512_set1_epi64(int64_t(-1)));
else
return _mm512_or_epi64(m, a.m);
}
template <typename T>
ENOKI_INLINE Derived and_ (const T &a) const {
if constexpr (is_mask_v<T>)
return _mm512_maskz_mov_epi64(a.k, m);
else
return _mm512_and_epi64(m, a.m);
}
template <typename T>
ENOKI_INLINE Derived andnot_ (const T &a) const {
if constexpr (is_mask_v<T>)
return _mm512_mask_mov_epi64(m, a.k, _mm512_setzero_si512());
else
return _mm512_andnot_epi64(m, a.m);
}
template <typename T>
ENOKI_INLINE Derived xor_ (const T &a) const {
if constexpr (is_mask_v<T>)
return _mm512_mask_xor_epi64(m, a.k, m, _mm512_set1_epi64(int64_t(-1)));
else
return _mm512_xor_epi64(m, a.m);
}
template <size_t k> ENOKI_INLINE Derived sl_() const {
return _mm512_slli_epi64(m, (int) k);
}
template <size_t k> ENOKI_INLINE Derived sr_() const {
return std::is_signed_v<Value> ? _mm512_srai_epi64(m, (int) k)
: _mm512_srli_epi64(m, (int) k);
}
ENOKI_INLINE Derived sl_(size_t k) const {
return _mm512_sll_epi64(m, _mm_set1_epi64x((long long) k));
}
ENOKI_INLINE Derived sr_(size_t k) const {
return std::is_signed_v<Value>
? _mm512_sra_epi64(m, _mm_set1_epi64x((long long) k))
: _mm512_srl_epi64(m, _mm_set1_epi64x((long long) k));
}
ENOKI_INLINE Derived sl_(Ref k) const {
return _mm512_sllv_epi64(m, k.m);
}
ENOKI_INLINE Derived sr_(Ref k) const {
return std::is_signed_v<Value> ? _mm512_srav_epi64(m, k.m)
: _mm512_srlv_epi64(m, k.m);
}
ENOKI_INLINE Derived rol_(Ref k) const { return _mm512_rolv_epi64(m, k.m); }
ENOKI_INLINE Derived ror_(Ref k) const { return _mm512_rorv_epi64(m, k.m); }
template <size_t Imm>
ENOKI_INLINE Derived rol_() const { return _mm512_rol_epi64(m, (int) Imm); }
template <size_t Imm>
ENOKI_INLINE Derived ror_() const { return _mm512_ror_epi64(m, (int) Imm); }
ENOKI_INLINE auto lt_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_epi64_mask(m, a.m, _MM_CMPINT_LT)); }
ENOKI_INLINE auto gt_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_epi64_mask(m, a.m, _MM_CMPINT_GT)); }
ENOKI_INLINE auto le_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_epi64_mask(m, a.m, _MM_CMPINT_LE)); }
ENOKI_INLINE auto ge_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_epi64_mask(m, a.m, _MM_CMPINT_GE)); }
ENOKI_INLINE auto eq_ (Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_epi64_mask(m, a.m, _MM_CMPINT_EQ)); }
ENOKI_INLINE auto neq_(Ref a) const { return mask_t<Derived>::from_k(_mm512_cmp_epi64_mask(m, a.m, _MM_CMPINT_NE)); }
ENOKI_INLINE Derived min_(Ref a) const {
return std::is_signed_v<Value> ? _mm512_min_epi64(a.m, m)
: _mm512_min_epu64(a.m, m);
}
ENOKI_INLINE Derived max_(Ref a) const {
return std::is_signed_v<Value> ? _mm512_max_epi64(a.m, m)
: _mm512_max_epu64(a.m, m);
}
ENOKI_INLINE Derived abs_() const {
return std::is_signed_v<Value> ? _mm512_abs_epi64(m) : m;
}
template <typename Mask>
static ENOKI_INLINE Derived select_(const Mask &m, Ref t, Ref f) {
return _mm512_mask_blend_epi64(m.k, f.m, t.m);
}
template <size_t I0, size_t I1, size_t I2, size_t I3, size_t I4, size_t I5,
size_t I6, size_t I7>
ENOKI_INLINE Derived shuffle_() const {
const __m512i idx =
_mm512_setr_epi64(I0, I1, I2, I3, I4, I5, I6, I7);
return _mm512_permutexvar_epi64(idx, m);
}
template <typename Index> ENOKI_INLINE Derived shuffle_(const Index &index) const {
return _mm512_permutexvar_epi64(index.m, m);
}
ENOKI_INLINE Derived mulhi_(Ref b) const {
if (std::is_unsigned_v<Value>) {
const __m512i low_bits = _mm512_set1_epi64(0xffffffffu);
__m512i al = m, bl = b.m;
__m512i ah = _mm512_srli_epi64(al, 32);
__m512i bh = _mm512_srli_epi64(bl, 32);
// 4x unsigned 32x32->64 bit multiplication
__m512i albl = _mm512_mul_epu32(al, bl);
__m512i albh = _mm512_mul_epu32(al, bh);
__m512i ahbl = _mm512_mul_epu32(ah, bl);
__m512i ahbh = _mm512_mul_epu32(ah, bh);
// Calculate a possible carry from the low bits of the multiplication.
__m512i carry = _mm512_add_epi64(
_mm512_srli_epi64(albl, 32),
_mm512_add_epi64(_mm512_and_epi64(albh, low_bits),
_mm512_and_epi64(ahbl, low_bits)));
__m512i s0 = _mm512_add_epi64(ahbh, _mm512_srli_epi64(carry, 32));
__m512i s1 = _mm512_add_epi64(_mm512_srli_epi64(albh, 32),
_mm512_srli_epi64(ahbl, 32));
return _mm512_add_epi64(s0, s1);
} else {
const Derived mask(0xffffffff);
const Derived a = derived();
Derived ah = sr<32>(a), bh = sr<32>(b),
al = a & mask, bl = b & mask;
Derived albl_hi = _mm512_srli_epi64(_mm512_mul_epu32(m, b.m), 32);
Derived t = ah * bl + albl_hi;
Derived w1 = al * bh + (t & mask);
return ah * bh + sr<32>(t) + sr<32>(w1);
}
}
#if defined(ENOKI_X86_AVX512CD)
ENOKI_INLINE Derived lzcnt_() const { return _mm512_lzcnt_epi64(m); }
ENOKI_INLINE Derived tzcnt_() const { return Value(64) - lzcnt(~derived() & (derived() - Value(1))); }
#endif
#if defined(ENOKI_X86_AVX512VPOPCNTDQ)
ENOKI_INLINE Derived popcnt_() const { return _mm512_popcnt_epi64(m); }
#endif
// -----------------------------------------------------------------------
//! @{ \name Horizontal operations
// -----------------------------------------------------------------------
ENOKI_INLINE Value hsum_() const { return hsum(low_() + high_()); }
ENOKI_INLINE Value hprod_() const { return hprod(low_() * high_()); }
ENOKI_INLINE Value hmin_() const { return hmin(min(low_(), high_())); }
ENOKI_INLINE Value hmax_() const { return hmax(max(low_(), high_())); }
//! @}
// -----------------------------------------------------------------------
ENOKI_INLINE void store_(void *ptr) const {
assert((uint64_t) ptr % 64 == 0);
_mm512_store_si512((__m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64), m);
}
template <typename Mask>
ENOKI_INLINE void store_(void *ptr, const Mask &mask) const {
assert((uint64_t) ptr % 64 == 0);
_mm512_mask_store_epi64((__m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64), mask.k, m);
}
ENOKI_INLINE void store_unaligned_(void *ptr) const {
_mm512_storeu_si512((__m512i *) ptr, m);
}
template <typename Mask>
ENOKI_INLINE void store_unaligned_(void *ptr, const Mask &mask) const {
_mm512_mask_storeu_epi64((__m512i *) ptr, mask.k, m);
}
static ENOKI_INLINE Derived load_(const void *ptr) {
assert((uint64_t) ptr % 64 == 0);
return _mm512_load_si512((const __m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64));
}
template <typename Mask>
static ENOKI_INLINE Derived load_(const void *ptr, const Mask &mask) {
assert((uint64_t) ptr % 64 == 0);
return _mm512_maskz_load_epi64(mask.k, (const __m512i *) ENOKI_ASSUME_ALIGNED(ptr, 64));
}
static ENOKI_INLINE Derived load_unaligned_(const void *ptr) {
return _mm512_loadu_si512((const __m512i *) ptr);
}
template <typename Mask>
static ENOKI_INLINE Derived load_unaligned_(const void *ptr, const Mask &mask) {
return _mm512_maskz_loadu_epi64(mask.k, (const __m512i *) ptr);
}
static ENOKI_INLINE Derived zero_() { return _mm512_setzero_si512(); }
#if defined(ENOKI_X86_AVX512PF)
template <bool Write, size_t Level, size_t Stride, typename Index, typename Mask>
static ENOKI_INLINE void prefetch_(const void *ptr, const Index &index, const Mask &mask) {
constexpr auto Hint = Level == 1 ? _MM_HINT_T0 : _MM_HINT_T1;
if constexpr (sizeof(scalar_t<Index>) == 4) {
if constexpr (Write)
_mm512_mask_prefetch_i32scatter_pd((void *) ptr, mask.k, index.m, Stride, Hint);
else
_mm512_mask_prefetch_i32gather_pd(index.m, mask.k, ptr, Stride, Hint);
} else {
if constexpr (Write)
_mm512_mask_prefetch_i64scatter_pd((void *) ptr, mask.k, index.m, Stride, Hint);
else
_mm512_mask_prefetch_i64gather_pd(index.m, mask.k, ptr, Stride, Hint);
}
}
#endif
template <size_t Stride, typename Index, typename Mask>
static ENOKI_INLINE Derived gather_(const void *ptr, const Index &index, const Mask &mask) {
if constexpr (sizeof(scalar_t<Index>) == 4)
return _mm512_mask_i32gather_epi64(_mm512_setzero_si512(), mask.k, index.m, (const float *) ptr, Stride);
else
return _mm512_mask_i64gather_epi64(_mm512_setzero_si512(), mask.k, index.m, (const float *) ptr, Stride);
}
template <size_t Stride, typename Index, typename Mask>
ENOKI_INLINE void scatter_(void *ptr, const Index &index, const Mask &mask) const {
if constexpr (sizeof(scalar_t<Index>) == 4)
_mm512_mask_i32scatter_epi64(ptr, mask.k, index.m, m, Stride);
else
_mm512_mask_i64scatter_epi64(ptr, mask.k, index.m, m, Stride);
}
template <typename Mask>
ENOKI_INLINE Value extract_(const Mask &mask) const {
return (Value) _mm_cvtsi128_si64(_mm512_castsi512_si128(_mm512_maskz_compress_epi64(mask.k, m)));
}
template <typename Mask>
ENOKI_INLINE size_t compress_(Value_ *&ptr, const Mask &mask) const {
_mm512_storeu_si512((__m512i *) ptr, _mm512_maskz_compress_epi64(mask.k, m));
size_t kn = (size_t) _mm_popcnt_u32(mask.k);
ptr += kn;
return kn;
}
#if defined(ENOKI_X86_AVX512CD)
template <size_t Stride, typename Index, typename Mask, typename Func, typename... Args>
static ENOKI_INLINE void transform_(void *mem,
Index index,
const Mask &mask,
const Func &func,
const Args &... args) {
Derived values = _mm512_mask_i64gather_epi64(
_mm512_undefined_epi32(), mask.k, index.m, mem, (int) Stride);
index.m = _mm512_mask_mov_epi64(_mm512_set1_epi64(-1), mask.k, index.m);
__m512i conflicts = _mm512_conflict_epi64(index.m);
__m512i perm_idx = _mm512_sub_epi64(_mm512_set1_epi64(63), _mm512_lzcnt_epi64(conflicts));
__mmask8 todo = _mm512_mask_test_epi64_mask(mask.k, conflicts, _mm512_set1_epi64(-1));
func(values, args...);
ENOKI_NOUNROLL while (ENOKI_UNLIKELY(todo)) {
__mmask8 cur = _mm512_mask_testn_epi64_mask(
todo, conflicts, _mm512_broadcastmb_epi64(todo));
values.m = _mm512_mask_permutexvar_epi64(values.m, cur, perm_idx, values.m);
__m512i backup(values.m);
func(values, args...);
values.m = _mm512_mask_mov_epi64(backup, cur, values.m);
todo ^= cur;
}
_mm512_mask_i64scatter_epi64(mem, mask.k, index.m, values.m, (int) Stride);
}
#endif
//! @}
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
//! @{ \name Masked versions of key operations
// -----------------------------------------------------------------------
template <typename Mask>
ENOKI_INLINE void massign_(const Derived &a, const Mask &mask) { m = _mm512_mask_mov_epi64(m, mask.k, a.m); }
template <typename Mask>
ENOKI_INLINE void madd_ (const Derived &a, const Mask &mask) { m = _mm512_mask_add_epi64(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void msub_ (const Derived &a, const Mask &mask) { m = _mm512_mask_sub_epi64(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void mmul_ (const Derived &a, const Mask &mask) {
#if defined(ENOKI_X86_AVX512DQ) && defined(ENOKI_X86_AVX512VL)
m = _mm512_mask_mullo_epi64(m, mask.k, m, a.m);
#else
m = select(mask, a * derived(), derived()).m;
#endif
}
template <typename Mask>
ENOKI_INLINE void mor_ (const Derived &a, const Mask &mask) { m = _mm512_mask_or_epi64(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void mand_ (const Derived &a, const Mask &mask) { m = _mm512_mask_and_epi64(m, mask.k, m, a.m); }
template <typename Mask>
ENOKI_INLINE void mxor_ (const Derived &a, const Mask &mask) { m = _mm512_mask_xor_epi64(m, mask.k, m, a.m); }
//! @}
// -----------------------------------------------------------------------
} ENOKI_MAY_ALIAS;
template <typename Derived_>
ENOKI_DECLARE_KMASK(float, 16, Derived_, int)
template <typename Derived_>
ENOKI_DECLARE_KMASK(double, 8, Derived_, int)
template <typename Value_, typename Derived_>
ENOKI_DECLARE_KMASK(Value_, 16, Derived_, enable_if_int32_t<Value_>)
template <typename Value_, typename Derived_>
ENOKI_DECLARE_KMASK(Value_, 8, Derived_, enable_if_int64_t<Value_>)
NAMESPACE_END(enoki)