From 97ba71de912dc55b168cd14354d719a532db754a Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Sat, 28 Sep 2024 12:39:17 +0200
Subject: [PATCH] [FIX] Vector Size for SVE instruction set is in free kernel
 parameters

---
 src/pystencils/astnodes.py                    | 13 +++++++-----
 .../backends/arm_instruction_sets.py          | 10 ++++++----
 src/pystencils/backends/cbackend.py           | 20 +------------------
 .../backends/riscv_instruction_sets.py        |  4 +++-
 src/pystencils/cpu/cpujit.py                  |  4 +---
 src/pystencils/typing/__init__.py             |  6 +++---
 src/pystencils/typing/typed_sympy.py          | 17 ++++++++++++++++
 tests/test_vectorization_specific.py          | 20 ++++++++++++++++++-
 8 files changed, 58 insertions(+), 36 deletions(-)

diff --git a/src/pystencils/astnodes.py b/src/pystencils/astnodes.py
index f399287ed..c8928a218 100644
--- a/src/pystencils/astnodes.py
+++ b/src/pystencils/astnodes.py
@@ -5,12 +5,12 @@ from typing import Any, List, Optional, Sequence, Set, Union
 
 import sympy as sp
 
-import pystencils
-from pystencils.typing.utilities import create_type, get_next_parent_of_type
+from pystencils.assignment import Assignment
 from pystencils.enums import Target, Backend
 from pystencils.field import Field
-from pystencils.typing.typed_sympy import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol
 from pystencils.sympyextensions import fast_subs
+from pystencils.typing import (create_type, get_next_parent_of_type,
+                               FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol, CFunction)
 
 NodeOrExpr = Union['Node', sp.Expr]
 
@@ -270,6 +270,9 @@ class KernelFunction(Node):
         parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
         if hasattr(self, 'indexing'):
             parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
+        # Exclude paramters of type CFunction. These parameters will result in a C function call that will be handled
+        # by including a respective header file in the compute kernel. Hence, it is not a free parameter.
+        parameters = [p for p in parameters if not isinstance(p.symbol, CFunction)]
         parameters.sort(key=lambda p: p.symbol.name)
         return parameters
 
@@ -387,7 +390,7 @@ class Block(Node):
     def symbols_defined(self):
         result = set()
         for a in self.args:
-            if isinstance(a, pystencils.Assignment):
+            if isinstance(a, Assignment):
                 result.update(a.free_symbols)
             else:
                 result.update(a.symbols_defined)
@@ -398,7 +401,7 @@ class Block(Node):
         result = set()
         defined_symbols = set()
         for a in self.args:
-            if isinstance(a, pystencils.Assignment):
+            if isinstance(a, Assignment):
                 result.update(a.free_symbols)
                 defined_symbols.update({a.lhs})
             else:
diff --git a/src/pystencils/backends/arm_instruction_sets.py b/src/pystencils/backends/arm_instruction_sets.py
index 227224f4e..fc0a8c450 100644
--- a/src/pystencils/backends/arm_instruction_sets.py
+++ b/src/pystencils/backends/arm_instruction_sets.py
@@ -1,3 +1,6 @@
+from pystencils.typing import CFunction
+
+
 def get_argument_string(function_shortcut, first=''):
     args = function_shortcut[function_shortcut.index('[') + 1: -1]
     arg_string = "("
@@ -66,10 +69,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
     if instruction_set.startswith('sve') or instruction_set == 'sme':
         base_names['stream'] = 'stnt1[0, 1]'
         prefix = 'sv'
-        suffix = f'_f{bits[data_type]}' 
+        suffix = f'_f{bits[data_type]}'
     elif instruction_set == 'neon':
         prefix = 'v'
-        suffix = f'q_f{bits[data_type]}' 
+        suffix = f'q_f{bits[data_type]}'
 
     if instruction_set in ['sve', 'sve2', 'sme']:
         predicate = f'{prefix}whilelt_b{bits[data_type]}_u64({{loop_counter}}, {{loop_stop}})'
@@ -91,7 +94,6 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
         result[intrinsic_id] = prefix + name + suffix + undef + arg_string
 
     if instruction_set in ['sve', 'sve2', 'sme']:
