-
Frederik Hennig authored
- Add a `dtype` member to all expression nodes - Make the `Typifier` apply `dtype`s to all expressions - Adapt transformations and IterationSpace to set data types on created expressions - Refactor TypeContext and contextual typing interface to be more intuitive - Refactor the Typifier to apply more operations through the TypeContext Squashed commit of the following: commit 3e81188a318aa1dc294cf0cd11bf2ec7f62a9b55 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Wed Mar 27 17:00:17 2024 +0100 Improve typification of integer expressions - Check integer type constraint in `_apply_target_type` to correctly catch deferred expressions commit 63d0cfa5ea1b8a41c9a74bbfcf0618fad03ffa48 Merge: 671f057 075ae357 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Wed Mar 27 16:46:28 2024 +0100 Merge branch 'backend-rework' into b_refactor_typing commit 671f0578a39e452504243019dab28d93f0114082 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Tue Mar 26 16:39:43 2024 +0100 Fix documentation for Typifier and PsExpression commit 3ec258517ad8a510118265184b5dc7805128dcd3 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Mon Mar 25 17:14:21 2024 +0100 Typing refactor: - Annotate all expressions with types - Refactor Typifier for cleaner information flow and better readability - Have iteration space and transformers typify newly created AST nodes
Frederik Hennig authored- Add a `dtype` member to all expression nodes - Make the `Typifier` apply `dtype`s to all expressions - Adapt transformations and IterationSpace to set data types on created expressions - Refactor TypeContext and contextual typing interface to be more intuitive - Refactor the Typifier to apply more operations through the TypeContext Squashed commit of the following: commit 3e81188a318aa1dc294cf0cd11bf2ec7f62a9b55 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Wed Mar 27 17:00:17 2024 +0100 Improve typification of integer expressions - Check integer type constraint in `_apply_target_type` to correctly catch deferred expressions commit 63d0cfa5ea1b8a41c9a74bbfcf0618fad03ffa48 Merge: 671f057 075ae357 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Wed Mar 27 16:46:28 2024 +0100 Merge branch 'backend-rework' into b_refactor_typing commit 671f0578a39e452504243019dab28d93f0114082 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Tue Mar 26 16:39:43 2024 +0100 Fix documentation for Typifier and PsExpression commit 3ec258517ad8a510118265184b5dc7805128dcd3 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Mon Mar 25 17:14:21 2024 +0100 Typing refactor: - Annotate all expressions with types - Refactor Typifier for cleaner information flow and better readability - Have iteration space and transformers typify newly created AST nodes
erase_anonymous_structs.py 3.28 KiB
from __future__ import annotations
from ..kernelcreation.context import KernelCreationContext
from ..constants import PsConstant
from ..ast.structural import PsAstNode
from ..ast.expressions import (
PsArrayAccess,
PsLookup,
PsExpression,
PsDeref,
PsAddressOf,
PsCast,
)
from ..kernelcreation import Typifier
from ..arrays import PsArrayBasePointer, TypeErasedBasePointer
from ...types import PsStructType, PsPointerType
class EraseAnonymousStructTypes:
"""Lower anonymous struct arrays to a byte-array representation.
For arrays whose element type is an anonymous struct, the struct type is erased from the base pointer,
making it a pointer to uint8_t.
Member lookups on accesses into these arrays are then transformed using type casts.
"""
def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx
self._substitutions: dict[PsArrayBasePointer, TypeErasedBasePointer] = dict()
def __call__(self, node: PsAstNode) -> PsAstNode:
self._substitutions = dict()
# Check if AST traversal is even necessary
if not any(
(isinstance(arr.element_type, PsStructType) and arr.element_type.anonymous)
for arr in self._ctx.arrays
):
return node
node = self.visit(node)
for old, new in self._substitutions.items():
self._ctx.replace_symbol(old, new)
return node
def visit(self, node: PsAstNode) -> PsAstNode:
match node:
case PsLookup():
# descend into expr
return self.handle_lookup(node)
case _:
node.children = [self.visit(c) for c in node.children]
return node
def handle_lookup(self, lookup: PsLookup) -> PsExpression:
aggr = lookup.aggregate
if not isinstance(aggr, PsArrayAccess):
return lookup
arr = aggr.array
if (
not isinstance(arr.element_type, PsStructType)
or not arr.element_type.anonymous
):
return lookup
struct_type = arr.element_type
struct_size = struct_type.itemsize
bp = aggr.base_ptr
# Need to keep track of base pointers already seen, since symbols must be unique
if bp not in self._substitutions:
type_erased_bp = TypeErasedBasePointer(bp.name, arr)
self._substitutions[bp] = type_erased_bp
else:
type_erased_bp = self._substitutions[bp]
base_index = aggr.index * PsExpression.make(
PsConstant(struct_size, self._ctx.index_dtype)
)
member_name = lookup.member_name
member = struct_type.find_member(member_name)
assert member is not None
np_struct = struct_type.numpy_dtype
assert np_struct is not None
assert np_struct.fields is not None
member_offset = np_struct.fields[member_name][1]
byte_index = base_index + PsExpression.make(
PsConstant(member_offset, self._ctx.index_dtype)
)
type_erased_access = PsArrayAccess(type_erased_bp, byte_index)
deref = PsDeref(
PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access))
)
typify = Typifier(self._ctx)
deref = typify(deref)
return deref