From 3f32ceca0bafb9c93e0fbcae13e2679d29d95dc2 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 22 Oct 2024 14:17:06 +0200
Subject: [PATCH] adapt JIT compilers to changed kernelfunction interface. Fix
 sparse index translations.

---
 src/pystencils/backend/ast/expressions.py     |  2 +-
 .../backend/jit/cpu_extension_module.py       | 44 ++++++-------
 src/pystencils/backend/jit/gpu_cupy.py        | 64 +++++++++++--------
 src/pystencils/backend/kernelfunction.py      |  9 +--
 src/pystencils/backend/platforms/cuda.py      |  4 +-
 .../backend/platforms/generic_cpu.py          |  4 +-
 src/pystencils/backend/platforms/sycl.py      |  6 +-
 src/pystencils/boundaries/boundaryhandling.py | 10 +--
 tests/nbackend/kernelcreation/test_context.py |  3 +-
 tests/nbackend/test_extensions.py             |  5 +-
 tests/nbackend/test_memory.py                 | 22 +++----
 11 files changed, 92 insertions(+), 81 deletions(-)

diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py
index 6a04f4f95..32d06b633 100644
--- a/src/pystencils/backend/ast/expressions.py
+++ b/src/pystencils/backend/ast/expressions.py
@@ -266,7 +266,7 @@ class PsBufferAcc(PsLvalue, PsExpression):
         return PsBufferAcc(self._base_ptr.symbol, [i.clone() for i in self._index])
 
     def __repr__(self) -> str:
-        return f"PsArrayAccess({repr(self._base_ptr)}, {repr(self._index)})"
+        return f"PsBufferAcc({repr(self._base_ptr)}, {repr(self._index)})"
 
 
 class PsSubscript(PsLvalue, PsExpression):
diff --git a/src/pystencils/backend/jit/cpu_extension_module.py b/src/pystencils/backend/jit/cpu_extension_module.py
index b9b793589..dede60cba 100644
--- a/src/pystencils/backend/jit/cpu_extension_module.py
+++ b/src/pystencils/backend/jit/cpu_extension_module.py
@@ -13,11 +13,8 @@ from ..exceptions import PsInternalCompilerError
 from ..kernelfunction import (
     KernelFunction,
     KernelParameter,
-    FieldParameter,
-    FieldShapeParam,
-    FieldStrideParam,
-    FieldPointerParam,
 )
