From d85f682c2b37be3e7a3ca101f50c894c5247eab5 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 8 Jul 2024 21:24:50 +0200 Subject: [PATCH] some extensions to the type system --- src/pystencils/__init__.py | 3 +- .../backend/kernelcreation/freeze.py | 18 ++++-- src/pystencils/sympyextensions/typed_sympy.py | 60 ++++++++++++++----- src/pystencils/types/parsing.py | 2 + src/pystencils/types/types.py | 2 +- tests/nbackend/kernelcreation/test_freeze.py | 53 ++++++++++++---- tests/nbackend/types/test_types.py | 11 ++++ 7 files changed, 115 insertions(+), 34 deletions(-) diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 3d3b7846a..c39cd3b82 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -6,7 +6,7 @@ from . import fd from . import stencil as stencil from .display_utils import get_code_obj, get_code_str, show_code, to_dot from .field import Field, FieldType, fields -from .types import create_type +from .types import create_type, create_numeric_type from .cache import clear_cache from .config import ( CreateKernelConfig, @@ -41,6 +41,7 @@ __all__ = [ "DEFAULTS", "TypedSymbol", "create_type", + "create_numeric_type", "make_slice", "CreateKernelConfig", "CpuOptimConfig", diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 3865db38f..59fa04b3b 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -7,13 +7,12 @@ import sympy.core.relational import sympy.logic.boolalg from sympy.codegen.ast import AssignmentBase, AugmentedAssignment +from ...sympyextensions.astnodes import Assignment, AssignmentCollection from ...sympyextensions import ( - Assignment, - AssignmentCollection, integer_functions, ConditionalFieldAccess, ) -from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc +from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType from ...sympyextensions.pointers import AddressOf from ...field import Field, FieldType @@ -58,7 +57,7 @@ from ..ast.expressions import ( ) from ..constants import PsConstant -from ...types import PsStructType +from ...types import PsStructType, PsType from ..exceptions import PsInputError from ..functions import PsMathFunction, MathFunctions @@ -465,7 +464,16 @@ class FreezeExpressions: return cast(PsCall, args[0]) def map_CastFunc(self, cast_expr: CastFunc) -> PsCast: - return PsCast(cast_expr.dtype, self.visit_expr(cast_expr.expr)) + dtype: PsType + match cast_expr.dtype: + case DynamicType.NUMERIC_TYPE: + dtype = self._ctx.default_dtype + case DynamicType.INDEX_TYPE: + dtype = self._ctx.index_dtype + case other if isinstance(other, PsType): + dtype = other + + return PsCast(dtype, self.visit_expr(cast_expr.expr)) def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel: arg1, arg2 = [self.visit_expr(arg) for arg in rel.args] diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index e022db511..cd5c80c88 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import sympy as sp +from enum import Enum, auto -from ..types import PsType, PsNumericType, PsPointerType, PsBoolType, create_type +from ..types import PsType, PsNumericType, PsPointerType, PsBoolType, PsIntegerType, create_type def assumptions_from_dtype(dtype: PsType): @@ -33,20 +36,28 @@ def is_loop_counter_symbol(symbol): return None +class DynamicType(Enum): + NUMERIC_TYPE = auto() + INDEX_TYPE = auto() + + class PsTypeAtom(sp.Atom): """Wrapper around a PsType to disguise it as a SymPy atom.""" def __new__(cls, *args, **kwargs): return sp.Basic.__new__(cls) - def __init__(self, dtype: PsType) -> None: + def __init__(self, dtype: PsType | DynamicType) -> None: self._dtype = dtype def _sympystr(self, *args, **kwargs): return str(self._dtype) - def get(self) -> PsType: + def get(self) -> PsType | DynamicType: return self._dtype + + def _hashable_content(self): + return (self._dtype, ) class TypedSymbol(sp.Symbol): @@ -105,12 +116,15 @@ class FieldStrideSymbol(TypedSymbol): obj = FieldStrideSymbol.__xnew_cached_(cls, *args, **kwds) return obj - def __new_stage2__(cls, field_name: str, coordinate: int): + def __new_stage2__(cls, field_name: str, coordinate: int, dtype: PsIntegerType | None = None): from ..defaults import DEFAULTS + + if dtype is None: + dtype = DEFAULTS.index_dtype name = f"_stride_{field_name}_{coordinate}" obj = super(FieldStrideSymbol, cls).__xnew__( - cls, name, DEFAULTS.index_dtype, positive=True + cls, name, dtype, positive=True ) obj.field_name = field_name obj.coordinate = coordinate @@ -138,12 +152,15 @@ class FieldShapeSymbol(TypedSymbol): obj = FieldShapeSymbol.__xnew_cached_(cls, *args, **kwds) return obj - def __new_stage2__(cls, field_name: str, coordinate: int): + def __new_stage2__(cls, field_name: str, coordinate: int, dtype: PsIntegerType | None = None): from ..defaults import DEFAULTS + + if dtype is None: + dtype = DEFAULTS.index_dtype name = f"_size_{field_name}_{coordinate}" obj = super(FieldShapeSymbol, cls).__xnew__( - cls, name, DEFAULTS.index_dtype, positive=True + cls, name, dtype, positive=True ) obj.field_name = field_name obj.coordinate = coordinate @@ -190,10 +207,21 @@ class FieldPointerSymbol(TypedSymbol): class CastFunc(sp.Function): + """Use this function to introduce a static type cast into the output code. + + Usage: ``CastFunc(expr, target_type)`` becomes, in C code, ``(target_type) expr``. + The `target_type` may be a valid pystencils type specification parsable by `create_type`, + or a special value of the `DynamicType` enum. + These dynamic types can be used to select the target type according to the code generation context. """ - CastFunc is used in order to introduce static casts. They are especially useful as a way to signal what type - a certain node should have, if it is impossible to add a type to a node, e.g. a sp.Number. - """ + + @staticmethod + def as_numeric(expr): + return CastFunc(expr, DynamicType.NUMERIC_TYPE) + + @staticmethod + def as_index(expr): + return CastFunc(expr, DynamicType.INDEX_TYPE) is_Atom = True @@ -207,8 +235,12 @@ class CastFunc(sp.Function): if expr.__class__ == CastFunc: expr = expr.args[0] - if not isinstance(dtype, PsTypeAtom): - dtype = PsTypeAtom(create_type(dtype)) + if not isinstance(dtype, (PsTypeAtom)): + if isinstance(dtype, DynamicType): + dtype = PsTypeAtom(dtype) + else: + dtype = PsTypeAtom(create_type(dtype)) + # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads # to problems when for example comparing cast_func's for equality @@ -236,7 +268,7 @@ class CastFunc(sp.Function): return self.args[0].is_commutative @property - def dtype(self) -> PsType: + def dtype(self) -> PsType | DynamicType: assert isinstance(self.args[1], PsTypeAtom) return self.args[1].get() @@ -246,7 +278,7 @@ class CastFunc(sp.Function): @property def is_integer(self): - if isinstance(self.dtype, PsNumericType): + if isinstance(self.dtype, PsNumericType) or self.dtype == DynamicType.INDEX_TYPE: return self.dtype.is_int() or super().is_integer else: return super().is_integer diff --git a/src/pystencils/types/parsing.py b/src/pystencils/types/parsing.py index 75fb35d22..d6522e5bb 100644 --- a/src/pystencils/types/parsing.py +++ b/src/pystencils/types/parsing.py @@ -158,6 +158,8 @@ def parse_type_name(typename: str, const: bool): case "uint8" | "uint8_t": return PsUnsignedIntegerType(8, const=const) + case "half" | "float16": + return PsIeeeFloatType(16, const=const) case "float" | "float32": return PsIeeeFloatType(32, const=const) case "double" | "float64": diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index 2f0f2ff46..658225762 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -200,7 +200,7 @@ class PsStructType(PsType): @property def numpy_dtype(self) -> np.dtype: members = [(m.name, m.dtype.numpy_dtype) for m in self._members] - return np.dtype(members) + return np.dtype(members, align=True) @property def itemsize(self) -> int: diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index b22df7d0b..f16a468e7 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -1,7 +1,8 @@ import sympy as sp import pytest -from pystencils import Assignment, fields +from pystencils import Assignment, fields, create_type, create_numeric_type +from pystencils.sympyextensions import CastFunc from pystencils.backend.ast.structural import ( PsAssignment, @@ -26,7 +27,8 @@ from pystencils.backend.ast.expressions import ( PsLe, PsGt, PsGe, - PsCall + PsCall, + PsCast, ) from pystencils.backend.constants import PsConstant from pystencils.backend.functions import PsMathFunction, MathFunctions @@ -182,14 +184,17 @@ def test_freeze_booleans(): assert expr.structurally_equal(PsOr(PsOr(PsOr(w2, x2), y2), z2)) -@pytest.mark.parametrize("rel_pair", [ - (sp.Eq, PsEq), - (sp.Ne, PsNe), - (sp.Lt, PsLt), - (sp.Gt, PsGt), - (sp.Le, PsLe), - (sp.Ge, PsGe) -]) +@pytest.mark.parametrize( + "rel_pair", + [ + (sp.Eq, PsEq), + (sp.Ne, PsNe), + (sp.Lt, PsLt), + (sp.Gt, PsGt), + (sp.Le, PsLe), + (sp.Ge, PsGe), + ], +) def test_freeze_relations(rel_pair): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) @@ -211,7 +216,7 @@ def test_freeze_piecewise(): freeze = FreezeExpressions(ctx) p, q, x, y, z = sp.symbols("p, q, x, y, z") - + p2 = PsExpression.make(ctx.get_symbol("p")) q2 = PsExpression.make(ctx.get_symbol("q")) x2 = PsExpression.make(ctx.get_symbol("x")) @@ -222,10 +227,10 @@ def test_freeze_piecewise(): expr = freeze(piecewise) assert isinstance(expr, PsTernary) - + should = PsTernary(p2, x2, PsTernary(q2, y2, z2)) assert expr.structurally_equal(should) - + piecewise = sp.Piecewise((x, p), (y, q), (z, sp.Or(p, q))) with pytest.raises(FreezeError): freeze(piecewise) @@ -259,3 +264,25 @@ def test_multiarg_min_max(): expr = freeze(sp.Max(w, x, y, z)) assert expr.structurally_equal(op(op(w2, x2), op(y2, z2))) + + +def test_cast_func(): + ctx = KernelCreationContext( + default_dtype=create_numeric_type("float16"), index_dtype=create_type("int16") + ) + freeze = FreezeExpressions(ctx) + + x, y, z = sp.symbols("x, y, z") + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + z2 = PsExpression.make(ctx.get_symbol("z")) + + expr = freeze(CastFunc(x, create_type("int"))) + assert expr.structurally_equal(PsCast(create_type("int"), x2)) + + expr = freeze(CastFunc.as_numeric(y)) + assert expr.structurally_equal(PsCast(ctx.default_dtype, y2)) + + expr = freeze(CastFunc.as_index(z)) + assert expr.structurally_equal(PsCast(ctx.index_dtype, z2)) diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py index 39f89e6fe..1cc2ae0e4 100644 --- a/tests/nbackend/types/test_types.py +++ b/tests/nbackend/types/test_types.py @@ -139,6 +139,17 @@ def test_struct_types(): with pytest.raises(PsTypeError): t.c_string() + t = PsStructType([ + ("a", SInt(8)), + ("b", SInt(16)), + ("c", SInt(64)) + ]) + + # Check that natural alignment is taken into account + numpy_type = np.dtype([("a", "i1"), ("b", "i2"), ("c", "i8")], align=True) + assert t.numpy_dtype == numpy_type + assert t.itemsize == numpy_type.itemsize == 16 + def test_pickle(): types = [ -- GitLab