From e5d49acf645f2dbc0221c08b49100c50fc8ed73d Mon Sep 17 00:00:00 2001
From: markus holzer <markus.holzer@fau.de>
Date: Wed, 26 Jan 2022 17:31:30 +0100
Subject: [PATCH] Fix division

---
 pystencils/backends/cbackend.py        | 24 +++++++++++++++++++++++-
 pystencils/cpu/vectorization.py        | 13 ++++++++-----
 pystencils/fast_approximation.py       |  1 +
 pystencils_tests/test_vectorization.py |  6 ++++--
 4 files changed, 36 insertions(+), 8 deletions(-)

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 696928ef..3645a89c 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -605,6 +605,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
             instruction = 'makeVecConstInt'
         return self.instruction_set[instruction].format(number, **self._kwargs)
 
+    def _typed_vectorized_symbol(self, expr, data_type):
+        if not isinstance(expr, TypedSymbol):
+            raise ValueError(f'{expr} is not a TypeSymbol. It is {expr.type=}')
+        basic_data_type = data_type.base_type
+        symbol = self._print(expr)
+        if basic_data_type != expr.dtype:
+            symbol = f'(({basic_data_type.data_type})({symbol}))'
+
+        instruction = 'makeVecConst'
+        if basic_data_type.is_bool():
+            instruction = 'makeVecConstBool'
+        # TODO is int, or sint, or uint?
+        elif basic_data_type.is_int():
+            instruction = 'makeVecConstInt'
+        return self.instruction_set[instruction].format(symbol, **self._kwargs)
+
     def _print_CastFunc(self, expr):
         arg, data_type = expr.args
         if type(data_type) is VectorType:
@@ -624,6 +640,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
             else:
                 if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)):
                     return self._typed_vectorized_number(arg, data_type)
+                elif isinstance(arg, TypedSymbol):
+                    return self._typed_vectorized_symbol(arg, data_type)
                 elif isinstance(arg, (InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)) \
                         and data_type == BasicType('float32'):
                     raise NotImplementedError('Vectorizer is not tested for trigonometric functions yes')
@@ -642,7 +660,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
                     data_type_prefix = self.instruction_set['dataTypePrefix'][data_type.base_type.c_name]
                     return f'(({data_type_prefix})({self._print(arg)}))'
         else:
-            raise ValueError(f'Non VectorType cast "{data_type}" in vectorized code.')
+            return self._scalarFallback('_print_Function', expr)
+            # raise ValueError(f'Non VectorType cast "{data_type}" in vectorized code.')
 
     def _print_Function(self, expr):
         if isinstance(expr, VectorMemoryAccess):
@@ -651,6 +670,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
                 return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs)
             instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
             return instruction.format(f"& {self._print(arg)}", **self._kwargs)
+        elif expr.func == DivFunc:
+            return self.instruction_set['/'].format(self._print(expr.divisor), self._print(expr.dividend),
+                                                    **self._kwargs)
         elif expr.func == fast_division:
             result = self._scalarFallback('_print_Function', expr)
             if not result:
diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index a161d587..812a6163 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -3,13 +3,14 @@ from typing import Container, Union
 
 import numpy as np
 import sympy as sp
-from sympy.logic.boolalg import BooleanFunction
+from sympy.logic.boolalg import BooleanFunction, BooleanAtom
 
 import pystencils.astnodes as ast
 from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set
 from pystencils.typing import (
     PointerType, TypedSymbol, VectorType, CastFunc, collate_types, get_type_of_expression, VectorMemoryAccess)
 from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
+from pystencils.functions import DivFunc
 from pystencils.field import Field
 from pystencils.integer_functions import modulo_ceil, modulo_floor
 from pystencils.sympyextensions import fast_subs
@@ -245,13 +246,13 @@ def mask_conditionals(loop_body):
 def insert_vector_casts(ast_node, default_float_type='double'):
     """Inserts necessary casts from scalar values to vector values."""
 
-    handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all)
+    handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all, DivFunc)
 
-    def visit_expr(expr, default_type='double'):
+    def visit_expr(expr, default_type='double'):  # TODO get rid of default_type
         if isinstance(expr, VectorMemoryAccess):
             return VectorMemoryAccess(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:])
         elif isinstance(expr, CastFunc):
-            return expr
+            return expr  # TODO here, since CastFunc might not be vector???
         elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set:
             new_arg = visit_expr(expr.args[0], default_type)
             base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is VectorMemoryAccess \
@@ -307,8 +308,10 @@ def insert_vector_casts(ast_node, default_float_type='double'):
                                  for a, t in zip(new_conditions, types_of_conditions)]
 
             return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)])
-        else:
+        elif isinstance(expr, (sp.Number, TypedSymbol, BooleanAtom)):
             return expr
+        else:
+            raise NotImplementedError(f'Should I raise or should I return now? {expr}')
 
     def visit_node(node, substitution_dict, default_type='double'):
         substitution_dict = substitution_dict.copy()
diff --git a/pystencils/fast_approximation.py b/pystencils/fast_approximation.py
index 9eee41a9..65f85a71 100644
--- a/pystencils/fast_approximation.py
+++ b/pystencils/fast_approximation.py
@@ -9,6 +9,7 @@ from pystencils.assignment import Assignment
 
 # noinspection PyPep8Naming
 class fast_division(sp.Function):
+    # TODO how is this fast? The printer prints a normal division???
     nargs = (2,)
 
 
diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py
index 6e8b0a4f..a7a335c7 100644
--- a/pystencils_tests/test_vectorization.py
+++ b/pystencils_tests/test_vectorization.py
@@ -171,9 +171,9 @@ def test_piecewise2(instruction_set=instruction_set):
         g[0, 0]     @= s.result
 
     ast = ps.create_kernel(test_kernel)
-    ps.show_code(ast)
+    # ps.show_code(ast)
     vectorize(ast, instruction_set=instruction_set)
-    ps.show_code(ast)
+    # ps.show_code(ast)
     func = ast.compile()
     func(f=arr, g=arr)
     np.testing.assert_equal(arr, np.ones_like(arr))
@@ -189,7 +189,9 @@ def test_piecewise3(instruction_set=instruction_set):
         g[0, 0] @= 1.0 / (s.b + s.k) if f[0, 0] > 0.0 else 1.0
 
     ast = ps.create_kernel(test_kernel)
+    ps.show_code(ast)
     vectorize(ast, instruction_set=instruction_set)
+    ps.show_code(ast)
     ast.compile()
 
 
-- 
GitLab