Skip to content
Snippets Groups Projects
Commit cdf73d8f authored by Michael Kuron's avatar Michael Kuron :mortar_board: Committed by Jan Hönig
Browse files

Sizeless vectorization

parent cc645538
No related branches found
No related tags found
1 merge request!234Sizeless vectorization
Showing
with 452 additions and 98 deletions
......@@ -178,7 +178,7 @@ arm64v9:
extends: .multiarch_template
image: i10git.cs.fau.de:5005/pycodegen/pycodegen/arm64
variables:
PYSTENCILS_SIMD: "sve256,sve512"
PYSTENCILS_SIMD: "sve256,sve512,sve"
ASAN_OPTIONS: detect_leaks=0
LD_PRELOAD: /usr/lib/aarch64-linux-gnu/libasan.so.6
before_script:
......@@ -186,6 +186,20 @@ arm64v9:
- sed -i s/march=native/march=armv8-a+sve/g ~/.config/pystencils/config.json
- sed -i s/g\+\+/clang++/g ~/.config/pystencils/config.json
riscv64:
# The RISC-V vector extension is still experimental and needs special compiler flags.
# Once they are officially released, this job should be cleaned up to match the others.
extends: .multiarch_template
image: i10git.cs.fau.de:5005/pycodegen/pycodegen/riscv64
variables:
PYSTENCILS_SIMD: "rvv"
QEMU_CPU: "rv64,x-v=true"
before_script:
- *multiarch_before_script
- sed -i 's/march=native/march=rv64imfdv0p10 -menable-experimental-extensions/g' ~/.config/pystencils/config.json
- sed -i s/g\+\+/clang++/g ~/.config/pystencils/config.json
- sed -i 's/fopenmp/fopenmp=libgomp -I\/usr\/include\/riscv64-linux-gnu/g' ~/.config/pystencils/config.json
minimal-conda:
stage: test
except:
......
......@@ -28,13 +28,19 @@ def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, o
elif byte_alignment == 'cacheline':
cacheline_sizes = [get_cacheline_size(is_name) for is_name in instruction_sets]
if all([s is None for s in cacheline_sizes]):
byte_alignment = max([get_vector_instruction_set(type_name, is_name)['width'] * np.dtype(dtype).itemsize
for is_name in instruction_sets])
widths = [get_vector_instruction_set(type_name, is_name)['width'] * np.dtype(dtype).itemsize
for is_name in instruction_sets
if type(get_vector_instruction_set(type_name, is_name)['width']) is int]
byte_alignment = 64 if all([s is None for s in widths]) else max(widths)
else:
byte_alignment = max([s for s in cacheline_sizes if s is not None])
elif not any([type(get_vector_instruction_set(type_name, is_name)['width']) is int
for is_name in instruction_sets]):
byte_alignment = 64
else:
byte_alignment = max([get_vector_instruction_set(type_name, is_name)['width'] * np.dtype(dtype).itemsize
for is_name in instruction_sets])
for is_name in instruction_sets
if type(get_vector_instruction_set(type_name, is_name)['width']) is int])
if (not align_inner_coordinate) or (not hasattr(shape, '__len__')):
size = np.prod(shape)
d = np.dtype(dtype)
......
......@@ -19,9 +19,8 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
if instruction_set != 'neon' and not instruction_set.startswith('sve'):
raise NotImplementedError(instruction_set)
if instruction_set == 'sve':
raise NotImplementedError("sizeless SVE is not implemented")
if instruction_set.startswith('sve'):
cmp = 'cmp'
elif instruction_set.startswith('sve'):
cmp = 'cmp'
bitwidth = int(instruction_set[3:])
elif instruction_set == 'neon':
......@@ -53,8 +52,16 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
'float': 32,
'int': 32}
width = bitwidth // bits[data_type]
intwidth = bitwidth // bits['int']
result = dict()
if instruction_set == 'sve':
width = 'svcntd()' if data_type == 'double' else 'svcntw()'
intwidth = 'svcntw()'
result['bytes'] = 'svcntb()'
else:
width = bitwidth // bits[data_type]
intwidth = bitwidth // bits['int']
result['bytes'] = bitwidth // 8
if instruction_set.startswith('sve'):
prefix = 'sv'
suffix = f'_f{bits[data_type]}'
......@@ -62,11 +69,12 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
prefix = 'v'
suffix = f'q_f{bits[data_type]}'
result = dict()
result['bytes'] = bitwidth // 8
predicate = f'{prefix}whilelt_b{bits[data_type]}(0, {width})'
int_predicate = f'{prefix}whilelt_b{bits["int"]}(0, {intwidth})'
if instruction_set == 'sve':
predicate = f'{prefix}whilelt_b{bits[data_type]}_u64({{loop_counter}}, {{loop_stop}})'
int_predicate = f'{prefix}whilelt_b{bits["int"]}_u64({{loop_counter}}, {{loop_stop}})'
else:
predicate = f'{prefix}whilelt_b{bits[data_type]}(0, {width})'
int_predicate = f'{prefix}whilelt_b{bits["int"]}(0, {intwidth})'
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
......@@ -80,8 +88,13 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result[intrinsic_id] = prefix + name + suffix + undef + arg_string
result['width'] = width
result['intwidth'] = intwidth
if instruction_set == 'sve':
from pystencils.backends.cbackend import CFunction
result['width'] = CFunction(width, "int")
result['intwidth'] = CFunction(intwidth, "int")
else:
result['width'] = width
result['intwidth'] = intwidth
if instruction_set.startswith('sve'):
result['makeVecConst'] = f'svdup_f{bits[data_type]}' + '({0})'
......@@ -89,17 +102,17 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})'
vindex = f'svindex_u{bits[data_type]}(0, {{0}})'
result['scatter'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{2}") + ', {1})'
result['gather'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{1}") + ')'
result['storeS'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{2}") + ', {1})'
result['loadS'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{1}") + ')'
result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})"
result['float'] = 'svfloat32_st'
result['double'] = 'svfloat64_st'
result['int'] = f'svint{bits["int"]}_st'
result['bool'] = 'svbool_st'
result['float'] = f'svfloat{bits["float"]}_{"s" if instruction_set != "sve" else ""}t'
result['double'] = f'svfloat{bits["double"]}_{"s" if instruction_set != "sve" else ""}t'
result['int'] = f'svint{bits["int"]}_{"s" if instruction_set != "sve" else ""}t'
result['bool'] = f'svbool_{"s" if instruction_set != "sve" else ""}t'
result['headers'] = ['<arm_sve.h>', '"arm_neon_helpers.h"']
......@@ -111,9 +124,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['maskStoreU'] = result['storeU'].replace(predicate, '{2}')
result['maskStoreA'] = result['storeA'].replace(predicate, '{2}')
result['maskScatter'] = result['scatter'].replace(predicate, '{3}')
result['maskStoreS'] = result['storeS'].replace(predicate, '{3}')
result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
if instruction_set != 'sve':
result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
else:
result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})'
result['makeVec'] = f'makeVec_f{bits[data_type]}' + '(' + ", ".join(['{' + str(i) + '}' for i in
......@@ -137,7 +151,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['any'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) > 0'
result['all'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) == 16*0xff'
if bitwidth & (bitwidth - 1) == 0:
if instruction_set == 'sve' or bitwidth & (bitwidth - 1) == 0:
# only power-of-2 vector sizes will evenly divide a cacheline
result['cachelineSize'] = 'cachelineSize()'
result['cachelineZero'] = 'cachelineZero((void*) {0})'
......
......@@ -6,6 +6,7 @@ from typing import Set
import numpy as np
import sympy as sp
from sympy.core import S
from sympy.core.cache import cacheit
from sympy.logic.boolalg import BooleanFalse, BooleanTrue
from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
......@@ -165,6 +166,23 @@ class PrintNode(CustomCodeNode):
self.headers.append("<iostream>")
class CFunction(TypedSymbol):
def __new__(cls, function, dtype):
return CFunction.__xnew_cached_(cls, function, dtype)
def __new_stage2__(cls, function, dtype):
return super(CFunction, cls).__xnew__(cls, function, dtype)
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), {}
# ------------------------------------------- Printer ------------------------------------------------------------------
......@@ -184,6 +202,8 @@ class CBackend:
self._indent = " "
self._dialect = dialect
self._signatureOnly = signature_only
self._kwargs = {}
self.sympy_printer._kwargs = self._kwargs
def __call__(self, node):
prev_is = VectorType.instruction_set
......@@ -205,7 +225,8 @@ class CBackend:
return str(node)
def _print_KernelFunction(self, node):
function_arguments = [f"{self._print(s.symbol.dtype)} {s.symbol.name}" for s in node.get_parameters()]
function_arguments = [f"{self._print(s.symbol.dtype)} {s.symbol.name}" for s in node.get_parameters()
if not type(s.symbol) is CFunction]
launch_bounds = ""
if self._dialect == 'cuda':
max_threads = node.indexing.max_threads_per_block()
......@@ -232,6 +253,8 @@ class CBackend:
condition = f"{counter_symbol} < {self.sympy_printer.doprint(node.stop)}"
update = f"{counter_symbol} += {self.sympy_printer.doprint(node.step)}"
loop_str = f"for ({start}; {condition}; {update})"
self._kwargs['loop_counter'] = counter_symbol
self._kwargs['loop_stop'] = node.stop
prefix = "\n".join(node.prefix_lines)
if prefix:
......@@ -265,7 +288,8 @@ class CBackend:
if instr not in self._vector_instruction_set:
self._vector_instruction_set[instr] = self._vector_instruction_set['store' + instr[-1]].format(
'{0}', self._vector_instruction_set['blendv'].format(
self._vector_instruction_set['load' + instr[-1]].format('{0}'), '{1}', '{2}'))
self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs),
'{1}', '{2}', **self._kwargs), **self._kwargs)
printed_mask = self.sympy_printer.doprint(mask)
if data_type.base_type.base_name == 'double':
if self._vector_instruction_set['double'] == '__m256d':
......@@ -287,9 +311,9 @@ class CBackend:
ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0])
if stride != 1:
instr = 'maskScatter' if mask != True else 'scatter' # NOQA
instr = 'maskStoreS' if mask != True else 'storeS' # NOQA
return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
stride, printed_mask) + ';'
stride, printed_mask, **self._kwargs) + ';'
pre_code = ''
if nontemporal and 'cachelineZero' in self._vector_instruction_set:
......@@ -301,22 +325,22 @@ class CBackend:
element_size = 8 if data_type.base_type.base_name == 'double' else 4
size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}"
pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \
self._vector_instruction_set['cachelineZero'].format(ptr) + ';\n}\n'
self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n'
code = self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
printed_mask) + ';'
printed_mask, **self._kwargs) + ';'
flushcond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == {CachelineSize.last_symbol}"
if nontemporal and 'flushCacheline' in self._vector_instruction_set:
code2 = self._vector_instruction_set['flushCacheline'].format(
ptr, self.sympy_printer.doprint(rhs)) + ';'
ptr, self.sympy_printer.doprint(rhs), **self._kwargs) + ';'
code = f"{code}\nif ({flushcond}) {{\n\t{code2}\n}}"
elif nontemporal and 'storeAAndFlushCacheline' in self._vector_instruction_set:
tmpvar = '_tmp_' + hashlib.sha1(self.sympy_printer.doprint(rhs).encode('ascii')).hexdigest()[:8]
code = 'const ' + self._print(node.lhs.dtype).replace(' const', '') + ' ' + tmpvar + ' = ' \
+ self.sympy_printer.doprint(rhs) + ';'
code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask) + ';'
code2 = self._vector_instruction_set['storeAAndFlushCacheline'].format(ptr, tmpvar, printed_mask) \
+ ';'
code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';'
code2 = self._vector_instruction_set['storeAAndFlushCacheline'].format(ptr, tmpvar, printed_mask,
**self._kwargs) + ';'
code += f"\nif ({flushcond}) {{\n\t{code2}\n}} else {{\n\t{code1}\n}}"
return pre_code + code
else:
......@@ -617,16 +641,16 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
def _print_Abs(self, expr):
if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access):
return self.instruction_set['abs'].format(self._print(expr.args[0]))
return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs)
return super()._print_Abs(expr)
def _print_Function(self, expr):
if isinstance(expr, vector_memory_access):
arg, data_type, aligned, _, mask, stride = expr.args
if stride != 1:
return self.instruction_set['gather'].format("& " + self._print(arg), stride)
return self.instruction_set['loadS'].format("& " + self._print(arg), stride, **self._kwargs)
instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
return instruction.format("& " + self._print(arg))
return instruction.format("& " + self._print(arg), **self._kwargs)
elif isinstance(expr, cast_func):
arg, data_type = expr.args
if type(data_type) is VectorType:
......@@ -640,19 +664,21 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set:
increments = np.array(arg)[1:] - np.array(arg)[:-1]
if len(set(increments)) == 1:
return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0])
return self.instruction_set[instruction].format(*printed_args)
return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0],
**self._kwargs)
return self.instruction_set[instruction].format(*printed_args, **self._kwargs)
else:
is_boolean = get_type_of_expression(arg) == create_type("bool")
is_integer = get_type_of_expression(arg) == create_type("int") or \
(isinstance(arg, TypedSymbol) and arg.dtype.is_int())
instruction = 'makeVecConstBool' if is_boolean else \
'makeVecConstInt' if is_integer else 'makeVecConst'
return self.instruction_set[instruction].format(self._print(arg))
return self.instruction_set[instruction].format(self._print(arg), **self._kwargs)
elif expr.func == fast_division:
result = self._scalarFallback('_print_Function', expr)
if not result:
result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]))
result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]),
**self._kwargs)
return result
elif expr.func == fast_sqrt:
return f"({self._print(sp.sqrt(expr.args[0]))})"
......@@ -660,7 +686,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self._scalarFallback('_print_Function', expr)
if not result:
if 'rsqrt' in self.instruction_set:
return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
return self.instruction_set['rsqrt'].format(self._print(expr.args[0]), **self._kwargs)
else:
return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
......@@ -672,8 +698,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if isinstance(expr.args[0], sp.Rel):
op = expr.args[0].rel_op
if (instr, op) in self.instruction_set:
return self.instruction_set[(instr, op)].format(*[self._print(a) for a in expr.args[0].args])
return self.instruction_set[instr].format(self._print(expr.args[0]))
return self.instruction_set[(instr, op)].format(*[self._print(a) for a in expr.args[0].args],
**self._kwargs)
return self.instruction_set[instr].format(self._print(expr.args[0]), **self._kwargs)
return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
......@@ -686,7 +713,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert len(arg_strings) > 0
result = arg_strings[0]
for item in arg_strings[1:]:
result = self.instruction_set['&'].format(result, item)
result = self.instruction_set['&'].format(result, item, **self._kwargs)
return result
def _print_Or(self, expr):
......@@ -698,7 +725,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert len(arg_strings) > 0
result = arg_strings[0]
for item in arg_strings[1:]:
result = self.instruction_set['|'].format(result, item)
result = self.instruction_set['|'].format(result, item, **self._kwargs)
return result
def _print_Add(self, expr, order=None):
......@@ -739,7 +766,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
processed = summands[0].term
for summand in summands[1:]:
func = self.instruction_set['-' + suffix] if summand.sign == -1 else self.instruction_set['+' + suffix]
processed = func.format(processed, summand.term)
processed = func.format(processed, summand.term, **self._kwargs)
return processed
def _print_Pow(self, expr):
......@@ -747,21 +774,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if result:
return result
one = self.instruction_set['makeVecConst'].format(1.0)
one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs)
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
elif expr.exp == -1:
one = self.instruction_set['makeVecConst'].format(1.0)
return self.instruction_set['/'].format(one, self._print(expr.base))
one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs)
return self.instruction_set['/'].format(one, self._print(expr.base), **self._kwargs)
elif expr.exp == 0.5:
return self.instruction_set['sqrt'].format(self._print(expr.base))
return self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
elif expr.exp == -0.5:
root = self.instruction_set['sqrt'].format(self._print(expr.base))
return self.instruction_set['/'].format(one, root)
root = self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
return self.instruction_set['/'].format(one, root, **self._kwargs)
elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
return self.instruction_set['/'].format(one,
self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)),
**self._kwargs)
else:
raise ValueError("Generic exponential not supported: " + str(expr))
......@@ -800,19 +828,19 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = a_str[0]
for item in a_str[1:]:
result = self.instruction_set['*'].format(result, item)
result = self.instruction_set['*'].format(result, item, **self._kwargs)
if len(b) > 0:
denominator_str = b_str[0]
for item in b_str[1:]:
denominator_str = self.instruction_set['*'].format(denominator_str, item)
result = self.instruction_set['/'].format(result, denominator_str)
denominator_str = self.instruction_set['*'].format(denominator_str, item, **self._kwargs)
result = self.instruction_set['/'].format(result, denominator_str, **self._kwargs)
if inside_add:
return sign, result
else:
if sign < 0:
return self.instruction_set['*'].format(self._print(S.NegativeOne), result)
return self.instruction_set['*'].format(self._print(S.NegativeOne), result, **self._kwargs)
else:
return result
......@@ -820,13 +848,13 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self._scalarFallback('_print_Relational', expr)
if result:
return result
return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs), **self._kwargs)
def _print_Equality(self, expr):
result = self._scalarFallback('_print_Equality', expr)
if result:
return result
return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs))
return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs), **self._kwargs)
def _print_Piecewise(self, expr):
result = self._scalarFallback('_print_Piecewise', expr)
......@@ -847,10 +875,11 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"):
if not KERNCRAFT_NO_TERNARY_MODE:
result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr),
result)
result, **self._kwargs)
else:
print("Warning - skipping ternary op")
else:
# noinspection SpellCheckingInspection
result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition),
**self._kwargs)
return result
def get_argument_string(function_shortcut, last=''):
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
if last:
arg_string += last + ','
arg_string = arg_string[:-1] + ")"
return arg_string
def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'):
assert instruction_set == 'rvv'
bits = {'double': 64,
'float': 32,
'int': 32}
base_names = {
'+': 'fadd_vv[0, 1]',
'-': 'fsub_vv[0, 1]',
'*': 'fmul_vv[0, 1]',
'/': 'fdiv_vv[0, 1]',
'sqrt': 'fsqrt_v[0]',
'loadU': f'le{bits[data_type]}_v[0]',
'loadA': f'le{bits[data_type]}_v[0]',
'storeU': f'se{bits[data_type]}_v[0, 1]',
'storeA': f'se{bits[data_type]}_v[0, 1]',
'maskStoreU': f'se{bits[data_type]}_v[2, 0, 1]',
'maskStoreA': f'se{bits[data_type]}_v[2, 0, 1]',
'loadS': f'lse{bits[data_type]}_v[0, 1]',
'storeS': f'sse{bits[data_type]}_v[0, 2, 1]',
'maskStoreS': f'sse{bits[data_type]}_v[2, 0, 3, 1]',
'abs': 'fabs_v[0]',
'==': 'mfeq_vv[0, 1]',
'!=': 'mfne_vv[0, 1]',
'<=': 'mfle_vv[0, 1]',
'<': 'mflt_vv[0, 1]',
'>=': 'mfge_vv[0, 1]',
'>': 'mfgt_vv[0, 1]',
'&': 'mand_mm[0, 1]',
'|': 'mor_mm[0, 1]',
'blendv': 'merge_vvm[2, 0, 1]',
'any': 'popc_m[0]',
'all': 'popc_m[0]',
}
result = dict()
width = f'vsetvlmax_e{bits[data_type]}m1()'
intwidth = 'vsetvlmax_e{bits["int"]}m1()'
result['bytes'] = 'vsetvlmax_e8m1()'
prefix = 'v'
suffix = f'_f{bits[data_type]}m1'
vl = '{loop_stop} - {loop_counter}'
int_vl = f'({vl})*{bits[data_type]//bits["int"]}'
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')]
if name.startswith('mf'):
suffix2 = suffix + f'_b{bits[data_type]}'
elif name.endswith('_mm') or name.endswith('_m'):
suffix2 = f'_b{bits[data_type]}'
elif intrinsic_id.startswith('mask'):
suffix2 = suffix + '_m'
else:
suffix2 = suffix
arg_string = get_argument_string(function_shortcut, last=vl)
result[intrinsic_id] = prefix + name + suffix2 + arg_string
from pystencils.backends.cbackend import CFunction
result['width'] = CFunction(width, "int")
result['intwidth'] = CFunction(intwidth, "int")
result['makeVecConst'] = f'vfmv_v_f_f{bits[data_type]}m1({{0}}, {vl})'
result['makeVecConstInt'] = f'vmv_v_x_i{bits["int"]}m1({{0}}, {int_vl})'
result['makeVecIndex'] = f'vmacc_vx_i{bits["int"]}m1({result["makeVecConstInt"]}, {{1}}, ' + \
f'vid_v_i{bits["int"]}m1({int_vl}), {int_vl})'
result['storeS'] = result['storeS'].replace('{2}', f'{{2}}*{bits[data_type]//8}')
result['loadS'] = result['loadS'].replace('{1}', f'{{1}}*{bits[data_type]//8}')
result['maskStoreS'] = result['maskStoreS'].replace('{3}', f'{{3}}*{bits[data_type]//8}')
result['+int'] = f"vadd_vv_i{bits['int']}m1({{0}}, {{1}}, {int_vl})"
result['float'] = f'vfloat{bits["float"]}m1_t'
result['double'] = f'vfloat{bits["double"]}m1_t'
result['int'] = f'vint{bits["int"]}m1_t'
result['bool'] = f'vbool{bits[data_type]}_t'
result['headers'] = ['<riscv_vector.h>']
result['any'] += ' > 0x0'
result['all'] += f' == vsetvl_e{bits[data_type]}m1({vl})'
return result
......@@ -6,6 +6,7 @@ from ctypes import CDLL
from pystencils.backends.x86_instruction_sets import get_vector_instruction_set_x86
from pystencils.backends.arm_instruction_sets import get_vector_instruction_set_arm
from pystencils.backends.ppc_instruction_sets import get_vector_instruction_set_ppc
from pystencils.backends.riscv_instruction_sets import get_vector_instruction_set_riscv
def get_vector_instruction_set(data_type='double', instruction_set='avx'):
......@@ -13,6 +14,8 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
return get_vector_instruction_set_arm(data_type, instruction_set)
elif instruction_set in ['vsx']:
return get_vector_instruction_set_ppc(data_type, instruction_set)
elif instruction_set in ['rvv']:
return get_vector_instruction_set_riscv(data_type, instruction_set)
else:
return get_vector_instruction_set_x86(data_type, instruction_set)
......@@ -30,6 +33,11 @@ def get_supported_instruction_sets():
return os.environ['PYSTENCILS_SIMD'].split(',')
if platform.system() == 'Darwin' and platform.machine() == 'arm64': # not supported by cpuinfo
return ['neon']
elif platform.system() == 'Linux' and platform.machine().startswith('riscv'): # not supported by cpuinfo
libc = CDLL('libc.so.6')
hwcap = libc.getauxval(16) # AT_HWCAP
hwcap_isa_v = 1 << (ord('V') - ord('A')) # COMPAT_HWCAP_ISA_V
return ['rvv'] if hwcap & hwcap_isa_v else []
elif platform.machine().startswith('ppc64'): # no flags reported by cpuinfo
import subprocess
import tempfile
......@@ -74,8 +82,7 @@ def get_supported_instruction_sets():
if native_length != pwr2_length:
result.append(f"sve{pwr2_length}")
result.append(f"sve{native_length}")
else:
result.append("sve")
result.append("sve")
return result
......
......@@ -147,11 +147,11 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
vindex = f'{pre}_set_epi{bit_width//size}(' + ', '.join([str(i) for i in range(result['width'])][::-1]) + ')'
vindex = f'{pre}_mullo_epi{bit_width//size}({vindex}, {pre}_set1_epi{bit_width//size}({{0}}))'
result['scatter'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \
f', {{1}}, {64//size})'
result['maskScatter'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \
f', {{1}}, {64//size})'
result['gather'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {64//size})'
result['storeS'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \
f', {{1}}, {64//size})'
result['maskStoreS'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \
f', {{1}}, {64//size})'
result['loadS'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {64//size})'
if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})"
......
......@@ -59,7 +59,7 @@ from appdirs import user_cache_dir, user_config_dir
from pystencils import FieldType
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.backends.cbackend import generate_c, get_headers, CFunction
from pystencils.data_types import cast_func, VectorType, vector_memory_access
from pystencils.include import get_pystencils_include_path
from pystencils.kernel_wrapper import KernelWrapper
......@@ -411,7 +411,7 @@ def create_function_boilerplate_code(parameter_info, name, ast_node, insert_chec
if has_openmp and has_nontemporal:
byte_width = ast_node.instruction_set['cachelineSize']
offset = max(max(ast_node.ghost_layers)) * item_size
offset_cond = f"(((uintptr_t) buffer_{field.name}.buf) + {offset}) % {byte_width} == 0"
offset_cond = f"(((uintptr_t) buffer_{field.name}.buf) + {offset}) % ({byte_width}) == 0"
message = str(offset) + ". This is probably due to a different number of ghost_layers chosen for " \
"the arrays and the kernel creation. If the number of ghost layers for " \
......@@ -460,6 +460,8 @@ def create_function_boilerplate_code(parameter_info, name, ast_node, insert_chec
name=field.name))
elif param.is_field_shape:
parameters.append(f"buffer_{param.field_name}.shape[{param.symbol.coordinate}]")
elif type(param.symbol) is CFunction:
continue
else:
extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
if np.issubdtype(param.symbol.dtype.numpy_dtype, np.complexfloating):
......
......@@ -127,9 +127,10 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
kernel_ast.instruction_set = vector_is
vectorize_rng(kernel_ast, vector_width)
scattergather = 'scatter' in vector_is and 'gather' in vector_is
strided = 'storeS' in vector_is and 'loadS' in vector_is
keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU']
vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal,
scattergather, assume_sufficient_line_padding)
strided, keep_loop_stop, assume_sufficient_line_padding)
insert_vector_casts(kernel_ast)
......@@ -152,7 +153,7 @@ def vectorize_rng(kernel_ast, vector_width):
def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields,
scattergather, assume_sufficient_line_padding):
strided, keep_loop_stop, assume_sufficient_line_padding):
"""Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
inner_loops = [n for n in all_loops if n.is_innermost_loop]
......@@ -162,7 +163,9 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
loop_range = loop_node.stop - loop_node.start
# cut off loop tail, that is not a multiple of four
if assume_aligned and assume_sufficient_line_padding:
if keep_loop_stop:
pass
elif assume_aligned and assume_sufficient_line_padding:
loop_range = loop_node.stop - loop_node.start
new_stop = loop_node.start + modulo_ceil(loop_range, vector_width)
loop_node.stop = new_stop
......@@ -184,7 +187,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms()
aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0
stride = sp.simplify(index.subs({loop_counter_symbol: loop_counter_symbol + 1}) - index)
if not loop_counter_is_offset and (not scattergather or loop_counter_symbol in stride.atoms()):
if not loop_counter_is_offset and (not strided or loop_counter_symbol in stride.atoms()):
successful = False
break
typed_symbol = base.label
......@@ -197,7 +200,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
if hasattr(indexed, 'field'):
nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields)
substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True,
stride if scattergather else 1)
stride if strided else 1)
if nontemporal:
# insert NontemporalFence after the outermost loop
parent = loop_node.parent
......@@ -214,7 +217,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
loop_node.subs(substitutions)
vector_int_width = ast_node.instruction_set['intwidth']
vector_loop_counter = cast_func(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) \
+ cast_func(tuple(range(vector_int_width)), VectorType(loop_counter_symbol.dtype, vector_int_width))
+ cast_func(tuple(range(vector_int_width if type(vector_int_width) is int else 2)),
VectorType(loop_counter_symbol.dtype, vector_int_width))
fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter},
skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access))
......
......@@ -712,7 +712,7 @@ class VectorType(Type):
def __str__(self):
if self.instruction_set is None:
return "%s[%d]" % (self.base_type, self.width)
return "%s[%s]" % (self.base_type, self.width)
else:
if self.base_type == create_type("int64") or self.base_type == create_type("int32"):
return self.instruction_set['int']
......
......@@ -30,6 +30,10 @@
#endif
#endif
#ifdef __riscv_v
#include <riscv_vector.h>
#endif
#ifndef __CUDA_ARCH__
#define QUALIFIERS inline
#include "myintrin.h"
......@@ -818,6 +822,156 @@ QUALIFIERS void philox_double2(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32
#endif
#if defined(__riscv_v)
QUALIFIERS void _philox4x32round(vuint32m1_t & ctr0, vuint32m1_t & ctr1, vuint32m1_t & ctr2, vuint32m1_t & ctr3,
vuint32m1_t key0, vuint32m1_t key1)
{
vuint32m1_t lo0 = vmul_vv_u32m1(ctr0, vmv_v_x_u32m1(PHILOX_M4x32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
vuint32m1_t lo1 = vmul_vv_u32m1(ctr2, vmv_v_x_u32m1(PHILOX_M4x32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1());
vuint32m1_t hi0 = vmulhu_vv_u32m1(ctr0, vmv_v_x_u32m1(PHILOX_M4x32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
vuint32m1_t hi1 = vmulhu_vv_u32m1(ctr2, vmv_v_x_u32m1(PHILOX_M4x32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1());
ctr0 = vxor_vv_u32m1(vxor_vv_u32m1(hi1, ctr1, vsetvlmax_e32m1()), key0, vsetvlmax_e32m1());
ctr1 = lo1;
ctr2 = vxor_vv_u32m1(vxor_vv_u32m1(hi0, ctr3, vsetvlmax_e32m1()), key1, vsetvlmax_e32m1());
ctr3 = lo0;
}
QUALIFIERS void _philox4x32bumpkey(vuint32m1_t & key0, vuint32m1_t & key1)
{
key0 = vadd_vv_u32m1(key0, vmv_v_x_u32m1(PHILOX_W32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
key1 = vadd_vv_u32m1(key1, vmv_v_x_u32m1(PHILOX_W32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1());
}
template<bool high>
QUALIFIERS vfloat64m1_t _uniform_double_hq(vuint32m1_t x, vuint32m1_t y)
{
// convert 32 to 64 bit
if (high)
{
size_t s = vsetvlmax_e32m1();
x = vslidedown_vx_u32m1(vundefined_u32m1(), x, s/2, s);
y = vslidedown_vx_u32m1(vundefined_u32m1(), y, s/2, s);
}
vuint64m1_t x64 = vwcvtu_x_x_v_u64m1(vlmul_trunc_v_u32m1_u32mf2(x), vsetvlmax_e64m1());
vuint64m1_t y64 = vwcvtu_x_x_v_u64m1(vlmul_trunc_v_u32m1_u32mf2(y), vsetvlmax_e64m1());
// calculate z = x ^ y << (53 - 32))
vuint64m1_t z = vsll_vx_u64m1(y64, 53 - 32, vsetvlmax_e64m1());
z = vxor_vv_u64m1(x64, z, vsetvlmax_e64m1());
// convert uint64 to double
vfloat64m1_t rs = vfcvt_f_xu_v_f64m1(z, vsetvlmax_e64m1());
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = vfmadd_vv_f64m1(rs, vfmv_v_f_f64m1(TWOPOW53_INV_DOUBLE, vsetvlmax_e64m1()), vfmv_v_f_f64m1(TWOPOW53_INV_DOUBLE/2.0, vsetvlmax_e64m1()), vsetvlmax_e64m1());
return rs;
}
QUALIFIERS void philox_float4(vuint32m1_t ctr0, vuint32m1_t ctr1, vuint32m1_t ctr2, vuint32m1_t ctr3,
uint32 key0, uint32 key1,
vfloat32m1_t & rnd1, vfloat32m1_t & rnd2, vfloat32m1_t & rnd3, vfloat32m1_t & rnd4)
{
vuint32m1_t key0v = vmv_v_x_u32m1(key0, vsetvlmax_e32m1());
vuint32m1_t key1v = vmv_v_x_u32m1(key1, vsetvlmax_e32m1());
_philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 1
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 2
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 3
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 4
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 5
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 6
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 7
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 8
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 9
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 10
// convert uint32 to float
rnd1 = vfcvt_f_xu_v_f32m1(ctr0, vsetvlmax_e32m1());
rnd2 = vfcvt_f_xu_v_f32m1(ctr1, vsetvlmax_e32m1());
rnd3 = vfcvt_f_xu_v_f32m1(ctr2, vsetvlmax_e32m1());
rnd4 = vfcvt_f_xu_v_f32m1(ctr3, vsetvlmax_e32m1());
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = vfmadd_vv_f32m1(rnd1, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
rnd2 = vfmadd_vv_f32m1(rnd2, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
rnd3 = vfmadd_vv_f32m1(rnd3, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
rnd4 = vfmadd_vv_f32m1(rnd4, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
}
QUALIFIERS void philox_double2(vuint32m1_t ctr0, vuint32m1_t ctr1, vuint32m1_t ctr2, vuint32m1_t ctr3,
uint32 key0, uint32 key1,
vfloat64m1_t & rnd1lo, vfloat64m1_t & rnd1hi, vfloat64m1_t & rnd2lo, vfloat64m1_t & rnd2hi)
{
vuint32m1_t key0v = vmv_v_x_u32m1(key0, vsetvlmax_e32m1());
vuint32m1_t key1v = vmv_v_x_u32m1(key1, vsetvlmax_e32m1());
_philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 1
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 2
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 3
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 4
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 5
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 6
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 7
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 8
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 9
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 10
rnd1lo = _uniform_double_hq<false>(ctr0, ctr1);
rnd1hi = _uniform_double_hq<true>(ctr0, ctr1);
rnd2lo = _uniform_double_hq<false>(ctr2, ctr3);
rnd2hi = _uniform_double_hq<true>(ctr2, ctr3);
}
QUALIFIERS void philox_float4(uint32 ctr0, vuint32m1_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
vfloat32m1_t & rnd1, vfloat32m1_t & rnd2, vfloat32m1_t & rnd3, vfloat32m1_t & rnd4)
{
vuint32m1_t ctr0v = vmv_v_x_u32m1(ctr0, vsetvlmax_e32m1());
vuint32m1_t ctr2v = vmv_v_x_u32m1(ctr2, vsetvlmax_e32m1());
vuint32m1_t ctr3v = vmv_v_x_u32m1(ctr3, vsetvlmax_e32m1());
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_float4(uint32 ctr0, vint32m1_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
vfloat32m1_t & rnd1, vfloat32m1_t & rnd2, vfloat32m1_t & rnd3, vfloat32m1_t & rnd4)
{
philox_float4(ctr0, vreinterpret_v_i32m1_u32m1(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, vuint32m1_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
vfloat64m1_t & rnd1lo, vfloat64m1_t & rnd1hi, vfloat64m1_t & rnd2lo, vfloat64m1_t & rnd2hi)
{
vuint32m1_t ctr0v = vmv_v_x_u32m1(ctr0, vsetvlmax_e32m1());
vuint32m1_t ctr2v = vmv_v_x_u32m1(ctr2, vsetvlmax_e32m1());
vuint32m1_t ctr3v = vmv_v_x_u32m1(ctr3, vsetvlmax_e32m1());
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void philox_double2(uint32 ctr0, vuint32m1_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
vfloat64m1_t & rnd1, vfloat64m1_t & rnd2)
{
vuint32m1_t ctr0v = vmv_v_x_u32m1(ctr0, vsetvlmax_e32m1());
vuint32m1_t ctr2v = vmv_v_x_u32m1(ctr2, vsetvlmax_e32m1());
vuint32m1_t ctr3v = vmv_v_x_u32m1(ctr3, vsetvlmax_e32m1());
vfloat64m1_t ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
}
QUALIFIERS void philox_double2(uint32 ctr0, vint32m1_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
vfloat64m1_t & rnd1, vfloat64m1_t & rnd2)
{
philox_double2(ctr0, vreinterpret_v_i32m1_u32m1(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2);
}
#endif
#ifdef __AVX2__
QUALIFIERS void _philox4x32round(__m256i* ctr, __m256i* key)
{
......
......@@ -12,7 +12,10 @@ supported_instruction_sets = get_supported_instruction_sets() if get_supported_i
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
@pytest.mark.parametrize('dtype', ('float', 'double'))
def test_vec_any(instruction_set, dtype):
width = get_vector_instruction_set(dtype, instruction_set)['width']
if instruction_set in ['sve', 'rvv']:
width = 4 # we don't know the actual value
else:
width = get_vector_instruction_set(dtype, instruction_set)['width']
data_arr = np.zeros((4*width, 4*width), dtype=np.float64 if dtype == 'double' else np.float32)
data_arr[3:9, 1:3*width-1] = 1.0
......@@ -28,13 +31,20 @@ def test_vec_any(instruction_set, dtype):
cpu_vectorize_info={'instruction_set': instruction_set})
kernel = ast.compile()
kernel(data=data_arr)
np.testing.assert_equal(data_arr[3:9, :3*width], 2.0)
if instruction_set in ['sve', 'rvv']:
# we only know that the first value has changed
np.testing.assert_equal(data_arr[3:9, :3*width-1], 2.0)
else:
np.testing.assert_equal(data_arr[3:9, :3*width], 2.0)
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
@pytest.mark.parametrize('dtype', ('float', 'double'))
def test_vec_all(instruction_set, dtype):
width = get_vector_instruction_set(dtype, instruction_set)['width']
if instruction_set in ['sve', 'rvv']:
width = 1000 # we don't know the actual value, need something guaranteed larger than vector
else:
width = get_vector_instruction_set(dtype, instruction_set)['width']
data_arr = np.zeros((4*width, 4*width), dtype=np.float64 if dtype == 'double' else np.float32)
data_arr[3:9, 1:3*width-1] = 1.0
......@@ -49,11 +59,16 @@ def test_vec_all(instruction_set, dtype):
cpu_vectorize_info={'instruction_set': instruction_set})
kernel = ast.compile()
kernel(data=data_arr)
np.testing.assert_equal(data_arr[3:9, :1], 0.0)
np.testing.assert_equal(data_arr[3:9, 1:width], 1.0)
np.testing.assert_equal(data_arr[3:9, width:2*width], 2.0)
np.testing.assert_equal(data_arr[3:9, 2*width:3*width-1], 1.0)
np.testing.assert_equal(data_arr[3:9, 3*width-1:], 0.0)
if instruction_set in ['sve', 'rvv']:
# we only know that some values in the middle have been replaced
assert np.all(data_arr[3:9, :2] <= 1.0)
assert np.any(data_arr[3:9, 2:] == 2.0)
else:
np.testing.assert_equal(data_arr[3:9, :1], 0.0)
np.testing.assert_equal(data_arr[3:9, 1:width], 1.0)
np.testing.assert_equal(data_arr[3:9, width:2*width], 2.0)
np.testing.assert_equal(data_arr[3:9, 2*width:3*width-1], 1.0)
np.testing.assert_equal(data_arr[3:9, 3*width-1:], 0.0)
@pytest.mark.skipif(not supported_instruction_sets, reason='cannot detect CPU instruction set')
......
......@@ -27,7 +27,7 @@ if get_compiler_config()['os'] == 'windows':
def test_rng(target, rng, precision, dtype, t=124, offsets=(0, 0), keys=(0, 0), offset_values=None):
if target == 'gpu':
pytest.importorskip('pycuda')
if instruction_sets and (set(['neon', 'vsx']).intersection(instruction_sets) or any([iset.startswith('sve') for iset in instruction_sets])) and rng == 'aesni':
if instruction_sets and set(['neon', 'sve', 'vsx', 'rvv']).intersection(instruction_sets) and rng == 'aesni':
pytest.xfail('AES not yet implemented for this architecture')
if rng == 'aesni' and len(keys) == 2:
keys *= 2
......@@ -118,7 +118,7 @@ def test_rng_offsets(kind, vectorized):
@pytest.mark.parametrize('rng', ('philox', 'aesni'))
@pytest.mark.parametrize('precision,dtype', (('float', 'float'), ('double', 'double')))
def test_rng_vectorized(target, rng, precision, dtype, t=130, offsets=(1, 3), keys=(0, 0), offset_values=None):
if (target in ['neon', 'vsx'] or target.startswith('sve')) and rng == 'aesni':
if (target in ['neon', 'vsx', 'rvv'] or target.startswith('sve')) and rng == 'aesni':
pytest.xfail('AES not yet implemented for this architecture')
cpu_vectorize_info = {'assume_inner_stride_one': True, 'assume_aligned': True, 'instruction_set': target}
......
......@@ -216,8 +216,7 @@ def test_logical_operators(instruction_set=instruction_set):
def test_hardware_query():
assert set(['sse', 'neon', 'vsx']).intersection(supported_instruction_sets) or \
any([iset.startswith('sve') for iset in supported_instruction_sets])
assert set(['sse', 'neon', 'sve', 'vsx', 'rvv']).intersection(supported_instruction_sets)
def test_vectorised_pow(instruction_set=instruction_set):
......
......@@ -55,10 +55,10 @@ def test_vectorized_abs(instruction_set, dtype):
@pytest.mark.parametrize('dtype', ('float', 'double'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_scatter_gather(instruction_set, dtype):
def test_strided(instruction_set, dtype):
f, g = ps.fields(f"f, g : float{64 if dtype == 'double' else 32}[2D]")
update_rule = [ps.Assignment(g[0, 0], f[0, 0] + f[-1, 0] + f[1, 0] + f[0, 1] + f[0, -1] + 42.0)]
if 'scatter' not in get_vector_instruction_set(dtype, instruction_set) and not instruction_set == 'avx512' and not instruction_set.startswith('sve'):
if 'storeS' not in get_vector_instruction_set(dtype, instruction_set) and not instruction_set in ['avx512', 'rvv'] and not instruction_set.startswith('sve'):
with pytest.warns(UserWarning) as warn:
ast = ps.create_kernel(update_rule, cpu_vectorize_info={'instruction_set': instruction_set})
assert 'Could not vectorize loop' in warn[0].message.args[0]
......@@ -106,12 +106,13 @@ def test_alignment_and_correct_ghost_layers(gl_field, gl_kernel, instruction_set
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_cacheline_size(instruction_set):
cacheline_size = get_cacheline_size(instruction_set)
if cacheline_size is None:
if cacheline_size is None and instruction_set in ['sse', 'avx', 'avx512', 'rvv']:
pytest.skip()
instruction_set = get_vector_instruction_set('double', instruction_set)
vector_size = instruction_set['bytes']
assert cacheline_size > 8 and cacheline_size < 0x100000, "Cache line size is implausible"
assert cacheline_size % vector_size == 0, "Cache line size should be multiple of vector size"
if type(vector_size) is int:
assert cacheline_size % vector_size == 0, "Cache line size should be multiple of vector size"
assert cacheline_size & (cacheline_size - 1) == 0, "Cache line size is not a power of 2"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment