Skip to content
Snippets Groups Projects
Commit d85f682c authored by Frederik Hennig's avatar Frederik Hennig
Browse files

some extensions to the type system

parent b6f6afd8
No related branches found
No related tags found
1 merge request!400Extensions and fixes to the type system
Pipeline #67439 passed
......@@ -6,7 +6,7 @@ from . import fd
from . import stencil as stencil
from .display_utils import get_code_obj, get_code_str, show_code, to_dot
from .field import Field, FieldType, fields
from .types import create_type
from .types import create_type, create_numeric_type
from .cache import clear_cache
from .config import (
CreateKernelConfig,
......@@ -41,6 +41,7 @@ __all__ = [
"DEFAULTS",
"TypedSymbol",
"create_type",
"create_numeric_type",
"make_slice",
"CreateKernelConfig",
"CpuOptimConfig",
......
......@@ -7,13 +7,12 @@ import sympy.core.relational
import sympy.logic.boolalg
from sympy.codegen.ast import AssignmentBase, AugmentedAssignment
from ...sympyextensions.astnodes import Assignment, AssignmentCollection
from ...sympyextensions import (
Assignment,
AssignmentCollection,
integer_functions,
ConditionalFieldAccess,
)
from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc
from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType
from ...sympyextensions.pointers import AddressOf
from ...field import Field, FieldType
......@@ -58,7 +57,7 @@ from ..ast.expressions import (
)
from ..constants import PsConstant
from ...types import PsStructType
from ...types import PsStructType, PsType
from ..exceptions import PsInputError
from ..functions import PsMathFunction, MathFunctions
......@@ -465,7 +464,16 @@ class FreezeExpressions:
return cast(PsCall, args[0])
def map_CastFunc(self, cast_expr: CastFunc) -> PsCast:
return PsCast(cast_expr.dtype, self.visit_expr(cast_expr.expr))
dtype: PsType
match cast_expr.dtype:
case DynamicType.NUMERIC_TYPE:
dtype = self._ctx.default_dtype
case DynamicType.INDEX_TYPE:
dtype = self._ctx.index_dtype
case other if isinstance(other, PsType):
dtype = other
return PsCast(dtype, self.visit_expr(cast_expr.expr))
def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel:
arg1, arg2 = [self.visit_expr(arg) for arg in rel.args]
......
from __future__ import annotations
import sympy as sp
from enum import Enum, auto
from ..types import PsType, PsNumericType, PsPointerType, PsBoolType, create_type
from ..types import PsType, PsNumericType, PsPointerType, PsBoolType, PsIntegerType, create_type
def assumptions_from_dtype(dtype: PsType):
......@@ -33,20 +36,28 @@ def is_loop_counter_symbol(symbol):
return None
class DynamicType(Enum):
NUMERIC_TYPE = auto()
INDEX_TYPE = auto()
class PsTypeAtom(sp.Atom):
"""Wrapper around a PsType to disguise it as a SymPy atom."""
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
def __init__(self, dtype: PsType) -> None:
def __init__(self, dtype: PsType | DynamicType) -> None:
self._dtype = dtype
def _sympystr(self, *args, **kwargs):
return str(self._dtype)
def get(self) -> PsType:
def get(self) -> PsType | DynamicType:
return self._dtype
def _hashable_content(self):
return (self._dtype, )
class TypedSymbol(sp.Symbol):
......@@ -105,12 +116,15 @@ class FieldStrideSymbol(TypedSymbol):
obj = FieldStrideSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, field_name: str, coordinate: int):
def __new_stage2__(cls, field_name: str, coordinate: int, dtype: PsIntegerType | None = None):
from ..defaults import DEFAULTS
if dtype is None:
dtype = DEFAULTS.index_dtype
name = f"_stride_{field_name}_{coordinate}"
obj = super(FieldStrideSymbol, cls).__xnew__(
cls, name, DEFAULTS.index_dtype, positive=True
cls, name, dtype, positive=True
)
obj.field_name = field_name
obj.coordinate = coordinate
......@@ -138,12 +152,15 @@ class FieldShapeSymbol(TypedSymbol):
obj = FieldShapeSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, field_name: str, coordinate: int):
def __new_stage2__(cls, field_name: str, coordinate: int, dtype: PsIntegerType | None = None):
from ..defaults import DEFAULTS
if dtype is None:
dtype = DEFAULTS.index_dtype
name = f"_size_{field_name}_{coordinate}"
obj = super(FieldShapeSymbol, cls).__xnew__(
cls, name, DEFAULTS.index_dtype, positive=True
cls, name, dtype, positive=True
)
obj.field_name = field_name
obj.coordinate = coordinate
......@@ -190,10 +207,21 @@ class FieldPointerSymbol(TypedSymbol):
class CastFunc(sp.Function):
"""Use this function to introduce a static type cast into the output code.
Usage: ``CastFunc(expr, target_type)`` becomes, in C code, ``(target_type) expr``.
The `target_type` may be a valid pystencils type specification parsable by `create_type`,
or a special value of the `DynamicType` enum.
These dynamic types can be used to select the target type according to the code generation context.
"""
CastFunc is used in order to introduce static casts. They are especially useful as a way to signal what type
a certain node should have, if it is impossible to add a type to a node, e.g. a sp.Number.
"""
@staticmethod
def as_numeric(expr):
return CastFunc(expr, DynamicType.NUMERIC_TYPE)
@staticmethod
def as_index(expr):
return CastFunc(expr, DynamicType.INDEX_TYPE)
is_Atom = True
......@@ -207,8 +235,12 @@ class CastFunc(sp.Function):
if expr.__class__ == CastFunc:
expr = expr.args[0]
if not isinstance(dtype, PsTypeAtom):
dtype = PsTypeAtom(create_type(dtype))
if not isinstance(dtype, (PsTypeAtom)):
if isinstance(dtype, DynamicType):
dtype = PsTypeAtom(dtype)
else:
dtype = PsTypeAtom(create_type(dtype))
# to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
# however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
# to problems when for example comparing cast_func's for equality
......@@ -236,7 +268,7 @@ class CastFunc(sp.Function):
return self.args[0].is_commutative
@property
def dtype(self) -> PsType:
def dtype(self) -> PsType | DynamicType:
assert isinstance(self.args[1], PsTypeAtom)
return self.args[1].get()
......@@ -246,7 +278,7 @@ class CastFunc(sp.Function):
@property
def is_integer(self):
if isinstance(self.dtype, PsNumericType):
if isinstance(self.dtype, PsNumericType) or self.dtype == DynamicType.INDEX_TYPE:
return self.dtype.is_int() or super().is_integer
else:
return super().is_integer
......
......@@ -158,6 +158,8 @@ def parse_type_name(typename: str, const: bool):
case "uint8" | "uint8_t":
return PsUnsignedIntegerType(8, const=const)
case "half" | "float16":
return PsIeeeFloatType(16, const=const)
case "float" | "float32":
return PsIeeeFloatType(32, const=const)
case "double" | "float64":
......
......@@ -200,7 +200,7 @@ class PsStructType(PsType):
@property
def numpy_dtype(self) -> np.dtype:
members = [(m.name, m.dtype.numpy_dtype) for m in self._members]
return np.dtype(members)
return np.dtype(members, align=True)
@property
def itemsize(self) -> int:
......
import sympy as sp
import pytest
from pystencils import Assignment, fields
from pystencils import Assignment, fields, create_type, create_numeric_type
from pystencils.sympyextensions import CastFunc
from pystencils.backend.ast.structural import (
PsAssignment,
......@@ -26,7 +27,8 @@ from pystencils.backend.ast.expressions import (
PsLe,
PsGt,
PsGe,
PsCall
PsCall,
PsCast,
)
from pystencils.backend.constants import PsConstant
from pystencils.backend.functions import PsMathFunction, MathFunctions
......@@ -182,14 +184,17 @@ def test_freeze_booleans():
assert expr.structurally_equal(PsOr(PsOr(PsOr(w2, x2), y2), z2))
@pytest.mark.parametrize("rel_pair", [
(sp.Eq, PsEq),
(sp.Ne, PsNe),
(sp.Lt, PsLt),
(sp.Gt, PsGt),
(sp.Le, PsLe),
(sp.Ge, PsGe)
])
@pytest.mark.parametrize(
"rel_pair",
[
(sp.Eq, PsEq),
(sp.Ne, PsNe),
(sp.Lt, PsLt),
(sp.Gt, PsGt),
(sp.Le, PsLe),
(sp.Ge, PsGe),
],
)
def test_freeze_relations(rel_pair):
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
......@@ -211,7 +216,7 @@ def test_freeze_piecewise():
freeze = FreezeExpressions(ctx)
p, q, x, y, z = sp.symbols("p, q, x, y, z")
p2 = PsExpression.make(ctx.get_symbol("p"))
q2 = PsExpression.make(ctx.get_symbol("q"))
x2 = PsExpression.make(ctx.get_symbol("x"))
......@@ -222,10 +227,10 @@ def test_freeze_piecewise():
expr = freeze(piecewise)
assert isinstance(expr, PsTernary)
should = PsTernary(p2, x2, PsTernary(q2, y2, z2))
assert expr.structurally_equal(should)
piecewise = sp.Piecewise((x, p), (y, q), (z, sp.Or(p, q)))
with pytest.raises(FreezeError):
freeze(piecewise)
......@@ -259,3 +264,25 @@ def test_multiarg_min_max():
expr = freeze(sp.Max(w, x, y, z))
assert expr.structurally_equal(op(op(w2, x2), op(y2, z2)))
def test_cast_func():
ctx = KernelCreationContext(
default_dtype=create_numeric_type("float16"), index_dtype=create_type("int16")
)
freeze = FreezeExpressions(ctx)
x, y, z = sp.symbols("x, y, z")
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
z2 = PsExpression.make(ctx.get_symbol("z"))
expr = freeze(CastFunc(x, create_type("int")))
assert expr.structurally_equal(PsCast(create_type("int"), x2))
expr = freeze(CastFunc.as_numeric(y))
assert expr.structurally_equal(PsCast(ctx.default_dtype, y2))
expr = freeze(CastFunc.as_index(z))
assert expr.structurally_equal(PsCast(ctx.index_dtype, z2))
......@@ -139,6 +139,17 @@ def test_struct_types():
with pytest.raises(PsTypeError):
t.c_string()
t = PsStructType([
("a", SInt(8)),
("b", SInt(16)),
("c", SInt(64))
])
# Check that natural alignment is taken into account
numpy_type = np.dtype([("a", "i1"), ("b", "i2"), ("c", "i8")], align=True)
assert t.numpy_dtype == numpy_type
assert t.itemsize == numpy_type.itemsize == 16
def test_pickle():
types = [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment