Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision

Target

Select target project
  • anirudh.jonnalagadda/pystencils
  • hyteg/pystencils
  • jbadwaik/pystencils
  • jngrad/pystencils
  • itischler/pystencils
  • ob28imeq/pystencils
  • hoenig/pystencils
  • Bindgen/pystencils
  • hammer/pystencils
  • da15siwa/pystencils
  • holzer/pystencils
  • alexander.reinauer/pystencils
  • ec93ujoh/pystencils
  • Harke/pystencils
  • seitz/pystencils
  • pycodegen/pystencils
16 results
Select Git revision
Show changes
Commits on Source (11)
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <emmintrin.h> // SSE2 #include <emmintrin.h> // SSE2
#include <wmmintrin.h> // AES #include <wmmintrin.h> // AES
#ifdef __AVX512VL__ #ifdef __AVX2__
#include <immintrin.h> // AVX* #include <immintrin.h> // AVX*
#else #else
#include <smmintrin.h> // SSE4 #include <smmintrin.h> // SSE4
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#define TWOPOW53_INV_DOUBLE (1.1102230246251565e-16) #define TWOPOW53_INV_DOUBLE (1.1102230246251565e-16)
#define TWOPOW32_INV_FLOAT (2.3283064e-10f) #define TWOPOW32_INV_FLOAT (2.3283064e-10f)
#include "myintrin.h"
typedef std::uint32_t uint32; typedef std::uint32_t uint32;
typedef std::uint64_t uint64; typedef std::uint64_t uint64;
...@@ -36,35 +38,6 @@ QUALIFIERS __m128i aesni1xm128i(const __m128i & in, const __m128i & k) { ...@@ -36,35 +38,6 @@ QUALIFIERS __m128i aesni1xm128i(const __m128i & in, const __m128i & k) {
return x; return x;
} }
QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v)
{
#ifdef __AVX512VL__
return _mm_cvtepu32_ps(v);
#else
__m128i v2 = _mm_srli_epi32(v, 1);
__m128i v1 = _mm_and_si128(v, _mm_set1_epi32(1));
__m128 v2f = _mm_cvtepi32_ps(v2);
__m128 v1f = _mm_cvtepi32_ps(v1);
return _mm_add_ps(_mm_add_ps(v2f, v2f), v1f);
#endif
}
#if !defined(__AVX512VL__) && defined(__GNUC__) && __GNUC__ >= 5
__attribute__((optimize("no-associative-math")))
#endif
QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x)
{
#ifdef __AVX512VL__
return _mm_cvtepu64_pd(x);
#else
__m128i xH = _mm_srli_epi64(x, 32);
xH = _mm_or_si128(xH, _mm_castpd_si128(_mm_set1_pd(19342813113834066795298816.))); // 2^84
__m128i xL = _mm_blend_epi16(x, _mm_castpd_si128(_mm_set1_pd(0x0010000000000000)), 0xcc); // 2^52
__m128d f = _mm_sub_pd(_mm_castsi128_pd(xH), _mm_set1_pd(19342813118337666422669312.)); // 2^84 + 2^52
return _mm_add_pd(f, _mm_castsi128_pd(xL));
#endif
}
QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3, uint32 key0, uint32 key1, uint32 key2, uint32 key3,
...@@ -130,3 +103,485 @@ QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, ...@@ -130,3 +103,485 @@ QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
rnd4 = r[3]; rnd4 = r[3];
} }
template<bool high>
QUALIFIERS __m128d _uniform_double_hq(__m128i x, __m128i y)
{
// convert 32 to 64 bit
if (high)
{
x = _mm_unpackhi_epi32(x, _mm_setzero_si128());
y = _mm_unpackhi_epi32(y, _mm_setzero_si128());
}
else
{
x = _mm_unpacklo_epi32(x, _mm_setzero_si128());
y = _mm_unpacklo_epi32(y, _mm_setzero_si128());
}
// calculate z = x ^ y << (53 - 32))
__m128i z = _mm_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm_xor_si128(x, z);
// convert uint64 to double
__m128d rs = _my_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
#ifdef __FMA__
rs = _mm_fmadd_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE), _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#else
rs = _mm_mul_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE));
rs = _mm_add_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#endif
return rs;
}
QUALIFIERS void aesni_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4)
{
// pack input and call AES
__m128i k128 = _mm_set_epi32(key3, key2, key1, key0);
__m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = aesni1xm128i(ctr[i], k128);
}
_MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]);
// convert uint32 to float
rnd1 = _my_cvtepu32_ps(ctr[0]);
rnd2 = _my_cvtepu32_ps(ctr[1]);
rnd3 = _my_cvtepu32_ps(ctr[2]);
rnd4 = _my_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
#ifdef __FMA__
rnd1 = _mm_fmadd_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm_fmadd_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm_fmadd_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm_fmadd_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
#else
rnd1 = _mm_mul_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd1 = _mm_add_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = _mm_mul_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd2 = _mm_add_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = _mm_mul_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd3 = _mm_add_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = _mm_mul_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd4 = _mm_add_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#endif
}
QUALIFIERS void aesni_double2(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi)
{
// pack input and call AES
__m128i k128 = _mm_set_epi32(key3, key2, key1, key0);
__m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = aesni1xm128i(ctr[i], k128);
}
_MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]);
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
#ifdef __AVX2__
QUALIFIERS __m256i aesni1xm128i(const __m256i & in, const __m256i & k) {
__m256i x = _mm256_xor_si256(k, in);
#if defined(__VAES__) && defined(__AVX512VL__)
x = _mm256_aesenc_epi128(x, k); // 1
x = _mm256_aesenc_epi128(x, k); // 2
x = _mm256_aesenc_epi128(x, k); // 3
x = _mm256_aesenc_epi128(x, k); // 4
x = _mm256_aesenc_epi128(x, k); // 5
x = _mm256_aesenc_epi128(x, k); // 6
x = _mm256_aesenc_epi128(x, k); // 7
x = _mm256_aesenc_epi128(x, k); // 8
x = _mm256_aesenc_epi128(x, k); // 9
x = _mm256_aesenclast_epi128(x, k); // 10
#else
__m128i a = aesni1xm128i(_mm256_extractf128_si256(in, 0), _mm256_extractf128_si256(k, 0));
__m128i b = aesni1xm128i(_mm256_extractf128_si256(in, 1), _mm256_extractf128_si256(k, 1));
x = _my256_set_m128i(b, a);
#endif
return x;
}
QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4)
{
__m128i ctr0v = _mm_add_epi32(_mm_set1_epi32(ctr0), _mm_set_epi32(3,2,1,0));
__m128i ctr1v = _mm_set1_epi32(ctr1);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
aesni_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi)
{
__m128i ctr0v = _mm_add_epi32(_mm_set1_epi32(ctr0), _mm_set_epi32(3,2,1,0));
__m128i ctr1v = _mm_set1_epi32(ctr1);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
aesni_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
template<bool high>
QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y)
{
// convert 32 to 64 bit
if (high)
{
x = _mm256_unpackhi_epi32(x, _mm256_setzero_si256());
y = _mm256_unpackhi_epi32(y, _mm256_setzero_si256());
}
else
{
x = _mm256_unpacklo_epi32(x, _mm256_setzero_si256());
y = _mm256_unpacklo_epi32(y, _mm256_setzero_si256());
}
// calculate z = x ^ y << (53 - 32))
__m256i z = _mm256_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm256_xor_si256(x, z);
// convert uint64 to double
__m256d rs = _my256_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
#ifdef __FMA__
rs = _mm256_fmadd_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE), _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#else
rs = _mm256_mul_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE));
rs = _mm256_add_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#endif
return rs;
}
QUALIFIERS void aesni_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4)
{
// pack input and call AES
__m256i k256 = _mm256_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0);
__m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
__m128i a[4], b[4];
for (int i = 0; i < 4; ++i)
{
a[i] = _mm256_extractf128_si256(ctr[i], 0);
b[i] = _mm256_extractf128_si256(ctr[i], 1);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my256_set_m128i(b[i], a[i]);
}
for (int i = 0; i < 4; ++i)
{
ctr[i] = aesni1xm128i(ctr[i], k256);
}
for (int i = 0; i < 4; ++i)
{
a[i] = _mm256_extractf128_si256(ctr[i], 0);
b[i] = _mm256_extractf128_si256(ctr[i], 1);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my256_set_m128i(b[i], a[i]);
}
// convert uint32 to float
rnd1 = _my256_cvtepu32_ps(ctr[0]);
rnd2 = _my256_cvtepu32_ps(ctr[1]);
rnd3 = _my256_cvtepu32_ps(ctr[2]);
rnd4 = _my256_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
#ifdef __FMA__
rnd1 = _mm256_fmadd_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm256_fmadd_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm256_fmadd_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm256_fmadd_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
#else
rnd1 = _mm256_mul_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd1 = _mm256_add_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = _mm256_mul_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd2 = _mm256_add_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = _mm256_mul_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd3 = _mm256_add_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = _mm256_mul_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd4 = _mm256_add_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#endif
}
QUALIFIERS void aesni_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi)
{
// pack input and call AES
__m256i k256 = _mm256_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0);
__m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
__m128i a[4], b[4];
for (int i = 0; i < 4; ++i)
{
a[i] = _mm256_extractf128_si256(ctr[i], 0);
b[i] = _mm256_extractf128_si256(ctr[i], 1);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my256_set_m128i(b[i], a[i]);
}
for (int i = 0; i < 4; ++i)
{
ctr[i] = aesni1xm128i(ctr[i], k256);
}
for (int i = 0; i < 4; ++i)
{
a[i] = _mm256_extractf128_si256(ctr[i], 0);
b[i] = _mm256_extractf128_si256(ctr[i], 1);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my256_set_m128i(b[i], a[i]);
}
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4)
{
__m256i ctr0v = _mm256_add_epi32(_mm256_set1_epi32(ctr0), _mm256_set_epi32(7,6,5,4,3,2,1,0));
__m256i ctr1v = _mm256_set1_epi32(ctr1);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
aesni_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi)
{
__m256i ctr0v = _mm256_add_epi32(_mm256_set1_epi32(ctr0), _mm256_set_epi32(7,6,5,4,3,2,1,0));
__m256i ctr1v = _mm256_set1_epi32(ctr1);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
aesni_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
#endif
#ifdef __AVX512F__
QUALIFIERS __m512i aesni1xm128i(const __m512i & in, const __m512i & k) {
__m512i x = _mm512_xor_si512(k, in);
#ifdef __VAES__
x = _mm512_aesenc_epi128(x, k); // 1
x = _mm512_aesenc_epi128(x, k); // 2
x = _mm512_aesenc_epi128(x, k); // 3
x = _mm512_aesenc_epi128(x, k); // 4
x = _mm512_aesenc_epi128(x, k); // 5
x = _mm512_aesenc_epi128(x, k); // 6
x = _mm512_aesenc_epi128(x, k); // 7
x = _mm512_aesenc_epi128(x, k); // 8
x = _mm512_aesenc_epi128(x, k); // 9
x = _mm512_aesenclast_epi128(x, k); // 10
#else
__m128i a = aesni1xm128i(_mm512_extracti32x4_epi32(in, 0), _mm512_extracti32x4_epi32(k, 0));
__m128i b = aesni1xm128i(_mm512_extracti32x4_epi32(in, 1), _mm512_extracti32x4_epi32(k, 1));
__m128i c = aesni1xm128i(_mm512_extracti32x4_epi32(in, 2), _mm512_extracti32x4_epi32(k, 2));
__m128i d = aesni1xm128i(_mm512_extracti32x4_epi32(in, 3), _mm512_extracti32x4_epi32(k, 3));
x = _my512_set_m128i(d, c, b, a);
#endif
return x;
}
template<bool high>
QUALIFIERS __m512d _uniform_double_hq(__m512i x, __m512i y)
{
// convert 32 to 64 bit
if (high)
{
x = _mm512_unpackhi_epi32(x, _mm512_setzero_si512());
y = _mm512_unpackhi_epi32(y, _mm512_setzero_si512());
}
else
{
x = _mm512_unpacklo_epi32(x, _mm512_setzero_si512());
y = _mm512_unpacklo_epi32(y, _mm512_setzero_si512());
}
// calculate z = x ^ y << (53 - 32))
__m512i z = _mm512_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm512_xor_si512(x, z);
// convert uint64 to double
__m512d rs = _mm512_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = _mm512_fmadd_pd(rs, _mm512_set1_pd(TWOPOW53_INV_DOUBLE), _mm512_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
return rs;
}
QUALIFIERS void aesni_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4)
{
// pack input and call AES
__m512i k512 = _mm512_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0,
key3, key2, key1, key0, key3, key2, key1, key0);
__m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
__m128i a[4], b[4], c[4], d[4];
for (int i = 0; i < 4; ++i)
{
a[i] = _mm512_extracti32x4_epi32(ctr[i], 0);
b[i] = _mm512_extracti32x4_epi32(ctr[i], 1);
c[i] = _mm512_extracti32x4_epi32(ctr[i], 2);
d[i] = _mm512_extracti32x4_epi32(ctr[i], 3);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
_MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]);
_MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]);
}
for (int i = 0; i < 4; ++i)
{
ctr[i] = aesni1xm128i(ctr[i], k512);
}
for (int i = 0; i < 4; ++i)
{
a[i] = _mm512_extracti32x4_epi32(ctr[i], 0);
b[i] = _mm512_extracti32x4_epi32(ctr[i], 1);
c[i] = _mm512_extracti32x4_epi32(ctr[i], 2);
d[i] = _mm512_extracti32x4_epi32(ctr[i], 3);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
_MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]);
_MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]);
}
// convert uint32 to float
rnd1 = _mm512_cvtepu32_ps(ctr[0]);
rnd2 = _mm512_cvtepu32_ps(ctr[1]);
rnd3 = _mm512_cvtepu32_ps(ctr[2]);
rnd4 = _mm512_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = _mm512_fmadd_ps(rnd1, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm512_fmadd_ps(rnd2, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm512_fmadd_ps(rnd3, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm512_fmadd_ps(rnd4, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
}
QUALIFIERS void aesni_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi)
{
// pack input and call AES
__m512i k512 = _mm512_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0,
key3, key2, key1, key0, key3, key2, key1, key0);
__m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
__m128i a[4], b[4], c[4], d[4];
for (int i = 0; i < 4; ++i)
{
a[i] = _mm512_extracti32x4_epi32(ctr[i], 0);
b[i] = _mm512_extracti32x4_epi32(ctr[i], 1);
c[i] = _mm512_extracti32x4_epi32(ctr[i], 2);
d[i] = _mm512_extracti32x4_epi32(ctr[i], 3);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
_MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]);
_MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]);
}
for (int i = 0; i < 4; ++i)
{
ctr[i] = aesni1xm128i(ctr[i], k512);
}
for (int i = 0; i < 4; ++i)
{
a[i] = _mm512_extracti32x4_epi32(ctr[i], 0);
b[i] = _mm512_extracti32x4_epi32(ctr[i], 1);
c[i] = _mm512_extracti32x4_epi32(ctr[i], 2);
d[i] = _mm512_extracti32x4_epi32(ctr[i], 3);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
_MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]);
_MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]);
}
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4)
{
__m512i ctr0v = _mm512_add_epi32(_mm512_set1_epi32(ctr0), _mm512_set_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0));
__m512i ctr1v = _mm512_set1_epi32(ctr1);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
philox_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi)
{
__m512i ctr0v = _mm512_add_epi32(_mm512_set1_epi32(ctr0), _mm512_set_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0));
__m512i ctr1v = _mm512_set1_epi32(ctr1);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
philox_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
#endif
#pragma once
#ifdef __SSE2__
QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v)
{
#ifdef __AVX512VL__
return _mm_cvtepu32_ps(v);
#else
__m128i v2 = _mm_srli_epi32(v, 1);
__m128i v1 = _mm_and_si128(v, _mm_set1_epi32(1));
__m128 v2f = _mm_cvtepi32_ps(v2);
__m128 v1f = _mm_cvtepi32_ps(v1);
return _mm_add_ps(_mm_add_ps(v2f, v2f), v1f);
#endif
}
QUALIFIERS void _MY_TRANSPOSE4_EPI32(__m128i & R0, __m128i & R1, __m128i & R2, __m128i & R3)
{
__m128i T0, T1, T2, T3;
T0 = _mm_unpacklo_epi32(R0, R1);
T1 = _mm_unpacklo_epi32(R2, R3);
T2 = _mm_unpackhi_epi32(R0, R1);
T3 = _mm_unpackhi_epi32(R2, R3);
R0 = _mm_unpacklo_epi64(T0, T1);
R1 = _mm_unpackhi_epi64(T0, T1);
R2 = _mm_unpacklo_epi64(T2, T3);
R3 = _mm_unpackhi_epi64(T2, T3);
}
#endif
#ifdef __SSE4_1__
#if !defined(__AVX512VL__) && defined(__GNUC__) && __GNUC__ >= 5
__attribute__((optimize("no-associative-math")))
#endif
QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x)
{
#ifdef __AVX512VL__
return _mm_cvtepu64_pd(x);
#else
__m128i xH = _mm_srli_epi64(x, 32);
xH = _mm_or_si128(xH, _mm_castpd_si128(_mm_set1_pd(19342813113834066795298816.))); // 2^84
__m128i xL = _mm_blend_epi16(x, _mm_castpd_si128(_mm_set1_pd(0x0010000000000000)), 0xcc); // 2^52
__m128d f = _mm_sub_pd(_mm_castsi128_pd(xH), _mm_set1_pd(19342813118337666422669312.)); // 2^84 + 2^52
return _mm_add_pd(f, _mm_castsi128_pd(xL));
#endif
}
#endif
#ifdef __AVX__
QUALIFIERS __m256i _my256_set_m128i(__m128i hi, __m128i lo)
{
return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1);
}
#endif
#ifdef __AVX2__
QUALIFIERS __m256 _my256_cvtepu32_ps(const __m256i v)
{
#ifdef __AVX512VL__
return _mm256_cvtepu32_ps(v);
#else
__m256i v2 = _mm256_srli_epi32(v, 1);
__m256i v1 = _mm256_and_si256(v, _mm256_set1_epi32(1));
__m256 v2f = _mm256_cvtepi32_ps(v2);
__m256 v1f = _mm256_cvtepi32_ps(v1);
return _mm256_add_ps(_mm256_add_ps(v2f, v2f), v1f);
#endif
}
#if !defined(__AVX512VL__) && defined(__GNUC__) && __GNUC__ >= 5
__attribute__((optimize("no-associative-math")))
#endif
QUALIFIERS __m256d _my256_cvtepu64_pd(const __m256i x)
{
#ifdef __AVX512VL__
return _mm256_cvtepu64_pd(x);
#else
__m256i xH = _mm256_srli_epi64(x, 32);
xH = _mm256_or_si256(xH, _mm256_castpd_si256(_mm256_set1_pd(19342813113834066795298816.))); // 2^84
__m256i xL = _mm256_blend_epi16(x, _mm256_castpd_si256(_mm256_set1_pd(0x0010000000000000)), 0xcc); // 2^52
__m256d f = _mm256_sub_pd(_mm256_castsi256_pd(xH), _mm256_set1_pd(19342813118337666422669312.)); // 2^84 + 2^52
return _mm256_add_pd(f, _mm256_castsi256_pd(xL));
#endif
}
#endif
#ifdef __AVX512F__
QUALIFIERS __m512i _my512_set_m128i(__m128i d, __m128i c, __m128i b, __m128i a)
{
return _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(a), b, 1), c, 2), d, 3);
}
#endif
#include <cstdint> #include <cstdint>
#ifdef __SSE4_1__
#include <emmintrin.h> // SSE2
#endif
#ifdef __AVX2__
#include <immintrin.h> // AVX*
#else
#include <smmintrin.h> // SSE4
#ifdef __FMA__
#include <immintrin.h> // FMA
#endif
#endif
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
#define QUALIFIERS inline #define QUALIFIERS inline
#include "myintrin.h"
#else #else
#define QUALIFIERS static __forceinline__ __device__ #define QUALIFIERS static __forceinline__ __device__
#endif #endif
...@@ -78,7 +91,6 @@ QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr ...@@ -78,7 +91,6 @@ QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr
} }
QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key0, uint32 key1,
float & rnd1, float & rnd2, float & rnd3, float & rnd4) float & rnd1, float & rnd2, float & rnd3, float & rnd4)
...@@ -100,4 +112,447 @@ QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3 ...@@ -100,4 +112,447 @@ QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3
rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
} }
\ No newline at end of file
#ifndef __CUDA_ARCH__
#ifdef __SSE4_1__
QUALIFIERS void _philox4x32round(__m128i* ctr, __m128i* key)
{
__m128i lohi0a = _mm_mul_epu32(ctr[0], _mm_set1_epi32(PHILOX_M4x32_0));
__m128i lohi0b = _mm_mul_epu32(_mm_srli_epi64(ctr[0], 32), _mm_set1_epi32(PHILOX_M4x32_0));
__m128i lohi1a = _mm_mul_epu32(ctr[2], _mm_set1_epi32(PHILOX_M4x32_1));
__m128i lohi1b = _mm_mul_epu32(_mm_srli_epi64(ctr[2], 32), _mm_set1_epi32(PHILOX_M4x32_1));
lohi0a = _mm_shuffle_epi32(lohi0a, 0xD8);
lohi0b = _mm_shuffle_epi32(lohi0b, 0xD8);
lohi1a = _mm_shuffle_epi32(lohi1a, 0xD8);
lohi1b = _mm_shuffle_epi32(lohi1b, 0xD8);
__m128i lo0 = _mm_unpacklo_epi32(lohi0a, lohi0b);
__m128i hi0 = _mm_unpackhi_epi32(lohi0a, lohi0b);
__m128i lo1 = _mm_unpacklo_epi32(lohi1a, lohi1b);
__m128i hi1 = _mm_unpackhi_epi32(lohi1a, lohi1b);
ctr[0] = _mm_xor_si128(_mm_xor_si128(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = _mm_xor_si128(_mm_xor_si128(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(__m128i* key)
{
key[0] = _mm_add_epi32(key[0], _mm_set1_epi32(PHILOX_W32_0));
key[1] = _mm_add_epi32(key[1], _mm_set1_epi32(PHILOX_W32_1));
}
template<bool high>
QUALIFIERS __m128d _uniform_double_hq(__m128i x, __m128i y)
{
// convert 32 to 64 bit
if (high)
{
x = _mm_unpackhi_epi32(x, _mm_setzero_si128());
y = _mm_unpackhi_epi32(y, _mm_setzero_si128());
}
else
{
x = _mm_unpacklo_epi32(x, _mm_setzero_si128());
y = _mm_unpacklo_epi32(y, _mm_setzero_si128());
}
// calculate z = x ^ y << (53 - 32))
__m128i z = _mm_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm_xor_si128(x, z);
// convert uint64 to double
__m128d rs = _my_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
#ifdef __FMA__
rs = _mm_fmadd_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE), _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#else
rs = _mm_mul_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE));
rs = _mm_add_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#endif
return rs;
}
QUALIFIERS void philox_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3,
uint32 key0, uint32 key1,
__m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4)
{
__m128i key[2] = {_mm_set1_epi32(key0), _mm_set1_epi32(key1)};
__m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = _my_cvtepu32_ps(ctr[0]);
rnd2 = _my_cvtepu32_ps(ctr[1]);
rnd3 = _my_cvtepu32_ps(ctr[2]);
rnd4 = _my_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
#ifdef __FMA__
rnd1 = _mm_fmadd_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm_fmadd_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm_fmadd_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm_fmadd_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
#else
rnd1 = _mm_mul_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd1 = _mm_add_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = _mm_mul_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd2 = _mm_add_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = _mm_mul_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd3 = _mm_add_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = _mm_mul_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd4 = _mm_add_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#endif
}
QUALIFIERS void philox_double2(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3,
uint32 key0, uint32 key1,
__m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi)
{
__m128i key[2] = {_mm_set1_epi32(key0), _mm_set1_epi32(key1)};
__m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4)
{
__m128i ctr0v = _mm_add_epi32(_mm_set1_epi32(ctr0), _mm_set_epi32(3,2,1,0));
__m128i ctr1v = _mm_set1_epi32(ctr1);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
philox_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi)
{
__m128i ctr0v = _mm_add_epi32(_mm_set1_epi32(ctr0), _mm_set_epi32(3,2,1,0));
__m128i ctr1v = _mm_set1_epi32(ctr1);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
philox_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
#endif
#ifdef __AVX2__
QUALIFIERS void _philox4x32round(__m256i* ctr, __m256i* key)
{
__m256i lohi0a = _mm256_mul_epu32(ctr[0], _mm256_set1_epi32(PHILOX_M4x32_0));
__m256i lohi0b = _mm256_mul_epu32(_mm256_srli_epi64(ctr[0], 32), _mm256_set1_epi32(PHILOX_M4x32_0));
__m256i lohi1a = _mm256_mul_epu32(ctr[2], _mm256_set1_epi32(PHILOX_M4x32_1));
__m256i lohi1b = _mm256_mul_epu32(_mm256_srli_epi64(ctr[2], 32), _mm256_set1_epi32(PHILOX_M4x32_1));
lohi0a = _mm256_shuffle_epi32(lohi0a, 0xD8);
lohi0b = _mm256_shuffle_epi32(lohi0b, 0xD8);
lohi1a = _mm256_shuffle_epi32(lohi1a, 0xD8);
lohi1b = _mm256_shuffle_epi32(lohi1b, 0xD8);
__m256i lo0 = _mm256_unpacklo_epi32(lohi0a, lohi0b);
__m256i hi0 = _mm256_unpackhi_epi32(lohi0a, lohi0b);
__m256i lo1 = _mm256_unpacklo_epi32(lohi1a, lohi1b);
__m256i hi1 = _mm256_unpackhi_epi32(lohi1a, lohi1b);
ctr[0] = _mm256_xor_si256(_mm256_xor_si256(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = _mm256_xor_si256(_mm256_xor_si256(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(__m256i* key)
{
key[0] = _mm256_add_epi32(key[0], _mm256_set1_epi32(PHILOX_W32_0));
key[1] = _mm256_add_epi32(key[1], _mm256_set1_epi32(PHILOX_W32_1));
}
template<bool high>
QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y)
{
// convert 32 to 64 bit
if (high)
{
x = _mm256_unpackhi_epi32(x, _mm256_setzero_si256());
y = _mm256_unpackhi_epi32(y, _mm256_setzero_si256());
}
else
{
x = _mm256_unpacklo_epi32(x, _mm256_setzero_si256());
y = _mm256_unpacklo_epi32(y, _mm256_setzero_si256());
}
// calculate z = x ^ y << (53 - 32))
__m256i z = _mm256_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm256_xor_si256(x, z);
// convert uint64 to double
__m256d rs = _my256_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
#ifdef __FMA__
rs = _mm256_fmadd_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE), _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#else
rs = _mm256_mul_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE));
rs = _mm256_add_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#endif
return rs;
}
QUALIFIERS void philox_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3,
uint32 key0, uint32 key1,
__m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4)
{
__m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)};
__m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = _my256_cvtepu32_ps(ctr[0]);
rnd2 = _my256_cvtepu32_ps(ctr[1]);
rnd3 = _my256_cvtepu32_ps(ctr[2]);
rnd4 = _my256_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
#ifdef __FMA__
rnd1 = _mm256_fmadd_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm256_fmadd_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm256_fmadd_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm256_fmadd_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
#else
rnd1 = _mm256_mul_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd1 = _mm256_add_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = _mm256_mul_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd2 = _mm256_add_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = _mm256_mul_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd3 = _mm256_add_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = _mm256_mul_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd4 = _mm256_add_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#endif
}
QUALIFIERS void philox_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3,
uint32 key0, uint32 key1,
__m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi)
{
__m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)};
__m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4)
{
__m256i ctr0v = _mm256_add_epi32(_mm256_set1_epi32(ctr0), _mm256_set_epi32(7,6,5,4,3,2,1,0));
__m256i ctr1v = _mm256_set1_epi32(ctr1);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
philox_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi)
{
__m256i ctr0v = _mm256_add_epi32(_mm256_set1_epi32(ctr0), _mm256_set_epi32(7,6,5,4,3,2,1,0));
__m256i ctr1v = _mm256_set1_epi32(ctr1);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
philox_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
#endif
#ifdef __AVX512F__
QUALIFIERS void _philox4x32round(__m512i* ctr, __m512i* key)
{
__m512i lohi0a = _mm512_mul_epu32(ctr[0], _mm512_set1_epi32(PHILOX_M4x32_0));
__m512i lohi0b = _mm512_mul_epu32(_mm512_srli_epi64(ctr[0], 32), _mm512_set1_epi32(PHILOX_M4x32_0));
__m512i lohi1a = _mm512_mul_epu32(ctr[2], _mm512_set1_epi32(PHILOX_M4x32_1));
__m512i lohi1b = _mm512_mul_epu32(_mm512_srli_epi64(ctr[2], 32), _mm512_set1_epi32(PHILOX_M4x32_1));
lohi0a = _mm512_shuffle_epi32(lohi0a, _MM_PERM_DBCA);
lohi0b = _mm512_shuffle_epi32(lohi0b, _MM_PERM_DBCA);
lohi1a = _mm512_shuffle_epi32(lohi1a, _MM_PERM_DBCA);
lohi1b = _mm512_shuffle_epi32(lohi1b, _MM_PERM_DBCA);
__m512i lo0 = _mm512_unpacklo_epi32(lohi0a, lohi0b);
__m512i hi0 = _mm512_unpackhi_epi32(lohi0a, lohi0b);
__m512i lo1 = _mm512_unpacklo_epi32(lohi1a, lohi1b);
__m512i hi1 = _mm512_unpackhi_epi32(lohi1a, lohi1b);
ctr[0] = _mm512_xor_si512(_mm512_xor_si512(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = _mm512_xor_si512(_mm512_xor_si512(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(__m512i* key)
{
key[0] = _mm512_add_epi32(key[0], _mm512_set1_epi32(PHILOX_W32_0));
key[1] = _mm512_add_epi32(key[1], _mm512_set1_epi32(PHILOX_W32_1));
}
template<bool high>
QUALIFIERS __m512d _uniform_double_hq(__m512i x, __m512i y)
{
// convert 32 to 64 bit
if (high)
{
x = _mm512_unpackhi_epi32(x, _mm512_setzero_si512());
y = _mm512_unpackhi_epi32(y, _mm512_setzero_si512());
}
else
{
x = _mm512_unpacklo_epi32(x, _mm512_setzero_si512());
y = _mm512_unpacklo_epi32(y, _mm512_setzero_si512());
}
// calculate z = x ^ y << (53 - 32))
__m512i z = _mm512_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm512_xor_si512(x, z);
// convert uint64 to double
__m512d rs = _mm512_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = _mm512_fmadd_pd(rs, _mm512_set1_pd(TWOPOW53_INV_DOUBLE), _mm512_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
return rs;
}
QUALIFIERS void philox_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3,
uint32 key0, uint32 key1,
__m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4)
{
__m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)};
__m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = _mm512_cvtepu32_ps(ctr[0]);
rnd2 = _mm512_cvtepu32_ps(ctr[1]);
rnd3 = _mm512_cvtepu32_ps(ctr[2]);
rnd4 = _mm512_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = _mm512_fmadd_ps(rnd1, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm512_fmadd_ps(rnd2, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm512_fmadd_ps(rnd3, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm512_fmadd_ps(rnd4, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
}
QUALIFIERS void philox_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3,
uint32 key0, uint32 key1,
__m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi)
{
__m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)};
__m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4)
{
__m512i ctr0v = _mm512_add_epi32(_mm512_set1_epi32(ctr0), _mm512_set_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0));
__m512i ctr1v = _mm512_set1_epi32(ctr1);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
philox_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi)
{
__m512i ctr0v = _mm512_add_epi32(_mm512_set1_epi32(ctr0), _mm512_set_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0));
__m512i ctr1v = _mm512_set1_epi32(ctr1);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
philox_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
#endif
#endif
...@@ -6,16 +6,23 @@ from pystencils.astnodes import LoopOverCoordinate ...@@ -6,16 +6,23 @@ from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import CustomCodeNode from pystencils.backends.cbackend import CustomCodeNode
def _get_rng_template(name, data_type, num_vars): def _data_type_to_str(data_type):
if data_type is np.float32: if data_type is np.float32:
c_type = "float" return "float"
elif data_type is np.float64: elif data_type is np.float64:
c_type = "double" return "double"
elif type(data_type) is str:
return data_type
raise ValueError("%s is not a supported data type" % (data_type, ))
def _get_rng_template(name, data_type, num_vars):
c_type = _data_type_to_str(data_type)
template = "\n" template = "\n"
for i in range(num_vars): for i in range(num_vars):
template += "{{result_symbols[{}].dtype}} {{result_symbols[{}].name}};\n".format(i, i) template += "{} {{result_symbols[{}].name}};\n".format(c_type, i, i)
template += ("{}_{}{}({{parameters}}, " + ", ".join(["{{result_symbols[{}].name}}"] * num_vars) + ");\n") \ template += ("{}({{parameters}}, " + ", ".join(["{{result_symbols[{}].name}}"] * num_vars) + ");\n") \
.format(name, c_type, num_vars, *tuple(range(num_vars))) .format(name, *tuple(range(num_vars)))
return template return template
...@@ -23,7 +30,7 @@ def _get_rng_code(template, dialect, vector_instruction_set, time_step, offsets, ...@@ -23,7 +30,7 @@ def _get_rng_code(template, dialect, vector_instruction_set, time_step, offsets,
parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i]
for i in range(dim)] + [0] * (3 - dim) + list(keys) for i in range(dim)] + [0] * (3 - dim) + list(keys)
if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None): if dialect == 'cuda' or dialect == 'c':
return template.format(parameters=', '.join(str(p) for p in parameters), return template.format(parameters=', '.join(str(p) for p in parameters),
result_symbols=result_symbols) result_symbols=result_symbols)
else: else:
...@@ -44,7 +51,7 @@ class RNGBase(CustomCodeNode): ...@@ -44,7 +51,7 @@ class RNGBase(CustomCodeNode):
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols) super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self._time_step = time_step self._time_step = time_step
self._offsets = offsets self._offsets = offsets
self.headers = ['"{}_rand.h"'.format(self._name)] self.headers = ['"{}_rand.h"'.format(self._name.split('_')[0])]
self.keys = tuple(keys) self.keys = tuple(keys)
self._args = sp.sympify((dim, time_step, keys)) self._args = sp.sympify((dim, time_step, keys))
self._dim = dim self._dim = dim
...@@ -65,7 +72,11 @@ class RNGBase(CustomCodeNode): ...@@ -65,7 +72,11 @@ class RNGBase(CustomCodeNode):
return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
def get_code(self, dialect, vector_instruction_set): def get_code(self, dialect, vector_instruction_set):
template = _get_rng_template(self._name, self._data_type, self._num_vars) if vector_instruction_set:
template = _get_rng_template(self._name, vector_instruction_set[_data_type_to_str(self._data_type)],
self._num_vars)
else:
template = _get_rng_template(self._name, self._data_type, self._num_vars)
return _get_rng_code(template, dialect, vector_instruction_set, return _get_rng_code(template, dialect, vector_instruction_set,
self._time_step, self._offsets, self.keys, self._dim, self.result_symbols) self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
...@@ -74,28 +85,28 @@ class RNGBase(CustomCodeNode): ...@@ -74,28 +85,28 @@ class RNGBase(CustomCodeNode):
class PhiloxTwoDoubles(RNGBase): class PhiloxTwoDoubles(RNGBase):
_name = "philox" _name = "philox_double2"
_data_type = np.float64 _data_type = np.float64
_num_vars = 2 _num_vars = 2
_num_keys = 2 _num_keys = 2
class PhiloxFourFloats(RNGBase): class PhiloxFourFloats(RNGBase):
_name = "philox" _name = "philox_float4"
_data_type = np.float32 _data_type = np.float32
_num_vars = 4 _num_vars = 4
_num_keys = 2 _num_keys = 2
class AESNITwoDoubles(RNGBase): class AESNITwoDoubles(RNGBase):
_name = "aesni" _name = "aesni_double2"
_data_type = np.float64 _data_type = np.float64
_num_vars = 2 _num_vars = 2
_num_keys = 4 _num_keys = 4
class AESNIFourFloats(RNGBase): class AESNIFourFloats(RNGBase):
_name = "aesni" _name = "aesni_float4"
_data_type = np.float32 _data_type = np.float32
_num_vars = 4 _num_vars = 4
_num_keys = 4 _num_keys = 4
......