From e485585005ba0e681af05545c65e6473acbfe629 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 22 Oct 2024 11:39:16 +0200
Subject: [PATCH] Adapt visitors, pt. 1

 - Make freeze emit `PsBufferAcc`s
 - Typify `PsBufferAcc`s
 - Implement `LowerToC`
 - Subsume `EraseAnonStructs` into `LowerToC`
---
 .../backend/kernelcreation/context.py         |   1 +
 .../backend/kernelcreation/freeze.py          |  11 +-
 .../backend/kernelcreation/typification.py    |   7 +-
 .../backend/transformations/__init__.py       |   6 +-
 .../erase_anonymous_structs.py                | 108 --------------
 .../backend/transformations/lower_to_c.py     | 140 ++++++++++++++++++
 src/pystencils/kernelcreation.py              |   4 +-
 tests/nbackend/kernelcreation/test_freeze.py  |   6 +-
 .../transformations/test_lower_to_c.py        | 115 ++++++++++++++
 9 files changed, 269 insertions(+), 129 deletions(-)
 delete mode 100644 src/pystencils/backend/transformations/erase_anonymous_structs.py
 create mode 100644 src/pystencils/backend/transformations/lower_to_c.py
 create mode 100644 tests/nbackend/transformations/test_lower_to_c.py

diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py
index e75144dee..558cc8a0e 100644
--- a/src/pystencils/backend/kernelcreation/context.py
+++ b/src/pystencils/backend/kernelcreation/context.py
@@ -240,6 +240,7 @@ class KernelCreationContext:
                 self._fields_collection.index_fields.add(field)
 
             case FieldType.CUSTOM:
+                buf = self._create_regular_field_buffer(field)
                 self._fields_collection.custom_fields.add(field)
 
             case _:
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 1cadf1fa4..bdc8f1133 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -363,18 +363,11 @@ class FreezeExpressions:
                 # For canonical representation, there must always be at least one index dimension
                 indices = [PsExpression.make(PsConstant(0))]
 
-        summands = tuple(
-            idx * PsExpression.make(stride)
-            for idx, stride in zip(offsets + indices, array.strides, strict=True)
-        )
-
-        index = summands[0] if len(summands) == 1 else reduce(add, summands)
-
         if struct_member_name is not None:
             # Produce a Lookup here, don't check yet if the member name is valid. That's the typifier's job.
-            return PsLookup(PsBufferAcc(ptr, index), struct_member_name)
+            return PsLookup(PsBufferAcc(ptr, offsets + indices), struct_member_name)
         else:
-            return PsBufferAcc(ptr, index)
+            return PsBufferAcc(ptr, offsets + indices)
 
     def map_ConditionalFieldAccess(self, acc: ConditionalFieldAccess):
         facc = self.visit_expr(acc.access)
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 222b2cac3..975ea5a60 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -413,9 +413,10 @@ class Typifier:
             case PsLiteralExpr(lit):
                 tc.apply_dtype(lit.dtype, expr)
 
-            case PsBufferAcc(bptr, idx):
-                tc.apply_dtype(bptr.array.element_type, expr)
-                self._handle_idx(idx)
+            case PsBufferAcc(_, indices):
+                tc.apply_dtype(expr.buffer.element_type, expr)
+                for idx in indices:
+                    self._handle_idx(idx)
 
             case PsMemAcc(ptr, offset):
                 ptr_tc = TypeContext()
diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py
index 88ad9348f..7375af618 100644
--- a/src/pystencils/backend/transformations/__init__.py
+++ b/src/pystencils/backend/transformations/__init__.py
@@ -69,7 +69,7 @@ Loop Reshaping Transformations
 Code Lowering and Materialization
 ---------------------------------
 
-.. autoclass:: EraseAnonymousStructTypes
+.. autoclass:: LowerToC
     :members: __call__
 
 .. autoclass:: SelectFunctions
