From 521faf6af73e819b2e3d255670fd533344ab5dc8 Mon Sep 17 00:00:00 2001
From: vy28quve <Fabian.Boehm@fau.de>
Date: Mon, 23 Oct 2023 16:08:46 +0200
Subject: [PATCH] Small fix for cast func: consider it as vectorized already in
 is_scalar

---
 pystencils/astnodes.py          | 3 ++-
 pystencils/cpu/vectorization.py | 5 ++++-
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 115d49ec8..5551580a2 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -55,7 +55,8 @@ class Node:
         for arg in self.args:
             if isinstance(arg, arg_type):
                 result.add(arg)
-            result.update(arg.atoms(arg_type))
+            arg_atoms = arg.atoms(arg_type) if (isinstance(arg, sp.Basic) or isinstance(arg, Node) or isinstance(arg, sp.Expr)) else {}
+            result.update(arg_atoms)
         return result
 
 
diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index a33c23f3a..6249a9303 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -260,6 +260,9 @@ def insert_vector_casts(ast_node, instruction_set, loop_counter_symbol, default_
     handled_functions = (sp.Add, sp.Mul, vec_any, vec_all, DivFunc, sp.Abs)
 
     def is_scalar(expr) -> bool:
+        if isinstance(expr, CastFunc) and expr.args[1] in [BasicType('float32'), BasicType('float64')]:
+            return False
+
         if hasattr(expr, "dtype"):
             if type(expr.dtype) is VectorType:
                 return False
@@ -285,7 +288,7 @@ def insert_vector_casts(ast_node, instruction_set, loop_counter_symbol, default_
         if isinstance(expr, ast.ResolvedFieldAccess):
             #assert (expr.args[1] != loop_counter_symbol) and (expr.args[1].name != loop_counter_symbol.name), f"ResolvedFieldAccess {expr} is indexed by the coordinate that is vectorized over, not handled."
             if force_vectorize:
-                    return CastFunc(expr, VectorType(expr.field.dtype, instruction_set['width']))
+                return CastFunc(expr, VectorType(expr.field.dtype, instruction_set['width']))
             else:
                 return  expr
         elif isinstance(expr, VectorMemoryAccess):
-- 
GitLab