From 521faf6af73e819b2e3d255670fd533344ab5dc8 Mon Sep 17 00:00:00 2001 From: vy28quve <Fabian.Boehm@fau.de> Date: Mon, 23 Oct 2023 16:08:46 +0200 Subject: [PATCH] Small fix for cast func: consider it as vectorized already in is_scalar --- pystencils/astnodes.py | 3 ++- pystencils/cpu/vectorization.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 115d49ec8..5551580a2 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -55,7 +55,8 @@ class Node: for arg in self.args: if isinstance(arg, arg_type): result.add(arg) - result.update(arg.atoms(arg_type)) + arg_atoms = arg.atoms(arg_type) if (isinstance(arg, sp.Basic) or isinstance(arg, Node) or isinstance(arg, sp.Expr)) else {} + result.update(arg_atoms) return result diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index a33c23f3a..6249a9303 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -260,6 +260,9 @@ def insert_vector_casts(ast_node, instruction_set, loop_counter_symbol, default_ handled_functions = (sp.Add, sp.Mul, vec_any, vec_all, DivFunc, sp.Abs) def is_scalar(expr) -> bool: + if isinstance(expr, CastFunc) and expr.args[1] in [BasicType('float32'), BasicType('float64')]: + return False + if hasattr(expr, "dtype"): if type(expr.dtype) is VectorType: return False @@ -285,7 +288,7 @@ def insert_vector_casts(ast_node, instruction_set, loop_counter_symbol, default_ if isinstance(expr, ast.ResolvedFieldAccess): #assert (expr.args[1] != loop_counter_symbol) and (expr.args[1].name != loop_counter_symbol.name), f"ResolvedFieldAccess {expr} is indexed by the coordinate that is vectorized over, not handled." if force_vectorize: - return CastFunc(expr, VectorType(expr.field.dtype, instruction_set['width'])) + return CastFunc(expr, VectorType(expr.field.dtype, instruction_set['width'])) else: return expr elif isinstance(expr, VectorMemoryAccess): -- GitLab