-        from pystencils.backends.cbackend import CFunction
         result['width'] = CFunction(width, "int")
         result['intwidth'] = CFunction(intwidth, "int")
     else:
@@ -134,7 +136,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
             result['maskStoreS'] = result['storeS'].replace(predicate, '{3}')
             if instruction_set.startswith('sve2') and instruction_set not in ('sve256', 'sve2048'):
                 result['maskStreamS'] = result['streamS'].replace(predicate, '{3}')
-        
+
         result['streamFence'] = '__dmb(15)'
 
         if instruction_set == 'sme':
diff --git a/src/pystencils/backends/cbackend.py b/src/pystencils/backends/cbackend.py
index 657f60d2f..5d6647607 100644
--- a/src/pystencils/backends/cbackend.py
+++ b/src/pystencils/backends/cbackend.py
@@ -6,7 +6,6 @@ from typing import Set
 import numpy as np
 import sympy as sp
 from sympy.core import S
-from sympy.core.cache import cacheit
 from sympy.logic.boolalg import BooleanFalse, BooleanTrue
 from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
 from sympy.functions.elementary.hyperbolic import HyperbolicFunction
@@ -15,7 +14,7 @@ from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
 from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
 from pystencils.typing import (
     PointerType, VectorType, CastFunc, create_type, get_type_of_expression,
-    ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol)
+    ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol, CFunction)
 from pystencils.enums import Backend
 from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
 from pystencils.functions import DivFunc, AddressOf
@@ -166,23 +165,6 @@ class PrintNode(CustomCodeNode):
         self.headers.append("<iostream>")
 
 
-class CFunction(TypedSymbol):
-    def __new__(cls, function, dtype):
-        return CFunction.__xnew_cached_(cls, function, dtype)
-
-    def __new_stage2__(cls, function, dtype):
-        return super(CFunction, cls).__xnew__(cls, function, dtype)
-
-    __xnew__ = staticmethod(__new_stage2__)
-    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
-
-    def __getnewargs__(self):
-        return self.name, self.dtype
-
-    def __getnewargs_ex__(self):
-        return (self.name, self.dtype), {}
-
-
 # ------------------------------------------- Printer ------------------------------------------------------------------
 
 
diff --git a/src/pystencils/backends/riscv_instruction_sets.py b/src/pystencils/backends/riscv_instruction_sets.py
index 27f631e7f..0b303393e 100644
--- a/src/pystencils/backends/riscv_instruction_sets.py
+++ b/src/pystencils/backends/riscv_instruction_sets.py
@@ -1,3 +1,6 @@
+from pystencils.typing import CFunction
+
+
 def get_argument_string(function_shortcut, last=''):
     args = function_shortcut[function_shortcut.index('[') + 1: -1]
     arg_string = "("
@@ -78,7 +81,6 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'):
 
         result[intrinsic_id] = prefix + name + suffix2 + arg_string
 
-    from pystencils.backends.cbackend import CFunction
     result['width'] = CFunction(width, "int")
     result['intwidth'] = CFunction(intwidth, "int")
 
diff --git a/src/pystencils/cpu/cpujit.py b/src/pystencils/cpu/cpujit.py
index 8cc100045..d9a320e76 100644
--- a/src/pystencils/cpu/cpujit.py
+++ b/src/pystencils/cpu/cpujit.py
@@ -62,7 +62,7 @@ import numpy as np
 
 from pystencils import FieldType
 from pystencils.astnodes import LoopOverCoordinate
-from pystencils.backends.cbackend import generate_c, get_headers, CFunction
+from pystencils.backends.cbackend import generate_c, get_headers
 from pystencils.cpu.msvc_detection import get_environment
 from pystencils.include import get_pystencils_include_path
 from pystencils.kernel_wrapper import KernelWrapper
@@ -447,8 +447,6 @@ def create_function_boilerplate_code(parameter_info, name, ast_node, insert_chec
             parameters.append(f"buffer_{field.name}.strides[{param.symbol.coordinate}] / {item_size}")
         elif param.is_field_shape:
             parameters.append(f"buffer_{param.field_name}.shape[{param.symbol.coordinate}]")
-        elif type(param.symbol) is CFunction:
-            continue
         else:
             extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
             pre_call_code += template_extract_scalar.format(extract_function=extract_function,
diff --git a/src/pystencils/typing/__init__.py b/src/pystencils/typing/__init__.py
index ae4483da4..ae3974d82 100644
--- a/src/pystencils/typing/__init__.py
+++ b/src/pystencils/typing/__init__.py
@@ -3,14 +3,14 @@ from pystencils.typing.cast_functions import (CastFunc, BooleanCastFunc, VectorM
 from pystencils.typing.types import (is_supported_type, numpy_name_to_c, AbstractType, BasicType, VectorType,
                                      PointerType, StructType, create_type)
 from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol,
-                                           FieldPointerSymbol)
+                                           FieldPointerSymbol, CFunction)
 from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types,
                                          get_type_of_expression, get_next_parent_of_type, parents_of_type)
 
 
 __all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc',
            'is_supported_type', 'numpy_name_to_c', 'AbstractType', 'BasicType',
-           'VectorType', 'PointerType', 'StructType', 'create_type',
-           'assumptions_from_dtype', 'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol',
+           'VectorType', 'PointerType', 'StructType', 'create_type', 'assumptions_from_dtype',
+           'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', 'CFunction',
            'typed_symbols', 'get_base_type', 'result_type', 'collate_types',
            'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type']
diff --git a/src/pystencils/typing/typed_sympy.py b/src/pystencils/typing/typed_sympy.py
index 302c2f998..03471228a 100644
--- a/src/pystencils/typing/typed_sympy.py
+++ b/src/pystencils/typing/typed_sympy.py
@@ -178,3 +178,20 @@ class FieldPointerSymbol(TypedSymbol):
 
     __xnew__ = staticmethod(__new_stage2__)
     __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
+
+
+class CFunction(TypedSymbol):
+    def __new__(cls, function, dtype):
+        return CFunction.__xnew_cached_(cls, function, dtype)
+
+    def __new_stage2__(cls, function, dtype):
+        return super(CFunction, cls).__xnew__(cls, function, dtype)
+
+    __xnew__ = staticmethod(__new_stage2__)
+    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
+
+    def __getnewargs__(self):
+        return self.name, self.dtype
+
+    def __getnewargs_ex__(self):
+        return (self.name, self.dtype), {}
diff --git a/tests/test_vectorization_specific.py b/tests/test_vectorization_specific.py
index 19c6e0033..48dd20f51 100644
--- a/tests/test_vectorization_specific.py
+++ b/tests/test_vectorization_specific.py
@@ -8,8 +8,10 @@ import sympy as sp
 import pystencils as ps
 from pystencils.backends.simd_instruction_sets import (get_cacheline_size, get_supported_instruction_sets,
                                                        get_vector_instruction_set)
-from . import test_vectorization
 from pystencils.enums import Target
+from pystencils.typing import CFunction
+from . import test_vectorization
+
 
 supported_instruction_sets = get_supported_instruction_sets() if get_supported_instruction_sets() else []
 
@@ -274,6 +276,22 @@ def test_div_and_unevaluated_expr(dtype, instruction_set):
     assert 'pow' not in code
 
 
+@pytest.mark.parametrize('dtype', ('float32', 'float64'))
+@pytest.mark.parametrize('instruction_set', ('sve', 'sve2', 'sme', 'rvv'))
+def test_check_ast_parameters_sizeless(dtype, instruction_set):
+    f, g = ps.fields(f"f, g: {dtype}[3D]", layout='fzyx')
+
+    update_rule = [ps.Assignment(g.center(), 2 * f.center())]
+
+    config = pystencils.config.CreateKernelConfig(data_type=dtype,
+                                                  cpu_vectorize_info={'instruction_set': instruction_set})
+    ast = ps.create_kernel(update_rule, config=config)
+    ast_symbols = [p.symbol for p in ast.get_parameters()]
+    assert ast.instruction_set['width'] not in ast_symbols
+    assert ast.instruction_set['intwidth'] not in ast_symbols
+
+
+
 # TODO this test case needs a complete rework of the vectoriser. The reason is that the vectoriser does not
 # TODO vectorise symbols at the moment because they could be strides or field sizes, thus involved in pointer arithmetic
 # TODO This means that the vectoriser only works if fields are involved on the rhs.
-- 
GitLab