Skip to content
Snippets Groups Projects
Commit 6209c0da authored by Michael Kuron's avatar Michael Kuron :mortar_board: Committed by Markus Holzer
Browse files

OpenCL RNG

parent fce1a417
No related branches found
No related tags found
1 merge request!247OpenCL RNG
#include <cstdint> #ifndef __OPENCL_VERSION__
#if defined(__SSE2__) || defined(_MSC_VER) #if defined(__SSE2__) || defined(_MSC_VER)
#include <emmintrin.h> // SSE2 #include <emmintrin.h> // SSE2
#endif #endif
...@@ -33,12 +32,15 @@ ...@@ -33,12 +32,15 @@
#ifdef __riscv_v #ifdef __riscv_v
#include <riscv_vector.h> #include <riscv_vector.h>
#endif #endif
#endif
#ifndef __CUDA_ARCH__ #ifdef __CUDA_ARCH__
#define QUALIFIERS static __forceinline__ __device__
#elif defined(__OPENCL_VERSION__)
#define QUALIFIERS static inline
#else
#define QUALIFIERS inline #define QUALIFIERS inline
#include "myintrin.h" #include "myintrin.h"
#else
#define QUALIFIERS static __forceinline__ __device__
#endif #endif
#define PHILOX_W32_0 (0x9E3779B9) #define PHILOX_W32_0 (0x9E3779B9)
...@@ -48,8 +50,15 @@ ...@@ -48,8 +50,15 @@
#define TWOPOW53_INV_DOUBLE (1.1102230246251565e-16) #define TWOPOW53_INV_DOUBLE (1.1102230246251565e-16)
#define TWOPOW32_INV_FLOAT (2.3283064e-10f) #define TWOPOW32_INV_FLOAT (2.3283064e-10f)
#ifdef __OPENCL_VERSION__
#include "opencl_stdint.h"
typedef uint32_t uint32;
typedef uint64_t uint64;
#else
#include <cstdint>
typedef std::uint32_t uint32; typedef std::uint32_t uint32;
typedef std::uint64_t uint64; typedef std::uint64_t uint64;
#endif
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0 #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0
typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
...@@ -67,6 +76,9 @@ QUALIFIERS uint32 mulhilo32(uint32 a, uint32 b, uint32* hip) ...@@ -67,6 +76,9 @@ QUALIFIERS uint32 mulhilo32(uint32 a, uint32 b, uint32* hip)
#if defined(__powerpc__) && (!defined(__clang__) || defined(__xlC__)) #if defined(__powerpc__) && (!defined(__clang__) || defined(__xlC__))
*hip = __mulhwu(a,b); *hip = __mulhwu(a,b);
return a*b; return a*b;
#elif defined(__OPENCL_VERSION__)
*hip = mul_hi(a,b);
return a*b;
#else #else
uint64 product = ((uint64)a) * ((uint64)b); uint64 product = ((uint64)a) * ((uint64)b);
*hip = product >> 32; *hip = product >> 32;
...@@ -106,7 +118,12 @@ QUALIFIERS double _uniform_double_hq(uint32 x, uint32 y) ...@@ -106,7 +118,12 @@ QUALIFIERS double _uniform_double_hq(uint32 x, uint32 y)
QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, double & rnd1, double & rnd2) uint32 key0, uint32 key1,
#ifdef __OPENCL_VERSION__
double * rnd1, double * rnd2)
#else
double & rnd1, double & rnd2)
#endif
{ {
uint32 key[2] = {key0, key1}; uint32 key[2] = {key0, key1};
uint32 ctr[4] = {ctr0, ctr1, ctr2, ctr3}; uint32 ctr[4] = {ctr0, ctr1, ctr2, ctr3};
...@@ -121,14 +138,23 @@ QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr ...@@ -121,14 +138,23 @@ QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
#ifdef __OPENCL_VERSION__
*rnd1 = _uniform_double_hq(ctr[0], ctr[1]);
*rnd2 = _uniform_double_hq(ctr[2], ctr[3]);
#else
rnd1 = _uniform_double_hq(ctr[0], ctr[1]); rnd1 = _uniform_double_hq(ctr[0], ctr[1]);
rnd2 = _uniform_double_hq(ctr[2], ctr[3]); rnd2 = _uniform_double_hq(ctr[2], ctr[3]);
#endif
} }
QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key0, uint32 key1,
#ifdef __OPENCL_VERSION__
float * rnd1, float * rnd2, float * rnd3, float * rnd4)
#else
float & rnd1, float & rnd2, float & rnd3, float & rnd4) float & rnd1, float & rnd2, float & rnd3, float & rnd4)
#endif
{ {
uint32 key[2] = {key0, key1}; uint32 key[2] = {key0, key1};
uint32 ctr[4] = {ctr0, ctr1, ctr2, ctr3}; uint32 ctr[4] = {ctr0, ctr1, ctr2, ctr3};
...@@ -143,13 +169,20 @@ QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3 ...@@ -143,13 +169,20 @@ QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
#ifdef __OPENCL_VERSION__
*rnd1 = ctr[0] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
*rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
*rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
*rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
#else
rnd1 = ctr[0] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); rnd1 = ctr[0] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
#endif
} }
#ifndef __CUDA_ARCH__ #if !defined(__CUDA_ARCH__) && !defined(__OPENCL_VERSION__)
#if defined(__SSE4_1__) || defined(_MSC_VER) #if defined(__SSE4_1__) || defined(_MSC_VER)
QUALIFIERS void _philox4x32round(__m128i* ctr, __m128i* key) QUALIFIERS void _philox4x32round(__m128i* ctr, __m128i* key)
{ {
...@@ -318,13 +351,13 @@ QUALIFIERS void _philox4x32round(__vector unsigned int* ctr, __vector unsigned i ...@@ -318,13 +351,13 @@ QUALIFIERS void _philox4x32round(__vector unsigned int* ctr, __vector unsigned i
{ {
#ifndef _ARCH_PWR8 #ifndef _ARCH_PWR8
__vector unsigned int lo0 = vec_mul(ctr[0], vec_splats(PHILOX_M4x32_0)); __vector unsigned int lo0 = vec_mul(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int lo1 = vec_mul(ctr[2], vec_splats(PHILOX_M4x32_1));
__vector unsigned int hi0 = vec_mulhuw(ctr[0], vec_splats(PHILOX_M4x32_0)); __vector unsigned int hi0 = vec_mulhuw(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int lo1 = vec_mul(ctr[2], vec_splats(PHILOX_M4x32_1));
__vector unsigned int hi1 = vec_mulhuw(ctr[2], vec_splats(PHILOX_M4x32_1)); __vector unsigned int hi1 = vec_mulhuw(ctr[2], vec_splats(PHILOX_M4x32_1));
#elif defined(_ARCH_PWR10) #elif defined(_ARCH_PWR10)
__vector unsigned int lo0 = vec_mul(ctr[0], vec_splats(PHILOX_M4x32_0)); __vector unsigned int lo0 = vec_mul(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int lo1 = vec_mul(ctr[2], vec_splats(PHILOX_M4x32_1));
__vector unsigned int hi0 = vec_mulh(ctr[0], vec_splats(PHILOX_M4x32_0)); __vector unsigned int hi0 = vec_mulh(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int lo1 = vec_mul(ctr[2], vec_splats(PHILOX_M4x32_1));
__vector unsigned int hi1 = vec_mulh(ctr[2], vec_splats(PHILOX_M4x32_1)); __vector unsigned int hi1 = vec_mulh(ctr[2], vec_splats(PHILOX_M4x32_1));
#else #else
__vector unsigned int lohi0a = (__vector unsigned int) vec_mule(ctr[0], vec_splats(PHILOX_M4x32_0)); __vector unsigned int lohi0a = (__vector unsigned int) vec_mule(ctr[0], vec_splats(PHILOX_M4x32_0));
...@@ -675,8 +708,8 @@ QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 ...@@ -675,8 +708,8 @@ QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32
QUALIFIERS void _philox4x32round(svuint32x4_t & ctr, svuint32x2_t & key) QUALIFIERS void _philox4x32round(svuint32x4_t & ctr, svuint32x2_t & key)
{ {
svuint32_t lo0 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0)); svuint32_t lo0 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0));
svuint32_t lo1 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1));
svuint32_t hi0 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0)); svuint32_t hi0 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0));
svuint32_t lo1 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1));
svuint32_t hi1 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1)); svuint32_t hi1 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1));
ctr = svset4_u32(ctr, 0, sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi1, svget4_u32(ctr, 1)), svget2_u32(key, 0))); ctr = svset4_u32(ctr, 0, sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi1, svget4_u32(ctr, 1)), svget2_u32(key, 0)));
...@@ -827,8 +860,8 @@ QUALIFIERS void _philox4x32round(vuint32m1_t & ctr0, vuint32m1_t & ctr1, vuint32 ...@@ -827,8 +860,8 @@ QUALIFIERS void _philox4x32round(vuint32m1_t & ctr0, vuint32m1_t & ctr1, vuint32
vuint32m1_t key0, vuint32m1_t key1) vuint32m1_t key0, vuint32m1_t key1)
{ {
vuint32m1_t lo0 = vmul_vv_u32m1(ctr0, vmv_v_x_u32m1(PHILOX_M4x32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1()); vuint32m1_t lo0 = vmul_vv_u32m1(ctr0, vmv_v_x_u32m1(PHILOX_M4x32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
vuint32m1_t lo1 = vmul_vv_u32m1(ctr2, vmv_v_x_u32m1(PHILOX_M4x32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1());
vuint32m1_t hi0 = vmulhu_vv_u32m1(ctr0, vmv_v_x_u32m1(PHILOX_M4x32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1()); vuint32m1_t hi0 = vmulhu_vv_u32m1(ctr0, vmv_v_x_u32m1(PHILOX_M4x32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
vuint32m1_t lo1 = vmul_vv_u32m1(ctr2, vmv_v_x_u32m1(PHILOX_M4x32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1());
vuint32m1_t hi1 = vmulhu_vv_u32m1(ctr2, vmv_v_x_u32m1(PHILOX_M4x32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1()); vuint32m1_t hi1 = vmulhu_vv_u32m1(ctr2, vmv_v_x_u32m1(PHILOX_M4x32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1());
ctr0 = vxor_vv_u32m1(vxor_vv_u32m1(hi1, ctr1, vsetvlmax_e32m1()), key0, vsetvlmax_e32m1()); ctr0 = vxor_vv_u32m1(vxor_vv_u32m1(hi1, ctr1, vsetvlmax_e32m1()), key0, vsetvlmax_e32m1());
......
...@@ -53,8 +53,9 @@ class RNGBase(CustomCodeNode): ...@@ -53,8 +53,9 @@ class RNGBase(CustomCodeNode):
else: else:
code += f"{vector_instruction_set[r.dtype.base_name] if vector_instruction_set else r.dtype} " + \ code += f"{vector_instruction_set[r.dtype.base_name] if vector_instruction_set else r.dtype} " + \
f"{r.name};\n" f"{r.name};\n"
code += (self._name + "(" + ", ".join([print_arg(a) for a in self.args] args = [print_arg(a) for a in self.args] + \
+ [r.name for r in self.result_symbols]) + ");\n") [('&' if dialect == 'opencl' else '') + r.name for r in self.result_symbols]
code += (self._name + "(" + ", ".join(args) + ");\n")
return code return code
def __repr__(self): def __repr__(self):
......
...@@ -21,12 +21,16 @@ if get_compiler_config()['os'] == 'windows': ...@@ -21,12 +21,16 @@ if get_compiler_config()['os'] == 'windows':
instruction_sets.remove('avx512') instruction_sets.remove('avx512')
@pytest.mark.parametrize('target,rng', (('cpu', 'philox'), ('cpu', 'aesni'), ('gpu', 'philox'))) @pytest.mark.parametrize('target,rng', (('cpu', 'philox'), ('cpu', 'aesni'), ('gpu', 'philox'), ('opencl', 'philox')))
@pytest.mark.parametrize('precision', ('float', 'double')) @pytest.mark.parametrize('precision', ('float', 'double'))
@pytest.mark.parametrize('dtype', ('float', 'double')) @pytest.mark.parametrize('dtype', ('float', 'double'))
def test_rng(target, rng, precision, dtype, t=124, offsets=(0, 0), keys=(0, 0), offset_values=None): def test_rng(target, rng, precision, dtype, t=124, offsets=(0, 0), keys=(0, 0), offset_values=None):
if target == 'gpu': if target == 'gpu':
pytest.importorskip('pycuda') pytest.importorskip('pycuda')
if target == 'opencl':
pytest.importorskip('pyopencl')
from pystencils.opencl.opencljit import init_globally
init_globally()
if instruction_sets and set(['neon', 'sve', 'vsx', 'rvv']).intersection(instruction_sets) and rng == 'aesni': if instruction_sets and set(['neon', 'sve', 'vsx', 'rvv']).intersection(instruction_sets) and rng == 'aesni':
pytest.xfail('AES not yet implemented for this architecture') pytest.xfail('AES not yet implemented for this architecture')
if rng == 'aesni' and len(keys) == 2: if rng == 'aesni' and len(keys) == 2:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment