Skip to content
Snippets Groups Projects
Commit f8b730e4 authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

Philox SIMD: cleanup

parent 891df9cc
No related branches found
No related tags found
No related merge requests found
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <emmintrin.h> // SSE2 #include <emmintrin.h> // SSE2
#include <wmmintrin.h> // AES #include <wmmintrin.h> // AES
#if defined(__AVX512VL__) || defined(__AVX512F__) #ifdef __AVX512VL__
#include <immintrin.h> // AVX* #include <immintrin.h> // AVX*
#endif #endif
#include <cstdint> #include <cstdint>
...@@ -33,7 +33,7 @@ QUALIFIERS __m128i aesni1xm128i(const __m128i & in, const __m128i & k) { ...@@ -33,7 +33,7 @@ QUALIFIERS __m128i aesni1xm128i(const __m128i & in, const __m128i & k) {
QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v) QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v)
{ {
#if defined(__AVX512VL__) || defined(__AVX512F__) #ifdef __AVX512VL__
return _mm_cvtepu32_ps(v); return _mm_cvtepu32_ps(v);
#else #else
__m128i v2 = _mm_srli_epi32(v, 1); __m128i v2 = _mm_srli_epi32(v, 1);
...@@ -46,7 +46,7 @@ QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v) ...@@ -46,7 +46,7 @@ QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v)
QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x) QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x)
{ {
#if defined(__AVX512VL__) || defined(__AVX512F__) #ifdef __AVX512VL__
return _mm_cvtepu64_pd(x); return _mm_cvtepu64_pd(x);
#else #else
uint64 r[2]; uint64 r[2];
...@@ -110,5 +110,4 @@ QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, ...@@ -110,5 +110,4 @@ QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
rnd2 = r[1]; rnd2 = r[1];
rnd3 = r[2]; rnd3 = r[2];
rnd4 = r[3]; rnd4 = r[3];
} }
\ No newline at end of file
#include <cstdint> #include <cstdint>
#ifdef __SSE__ #ifdef __SSE4_1__
#include <emmintrin.h> // SSE2 #include <emmintrin.h> // SSE2
#endif #endif
#ifdef __AVX__ #ifdef __AVX2__
#include <immintrin.h> // AVX* #include <immintrin.h> // AVX*
#else #else
#include <smmintrin.h> // SSE4 #include <smmintrin.h> // SSE4
...@@ -113,7 +113,8 @@ QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3 ...@@ -113,7 +113,8 @@ QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3
rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
} }
#ifdef __SSE__ #ifndef __CUDA_ARCH__
#ifdef __SSE4_1__
QUALIFIERS void _philox4x32round(__m128i* ctr, __m128i* key) QUALIFIERS void _philox4x32round(__m128i* ctr, __m128i* key)
{ {
__m128i lohi0a = _mm_mul_epu32(ctr[0], _mm_set1_epi32(PHILOX_M4x32_0)); __m128i lohi0a = _mm_mul_epu32(ctr[0], _mm_set1_epi32(PHILOX_M4x32_0));
...@@ -156,7 +157,7 @@ QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v) ...@@ -156,7 +157,7 @@ QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v)
#endif #endif
} }
#if !defined(__AVX512VL__) && defined(__GNUC__) && __GNUC__ >= 5 #if !defined(__AVX512VL__)&& !defined(__AVX512F__) && defined(__GNUC__) && __GNUC__ >= 5
__attribute__((optimize("no-associative-math"))) __attribute__((optimize("no-associative-math")))
#endif #endif
QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x) QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x)
...@@ -178,36 +179,36 @@ QUALIFIERS __m128d _uniform_double_hq(__m128i x, __m128i y) ...@@ -178,36 +179,36 @@ QUALIFIERS __m128d _uniform_double_hq(__m128i x, __m128i y)
// convert 32 to 64 bit // convert 32 to 64 bit
if (high) if (high)
{ {
x = _mm_unpackhi_epi32(x, _mm_set1_epi32(0)); x = _mm_unpackhi_epi32(x, _mm_setzero_si128());
y = _mm_unpackhi_epi32(y, _mm_set1_epi32(0));; y = _mm_unpackhi_epi32(y, _mm_setzero_si128());
} }
else else
{ {
x = _mm_unpacklo_epi32(x, _mm_set1_epi32(0)); x = _mm_unpacklo_epi32(x, _mm_setzero_si128());
y = _mm_unpacklo_epi32(y, _mm_set1_epi32(0));; y = _mm_unpacklo_epi32(y, _mm_setzero_si128());
} }
// calculate z = x ^ y << (53 - 32)) // calculate z = x ^ y << (53 - 32))
__m128i z = _mm_sll_epi64(y, _mm_set_epi64x(53 - 32, 53 - 32)); __m128i z = _mm_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm_xor_si128(x, z); z = _mm_xor_si128(x, z);
// convert uint64 to double // convert uint64 to double
__m128d rs = _my_cvtepu64_pd(z); __m128d rs = _my_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
#ifdef __FMA__ #ifdef __FMA__
rs = _mm_fmadd_pd(rs, _mm_set_pd1(TWOPOW53_INV_DOUBLE), _mm_set_pd1(TWOPOW53_INV_DOUBLE/2.0)); rs = _mm_fmadd_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE), _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#else #else
rs = _mm_mul_pd(rs, _mm_set_pd1(TWOPOW53_INV_DOUBLE)); rs = _mm_mul_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE));
rs = _mm_add_pd(rs, _mm_set_pd1(TWOPOW53_INV_DOUBLE/2.0)); rs = _mm_add_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#endif #endif
return rs; return rs;
} }
QUALIFIERS void philox_float16(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3, QUALIFIERS void philox_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3,
uint32 key0, uint32 key1, uint32 key0, uint32 key1,
__m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4) __m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4)
{ {
__m128i key[2] = {_mm_set1_epi32(key0), _mm_set1_epi32(key1)}; __m128i key[2] = {_mm_set1_epi32(key0), _mm_set1_epi32(key1)};
__m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; __m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
...@@ -229,24 +230,24 @@ QUALIFIERS void philox_float16(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ...@@ -229,24 +230,24 @@ QUALIFIERS void philox_float16(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i
rnd4 = _my_cvtepu32_ps(ctr[3]); rnd4 = _my_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
#ifdef __FMA__ #ifdef __FMA__
rnd1 = _mm_fmadd_ps(rnd1, _mm_set_ps1(TWOPOW32_INV_FLOAT), _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0)); rnd1 = _mm_fmadd_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm_fmadd_ps(rnd2, _mm_set_ps1(TWOPOW32_INV_FLOAT), _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0)); rnd2 = _mm_fmadd_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm_fmadd_ps(rnd3, _mm_set_ps1(TWOPOW32_INV_FLOAT), _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0)); rnd3 = _mm_fmadd_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm_fmadd_ps(rnd4, _mm_set_ps1(TWOPOW32_INV_FLOAT), _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0)); rnd4 = _mm_fmadd_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
#else #else
rnd1 = _mm_mul_ps(rnd1, _mm_set_ps1(TWOPOW32_INV_FLOAT)); rnd1 = _mm_mul_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd1 = _mm_add_ps(rnd1, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); rnd1 = _mm_add_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = _mm_mul_ps(rnd2, _mm_set_ps1(TWOPOW32_INV_FLOAT)); rnd2 = _mm_mul_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd2 = _mm_add_ps(rnd2, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); rnd2 = _mm_add_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = _mm_mul_ps(rnd3, _mm_set_ps1(TWOPOW32_INV_FLOAT)); rnd3 = _mm_mul_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd3 = _mm_add_ps(rnd3, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); rnd3 = _mm_add_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = _mm_mul_ps(rnd4, _mm_set_ps1(TWOPOW32_INV_FLOAT)); rnd4 = _mm_mul_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd4 = _mm_add_ps(rnd4, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); rnd4 = _mm_add_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#endif #endif
} }
QUALIFIERS void philox_double8(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3, QUALIFIERS void philox_double2(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3,
uint32 key0, uint32 key1, uint32 key0, uint32 key1,
__m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi) __m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi)
{ {
...@@ -270,7 +271,7 @@ QUALIFIERS void philox_double8(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ...@@ -270,7 +271,7 @@ QUALIFIERS void philox_double8(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i
} }
#endif #endif
#ifdef __AVX__ #ifdef __AVX2__
QUALIFIERS void _philox4x32round(__m256i* ctr, __m256i* key) QUALIFIERS void _philox4x32round(__m256i* ctr, __m256i* key)
{ {
__m256i lohi0a = _mm256_mul_epu32(ctr[0], _mm256_set1_epi32(PHILOX_M4x32_0)); __m256i lohi0a = _mm256_mul_epu32(ctr[0], _mm256_set1_epi32(PHILOX_M4x32_0));
...@@ -313,7 +314,7 @@ QUALIFIERS __m256 _my256_cvtepu32_ps(const __m256i v) ...@@ -313,7 +314,7 @@ QUALIFIERS __m256 _my256_cvtepu32_ps(const __m256i v)
#endif #endif
} }
#if !defined(__AVX512VL__) && defined(__GNUC__) && __GNUC__ >= 5 #if !defined(__AVX512VL__) && !defined(__AVX512F__) && defined(__GNUC__) && __GNUC__ >= 5
__attribute__((optimize("no-associative-math"))) __attribute__((optimize("no-associative-math")))
#endif #endif
QUALIFIERS __m256d _my256_cvtepu64_pd(const __m256i x) QUALIFIERS __m256d _my256_cvtepu64_pd(const __m256i x)
...@@ -329,52 +330,42 @@ QUALIFIERS __m256d _my256_cvtepu64_pd(const __m256i x) ...@@ -329,52 +330,42 @@ QUALIFIERS __m256d _my256_cvtepu64_pd(const __m256i x)
#endif #endif
} }
QUALIFIERS __m256 _my256_set_ps1(const float v)
{
return _mm256_set_ps(v, v, v, v, v, v, v, v);
}
QUALIFIERS __m256d _my256_set_pd1(const double v)
{
return _mm256_set_pd(v, v, v, v);
}
template<bool high> template<bool high>
QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y) QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y)
{ {
// convert 32 to 64 bit // convert 32 to 64 bit
if (high) if (high)
{ {
x = _mm256_unpackhi_epi32(x, _mm256_set1_epi32(0)); x = _mm256_unpackhi_epi32(x, _mm256_setzero_si256());
y = _mm256_unpackhi_epi32(y, _mm256_set1_epi32(0));; y = _mm256_unpackhi_epi32(y, _mm256_setzero_si256());
} }
else else
{ {
x = _mm256_unpacklo_epi32(x, _mm256_set1_epi32(0)); x = _mm256_unpacklo_epi32(x, _mm256_setzero_si256());
y = _mm256_unpacklo_epi32(y, _mm256_set1_epi32(0));; y = _mm256_unpacklo_epi32(y, _mm256_setzero_si256());
} }
// calculate z = x ^ y << (53 - 32)) // calculate z = x ^ y << (53 - 32))
__m256i z = _mm256_sll_epi64(y, _mm_set_epi64x(53 - 32, 53 - 32)); __m256i z = _mm256_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm256_xor_si256(x, z); z = _mm256_xor_si256(x, z);
// convert uint64 to double // convert uint64 to double
__m256d rs = _my256_cvtepu64_pd(z); __m256d rs = _my256_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
#ifdef __FMA__ #ifdef __FMA__
rs = _mm256_fmadd_pd(rs, _my256_set_pd1(TWOPOW53_INV_DOUBLE), _my256_set_pd1(TWOPOW53_INV_DOUBLE/2.0)); rs = _mm256_fmadd_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE), _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#else #else
rs = _mm256_mul_pd(rs, _my256_set_pd1(TWOPOW53_INV_DOUBLE)); rs = _mm256_mul_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE));
rs = _mm256_add_pd(rs, _my256_set_pd1(TWOPOW53_INV_DOUBLE/2.0)); rs = _mm256_add_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#endif #endif
return rs; return rs;
} }
QUALIFIERS void philox_float32(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3, QUALIFIERS void philox_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3,
uint32 key0, uint32 key1, uint32 key0, uint32 key1,
__m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4) __m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4)
{ {
__m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)}; __m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)};
__m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; __m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
...@@ -396,26 +387,26 @@ QUALIFIERS void philox_float32(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ...@@ -396,26 +387,26 @@ QUALIFIERS void philox_float32(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i
rnd4 = _my256_cvtepu32_ps(ctr[3]); rnd4 = _my256_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
#ifdef __FMA__ #ifdef __FMA__
rnd1 = _mm256_fmadd_ps(rnd1, _my256_set_ps1(TWOPOW32_INV_FLOAT), _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0)); rnd1 = _mm256_fmadd_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm256_fmadd_ps(rnd2, _my256_set_ps1(TWOPOW32_INV_FLOAT), _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0)); rnd2 = _mm256_fmadd_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm256_fmadd_ps(rnd3, _my256_set_ps1(TWOPOW32_INV_FLOAT), _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0)); rnd3 = _mm256_fmadd_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm256_fmadd_ps(rnd4, _my256_set_ps1(TWOPOW32_INV_FLOAT), _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0)); rnd4 = _mm256_fmadd_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
#else #else
rnd1 = _mm256_mul_ps(rnd1, _my256_set_ps1(TWOPOW32_INV_FLOAT)); rnd1 = _mm256_mul_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd1 = _mm256_add_ps(rnd1, _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); rnd1 = _mm256_add_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = _mm256_mul_ps(rnd2, _my256_set_ps1(TWOPOW32_INV_FLOAT)); rnd2 = _mm256_mul_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd2 = _mm256_add_ps(rnd2, _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); rnd2 = _mm256_add_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = _mm256_mul_ps(rnd3, _my256_set_ps1(TWOPOW32_INV_FLOAT)); rnd3 = _mm256_mul_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd3 = _mm256_add_ps(rnd3, _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); rnd3 = _mm256_add_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = _mm256_mul_ps(rnd4, _my256_set_ps1(TWOPOW32_INV_FLOAT)); rnd4 = _mm256_mul_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd4 = _mm256_add_ps(rnd4, _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); rnd4 = _mm256_add_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#endif #endif
} }
QUALIFIERS void philox_double16(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3, QUALIFIERS void philox_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3,
uint32 key0, uint32 key1, uint32 key0, uint32 key1,
__m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi) __m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi)
{ {
__m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)}; __m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)};
__m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; __m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
...@@ -467,47 +458,37 @@ QUALIFIERS void _philox4x32bumpkey(__m512i* key) ...@@ -467,47 +458,37 @@ QUALIFIERS void _philox4x32bumpkey(__m512i* key)
key[1] = _mm512_add_epi32(key[1], _mm512_set1_epi32(PHILOX_W32_1)); key[1] = _mm512_add_epi32(key[1], _mm512_set1_epi32(PHILOX_W32_1));
} }
QUALIFIERS __m512 _my512_set_ps1(const float v)
{
return _mm512_set_ps(v, v, v, v, v, v, v, v, v, v, v, v, v, v, v, v);
}
QUALIFIERS __m512d _my512_set_pd1(const double v)
{
return _mm512_set_pd(v, v, v, v, v, v, v, v);
}
template<bool high> template<bool high>
QUALIFIERS __m512d _uniform_double_hq(__m512i x, __m512i y) QUALIFIERS __m512d _uniform_double_hq(__m512i x, __m512i y)
{ {
// convert 32 to 64 bit // convert 32 to 64 bit
if (high) if (high)
{ {
x = _mm512_unpackhi_epi32(x, _mm512_set1_epi32(0)); x = _mm512_unpackhi_epi32(x, _mm512_setzero_si512());
y = _mm512_unpackhi_epi32(y, _mm512_set1_epi32(0));; y = _mm512_unpackhi_epi32(y, _mm512_setzero_si512());
} }
else else
{ {
x = _mm512_unpacklo_epi32(x, _mm512_set1_epi32(0)); x = _mm512_unpacklo_epi32(x, _mm512_setzero_si512());
y = _mm512_unpacklo_epi32(y, _mm512_set1_epi32(0));; y = _mm512_unpacklo_epi32(y, _mm512_setzero_si512());
} }
// calculate z = x ^ y << (53 - 32)) // calculate z = x ^ y << (53 - 32))
__m512i z = _mm512_sll_epi64(y, _mm_set_epi64x(53 - 32, 53 - 32)); __m512i z = _mm512_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm512_xor_si512(x, z); z = _mm512_xor_si512(x, z);
// convert uint64 to double // convert uint64 to double
__m512d rs = _mm512_cvtepu64_pd(z); __m512d rs = _mm512_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = _mm512_fmadd_pd(rs, _my512_set_pd1(TWOPOW53_INV_DOUBLE), _my512_set_pd1(TWOPOW53_INV_DOUBLE/2.0)); rs = _mm512_fmadd_pd(rs, _mm512_set1_pd(TWOPOW53_INV_DOUBLE), _mm512_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
return rs; return rs;
} }
QUALIFIERS void philox_float64(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3, QUALIFIERS void philox_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3,
uint32 key0, uint32 key1, uint32 key0, uint32 key1,
__m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4) __m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4)
{ {
__m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)}; __m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)};
__m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; __m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
...@@ -528,16 +509,16 @@ QUALIFIERS void philox_float64(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ...@@ -528,16 +509,16 @@ QUALIFIERS void philox_float64(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i
rnd3 = _mm512_cvtepu32_ps(ctr[2]); rnd3 = _mm512_cvtepu32_ps(ctr[2]);
rnd4 = _mm512_cvtepu32_ps(ctr[3]); rnd4 = _mm512_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = _mm512_fmadd_ps(rnd1, _my512_set_ps1(TWOPOW32_INV_FLOAT), _my512_set_ps1(TWOPOW32_INV_FLOAT/2.0)); rnd1 = _mm512_fmadd_ps(rnd1, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm512_fmadd_ps(rnd2, _my512_set_ps1(TWOPOW32_INV_FLOAT), _my512_set_ps1(TWOPOW32_INV_FLOAT/2.0)); rnd2 = _mm512_fmadd_ps(rnd2, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm512_fmadd_ps(rnd3, _my512_set_ps1(TWOPOW32_INV_FLOAT), _my512_set_ps1(TWOPOW32_INV_FLOAT/2.0)); rnd3 = _mm512_fmadd_ps(rnd3, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm512_fmadd_ps(rnd4, _my512_set_ps1(TWOPOW32_INV_FLOAT), _my512_set_ps1(TWOPOW32_INV_FLOAT/2.0)); rnd4 = _mm512_fmadd_ps(rnd4, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
} }
QUALIFIERS void philox_double32(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3, QUALIFIERS void philox_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3,
uint32 key0, uint32 key1, uint32 key0, uint32 key1,
__m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi) __m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi)
{ {
__m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)}; __m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)};
__m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; __m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
...@@ -558,4 +539,5 @@ QUALIFIERS void philox_double32(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512 ...@@ -558,4 +539,5 @@ QUALIFIERS void philox_double32(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
} }
#endif #endif
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment