diff --git a/pystencils/include/aesni_rand.h b/pystencils/include/aesni_rand.h index c8b4089f86fb08c7740f72f17d288b707cff6a1c..b8efcd4b53f4afbb0e88a17746824de4153295bf 100644 --- a/pystencils/include/aesni_rand.h +++ b/pystencils/include/aesni_rand.h @@ -4,7 +4,7 @@ #include <emmintrin.h> // SSE2 #include <wmmintrin.h> // AES -#ifdef __AVX512VL__ +#ifdef __AVX2__ #include <immintrin.h> // AVX* #else #include <smmintrin.h> // SSE4 @@ -103,3 +103,207 @@ QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, rnd4 = r[3]; } + +QUALIFIERS void aesni_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4) +{ + // pack input and call AES + __m128i k128 = _mm_set_epi32(key3, key2, key1, key0); + __m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = aesni1xm128i(ctr[i], k128); + } + _MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]); + + // convert uint32 to float + rnd1 = _my_cvtepu32_ps(ctr[0]); + rnd2 = _my_cvtepu32_ps(ctr[1]); + rnd3 = _my_cvtepu32_ps(ctr[2]); + rnd4 = _my_cvtepu32_ps(ctr[3]); + // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) +#ifdef __FMA__ + rnd1 = _mm_fmadd_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd2 = _mm_fmadd_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd3 = _mm_fmadd_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd4 = _mm_fmadd_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0)); +#else + rnd1 = _mm_mul_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT)); + rnd1 = _mm_add_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd2 = _mm_mul_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT)); + rnd2 = _mm_add_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd3 = _mm_mul_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT)); + rnd3 = _mm_add_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd4 = _mm_mul_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT)); + rnd4 = _mm_add_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); +#endif +} + + +#ifdef __AVX2__ +QUALIFIERS __m256i aesni1xm128i(const __m256i & in, const __m256i & k) { + __m256i x = _mm256_xor_si256(k, in); +#if defined(__VAES__) && defined(__AVX512VL__) + x = _mm256_aesenc_epi128(x, k); // 1 + x = _mm256_aesenc_epi128(x, k); // 2 + x = _mm256_aesenc_epi128(x, k); // 3 + x = _mm256_aesenc_epi128(x, k); // 4 + x = _mm256_aesenc_epi128(x, k); // 5 + x = _mm256_aesenc_epi128(x, k); // 6 + x = _mm256_aesenc_epi128(x, k); // 7 + x = _mm256_aesenc_epi128(x, k); // 8 + x = _mm256_aesenc_epi128(x, k); // 9 + x = _mm256_aesenclast_epi128(x, k); // 10 +#else + __m128i a = aesni1xm128i(_mm256_extractf128_si256(in, 0), _mm256_extractf128_si256(k, 0)); + __m128i b = aesni1xm128i(_mm256_extractf128_si256(in, 1), _mm256_extractf128_si256(k, 1)); + x = _my256_set_m128i(b, a); +#endif + return x; +} + + +QUALIFIERS void aesni_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4) +{ + // pack input and call AES + __m256i k256 = _mm256_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0); + __m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + __m128i a[4], b[4]; + for (int i = 0; i < 4; ++i) + { + a[i] = _mm256_extractf128_si256(ctr[i], 0); + b[i] = _mm256_extractf128_si256(ctr[i], 1); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my256_set_m128i(b[i], a[i]); + } + for (int i = 0; i < 4; ++i) + { + ctr[i] = aesni1xm128i(ctr[i], k256); + } + for (int i = 0; i < 4; ++i) + { + a[i] = _mm256_extractf128_si256(ctr[i], 0); + b[i] = _mm256_extractf128_si256(ctr[i], 1); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my256_set_m128i(b[i], a[i]); + } + + // convert uint32 to float + rnd1 = _my256_cvtepu32_ps(ctr[0]); + rnd2 = _my256_cvtepu32_ps(ctr[1]); + rnd3 = _my256_cvtepu32_ps(ctr[2]); + rnd4 = _my256_cvtepu32_ps(ctr[3]); + // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) +#ifdef __FMA__ + rnd1 = _mm256_fmadd_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd2 = _mm256_fmadd_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd3 = _mm256_fmadd_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd4 = _mm256_fmadd_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0)); +#else + rnd1 = _mm256_mul_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT)); + rnd1 = _mm256_add_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd2 = _mm256_mul_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT)); + rnd2 = _mm256_add_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd3 = _mm256_mul_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT)); + rnd3 = _mm256_add_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd4 = _mm256_mul_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT)); + rnd4 = _mm256_add_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); +#endif +} +#endif + + +#ifdef __AVX512F__ +QUALIFIERS __m512i aesni1xm128i(const __m512i & in, const __m512i & k) { + __m512i x = _mm512_xor_si512(k, in); +#ifdef __VAES__ + x = _mm512_aesenc_epi128(x, k); // 1 + x = _mm512_aesenc_epi128(x, k); // 2 + x = _mm512_aesenc_epi128(x, k); // 3 + x = _mm512_aesenc_epi128(x, k); // 4 + x = _mm512_aesenc_epi128(x, k); // 5 + x = _mm512_aesenc_epi128(x, k); // 6 + x = _mm512_aesenc_epi128(x, k); // 7 + x = _mm512_aesenc_epi128(x, k); // 8 + x = _mm512_aesenc_epi128(x, k); // 9 + x = _mm512_aesenclast_epi128(x, k); // 10 +#else + __m128i a = aesni1xm128i(_mm512_extracti32x4_epi32(in, 0), _mm512_extracti32x4_epi32(k, 0)); + __m128i b = aesni1xm128i(_mm512_extracti32x4_epi32(in, 1), _mm512_extracti32x4_epi32(k, 1)); + __m128i c = aesni1xm128i(_mm512_extracti32x4_epi32(in, 2), _mm512_extracti32x4_epi32(k, 2)); + __m128i d = aesni1xm128i(_mm512_extracti32x4_epi32(in, 3), _mm512_extracti32x4_epi32(k, 3)); + x = _my512_set_m128i(d, c, b, a); +#endif + return x; +} + + +QUALIFIERS void aesni_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4) +{ + // pack input and call AES + __m512i k512 = _mm512_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0, + key3, key2, key1, key0, key3, key2, key1, key0); + __m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + __m128i a[4], b[4], c[4], d[4]; + for (int i = 0; i < 4; ++i) + { + a[i] = _mm512_extracti32x4_epi32(ctr[i], 0); + b[i] = _mm512_extracti32x4_epi32(ctr[i], 1); + c[i] = _mm512_extracti32x4_epi32(ctr[i], 2); + d[i] = _mm512_extracti32x4_epi32(ctr[i], 3); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + _MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]); + _MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]); + } + for (int i = 0; i < 4; ++i) + { + ctr[i] = aesni1xm128i(ctr[i], k512); + } + for (int i = 0; i < 4; ++i) + { + a[i] = _mm512_extracti32x4_epi32(ctr[i], 0); + b[i] = _mm512_extracti32x4_epi32(ctr[i], 1); + c[i] = _mm512_extracti32x4_epi32(ctr[i], 2); + d[i] = _mm512_extracti32x4_epi32(ctr[i], 3); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + _MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]); + _MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]); + } + + // convert uint32 to float + rnd1 = _mm512_cvtepu32_ps(ctr[0]); + rnd2 = _mm512_cvtepu32_ps(ctr[1]); + rnd3 = _mm512_cvtepu32_ps(ctr[2]); + rnd4 = _mm512_cvtepu32_ps(ctr[3]); + // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) + rnd1 = _mm512_fmadd_ps(rnd1, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd2 = _mm512_fmadd_ps(rnd2, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd3 = _mm512_fmadd_ps(rnd3, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd4 = _mm512_fmadd_ps(rnd4, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0)); +} +#endif + diff --git a/pystencils/include/myintrin.h b/pystencils/include/myintrin.h index 38304cb5f4c9cead701e4e902aef05862754d237..5b7ee436ef1955988f2f8cb6ab1712738fa3f620 100644 --- a/pystencils/include/myintrin.h +++ b/pystencils/include/myintrin.h @@ -46,6 +46,13 @@ QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x) } #endif +#ifdef __AVX__ +QUALIFIERS __m256i _my256_set_m128i(__m128i hi, __m128i lo) +{ + return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1); +} +#endif + #ifdef __AVX2__ QUALIFIERS __m256 _my256_cvtepu32_ps(const __m256i v) { @@ -77,3 +84,10 @@ QUALIFIERS __m256d _my256_cvtepu64_pd(const __m256i x) } #endif +#ifdef __AVX512F__ +QUALIFIERS __m512i _my512_set_m128i(__m128i d, __m128i c, __m128i b, __m128i a) +{ + return _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(a), b, 1), c, 2), d, 3); +} +#endif +