Skip to content
Snippets Groups Projects

some fixes for lbmpy vectorization

Merged Michael Kuron requested to merge vectorization into master
2 files
+ 8
4
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -209,10 +209,11 @@ def insert_vector_casts(ast_node):
@@ -209,10 +209,11 @@ def insert_vector_casts(ast_node):
if expr.func is sp.Mul and expr.args[0] == -1:
if expr.func is sp.Mul and expr.args[0] == -1:
# special treatment for the unary minus: make sure that the -1 has the same type as the argument
# special treatment for the unary minus: make sure that the -1 has the same type as the argument
dtype = int
dtype = int
for arg in expr.args[1:]:
for arg in expr.atoms(vector_memory_access):
if type(arg) is vector_memory_access and arg.dtype.base_type.is_float():
if arg.dtype.base_type.is_float():
dtype = arg.dtype.base_type.numpy_dtype.type
dtype = arg.dtype.base_type.numpy_dtype.type
elif type(arg) is TypedSymbol and type(arg.dtype) is VectorType and arg.dtype.base_type.is_float():
for arg in expr.atoms(TypedSymbol):
 
if type(arg.dtype) is VectorType and arg.dtype.base_type.is_float():
dtype = arg.dtype.base_type.numpy_dtype.type
dtype = arg.dtype.base_type.numpy_dtype.type
if dtype is not int:
if dtype is not int:
if dtype is np.float32:
if dtype is np.float32:
Loading