Skip to content
Snippets Groups Projects
Commit 170a7717 authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

clean up AES-NI RNG

parent fa0a09a5
Branches
Tags
1 merge request!30AES-NI Random Number Generator
Pipeline #17212 passed
...@@ -2,7 +2,11 @@ ...@@ -2,7 +2,11 @@
#error AES-NI and SSE2 need to be enabled #error AES-NI and SSE2 need to be enabled
#endif #endif
#include <x86intrin.h> #include <emmintrin.h> // SSE2
#include <wmmintrin.h> // AES
#ifdef __AVX512VL__
#include <immintrin.h> // AVX*
#endif
#include <cstdint> #include <cstdint>
#define QUALIFIERS inline #define QUALIFIERS inline
...@@ -14,22 +18,22 @@ typedef std::uint64_t uint64; ...@@ -14,22 +18,22 @@ typedef std::uint64_t uint64;
QUALIFIERS __m128i aesni1xm128i(const __m128i & in, const __m128i & k) { QUALIFIERS __m128i aesni1xm128i(const __m128i & in, const __m128i & k) {
__m128i x = _mm_xor_si128(k, in); __m128i x = _mm_xor_si128(k, in);
x = _mm_aesenc_si128(x, k); x = _mm_aesenc_si128(x, k); // 1
x = _mm_aesenc_si128(x, k); x = _mm_aesenc_si128(x, k); // 2
x = _mm_aesenc_si128(x, k); x = _mm_aesenc_si128(x, k); // 3
x = _mm_aesenc_si128(x, k); x = _mm_aesenc_si128(x, k); // 4
x = _mm_aesenc_si128(x, k); x = _mm_aesenc_si128(x, k); // 5
x = _mm_aesenc_si128(x, k); x = _mm_aesenc_si128(x, k); // 6
x = _mm_aesenc_si128(x, k); x = _mm_aesenc_si128(x, k); // 7
x = _mm_aesenc_si128(x, k); x = _mm_aesenc_si128(x, k); // 8
x = _mm_aesenc_si128(x, k); x = _mm_aesenc_si128(x, k); // 9
x = _mm_aesenclast_si128(x, k); x = _mm_aesenclast_si128(x, k); // 10
return x; return x;
} }
QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v) QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v)
{ {
#ifdef __AVX512F__ #ifdef __AVX512VL__
return _mm_cvtepu32_ps(v); return _mm_cvtepu32_ps(v);
#else #else
__m128i v2 = _mm_srli_epi32(v, 1); __m128i v2 = _mm_srli_epi32(v, 1);
...@@ -40,13 +44,14 @@ QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v) ...@@ -40,13 +44,14 @@ QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v)
#endif #endif
} }
QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i v) QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x)
{ {
#ifdef __AVX512F__ #ifdef __AVX512VL__
return _mm_cvtepu64_pd(v); return _mm_cvtepu64_pd(x);
#else #else
#warning need to implement _my_cvtepu64_pd uint64 r[2];
return (__m128d) v; _mm_storeu_si128((__m128i*)r, x);
return _mm_set_pd((double)r[1], (double)r[0]);
#endif #endif
} }
...@@ -55,25 +60,27 @@ QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3 ...@@ -55,25 +60,27 @@ QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3
uint32 key0, uint32 key1, uint32 key2, uint32 key3, uint32 key0, uint32 key1, uint32 key2, uint32 key3,
double & rnd1, double & rnd2) double & rnd1, double & rnd2)
{ {
// pack input and call AES
__m128i c128 = _mm_set_epi32(ctr3, ctr2, ctr1, ctr0); __m128i c128 = _mm_set_epi32(ctr3, ctr2, ctr1, ctr0);
__m128i k128 = _mm_set_epi32(key3, key2, key1, key0); __m128i k128 = _mm_set_epi32(key3, key2, key1, key0);
c128 = aesni1xm128i(c128, k128); c128 = aesni1xm128i(c128, k128);
uint32 r[4]; // convert 32 to 64 bit and put 0th and 2nd element into x, 1st and 3rd element into y
_mm_storeu_si128((__m128i*)&r[0], c128); __m128i x = _mm_and_si128(c128, _mm_set_epi32(0, 0xffffffff, 0, 0xffffffff));
__m128i x = _mm_set_epi64x((uint64) r[2], (uint64) r[0]); __m128i y = _mm_and_si128(c128, _mm_set_epi32(0xffffffff, 0, 0xffffffff, 0));
__m128i y = _mm_set_epi64x((uint64) r[3], (uint64) r[1]); y = _mm_srli_si128(y, 4);
__m128i cnt = _mm_set_epi64x(53 - 32, 53 - 32); // calculate z = x ^ y << (53 - 32))
y = _mm_sll_epi64(y, cnt); __m128i z = _mm_sll_epi64(y, _mm_set_epi64x(53 - 32, 53 - 32));
__m128i z = _mm_xor_si128(x, y); z = _mm_xor_si128(x, z);
// convert uint64 to double
__m128d rs = _my_cvtepu64_pd(z); __m128d rs = _my_cvtepu64_pd(z);
const __m128d tp53 = _mm_set_pd1(TWOPOW53_INV_DOUBLE); // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
const __m128d tp54 = _mm_set_pd1(TWOPOW53_INV_DOUBLE/2.0); rs = _mm_mul_pd(rs, _mm_set_pd1(TWOPOW53_INV_DOUBLE));
rs = _mm_mul_pd(rs, tp53); rs = _mm_add_pd(rs, _mm_set_pd1(TWOPOW53_INV_DOUBLE/2.0));
rs = _mm_add_pd(rs, tp54);
// store result
double rr[2]; double rr[2];
_mm_storeu_pd(rr, rs); _mm_storeu_pd(rr, rs);
rnd1 = rr[0]; rnd1 = rr[0];
...@@ -85,16 +92,18 @@ QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, ...@@ -85,16 +92,18 @@ QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3, uint32 key0, uint32 key1, uint32 key2, uint32 key3,
float & rnd1, float & rnd2, float & rnd3, float & rnd4) float & rnd1, float & rnd2, float & rnd3, float & rnd4)
{ {
// pack input and call AES
__m128i c128 = _mm_set_epi32(ctr3, ctr2, ctr1, ctr0); __m128i c128 = _mm_set_epi32(ctr3, ctr2, ctr1, ctr0);
__m128i k128 = _mm_set_epi32(key3, key2, key1, key0); __m128i k128 = _mm_set_epi32(key3, key2, key1, key0);
c128 = aesni1xm128i(c128, k128); c128 = aesni1xm128i(c128, k128);
// convert uint32 to float
__m128 rs = _my_cvtepu32_ps(c128); __m128 rs = _my_cvtepu32_ps(c128);
const __m128 tp32 = _mm_set_ps1(TWOPOW32_INV_FLOAT); // calculate rs * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
const __m128 tp33 = _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f); rs = _mm_mul_ps(rs, _mm_set_ps1(TWOPOW32_INV_FLOAT));
rs = _mm_mul_ps(rs, tp32); rs = _mm_add_ps(rs, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f));
rs = _mm_add_ps(rs, tp33);
// store result
float r[4]; float r[4];
_mm_storeu_ps(r, rs); _mm_storeu_ps(r, rs);
rnd1 = r[0]; rnd1 = r[0];
......
...@@ -21,7 +21,7 @@ def _get_rng_template(name, data_type, num_vars): ...@@ -21,7 +21,7 @@ def _get_rng_template(name, data_type, num_vars):
def _get_rng_code(template, dialect, vector_instruction_set, time_step, offsets, keys, dim, result_symbols): def _get_rng_code(template, dialect, vector_instruction_set, time_step, offsets, keys, dim, result_symbols):
parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i]
for i in range(dim)] + [0] * (3-dim) + list(keys) for i in range(dim)] + [0] * (3 - dim) + list(keys)
if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None): if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
return template.format(parameters=', '.join(str(p) for p in parameters), return template.format(parameters=', '.join(str(p) for p in parameters),
...@@ -67,7 +67,7 @@ class RNGBase(CustomCodeNode): ...@@ -67,7 +67,7 @@ class RNGBase(CustomCodeNode):
def get_code(self, dialect, vector_instruction_set): def get_code(self, dialect, vector_instruction_set):
template = _get_rng_template(self._name, self._data_type, self._num_vars) template = _get_rng_template(self._name, self._data_type, self._num_vars)
return _get_rng_code(template, dialect, vector_instruction_set, return _get_rng_code(template, dialect, vector_instruction_set,
self._time_step, self._offsets, self.keys, self._dim, self.result_symbols) self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
def __repr__(self): def __repr__(self):
return (", ".join(['{}'] * self._num_vars) + " <- {}RNG").format(*self.result_symbols, self._name.capitalize()) return (", ".join(['{}'] * self._num_vars) + " <- {}RNG").format(*self.result_symbols, self._name.capitalize())
......
...@@ -76,12 +76,6 @@ def test_aesni_double(): ...@@ -76,12 +76,6 @@ def test_aesni_double():
arr = dh.gather_array('f') arr = dh.gather_array('f')
assert np.logical_and(arr <= 1.0, arr >= 0).all() assert np.logical_and(arr <= 1.0, arr >= 0).all()
#x = aesni_reference[:,:,0::2]
#y = aesni_reference[:,:,1::2]
#z = x ^ y << (53 - 32)
#double_reference = z * 2.**-53 + 2.**-54
#assert(np.allclose(arr, double_reference, rtol=0, atol=np.finfo(np.float64).eps))
def test_aesni_float(): def test_aesni_float():
dh = ps.create_data_handling((2, 2), default_ghost_layers=0, default_target="cpu") dh = ps.create_data_handling((2, 2), default_ghost_layers=0, default_target="cpu")
...@@ -97,7 +91,4 @@ def test_aesni_float(): ...@@ -97,7 +91,4 @@ def test_aesni_float():
dh.run_kernel(kernel, time_step=124) dh.run_kernel(kernel, time_step=124)
dh.all_to_cpu() dh.all_to_cpu()
arr = dh.gather_array('f') arr = dh.gather_array('f')
assert np.logical_and(arr <= 1.0, arr >= 0).all() assert np.logical_and(arr <= 1.0, arr >= 0).all()
\ No newline at end of file
#float_reference = aesni_reference * 2.**-32 + 2.**-33
#assert(np.allclose(arr, float_reference, rtol=0, atol=np.finfo(np.float32).eps))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment