diff --git a/pystencils/include/philox_rand.h b/pystencils/include/philox_rand.h index 6b19e8c3754841f19f01d94da51c46fe8e32f617..a60f7fdf9857ac89b5c8037cfbfcf11a0012ec4c 100644 --- a/pystencils/include/philox_rand.h +++ b/pystencils/include/philox_rand.h @@ -156,6 +156,53 @@ QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v) #endif } +#if !defined(__AVX512VL__) && defined(__GNUC__) && __GNUC__ >= 5 +__attribute__((optimize("no-associative-math"))) +#endif +QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x) +{ +#ifdef __AVX512VL__ + return _mm_cvtepu64_pd(x); +#else + __m128i xH = _mm_srli_epi64(x, 32); + xH = _mm_or_si128(xH, _mm_castpd_si128(_mm_set1_pd(19342813113834066795298816.))); // 2^84 + __m128i xL = _mm_blend_epi16(x, _mm_castpd_si128(_mm_set1_pd(0x0010000000000000)), 0xcc); // 2^52 + __m128d f = _mm_sub_pd(_mm_castsi128_pd(xH), _mm_set1_pd(19342813118337666422669312.)); // 2^84 + 2^52 + return _mm_add_pd(f, _mm_castsi128_pd(xL)); +#endif +} + +template<bool high> +QUALIFIERS __m128d _uniform_double_hq(__m128i x, __m128i y) +{ + // convert 32 to 64 bit + if (high) + { + x = _mm_unpackhi_epi32(x, _mm_set1_epi32(0)); + y = _mm_unpackhi_epi32(y, _mm_set1_epi32(0));; + } + else + { + x = _mm_unpacklo_epi32(x, _mm_set1_epi32(0)); + y = _mm_unpacklo_epi32(y, _mm_set1_epi32(0));; + } + + // calculate z = x ^ y << (53 - 32)) + __m128i z = _mm_sll_epi64(y, _mm_set_epi64x(53 - 32, 53 - 32)); + z = _mm_xor_si128(x, z); + + // convert uint64 to double + __m128d rs = _my_cvtepu64_pd(z); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) +#ifdef __FMA__ + rs = _mm_fmadd_pd(rs, _mm_set_pd1(TWOPOW53_INV_DOUBLE), _mm_set_pd1(TWOPOW53_INV_DOUBLE/2.0)); +#else + rs = _mm_mul_pd(rs, _mm_set_pd1(TWOPOW53_INV_DOUBLE)); + rs = _mm_add_pd(rs, _mm_set_pd1(TWOPOW53_INV_DOUBLE/2.0)); +#endif + + return rs; +} QUALIFIERS void philox_float16(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3, @@ -197,6 +244,30 @@ QUALIFIERS void philox_float16(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i rnd4 = _mm_add_ps(rnd4, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); #endif } + + +QUALIFIERS void philox_double8(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3, + uint32 key0, uint32 key1, + __m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi) +{ + __m128i key[2] = {_mm_set1_epi32(key0), _mm_set1_epi32(key1)}; + __m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); + rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); + rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); + rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); +} #endif #ifdef __AVX__ @@ -242,11 +313,64 @@ QUALIFIERS __m256 _my256_cvtepu32_ps(const __m256i v) #endif } +#if !defined(__AVX512VL__) && defined(__GNUC__) && __GNUC__ >= 5 +__attribute__((optimize("no-associative-math"))) +#endif +QUALIFIERS __m256d _my256_cvtepu64_pd(const __m256i x) +{ +#ifdef __AVX512VL__ + return _mm256_cvtepu64_pd(x); +#else + __m256i xH = _mm256_srli_epi64(x, 32); + xH = _mm256_or_si256(xH, _mm256_castpd_si256(_mm256_set1_pd(19342813113834066795298816.))); // 2^84 + __m256i xL = _mm256_blend_epi16(x, _mm256_castpd_si256(_mm256_set1_pd(0x0010000000000000)), 0xcc); // 2^52 + __m256d f = _mm256_sub_pd(_mm256_castsi256_pd(xH), _mm256_set1_pd(19342813118337666422669312.)); // 2^84 + 2^52 + return _mm256_add_pd(f, _mm256_castsi256_pd(xL)); +#endif +} + QUALIFIERS __m256 _my256_set_ps1(const float v) { return _mm256_set_ps(v, v, v, v, v, v, v, v); } +QUALIFIERS __m256d _my256_set_pd1(const double v) +{ + return _mm256_set_pd(v, v, v, v); +} + +template<bool high> +QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y) +{ + // convert 32 to 64 bit + if (high) + { + x = _mm256_unpackhi_epi32(x, _mm256_set1_epi32(0)); + y = _mm256_unpackhi_epi32(y, _mm256_set1_epi32(0));; + } + else + { + x = _mm256_unpacklo_epi32(x, _mm256_set1_epi32(0)); + y = _mm256_unpacklo_epi32(y, _mm256_set1_epi32(0));; + } + + // calculate z = x ^ y << (53 - 32)) + __m256i z = _mm256_sll_epi64(y, _mm_set_epi64x(53 - 32, 53 - 32)); + z = _mm256_xor_si256(x, z); + + // convert uint64 to double + __m256d rs = _my256_cvtepu64_pd(z); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) +#ifdef __FMA__ + rs = _mm256_fmadd_pd(rs, _my256_set_pd1(TWOPOW53_INV_DOUBLE), _my256_set_pd1(TWOPOW53_INV_DOUBLE/2.0)); +#else + rs = _mm256_mul_pd(rs, _my256_set_pd1(TWOPOW53_INV_DOUBLE)); + rs = _mm256_add_pd(rs, _my256_set_pd1(TWOPOW53_INV_DOUBLE/2.0)); +#endif + + return rs; +} + QUALIFIERS void philox_float32(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3, uint32 key0, uint32 key1, @@ -287,6 +411,30 @@ QUALIFIERS void philox_float32(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i rnd4 = _mm256_add_ps(rnd4, _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); #endif } + + +QUALIFIERS void philox_double16(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3, + uint32 key0, uint32 key1, + __m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi) +{ + __m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)}; + __m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); + rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); + rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); + rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); +} #endif #ifdef __AVX512F__ @@ -324,6 +472,38 @@ QUALIFIERS __m512 _my512_set_ps1(const float v) return _mm512_set_ps(v, v, v, v, v, v, v, v, v, v, v, v, v, v, v, v); } +QUALIFIERS __m512d _my512_set_pd1(const double v) +{ + return _mm512_set_pd(v, v, v, v, v, v, v, v); +} + +template<bool high> +QUALIFIERS __m512d _uniform_double_hq(__m512i x, __m512i y) +{ + // convert 32 to 64 bit + if (high) + { + x = _mm512_unpackhi_epi32(x, _mm512_set1_epi32(0)); + y = _mm512_unpackhi_epi32(y, _mm512_set1_epi32(0));; + } + else + { + x = _mm512_unpacklo_epi32(x, _mm512_set1_epi32(0)); + y = _mm512_unpacklo_epi32(y, _mm512_set1_epi32(0));; + } + + // calculate z = x ^ y << (53 - 32)) + __m512i z = _mm512_sll_epi64(y, _mm_set_epi64x(53 - 32, 53 - 32)); + z = _mm512_xor_si512(x, z); + + // convert uint64 to double + __m512d rs = _mm512_cvtepu64_pd(z); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) + rs = _mm512_fmadd_pd(rs, _my512_set_pd1(TWOPOW53_INV_DOUBLE), _my512_set_pd1(TWOPOW53_INV_DOUBLE/2.0)); + + return rs; +} + QUALIFIERS void philox_float64(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3, uint32 key0, uint32 key1, @@ -353,5 +533,29 @@ QUALIFIERS void philox_float64(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i rnd3 = _mm512_fmadd_ps(rnd3, _my512_set_ps1(TWOPOW32_INV_FLOAT), _my512_set_ps1(TWOPOW32_INV_FLOAT/2.0)); rnd4 = _mm512_fmadd_ps(rnd4, _my512_set_ps1(TWOPOW32_INV_FLOAT), _my512_set_ps1(TWOPOW32_INV_FLOAT/2.0)); } + + +QUALIFIERS void philox_double32(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3, + uint32 key0, uint32 key1, + __m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi) +{ + __m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)}; + __m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); + rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); + rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); + rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); +} #endif