From 9f3406f63b56d1ae2044e5618a3e625263b79f0d Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Wed, 25 Sep 2024 13:03:29 +0200
Subject: [PATCH] Fix CFunction in parameters

---
 src/pystencils/astnodes.py           | 14 ++++++++------
 src/pystencils/backends/cbackend.py  | 20 +-------------------
 src/pystencils/typing/__init__.py    |  6 +++---
 src/pystencils/typing/typed_sympy.py | 17 +++++++++++++++++
 tests/test_vectorization_specific.py | 20 +++++++++++++++++++-
 5 files changed, 48 insertions(+), 29 deletions(-)

diff --git a/src/pystencils/astnodes.py b/src/pystencils/astnodes.py
index f0755a418..a54d350bc 100644
--- a/src/pystencils/astnodes.py
+++ b/src/pystencils/astnodes.py
@@ -5,12 +5,12 @@ from typing import Any, List, Optional, Sequence, Set, Union
 
 import sympy as sp
 
-import pystencils
-from pystencils.typing.utilities import create_type, get_next_parent_of_type
+from pystencils.assignment import Assignment
 from pystencils.enums import Target, Backend
 from pystencils.field import Field
-from pystencils.typing.typed_sympy import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol
 from pystencils.sympyextensions import fast_subs
+from pystencils.typing import (CFunction, create_type, get_next_parent_of_type,
+                               FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol)
 
 NodeOrExpr = Union['Node', sp.Expr]
 
@@ -270,7 +270,9 @@ class KernelFunction(Node):
         parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
         if hasattr(self, 'indexing'):
             parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
-        parameters = [p for p in parameters if p.symbol.name != "svcntd()"]
+        # Exclude paramters of type CFunction. These parameters will result in a C function call that will be handled
+        # by including a respective header file in the compute kernel. Hence, it is not a free parameter.
+        # parameters = [p for p in parameters if not isinstance(p.symbol, CFunction)]
         parameters.sort(key=lambda p: p.symbol.name)
         return parameters
 
@@ -388,7 +390,7 @@ class Block(Node):
     def symbols_defined(self):
         result = set()
         for a in self.args:
-            if isinstance(a, pystencils.Assignment):
+            if isinstance(a, Assignment):
                 result.update(a.free_symbols)
             else:
                 result.update(a.symbols_defined)
@@ -399,7 +401,7 @@ class Block(Node):
         result = set()
         defined_symbols = set()
         for a in self.args:
-            if isinstance(a, pystencils.Assignment):
+            if isinstance(a, Assignment):
                 result.update(a.free_symbols)
                 defined_symbols.update({a.lhs})
             else:
diff --git a/src/pystencils/backends/cbackend.py b/src/pystencils/backends/cbackend.py
index 7dbf84d37..7065f7feb 100644
--- a/src/pystencils/backends/cbackend.py
+++ b/src/pystencils/backends/cbackend.py
@@ -6,7 +6,6 @@ from typing import Set
 import numpy as np
 import sympy as sp
 from sympy.core import S
-from sympy.core.cache import cacheit
 from sympy.logic.boolalg import BooleanFalse, BooleanTrue
 from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
 from sympy.functions.elementary.hyperbolic import HyperbolicFunction
@@ -15,7 +14,7 @@ from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
 from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
 from pystencils.typing import (
     PointerType, VectorType, CastFunc, create_type, get_type_of_expression,
-    ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol)
+    ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol, CFunction)
 from pystencils.enums import Backend
 from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
 from pystencils.functions import DivFunc, AddressOf
@@ -166,23 +165,6 @@ class PrintNode(CustomCodeNode):
         self.headers.append("<iostream>")
 
 
-class CFunction(TypedSymbol):
-    def __new__(cls, function, dtype):
-        return CFunction.__xnew_cached_(cls, function, dtype)
-
-    def __new_stage2__(cls, function, dtype):
-        return super(CFunction, cls).__xnew__(cls, function, dtype)
-
-    __xnew__ = staticmethod(__new_stage2__)
-    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
-
-    def __getnewargs__(self):
-        return self.name, self.dtype
-
-    def __getnewargs_ex__(self):
-        return (self.name, self.dtype), {}
-
-
 # ------------------------------------------- Printer ------------------------------------------------------------------
 
 
diff --git a/src/pystencils/typing/__init__.py b/src/pystencils/typing/__init__.py
index ae4483da4..ae3974d82 100644
--- a/src/pystencils/typing/__init__.py
+++ b/src/pystencils/typing/__init__.py
@@ -3,14 +3,14 @@ from pystencils.typing.cast_functions import (CastFunc, BooleanCastFunc, VectorM
 from pystencils.typing.types import (is_supported_type, numpy_name_to_c, AbstractType, BasicType, VectorType,
                                      PointerType, StructType, create_type)
 from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol,
-                                           FieldPointerSymbol)
+                                           FieldPointerSymbol, CFunction)
 from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types,
                                          get_type_of_expression, get_next_parent_of_type, parents_of_type)
 
 
 __all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc',
            'is_supported_type', 'numpy_name_to_c', 'AbstractType', 'BasicType',
-           'VectorType', 'PointerType', 'StructType', 'create_type',
-           'assumptions_from_dtype', 'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol',
+           'VectorType', 'PointerType', 'StructType', 'create_type', 'assumptions_from_dtype',
+           'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', 'CFunction',
            'typed_symbols', 'get_base_type', 'result_type', 'collate_types',
            'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type']