+from ..properties import FieldBasePtr, FieldShape, FieldStride
 from ..constraints import KernelParamsConstraint
 from ...types import (
     PsType,
@@ -209,7 +206,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
         self._array_extractions: dict[Field, str] = dict()
         self._array_frees: dict[Field, str] = dict()
 
-        self._array_assoc_var_extractions: dict[FieldParameter, str] = dict()
+        self._array_assoc_var_extractions: dict[KernelParameter, str] = dict()
         self._scalar_extractions: dict[KernelParameter, str] = dict()
 
         self._constraint_checks: list[str] = []
@@ -282,31 +279,34 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
 
         return param.name
 
-    def extract_array_assoc_var(self, param: FieldParameter) -> str:
+    def extract_array_assoc_var(self, param: KernelParameter) -> str:
         if param not in self._array_assoc_var_extractions:
-            field = param.field
+            field = param.fields.pop()
             buffer = self.extract_field(field)
-            match param:
-                case FieldPointerParam():
-                    code = f"{param.dtype} {param.name} = ({param.dtype}) {buffer}.buf;"
-                case FieldShapeParam():
-                    coord = param.coordinate
-                    code = f"{param.dtype} {param.name} = {buffer}.shape[{coord}];"
-                case FieldStrideParam():
-                    coord = param.coordinate
-                    code = (
-                        f"{param.dtype} {param.name} = "
-                        f"{buffer}.strides[{coord}] / {field.dtype.itemsize};"
-                    )
-                case _:
-                    assert False, "unreachable code"
+            code: str | None = None
+
+            for prop in param.properties:
+                match prop:
+                    case FieldBasePtr():
+                        code = f"{param.dtype} {param.name} = ({param.dtype}) {buffer}.buf;"
+                        break
+                    case FieldShape(_, coord):
+                        code = f"{param.dtype} {param.name} = {buffer}.shape[{coord}];"
+                        break
+                    case FieldStride(_, coord):
+                        code = (
+                            f"{param.dtype} {param.name} = "
+                            f"{buffer}.strides[{coord}] / {field.dtype.itemsize};"
+                        )
+                        break
+            assert code is not None
 
             self._array_assoc_var_extractions[param] = code
 
         return param.name
 
     def extract_parameter(self, param: KernelParameter):
-        if isinstance(param, FieldParameter):
+        if param.is_field_parameter:
             self.extract_array_assoc_var(param)
         else:
             self.extract_scalar(param)
diff --git a/src/pystencils/backend/jit/gpu_cupy.py b/src/pystencils/backend/jit/gpu_cupy.py
index d6aaac2d2..15f5f6967 100644
--- a/src/pystencils/backend/jit/gpu_cupy.py
+++ b/src/pystencils/backend/jit/gpu_cupy.py
@@ -16,11 +16,9 @@ from .jit import JitBase, JitError, KernelWrapper
 from ..kernelfunction import (
     KernelFunction,
     GpuKernelFunction,
-    FieldPointerParam,
-    FieldShapeParam,
-    FieldStrideParam,
     KernelParameter,
 )
+from ..properties import FieldShape, FieldStride, FieldBasePtr
 from ..emission import emit_code
 from ...types import PsStructType
 
@@ -98,8 +96,8 @@ class CupyKernelWrapper(KernelWrapper):
         field_shapes = set()
         index_shapes = set()
 
-        def check_shape(field_ptr: FieldPointerParam, arr: cp.ndarray):
-            field = field_ptr.field
+        def check_shape(field_ptr: KernelParameter, arr: cp.ndarray):
+            field = field_ptr.fields.pop()
 
             if field.has_fixed_shape:
                 expected_shape = tuple(int(s) for s in field.shape)
@@ -118,7 +116,7 @@ class CupyKernelWrapper(KernelWrapper):
                 if isinstance(field.dtype, PsStructType):
                     assert expected_strides[-1] == 1
                     expected_strides = expected_strides[:-1]
-                
+
                 actual_strides = tuple(s // arr.dtype.itemsize for s in arr.strides)
                 if expected_strides != actual_strides:
                     raise ValueError(
@@ -149,28 +147,38 @@ class CupyKernelWrapper(KernelWrapper):
         arr: cp.ndarray
 
         for kparam in self._kfunc.parameters:
-            match kparam:
-                case FieldPointerParam(_, dtype, field):
-                    arr = kwargs[field.name]
-                    if arr.dtype != field.dtype.numpy_dtype:
-                        raise JitError(
-                            f"Data type mismatch at array argument {field.name}:"
-                            f"Expected {field.dtype}, got {arr.dtype}"
-                        )
-                    check_shape(kparam, arr)
-                    args.append(arr)
-
-                case FieldShapeParam(name, dtype, field, coord):
-                    arr = kwargs[field.name]
-                    add_arg(name, arr.shape[coord], dtype)
-
-                case FieldStrideParam(name, dtype, field, coord):
-                    arr = kwargs[field.name]
-                    add_arg(name, arr.strides[coord] // arr.dtype.itemsize, dtype)
-
-                case KernelParameter(name, dtype):
-                    val: Any = kwargs[name]
-                    add_arg(name, val, dtype)
+            if kparam.is_field_parameter:
+                #   Determine field-associated data to pass in
+                for prop in kparam.properties:
+                    match prop:
+                        case FieldBasePtr(field):
+                            arr = kwargs[field.name]
+                            if arr.dtype != field.dtype.numpy_dtype:
+                                raise JitError(
+                                    f"Data type mismatch at array argument {field.name}:"
+                                    f"Expected {field.dtype}, got {arr.dtype}"
+                                )
+                            check_shape(kparam, arr)
+                            args.append(arr)
+                            break
+
+                        case FieldShape(field, coord):
+                            arr = kwargs[field.name]
+                            add_arg(kparam.name, arr.shape[coord], kparam.dtype)
+                            break
+
+                        case FieldStride(field, coord):
+                            arr = kwargs[field.name]
+                            add_arg(
+                                kparam.name,
+                                arr.strides[coord] // arr.dtype.itemsize,
+                                kparam.dtype,
+                            )
+                            break
+            else:
+                #   scalar parameter
+                val: Any = kwargs[kparam.name]
+                add_arg(kparam.name, val, kparam.dtype)
 
         #   Determine launch grid
         from ..ast.expressions import evaluate_expression
diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py
index a5bdab623..da0b59e8f 100644
--- a/src/pystencils/backend/kernelfunction.py
+++ b/src/pystencils/backend/kernelfunction.py
@@ -83,9 +83,9 @@ class KernelParameter:
         return set(p.field for p in filter(lambda p: isinstance(p, _FieldProperty), self.properties))  # type: ignore
 
     def get_properties(
-        self, prop_type: type[PsSymbolProperty]
+        self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...]
     ) -> set[PsSymbolProperty]:
-        """Retrieve all properties of the given type attached to this parameter"""
+        """Retrieve all properties of the given type(s) attached to this parameter"""
         return set(filter(lambda p: isinstance(p, prop_type), self._properties))
     
     @property
@@ -94,11 +94,6 @@ class KernelParameter:
 
     @property
     def is_field_parameter(self) -> bool:
-        warn(
-            "`is_field_parameter` is deprecated and will be removed in a future version of pystencils. "
-            "Check `param.fields` for emptiness instead.",
-            DeprecationWarning,
-        )
         return bool(self.fields)
 
     @property
diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py
index 6100a371b..323dcc5a9 100644
--- a/src/pystencils/backend/platforms/cuda.py
+++ b/src/pystencils/backend/platforms/cuda.py
@@ -7,6 +7,7 @@ from ..kernelcreation import (
     IterationSpace,
     FullIterationSpace,
     SparseIterationSpace,
+    AstFactory
 )
 
 from ..kernelcreation.context import KernelCreationContext
@@ -159,6 +160,7 @@ class CudaPlatform(GenericGpu):
     def _prepend_sparse_translation(
         self, body: PsBlock, ispace: SparseIterationSpace
     ) -> tuple[PsBlock, GpuThreadsRange]:
+        factory = AstFactory(self._ctx)
         ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
 
         sparse_ctr = PsExpression.make(ispace.sparse_counter)
@@ -173,7 +175,7 @@ class CudaPlatform(GenericGpu):
                 PsLookup(
                     PsBufferAcc(
                         ispace.index_list.base_pointer,
-                        (sparse_ctr,),
+                        (sparse_ctr, factory.parse_index(0)),
                     ),
                     coord.name,
                 ),
diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py
index 95aaf50c4..f8cae89fc 100644
--- a/src/pystencils/backend/platforms/generic_cpu.py
+++ b/src/pystencils/backend/platforms/generic_cpu.py
@@ -124,13 +124,15 @@ class GenericCpu(Platform):
         return PsBlock([loops])
 
     def _create_sparse_loop(self, body: PsBlock, ispace: SparseIterationSpace):
+        factory = AstFactory(self._ctx)
+
         mappings = [
             PsDeclaration(
                 PsSymbolExpr(ctr),
                 PsLookup(
                     PsBufferAcc(
                         ispace.index_list.base_pointer,
-                        (PsExpression.make(ispace.sparse_counter),),
+                        (PsExpression.make(ispace.sparse_counter), factory.parse_index(0)),
                     ),
                     coord.name,
                 ),
diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py
index b8684ce22..ec5e7eda0 100644
--- a/src/pystencils/backend/platforms/sycl.py
+++ b/src/pystencils/backend/platforms/sycl.py
@@ -20,7 +20,7 @@ from ..ast.expressions import (
 )
 from ..extensions.cpp import CppMethodCall
 
-from ..kernelcreation.context import KernelCreationContext
+from ..kernelcreation import KernelCreationContext, AstFactory
 from ..constants import PsConstant
 from .generic_gpu import GenericGpu, GpuThreadsRange
 from ..exceptions import MaterializationError
@@ -147,6 +147,8 @@ class SyclPlatform(GenericGpu):
     def _prepend_sparse_translation(
         self, body: PsBlock, ispace: SparseIterationSpace
     ) -> tuple[PsBlock, GpuThreadsRange]:
+        factory = AstFactory(self._ctx)
+        
         id_type = PsCustomType("sycl::id< 1 >", const=True)
         id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type))
 
@@ -165,7 +167,7 @@ class SyclPlatform(GenericGpu):
                 PsLookup(
                     PsBufferAcc(
                         ispace.index_list.base_pointer,
-                        (sparse_ctr,),
+                        (sparse_ctr, factory.parse_index(0)),
                     ),
                     coord.name,
                 ),
diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py
index c7657ec51..52ded8ab2 100644
--- a/src/pystencils/boundaries/boundaryhandling.py
+++ b/src/pystencils/boundaries/boundaryhandling.py
@@ -12,7 +12,7 @@ from pystencils.types import PsIntegerType
 from pystencils.types.quick import Arr, SInt
 from pystencils.gpu.gpu_array_handler import GPUArrayHandler
 from pystencils.field import Field, FieldType
-from pystencils.backend.kernelfunction import FieldPointerParam
+from pystencils.backend.properties import FieldBasePtr
 
 try:
     # noinspection PyPep8Naming
@@ -244,9 +244,9 @@ class BoundaryHandling:
             for b_obj, idx_arr in b[self._index_array_name].boundary_object_to_index_list.items():
                 kwargs[self._field_name] = b[self._field_name]
                 kwargs['indexField'] = idx_arr
-                data_used_in_kernel = (p.field.name
+                data_used_in_kernel = (p.fields.pop().name
                                        for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters
-                                       if isinstance(p, FieldPointerParam) and p.field.name not in kwargs)
+                                       if bool(p.get_properties(FieldBasePtr)) and p.fields.pop().name not in kwargs)
                 kwargs.update({name: b[name] for name in data_used_in_kernel})
 
                 self._boundary_object_to_boundary_info[b_obj].kernel(**kwargs)
@@ -260,9 +260,9 @@ class BoundaryHandling:
                 arguments = kwargs.copy()
                 arguments[self._field_name] = b[self._field_name]
                 arguments['indexField'] = idx_arr
-                data_used_in_kernel = (p.field.name
+                data_used_in_kernel = (p.fields.pop().name
                                        for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters
-                                       if isinstance(p, FieldPointerParam) and p.field.name not in arguments)
+                                       if bool(p.get_properties(FieldBasePtr)) and p.fields.pop().name not in arguments)
                 arguments.update({name: b[name] for name in data_used_in_kernel if name not in arguments})
 
                 kernel = self._boundary_object_to_boundary_info[b_obj].kernel
diff --git a/tests/nbackend/kernelcreation/test_context.py b/tests/nbackend/kernelcreation/test_context.py
index ff766e6b5..384fc9315 100644
--- a/tests/nbackend/kernelcreation/test_context.py
+++ b/tests/nbackend/kernelcreation/test_context.py
@@ -5,7 +5,8 @@ from pystencils import Field, TypedSymbol, FieldType, DynamicType
 
 from pystencils.backend.kernelcreation import KernelCreationContext
 from pystencils.backend.constants import PsConstant
-from pystencils.backend.memory import PsSymbol, FieldShape, FieldStride
+from pystencils.backend.memory import PsSymbol
+from pystencils.backend.properties import FieldShape, FieldStride
 from pystencils.backend.exceptions import KernelConstraintsError
 from pystencils.types.quick import SInt, Fp
 from pystencils.types import deconstify
diff --git a/tests/nbackend/test_extensions.py b/tests/nbackend/test_extensions.py
index 914d05594..b1403185c 100644
--- a/tests/nbackend/test_extensions.py
+++ b/tests/nbackend/test_extensions.py
@@ -3,7 +3,7 @@ import sympy as sp
 
 from pystencils import make_slice, Field, Assignment
 from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, FullIterationSpace
-from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations
+from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations, LowerToC
 from pystencils.backend.literals import PsLiteral
 from pystencils.backend.emission import CAstPrinter
 from pystencils.backend.ast.expressions import PsExpression, PsSubscript
@@ -46,6 +46,9 @@ def test_literals():
     hoist = HoistLoopInvariantDeclarations(ctx)
     ast = hoist(ast)
 
+    lower = LowerToC(ctx)
+    ast = lower(ast)
+
     assert isinstance(ast, PsBlock)
     assert len(ast.statements) == 2
     assert ast.statements[0] == x_decl
diff --git a/tests/nbackend/test_memory.py b/tests/nbackend/test_memory.py
index fb2ab340e..5841e0f4f 100644
--- a/tests/nbackend/test_memory.py
+++ b/tests/nbackend/test_memory.py
@@ -1,8 +1,7 @@
 import pytest
 
-from typing import ClassVar
 from dataclasses import dataclass
-from pystencils.backend.memory import PsSymbol, PsSymbolProperty
+from pystencils.backend.memory import PsSymbol, PsSymbolProperty, UniqueSymbolProperty
 
 
 def test_properties():
@@ -16,9 +15,8 @@ def test_properties():
         s: str
 
     @dataclass(frozen=True)
-    class UniqueProperty(PsSymbolProperty):
+    class MyUniqueProperty(UniqueSymbolProperty):
         val: int
-        _unique: ClassVar[bool] = True
 
     s = PsSymbol("s")
 
@@ -36,17 +34,17 @@ def test_properties():
 
     assert s.get_properties(NumbersProperty) == {NumbersProperty(42, 8.71)}
     
-    assert not s.get_properties(UniqueProperty)
+    assert not s.get_properties(MyUniqueProperty)
     
-    s.add_property(UniqueProperty(13))
-    assert s.get_properties(UniqueProperty) == {UniqueProperty(13)}
+    s.add_property(MyUniqueProperty(13))
+    assert s.get_properties(MyUniqueProperty) == {MyUniqueProperty(13)}
 
     #   Adding the same one again does not raise
-    s.add_property(UniqueProperty(13))
-    assert s.get_properties(UniqueProperty) == {UniqueProperty(13)}
+    s.add_property(MyUniqueProperty(13))
+    assert s.get_properties(MyUniqueProperty) == {MyUniqueProperty(13)}
 
     with pytest.raises(ValueError):
-        s.add_property(UniqueProperty(14))
+        s.add_property(MyUniqueProperty(14))
 
-    s.remove_property(UniqueProperty(13))
-    assert not s.get_properties(UniqueProperty)
+    s.remove_property(MyUniqueProperty(13))
+    assert not s.get_properties(MyUniqueProperty)
-- 
GitLab