@@ -84,7 +84,7 @@ from .eliminate_branches import EliminateBranches
 from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations
 from .reshape_loops import ReshapeLoops
 from .add_pragmas import InsertPragmasAtLoops, LoopPragma, AddOpenMP
-from .erase_anonymous_structs import EraseAnonymousStructTypes
+from .lower_to_c import LowerToC
 from .select_functions import SelectFunctions
 from .select_intrinsics import MaterializeVectorIntrinsics
 
@@ -98,7 +98,7 @@ __all__ = [
     "InsertPragmasAtLoops",
     "LoopPragma",
     "AddOpenMP",
-    "EraseAnonymousStructTypes",
+    "LowerToC",
     "SelectFunctions",
     "MaterializeVectorIntrinsics",
 ]
diff --git a/src/pystencils/backend/transformations/erase_anonymous_structs.py b/src/pystencils/backend/transformations/erase_anonymous_structs.py
deleted file mode 100644
index 607877592..000000000
--- a/src/pystencils/backend/transformations/erase_anonymous_structs.py
+++ /dev/null
@@ -1,108 +0,0 @@
-from __future__ import annotations
-
-from ..kernelcreation.context import KernelCreationContext
-
-from ..constants import PsConstant
-from ..ast.structural import PsAstNode
-from ..ast.expressions import (
-    PsBufferAcc,
-    PsLookup,
-    PsExpression,
-    PsMemAcc,
-    PsAddressOf,
-    PsCast,
-)
-from ..kernelcreation import Typifier
-from ...types import PsStructType, PsPointerType
-
-
-class EraseAnonymousStructTypes:
-    """Lower anonymous struct arrays to a byte-array representation.
-
-    For arrays whose element type is an anonymous struct, the struct type is erased from the base pointer,
-    making it a pointer to uint8_t.
-    Member lookups on accesses into these arrays are then transformed using type casts.
-    """
-
-    def __init__(self, ctx: KernelCreationContext) -> None:
-        self._ctx = ctx
-
-        self._substitutions: dict[PsArrayBasePointer, TypeErasedBasePointer] = dict()
-
-    def __call__(self, node: PsAstNode) -> PsAstNode:
-        self._substitutions = dict()
-
-        #   Check if AST traversal is even necessary
-        if not any(
-            (isinstance(arr.element_type, PsStructType) and arr.element_type.anonymous)
-            for arr in self._ctx.arrays
-        ):
-            return node
-
-        node = self.visit(node)
-
-        for old, new in self._substitutions.items():
-            self._ctx.replace_symbol(old, new)
-
-        return node
-
-    def visit(self, node: PsAstNode) -> PsAstNode:
-        match node:
-            case PsLookup():
-                # descend into expr
-                return self.handle_lookup(node)
-            case _:
-                node.children = [self.visit(c) for c in node.children]
-
-        return node
-
-    def handle_lookup(self, lookup: PsLookup) -> PsExpression:
-        aggr = lookup.aggregate
-        if not isinstance(aggr, PsBufferAcc):
-            return lookup
-
-        arr = aggr.buffer
-        if (
-            not isinstance(arr.element_type, PsStructType)
-            or not arr.element_type.anonymous
-        ):
-            return lookup
-
-        struct_type = arr.element_type
-        struct_size = struct_type.itemsize
-
-        bp = aggr.base_ptr
-
-        #   Need to keep track of base pointers already seen, since symbols must be unique
-        if bp not in self._substitutions:
-            type_erased_bp = TypeErasedBasePointer(bp.name, arr)
-            self._substitutions[bp] = type_erased_bp
-        else:
-            type_erased_bp = self._substitutions[bp]
-
-        base_index = aggr.index * PsExpression.make(
-            PsConstant(struct_size, self._ctx.index_dtype)
-        )
-
-        member_name = lookup.member_name
-        member = struct_type.find_member(member_name)
-        assert member is not None
-
-        np_struct = struct_type.numpy_dtype
-        assert np_struct is not None
-        assert np_struct.fields is not None
-        member_offset = np_struct.fields[member_name][1]
-
-        byte_index = base_index + PsExpression.make(
-            PsConstant(member_offset, self._ctx.index_dtype)
-        )
-        type_erased_access = PsBufferAcc(type_erased_bp, byte_index)
-
-        deref = PsMemAcc(
-            PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access)),
-            PsExpression.make(PsConstant(0))
-        )
-
-        typify = Typifier(self._ctx)
-        deref = typify(deref)
-        return deref
diff --git a/src/pystencils/backend/transformations/lower_to_c.py b/src/pystencils/backend/transformations/lower_to_c.py
new file mode 100644
index 000000000..8fd6f89ba
--- /dev/null
+++ b/src/pystencils/backend/transformations/lower_to_c.py
@@ -0,0 +1,140 @@
+from __future__ import annotations
+
+from typing import cast
+from functools import reduce
+import operator
+
+from ..kernelcreation import KernelCreationContext, Typifier
+
+from ..constants import PsConstant
+from ..memory import PsSymbol, PsBuffer, BufferBasePtr
+from ..ast.structural import PsAstNode
+from ..ast.expressions import (
+    PsBufferAcc,
+    PsLookup,
+    PsExpression,
+    PsMemAcc,
+    PsAddressOf,
+    PsCast,
+    PsSymbolExpr,
+)
+from ...types import PsStructType, PsPointerType
+
+
+class LowerToC:
+    """Lower high-level IR constructs to C language concepts.
+
+    This pass will replace a number of IR constructs that have no direct counterpart in the C language
+    to lower-level AST nodes. These include:
+
+    - *Linearization of Buffer Accesses:* `PsBufferAcc` buffer accesses are linearized according to
+      their buffers' stride information and replaced by `PsMemAcc`.
+    - *Erasure of Anonymous Structs:*
+      For buffers whose element type is an anonymous struct, the struct type is erased from the base pointer,
+      making it a pointer to uint8_t.
+      Member lookups on accesses into these buffers are then transformed using type casts.
+    """
+
+    def __init__(self, ctx: KernelCreationContext) -> None:
+        self._ctx = ctx
+        self._typify = Typifier(ctx)
+
+        self._substitutions: dict[PsSymbol, PsSymbol] = dict()
+
+    def __call__(self, node: PsAstNode) -> PsAstNode:
+        self._substitutions = dict()
+
+        node = self.visit(node)
+
+        for old, new in self._substitutions.items():
+            self._ctx.replace_symbol(old, new)
+
+        return node
+
+    def visit(self, node: PsAstNode) -> PsAstNode:
+        match node:
+            case PsBufferAcc(bptr, indices):
+                #   Linearize
+                buf = node.buffer
+
+                #   Typifier allows different data types in each index
+                def maybe_cast(i: PsExpression):
+                    if i.get_dtype() != buf.index_type:
+                        return PsCast(buf.index_type, i)
+                    else:
+                        return i
+
+                summands: list[PsExpression] = [
+                    maybe_cast(idx) * PsExpression.make(stride)
+                    for idx, stride in zip(indices, buf.strides, strict=True)
+                ]
+
+                linearized_idx: PsExpression = (
+                    summands[0]
+                    if len(summands) == 1
+                    else reduce(operator.add, summands)
+                )
+
+                mem_acc = PsMemAcc(bptr, linearized_idx)
+
+                return self._typify.typify_expression(
+                    mem_acc, target_type=buf.element_type
+                )[0]
+
+            case PsLookup(aggr, member_name) if isinstance(
+                aggr, PsBufferAcc
+            ) and isinstance(
+                aggr.buffer.element_type, PsStructType
+            ) and aggr.buffer.element_type.anonymous:
+                #   Need to lower this buffer-lookup
+                linearized_acc = self.visit(aggr)
+                return self._lower_anon_lookup(
+                    cast(PsMemAcc, linearized_acc), aggr.buffer, member_name
+                )
+
+            case _:
+                node.children = [self.visit(c) for c in node.children]
+
+        return node
+
+    def _lower_anon_lookup(
+        self, aggr: PsMemAcc, buf: PsBuffer, member_name: str
+    ) -> PsExpression:
+        struct_type = cast(PsStructType, buf.element_type)
+        struct_size = struct_type.itemsize
+
+        assert isinstance(aggr.pointer, PsSymbolExpr)
+        bp = aggr.pointer.symbol
+
+        #   Need to keep track of base pointers already seen, since symbols must be unique
+        if bp not in self._substitutions:
+            type_erased_bp = PsSymbol(bp.name, bp.dtype)
+            type_erased_bp.add_property(BufferBasePtr(buf))
+            self._substitutions[bp] = type_erased_bp
+        else:
+            type_erased_bp = self._substitutions[bp]
+
+        base_index = aggr.offset * PsExpression.make(
+            PsConstant(struct_size, self._ctx.index_dtype)
+        )
+
+        member = struct_type.find_member(member_name)
+        assert member is not None
+
+        np_struct = struct_type.numpy_dtype
+        assert np_struct is not None
+        assert np_struct.fields is not None
+        member_offset = np_struct.fields[member_name][1]
+
+        byte_index = base_index + PsExpression.make(
+            PsConstant(member_offset, self._ctx.index_dtype)
+        )
+        type_erased_access = PsMemAcc(PsExpression.make(type_erased_bp), byte_index)
+
+        deref = PsMemAcc(
+            PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access)),
+            PsExpression.make(PsConstant(0)),
+        )
+
+        deref = self._typify(deref)
+        return deref
diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py
index ae64bdea3..4ebafece7 100644
--- a/src/pystencils/kernelcreation.py
+++ b/src/pystencils/kernelcreation.py
@@ -20,7 +20,7 @@ from .backend.kernelcreation.iteration_space import (
 
 from .backend.transformations import (
     EliminateConstants,
-    EraseAnonymousStructTypes,
+    LowerToC,
     SelectFunctions,
 )
 from .backend.kernelfunction import (
@@ -143,7 +143,7 @@ def create_kernel(
 
         kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim)
 
-    erase_anons = EraseAnonymousStructTypes(ctx)
+    erase_anons = LowerToC(ctx)
     kernel_ast = cast(PsBlock, erase_anons(kernel_ast))
 
     select_functions = SelectFunctions(platform)
diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py
index 67e5d2319..ce4f61785 100644
--- a/tests/nbackend/kernelcreation/test_freeze.py
+++ b/tests/nbackend/kernelcreation/test_freeze.py
@@ -115,13 +115,11 @@ def test_freeze_fields():
 
     lhs = PsBufferAcc(
         f_arr.base_pointer,
-        (PsExpression.make(counter) + zero) * PsExpression.make(f_arr.strides[0])
-        + zero * one,
+        (PsExpression.make(counter) + zero, zero)
     )
     rhs = PsBufferAcc(
         g_arr.base_pointer,
-        (PsExpression.make(counter) + zero) * PsExpression.make(g_arr.strides[0])
-        + zero * one,
+        (PsExpression.make(counter) + zero, zero)
     )
 
     should = PsAssignment(lhs, rhs)
diff --git a/tests/nbackend/transformations/test_lower_to_c.py b/tests/nbackend/transformations/test_lower_to_c.py
new file mode 100644
index 000000000..1c8a2c8f3
--- /dev/null
+++ b/tests/nbackend/transformations/test_lower_to_c.py
@@ -0,0 +1,115 @@
+from functools import reduce
+from operator import add
+
+from pystencils import fields, Assignment, make_slice, Field
+from pystencils.types import PsStructType, create_type
+
+from pystencils.backend.memory import BufferBasePtr
+from pystencils.backend.kernelcreation import (
+    KernelCreationContext,
+    AstFactory,
+    FullIterationSpace,
+)
+from pystencils.backend.transformations import LowerToC
+
+from pystencils.backend.ast.expressions import (
+    PsBufferAcc,
+    PsMemAcc,
+    PsSymbolExpr,
+    PsExpression,
+    PsLookup,
+    PsAddressOf
+)
+from pystencils.backend.ast.structural import PsAssignment
+
+
+def test_lower_buffer_accesses():
+    ctx = KernelCreationContext()
+    factory = AstFactory(ctx)
+
+    ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:42, :31])
+    ctx.set_iteration_space(ispace)
+
+    lower = LowerToC(ctx)
+
+    f, g = fields("f(2), g(3): [2D]")
+    asm = Assignment(f.center(1), g[-1, 1](2))
+
+    f_buf = ctx.get_buffer(f)
+    g_buf = ctx.get_buffer(g)
+
+    fasm = factory.parse_sympy(asm)
+    assert isinstance(fasm.lhs, PsBufferAcc)
+    assert isinstance(fasm.rhs, PsBufferAcc)
+
+    fasm_lowered = lower(fasm)
+    assert isinstance(fasm_lowered, PsAssignment)
+
+    assert isinstance(fasm_lowered.lhs, PsMemAcc)
+    assert isinstance(fasm_lowered.lhs.pointer, PsSymbolExpr)
+    assert fasm_lowered.lhs.pointer.symbol == f_buf.base_pointer
+
+    zero = factory.parse_index(0)
+    expected_offset = reduce(
+        add,
+        (
+            (PsExpression.make(dm.counter) + zero) * PsExpression.make(stride)
+            for dm, stride in zip(ispace.dimensions, f_buf.strides)
+        ),
+    ) + factory.parse_index(1) * PsExpression.make(f_buf.strides[-1])
+    assert fasm_lowered.lhs.offset.structurally_equal(expected_offset)
+
+    assert isinstance(fasm_lowered.rhs, PsMemAcc)
+    assert isinstance(fasm_lowered.rhs.pointer, PsSymbolExpr)
+    assert fasm_lowered.rhs.pointer.symbol == g_buf.base_pointer
+
+    expected_offset = (
+        (PsExpression.make(ispace.dimensions[0].counter) + factory.parse_index(-1))
+        * PsExpression.make(g_buf.strides[0])
+        + (PsExpression.make(ispace.dimensions[1].counter) + factory.parse_index(1))
+        * PsExpression.make(g_buf.strides[1])
+        + factory.parse_index(2) * PsExpression.make(g_buf.strides[-1])
+    )
+    assert fasm_lowered.rhs.offset.structurally_equal(expected_offset)
+
+
+def test_lower_anonymous_structs():
+    ctx = KernelCreationContext()
+    factory = AstFactory(ctx)
+
+    ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:12])
+    ctx.set_iteration_space(ispace)
+
+    lower = LowerToC(ctx)
+
+    stype = PsStructType(
+        [
+            ("val", ctx.default_dtype),
+            ("x", ctx.index_dtype),
+        ]
+    )
+    sfield = Field.create_generic("s", spatial_dimensions=1, dtype=stype)
+    
+    asm = Assignment(sfield.center("val"), 31.2)
+
+    fasm = factory.parse_sympy(asm)
+
+    sbuf = ctx.get_buffer(sfield)
+
+    assert isinstance(fasm, PsAssignment)
+    assert isinstance(fasm.lhs, PsLookup)
+
+    lowered_fasm = lower(fasm.clone())
+    assert isinstance(lowered_fasm, PsAssignment)
+    assert isinstance(lowered_fasm.lhs, PsMemAcc)
+    assert isinstance(
+        lowered_fasm.lhs.pointer, PsAddressOf
+    )
+    assert isinstance(
+        lowered_fasm.lhs.pointer.operand, PsMemAcc
+    )
+    type_erased_pointer = lowered_fasm.lhs.pointer.operand.pointer
+    
+    assert isinstance(type_erased_pointer, PsSymbolExpr)
+    assert BufferBasePtr(sbuf) in type_erased_pointer.symbol.properties
+    assert type_erased_pointer.symbol.dtype == create_type("restrict uint8_t *")
-- 
GitLab