Skip to content
Snippets Groups Projects
Commit 521faf6a authored by Fabian Böhm's avatar Fabian Böhm
Browse files

Small fix for cast func: consider it as vectorized already in is_scalar

parent 13074410
Branches
No related tags found
No related merge requests found
......@@ -55,7 +55,8 @@ class Node:
for arg in self.args:
if isinstance(arg, arg_type):
result.add(arg)
result.update(arg.atoms(arg_type))
arg_atoms = arg.atoms(arg_type) if (isinstance(arg, sp.Basic) or isinstance(arg, Node) or isinstance(arg, sp.Expr)) else {}
result.update(arg_atoms)
return result
......
......@@ -260,6 +260,9 @@ def insert_vector_casts(ast_node, instruction_set, loop_counter_symbol, default_
handled_functions = (sp.Add, sp.Mul, vec_any, vec_all, DivFunc, sp.Abs)
def is_scalar(expr) -> bool:
if isinstance(expr, CastFunc) and expr.args[1] in [BasicType('float32'), BasicType('float64')]:
return False
if hasattr(expr, "dtype"):
if type(expr.dtype) is VectorType:
return False
......@@ -285,7 +288,7 @@ def insert_vector_casts(ast_node, instruction_set, loop_counter_symbol, default_
if isinstance(expr, ast.ResolvedFieldAccess):
#assert (expr.args[1] != loop_counter_symbol) and (expr.args[1].name != loop_counter_symbol.name), f"ResolvedFieldAccess {expr} is indexed by the coordinate that is vectorized over, not handled."
if force_vectorize:
return CastFunc(expr, VectorType(expr.field.dtype, instruction_set['width']))
return CastFunc(expr, VectorType(expr.field.dtype, instruction_set['width']))
else:
return expr
elif isinstance(expr, VectorMemoryAccess):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment