diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 194900620b8727514d30aa75b4b69c1fca276004..33d977e78f3102c513bdfe26a8f0ecffc19a5d1e 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -345,8 +345,8 @@ class CBackend: pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \ self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n' - code = self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs), - printed_mask, **self._kwargs) + ';' + code = self._vector_instruction_set.operation(instr, data_type).format(ptr, self.sympy_printer.doprint(rhs), + printed_mask, **self._kwargs) + ';' flushcond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == {CachelineSize.last_symbol}" if nontemporal and 'flushCacheline' in self._vector_instruction_set: code2 = self._vector_instruction_set['flushCacheline'].format( @@ -624,24 +624,20 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): if type(expr_type) is not VectorType: return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs) else: - assert self.instruction_set['width'] == expr_type.width + # assert self.instruction_set['width'] == expr_type.width return None def _print_Abs(self, expr): - if 'abs' in self.instruction_set and isinstance(expr.args[0], VectorMemoryAccess): - return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs) + expr_type = get_type_of_expression(expr) + if isinstance(expr_type, VectorType) and self.instruction_set.supports('abs'): + return self.instruction_set.operation('abs', expr_type).format(self._print(expr.args[0]), **self._kwargs) return super()._print_Abs(expr) def _typed_vectorized_number(self, expr, data_type): basic_data_type = data_type.base_type number = self._typed_number(expr, basic_data_type) instruction = 'makeVecConst' - if basic_data_type.is_bool(): - instruction = 'makeVecConstBool' - # TODO Vectorization Revamp: is int, or sint, or uint (my guess is sint) - elif basic_data_type.is_int(): - instruction = 'makeVecConstInt' - return self.instruction_set[instruction].format(number, **self._kwargs) + return self.instruction_set.operation(instruction, data_type).format(number, **self._kwargs) def _typed_vectorized_symbol(self, expr, data_type): if not isinstance(expr, TypedSymbol): @@ -652,12 +648,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): symbol = f'(({basic_data_type})({symbol}))' instruction = 'makeVecConst' - if basic_data_type.is_bool(): - instruction = 'makeVecConstBool' - # TODO Vectorization Revamp: is int, or sint, or uint (my guess is sint) - elif basic_data_type.is_int(): - instruction = 'makeVecConstInt' - return self.instruction_set[instruction].format(symbol, **self._kwargs) + return self.instruction_set.operation(instruction, data_type).format(symbol, **self._kwargs) def _typed_vectorized_access(self, expr, data_type): basic_data_type = data_type.base_type @@ -675,16 +666,15 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): # vector_memory_access is a cast_func itself so it should't be directly inside a cast_func assert not isinstance(arg, VectorMemoryAccess) if isinstance(arg, sp.Tuple): - is_boolean = get_type_of_expression(arg[0]) == create_type("bool") is_integer = get_type_of_expression(arg[0]) == create_type("int") printed_args = [self._print(a) for a in arg] - instruction = 'makeVecBool' if is_boolean else 'makeVecInt' if is_integer else 'makeVec' - if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set: + instruction = 'makeVec' + if is_integer and self.instruction_set.supports('makeVecIndex'): increments = np.array(arg)[1:] - np.array(arg)[:-1] if len(set(increments)) == 1: return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0], **self._kwargs) - return self.instruction_set[instruction].format(*printed_args, **self._kwargs) + return self.instruction_set.operation(instruction, data_type).format(*printed_args, **self._kwargs) else: if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)): return self._typed_vectorized_number(arg, data_type) @@ -706,7 +696,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): elif isinstance(arg, ResolvedFieldAccess): return self._typed_vectorized_access(arg, data_type) else: - raise NotImplementedError('Vectorizer cannot cast between different datatypes') + return self.instruction_set.operation('convert', data_type, [get_type_of_expression(arg)]).format(self._print(arg), **self._kwargs) + # raise NotImplementedError(f'Vectorizer cannot cast between different datatypes in expression "{expr}" ({sp.srepr(expr)})') # to_type = self.instruction_set['suffix'][data_type.base_type.c_name] # from_type = self.instruction_set['suffix'][get_type_of_expression(arg).base_type.c_name] # return self.instruction_set['cast'].format(from_type, to_type, self._print(arg)) @@ -715,17 +706,18 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): # raise ValueError(f'Non VectorType cast "{data_type}" in vectorized code.') def _print_Function(self, expr): + expr_type = get_type_of_expression(expr.args[0]) if isinstance(expr, VectorMemoryAccess): arg, data_type, aligned, _, mask, stride = expr.args if stride != 1: - return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs) - instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] + return self.instruction_set.operation('loadS', data_type).format(f"& {self._print(arg)}", stride, **self._kwargs) + instruction = self.instruction_set.operation('loadA', data_type) if aligned else self.instruction_set.operation('loadU', data_type) return instruction.format(f"& {self._print(arg)}", **self._kwargs) elif expr.func == DivFunc: result = self._scalarFallback('_print_Function', expr) if not result: - result = self.instruction_set['/'].format(self._print(expr.divisor), self._print(expr.dividend), - **self._kwargs) + result = self.instruction_set.operation('/', expr_type).format(self._print(expr.divisor), self._print(expr.dividend), + **self._kwargs) return result elif isinstance(expr, fast_division): raise ValueError("fast_division is only supported for Taget.GPU") @@ -735,7 +727,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): raise ValueError("fast_inv_sqrt is only supported for Taget.GPU") elif isinstance(expr, vec_any) or isinstance(expr, vec_all): instr = 'any' if isinstance(expr, vec_any) else 'all' - expr_type = get_type_of_expression(expr.args[0]) if type(expr_type) is not VectorType: return self._print(expr.args[0]) else: @@ -782,15 +773,13 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): args = expr.args # special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization - suffix = "" - if all([(type(e) is CastFunc and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer) + if all([(type(e) is CastFunc and isinstance(e.dtype, BasicType) and e.dtype.is_int()) or isinstance(e, sp.Integer) or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]): dtype = set([e.dtype for e in args if type(e) is CastFunc]) assert len(dtype) == 1 dtype = dtype.pop() args = [CastFunc(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e for e in args] - suffix = "int" summands = [] for term in args: @@ -802,21 +791,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): summands.append(self.SummandInfo(sign, t)) # Use positive terms first summands.sort(key=lambda e: e.sign, reverse=True) + arg_types = [get_type_of_expression(a) for a in args] + target_type = collate_types(arg_types) # if no positive term exists, prepend a zero if summands[0].sign == -1: - arg_types = [get_type_of_expression(a) for a in args] - target_type = collate_types(arg_types) summands.insert(0, self.SummandInfo(1, self._print(CastFunc(0, target_type)))) assert len(summands) >= 2 processed = summands[0].term for summand in summands[1:]: - func = self.instruction_set['-' + suffix] if summand.sign == -1 else self.instruction_set['+' + suffix] - processed = func.format(processed, summand.term, **self._kwargs) + op = '-' if summand.sign == -1 else '+' + processed = self.instruction_set.operation(op, target_type).format(processed, summand.term, **self._kwargs) return processed def _print_Pow(self, expr): # Due to loop cutting sp.Mul is evaluated again. + expr_type = get_type_of_expression(expr) try: result = self._scalarFallback('_print_Pow', expr) @@ -825,8 +815,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): if result: return result - one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs) - root = self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs) + one = self.instruction_set.operation('makeVecConst', expr_type).format(1.0, **self._kwargs) + root = self.instruction_set.operation('sqrt', expr_type).format(self._print(expr.base), **self._kwargs) if isinstance(expr.exp, CastFunc) and expr.exp.args[0].is_number: exp = expr.exp.args[0] @@ -845,7 +835,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): elif exp == 0.5: return root elif exp == -0.5: - return self.instruction_set['/'].format(one, root, **self._kwargs) + return self.instruction_set.operation('/', expr_type).format(one, root, **self._kwargs) else: raise ValueError("Generic exponential not supported: " + str(expr)) @@ -853,6 +843,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): # noinspection PyProtectedMember from sympy.core.mul import _keep_coeff + expr_type = get_type_of_expression(expr) + if not inside_add: result = self._scalarFallback('_print_Mul', expr) else: @@ -887,19 +879,19 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = a_str[0] for item in a_str[1:]: - result = self.instruction_set['*'].format(result, item, **self._kwargs) + result = self.instruction_set.operation('*', expr_type).format(result, item, **self._kwargs) if len(b) > 0: denominator_str = b_str[0] for item in b_str[1:]: - denominator_str = self.instruction_set['*'].format(denominator_str, item, **self._kwargs) - result = self.instruction_set['/'].format(result, denominator_str, **self._kwargs) + denominator_str = self.instruction_set.operation('*', expr_type).format(denominator_str, item, **self._kwargs) + result = self.instruction_set.operation('/', expr_type).format(result, denominator_str, **self._kwargs) if inside_add: return sign, result else: if sign < 0: - return self.instruction_set['*'].format(self._print(S.NegativeOne), result, **self._kwargs) + return self.instruction_set.operation('*', expr_type).format(self._print(S.NegativeOne), result, **self._kwargs) else: return result diff --git a/pystencils/backends/x86_instruction_sets.py b/pystencils/backends/x86_instruction_sets.py index 7f05a403d45a8ca890a78b7db92ff147d30dd7dc..bb1bac5aad8064df1aef5d21d6fdbca98a145bd2 100644 --- a/pystencils/backends/x86_instruction_sets.py +++ b/pystencils/backends/x86_instruction_sets.py @@ -1,3 +1,7 @@ +from typing import Dict, List +import numpy as np +from pystencils.typing.types import VectorType + def get_argument_string(intrinsic_id, width, function_shortcut): if intrinsic_id == 'makeVecConst' or intrinsic_id == 'makeVecConstInt': arg_string = f"({','.join(['{0}'] * width)})" @@ -176,3 +180,131 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): result['streamFence'] = '_mm_mfence()' return result + +class X86InstructionSet: + instruction_set: str + comparisons: Dict[str, str] + base_names: Dict[str, str] + + def __init__(self, instruction_set: str = 'avx'): + assert instruction_set in ['sse', 'avx', 'avx512', 'avx512vl'] + self.instruction_set = instruction_set + + self.comparisons = { + '==': '_CMP_EQ_UQ', + '!=': '_CMP_NEQ_UQ', + '>=': '_CMP_GE_OQ', + '<=': '_CMP_LE_OQ', + '<': '_CMP_NGE_UQ', + '>': '_CMP_NLE_UQ', + } + self.base_names = { + '+': 'add[0, 1]', + '-': 'sub[0, 1]', + '*': 'mul[0, 1]', + '/': 'div[0, 1]', + '&': 'and[0, 1]', + '|': 'or[0, 1]', + 'blendv': 'blendv[0, 1, 2]', + + 'sqrt': 'sqrt[0]', + + 'makeVecConst': 'set[]', + 'makeVec': 'set[]', + + 'loadU': 'loadu[0]', + 'loadA': 'load[0]', + 'storeU': 'storeu[0,1]', + 'storeA': 'store[0,1]', + 'stream': 'stream[0,1]', + 'maskStoreA': 'mask_store[0, 2, 1]' if instruction_set.startswith('avx512') else 'maskstore[0, 2, 1]', + 'maskStoreU': 'mask_storeu[0, 2, 1]' if instruction_set.startswith('avx512') else 'maskstore[0, 2, 1]', + } + + for comparison_op, constant in self.comparisons.items(): + self.base_names[comparison_op] = f'cmp[0, 1, {constant}]' + + def type_name(self, dtype: VectorType) -> str: + if dtype.base_type.numpy_dtype == np.float16: + suffix = 'h' + elif dtype.base_type.numpy_dtype == np.float32: + suffix = '' + elif dtype.base_type.numpy_dtype == np.float64: + suffix = 'd' + elif dtype.base_type.is_int(): + suffix = 'i' + else: + raise RuntimeError(f'unsopported base type {dtype.base_type}') + + if dtype.bit_width == 64 and suffix == 'i': + return '__m64' + elif dtype.bit_width == 128 and suffix != 'h': + return '__m128' + suffix + elif dtype.bit_width == 256 and self.instruction_set.startswith('avx'): + if self.instruction_set == 'avx' and suffix == 'h': + raise RuntimeError('half precision not supported by avx instruction set') + return '__m256' + suffix + elif dtype.bit_width == 512 and self.instruction_set.startswith('avx512'): + return '__m512' + suffix + + raise RuntimeError(f'unsopported vector type {dtype}') + + def operation(self, op: str, dtype: VectorType, arg_types: List[VectorType] = []) -> str: + suffix = self.suffix(dtype) + + prefix = '' + if dtype.bit_width in [64, 128]: + prefix = '_mm' + elif dtype.bit_width in [256, 512]: + prefix = f'_mm{dtype.bit_width}' + else: + raise RuntimeError(f'unsopported vector type {dtype}') + + if op in self.base_names: + function_shortcut = self.base_names[op] + function_shortcut = function_shortcut.strip() + name = function_shortcut[:function_shortcut.index('[')] + + arg_string = get_argument_string(op, dtype.width, function_shortcut) + mask_suffix = '_mask' if self.instruction_set.startswith('avx512') and op in self.comparisons.keys() else '' + + if name == 'set' and suffix == 'epi64' and dtype.bit_width < 512: + suffix = 'epi64x' + + return f'{prefix}_{name}_{suffix}{mask_suffix}{arg_string}' + + elif op == 'abs': + if self.instruction_set == 'avx512': + return f'{prefix}_abs_{suffix}({{0}})' + else: + setsuf = "x" if dtype.bit_width < 512 and dtype.bit_width // dtype.width == 64 else "" + return f'{prefix}_castsi{dtype.bit_width}_{suffix}({prefix}_and_si{dtype.bit_width}(' + \ + f"{prefix}_set1_epi{dtype.bit_width // dtype.width}{setsuf}(0x7" + \ + 'f' * (dtype.bit_width // dtype.width // 4 - 1) + '), ' + \ + f'{prefix}_cast{suffix}_si{dtype.bit_width}({{0}})))' + + elif op == 'convert': + to_type = dtype + from_type = arg_types[0] + # TODO check that instruction exists + + name = f'cvt{self.suffix(from_type)}' + arg_string = get_argument_string(op, dtype.width, f'{name}[0]') + return f'{prefix}_{name}_{suffix}{arg_string}' + + def supports(self, op: str) -> bool: + return op in self.base_names or op in ['abs', 'convert'] + + def suffix(self, dtype: VectorType) -> str: + if dtype.base_type.numpy_dtype == np.float16: + return 'ph' + elif dtype.base_type.numpy_dtype == np.float32: + return 'ps' + elif dtype.base_type.numpy_dtype == np.float64: + return 'pd' + elif dtype.base_type.is_sint(): + return f'epi{dtype.base_type.bit_width}' + elif dtype.base_type.is_uint(): + return f'epu{dtype.base_type.bit_width}' + else: + raise RuntimeError(f'unsopported base type {dtype.base_type}') diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 6249a9303602f1ced020d8ecde6feed5b041e46d..ec1137ce26fa37fa4d614b691d4566ce72acaf2a 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -206,10 +206,9 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, assume_aligned, nontem loop_node.step = vector_width loop_node.subs(substitutions) - vector_int_width = ast_node.instruction_set['intwidth'] - arg_1 = CastFunc(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) - arg_2 = CastFunc(tuple(range(vector_int_width if type(vector_int_width) is int else 2)), - VectorType(loop_counter_symbol.dtype, vector_int_width)) + arg_1 = CastFunc(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_width)) + arg_2 = CastFunc(tuple(range(vector_width if type(vector_width) is int else 2)), + VectorType(loop_counter_symbol.dtype, vector_width)) vector_loop_counter = arg_1 + arg_2 fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter}, @@ -296,10 +295,13 @@ def insert_vector_casts(ast_node, instruction_set, loop_counter_symbol, default_ *expr.args[5:]) elif isinstance(expr, CastFunc): cast_type = expr.args[1] - arg = visit_expr(expr.args[0], default_type, force_vectorize) - assert cast_type in [BasicType('float32'), BasicType('float64')],\ - f'Vectorization cannot vectorize type {cast_type}' - return expr.func(arg, VectorType(cast_type, instruction_set['width'])) + if isinstance(cast_type, VectorType): + return expr + else: + arg = visit_expr(expr.args[0], default_type, force_vectorize) + # assert cast_type in [BasicType('float32'), BasicType('float64')],\ + # f'Vectorization cannot vectorize type {cast_type}' + return expr.func(arg, VectorType(cast_type, instruction_set['width'])) elif expr.func is sp.Abs and 'abs' not in instruction_set: # make abs a piecewise function if it is not natively supported by the instruction set new_arg = visit_expr(expr.args[0], default_type, force_vectorize) diff --git a/pystencils/typing/types.py b/pystencils/typing/types.py index f0f9744a558ed42bfebe08e6430da3aa21d8131a..46a129e4bba73282169e5760008a45c82ab97d36 100644 --- a/pystencils/typing/types.py +++ b/pystencils/typing/types.py @@ -121,6 +121,19 @@ class BasicType(AbstractType): def c_name(self) -> str: return numpy_name_to_c(self.numpy_dtype.name) + @property + def bit_width(self) -> int: + if self.numpy_dtype in [np.int8, np.uint8]: + return 8 + elif self.numpy_dtype in [np.int16, np.uint16, np.float16]: + return 16 + elif self.numpy_dtype in [np.int32, np.uint32, np.float32]: + return 32 + elif self.numpy_dtype in [np.int64, np.uint64, np.float64]: + return 64 + else: + raise NotImplementedError("Bit width of type {self.numpy_dtype} is unknown") + def __str__(self): return f'{self.c_name}{" const" if self.const else ""}' @@ -155,6 +168,10 @@ class VectorType(AbstractType): def item_size(self): return self.width * self.base_type.item_size + @property + def bit_width(self) -> int: + return self._base_type.bit_width * self.width + def __eq__(self, other): if not isinstance(other, VectorType): return False @@ -165,18 +182,7 @@ class VectorType(AbstractType): if self.instruction_set is None: return f"{self.base_type}[{self.width}]" else: - # TODO VectorizationRevamp: this seems super weird. the instruction_set should know how to print a type out! - # TODO VectorizationRevamp: this is error prone. base_type could be cons=True. Use dtype instead - if self.base_type == create_type("int64") or self.base_type == create_type("int32"): - return self.instruction_set['int'] - elif self.base_type == create_type("float64"): - return self.instruction_set['double'] - elif self.base_type == create_type("float32"): - return self.instruction_set['float'] - elif self.base_type == create_type("bool"): - return self.instruction_set['bool'] - else: - raise NotImplementedError() + return self.instruction_set.type_name(self) def __hash__(self): return hash((self.base_type, self.width))