From e485585005ba0e681af05545c65e6473acbfe629 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 22 Oct 2024 11:39:16 +0200 Subject: [PATCH] Adapt visitors, pt. 1 - Make freeze emit `PsBufferAcc`s - Typify `PsBufferAcc`s - Implement `LowerToC` - Subsume `EraseAnonStructs` into `LowerToC` --- .../backend/kernelcreation/context.py | 1 + .../backend/kernelcreation/freeze.py | 11 +- .../backend/kernelcreation/typification.py | 7 +- .../backend/transformations/__init__.py | 6 +- .../erase_anonymous_structs.py | 108 -------------- .../backend/transformations/lower_to_c.py | 140 ++++++++++++++++++ src/pystencils/kernelcreation.py | 4 +- tests/nbackend/kernelcreation/test_freeze.py | 6 +- .../transformations/test_lower_to_c.py | 115 ++++++++++++++ 9 files changed, 269 insertions(+), 129 deletions(-) delete mode 100644 src/pystencils/backend/transformations/erase_anonymous_structs.py create mode 100644 src/pystencils/backend/transformations/lower_to_c.py create mode 100644 tests/nbackend/transformations/test_lower_to_c.py diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index e75144dee..558cc8a0e 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -240,6 +240,7 @@ class KernelCreationContext: self._fields_collection.index_fields.add(field) case FieldType.CUSTOM: + buf = self._create_regular_field_buffer(field) self._fields_collection.custom_fields.add(field) case _: diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 1cadf1fa4..bdc8f1133 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -363,18 +363,11 @@ class FreezeExpressions: # For canonical representation, there must always be at least one index dimension indices = [PsExpression.make(PsConstant(0))] - summands = tuple( - idx * PsExpression.make(stride) - for idx, stride in zip(offsets + indices, array.strides, strict=True) - ) - - index = summands[0] if len(summands) == 1 else reduce(add, summands) - if struct_member_name is not None: # Produce a Lookup here, don't check yet if the member name is valid. That's the typifier's job. - return PsLookup(PsBufferAcc(ptr, index), struct_member_name) + return PsLookup(PsBufferAcc(ptr, offsets + indices), struct_member_name) else: - return PsBufferAcc(ptr, index) + return PsBufferAcc(ptr, offsets + indices) def map_ConditionalFieldAccess(self, acc: ConditionalFieldAccess): facc = self.visit_expr(acc.access) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 222b2cac3..975ea5a60 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -413,9 +413,10 @@ class Typifier: case PsLiteralExpr(lit): tc.apply_dtype(lit.dtype, expr) - case PsBufferAcc(bptr, idx): - tc.apply_dtype(bptr.array.element_type, expr) - self._handle_idx(idx) + case PsBufferAcc(_, indices): + tc.apply_dtype(expr.buffer.element_type, expr) + for idx in indices: + self._handle_idx(idx) case PsMemAcc(ptr, offset): ptr_tc = TypeContext() diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index 88ad9348f..7375af618 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -69,7 +69,7 @@ Loop Reshaping Transformations Code Lowering and Materialization --------------------------------- -.. autoclass:: EraseAnonymousStructTypes +.. autoclass:: LowerToC :members: __call__ .. autoclass:: SelectFunctions @@ -84,7 +84,7 @@ from .eliminate_branches import EliminateBranches from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .reshape_loops import ReshapeLoops from .add_pragmas import InsertPragmasAtLoops, LoopPragma, AddOpenMP -from .erase_anonymous_structs import EraseAnonymousStructTypes +from .lower_to_c import LowerToC from .select_functions import SelectFunctions from .select_intrinsics import MaterializeVectorIntrinsics @@ -98,7 +98,7 @@ __all__ = [ "InsertPragmasAtLoops", "LoopPragma", "AddOpenMP", - "EraseAnonymousStructTypes", + "LowerToC", "SelectFunctions", "MaterializeVectorIntrinsics", ] diff --git a/src/pystencils/backend/transformations/erase_anonymous_structs.py b/src/pystencils/backend/transformations/erase_anonymous_structs.py deleted file mode 100644 index 607877592..000000000 --- a/src/pystencils/backend/transformations/erase_anonymous_structs.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import annotations - -from ..kernelcreation.context import KernelCreationContext - -from ..constants import PsConstant -from ..ast.structural import PsAstNode -from ..ast.expressions import ( - PsBufferAcc, - PsLookup, - PsExpression, - PsMemAcc, - PsAddressOf, - PsCast, -) -from ..kernelcreation import Typifier -from ...types import PsStructType, PsPointerType - - -class EraseAnonymousStructTypes: - """Lower anonymous struct arrays to a byte-array representation. - - For arrays whose element type is an anonymous struct, the struct type is erased from the base pointer, - making it a pointer to uint8_t. - Member lookups on accesses into these arrays are then transformed using type casts. - """ - - def __init__(self, ctx: KernelCreationContext) -> None: - self._ctx = ctx - - self._substitutions: dict[PsArrayBasePointer, TypeErasedBasePointer] = dict() - - def __call__(self, node: PsAstNode) -> PsAstNode: - self._substitutions = dict() - - # Check if AST traversal is even necessary - if not any( - (isinstance(arr.element_type, PsStructType) and arr.element_type.anonymous) - for arr in self._ctx.arrays - ): - return node - - node = self.visit(node) - - for old, new in self._substitutions.items(): - self._ctx.replace_symbol(old, new) - - return node - - def visit(self, node: PsAstNode) -> PsAstNode: - match node: - case PsLookup(): - # descend into expr - return self.handle_lookup(node) - case _: - node.children = [self.visit(c) for c in node.children] - - return node - - def handle_lookup(self, lookup: PsLookup) -> PsExpression: - aggr = lookup.aggregate - if not isinstance(aggr, PsBufferAcc): - return lookup - - arr = aggr.buffer - if ( - not isinstance(arr.element_type, PsStructType) - or not arr.element_type.anonymous - ): - return lookup - - struct_type = arr.element_type - struct_size = struct_type.itemsize - - bp = aggr.base_ptr - - # Need to keep track of base pointers already seen, since symbols must be unique - if bp not in self._substitutions: - type_erased_bp = TypeErasedBasePointer(bp.name, arr) - self._substitutions[bp] = type_erased_bp - else: - type_erased_bp = self._substitutions[bp] - - base_index = aggr.index * PsExpression.make( - PsConstant(struct_size, self._ctx.index_dtype) - ) - - member_name = lookup.member_name - member = struct_type.find_member(member_name) - assert member is not None - - np_struct = struct_type.numpy_dtype - assert np_struct is not None - assert np_struct.fields is not None - member_offset = np_struct.fields[member_name][1] - - byte_index = base_index + PsExpression.make( - PsConstant(member_offset, self._ctx.index_dtype) - ) - type_erased_access = PsBufferAcc(type_erased_bp, byte_index) - - deref = PsMemAcc( - PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access)), - PsExpression.make(PsConstant(0)) - ) - - typify = Typifier(self._ctx) - deref = typify(deref) - return deref diff --git a/src/pystencils/backend/transformations/lower_to_c.py b/src/pystencils/backend/transformations/lower_to_c.py new file mode 100644 index 000000000..8fd6f89ba --- /dev/null +++ b/src/pystencils/backend/transformations/lower_to_c.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from typing import cast +from functools import reduce +import operator + +from ..kernelcreation import KernelCreationContext, Typifier + +from ..constants import PsConstant +from ..memory import PsSymbol, PsBuffer, BufferBasePtr +from ..ast.structural import PsAstNode +from ..ast.expressions import ( + PsBufferAcc, + PsLookup, + PsExpression, + PsMemAcc, + PsAddressOf, + PsCast, + PsSymbolExpr, +) +from ...types import PsStructType, PsPointerType + + +class LowerToC: + """Lower high-level IR constructs to C language concepts. + + This pass will replace a number of IR constructs that have no direct counterpart in the C language + to lower-level AST nodes. These include: + + - *Linearization of Buffer Accesses:* `PsBufferAcc` buffer accesses are linearized according to + their buffers' stride information and replaced by `PsMemAcc`. + - *Erasure of Anonymous Structs:* + For buffers whose element type is an anonymous struct, the struct type is erased from the base pointer, + making it a pointer to uint8_t. + Member lookups on accesses into these buffers are then transformed using type casts. + """ + + def __init__(self, ctx: KernelCreationContext) -> None: + self._ctx = ctx + self._typify = Typifier(ctx) + + self._substitutions: dict[PsSymbol, PsSymbol] = dict() + + def __call__(self, node: PsAstNode) -> PsAstNode: + self._substitutions = dict() + + node = self.visit(node) + + for old, new in self._substitutions.items(): + self._ctx.replace_symbol(old, new) + + return node + + def visit(self, node: PsAstNode) -> PsAstNode: + match node: + case PsBufferAcc(bptr, indices): + # Linearize + buf = node.buffer + + # Typifier allows different data types in each index + def maybe_cast(i: PsExpression): + if i.get_dtype() != buf.index_type: + return PsCast(buf.index_type, i) + else: + return i + + summands: list[PsExpression] = [ + maybe_cast(idx) * PsExpression.make(stride) + for idx, stride in zip(indices, buf.strides, strict=True) + ] + + linearized_idx: PsExpression = ( + summands[0] + if len(summands) == 1 + else reduce(operator.add, summands) + ) + + mem_acc = PsMemAcc(bptr, linearized_idx) + + return self._typify.typify_expression( + mem_acc, target_type=buf.element_type + )[0] + + case PsLookup(aggr, member_name) if isinstance( + aggr, PsBufferAcc + ) and isinstance( + aggr.buffer.element_type, PsStructType + ) and aggr.buffer.element_type.anonymous: + # Need to lower this buffer-lookup + linearized_acc = self.visit(aggr) + return self._lower_anon_lookup( + cast(PsMemAcc, linearized_acc), aggr.buffer, member_name + ) + + case _: + node.children = [self.visit(c) for c in node.children] + + return node + + def _lower_anon_lookup( + self, aggr: PsMemAcc, buf: PsBuffer, member_name: str + ) -> PsExpression: + struct_type = cast(PsStructType, buf.element_type) + struct_size = struct_type.itemsize + + assert isinstance(aggr.pointer, PsSymbolExpr) + bp = aggr.pointer.symbol + + # Need to keep track of base pointers already seen, since symbols must be unique + if bp not in self._substitutions: + type_erased_bp = PsSymbol(bp.name, bp.dtype) + type_erased_bp.add_property(BufferBasePtr(buf)) + self._substitutions[bp] = type_erased_bp + else: + type_erased_bp = self._substitutions[bp] + + base_index = aggr.offset * PsExpression.make( + PsConstant(struct_size, self._ctx.index_dtype) + ) + + member = struct_type.find_member(member_name) + assert member is not None + + np_struct = struct_type.numpy_dtype + assert np_struct is not None + assert np_struct.fields is not None + member_offset = np_struct.fields[member_name][1] + + byte_index = base_index + PsExpression.make( + PsConstant(member_offset, self._ctx.index_dtype) + ) + type_erased_access = PsMemAcc(PsExpression.make(type_erased_bp), byte_index) + + deref = PsMemAcc( + PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access)), + PsExpression.make(PsConstant(0)), + ) + + deref = self._typify(deref) + return deref diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index ae64bdea3..4ebafece7 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -20,7 +20,7 @@ from .backend.kernelcreation.iteration_space import ( from .backend.transformations import ( EliminateConstants, - EraseAnonymousStructTypes, + LowerToC, SelectFunctions, ) from .backend.kernelfunction import ( @@ -143,7 +143,7 @@ def create_kernel( kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) - erase_anons = EraseAnonymousStructTypes(ctx) + erase_anons = LowerToC(ctx) kernel_ast = cast(PsBlock, erase_anons(kernel_ast)) select_functions = SelectFunctions(platform) diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 67e5d2319..ce4f61785 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -115,13 +115,11 @@ def test_freeze_fields(): lhs = PsBufferAcc( f_arr.base_pointer, - (PsExpression.make(counter) + zero) * PsExpression.make(f_arr.strides[0]) - + zero * one, + (PsExpression.make(counter) + zero, zero) ) rhs = PsBufferAcc( g_arr.base_pointer, - (PsExpression.make(counter) + zero) * PsExpression.make(g_arr.strides[0]) - + zero * one, + (PsExpression.make(counter) + zero, zero) ) should = PsAssignment(lhs, rhs) diff --git a/tests/nbackend/transformations/test_lower_to_c.py b/tests/nbackend/transformations/test_lower_to_c.py new file mode 100644 index 000000000..1c8a2c8f3 --- /dev/null +++ b/tests/nbackend/transformations/test_lower_to_c.py @@ -0,0 +1,115 @@ +from functools import reduce +from operator import add + +from pystencils import fields, Assignment, make_slice, Field +from pystencils.types import PsStructType, create_type + +from pystencils.backend.memory import BufferBasePtr +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) +from pystencils.backend.transformations import LowerToC + +from pystencils.backend.ast.expressions import ( + PsBufferAcc, + PsMemAcc, + PsSymbolExpr, + PsExpression, + PsLookup, + PsAddressOf +) +from pystencils.backend.ast.structural import PsAssignment + + +def test_lower_buffer_accesses(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:42, :31]) + ctx.set_iteration_space(ispace) + + lower = LowerToC(ctx) + + f, g = fields("f(2), g(3): [2D]") + asm = Assignment(f.center(1), g[-1, 1](2)) + + f_buf = ctx.get_buffer(f) + g_buf = ctx.get_buffer(g) + + fasm = factory.parse_sympy(asm) + assert isinstance(fasm.lhs, PsBufferAcc) + assert isinstance(fasm.rhs, PsBufferAcc) + + fasm_lowered = lower(fasm) + assert isinstance(fasm_lowered, PsAssignment) + + assert isinstance(fasm_lowered.lhs, PsMemAcc) + assert isinstance(fasm_lowered.lhs.pointer, PsSymbolExpr) + assert fasm_lowered.lhs.pointer.symbol == f_buf.base_pointer + + zero = factory.parse_index(0) + expected_offset = reduce( + add, + ( + (PsExpression.make(dm.counter) + zero) * PsExpression.make(stride) + for dm, stride in zip(ispace.dimensions, f_buf.strides) + ), + ) + factory.parse_index(1) * PsExpression.make(f_buf.strides[-1]) + assert fasm_lowered.lhs.offset.structurally_equal(expected_offset) + + assert isinstance(fasm_lowered.rhs, PsMemAcc) + assert isinstance(fasm_lowered.rhs.pointer, PsSymbolExpr) + assert fasm_lowered.rhs.pointer.symbol == g_buf.base_pointer + + expected_offset = ( + (PsExpression.make(ispace.dimensions[0].counter) + factory.parse_index(-1)) + * PsExpression.make(g_buf.strides[0]) + + (PsExpression.make(ispace.dimensions[1].counter) + factory.parse_index(1)) + * PsExpression.make(g_buf.strides[1]) + + factory.parse_index(2) * PsExpression.make(g_buf.strides[-1]) + ) + assert fasm_lowered.rhs.offset.structurally_equal(expected_offset) + + +def test_lower_anonymous_structs(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:12]) + ctx.set_iteration_space(ispace) + + lower = LowerToC(ctx) + + stype = PsStructType( + [ + ("val", ctx.default_dtype), + ("x", ctx.index_dtype), + ] + ) + sfield = Field.create_generic("s", spatial_dimensions=1, dtype=stype) + + asm = Assignment(sfield.center("val"), 31.2) + + fasm = factory.parse_sympy(asm) + + sbuf = ctx.get_buffer(sfield) + + assert isinstance(fasm, PsAssignment) + assert isinstance(fasm.lhs, PsLookup) + + lowered_fasm = lower(fasm.clone()) + assert isinstance(lowered_fasm, PsAssignment) + assert isinstance(lowered_fasm.lhs, PsMemAcc) + assert isinstance( + lowered_fasm.lhs.pointer, PsAddressOf + ) + assert isinstance( + lowered_fasm.lhs.pointer.operand, PsMemAcc + ) + type_erased_pointer = lowered_fasm.lhs.pointer.operand.pointer + + assert isinstance(type_erased_pointer, PsSymbolExpr) + assert BufferBasePtr(sbuf) in type_erased_pointer.symbol.properties + assert type_erased_pointer.symbol.dtype == create_type("restrict uint8_t *") -- GitLab