From e5d49acf645f2dbc0221c08b49100c50fc8ed73d Mon Sep 17 00:00:00 2001 From: markus holzer <markus.holzer@fau.de> Date: Wed, 26 Jan 2022 17:31:30 +0100 Subject: [PATCH] Fix division --- pystencils/backends/cbackend.py | 24 +++++++++++++++++++++++- pystencils/cpu/vectorization.py | 13 ++++++++----- pystencils/fast_approximation.py | 1 + pystencils_tests/test_vectorization.py | 6 ++++-- 4 files changed, 36 insertions(+), 8 deletions(-) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 696928ef..3645a89c 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -605,6 +605,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): instruction = 'makeVecConstInt' return self.instruction_set[instruction].format(number, **self._kwargs) + def _typed_vectorized_symbol(self, expr, data_type): + if not isinstance(expr, TypedSymbol): + raise ValueError(f'{expr} is not a TypeSymbol. It is {expr.type=}') + basic_data_type = data_type.base_type + symbol = self._print(expr) + if basic_data_type != expr.dtype: + symbol = f'(({basic_data_type.data_type})({symbol}))' + + instruction = 'makeVecConst' + if basic_data_type.is_bool(): + instruction = 'makeVecConstBool' + # TODO is int, or sint, or uint? + elif basic_data_type.is_int(): + instruction = 'makeVecConstInt' + return self.instruction_set[instruction].format(symbol, **self._kwargs) + def _print_CastFunc(self, expr): arg, data_type = expr.args if type(data_type) is VectorType: @@ -624,6 +640,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): else: if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)): return self._typed_vectorized_number(arg, data_type) + elif isinstance(arg, TypedSymbol): + return self._typed_vectorized_symbol(arg, data_type) elif isinstance(arg, (InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)) \ and data_type == BasicType('float32'): raise NotImplementedError('Vectorizer is not tested for trigonometric functions yes') @@ -642,7 +660,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): data_type_prefix = self.instruction_set['dataTypePrefix'][data_type.base_type.c_name] return f'(({data_type_prefix})({self._print(arg)}))' else: - raise ValueError(f'Non VectorType cast "{data_type}" in vectorized code.') + return self._scalarFallback('_print_Function', expr) + # raise ValueError(f'Non VectorType cast "{data_type}" in vectorized code.') def _print_Function(self, expr): if isinstance(expr, VectorMemoryAccess): @@ -651,6 +670,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs) instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] return instruction.format(f"& {self._print(arg)}", **self._kwargs) + elif expr.func == DivFunc: + return self.instruction_set['/'].format(self._print(expr.divisor), self._print(expr.dividend), + **self._kwargs) elif expr.func == fast_division: result = self._scalarFallback('_print_Function', expr) if not result: diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index a161d587..812a6163 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -3,13 +3,14 @@ from typing import Container, Union import numpy as np import sympy as sp -from sympy.logic.boolalg import BooleanFunction +from sympy.logic.boolalg import BooleanFunction, BooleanAtom import pystencils.astnodes as ast from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set from pystencils.typing import ( PointerType, TypedSymbol, VectorType, CastFunc, collate_types, get_type_of_expression, VectorMemoryAccess) from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt +from pystencils.functions import DivFunc from pystencils.field import Field from pystencils.integer_functions import modulo_ceil, modulo_floor from pystencils.sympyextensions import fast_subs @@ -245,13 +246,13 @@ def mask_conditionals(loop_body): def insert_vector_casts(ast_node, default_float_type='double'): """Inserts necessary casts from scalar values to vector values.""" - handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all) + handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all, DivFunc) - def visit_expr(expr, default_type='double'): + def visit_expr(expr, default_type='double'): # TODO get rid of default_type if isinstance(expr, VectorMemoryAccess): return VectorMemoryAccess(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:]) elif isinstance(expr, CastFunc): - return expr + return expr # TODO here, since CastFunc might not be vector??? elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set: new_arg = visit_expr(expr.args[0], default_type) base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is VectorMemoryAccess \ @@ -307,8 +308,10 @@ def insert_vector_casts(ast_node, default_float_type='double'): for a, t in zip(new_conditions, types_of_conditions)] return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)]) - else: + elif isinstance(expr, (sp.Number, TypedSymbol, BooleanAtom)): return expr + else: + raise NotImplementedError(f'Should I raise or should I return now? {expr}') def visit_node(node, substitution_dict, default_type='double'): substitution_dict = substitution_dict.copy() diff --git a/pystencils/fast_approximation.py b/pystencils/fast_approximation.py index 9eee41a9..65f85a71 100644 --- a/pystencils/fast_approximation.py +++ b/pystencils/fast_approximation.py @@ -9,6 +9,7 @@ from pystencils.assignment import Assignment # noinspection PyPep8Naming class fast_division(sp.Function): + # TODO how is this fast? The printer prints a normal division??? nargs = (2,) diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py index 6e8b0a4f..a7a335c7 100644 --- a/pystencils_tests/test_vectorization.py +++ b/pystencils_tests/test_vectorization.py @@ -171,9 +171,9 @@ def test_piecewise2(instruction_set=instruction_set): g[0, 0] @= s.result ast = ps.create_kernel(test_kernel) - ps.show_code(ast) + # ps.show_code(ast) vectorize(ast, instruction_set=instruction_set) - ps.show_code(ast) + # ps.show_code(ast) func = ast.compile() func(f=arr, g=arr) np.testing.assert_equal(arr, np.ones_like(arr)) @@ -189,7 +189,9 @@ def test_piecewise3(instruction_set=instruction_set): g[0, 0] @= 1.0 / (s.b + s.k) if f[0, 0] > 0.0 else 1.0 ast = ps.create_kernel(test_kernel) + ps.show_code(ast) vectorize(ast, instruction_set=instruction_set) + ps.show_code(ast) ast.compile() -- GitLab