diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 07283d5294bc08c8e68e6a40af0f956b36a0129a..15aa3fd01d7b7a9e2831d2aa1ea800d6f713a665 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -1,10 +1,6 @@ """Module to generate stencil kernels in C or CUDA using sympy expressions and call them as Python functions""" -from .codegen import ( - Target, - CreateKernelConfig, - AUTO -) +from .codegen import Target, CreateKernelConfig, AUTO from .defaults import DEFAULTS from . import fd from . import stencil as stencil @@ -31,7 +27,13 @@ from .spatial_coordinates import ( ) from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil from .simp import AssignmentCollection -from .sympyextensions.typed_sympy import TypedSymbol, DynamicType +from .sympyextensions.typed_sympy import ( + TypedSymbol, + DynamicType, + DynamicIndexType, + DynamicNumericType, + DynamicNumericArrayType, +) from .sympyextensions import SymbolCreator, tcast from .datahandling import create_data_handling @@ -42,6 +44,9 @@ __all__ = [ "DEFAULTS", "TypedSymbol", "DynamicType", + "DynamicIndexType", + "DynamicNumericType", + "DynamicNumericArrayType", "create_type", "create_numeric_type", "make_slice", @@ -81,4 +86,5 @@ __all__ = [ ] from . import _version -__version__ = _version.get_versions()['version'] + +__version__ = _version.get_versions()["version"] diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 68da893ff6204c73d61dcebddb1da602f37520c7..b9192bba858f3e91a850d3296c8f33e167463e5b 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -7,12 +7,19 @@ import re from ...defaults import DEFAULTS from ...field import Field, FieldType -from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType +from ...sympyextensions.typed_sympy import ( + TypedSymbol, + DynamicType, + DynamicIndexType, + DynamicNumericType, + DynamicNumericArrayType, +) from ..memory import PsSymbol, PsBuffer from ..constants import PsConstant from ...types import ( PsType, + PsArrayType, PsIntegerType, PsNumericType, PsPointerType, @@ -93,14 +100,20 @@ class KernelCreationContext: def index_dtype(self) -> PsIntegerType: """Data type used by default for index expressions""" return self._index_dtype - + def resolve_dynamic_type(self, dtype: DynamicType | PsType) -> PsType: """Selects the appropriate data type for `DynamicType` instances, and returns all other types as they are.""" match dtype: - case DynamicType.NUMERIC_TYPE: + case DynamicNumericType(): return self._default_dtype - case DynamicType.INDEX_TYPE: + case DynamicIndexType(): return self._index_dtype + case DynamicNumericArrayType(shape): + return PsArrayType(self._default_dtype, shape) + case DynamicType(): + raise PsInternalCompilerError( + f"Unknown `DynamicType` can not be resolved: {dtype}" + ) case _: return dtype @@ -325,9 +338,9 @@ class KernelCreationContext: def _normalize_type(self, s: TypedSymbol) -> PsIntegerType: match s.dtype: - case DynamicType.INDEX_TYPE: + case DynamicIndexType(): return self.index_dtype - case DynamicType.NUMERIC_TYPE: + case DynamicNumericType(): if isinstance(self.default_dtype, PsIntegerType): return self.default_dtype else: diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index b3ff5aefb525ef311d0e3199c79f60c52617a853..fbd891d2b84cb632e1c72ba16b03d92d1f362700 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -13,7 +13,14 @@ from ...sympyextensions import ( integer_functions, ConditionalFieldAccess, ) -from ...sympyextensions.typed_sympy import TypedSymbol, TypeCast, DynamicType +from ...sympyextensions.typed_sympy import ( + TypedSymbol, + TypeCast, + DynamicIndexType, + DynamicNumericType, + DynamicNumericArrayType, +) + from ...sympyextensions.pointers import AddressOf, mem_acc from ...field import Field, FieldType @@ -55,12 +62,12 @@ from ..ast.expressions import ( PsAnd, PsOr, PsNot, - PsMemAcc + PsMemAcc, ) from ..ast.vector import PsVecMemAcc from ..constants import PsConstant -from ...types import PsNumericType, PsStructType, PsType +from ...types import PsArrayType, PsNumericType, PsStructType, PsType from ..exceptions import PsInputError from ..functions import PsMathFunction, MathFunctions from ..exceptions import FreezeError @@ -214,7 +221,7 @@ class FreezeExpressions: exponent = expr.args[1] expr_frozen = self.visit_expr(base) - + if isinstance(exponent, sp.Rational): # Decompose rational exponent num: int = exponent.numerator @@ -234,7 +241,7 @@ class FreezeExpressions: denom = 1 assert denom == 1 - + # Pairwise multiplication for logarithmic runtime factors = [expr_frozen] + [expr_frozen.clone() for _ in range(num - 1)] while len(factors) > 1: @@ -250,7 +257,7 @@ class FreezeExpressions: expr_frozen = one / expr_frozen return expr_frozen - + # If we got this far, use pow exponent_frozen = self.visit_expr(exponent) expr_frozen = PsMathFunction(MathFunctions.Pow)(expr_frozen, exponent_frozen) @@ -280,22 +287,22 @@ class FreezeExpressions: raise FreezeError("Cannot translate an empty tuple.") items = [self.visit_expr(item) for item in expr] - + if any(isinstance(i, PsArrayInitList) for i in items): # base case: have nested arrays if not all(isinstance(i, PsArrayInitList) for i in items): raise FreezeError( f"Cannot translate nested arrays of non-uniform shape: {expr}" ) - + subarrays = cast(list[PsArrayInitList], items) shape_tail = subarrays[0].shape - + if not all(s.shape == shape_tail for s in subarrays[1:]): raise FreezeError( f"Cannot translate nested arrays of non-uniform shape: {expr}" ) - + return PsArrayInitList([s.items_grid for s in subarrays]) # type: ignore else: # base case: no nested arrays @@ -487,10 +494,12 @@ class FreezeExpressions: def map_TypeCast(self, cast_expr: TypeCast) -> PsCast | PsConstantExpr: dtype: PsType match cast_expr.dtype: - case DynamicType.NUMERIC_TYPE: + case DynamicNumericType(): dtype = self._ctx.default_dtype - case DynamicType.INDEX_TYPE: + case DynamicIndexType(): dtype = self._ctx.index_dtype + case DynamicNumericArrayType(shape): + dtype = PsArrayType(self._ctx.default_dtype, shape) case other if isinstance(other, PsType): dtype = other diff --git a/src/pystencils/defaults.py b/src/pystencils/defaults.py index 0b6a48af1944f81a29d6dcd4b0d29e90aeb4ce2d..d637182d777795aea701231b828927e2adf6f46c 100644 --- a/src/pystencils/defaults.py +++ b/src/pystencils/defaults.py @@ -7,7 +7,7 @@ from .types import ( create_type, ) -from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType +from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicIndexType class SympyDefaults: @@ -22,9 +22,9 @@ class SympyDefaults: """Names of the default spatial counters""" self.spatial_counters = ( - TypedSymbol("ctr_0", DynamicType.INDEX_TYPE), - TypedSymbol("ctr_1", DynamicType.INDEX_TYPE), - TypedSymbol("ctr_2", DynamicType.INDEX_TYPE), + TypedSymbol("ctr_0", DynamicIndexType()), + TypedSymbol("ctr_1", DynamicIndexType()), + TypedSymbol("ctr_2", DynamicIndexType()), ) """Default spatial counters""" @@ -34,9 +34,7 @@ class SympyDefaults: self.sparse_counter_name = "sparse_idx" """Name of the default sparse iteration counter""" - self.sparse_counter = TypedSymbol( - self.sparse_counter_name, DynamicType.INDEX_TYPE - ) + self.sparse_counter = TypedSymbol(self.sparse_counter_name, DynamicIndexType()) """Default sparse iteration counter.""" def field_shape_name(self, field_name: str, coord: int): diff --git a/src/pystencils/field.py b/src/pystencils/field.py index 246232efde7a6b432598f614492725e2ea063cff..ecf47319527447f7bf63f141c565dc87fddfae1c 100644 --- a/src/pystencils/field.py +++ b/src/pystencils/field.py @@ -23,7 +23,12 @@ from .stencil import ( offset_to_direction_string, ) from .types import PsType, PsStructType, create_type -from .sympyextensions.typed_sympy import TypedSymbol, DynamicType +from .sympyextensions.typed_sympy import ( + TypedSymbol, + DynamicType, + DynamicIndexType, + DynamicNumericType, +) from .sympyextensions import is_integer_sequence from .types import UserTypeSpec @@ -145,7 +150,7 @@ class Field: def create_generic( field_name, spatial_dimensions, - dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE, + dtype: UserTypeSpec | DynamicType = DynamicNumericType(), index_dimensions=0, layout="numpy", index_shape=None, @@ -186,7 +191,8 @@ class Field: shape = tuple( [ TypedSymbol( - DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE + DEFAULTS.field_shape_name(field_name, i), + DynamicIndexType(), ) for i in range(total_dimensions) ] @@ -195,7 +201,8 @@ class Field: shape = tuple( [ TypedSymbol( - DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE + DEFAULTS.field_shape_name(field_name, i), + DynamicIndexType(), ) for i in range(spatial_dimensions) ] @@ -205,7 +212,7 @@ class Field: strides: tuple[TypedSymbol | int, ...] = tuple( [ TypedSymbol( - DEFAULTS.field_stride_name(field_name, i), DynamicType.INDEX_TYPE + DEFAULTS.field_stride_name(field_name, i), DynamicIndexType() ) for i in range(total_dimensions) ] @@ -277,7 +284,7 @@ class Field: field_name: str, shape: tuple[int, ...], index_dimensions: int = 0, - dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE, + dtype: UserTypeSpec | DynamicType = DynamicNumericType(), layout: str | tuple[int, ...] = "numpy", memory_strides: None | Sequence[int] = None, strides: Optional[Sequence[int]] = None, @@ -303,16 +310,20 @@ class Field: "Use `memory_strides` instead; " "beware that `memory_strides` takes the number of *elements* to skip, " "instead of the number of bytes.", - FutureWarning + FutureWarning, ) if memory_strides is not None: - raise ValueError("Cannot specify `memory_strides` and deprecated parameter `strides` at the same time.") - + raise ValueError( + "Cannot specify `memory_strides` and deprecated parameter `strides` at the same time." + ) + if isinstance(dtype, DynamicType): - raise ValueError("Cannot specify the deprecated parameter `strides` together with a `DynamicType`. " - "Set `memory_strides` instead.") - + raise ValueError( + "Cannot specify the deprecated parameter `strides` together with a `DynamicType`. " + "Set `memory_strides` instead." + ) + np_type = create_type(dtype).numpy_dtype assert np_type is not None memory_strides = tuple([s // np_type.itemsize for s in strides]) @@ -1135,8 +1146,8 @@ def fields( (can be omitted for scalar fields) - ``<data-type>`` is the numerical data type of the field's entries; this can be any type parseable by `create_type`, - as well as ``dyn`` for `DynamicType.NUMERIC_TYPE` - and ``dynidx`` for `DynamicType.INDEX_TYPE`. + as well as ``dyn`` for `DynamicNumericType()` + and ``dynidx`` for `DynamicIndexType()`. - ``<dimension-or-shape>`` can be a dimensionality (e.g. ``1D``, ``2D``, ``3D``) or a tuple of integers defining the spatial shape of the field. @@ -1144,7 +1155,7 @@ def fields( Create a 3D scalar field of default numeric type: >>> f = fields("f(1): [2D]") >>> str(f.dtype) - 'DynamicType.NUMERIC_TYPE' + 'DynamicNumericType' Create a 2D scalar and vector field of 64-bit float type: >>> s, v = fields("s, v(2): double[2D]") @@ -1452,11 +1463,14 @@ def _parse_description(description): if data_type_str: match data_type_str: - case "dyn": dtype = DynamicType.NUMERIC_TYPE - case "dynidx": dtype = DynamicType.INDEX_TYPE - case _: dtype = create_type(data_type_str) + case "dyn": + dtype = DynamicNumericType() + case "dynidx": + dtype = DynamicIndexType() + case _: + dtype = create_type(data_type_str) else: - dtype = DynamicType.NUMERIC_TYPE + dtype = DynamicNumericType() if size_info.endswith("d"): size_info = int(size_info[:-1]) diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index e2435d6bbe570887e0903c67f6041ed9911c02be..dbc96e5a1847c61fa5f6415ada625bda69d4b28c 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -1,15 +1,11 @@ from __future__ import annotations -from typing import cast +from abc import ABC +from dataclasses import dataclass +from typing import cast, Any, Sequence, SupportsIndex import sympy as sp -from enum import Enum, auto -from ..types import ( - PsType, - PsNumericType, - create_type, - UserTypeSpec -) +from ..types import PsType, PsNumericType, create_type, UserTypeSpec from sympy.logic.boolalg import Boolean @@ -25,18 +21,54 @@ def is_loop_counter_symbol(symbol): return None -class DynamicType(Enum): - """Dynamic data type that will be resolved during kernel creation""" +class DynamicType(ABC): + """Dynamic data type that will be resolved during kernel creation.""" - NUMERIC_TYPE = auto() - """Use the default numeric type set for the kernel""" - INDEX_TYPE = auto() +class DynamicNumericType(DynamicType): + """Use the default numeric type set for the kernel.""" + + def __eq__(self, o: Any) -> bool: + return type(self) is type(o) + + def __hash__(self) -> int: + return hash(()) + + def __str__(self) -> str: + return "DynamicNumericType" + + +class DynamicIndexType(DynamicType): """Use the default index type set for the kernel. - + This is guaranteed to be an interger type. """ + def __eq__(self, o: Any) -> bool: + return type(self) is type(o) + + def __hash__(self) -> int: + return hash(()) + + def __str__(self) -> str: + return "DynamicIndexType" + + +@dataclass +class DynamicNumericArrayType(DynamicType): + """An array of the default numeric type with a statically known shape.""" + + shape: SupportsIndex | Sequence[SupportsIndex] + + def __eq__(self, o: Any) -> bool: + return type(self) is type(o) and self.shape == o.shape + + def __hash__(self) -> int: + return hash(self.shape) + + def __str__(self) -> str: + return f"DynamicNumericArrayType({self.shape})" + class TypeAtom(sp.Atom): """Wrapper around a type to disguise it as a SymPy atom.""" @@ -56,10 +88,10 @@ class TypeAtom(sp.Atom): def _hashable_content(self): return (self._dtype,) - + def __getnewargs__(self): return (self._dtype,) - + def assumptions_from_dtype(dtype: PsType | DynamicType): """Derives SymPy assumptions from :class:`PsAbstractType` @@ -72,9 +104,9 @@ def assumptions_from_dtype(dtype: PsType | DynamicType): assumptions = dict() match dtype: - case DynamicType.INDEX_TYPE: + case DynamicIndexType(): assumptions.update({"integer": True, "real": True}) - case DynamicType.NUMERIC_TYPE: + case DynamicNumericType(): assumptions.update({"real": True}) case PsNumericType(): if dtype.is_int(): @@ -146,12 +178,12 @@ class TypeCast(sp.Function): @staticmethod def as_numeric(expr): - return TypeCast(expr, DynamicType.NUMERIC_TYPE) + return TypeCast(expr, DynamicNumericType()) @staticmethod def as_index(expr): - return TypeCast(expr, DynamicType.INDEX_TYPE) - + return TypeCast(expr, DynamicIndexType()) + @property def expr(self) -> sp.Basic: return self.args[0] @@ -159,7 +191,7 @@ class TypeCast(sp.Function): @property def dtype(self) -> PsType | DynamicType: return cast(TypeAtom, self._args[1]).get() - + def __new__(cls, expr: sp.Basic, dtype: UserTypeSpec | DynamicType | TypeAtom): tatom: TypeAtom match dtype: @@ -169,29 +201,31 @@ class TypeCast(sp.Function): tatom = TypeAtom(dtype) case _: tatom = TypeAtom(create_type(dtype)) - + return super().__new__(cls, expr, tatom) - + @classmethod def eval(cls, expr: sp.Basic, tatom: TypeAtom) -> TypeCast | None: dtype = tatom.get() if cls is not BoolCast and isinstance(dtype, PsNumericType) and dtype.is_bool(): return BoolCast(expr, tatom) - + return None - + def _eval_is_integer(self): - if self.dtype == DynamicType.INDEX_TYPE: + if self.dtype == DynamicIndexType(): return True if isinstance(self.dtype, PsNumericType) and self.dtype.is_int(): return True - + def _eval_is_real(self): - if isinstance(self.dtype, DynamicType): + if self.dtype == DynamicIndexType() or self.dtype == DynamicNumericType(): return True - if isinstance(self.dtype, PsNumericType) and (self.dtype.is_float() or self.dtype.is_int()): + if isinstance(self.dtype, PsNumericType) and ( + self.dtype.is_float() or self.dtype.is_int() + ): return True - + def _eval_is_nonnegative(self): if isinstance(self.dtype, PsNumericType) and self.dtype.is_uint(): return True @@ -209,6 +243,6 @@ class CastFunc(TypeCast): warn( "CastFunc is deprecated and will be removed in pystencils 2.1. " "Use `pystencils.tcast` instead.", - FutureWarning + FutureWarning, ) return TypeCast.__new__(cls, *args, **kwargs) diff --git a/tests/frontend/test_field.py b/tests/frontend/test_field.py index 6d2942569704b7ff85b15fd23432667ba109ed7d..dfbec1c85bb33d74e607a08f1682e404dc4e1431 100644 --- a/tests/frontend/test_field.py +++ b/tests/frontend/test_field.py @@ -3,7 +3,13 @@ import pytest import sympy as sp import pystencils as ps -from pystencils import DEFAULTS, DynamicType, create_type, fields +from pystencils import ( + DEFAULTS, + DynamicNumericType, + DynamicIndexType, + create_type, + fields, +) from pystencils.field import ( Field, FieldType, @@ -15,7 +21,7 @@ from pystencils.field import ( def test_field_basic(): f = Field.create_generic("f", spatial_dimensions=2) assert FieldType.is_generic(f) - assert f.dtype == DynamicType.NUMERIC_TYPE + assert f.dtype == DynamicNumericType() assert f["E"] == f[1, 0] assert f["N"] == f[0, 1] assert "_" in f.center._latex("dummy") @@ -61,34 +67,34 @@ def test_field_basic(): neighbor = field_access.neighbor(coord_id=0, offset=-2) assert neighbor.offsets == (-1, 1) assert "_" in neighbor._latex("dummy") - assert f.dtype == DynamicType.NUMERIC_TYPE + assert f.dtype == DynamicNumericType() f = Field.create_fixed_size("f", (8, 8, 2, 2, 2), index_dimensions=3) assert f.center_vector == sp.Array( [[[f(i, j, k) for k in range(2)] for j in range(2)] for i in range(2)] ) - assert f.dtype == DynamicType.NUMERIC_TYPE + assert f.dtype == DynamicNumericType() f = Field.create_generic("f", spatial_dimensions=5, index_dimensions=2) field_access = f[1, -1, 2, -3, 0](1, 0) assert field_access.offsets == (1, -1, 2, -3, 0) assert field_access.index == (1, 0) - assert f.dtype == DynamicType.NUMERIC_TYPE + assert f.dtype == DynamicNumericType() def test_field_description_parsing(): f, g = fields("f(1), g(3): [2D]") - assert f.dtype == g.dtype == DynamicType.NUMERIC_TYPE + assert f.dtype == g.dtype == DynamicNumericType() assert f.spatial_dimensions == g.spatial_dimensions == 2 assert f.index_shape == (1,) assert g.index_shape == (3,) f = fields("f: dyn[3D]") - assert f.dtype == DynamicType.NUMERIC_TYPE + assert f.dtype == DynamicNumericType() idx = fields("idx: dynidx[3D]") - assert idx.dtype == DynamicType.INDEX_TYPE + assert idx.dtype == DynamicIndexType() h = fields("h: float32[3D]") @@ -98,7 +104,7 @@ def test_field_description_parsing(): assert h.dtype == create_type("float32") f: Field = fields("f(5, 5) : double[20, 20]") - + assert f.dtype == create_type("float64") assert f.spatial_shape == (20, 20) assert f.index_shape == (5, 5) @@ -264,9 +270,7 @@ def test_memory_layout_descriptors(): == (3, 2, 1, 0) ) assert ( - layout_string_to_tuple("c", 4) - == layout_string_to_tuple("C", 4) - == (0, 1, 2, 3) + layout_string_to_tuple("c", 4) == layout_string_to_tuple("C", 4) == (0, 1, 2, 3) ) assert layout_string_to_tuple("C", 5) == (0, 1, 2, 3, 4) diff --git a/tests/frontend/test_typed_sympy.py b/tests/frontend/test_typed_sympy.py index bf6058537a7217851d22987f3b011edea08058c8..6f0e40a5fe8f8b44482d9abae5fcab4fb5b46f16 100644 --- a/tests/frontend/test_typed_sympy.py +++ b/tests/frontend/test_typed_sympy.py @@ -8,6 +8,9 @@ from pystencils.sympyextensions.typed_sympy import ( tcast, TypeAtom, DynamicType, + DynamicIndexType, + DynamicNumericType, + DynamicNumericArrayType, ) from pystencils.types import create_type from pystencils.types.quick import UInt, Ptr @@ -22,16 +25,19 @@ def test_type_atoms(): atom3 = TypeAtom(create_type("const int32")) assert atom1 != atom3 - atom4 = TypeAtom(DynamicType.INDEX_TYPE) - atom5 = TypeAtom(DynamicType.NUMERIC_TYPE) + atom4 = TypeAtom(DynamicIndexType()) + atom5 = TypeAtom(DynamicNumericType()) + atom6 = TypeAtom(DynamicNumericArrayType((4, 2))) assert atom3 != atom4 assert atom4 != atom5 + assert atom5 != atom6 - dump = pickle.dumps(atom1) - atom1_reconst = pickle.loads(dump) + for atom in [atom1, atom4, atom5, atom6]: + dump = pickle.dumps(atom) + atom_reconst = pickle.loads(dump) - assert atom1_reconst == atom1 + assert atom_reconst == atom def test_typed_symbol(): @@ -66,10 +72,10 @@ def test_casts(): # Boolean Casts bool_expr = tcast(x, "bool") assert isinstance(bool_expr, boolalg.Boolean) - + # Check that we can construct boolean expressions with cast results _ = boolalg.Or(bool_expr, y) - + # Assumptions expr = tcast(x, "int32") assert expr.is_integer diff --git a/tests/kernelcreation/test_iteration_slices.py b/tests/kernelcreation/test_iteration_slices.py index 2b3a8ebf0e2fbbd5ed94779b2eb764a6127fc030..ab6d4c52b7d09a4217bcb83e65ce376c6c0933fd 100644 --- a/tests/kernelcreation/test_iteration_slices.py +++ b/tests/kernelcreation/test_iteration_slices.py @@ -13,7 +13,7 @@ from pystencils import ( make_slice, Target, CreateKernelConfig, - DynamicType, + DynamicIndexType, ) from pystencils.sympyextensions.integer_functions import int_rem from pystencils.simp import sympy_cse_on_assignment_list @@ -63,7 +63,7 @@ def test_sliced_iteration(): make_slice[2:-2:2, ::3], make_slice[10:, :-5:2], make_slice[-5:-1, -1], - make_slice[-3, -1] + make_slice[-3, -1], ], ) def test_numerical_slices(gen_config: CreateKernelConfig, xp, islice): @@ -94,7 +94,7 @@ def test_symbolic_slice(gen_config: CreateKernelConfig, xp): shape = (16, 16) sx, sy, ex, ey = [ - TypedSymbol(n, DynamicType.INDEX_TYPE) for n in ("sx", "sy", "ex", "ey") + TypedSymbol(n, DynamicIndexType()) for n in ("sx", "sy", "ex", "ey") ] f_arr = xp.zeros(shape) diff --git a/tests/nbackend/kernelcreation/test_context.py b/tests/nbackend/kernelcreation/test_context.py index 200c1e34e8ab3ac04fa119491805ef61111062c6..3be7d941b7bd172818f64410b2d33d61b78c9e66 100644 --- a/tests/nbackend/kernelcreation/test_context.py +++ b/tests/nbackend/kernelcreation/test_context.py @@ -1,7 +1,7 @@ from itertools import chain import pytest -from pystencils import Field, TypedSymbol, FieldType, DynamicType +from pystencils import Field, TypedSymbol, FieldType, DynamicNumericType from pystencils.backend.kernelcreation import KernelCreationContext from pystencils.backend.constants import PsConstant @@ -100,8 +100,8 @@ def test_invalid_fields(): FieldType.GENERIC, Fp(32), (0,), - (TypedSymbol("nx", DynamicType.NUMERIC_TYPE),), - (TypedSymbol("sx", DynamicType.NUMERIC_TYPE),), + (TypedSymbol("nx", DynamicNumericType()),), + (TypedSymbol("sx", DynamicNumericType()),), ) with pytest.raises(KernelConstraintsError): @@ -121,7 +121,9 @@ def test_duplicate_fields(): f_buf = ctx.get_buffer(f) g_buf = ctx.get_buffer(g) - for sf, sg in zip(chain(f_buf.shape, f_buf.strides), chain(g_buf.shape, g_buf.strides)): + for sf, sg in zip( + chain(f_buf.shape, f_buf.strides), chain(g_buf.shape, g_buf.strides) + ): # Must be the same assert sf == sg diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index f6c8f85b2b3df2289e809728b9e7b014d6428976..96c1d626ed3b451f09ad7fd886620dfa550ebb29 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -7,11 +7,15 @@ from pystencils import ( create_type, create_numeric_type, TypedSymbol, - DynamicType, + DynamicIndexType, + DynamicNumericType, + DynamicNumericArrayType, ) from pystencils.sympyextensions import tcast from pystencils.sympyextensions.pointers import mem_acc +from pystencils.types import PsArrayType + from pystencils.backend.ast.structural import ( PsAssignment, PsDeclaration, @@ -287,8 +291,9 @@ def test_dynamic_types(): ) freeze = FreezeExpressions(ctx) - x, y = [TypedSymbol(n, DynamicType.NUMERIC_TYPE) for n in "xy"] - p, q = [TypedSymbol(n, DynamicType.INDEX_TYPE) for n in "pq"] + x, y = [TypedSymbol(n, DynamicNumericType()) for n in "xy"] + p, q = [TypedSymbol(n, DynamicIndexType()) for n in "pq"] + arr = TypedSymbol("arr", DynamicNumericArrayType((4, 2))) expr = freeze(x + y) @@ -299,6 +304,9 @@ def test_dynamic_types(): assert ctx.get_symbol("p").dtype == ctx.index_dtype assert ctx.get_symbol("q").dtype == ctx.index_dtype + expr = freeze(arr) + assert ctx.get_symbol("arr").dtype == PsArrayType(ctx.default_dtype, (4, 2)) + def test_cast_func(): ctx = KernelCreationContext(