diff --git a/src/pystencils/typing/typed_sympy.py b/src/pystencils/typing/typed_sympy.py
index 302c2f998..bbaf3ef91 100644
--- a/src/pystencils/typing/typed_sympy.py
+++ b/src/pystencils/typing/typed_sympy.py
@@ -178,3 +178,20 @@ class FieldPointerSymbol(TypedSymbol):
 
     __xnew__ = staticmethod(__new_stage2__)
     __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
+
+
+class CFunction(TypedSymbol):
+    def __new__(cls, function, dtype):
+        return CFunction.__xnew_cached_(cls, function, dtype)
+
+    def __new_stage2__(cls, function, dtype):
+        return super(CFunction, cls).__xnew__(cls, function, dtype)
+
+    __xnew__ = staticmethod(__new_stage2__)
+    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
+
+    def __getnewargs__(self):
+        return self.name, self.dtype
+
+    def __getnewargs_ex__(self):
+        return (self.name, self.dtype), {}
\ No newline at end of file
diff --git a/tests/test_vectorization_specific.py b/tests/test_vectorization_specific.py
index dcebeae60..26a1ae3ec 100644
--- a/tests/test_vectorization_specific.py
+++ b/tests/test_vectorization_specific.py
@@ -8,8 +8,10 @@ import sympy as sp
 import pystencils as ps
 from pystencils.backends.simd_instruction_sets import (get_cacheline_size, get_supported_instruction_sets,
                                                        get_vector_instruction_set)
-from . import test_vectorization
 from pystencils.enums import Target
+from pystencils.typing import CFunction
+from . import test_vectorization
+
 
 supported_instruction_sets = get_supported_instruction_sets() if get_supported_instruction_sets() else []
 
@@ -32,6 +34,8 @@ def test_vectorisation_varying_arch(instruction_set):
 
     config = pystencils.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
     ast = ps.create_kernel(update_rule, config=config)
+    for parameter in ast.get_parameters():
+        assert not isinstance(parameter.symbol, CFunction)
     kernel = ast.compile()
     kernel(f=arr)
     np.testing.assert_equal(arr, 2)
@@ -51,6 +55,8 @@ def test_vectorized_abs(instruction_set, dtype):
 
     config = pystencils.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
     ast = ps.create_kernel(update_rule, config=config)
+    for parameter in ast.get_parameters():
+        assert not isinstance(parameter.symbol, CFunction)
 
     func = ast.compile()
     dst = np.zeros_like(arr)
@@ -118,6 +124,8 @@ def test_alignment_and_correct_ghost_layers(gl_field, gl_kernel, instruction_set
     config = pystencils.config.CreateKernelConfig(target=dh.default_target,
                                                   cpu_vectorize_info=opt, ghost_layers=gl_kernel)
     ast = ps.create_kernel(update_rule, config=config)
+    for parameter in ast.get_parameters():
+        assert not isinstance(parameter.symbol, CFunction)
     kernel = ast.compile()
     if ('loadA' in ast.instruction_set or 'storeA' in ast.instruction_set) and gl_kernel != gl_field:
         with pytest.raises(ValueError):
@@ -165,6 +173,8 @@ def test_square_root(dtype, instruction_set, field_layout):
           ps.Assignment(sp.Symbol("xi_2"), sp.Symbol("xi") * sp.sqrt(src_field.center))]
 
     ast = ps.create_kernel(eq, config=config)
+    for parameter in ast.get_parameters():
+        assert not isinstance(parameter.symbol, CFunction)
     ast.compile()
     code = ps.get_code_str(ast)
     print(code)
@@ -185,6 +195,8 @@ def test_square_root_2(dtype, instruction_set, padding):
 
     config = ps.CreateKernelConfig(data_type=dtype, default_number_float=dtype, cpu_vectorize_info=cpu_vec)
     ast = ps.create_kernel(up, config=config)
+    for parameter in ast.get_parameters():
+        assert not isinstance(parameter.symbol, CFunction)
     ast.compile()
 
     code = ps.get_code_str(ast)
@@ -208,6 +220,8 @@ def test_pow(dtype, instruction_set, padding):
           ps.Assignment(sp.Symbol("xi_2"), sp.Symbol("xi") * sp.Pow(src_field.center, 0.5))]
 
     ast = ps.create_kernel(eq, config=config)
+    for parameter in ast.get_parameters():
+        assert not isinstance(parameter.symbol, CFunction)
     ast.compile()
     code = ps.get_code_str(ast)
 
@@ -235,6 +249,8 @@ def test_issue62(dtype, instruction_set, padding):
                                    cpu_vectorize_info=opt)
 
     ast = ps.create_kernel(up, config=config)
+    for parameter in ast.get_parameters():
+        assert not isinstance(parameter.symbol, CFunction)
     ast.compile()
     code = ps.get_code_str(ast)
 
@@ -261,6 +277,8 @@ def test_div_and_unevaluated_expr(dtype, instruction_set):
                                    cpu_vectorize_info=opt)
 
     ast = ps.create_kernel(up, config=config)
+    for parameter in ast.get_parameters():
+        assert not isinstance(parameter.symbol, CFunction)
     code = ps.get_code_str(ast)
     # print(code)
 
-- 
GitLab