diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 115d49ec8c7e6b631a16dcb2c73915254b32d8e3..5551580a2315d12dd3c687bf7135928b9be23440 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -55,7 +55,8 @@ class Node: for arg in self.args: if isinstance(arg, arg_type): result.add(arg) - result.update(arg.atoms(arg_type)) + arg_atoms = arg.atoms(arg_type) if (isinstance(arg, sp.Basic) or isinstance(arg, Node) or isinstance(arg, sp.Expr)) else {} + result.update(arg_atoms) return result diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index a33c23f3a7155010d1da7e9ff86094683bdb57b6..6249a9303602f1ced020d8ecde6feed5b041e46d 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -260,6 +260,9 @@ def insert_vector_casts(ast_node, instruction_set, loop_counter_symbol, default_ handled_functions = (sp.Add, sp.Mul, vec_any, vec_all, DivFunc, sp.Abs) def is_scalar(expr) -> bool: + if isinstance(expr, CastFunc) and expr.args[1] in [BasicType('float32'), BasicType('float64')]: + return False + if hasattr(expr, "dtype"): if type(expr.dtype) is VectorType: return False @@ -285,7 +288,7 @@ def insert_vector_casts(ast_node, instruction_set, loop_counter_symbol, default_ if isinstance(expr, ast.ResolvedFieldAccess): #assert (expr.args[1] != loop_counter_symbol) and (expr.args[1].name != loop_counter_symbol.name), f"ResolvedFieldAccess {expr} is indexed by the coordinate that is vectorized over, not handled." if force_vectorize: - return CastFunc(expr, VectorType(expr.field.dtype, instruction_set['width'])) + return CastFunc(expr, VectorType(expr.field.dtype, instruction_set['width'])) else: return expr elif isinstance(expr, VectorMemoryAccess):