diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 696928ef5a9198d25d769748d6324e1b2971f372..3645a89c3c8f9d85b803076bc7102f4bac3a0eac 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -605,6 +605,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): instruction = 'makeVecConstInt' return self.instruction_set[instruction].format(number, **self._kwargs) + def _typed_vectorized_symbol(self, expr, data_type): + if not isinstance(expr, TypedSymbol): + raise ValueError(f'{expr} is not a TypeSymbol. It is {expr.type=}') + basic_data_type = data_type.base_type + symbol = self._print(expr) + if basic_data_type != expr.dtype: + symbol = f'(({basic_data_type.data_type})({symbol}))' + + instruction = 'makeVecConst' + if basic_data_type.is_bool(): + instruction = 'makeVecConstBool' + # TODO is int, or sint, or uint? + elif basic_data_type.is_int(): + instruction = 'makeVecConstInt' + return self.instruction_set[instruction].format(symbol, **self._kwargs) + def _print_CastFunc(self, expr): arg, data_type = expr.args if type(data_type) is VectorType: @@ -624,6 +640,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): else: if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)): return self._typed_vectorized_number(arg, data_type) + elif isinstance(arg, TypedSymbol): + return self._typed_vectorized_symbol(arg, data_type) elif isinstance(arg, (InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)) \ and data_type == BasicType('float32'): raise NotImplementedError('Vectorizer is not tested for trigonometric functions yes') @@ -642,7 +660,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): data_type_prefix = self.instruction_set['dataTypePrefix'][data_type.base_type.c_name] return f'(({data_type_prefix})({self._print(arg)}))' else: - raise ValueError(f'Non VectorType cast "{data_type}" in vectorized code.') + return self._scalarFallback('_print_Function', expr) + # raise ValueError(f'Non VectorType cast "{data_type}" in vectorized code.') def _print_Function(self, expr): if isinstance(expr, VectorMemoryAccess): @@ -651,6 +670,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs) instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] return instruction.format(f"& {self._print(arg)}", **self._kwargs) + elif expr.func == DivFunc: + return self.instruction_set['/'].format(self._print(expr.divisor), self._print(expr.dividend), + **self._kwargs) elif expr.func == fast_division: result = self._scalarFallback('_print_Function', expr) if not result: diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index a161d5879dcef112c8fc0ddb6e71daab730fcc01..812a6163465295911f8b252a2a4eab0af7ec2417 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -3,13 +3,14 @@ from typing import Container, Union import numpy as np import sympy as sp -from sympy.logic.boolalg import BooleanFunction +from sympy.logic.boolalg import BooleanFunction, BooleanAtom import pystencils.astnodes as ast from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set from pystencils.typing import ( PointerType, TypedSymbol, VectorType, CastFunc, collate_types, get_type_of_expression, VectorMemoryAccess) from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt +from pystencils.functions import DivFunc from pystencils.field import Field from pystencils.integer_functions import modulo_ceil, modulo_floor from pystencils.sympyextensions import fast_subs @@ -245,13 +246,13 @@ def mask_conditionals(loop_body): def insert_vector_casts(ast_node, default_float_type='double'): """Inserts necessary casts from scalar values to vector values.""" - handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all) + handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all, DivFunc) - def visit_expr(expr, default_type='double'): + def visit_expr(expr, default_type='double'): # TODO get rid of default_type if isinstance(expr, VectorMemoryAccess): return VectorMemoryAccess(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:]) elif isinstance(expr, CastFunc): - return expr + return expr # TODO here, since CastFunc might not be vector??? elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set: new_arg = visit_expr(expr.args[0], default_type) base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is VectorMemoryAccess \ @@ -307,8 +308,10 @@ def insert_vector_casts(ast_node, default_float_type='double'): for a, t in zip(new_conditions, types_of_conditions)] return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)]) - else: + elif isinstance(expr, (sp.Number, TypedSymbol, BooleanAtom)): return expr + else: + raise NotImplementedError(f'Should I raise or should I return now? {expr}') def visit_node(node, substitution_dict, default_type='double'): substitution_dict = substitution_dict.copy() diff --git a/pystencils/fast_approximation.py b/pystencils/fast_approximation.py index 9eee41a96f96d05b9fc9be3443a7291359369857..65f85a71a25e2da0082a61a5418ce4c4eb656af0 100644 --- a/pystencils/fast_approximation.py +++ b/pystencils/fast_approximation.py @@ -9,6 +9,7 @@ from pystencils.assignment import Assignment # noinspection PyPep8Naming class fast_division(sp.Function): + # TODO how is this fast? The printer prints a normal division??? nargs = (2,) diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py index 6e8b0a4ff68c2453ddfc336ed63fb31d62e7ad4c..a7a335c7592f87df4524276dffd18b03c8f0a1c8 100644 --- a/pystencils_tests/test_vectorization.py +++ b/pystencils_tests/test_vectorization.py @@ -171,9 +171,9 @@ def test_piecewise2(instruction_set=instruction_set): g[0, 0] @= s.result ast = ps.create_kernel(test_kernel) - ps.show_code(ast) + # ps.show_code(ast) vectorize(ast, instruction_set=instruction_set) - ps.show_code(ast) + # ps.show_code(ast) func = ast.compile() func(f=arr, g=arr) np.testing.assert_equal(arr, np.ones_like(arr)) @@ -189,7 +189,9 @@ def test_piecewise3(instruction_set=instruction_set): g[0, 0] @= 1.0 / (s.b + s.k) if f[0, 0] > 0.0 else 1.0 ast = ps.create_kernel(test_kernel) + ps.show_code(ast) vectorize(ast, instruction_set=instruction_set) + ps.show_code(ast) ast.compile()