Skip to content
Snippets Groups Projects
Commit 6a934814 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

type erasure for anonymous structs; full translation pass for index kernels

parent bd2c7d88
No related branches found
No related tags found
No related merge requests found
Pipeline #61493 failed
......@@ -50,7 +50,7 @@ from abc import ABC
import pymbolic.primitives as pb
from .types import PsAbstractType, PsPointerType, PsIntegerType, PsSignedIntegerType
from .types import PsAbstractType, PsPointerType, PsIntegerType, PsUnsignedIntegerType, PsSignedIntegerType
from .typed_expressions import PsTypedVariable, ExprOrConstant, PsTypedConstant
......@@ -110,7 +110,7 @@ class PsLinearizedArray:
@property
def name(self):
return self._name
@property
def base_pointer(self) -> PsArrayBasePointer:
return self._base_ptr
......@@ -119,9 +119,21 @@ class PsLinearizedArray:
def shape(self) -> tuple[PsArrayShapeVar | PsTypedConstant, ...]:
return self._shape
@property
def shape_spec(self) -> tuple[EllipsisType | int, ...]:
return tuple(
(s.value if isinstance(s, PsTypedConstant) else ...) for s in self._shape
)
@property
def strides(self) -> tuple[PsArrayStrideVar | PsTypedConstant, ...]:
return self._strides
@property
def strides_spec(self) -> tuple[EllipsisType | int, ...]:
return tuple(
(s.value if isinstance(s, PsTypedConstant) else ...) for s in self._strides
)
@property
def element_type(self):
......@@ -134,12 +146,8 @@ class PsLinearizedArray:
if these variables would occur in here, an infinite recursion would follow.
Hence they are filtered and replaced by the ellipsis.
"""
shape_clean = tuple(
(s if isinstance(s, PsTypedConstant) else ...) for s in self._shape
)
strides_clean = tuple(
(s if isinstance(s, PsTypedConstant) else ...) for s in self._strides
)
shape_clean = self.shape_spec
strides_clean = self.strides_spec
return (
self._name,
self._element_type,
......@@ -156,9 +164,11 @@ class PsLinearizedArray:
def __hash__(self) -> int:
return hash(self._hashable_contents())
def __repr__(self) -> str:
return f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])"
return (
f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])"
)
class PsArrayAssocVar(PsTypedVariable, ABC):
......@@ -195,6 +205,17 @@ class PsArrayBasePointer(PsArrayAssocVar):
def __getinitargs__(self):
return self.name, self.array
class TypeErasedBasePointer(PsArrayBasePointer):
"""Base pointer for arrays whose element type has been erased.
Used primarily for arrays of anonymous structs."""
def __init__(self, name: str, array: PsLinearizedArray):
dtype = PsPointerType(PsUnsignedIntegerType(8))
super(PsArrayBasePointer, self).__init__(name, dtype, array)
self._array = array
class PsArrayShapeVar(PsArrayAssocVar):
......@@ -244,7 +265,6 @@ class PsArrayStrideVar(PsArrayAssocVar):
class PsArrayAccess(pb.Subscript):
mapper_method = intern("map_array_access")
def __init__(self, base_ptr: PsArrayBasePointer, index: ExprOrConstant):
......
......@@ -14,21 +14,34 @@ from .ast import (
)
from .ast.kernelfunction import PsKernelFunction
from .typed_expressions import PsTypedVariable
from .functions import Deref, AddressOf, Cast
def emit_code(kernel: PsKernelFunction):
# TODO: Specialize for different targets
printer = CPrinter()
printer = CAstPrinter()
return printer.print(kernel)
class CPrinter:
class CExpressionsPrinter(CCodeMapper):
def map_deref(self, deref: Deref, enclosing_prec):
return "*"
def map_address_of(self, addrof: AddressOf, enclosing_prec):
return "&"
def map_cast(self, cast: Cast, enclosing_prec):
return f"({cast.target_type.c_string()})"
class CAstPrinter:
def __init__(self, indent_width=3):
self._indent_width = indent_width
self._current_indent_level = 0
self._pb_cmapper = CCodeMapper()
self._expr_printer = CExpressionsPrinter()
def indent(self, line):
return " " * self._current_indent_level + line
......@@ -60,7 +73,7 @@ class CPrinter:
@visit.case(PsExpression)
def pymb_expression(self, expr: PsExpression):
return self._pb_cmapper(expr.expression)
return self._expr_printer(expr.expression)
@visit.case(PsDeclaration)
def declaration(self, decl: PsDeclaration):
......@@ -81,7 +94,7 @@ class CPrinter:
def loop(self, loop: PsLoop):
ctr_symbol = loop.counter.symbol
assert isinstance(ctr_symbol, PsTypedVariable)
ctr = ctr_symbol.name
start_code = self.visit(loop.start)
stop_code = self.visit(loop.stop)
......
......@@ -13,12 +13,63 @@ TODO: Maybe add a way for the user to register additional functions
TODO: Figure out the best way to describe function signatures and overloads for typing
"""
from sys import intern
import pymbolic.primitives as pb
from abc import ABC, abstractmethod
from .types import PsAbstractType
from .typed_expressions import ExprOrConstant
class PsFunction(pb.FunctionSymbol, ABC):
@property
@abstractmethod
def arg_count(self) -> int:
"Number of arguments this function takes"
class Deref(PsFunction):
"""Dereferences a pointer."""
mapper_method = intern("map_deref")
@property
def arg_count(self) -> int:
return 1
deref = Deref()
class AddressOf(PsFunction):
"""Take the address of an object"""
mapper_method = intern("map_address_of")
@property
def arg_count(self) -> int:
return 1
address_of = AddressOf()
class Cast(PsFunction):
mapper_method = intern("map_cast")
"""An unsafe C-style type cast"""
def __init__(self, target_type: PsAbstractType):
self._target_type = target_type
@property
def arg_count(self) -> int:
return 1
@property
def target_type(self) -> PsAbstractType:
return self._target_type
def cast(target_type: PsAbstractType, arg: ExprOrConstant):
return Cast(target_type)(ExprOrConstant)
......@@ -12,6 +12,7 @@ from .iteration_space import (
create_sparse_iteration_space,
create_full_iteration_space,
)
from .transformations import EraseAnonymousStructTypes
def create_kernel(assignments: AssignmentCollection, options: KernelCreationOptions):
......@@ -45,6 +46,7 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti
raise NotImplementedError("Target platform not implemented")
kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
kernel_ast = EraseAnonymousStructTypes(ctx)(kernel_ast)
# 7. Apply optimizations
# - Vectorization
......
from __future__ import annotations
from typing import TypeVar
import pymbolic.primitives as pb
from pymbolic.mapper import IdentityMapper
from .context import KernelCreationContext
from ..ast import PsAstNode, PsExpression
from ..arrays import PsArrayAccess, TypeErasedBasePointer
from ..typed_expressions import PsTypedConstant
from ..types import PsStructType, PsPointerType
from ..functions import deref, address_of, Cast
NodeT = TypeVar("NodeT", bound=PsAstNode)
class EraseAnonymousStructTypes(IdentityMapper):
"""Lower anonymous struct arrays to a byte-array representation.
Arrays whose element type is an anonymous struct are transformed to arrays with element type UInt(8).
Lookups on accesses into these arrays are transformed using type casts.
"""
def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx
def __call__(self, node: NodeT) -> NodeT:
match node:
case PsExpression(expr):
# descend into expr
node.expression = self.rec(expr)
case other:
for c in other.children:
self(c)
return node
def map_lookup(self, lookup: pb.Lookup) -> pb.Expression:
aggr = lookup.aggregate
if not isinstance(aggr, PsArrayAccess):
return lookup
arr = aggr.array
if (
not isinstance(arr.element_type, PsStructType)
or not arr.element_type.anonymous
):
return lookup
struct_type = arr.element_type
struct_size = struct_type.itemsize
bp = aggr.base_ptr
type_erased_bp = TypeErasedBasePointer(bp.name, arr)
base_index = aggr.index_tuple[0] * PsTypedConstant(struct_size, self._ctx.index_dtype)
member_name = lookup.name
member = struct_type.get_member(member_name)
assert member is not None
np_struct = struct_type.numpy_dtype
assert np_struct is not None
assert np_struct.fields is not None
member_offset = np_struct.fields[member_name][1]
byte_index = base_index + PsTypedConstant(member_offset, self._ctx.index_dtype)
type_erased_access = PsArrayAccess(type_erased_bp, byte_index)
cast = Cast(PsPointerType(member.dtype))
return deref(cast(address_of(type_erased_access)))
......@@ -209,9 +209,13 @@ class PsStructType(PsAbstractType):
return self._name is None
@property
def numpy_dtype(self) -> np.dtype | None:
def numpy_dtype(self) -> np.dtype:
members = [(m.name, m.dtype.numpy_dtype) for m in self._members]
return np.dtype(members)
@property
def itemsize(self) -> int:
return self.numpy_dtype.itemsize
def c_string(self) -> str:
if self._name is None:
......
......@@ -6,7 +6,7 @@ from pystencils.nbackend.ast import *
from pystencils.nbackend.typed_expressions import *
from pystencils.nbackend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAccess
from pystencils.nbackend.types.quick import *
from pystencils.nbackend.emission import CPrinter
from pystencils.nbackend.emission import CAstPrinter
def test_basic_kernel():
......@@ -32,7 +32,7 @@ def test_basic_kernel():
func = PsKernelFunction(PsBlock([loop]), target=Target.CPU)
printer = CPrinter()
printer = CAstPrinter()
code = printer.print(func)
paramlist = func.get_parameters().params
......
......@@ -6,7 +6,9 @@ from pystencils.nbackend.types import *
from pystencils.nbackend.types.quick import *
@pytest.mark.parametrize("Type", [PsSignedIntegerType, PsUnsignedIntegerType, PsIeeeFloatType])
@pytest.mark.parametrize(
"Type", [PsSignedIntegerType, PsUnsignedIntegerType, PsIeeeFloatType]
)
def test_widths(Type):
for width in Type.SUPPORTED_WIDTHS:
assert Type(width).width == width
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment