diff --git a/pystencils/include/aesni_rand.h b/pystencils/include/aesni_rand.h index b8efcd4b53f4afbb0e88a17746824de4153295bf..4646b17c15d7d0f8fa28e2b01b160a632d080f4f 100644 --- a/pystencils/include/aesni_rand.h +++ b/pystencils/include/aesni_rand.h @@ -104,6 +104,39 @@ QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, } +template<bool high> +QUALIFIERS __m128d _uniform_double_hq(__m128i x, __m128i y) +{ + // convert 32 to 64 bit + if (high) + { + x = _mm_unpackhi_epi32(x, _mm_setzero_si128()); + y = _mm_unpackhi_epi32(y, _mm_setzero_si128()); + } + else + { + x = _mm_unpacklo_epi32(x, _mm_setzero_si128()); + y = _mm_unpacklo_epi32(y, _mm_setzero_si128()); + } + + // calculate z = x ^ y << (53 - 32)) + __m128i z = _mm_sll_epi64(y, _mm_set1_epi64x(53 - 32)); + z = _mm_xor_si128(x, z); + + // convert uint64 to double + __m128d rs = _my_cvtepu64_pd(z); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) +#ifdef __FMA__ + rs = _mm_fmadd_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE), _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0)); +#else + rs = _mm_mul_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE)); + rs = _mm_add_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0)); +#endif + + return rs; +} + + QUALIFIERS void aesni_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3, uint32 key0, uint32 key1, uint32 key2, uint32 key3, __m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4) @@ -142,6 +175,27 @@ QUALIFIERS void aesni_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i c } +QUALIFIERS void aesni_double2(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi) +{ + // pack input and call AES + __m128i k128 = _mm_set_epi32(key3, key2, key1, key0); + __m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = aesni1xm128i(ctr[i], k128); + } + _MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]); + + rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); + rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); + rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); + rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); +} + + #ifdef __AVX2__ QUALIFIERS __m256i aesni1xm128i(const __m256i & in, const __m256i & k) { __m256i x = _mm256_xor_si256(k, in); @@ -164,6 +218,38 @@ QUALIFIERS __m256i aesni1xm128i(const __m256i & in, const __m256i & k) { return x; } +template<bool high> +QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y) +{ + // convert 32 to 64 bit + if (high) + { + x = _mm256_unpackhi_epi32(x, _mm256_setzero_si256()); + y = _mm256_unpackhi_epi32(y, _mm256_setzero_si256()); + } + else + { + x = _mm256_unpacklo_epi32(x, _mm256_setzero_si256()); + y = _mm256_unpacklo_epi32(y, _mm256_setzero_si256()); + } + + // calculate z = x ^ y << (53 - 32)) + __m256i z = _mm256_sll_epi64(y, _mm_set1_epi64x(53 - 32)); + z = _mm256_xor_si256(x, z); + + // convert uint64 to double + __m256d rs = _my256_cvtepu64_pd(z); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) +#ifdef __FMA__ + rs = _mm256_fmadd_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE), _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0)); +#else + rs = _mm256_mul_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE)); + rs = _mm256_add_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0)); +#endif + + return rs; +} + QUALIFIERS void aesni_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3, uint32 key0, uint32 key1, uint32 key2, uint32 key3, @@ -222,6 +308,48 @@ QUALIFIERS void aesni_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i c rnd4 = _mm256_add_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); #endif } + + +QUALIFIERS void aesni_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi) +{ + // pack input and call AES + __m256i k256 = _mm256_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0); + __m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + __m128i a[4], b[4]; + for (int i = 0; i < 4; ++i) + { + a[i] = _mm256_extractf128_si256(ctr[i], 0); + b[i] = _mm256_extractf128_si256(ctr[i], 1); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my256_set_m128i(b[i], a[i]); + } + for (int i = 0; i < 4; ++i) + { + ctr[i] = aesni1xm128i(ctr[i], k256); + } + for (int i = 0; i < 4; ++i) + { + a[i] = _mm256_extractf128_si256(ctr[i], 0); + b[i] = _mm256_extractf128_si256(ctr[i], 1); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my256_set_m128i(b[i], a[i]); + } + + rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); + rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); + rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); + rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); +} #endif @@ -249,6 +377,33 @@ QUALIFIERS __m512i aesni1xm128i(const __m512i & in, const __m512i & k) { return x; } +template<bool high> +QUALIFIERS __m512d _uniform_double_hq(__m512i x, __m512i y) +{ + // convert 32 to 64 bit + if (high) + { + x = _mm512_unpackhi_epi32(x, _mm512_setzero_si512()); + y = _mm512_unpackhi_epi32(y, _mm512_setzero_si512()); + } + else + { + x = _mm512_unpacklo_epi32(x, _mm512_setzero_si512()); + y = _mm512_unpacklo_epi32(y, _mm512_setzero_si512()); + } + + // calculate z = x ^ y << (53 - 32)) + __m512i z = _mm512_sll_epi64(y, _mm_set1_epi64x(53 - 32)); + z = _mm512_xor_si512(x, z); + + // convert uint64 to double + __m512d rs = _mm512_cvtepu64_pd(z); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) + rs = _mm512_fmadd_pd(rs, _mm512_set1_pd(TWOPOW53_INV_DOUBLE), _mm512_set1_pd(TWOPOW53_INV_DOUBLE/2.0)); + + return rs; +} + QUALIFIERS void aesni_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3, uint32 key0, uint32 key1, uint32 key2, uint32 key3, @@ -305,5 +460,56 @@ QUALIFIERS void aesni_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i c rnd3 = _mm512_fmadd_ps(rnd3, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0)); rnd4 = _mm512_fmadd_ps(rnd4, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0)); } + + +QUALIFIERS void aesni_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi) +{ + // pack input and call AES + __m512i k512 = _mm512_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0, + key3, key2, key1, key0, key3, key2, key1, key0); + __m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + __m128i a[4], b[4], c[4], d[4]; + for (int i = 0; i < 4; ++i) + { + a[i] = _mm512_extracti32x4_epi32(ctr[i], 0); + b[i] = _mm512_extracti32x4_epi32(ctr[i], 1); + c[i] = _mm512_extracti32x4_epi32(ctr[i], 2); + d[i] = _mm512_extracti32x4_epi32(ctr[i], 3); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + _MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]); + _MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]); + } + for (int i = 0; i < 4; ++i) + { + ctr[i] = aesni1xm128i(ctr[i], k512); + } + for (int i = 0; i < 4; ++i) + { + a[i] = _mm512_extracti32x4_epi32(ctr[i], 0); + b[i] = _mm512_extracti32x4_epi32(ctr[i], 1); + c[i] = _mm512_extracti32x4_epi32(ctr[i], 2); + d[i] = _mm512_extracti32x4_epi32(ctr[i], 3); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + _MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]); + _MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]); + } + + rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); + rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); + rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); + rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); +} #endif