Skip to content
Snippets Groups Projects
Commit ec5dedb6 authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

First attempt at calling vectorized RNG

parent 12e5f4d0
Branches philox-simd
No related tags found
No related merge requests found
......@@ -218,6 +218,30 @@ QUALIFIERS __m256i aesni1xm128i(const __m256i & in, const __m256i & k) {
return x;
}
QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4)
{
__m128i ctr0v = _mm_add_epi32(_mm_set1_epi32(ctr0), _mm_set_epi32(3,2,1,0));
__m128i ctr1v = _mm_set1_epi32(ctr1);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
aesni_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi)
{
__m128i ctr0v = _mm_add_epi32(_mm_set1_epi32(ctr0), _mm_set_epi32(3,2,1,0));
__m128i ctr1v = _mm_set1_epi32(ctr1);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
aesni_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
template<bool high>
QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y)
{
......@@ -350,6 +374,30 @@ QUALIFIERS void aesni_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4)
{
__m256i ctr0v = _mm256_add_epi32(_mm256_set1_epi32(ctr0), _mm256_set_epi32(7,6,5,4,3,2,1,0));
__m256i ctr1v = _mm256_set1_epi32(ctr1);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
aesni_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi)
{
__m256i ctr0v = _mm256_add_epi32(_mm256_set1_epi32(ctr0), _mm256_set_epi32(7,6,5,4,3,2,1,0));
__m256i ctr1v = _mm256_set1_epi32(ctr1);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
aesni_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
#endif
......@@ -511,5 +559,29 @@ QUALIFIERS void aesni_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4)
{
__m512i ctr0v = _mm512_add_epi32(_mm512_set1_epi32(ctr0), _mm512_set_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0));
__m512i ctr1v = _mm512_set1_epi32(ctr1);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
philox_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi)
{
__m512i ctr0v = _mm512_add_epi32(_mm512_set1_epi32(ctr0), _mm512_set_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0));
__m512i ctr1v = _mm512_set1_epi32(ctr1);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
philox_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
#endif
......@@ -241,6 +241,30 @@ QUALIFIERS void philox_double2(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4)
{
__m128i ctr0v = _mm_add_epi32(_mm_set1_epi32(ctr0), _mm_set_epi32(3,2,1,0));
__m128i ctr1v = _mm_set1_epi32(ctr1);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
philox_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi)
{
__m128i ctr0v = _mm_add_epi32(_mm_set1_epi32(ctr0), _mm_set_epi32(3,2,1,0));
__m128i ctr1v = _mm_set1_epi32(ctr1);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
philox_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
#endif
#ifdef __AVX2__
......@@ -369,6 +393,30 @@ QUALIFIERS void philox_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4)
{
__m256i ctr0v = _mm256_add_epi32(_mm256_set1_epi32(ctr0), _mm256_set_epi32(7,6,5,4,3,2,1,0));
__m256i ctr1v = _mm256_set1_epi32(ctr1);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
philox_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi)
{
__m256i ctr0v = _mm256_add_epi32(_mm256_set1_epi32(ctr0), _mm256_set_epi32(7,6,5,4,3,2,1,0));
__m256i ctr1v = _mm256_set1_epi32(ctr1);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
philox_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
#endif
#ifdef __AVX512F__
......@@ -481,6 +529,30 @@ QUALIFIERS void philox_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4)
{
__m512i ctr0v = _mm512_add_epi32(_mm512_set1_epi32(ctr0), _mm512_set_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0));
__m512i ctr1v = _mm512_set1_epi32(ctr1);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
philox_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi)
{
__m512i ctr0v = _mm512_add_epi32(_mm512_set1_epi32(ctr0), _mm512_set_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0));
__m512i ctr1v = _mm512_set1_epi32(ctr1);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
philox_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
#endif
#endif
......@@ -6,16 +6,23 @@ from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import CustomCodeNode
def _get_rng_template(name, data_type, num_vars):
def _data_type_to_str(data_type):
if data_type is np.float32:
c_type = "float"
return "float"
elif data_type is np.float64:
c_type = "double"
return "double"
elif type(data_type) is str:
return data_type
raise ValueError("%s is not a supported data type" % (data_type, ))
def _get_rng_template(name, data_type, num_vars):
c_type = _data_type_to_str(data_type)
template = "\n"
for i in range(num_vars):
template += "{{result_symbols[{}].dtype}} {{result_symbols[{}].name}};\n".format(i, i)
template += ("{}_{}{}({{parameters}}, " + ", ".join(["{{result_symbols[{}].name}}"] * num_vars) + ");\n") \
.format(name, c_type, num_vars, *tuple(range(num_vars)))
template += "{} {{result_symbols[{}].name}};\n".format(c_type, i, i)
template += ("{}({{parameters}}, " + ", ".join(["{{result_symbols[{}].name}}"] * num_vars) + ");\n") \
.format(name, *tuple(range(num_vars)))
return template
......@@ -23,7 +30,7 @@ def _get_rng_code(template, dialect, vector_instruction_set, time_step, offsets,
parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i]
for i in range(dim)] + [0] * (3 - dim) + list(keys)
if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
if dialect == 'cuda' or dialect == 'c':
return template.format(parameters=', '.join(str(p) for p in parameters),
result_symbols=result_symbols)
else:
......@@ -44,7 +51,7 @@ class RNGBase(CustomCodeNode):
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self._time_step = time_step
self._offsets = offsets
self.headers = ['"{}_rand.h"'.format(self._name)]
self.headers = ['"{}_rand.h"'.format(self._name.split('_')[0])]
self.keys = tuple(keys)
self._args = sp.sympify((dim, time_step, keys))
self._dim = dim
......@@ -65,7 +72,11 @@ class RNGBase(CustomCodeNode):
return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
def get_code(self, dialect, vector_instruction_set):
template = _get_rng_template(self._name, self._data_type, self._num_vars)
if vector_instruction_set:
template = _get_rng_template(self._name, vector_instruction_set[_data_type_to_str(self._data_type)],
self._num_vars)
else:
template = _get_rng_template(self._name, self._data_type, self._num_vars)
return _get_rng_code(template, dialect, vector_instruction_set,
self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
......@@ -74,28 +85,28 @@ class RNGBase(CustomCodeNode):
class PhiloxTwoDoubles(RNGBase):
_name = "philox"
_name = "philox_double2"
_data_type = np.float64
_num_vars = 2
_num_keys = 2
class PhiloxFourFloats(RNGBase):
_name = "philox"
_name = "philox_float4"
_data_type = np.float32
_num_vars = 4
_num_keys = 2
class AESNITwoDoubles(RNGBase):
_name = "aesni"
_name = "aesni_double2"
_data_type = np.float64
_num_vars = 2
_num_keys = 4
class AESNIFourFloats(RNGBase):
_name = "aesni"
_name = "aesni_float4"
_data_type = np.float32
_num_vars = 4
_num_keys = 4
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment