From 6fed27bbc7e59c1982b01384c75daf5d9d5bff4a Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Fri, 19 Feb 2021 16:40:58 +0100
Subject: [PATCH] some fixes for lbmpy vectorization

---
 pystencils/cpu/vectorization.py          | 2 ++
 pystencils/simp/assignment_collection.py | 5 ++++-
 2 files changed, 6 insertions(+), 1 deletion(-)

diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index 2f49d44a2..4c632b145 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -210,6 +210,8 @@ def insert_vector_casts(ast_node):
                 # special treatment for the unary minus: make sure that the -1 has the same type as the argument
                 dtype = int
                 for arg in expr.args[1:]:
+                    if type(arg) is sp.Pow:
+                        arg = arg.args[0]
                     if type(arg) is vector_memory_access and arg.dtype.base_type.is_float():
                         dtype = arg.dtype.base_type.numpy_dtype.type
                     elif type(arg) is TypedSymbol and type(arg.dtype) is VectorType and arg.dtype.base_type.is_float():
diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py
index 950644089..33102dee5 100644
--- a/pystencils/simp/assignment_collection.py
+++ b/pystencils/simp/assignment_collection.py
@@ -437,9 +437,10 @@ class AssignmentCollection:
 class SymbolGen:
     """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
 
-    def __init__(self, symbol="xi"):
+    def __init__(self, symbol="xi", dtype=None):
         self._ctr = 0
         self._symbol = symbol
+        self._dtype = dtype
 
     def __iter__(self):
         return self
@@ -447,4 +448,6 @@ class SymbolGen:
     def __next__(self):
         name = f"{self._symbol}_{self._ctr}"
         self._ctr += 1
+        if self._dtype is not None:
+            return pystencils.TypedSymbol(name, self._dtype)
         return sp.Symbol(name)
-- 
GitLab