Skip to content
Snippets Groups Projects
Commit 4c842f1a authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

vectorize philox_float4

parent 745be03a
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,7 @@
#include <emmintrin.h> // SSE2
#include <wmmintrin.h> // AES
#ifdef __AVX512VL__
#if defined(__AVX512VL__) || defined(__AVX512F__)
#include <immintrin.h> // AVX*
#endif
#include <cstdint>
......@@ -33,7 +33,7 @@ QUALIFIERS __m128i aesni1xm128i(const __m128i & in, const __m128i & k) {
QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v)
{
#ifdef __AVX512VL__
#if defined(__AVX512VL__) || defined(__AVX512F__)
return _mm_cvtepu32_ps(v);
#else
__m128i v2 = _mm_srli_epi32(v, 1);
......@@ -46,7 +46,7 @@ QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v)
QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x)
{
#ifdef __AVX512VL__
#if defined(__AVX512VL__) || defined(__AVX512F__)
return _mm_cvtepu64_pd(x);
#else
uint64 r[2];
......@@ -110,4 +110,5 @@ QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
rnd2 = r[1];
rnd3 = r[2];
rnd4 = r[3];
}
\ No newline at end of file
}
#include <cstdint>
#ifdef __SSE__
#include <emmintrin.h> // SSE2
#endif
#ifdef __AVX__
#include <immintrin.h> // AVX*
#else
#include <smmintrin.h> // SSE4
#ifdef __FMA__
#include <immintrin.h> // FMA
#endif
#endif
#ifndef __CUDA_ARCH__
#define QUALIFIERS inline
#else
......@@ -78,7 +90,6 @@ QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr
}
QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float & rnd1, float & rnd2, float & rnd3, float & rnd4)
......@@ -100,4 +111,247 @@ QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3
rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
}
\ No newline at end of file
}
#ifdef __SSE__
QUALIFIERS void _philox4x32round(__m128i* ctr, __m128i* key)
{
__m128i lohi0a = _mm_mul_epu32(ctr[0], _mm_set1_epi32(PHILOX_M4x32_0));
__m128i lohi0b = _mm_mul_epu32(_mm_srli_epi64(ctr[0], 32), _mm_set1_epi32(PHILOX_M4x32_0));
__m128i lohi1a = _mm_mul_epu32(ctr[2], _mm_set1_epi32(PHILOX_M4x32_1));
__m128i lohi1b = _mm_mul_epu32(_mm_srli_epi64(ctr[2], 32), _mm_set1_epi32(PHILOX_M4x32_1));
lohi0a = _mm_shuffle_epi32(lohi0a, 0xD8);
lohi0b = _mm_shuffle_epi32(lohi0b, 0xD8);
lohi1a = _mm_shuffle_epi32(lohi1a, 0xD8);
lohi1b = _mm_shuffle_epi32(lohi1b, 0xD8);
__m128i lo0 = _mm_unpacklo_epi32(lohi0a, lohi0b);
__m128i hi0 = _mm_unpackhi_epi32(lohi0a, lohi0b);
__m128i lo1 = _mm_unpacklo_epi32(lohi1a, lohi1b);
__m128i hi1 = _mm_unpackhi_epi32(lohi1a, lohi1b);
ctr[0] = _mm_xor_si128(_mm_xor_si128(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = _mm_xor_si128(_mm_xor_si128(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(__m128i* key)
{
key[0] = _mm_add_epi32(key[0], _mm_set1_epi32(PHILOX_W32_0));
key[1] = _mm_add_epi32(key[1], _mm_set1_epi32(PHILOX_W32_1));
}
QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v)
{
#if defined(__AVX512VL__) || defined(__AVX512F__)
return _mm_cvtepu32_ps(v);
#else
__m128i v2 = _mm_srli_epi32(v, 1);
__m128i v1 = _mm_and_si128(v, _mm_set1_epi32(1));
__m128 v2f = _mm_cvtepi32_ps(v2);
__m128 v1f = _mm_cvtepi32_ps(v1);
return _mm_add_ps(_mm_add_ps(v2f, v2f), v1f);
#endif
}
QUALIFIERS void philox_float16(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3,
uint32 key0, uint32 key1,
__m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4)
{
__m128i key[2] = {_mm_set1_epi32(key0), _mm_set1_epi32(key1)};
__m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = _my_cvtepu32_ps(ctr[0]);
rnd2 = _my_cvtepu32_ps(ctr[1]);
rnd3 = _my_cvtepu32_ps(ctr[2]);
rnd4 = _my_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
#ifdef __FMA__
rnd1 = _mm_fmadd_ps(rnd1, _mm_set_ps1(TWOPOW32_INV_FLOAT), _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm_fmadd_ps(rnd2, _mm_set_ps1(TWOPOW32_INV_FLOAT), _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm_fmadd_ps(rnd3, _mm_set_ps1(TWOPOW32_INV_FLOAT), _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm_fmadd_ps(rnd4, _mm_set_ps1(TWOPOW32_INV_FLOAT), _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0));
#else
rnd1 = _mm_mul_ps(rnd1, _mm_set_ps1(TWOPOW32_INV_FLOAT));
rnd1 = _mm_add_ps(rnd1, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = _mm_mul_ps(rnd2, _mm_set_ps1(TWOPOW32_INV_FLOAT));
rnd2 = _mm_add_ps(rnd2, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = _mm_mul_ps(rnd3, _mm_set_ps1(TWOPOW32_INV_FLOAT));
rnd3 = _mm_add_ps(rnd3, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = _mm_mul_ps(rnd4, _mm_set_ps1(TWOPOW32_INV_FLOAT));
rnd4 = _mm_add_ps(rnd4, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f));
#endif
}
#endif
#ifdef __AVX__
QUALIFIERS void _philox4x32round(__m256i* ctr, __m256i* key)
{
__m256i lohi0a = _mm256_mul_epu32(ctr[0], _mm256_set1_epi32(PHILOX_M4x32_0));
__m256i lohi0b = _mm256_mul_epu32(_mm256_srli_epi64(ctr[0], 32), _mm256_set1_epi32(PHILOX_M4x32_0));
__m256i lohi1a = _mm256_mul_epu32(ctr[2], _mm256_set1_epi32(PHILOX_M4x32_1));
__m256i lohi1b = _mm256_mul_epu32(_mm256_srli_epi64(ctr[2], 32), _mm256_set1_epi32(PHILOX_M4x32_1));
lohi0a = _mm256_shuffle_epi32(lohi0a, 0xD8);
lohi0b = _mm256_shuffle_epi32(lohi0b, 0xD8);
lohi1a = _mm256_shuffle_epi32(lohi1a, 0xD8);
lohi1b = _mm256_shuffle_epi32(lohi1b, 0xD8);
__m256i lo0 = _mm256_unpacklo_epi32(lohi0a, lohi0b);
__m256i hi0 = _mm256_unpackhi_epi32(lohi0a, lohi0b);
__m256i lo1 = _mm256_unpacklo_epi32(lohi1a, lohi1b);
__m256i hi1 = _mm256_unpackhi_epi32(lohi1a, lohi1b);
ctr[0] = _mm256_xor_si256(_mm256_xor_si256(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = _mm256_xor_si256(_mm256_xor_si256(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(__m256i* key)
{
key[0] = _mm256_add_epi32(key[0], _mm256_set1_epi32(PHILOX_W32_0));
key[1] = _mm256_add_epi32(key[1], _mm256_set1_epi32(PHILOX_W32_1));
}
QUALIFIERS __m256 _my256_cvtepu32_ps(const __m256i v)
{
#if defined(__AVX512VL__) || defined(__AVX512F__)
return _mm256_cvtepu32_ps(v);
#else
__m256i v2 = _mm256_srli_epi32(v, 1);
__m256i v1 = _mm256_and_si256(v, _mm256_set1_epi32(1));
__m256 v2f = _mm256_cvtepi32_ps(v2);
__m256 v1f = _mm256_cvtepi32_ps(v1);
return _mm256_add_ps(_mm256_add_ps(v2f, v2f), v1f);
#endif
}
QUALIFIERS __m256 _my256_set_ps1(const float v)
{
return _mm256_set_ps(v, v, v, v, v, v, v, v);
}
QUALIFIERS void philox_float32(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3,
uint32 key0, uint32 key1,
__m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4)
{
__m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)};
__m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = _my256_cvtepu32_ps(ctr[0]);
rnd2 = _my256_cvtepu32_ps(ctr[1]);
rnd3 = _my256_cvtepu32_ps(ctr[2]);
rnd4 = _my256_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
#ifdef __FMA__
rnd1 = _mm256_fmadd_ps(rnd1, _my256_set_ps1(TWOPOW32_INV_FLOAT), _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm256_fmadd_ps(rnd2, _my256_set_ps1(TWOPOW32_INV_FLOAT), _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm256_fmadd_ps(rnd3, _my256_set_ps1(TWOPOW32_INV_FLOAT), _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm256_fmadd_ps(rnd4, _my256_set_ps1(TWOPOW32_INV_FLOAT), _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0));
#else
rnd1 = _mm256_mul_ps(rnd1, _my256_set_ps1(TWOPOW32_INV_FLOAT));
rnd1 = _mm256_add_ps(rnd1, _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = _mm256_mul_ps(rnd2, _my256_set_ps1(TWOPOW32_INV_FLOAT));
rnd2 = _mm256_add_ps(rnd2, _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = _mm256_mul_ps(rnd3, _my256_set_ps1(TWOPOW32_INV_FLOAT));
rnd3 = _mm256_add_ps(rnd3, _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = _mm256_mul_ps(rnd4, _my256_set_ps1(TWOPOW32_INV_FLOAT));
rnd4 = _mm256_add_ps(rnd4, _my256_set_ps1(TWOPOW32_INV_FLOAT/2.0f));
#endif
}
#endif
#ifdef __AVX512F__
QUALIFIERS void _philox4x32round(__m512i* ctr, __m512i* key)
{
__m512i lohi0a = _mm512_mul_epu32(ctr[0], _mm512_set1_epi32(PHILOX_M4x32_0));
__m512i lohi0b = _mm512_mul_epu32(_mm512_srli_epi64(ctr[0], 32), _mm512_set1_epi32(PHILOX_M4x32_0));
__m512i lohi1a = _mm512_mul_epu32(ctr[2], _mm512_set1_epi32(PHILOX_M4x32_1));
__m512i lohi1b = _mm512_mul_epu32(_mm512_srli_epi64(ctr[2], 32), _mm512_set1_epi32(PHILOX_M4x32_1));
lohi0a = _mm512_shuffle_epi32(lohi0a, 0xD8);
lohi0b = _mm512_shuffle_epi32(lohi0b, 0xD8);
lohi1a = _mm512_shuffle_epi32(lohi1a, 0xD8);
lohi1b = _mm512_shuffle_epi32(lohi1b, 0xD8);
__m512i lo0 = _mm512_unpacklo_epi32(lohi0a, lohi0b);
__m512i hi0 = _mm512_unpackhi_epi32(lohi0a, lohi0b);
__m512i lo1 = _mm512_unpacklo_epi32(lohi1a, lohi1b);
__m512i hi1 = _mm512_unpackhi_epi32(lohi1a, lohi1b);
ctr[0] = _mm512_xor_si512(_mm512_xor_si512(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = _mm512_xor_si512(_mm512_xor_si512(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(__m512i* key)
{
key[0] = _mm512_add_epi32(key[0], _mm512_set1_epi32(PHILOX_W32_0));
key[1] = _mm512_add_epi32(key[1], _mm512_set1_epi32(PHILOX_W32_1));
}
QUALIFIERS __m512 _my512_set_ps1(const float v)
{
return _mm512_set_ps(v, v, v, v, v, v, v, v, v, v, v, v, v, v, v, v);
}
QUALIFIERS void philox_float64(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3,
uint32 key0, uint32 key1,
__m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4)
{
__m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)};
__m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = _mm512_cvtepu32_ps(ctr[0]);
rnd2 = _mm512_cvtepu32_ps(ctr[1]);
rnd3 = _mm512_cvtepu32_ps(ctr[2]);
rnd4 = _mm512_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = _mm512_fmadd_ps(rnd1, _my512_set_ps1(TWOPOW32_INV_FLOAT), _my512_set_ps1(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm512_fmadd_ps(rnd2, _my512_set_ps1(TWOPOW32_INV_FLOAT), _my512_set_ps1(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm512_fmadd_ps(rnd3, _my512_set_ps1(TWOPOW32_INV_FLOAT), _my512_set_ps1(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm512_fmadd_ps(rnd4, _my512_set_ps1(TWOPOW32_INV_FLOAT), _my512_set_ps1(TWOPOW32_INV_FLOAT/2.0));
}
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment