Skip to content
Snippets Groups Projects
Commit 42096a22 authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

ARM Neon: Fix makeVec and add Philox

parent 15a7e62d
No related branches found
No related tags found
1 merge request!222ARM Neon: Fix makeVec and add Philox
Pipeline #30529 passed
......@@ -42,12 +42,12 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon', q
if q_registers is True:
q_reg = 'q'
width = 128 // bits[data_type]
intwidth = 128 // bits[data_type]
intwidth = 128 // bits['int']
suffix = f'q_f{bits[data_type]}'
else:
q_reg = ''
width = 64 // bits[data_type]
intwidth = 64 // bits[data_type]
intwidth = 64 // bits['int']
suffix = f'_f{bits[data_type]}'
result = dict()
......@@ -61,9 +61,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon', q
result[intrinsic_id] = 'v' + name + suffix + arg_string
result['makeVecConst'] = f'vdup{q_reg}_n_f{bits[data_type]}' + '({0})'
result['makeVec'] = f'vdup{q_reg}_n_f{bits[data_type]}' + '({0})'
result['makeVec'] = f'makeVec_f{bits[data_type]}' + '(' + ", ".join(['{' + str(i) + '}' for i in range(width)]) + \
')'
result['makeVecConstInt'] = f'vdup{q_reg}_n_s{bits["int"]}' + '({0})'
result['makeVecInt'] = f'vdup{q_reg}_n_s{bits["int"]}' + '({0})'
result['makeVecInt'] = f'makeVec_s{bits["int"]}' + '({0}, {1}, {2}, {3})'
result['+int'] = f"vaddq_s{bits['int']}" + "({0}, {1})"
......@@ -74,7 +75,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon', q
result[data_type] = f'float{bits[data_type]}x{width}_t'
result['int'] = f'int{bits["int"]}x{bits[data_type]}_t'
result['bool'] = f'uint{bits[data_type]}x{width}_t'
result['headers'] = ['<arm_neon.h>']
result['headers'] = ['<arm_neon.h>', '"arm_neon_helpers.h"']
result['!='] = f'vmvn{q_reg}_u{bits[data_type]}({result["=="]})'
......
......@@ -155,9 +155,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
loop_node.step = vector_width
loop_node.subs(substitutions)
vector_int_width = ast_node.instruction_set['intwidth']
vector_loop_counter = cast_func((loop_counter_symbol,) * vector_int_width,
VectorType(loop_counter_symbol.dtype, vector_int_width)) + \
cast_func(tuple(range(vector_int_width)), VectorType(loop_counter_symbol.dtype, vector_int_width))
vector_loop_counter = cast_func(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) \
+ cast_func(tuple(range(vector_int_width)), VectorType(loop_counter_symbol.dtype, vector_int_width))
fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter},
skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access))
......
#include <arm_neon.h>
inline float32x4_t makeVec_f32(float a, float b, float c, float d)
{
alignas(16) float data[4] = {a, b, c, d};
return vld1q_f32(data);
}
inline float64x2_t makeVec_f64(double a, double b)
{
alignas(16) double data[2] = {a, b};
return vld1q_f64(data);
}
inline int32x4_t makeVec_s32(int a, int b, int c, int d)
{
alignas(16) int data[4] = {a, b, c, d};
return vld1q_s32(data);
}
......@@ -12,6 +12,10 @@
#endif
#endif
#ifdef __ARM_NEON
#include <arm_neon.h>
#endif
#ifndef __CUDA_ARCH__
#define QUALIFIERS inline
#include "myintrin.h"
......@@ -277,6 +281,161 @@ QUALIFIERS void philox_double2(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ct
}
#endif
#if defined(__ARM_NEON)
QUALIFIERS void _philox4x32round(uint32x4_t* ctr, uint32x4_t* key)
{
uint32x4_t lohi0a = vreinterpretq_u32_u64(vmull_u32(vget_low_u32(ctr[0]), vdup_n_u32(PHILOX_M4x32_0)));
uint32x4_t lohi0b = vreinterpretq_u32_u64(vmull_high_u32(ctr[0], vdupq_n_u32(PHILOX_M4x32_0)));
uint32x4_t lohi1a = vreinterpretq_u32_u64(vmull_u32(vget_low_u32(ctr[2]), vdup_n_u32(PHILOX_M4x32_1)));
uint32x4_t lohi1b = vreinterpretq_u32_u64(vmull_high_u32(ctr[2], vdupq_n_u32(PHILOX_M4x32_1)));
uint32x4_t lo0 = vuzp1q_u32(lohi0a, lohi0b);
uint32x4_t lo1 = vuzp1q_u32(lohi1a, lohi1b);
uint32x4_t hi0 = vuzp2q_u32(lohi0a, lohi0b);
uint32x4_t hi1 = vuzp2q_u32(lohi1a, lohi1b);
ctr[0] = veorq_u32(veorq_u32(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = veorq_u32(veorq_u32(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(uint32x4_t* key)
{
key[0] = vaddq_u32(key[0], vdupq_n_u32(PHILOX_W32_0));
key[1] = vaddq_u32(key[1], vdupq_n_u32(PHILOX_W32_1));
}
template<bool high>
QUALIFIERS float64x2_t _uniform_double_hq(uint32x4_t x, uint32x4_t y)
{
// convert 32 to 64 bit
if (high)
{
x = vzip2q_u32(x, vdupq_n_u32(0));
y = vzip2q_u32(y, vdupq_n_u32(0));
}
else
{
x = vzip1q_u32(x, vdupq_n_u32(0));
y = vzip1q_u32(y, vdupq_n_u32(0));
}
// calculate z = x ^ y << (53 - 32))
uint64x2_t z = vshlq_n_u64(vreinterpretq_u64_u32(y), 53 - 32);
z = veorq_u64(vreinterpretq_u64_u32(x), z);
// convert uint64 to double
float64x2_t rs = vcvtq_f64_u64(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = vfmaq_f64(vdupq_n_f64(TWOPOW53_INV_DOUBLE/2.0), vdupq_n_f64(TWOPOW53_INV_DOUBLE), rs);
return rs;
}
QUALIFIERS void philox_float4(uint32x4_t ctr0, uint32x4_t ctr1, uint32x4_t ctr2, uint32x4_t ctr3,
uint32 key0, uint32 key1,
float32x4_t & rnd1, float32x4_t & rnd2, float32x4_t & rnd3, float32x4_t & rnd4)
{
uint32x4_t key[2] = {vdupq_n_u32(key0), vdupq_n_u32(key1)};
uint32x4_t ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = vcvtq_f32_u32(ctr[0]);
rnd2 = vcvtq_f32_u32(ctr[1]);
rnd3 = vcvtq_f32_u32(ctr[2]);
rnd4 = vcvtq_f32_u32(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd1);
rnd2 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd2);
rnd3 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd3);
rnd4 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd4);
}
QUALIFIERS void philox_double2(uint32x4_t ctr0, uint32x4_t ctr1, uint32x4_t ctr2, uint32x4_t ctr3,
uint32 key0, uint32 key1,
float64x2_t & rnd1lo, float64x2_t & rnd1hi, float64x2_t & rnd2lo, float64x2_t & rnd2hi)
{
uint32x4_t key[2] = {vdupq_n_u32(key0), vdupq_n_u32(key1)};
uint32x4_t ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void philox_float4(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float32x4_t & rnd1, float32x4_t & rnd2, float32x4_t & rnd3, float32x4_t & rnd4)
{
uint32x4_t ctr0v = vdupq_n_u32(ctr0);
uint32x4_t ctr2v = vdupq_n_u32(ctr2);
uint32x4_t ctr3v = vdupq_n_u32(ctr3);
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_float4(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float32x4_t & rnd1, float32x4_t & rnd2, float32x4_t & rnd3, float32x4_t & rnd4)
{
philox_float4(ctr0, vreinterpretq_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float64x2_t & rnd1lo, float64x2_t & rnd1hi, float64x2_t & rnd2lo, float64x2_t & rnd2hi)
{
uint32x4_t ctr0v = vdupq_n_u32(ctr0);
uint32x4_t ctr2v = vdupq_n_u32(ctr2);
uint32x4_t ctr3v = vdupq_n_u32(ctr3);
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void philox_double2(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float64x2_t & rnd1, float64x2_t & rnd2)
{
uint32x4_t ctr0v = vdupq_n_u32(ctr0);
uint32x4_t ctr2v = vdupq_n_u32(ctr2);
uint32x4_t ctr3v = vdupq_n_u32(ctr3);
float64x2_t ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
}
QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float64x2_t & rnd1, float64x2_t & rnd2)
{
philox_double2(ctr0, vreinterpretq_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2);
}
#endif
#ifdef __AVX2__
QUALIFIERS void _philox4x32round(__m256i* ctr, __m256i* key)
{
......
......@@ -58,8 +58,8 @@ class RNGBase(CustomCodeNode):
return code
def __repr__(self):
return (", ".join(['{}'] * self._num_vars) + " \\leftarrow {}RNG").format(*self.result_symbols,
self._name.capitalize())
return ", ".join([str(s) for s in self.result_symbols]) + " \\leftarrow " + \
self._name.capitalize() + "_RNG(" + ", ".join([str(a) for a in self.args]) + ")"
class PhiloxTwoDoubles(RNGBase):
......
......@@ -27,6 +27,8 @@ if get_compiler_config()['os'] == 'windows':
def test_rng(target, rng, precision, dtype, t=124, offsets=(0, 0), keys=(0, 0), offset_values=None):
if target == 'gpu':
pytest.importorskip('pycuda')
if instruction_sets and 'neon' in instruction_sets and rng == 'aesni':
pytest.xfail('AES not yet implemented for ARM Neon')
if rng == 'aesni' and len(keys) == 2:
keys *= 2
if offset_values is None:
......@@ -116,6 +118,8 @@ def test_rng_offsets(kind, vectorized):
@pytest.mark.parametrize('rng', ('philox', 'aesni'))
@pytest.mark.parametrize('precision,dtype', (('float', 'float'), ('double', 'double')))
def test_rng_vectorized(target, rng, precision, dtype, t=130, offsets=(1, 3), keys=(0, 0), offset_values=None):
if target == 'neon' and rng == 'aesni':
pytest.xfail('AES not yet implemented for ARM Neon')
cpu_vectorize_info = {'assume_inner_stride_one': True, 'assume_aligned': True, 'instruction_set': target}
dh = ps.create_data_handling((17, 17), default_ghost_layers=0, default_target='cpu')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment