From 7315fde9c9a2fe9105f29cd95dd471d247e52b64 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de>
Date: Wed, 2 Feb 2022 15:09:06 +0100
Subject: [PATCH] More vector

---
 pystencils/backends/cbackend.py             | 27 ++++++++++++--------
 pystencils/backends/x86_instruction_sets.py |  8 +-----
 pystencils/cpu/vectorization.py             | 22 ++++++++++------
 pystencils_tests/test_vectorization.py      | 28 +++++++++++++--------
 4 files changed, 49 insertions(+), 36 deletions(-)

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 3645a89c..b631f0a8 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -644,7 +644,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
                     return self._typed_vectorized_symbol(arg, data_type)
                 elif isinstance(arg, (InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)) \
                         and data_type == BasicType('float32'):
-                    raise NotImplementedError('Vectorizer is not tested for trigonometric functions yes')
+                    raise NotImplementedError('Vectorizer is not tested for trigonometric functions yet')
                     # known = self.known_functions[arg.__class__.__name__.lower()]
                     # code = self._print(arg)
                     # return code.replace(known, f"{known}f")
@@ -657,8 +657,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
                     #         return code.replace(k, f'{k}f')
                     # raise ValueError(f"{code} doesn't give {known=} function back.")
                 else:
-                    data_type_prefix = self.instruction_set['dataTypePrefix'][data_type.base_type.c_name]
-                    return f'(({data_type_prefix})({self._print(arg)}))'
+                    raise NotImplementedError('Vectorizer cannot cast between different datatypes')
+                    # to_type = self.instruction_set['suffix'][data_type.base_type.c_name]
+                    # from_type = self.instruction_set['suffix'][get_type_of_expression(arg).base_type.c_name]
+                    # return self.instruction_set['cast'].format(from_type, to_type, self._print(arg))
         else:
             return self._scalarFallback('_print_Function', expr)
             # raise ValueError(f'Non VectorType cast "{data_type}" in vectorized code.')
@@ -775,19 +777,24 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
 
         one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs)
 
-        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
-            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
-        elif expr.exp == -1:
+        if isinstance(expr.exp, CastFunc) and expr.exp.args[0].is_number:
+            exp = expr.exp.args[0]
+        else:
+            exp = expr.exp
+
+        if exp.is_integer and exp.is_number and 0 < exp < 8:
+            return "(" + self._print(sp.Mul(*[expr.base] * exp, evaluate=False)) + ")"
+        elif exp == -1:
             one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs)
             return self.instruction_set['/'].format(one, self._print(expr.base), **self._kwargs)
-        elif expr.exp == 0.5:
+        elif exp == 0.5:
             return self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
-        elif expr.exp == -0.5:
+        elif exp == -0.5:
             root = self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
             return self.instruction_set['/'].format(one, root, **self._kwargs)
-        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
+        elif exp.is_integer and exp.is_number and - 8 < exp < 0:
             return self.instruction_set['/'].format(one,
-                                                    self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)),
+                                                    self._print(sp.Mul(*[expr.base] * (-exp), evaluate=False)),
                                                     **self._kwargs)
         else:
             raise ValueError("Generic exponential not supported: " + str(expr))
diff --git a/pystencils/backends/x86_instruction_sets.py b/pystencils/backends/x86_instruction_sets.py
index db3dc362..7653c7c6 100644
--- a/pystencils/backends/x86_instruction_sets.py
+++ b/pystencils/backends/x86_instruction_sets.py
@@ -51,7 +51,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
         'makeVecConstBool': 'set[]',
         'makeVecInt': 'set[]',
         'makeVecConstInt': 'set[]',
-        
+
         'loadU': 'loadu[0]',
         'loadA': 'load[0]',
         'storeU': 'storeu[0,1]',
@@ -93,7 +93,6 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
         ("float", "avx512"): 16,
         ("int", "avx512"): 16,
     }
-
     result = {
         'width': width[(data_type, instruction_set)],
         'intwidth': width[('int', instruction_set)],
@@ -114,11 +113,6 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
         mask_suffix = '_mask' if instruction_set == 'avx512' and intrinsic_id in comparisons.keys() else ''
         result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string
 
-    result['dataTypePrefix'] = {
-        'double': "_" + pre[0:2] + pre[3:] + 'd',
-        'float': "_" + pre[0:2] + pre[3:],
-    }
-
     bit_width = result['width'] * (64 if data_type == 'double' else 32)
     result['double'] = f"__m{bit_width}d"
     result['float'] = f"__m{bit_width}"
diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index 47a529a4..4069a248 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -122,6 +122,7 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
                                   "to differently typed floating point fields")
     float_size = field_float_dtypes.pop().numpy_dtype.itemsize
     assert float_size in (8, 4)
+    # TODO: future work allow mixed precision fields
     default_float_type = 'double' if float_size == 8 else 'float'
     vector_is = get_vector_instruction_set(default_float_type, instruction_set=instruction_set)
     vector_width = vector_is['width']
@@ -130,12 +131,14 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
     strided = 'storeS' in vector_is and 'loadS' in vector_is
     keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU']
     vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal,
-                                                strided, keep_loop_stop, assume_sufficient_line_padding)
-    insert_vector_casts(kernel_ast, default_float_type)
+                                                strided, keep_loop_stop, assume_sufficient_line_padding,
+                                                default_float_type)
+    # is in vectorize_inner_loops_and_adapt_load_stores.. insert_vector_casts(kernel_ast, default_float_type)
 
 
 def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields,
