Skip to content
Snippets Groups Projects
Commit e5e1a95c authored by Frederik Hennig's avatar Frederik Hennig
Browse files

add freeze and typify unit tests. various minor fixes

parent 5ab75b3a
Branches
Tags
No related merge requests found
Pipeline #61147 failed
Showing
with 232 additions and 50 deletions
......@@ -156,6 +156,9 @@ class PsLinearizedArray:
def __hash__(self) -> int:
return hash(self._hashable_contents())
def __repr__(self) -> str:
return f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])"
class PsArrayAssocVar(PsTypedVariable, ABC):
......
......@@ -2,9 +2,11 @@ from __future__ import annotations
from typing import Sequence, Iterable, cast, TypeAlias
from types import NoneType
from pymbolic.primitives import Variable
from abc import ABC, abstractmethod
from ..typed_expressions import PsTypedVariable, ExprOrConstant
from ..typed_expressions import ExprOrConstant
from ..arrays import PsArrayAccess
from .util import failing_cast
......@@ -35,6 +37,15 @@ class PsAstNode(ABC):
def set_child(self, idx: int, c: PsAstNode):
...
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsAstNode):
return False
return type(self) is type(other) and self.children == other.children
def __hash__(self) -> int:
return hash((type(self), self.children))
class PsBlock(PsAstNode):
__match_args__ = ("statements",)
......@@ -56,6 +67,10 @@ class PsBlock(PsAstNode):
def statements(self, stm: Sequence[PsAstNode]):
self._statements = list(stm)
def __repr__(self) -> str:
contents = ", ".join(repr(c) for c in self.children)
return f"PsBlock( {contents} )"
class PsLeafNode(PsAstNode):
def get_children(self) -> tuple[PsAstNode, ...]:
......@@ -81,12 +96,23 @@ class PsExpression(PsLeafNode):
def expression(self, expr: ExprOrConstant):
self._expr = expr
def __repr__(self) -> str:
return repr(self._expr)
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsExpression):
return False
return type(self) is type(other) and self._expr == other._expr
def __hash__(self) -> int:
return hash((type(self), self._expr))
class PsLvalueExpr(PsExpression):
"""Wrapper around pymbolics expressions that may occur at the left-hand side of an assignment"""
def __init__(self, expr: PsLvalue):
if not isinstance(expr, (PsTypedVariable, PsArrayAccess)):
if not isinstance(expr, (Variable, PsArrayAccess)):
raise TypeError("Expression was not a valid lvalue")
super(PsLvalueExpr, self).__init__(expr)
......@@ -97,19 +123,19 @@ class PsSymbolExpr(PsLvalueExpr):
__match_args__ = ("symbol",)
def __init__(self, symbol: PsTypedVariable):
def __init__(self, symbol: Variable):
super().__init__(symbol)
@property
def symbol(self) -> PsTypedVariable:
return cast(PsTypedVariable, self._expr)
def symbol(self) -> Variable:
return cast(Variable, self._expr)
@symbol.setter
def symbol(self, symbol: PsTypedVariable):
def symbol(self, symbol: Variable):
self._expr = symbol
PsLvalue: TypeAlias = PsTypedVariable | PsArrayAccess
PsLvalue: TypeAlias = Variable | PsArrayAccess
"""Types of expressions that may occur on the left-hand side of assignments."""
......@@ -151,6 +177,9 @@ class PsAssignment(PsAstNode):
else:
assert False, "unreachable code"
def __repr__(self) -> str:
return f"PsAssignment({repr(self._lhs)}, {repr(self._rhs)})"
class PsDeclaration(PsAssignment):
__match_args__ = (
......@@ -186,6 +215,9 @@ class PsDeclaration(PsAssignment):
else:
assert False, "unreachable code"
def __repr__(self) -> str:
return f"PsDeclaration({repr(self._lhs)}, {repr(self._rhs)})"
class PsLoop(PsAstNode):
__match_args__ = ("counter", "start", "stop", "step", "body")
......
from .options import KernelCreationOptions
from .kernelcreation import create_kernel
from .context import KernelCreationContext
from .analysis import KernelAnalysis
from .freeze import FreezeExpressions
from .typification import Typifier
from .iteration_space import FullIterationSpace, SparseIterationSpace
__all__ = [
"KernelCreationOptions",
"create_kernel",
"KernelCreationContext",
"KernelAnalysis",
"FreezeExpressions",
"Typifier",
"FullIterationSpace",
"SparseIterationSpace",
]
from __future__ import annotations
from typing import cast
from dataclasses import dataclass
from ...field import Field, FieldType
......@@ -16,12 +15,12 @@ from .options import KernelCreationOptions
from .iteration_space import IterationSpace, FullIterationSpace, SparseIterationSpace
@dataclass
class FieldsInKernel:
domain_fields: set[Field] = set()
index_fields: set[Field] = set()
custom_fields: set[Field] = set()
buffer_fields: set[Field] = set()
def __init__(self) -> None:
self.domain_fields: set[Field] = set()
self.index_fields: set[Field] = set()
self.custom_fields: set[Field] = set()
self.buffer_fields: set[Field] = set()
class KernelCreationContext:
......@@ -70,6 +69,8 @@ class KernelCreationContext:
def constraints(self) -> tuple[PsKernelConstraint, ...]:
return tuple(self._constraints)
# Fields and Arrays
@property
def fields(self) -> FieldsInKernel:
return self._fields_collection
......@@ -113,7 +114,9 @@ class KernelCreationContext:
self._arrays[field] = arr
return self._arrays[field]
return self._arrays[field]
# Iteration Space
def set_iteration_space(self, ispace: IterationSpace):
if self._ispace is not None:
......
......@@ -2,11 +2,18 @@ import pymbolic.primitives as pb
from pymbolic.interop.sympy import SympyToPymbolicMapper
from ...field import Field, FieldType
from ...typing import BasicType
from .context import KernelCreationContext
from ..ast.nodes import PsAssignment
from ..types import PsSignedIntegerType, PsIeeeFloatType, PsUnsignedIntegerType
from ..ast.nodes import (
PsAssignment,
PsDeclaration,
PsSymbolExpr,
PsLvalueExpr,
PsExpression,
)
from ..types import constify, make_type
from ..typed_expressions import PsTypedVariable
from ..arrays import PsArrayAccess
......@@ -18,19 +25,21 @@ class FreezeExpressions(SympyToPymbolicMapper):
def map_Assignment(self, expr): # noqa
lhs = self.rec(expr.lhs)
rhs = self.rec(expr.rhs)
return PsAssignment(lhs, rhs)
def map_BasicType(self, expr):
width = expr.numpy_dtype.itemsize * 8
const = expr.const
if expr.is_float():
return PsIeeeFloatType(width, const)
elif expr.is_uint():
return PsUnsignedIntegerType(width, const)
elif expr.is_int():
return PsSignedIntegerType(width, const)
if isinstance(lhs, pb.Variable):
return PsDeclaration(PsSymbolExpr(lhs), PsExpression(rhs))
elif isinstance(lhs, PsArrayAccess):
return PsAssignment(PsLvalueExpr(lhs), PsExpression(rhs))
else:
assert False, "That should not have happened."
def map_BasicType(self, expr: BasicType):
# TODO: This should not be necessary; the frontend should use the new type system.
dtype = make_type(expr.numpy_dtype.type)
if expr.const:
return constify(dtype)
else:
raise NotImplementedError("Data type not supported.")
return dtype
def map_FieldShapeSymbol(self, expr):
dtype = self.rec(expr.dtype)
......@@ -53,7 +62,10 @@ class FreezeExpressions(SympyToPymbolicMapper):
case FieldType.GENERIC:
# Add the iteration counters
offsets = [
i + o for i, o in zip(self._ctx.get_iteration_space().spatial_indices, offsets)
i + o
for i, o in zip(
self._ctx.get_iteration_space().spatial_indices, offsets
)
]
case FieldType.INDEXED:
# flake8: noqa
......@@ -68,11 +80,11 @@ class FreezeExpressions(SympyToPymbolicMapper):
f"Cannot translate accesses to field type {unknown} yet."
)
index = pb.Sum(
tuple(
idx * stride
for idx, stride in zip(offsets + indices, array.strides, strict=True)
)
summands = tuple(
idx * stride
for idx, stride in zip(offsets + indices, array.strides, strict=True)
)
index = summands[0] if len(summands) == 1 else pb.Sum(summands)
return PsArrayAccess(ptr, index)
......@@ -19,14 +19,11 @@ from .iteration_space import (
def create_kernel(assignments: AssignmentCollection, options: KernelCreationOptions):
# 1. Prepare context
ctx = KernelCreationContext(options)
# 2. Check kernel constraints and collect knowledge
analysis = KernelAnalysis(ctx)
analysis(assignments)
# 3. Create iteration space
ispace: IterationSpace = (
create_sparse_iteration_space(ctx, assignments)
if len(ctx.fields.index_fields) > 0
......@@ -35,13 +32,9 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti
ctx.set_iteration_space(ispace)
# 4. Freeze assignments
# This call is the same for both domain and indexed kernels
freeze = FreezeExpressions(ctx)
kernel_body: PsBlock = freeze(assignments)
# 5. Typify
# Also the same for both types of kernels
typify = Typifier(ctx)
kernel_body = typify(kernel_body)
......
......@@ -105,9 +105,9 @@ class Typifier(Mapper):
def map_array_access(
self, access: PsArrayAccess, target_type: PsNumericType | None
) -> tuple[PsArrayAccess, PsNumericType]:
self._check_target_type(access, access.array.element_type, target_type)
self._check_target_type(access, access.dtype, target_type)
index, _ = self.rec(access.index_tuple[0], self._ctx.options.index_dtype)
return PsArrayAccess(access.base_ptr, index), cast(PsNumericType, access.array.element_type)
return PsArrayAccess(access.base_ptr, index), cast(PsNumericType, access.dtype)
# Arithmetic Expressions
......@@ -116,7 +116,7 @@ class Typifier(Mapper):
expr: pb.Expression,
args: Sequence[Any],
target_type: PsNumericType | None,
) -> tuple[Sequence[ExprOrConstant], PsNumericType]:
) -> tuple[tuple[ExprOrConstant], PsNumericType]:
"""Typify all arguments of a multi-argument expression with the same type."""
new_args = [None] * len(args)
common_type: PsNumericType | None = None
......@@ -134,7 +134,7 @@ class Typifier(Mapper):
assert common_type is not None
return cast(Sequence[ExprOrConstant], new_args), common_type
return cast(tuple[ExprOrConstant], tuple(new_args)), common_type
def map_sum(
self, expr: pb.Sum, target_type: PsNumericType | None
......
......@@ -80,6 +80,8 @@ class PsTypedConstant:
Usage of `//` and the pymbolic `FloorDiv` is illegal.
"""
__match_args__ = ("value", "dtype")
@staticmethod
def try_create(value: Any, dtype: PsNumericType):
try:
......@@ -100,6 +102,10 @@ class PsTypedConstant:
self._dtype = constify(dtype)
self._value = self._dtype.create_constant(value)
@property
def value(self) -> Any:
return self._value
@property
def dtype(self) -> PsNumericType:
return self._dtype
......
......@@ -12,6 +12,8 @@ from .basic_types import (
deconstify,
)
from .quick import make_type
from .exception import PsTypeError
__all__ = [
......@@ -26,5 +28,6 @@ __all__ = [
"PsIeeeFloatType",
"constify",
"deconstify",
"make_type",
"PsTypeError",
]
......@@ -381,10 +381,7 @@ class PsIeeeFloatType(PsScalarType):
def create_constant(self, value: Any) -> Any:
np_type = self.NUMPY_TYPES[self._width]
if isinstance(value, int) and value in (0, 1, -1):
return np_type(value)
if isinstance(value, float):
if isinstance(value, int) or isinstance(value, float):
return np_type(value)
if isinstance(value, np_type):
......
......@@ -68,7 +68,7 @@ def parse_type_string(s: str) -> PsAbstractType:
raise ValueError(f"Could not parse token '{s}' as C type.")
case _:
raise ValueError(f"Could not parse token '{s}`' as C type.")
raise ValueError(f"Could not parse token '{s}' as C type.")
def parse_type_name(typename: str, const: bool):
......
import sympy as sp
import pymbolic.primitives as pb
from pystencils import Assignment, fields
from pystencils.nbackend.ast import (
PsAssignment,
PsDeclaration,
PsExpression,
PsSymbolExpr,
PsLvalueExpr,
)
from pystencils.nbackend.typed_expressions import PsTypedConstant, PsTypedVariable
from pystencils.nbackend.arrays import PsArrayAccess
from pystencils.nbackend.kernelcreation import (
KernelCreationOptions,
KernelCreationContext,
FreezeExpressions,
FullIterationSpace,
)
def test_freeze_simple():
options = KernelCreationOptions()
ctx = KernelCreationContext(options)
freeze = FreezeExpressions(ctx)
x, y, z = sp.symbols("x, y, z")
asm = Assignment(z, 2 * x + y)
fasm = freeze(asm)
pb_x, pb_y, pb_z = pb.variables("x y z")
assert fasm == PsDeclaration(PsSymbolExpr(pb_z), PsExpression(pb_y + 2 * pb_x))
assert fasm != PsAssignment(PsSymbolExpr(pb_z), PsExpression(pb_y + 2 * pb_x))
def test_freeze_fields():
options = KernelCreationOptions()
ctx = KernelCreationContext(options)
start = PsTypedConstant(0, ctx.index_dtype)
stop = PsTypedConstant(42, ctx.index_dtype)
step = PsTypedConstant(1, ctx.index_dtype)
counter = PsTypedVariable("ctr", ctx.index_dtype)
ispace = FullIterationSpace(
ctx, [FullIterationSpace.Dimension(start, stop, step, counter)]
)
ctx.set_iteration_space(ispace)
freeze = FreezeExpressions(ctx)
f, g = fields("f, g : [1D]")
asm = Assignment(f.center(0), g.center(0))
f_arr = ctx.get_array(f)
g_arr = ctx.get_array(g)
fasm = freeze(asm)
lhs = PsArrayAccess(f_arr.base_pointer, counter * f_arr.strides[0])
rhs = PsArrayAccess(g_arr.base_pointer, counter * g_arr.strides[0])
should = PsAssignment(PsLvalueExpr(lhs), PsExpression(rhs))
assert fasm == should
import pytest
import sympy as sp
import pymbolic.primitives as pb
from pystencils import Assignment
from pystencils.nbackend.ast import PsDeclaration
from pystencils.nbackend.types import constify
from pystencils.nbackend.typed_expressions import PsTypedConstant, PsTypedVariable
from pystencils.nbackend.kernelcreation.options import KernelCreationOptions
from pystencils.nbackend.kernelcreation.context import KernelCreationContext
from pystencils.nbackend.kernelcreation.freeze import FreezeExpressions
from pystencils.nbackend.kernelcreation.typification import Typifier
def test_typify_simple():
options = KernelCreationOptions()
ctx = KernelCreationContext(options)
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
asm = Assignment(z, 2 * x + y)
fasm = freeze(asm)
fasm = typify(fasm)
assert isinstance(fasm, PsDeclaration)
def check(expr):
match expr:
case PsTypedConstant(value, dtype):
assert value == 2
assert dtype == constify(ctx.options.default_dtype)
case PsTypedVariable(name, dtype):
assert name in "xyz"
assert dtype == ctx.options.default_dtype
case pb.Variable:
pytest.fail("Encountered untyped variable")
case pb.Sum(cs) | pb.Product(cs):
[check(c) for c in cs]
case _:
pytest.fail("Non-exhaustive pattern matcher.")
check(fasm.lhs.expression)
check(fasm.rhs.expression)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment