diff --git a/pystencils/include/aesni_rand.h b/pystencils/include/aesni_rand.h index 09327f27b8bc0b3cdd0b16cb8a64e3237b555797..36d5bbf6f8859b25920c569d9476e3485d5c5c2f 100644 --- a/pystencils/include/aesni_rand.h +++ b/pystencils/include/aesni_rand.h @@ -4,7 +4,7 @@ #include <emmintrin.h> // SSE2 #include <wmmintrin.h> // AES -#ifdef __AVX512VL__ +#if defined(__AVX512VL__) || defined(__AVX512F__) #include <immintrin.h> // AVX* #endif #include <cstdint> @@ -33,7 +33,7 @@ QUALIFIERS __m128i aesni1xm128i(const __m128i & in, const __m128i & k) { QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v) { -#ifdef __AVX512VL__ +#if defined(__AVX512VL__) || defined(__AVX512F__) return _mm_cvtepu32_ps(v); #else __m128i v2 = _mm_srli_epi32(v, 1); @@ -46,7 +46,7 @@ QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v) QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x) { -#ifdef __AVX512VL__ +#if defined(__AVX512VL__) || defined(__AVX512F__) return _mm_cvtepu64_pd(x); #else uint64 r[2]; @@ -110,4 +110,5 @@ QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, rnd2 = r[1]; rnd3 = r[2]; rnd4 = r[3]; -} \ No newline at end of file +} + diff --git a/pystencils/include/philox_rand.h b/pystencils/include/philox_rand.h index 283204921079ebfd79e022af53b17963de874cf4..6b19e8c3754841f19f01d94da51c46fe8e32f617 100644 --- a/pystencils/include/philox_rand.h +++ b/pystencils/include/philox_rand.h @@ -1,5 +1,17 @@ #include <cstdint> +#ifdef __SSE__ +#include <emmintrin.h> // SSE2 +#endif +#ifdef __AVX__ +#include <immintrin.h> // AVX* +#else +#include <smmintrin.h> // SSE4 +#ifdef __FMA__ +#include <immintrin.h> // FMA +#endif +#endif + #ifndef __CUDA_ARCH__ #define QUALIFIERS inline #else @@ -78,7 +90,6 @@ QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr } - QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, uint32 key0, uint32 key1, float & rnd1, float & rnd2, float & rnd3, float & rnd4) @@ -100,4 +111,247 @@ QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3 rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); -} \ No newline at end of file +} + +#ifdef __SSE__ +QUALIFIERS void _philox4x32round(__m128i* ctr, __m128i* key) +{ + __m128i lohi0a = _mm_mul_epu32(ctr[0], _mm_set1_epi32(PHILOX_M4x32_0)); + __m128i lohi0b = _mm_mul_epu32(_mm_srli_epi64(ctr[0], 32), _mm_set1_epi32(PHILOX_M4x32_0)); + __m128i lohi1a = _mm_mul_epu32(ctr[2], _mm_set1_epi32(PHILOX_M4x32_1)); + __m128i lohi1b = _mm_mul_epu32(_mm_srli_epi64(ctr[2], 32), _mm_set1_epi32(PHILOX_M4x32_1)); + + lohi0a = _mm_shuffle_epi32(lohi0a, 0xD8); + lohi0b = _mm_shuffle_epi32(lohi0b, 0xD8); + lohi1a = _mm_shuffle_epi32(lohi1a, 0xD8); + lohi1b = _mm_shuffle_epi32(lohi1b, 0xD8); + + __m128i lo0 = _mm_unpacklo_epi32(lohi0a, lohi0b); + __m128i hi0 = _mm_unpackhi_epi32(lohi0a, lohi0b); + __m128i lo1 = _mm_unpacklo_epi32(lohi1a, lohi1b); + __m128i hi1 = _mm_unpackhi_epi32(lohi1a, lohi1b); + + ctr[0] = _mm_xor_si128(_mm_xor_si128(hi1, ctr[1]), key[0]); + ctr[1] = lo1; + ctr[2] = _mm_xor_si128(_mm_xor_si128(hi0, ctr[3]), key[1]); + ctr[3] = lo0; +} + +QUALIFIERS void _philox4x32bumpkey(__m128i* key) +{ + key[0] = _mm_add_epi32(key[0], _mm_set1_epi32(PHILOX_W32_0)); + key[1] = _mm_add_epi32(key[1], _mm_set1_epi32(PHILOX_W32_1)); +} + +QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v) +{ +#if defined(__AVX512VL__) || defined(__AVX512F__) + return _mm_cvtepu32_ps(v); +#else + __m128i v2 = _mm_srli_epi32(v, 1); + __m128i v1 = _mm_and_si128(v, _mm_set1_epi32(1)); + __m128 v2f = _mm_cvtepi32_ps(v2); + __m128 v1f = _mm_cvtepi32_ps(v1); + return _mm_add_ps(_mm_add_ps(v2f, v2f), v1f); +#endif +} + + + +QUALIFIERS void philox_float16(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3, + uint32 key0, uint32 key1, + __m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4) +{ + __m128i key[2] = {_mm_set1_epi32(key0), _mm_set1_epi32(key1)}; + __m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + // convert uint32 to float + rnd1 = _my_cvtepu32_ps(ctr[0]); + rnd2 = _my_cvtepu32_ps(ctr[1]); + rnd3 = _my_cvtepu32_ps(ctr[2]); + rnd4 = _my_cvtepu32_ps(ctr[3]); + // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) +#ifdef __FMA__ + rnd1 = _mm_fmadd_ps(rnd1, _mm_set_ps1(TWOPOW32_INV_FLOAT), _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0)); + rnd2 = _mm_fmadd_ps(rnd2, _mm_set_ps1(TWOPOW32_INV_FLOAT), _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0)); + rnd3 = _mm_fmadd_ps(rnd3, _mm_set_ps1(TWOPOW32_INV_FLOAT), _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0)); + rnd4 = _mm_fmadd_ps(rnd4, _mm_set_ps1(TWOPOW32_INV_FLOAT), _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0)); +#else + rnd1 = _mm_mul_ps(rnd1, _mm_set_ps1(TWOPOW32_INV_FLOAT)); + rnd1 = _mm_add_ps(rnd1, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); + rnd2 = _mm_mul_ps(rnd2, _mm_set_ps1(TWOPOW32_INV_FLOAT)); + rnd2 = _mm_add_ps(rnd2, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); + rnd3 = _mm_mul_ps(rnd3, _mm_set_ps1(TWOPOW32_INV_FLOAT)); + rnd3 = _mm_add_ps(rnd3, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); + rnd4 = _mm_mul_ps(rnd4, _mm_set_ps1(TWOPOW32_INV_FLOAT)); + rnd4 = _mm_add_ps(rnd4, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); +#endif +} +#endif + +#ifdef __AVX__ +QUALIFIERS void _philox4x32round(__m256i* ctr, __m256i* key) +{ + __m256i lohi0a = _mm256_mul_epu32(ctr[0], _mm256_set1_epi32(PHILOX_M4x32_0)); + __m256i lohi0b = _mm256_mul_epu32(_mm256_srli_epi64(ctr[0], 32), _mm256_set1_epi32(PHILOX_M4x32_0)); + __m256i lohi1a = _mm256_mul_epu32(ctr[2], _mm256_set1_epi32(PHILOX_M4x32_1)); + __m256i lohi1b = _mm256_mul_epu32(_mm256_srli_epi64(ctr[2], 32), _mm256_set1_epi32(PHILOX_M4x32_1)); + + lohi0a = _mm256_shuffle_epi32(lohi0a, 0xD8); + lohi0b = _mm256_shuffle_epi32(lohi0b, 0xD8); + lohi1a = _mm256_shuffle_epi32(lohi1a, 0xD8); + lohi1b = _mm256_shuffle_epi32(lohi1b, 0xD8); + + __m256i lo0 = _mm256_unpacklo_epi32(lohi0a, lohi0b); + __m256i hi0 = _mm256_unpackhi_epi32(lohi0a, lohi0b); + __m256i lo1 = _mm256_unpacklo_epi32(lohi1a, lohi1b); + __m256i hi1 = _mm256_unpackhi_epi32(lohi1a, lohi1b); + + ctr[0] = _mm256_xor_si256(_mm256_xor_si256(hi1, ctr[1]), key[0]); + ctr[1] = lo1; + ctr[2] = _mm256_xor_si256(_mm256_xor_si256(hi0, ctr[3]), key[1]); + ctr[3] = lo0; +} + +QUALIFIERS void _philox4x32bumpkey(__m256i* key) +{ + key[0] = _mm256_add_epi32(key[0], _mm256_set1_epi32(PHILOX_W32_0)); + key[1] = _mm256_add_epi32(key[1], _mm256_set1_epi32(PHILOX_W32_1)); +} + +QUALIFIERS __m256 _my256_cvtepu32_ps(const __m256i v) +{ +#if defined(__AVX512VL__) || defined(__AVX512F__) + return _mm256_cvtepu32_ps(v); +#else + __m256i v2 = _mm256_srli_epi32(v, 1); + __m256i v1 = _mm256_and_si256(v, _mm256_set1_epi32(1)); + __m256 v2f = _mm256_cvtepi32_ps(v2); + __m256 v1f = _mm256_cvtepi32_ps(v1); + return _mm256_add_ps(_mm256_add_ps(v2f, v2f), v1f); +#endif +} + +QUALIFIERS __m256 _my256_set_ps1(const float v) +{ + return _mm256_set_ps(v, v, v, v, v, v, v, v); +} + + +QUALIFIERS void philox_float32(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3, + uint32 key0, uint32 key1, + __m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4) +{ + __m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)}; + __m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + // convert uint32 to float + rnd1 = _my256_cvtepu32_ps(ctr[0]); + rnd2 = _my256_cvtepu32_ps(ctr[1]); + rnd3 = _my256_cvtepu32_ps(ctr[2]); + rnd4 = _my256_cvtepu32_ps(ctr[3]); + // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) +#ifdef __FMA__ + rnd1 = _mm256_fmadd_ps(rnd1, _my256_set_ps1(TWOPOW32_INV_FLOAT), _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0)); + rnd2 = _mm256_fmadd_ps(rnd2, _my256_set_ps1(TWOPOW32_INV_FLOAT), _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0)); + rnd3 = _mm256_fmadd_ps(rnd3, _my256_set_ps1(TWOPOW32_INV_FLOAT), _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0)); + rnd4 = _mm256_fmadd_ps(rnd4, _my256_set_ps1(TWOPOW32_INV_FLOAT), _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0)); +#else + rnd1 = _mm256_mul_ps(rnd1, _my256_set_ps1(TWOPOW32_INV_FLOAT)); + rnd1 = _mm256_add_ps(rnd1, _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); + rnd2 = _mm256_mul_ps(rnd2, _my256_set_ps1(TWOPOW32_INV_FLOAT)); + rnd2 = _mm256_add_ps(rnd2, _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); + rnd3 = _mm256_mul_ps(rnd3, _my256_set_ps1(TWOPOW32_INV_FLOAT)); + rnd3 = _mm256_add_ps(rnd3, _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); + rnd4 = _mm256_mul_ps(rnd4, _my256_set_ps1(TWOPOW32_INV_FLOAT)); + rnd4 = _mm256_add_ps(rnd4, _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); +#endif +} +#endif + +#ifdef __AVX512F__ +QUALIFIERS void _philox4x32round(__m512i* ctr, __m512i* key) +{ + __m512i lohi0a = _mm512_mul_epu32(ctr[0], _mm512_set1_epi32(PHILOX_M4x32_0)); + __m512i lohi0b = _mm512_mul_epu32(_mm512_srli_epi64(ctr[0], 32), _mm512_set1_epi32(PHILOX_M4x32_0)); + __m512i lohi1a = _mm512_mul_epu32(ctr[2], _mm512_set1_epi32(PHILOX_M4x32_1)); + __m512i lohi1b = _mm512_mul_epu32(_mm512_srli_epi64(ctr[2], 32), _mm512_set1_epi32(PHILOX_M4x32_1)); + + lohi0a = _mm512_shuffle_epi32(lohi0a, 0xD8); + lohi0b = _mm512_shuffle_epi32(lohi0b, 0xD8); + lohi1a = _mm512_shuffle_epi32(lohi1a, 0xD8); + lohi1b = _mm512_shuffle_epi32(lohi1b, 0xD8); + + __m512i lo0 = _mm512_unpacklo_epi32(lohi0a, lohi0b); + __m512i hi0 = _mm512_unpackhi_epi32(lohi0a, lohi0b); + __m512i lo1 = _mm512_unpacklo_epi32(lohi1a, lohi1b); + __m512i hi1 = _mm512_unpackhi_epi32(lohi1a, lohi1b); + + ctr[0] = _mm512_xor_si512(_mm512_xor_si512(hi1, ctr[1]), key[0]); + ctr[1] = lo1; + ctr[2] = _mm512_xor_si512(_mm512_xor_si512(hi0, ctr[3]), key[1]); + ctr[3] = lo0; +} + +QUALIFIERS void _philox4x32bumpkey(__m512i* key) +{ + key[0] = _mm512_add_epi32(key[0], _mm512_set1_epi32(PHILOX_W32_0)); + key[1] = _mm512_add_epi32(key[1], _mm512_set1_epi32(PHILOX_W32_1)); +} + +QUALIFIERS __m512 _my512_set_ps1(const float v) +{ + return _mm512_set_ps(v, v, v, v, v, v, v, v, v, v, v, v, v, v, v, v); +} + + +QUALIFIERS void philox_float64(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3, + uint32 key0, uint32 key1, + __m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4) +{ + __m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)}; + __m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + // convert uint32 to float + rnd1 = _mm512_cvtepu32_ps(ctr[0]); + rnd2 = _mm512_cvtepu32_ps(ctr[1]); + rnd3 = _mm512_cvtepu32_ps(ctr[2]); + rnd4 = _mm512_cvtepu32_ps(ctr[3]); + // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) + rnd1 = _mm512_fmadd_ps(rnd1, _my512_set_ps1(TWOPOW32_INV_FLOAT), _my512_set_ps1(TWOPOW32_INV_FLOAT/2.0)); + rnd2 = _mm512_fmadd_ps(rnd2, _my512_set_ps1(TWOPOW32_INV_FLOAT), _my512_set_ps1(TWOPOW32_INV_FLOAT/2.0)); + rnd3 = _mm512_fmadd_ps(rnd3, _my512_set_ps1(TWOPOW32_INV_FLOAT), _my512_set_ps1(TWOPOW32_INV_FLOAT/2.0)); + rnd4 = _mm512_fmadd_ps(rnd4, _my512_set_ps1(TWOPOW32_INV_FLOAT), _my512_set_ps1(TWOPOW32_INV_FLOAT/2.0)); +} +#endif +