-                                                strided, keep_loop_stop, assume_sufficient_line_padding):
+                                                strided, keep_loop_stop, assume_sufficient_line_padding,
+                                                default_float_type):
     """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
     all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
     inner_loops = [n for n in all_loops if n.is_innermost_loop]
@@ -158,6 +161,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
             if len(loop_nodes) == 0:
                 continue
             loop_node = loop_nodes[0]
+            # TODO loop_node is the vectorized one
 
         # Find all array accesses (indexed) that depend on the loop counter as offset
         loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over)
@@ -215,6 +219,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
             substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)})
             rng._symbols_defined = set(new_result_symbols)
         fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase))
+        insert_vector_casts(loop_node, default_float_type)
 
 
 def mask_conditionals(loop_body):
@@ -246,7 +251,8 @@ def mask_conditionals(loop_body):
 def insert_vector_casts(ast_node, default_float_type='double'):
     """Inserts necessary casts from scalar values to vector values."""
 
-    handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all, DivFunc)
+    handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all, DivFunc,
+                         sp.UnevaluatedExpr)
 
     def visit_expr(expr, default_type='double'):  # TODO get rid of default_type
         if isinstance(expr, VectorMemoryAccess):
@@ -254,8 +260,8 @@ def insert_vector_casts(ast_node, default_float_type='double'):
         elif isinstance(expr, CastFunc):
             cast_type = expr.args[1]
             arg = visit_expr(expr.args[0])
-            assert(cast_type in [BasicType('float32'), BasicType('float64')],
-                   f'Vectorization cannot vectorize type {cast_type}')
+            assert cast_type in [BasicType('float32'), BasicType('float64')],\
+                f'Vectorization cannot vectorize type {cast_type}'
             return expr.func(arg, VectorType(cast_type))
         elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set:
             new_arg = visit_expr(expr.args[0], default_type)
@@ -325,8 +331,8 @@ def insert_vector_casts(ast_node, default_float_type='double'):
                 # TODO only if not remainder loop (? if no VectorAccess then remainder loop)
                 assignment = arg
                 # If there is a remainder loop we do not vectorise it, thus lhs will indicate this
-                if isinstance(assignment.lhs, ast.ResolvedFieldAccess):
-                    continue
+                # if isinstance(assignment.lhs, ast.ResolvedFieldAccess):
+                    # continue
                 subs_expr = fast_subs(assignment.rhs, substitution_dict,
                                       skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
                 assignment.rhs = visit_expr(subs_expr, default_type)
diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py
index 9c9a99c3..55070e54 100644
--- a/pystencils_tests/test_vectorization.py
+++ b/pystencils_tests/test_vectorization.py
@@ -1,5 +1,7 @@
 import numpy as np
 
+import pytest
+
 import pystencils.config
 import sympy as sp
 
@@ -21,17 +23,14 @@ else:
 # CI:
 # FAILED pystencils_tests/test_vectorization.py::test_vectorised_pow - NotImple...
 # FAILED pystencils_tests/test_vectorization.py::test_inplace_update - NotImple...
-# FAILED pystencils_tests/test_vectorization.py::test_vectorization_fixed_size
 # FAILED pystencils_tests/test_vectorization.py::test_vectorised_fast_approximations
-# FAILED pystencils_tests/test_vectorization.py::test_vectorization_variable_size
+# test_issue40
 
 # Jan:
-# test_aligned_and_nt_stores
-# test_aligned_and_nt_stores_openmp
-# test_hardware_query
-# test_vectorised_fast_approximations
+# test_vectorised_pow
+# test_issue40
 
-# TODO: Skip tests if no instruction set is available
+# TODO: Skip tests if no instruction set is available and check all codes if they are really vectorised !
 def test_vector_type_propagation(instruction_set=instruction_set):
     a, b, c, d, e = sp.symbols("a b c d e")
     arr = np.ones((2 ** 2 + 2, 2 ** 3 + 2))
@@ -136,7 +135,7 @@ def test_vectorization_fixed_size(instruction_set=instruction_set):
         code = ps.get_code_str(ast)
         add_instruction = instructions["+"][:instructions["+"].find("(")]
         assert add_instruction in code
-        # print(code)
+        print(code)
 
         func = ast.compile()
         dst = np.zeros_like(arr)
@@ -289,6 +288,7 @@ def test_vectorised_pow(instruction_set=instruction_set):
 
 
 def test_vectorised_fast_approximations(instruction_set=instruction_set):
+    # fast_approximations are a gpu thing
     arr = np.zeros((24, 24))
     f, g = ps.fields(f=arr, g=arr)
 
@@ -296,18 +296,24 @@ def test_vectorised_fast_approximations(instruction_set=instruction_set):
     assignment = ps.Assignment(g[0, 0], insert_fast_sqrts(expr))
     ast = ps.create_kernel(assignment)
     vectorize(ast, instruction_set=instruction_set)
-    ast.compile()
+
+    with pytest.raises(Exception):
+        ast.compile()
 
     expr = f[0, 0] / f[1, 0]
     assignment = ps.Assignment(g[0, 0], insert_fast_divisions(expr))
     ast = ps.create_kernel(assignment)
     vectorize(ast, instruction_set=instruction_set)
-    ast.compile()
+
+    with pytest.raises(Exception):
+        ast.compile()
 
     assignment = ps.Assignment(sp.Symbol("tmp"), 3 / sp.sqrt(f[0, 0] + f[1, 0]))
     ast = ps.create_kernel(insert_fast_sqrts(assignment))
     vectorize(ast, instruction_set=instruction_set)
-    ast.compile()
+
+    with pytest.raises(Exception):
+        ast.compile()
 
 
 def test_issue40(*_):
-- 
GitLab