Skip to content
Snippets Groups Projects
Commit bd2c7d88 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Extend struct support; add lookup freeze and typificaiton; add various tests

parent 095436ed
Branches
Tags
No related merge requests found
Pipeline #61482 failed
Showing
with 247 additions and 60 deletions
......@@ -13,6 +13,7 @@ from .ast import (
PsConditional,
)
from .ast.kernelfunction import PsKernelFunction
from .typed_expressions import PsTypedVariable
def emit_code(kernel: PsKernelFunction):
......@@ -42,7 +43,7 @@ class CPrinter:
@visit.case(PsKernelFunction)
def function(self, func: PsKernelFunction) -> str:
params_spec = func.get_parameters()
params_str = ", ".join(f"{p.dtype} {p.name}" for p in params_spec.params)
params_str = ", ".join(f"{p.dtype.c_string()} {p.name}" for p in params_spec.params)
decl = f"FUNC_PREFIX void {func.name} ({params_str})"
body = self.visit(func.body)
return f"{decl}\n{body}"
......@@ -64,10 +65,11 @@ class CPrinter:
@visit.case(PsDeclaration)
def declaration(self, decl: PsDeclaration):
lhs_symb = decl.declared_variable.symbol
assert isinstance(lhs_symb, PsTypedVariable)
lhs_dtype = lhs_symb.dtype
rhs_code = self.visit(decl.rhs)
return self.indent(f"{lhs_dtype} {lhs_symb.name} = {rhs_code};")
return self.indent(f"{lhs_dtype.c_string()} {lhs_symb.name} = {rhs_code};")
@visit.case(PsAssignment)
def assignment(self, asm: PsAssignment):
......@@ -78,6 +80,8 @@ class CPrinter:
@visit.case(PsLoop)
def loop(self, loop: PsLoop):
ctr_symbol = loop.counter.symbol
assert isinstance(ctr_symbol, PsTypedVariable)
ctr = ctr_symbol.name
start_code = self.visit(loop.start)
stop_code = self.visit(loop.stop)
......
......@@ -120,6 +120,13 @@ class KernelCreationContext:
assert isinstance(field.dtype, (BasicType, StructType))
element_type = make_type(field.dtype.numpy_dtype)
# The frontend doesn't quite agree with itself on how to model
# fields with trivial index dimensions. Sometimes the index_shape is empty,
# sometimes its (1,). This is canonicalized here.
if not field.index_shape:
arr_shape += (1,)
arr_strides += (1,)
arr = PsLinearizedArray(
field.name, element_type, arr_shape, arr_strides, self.index_dtype
)
......
......@@ -20,12 +20,16 @@ from ..ast.nodes import (
PsLvalueExpr,
PsExpression,
)
from ..types import constify, make_type
from ..types import constify, make_type, PsStructType
from ..typed_expressions import PsTypedVariable
from ..arrays import PsArrayAccess
from ..exceptions import PsInputError
class FreezeError(Exception):
"""Signifies an error during expression freezing."""
class FreezeExpressions(SympyToPymbolicMapper):
def __init__(self, ctx: KernelCreationContext):
self._ctx = ctx
......@@ -85,7 +89,6 @@ class FreezeExpressions(SympyToPymbolicMapper):
ptr = array.base_pointer
offsets: list[pb.Expression] = [self.rec(o) for o in access.offsets]
indices: list[pb.Expression] = [self.rec(o) for o in access.index]
if not access.is_absolute_access:
match field.field_type:
......@@ -101,7 +104,7 @@ class FreezeExpressions(SympyToPymbolicMapper):
# flake8: noqa
sparse_ispace = self._ctx.get_sparse_iteration_space()
# Add sparse iteration counter to offset
assert len(offsets) == 1 # must have been checked by the context
assert len(offsets) == 1 # must have been checked by the context
offsets = [offsets[0] + sparse_ispace.sparse_counter]
case FieldType.CUSTOM:
raise ValueError("Custom fields support only absolute accesses.")
......@@ -110,6 +113,26 @@ class FreezeExpressions(SympyToPymbolicMapper):
f"Cannot translate accesses to field type {unknown} yet."
)
# If the array type is a struct, accesses are modelled using strings
# In that case, the index is empty
if isinstance(array.element_type, PsStructType):
if isinstance(access.index, str):
struct_member_name = access.index
indices = [0]
elif len(access.index) == 1 and isinstance(access.index[0], str):
struct_member_name = access.index[0]
indices = [0]
else:
raise FreezeError(
f"Unsupported access into field with struct-type elements: {access}"
)
else:
struct_member_name = None
indices = [self.rec(i) for i in access.index]
if not indices:
# For canonical representation, there must always be at least one index dimension
indices = [0]
summands = tuple(
idx * stride
for idx, stride in zip(offsets + indices, array.strides, strict=True)
......@@ -117,8 +140,12 @@ class FreezeExpressions(SympyToPymbolicMapper):
index = summands[0] if len(summands) == 1 else pb.Sum(summands)
return PsArrayAccess(ptr, index)
if struct_member_name is not None:
# Produce a pb.Lookup here, don't check yet if the member name is valid. That's the typifier's job.
return pb.Lookup(PsArrayAccess(ptr, index), struct_member_name)
else:
return PsArrayAccess(ptr, index)
def map_Function(self, func: sp.Function):
"""Map a SymPy function to a backend-supported function symbol.
......
......@@ -2,7 +2,7 @@ from typing import Sequence
from dataclasses import dataclass
from ...enums import Target
from ...field import Field
from ...field import Field, FieldType
from ..exceptions import PsOptionsError
from ..types import PsIntegerType, PsNumericType, PsIeeeFloatType
......@@ -73,3 +73,11 @@ class KernelCreationOptions:
"Parameters `iteration_slice`, `ghost_layers` and 'index_field` are mutually exclusive; "
"at most one of them may be set."
)
if (
self.index_field is not None
and self.index_field.field_type != FieldType.INDEXED
):
raise PsOptionsError(
"Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`"
)
......@@ -6,7 +6,7 @@ import pymbolic.primitives as pb
from pymbolic.mapper import Mapper
from .context import KernelCreationContext
from ..types import PsAbstractType, PsNumericType, deconstify
from ..types import PsAbstractType, PsNumericType, PsStructType, deconstify
from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant
from ..arrays import PsArrayAccess
from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment
......@@ -24,9 +24,10 @@ NodeT = TypeVar("NodeT", bound=PsAstNode)
class UndeterminedType(PsNumericType):
"""Placeholder for types that could not yet be determined by the typifier.
Instances of this class should never leave the typifier; it is an error if they do.
"""
def create_constant(self, value: Any) -> Any:
return None
......@@ -51,7 +52,7 @@ class UndeterminedType(PsNumericType):
def __eq__(self, other: object) -> bool:
self._err()
def _c_string(self) -> str:
def c_string(self) -> str:
self._err()
......@@ -69,7 +70,7 @@ class DeferredTypedConstant(PsTypedConstant):
class TypeContext:
def __init__(self, target_type: PsNumericType | None):
def __init__(self, target_type: PsAbstractType | None):
self._target_type = deconstify(target_type) if target_type is not None else None
self._deferred_constants: list[DeferredTypedConstant] = []
......@@ -78,18 +79,28 @@ class TypeContext:
dc = DeferredTypedConstant(value)
self._deferred_constants.append(dc)
return dc
elif not isinstance(self._target_type, PsNumericType):
raise TypificationError(
f"Can't typify constant with non-numeric type {self._target_type}"
)
else:
return PsTypedConstant(value, self._target_type)
def apply(self, target_type: PsNumericType):
def apply(self, target_type: PsAbstractType):
assert self._target_type is None, "Type context was already resolved"
self._target_type = deconstify(target_type)
for dc in self._deferred_constants:
if not isinstance(self._target_type, PsNumericType):
raise TypificationError(
f"Can't typify constant with non-numeric type {self._target_type}"
)
dc.resolve(self._target_type)
self._deferred_constants = []
@property
def target_type(self) -> PsNumericType | None:
def target_type(self) -> PsAbstractType | None:
return self._target_type
......@@ -194,15 +205,32 @@ class Typifier(Mapper):
return tc.make_constant(value)
# Array Access
# Array Accesses and Lookups
def map_array_access(self, access: PsArrayAccess, tc: TypeContext) -> PsArrayAccess:
self._apply_target_type(access, access.dtype, tc)
index, _ = self.rec(
index = self.rec(
access.index_tuple[0], TypeContext(self._ctx.options.index_dtype)
)
return PsArrayAccess(access.base_ptr, index)
def map_lookup(self, lookup: pb.Lookup, tc: TypeContext) -> pb.Lookup:
aggr_tc = TypeContext(None)
aggregate = self.rec(lookup.aggregate, aggr_tc)
aggr_type = aggr_tc.target_type
if not isinstance(aggr_type, PsStructType):
raise TypificationError("Aggregate type of lookup was not a struct type.")
member = aggr_type.get_member(lookup.name)
if member is None:
raise TypificationError(
f"Aggregate of type {aggr_type} does not have a member {member}."
)
self._apply_target_type(lookup, member.dtype, tc)
return pb.Lookup(aggregate, member.name)
# Arithmetic Expressions
def map_sum(self, expr: pb.Sum, tc: TypeContext) -> pb.Sum:
......@@ -210,7 +238,7 @@ class Typifier(Mapper):
def map_product(self, expr: pb.Product, tc: TypeContext) -> pb.Product:
return pb.Product(tuple(self.rec(c, tc) for c in expr.children))
# Functions
def map_call(self, expr: pb.Call, tc: TypeContext) -> pb.Call:
......@@ -218,14 +246,13 @@ class Typifier(Mapper):
TODO: Figure out how to describe function signatures
"""
raise NotImplementedError()
# Internals
def _apply_target_type(
self, expr: ExprOrConstant, expr_type: PsAbstractType, tc: TypeContext
):
if tc.target_type is None:
assert isinstance(expr_type, PsNumericType)
tc.apply(expr_type)
elif deconstify(expr_type) != tc.target_type:
raise TypificationError(
......
......@@ -67,7 +67,7 @@ class PsAbstractType(ABC):
return "const " if self._const else ""
@abstractmethod
def _c_string(self) -> str:
def c_string(self) -> str:
...
# -------------------------------------------------------------------------------------------
......@@ -79,7 +79,7 @@ class PsAbstractType(ABC):
...
def __str__(self) -> str:
return self._c_string()
return self.c_string()
@abstractmethod
def __hash__(self) -> int:
......@@ -107,7 +107,7 @@ class PsCustomType(PsAbstractType):
def __hash__(self) -> int:
return hash(("PsCustomType", self._name, self._const))
def _c_string(self) -> str:
def c_string(self) -> str:
return f"{self._const_string()} {self._name}"
def __repr__(self) -> str:
......@@ -143,8 +143,8 @@ class PsPointerType(PsAbstractType):
def __hash__(self) -> int:
return hash(("PsPointerType", self._base_type, self._restrict, self._const))
def _c_string(self) -> str:
base_str = self._base_type._c_string()
def c_string(self) -> str:
base_str = self._base_type.c_string()
restrict_str = " RESTRICT" if self._restrict else ""
return f"{base_str} *{restrict_str} {self._const_string()}"
......@@ -189,6 +189,13 @@ class PsStructType(PsAbstractType):
def members(self) -> tuple[PsStructType.Member, ...]:
return self._members
def get_member(self, member_name: str) -> PsStructType.Member | None:
"""Find a member by name"""
for m in self._members:
if m.name == member_name:
return m
return None
@property
def name(self) -> str:
if self._name is None:
......@@ -206,12 +213,18 @@ class PsStructType(PsAbstractType):
members = [(m.name, m.dtype.numpy_dtype) for m in self._members]
return np.dtype(members)
def _c_string(self) -> str:
def c_string(self) -> str:
if self._name is None:
raise PsInternalCompilerError(
"Cannot retrieve C string for anonymous struct type"
)
return self._name
def __str__(self) -> str:
if self._name is None:
return "<anonymous>"
else:
return self._name
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsStructType):
......@@ -359,7 +372,7 @@ class PsIntegerType(PsScalarType, ABC):
def __hash__(self) -> int:
return hash(("PsIntegerType", self._width, self._signed, self._const))
def _c_string(self) -> str:
def c_string(self) -> str:
prefix = "" if self._signed else "u"
return f"{self._const_string()}{prefix}int{self._width}_t"
......@@ -499,7 +512,7 @@ class PsIeeeFloatType(PsScalarType):
def __hash__(self) -> int:
return hash(("PsIeeeFloatType", self._width, self._const))
def _c_string(self) -> str:
def c_string(self) -> str:
match self._width:
case 16:
return f"{self._const_string()}half"
......
import pytest
import sympy as sp
import numpy as np
from pystencils import Assignment, Field, FieldType, AssignmentCollection
from pystencils.nbackend.kernelcreation import create_kernel, KernelCreationOptions
from pystencils.cpu.cpujit import compile_and_load
def test_indexed_kernel():
arr = np.zeros((3, 4))
dtype = np.dtype([('x', int), ('y', int), ('value', arr.dtype)])
index_arr = np.zeros((3,), dtype=dtype)
index_arr[0] = (0, 2, 3.0)
index_arr[1] = (1, 3, 42.0)
index_arr[2] = (2, 1, 5.0)
index_field = Field.create_from_numpy_array('index', index_arr, field_type=FieldType.INDEXED)
normal_field = Field.create_from_numpy_array('f', arr)
update_rule = AssignmentCollection([
Assignment(normal_field[0, 0], index_field('value'))
])
options = KernelCreationOptions(index_field=index_field)
ast = create_kernel(update_rule, options)
kernel = compile_and_load(ast)
kernel(f=arr, index=index_arr)
for i in range(index_arr.shape[0]):
np.testing.assert_allclose(arr[index_arr[i]['x'], index_arr[i]['y']], index_arr[i]['value'], atol=1e-13)
import pytest
from pystencils.field import Field, FieldType
from pystencils.nbackend.types.quick import *
from pystencils.nbackend.kernelcreation.options import (
KernelCreationOptions,
PsOptionsError,
)
def test_invalid_iteration_region_options():
idx_field = Field.create_generic(
"idx", spatial_dimensions=1, field_type=FieldType.INDEXED
)
with pytest.raises(PsOptionsError):
KernelCreationOptions(
ghost_layers=2, iteration_slice=(slice(1, -1), slice(1, -1))
)
with pytest.raises(PsOptionsError):
KernelCreationOptions(ghost_layers=2, index_field=idx_field)
def test_index_field_options():
with pytest.raises(PsOptionsError):
idx_field = Field.create_generic(
"idx", spatial_dimensions=1, field_type=FieldType.GENERIC
)
KernelCreationOptions(index_field=idx_field)
......@@ -3,10 +3,11 @@ import sympy as sp
import numpy as np
import pymbolic.primitives as pb
from pystencils import Assignment, TypedSymbol
from pystencils import Assignment, TypedSymbol, Field, FieldType
from pystencils.nbackend.ast import PsDeclaration
from pystencils.nbackend.types import constify, make_numeric_type
from pystencils.nbackend.types import constify, deconstify, PsStructType
from pystencils.nbackend.types.quick import *
from pystencils.nbackend.typed_expressions import PsTypedConstant, PsTypedVariable
from pystencils.nbackend.kernelcreation.options import KernelCreationOptions
from pystencils.nbackend.kernelcreation.context import KernelCreationContext
......@@ -45,6 +46,28 @@ def test_typify_simple():
check(fasm.rhs.expression)
def test_typify_structs():
options = KernelCreationOptions(default_dtype=Fp(32))
ctx = KernelCreationContext(options)
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
np_struct = np.dtype([("size", np.uint32), ("data", np.float32)])
f = Field.create_generic("f", 1, dtype=np_struct, field_type=FieldType.CUSTOM)
x = sp.Symbol("x")
# Good
asm = Assignment(x, f.absolute_access((0,), "data"))
fasm = freeze(asm)
fasm = typify(fasm)
# Bad
asm = Assignment(x, f.absolute_access((0,), "size"))
fasm = freeze(asm)
with pytest.raises(TypificationError):
fasm = typify(fasm)
def test_contextual_typing():
options = KernelCreationOptions()
ctx = KernelCreationContext(options)
......
import pytest
from pystencils.nbackend.types.quick import *
def test_parsing_positive():
assert make_type("const uint32_t * restrict") == Ptr(UInt(32, const=True), restrict=True)
assert make_type("float * * const") == Ptr(Ptr(Fp(32)), const=True)
assert make_type("uint16 * const") == Ptr(UInt(16), const=True)
assert make_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True)
def test_parsing_negative():
bad_specs = [
"const notatype * const",
"cnost uint32_t",
"uint45_t",
"int", # plain ints are ambiguous
"float float",
"double * int",
"bool"
]
for spec in bad_specs:
with pytest.raises(ValueError):
make_type(spec)
def test_numpy():
import numpy as np
assert make_type(np.single) == make_type(np.float32) == PsIeeeFloatType(32)
assert make_type(float) == make_type(np.double) == make_type(np.float64) == PsIeeeFloatType(64)
assert make_type(int) == make_type(np.int64) == PsSignedIntegerType(64)
......@@ -6,6 +6,54 @@ from pystencils.nbackend.types import *
from pystencils.nbackend.types.quick import *
@pytest.mark.parametrize("Type", [PsSignedIntegerType, PsUnsignedIntegerType, PsIeeeFloatType])
def test_widths(Type):
for width in Type.SUPPORTED_WIDTHS:
assert Type(width).width == width
for width in (1, 9, 33, 63):
with pytest.raises(ValueError):
Type(width)
def test_parsing_positive():
assert make_type("const uint32_t * restrict") == Ptr(
UInt(32, const=True), restrict=True
)
assert make_type("float * * const") == Ptr(Ptr(Fp(32)), const=True)
assert make_type("uint16 * const") == Ptr(UInt(16), const=True)
assert make_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True)
def test_parsing_negative():
bad_specs = [
"const notatype * const",
"cnost uint32_t",
"uint45_t",
"int", # plain ints are ambiguous
"float float",
"double * int",
"bool",
]
for spec in bad_specs:
with pytest.raises(ValueError):
make_type(spec)
def test_numpy():
import numpy as np
assert make_type(np.single) == make_type(np.float32) == PsIeeeFloatType(32)
assert (
make_type(float)
== make_type(np.double)
== make_type(np.float64)
== PsIeeeFloatType(64)
)
assert make_type(int) == make_type(np.int64) == PsSignedIntegerType(64)
@pytest.mark.parametrize(
"numpy_type",
[
......@@ -68,5 +116,6 @@ def test_struct_types():
)
assert t.anonymous
assert str(t) == "<anonymous>"
with pytest.raises(PsInternalCompilerError):
str(t)
t.c_string()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment