diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 4f057e1fc8c3c68c6dae3d35ace54c8be8f3de21..ba77ad24d70f97d2b17087448335ab14fd10b27f 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -13,7 +13,7 @@ from ..memory import PsSymbol, PsBuffer from ..constants import PsConstant from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem from ..ast.util import failing_cast -from ...types import PsStructType, constify +from ...types import PsStructType from ..exceptions import PsInputError, KernelConstraintsError if TYPE_CHECKING: @@ -359,7 +359,7 @@ def create_sparse_iteration_space( dim = archetype_field.spatial_dimensions coord_members = [ PsStructType.Member(name, ctx.index_dtype) - for name in DEFAULTS._index_struct_coordinate_names[:dim] + for name in DEFAULTS.index_struct_coordinate_names[:dim] ] # Determine index field @@ -379,7 +379,7 @@ def create_sparse_iteration_space( ) spatial_counters = [ - ctx.get_symbol(name, constify(ctx.index_dtype)) + ctx.get_symbol(name, ctx.index_dtype) for name in DEFAULTS.spatial_counter_names[:dim] ] diff --git a/src/pystencils/defaults.py b/src/pystencils/defaults.py index c7ac33347976bc2c5beca596ae78c553e0b3bf19..0b6a48af1944f81a29d6dcd4b0d29e90aeb4ce2d 100644 --- a/src/pystencils/defaults.py +++ b/src/pystencils/defaults.py @@ -1,13 +1,17 @@ -from typing import TypeVar, Generic, Callable -from .types import PsType, PsIeeeFloatType, PsIntegerType, PsSignedIntegerType, PsStructType +from .types import ( + PsIeeeFloatType, + PsIntegerType, + PsSignedIntegerType, + PsStructType, + UserTypeSpec, + create_type, +) -from pystencils.sympyextensions.typed_sympy import TypedSymbol +from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType -SymbolT = TypeVar("SymbolT") - -class GenericDefaults(Generic[SymbolT]): - def __init__(self, symcreate: Callable[[str, PsType], SymbolT]): +class SympyDefaults: + def __init__(self): self.numeric_dtype = PsIeeeFloatType(64) """Default data type for numerical computations""" @@ -18,37 +22,38 @@ class GenericDefaults(Generic[SymbolT]): """Names of the default spatial counters""" self.spatial_counters = ( - symcreate("ctr_0", self.index_dtype), - symcreate("ctr_1", self.index_dtype), - symcreate("ctr_2", self.index_dtype), + TypedSymbol("ctr_0", DynamicType.INDEX_TYPE), + TypedSymbol("ctr_1", DynamicType.INDEX_TYPE), + TypedSymbol("ctr_2", DynamicType.INDEX_TYPE), ) """Default spatial counters""" - self._index_struct_coordinate_names = ("x", "y", "z") + self.index_struct_coordinate_names = ("x", "y", "z") """Default names of spatial coordinate members in index list structures""" - self.index_struct_coordinates = ( - PsStructType.Member("x", self.index_dtype), - PsStructType.Member("y", self.index_dtype), - PsStructType.Member("z", self.index_dtype), - ) - """Default spatial coordinate members in index list structures""" - self.sparse_counter_name = "sparse_idx" """Name of the default sparse iteration counter""" - self.sparse_counter = symcreate(self.sparse_counter_name, self.index_dtype) + self.sparse_counter = TypedSymbol( + self.sparse_counter_name, DynamicType.INDEX_TYPE + ) """Default sparse iteration counter.""" def field_shape_name(self, field_name: str, coord: int): return f"_size_{field_name}_{coord}" - + def field_stride_name(self, field_name: str, coord: int): return f"_stride_{field_name}_{coord}" - + def field_pointer_name(self, field_name: str): return f"_data_{field_name}" + def index_struct(self, index_dtype: UserTypeSpec, dim: int) -> PsStructType: + idx_type = create_type(index_dtype) + return PsStructType( + [(name, idx_type) for name in self.index_struct_coordinate_names[:dim]] + ) + -DEFAULTS = GenericDefaults[TypedSymbol](TypedSymbol) +DEFAULTS = SympyDefaults() """Default names and symbols used throughout code generation""" diff --git a/src/pystencils/py.typed b/src/pystencils/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/kernelcreation/test_spatial_counters.py b/tests/kernelcreation/test_spatial_counters.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb365294c98311943c370cb650694b1a4bd8613 --- /dev/null +++ b/tests/kernelcreation/test_spatial_counters.py @@ -0,0 +1,70 @@ +import pytest +import numpy as np + +from pystencils import ( + Field, + Assignment, + create_kernel, + CreateKernelConfig, + DEFAULTS, + FieldType, +) +from pystencils.sympyextensions import CastFunc + + +@pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"]) +def test_spatial_counters_dense(index_dtype): + # Parametrized over index_dtype to make sure the `DynamicType.INDEX` in the + # DEFAULTS works validly + x, y, z = DEFAULTS.spatial_counters + + f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx") + + asms = [ + Assignment(f(0), CastFunc.as_numeric(z)), + Assignment(f(1), CastFunc.as_numeric(y)), + Assignment(f(2), CastFunc.as_numeric(x)), + ] + + cfg = CreateKernelConfig(index_dtype=index_dtype) + kernel = create_kernel(asms, cfg).compile() + + f_arr = np.zeros((16, 16, 16, 3)) + kernel(f=f_arr) + + expected = np.mgrid[0:16, 0:16, 0:16].astype(np.float64).transpose() + + np.testing.assert_equal(f_arr, expected) + + +@pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"]) +def test_spatial_counters_sparse(index_dtype): + x, y, z = DEFAULTS.spatial_counters + + f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx") + + asms = [ + Assignment(f(0), CastFunc.as_numeric(x)), + Assignment(f(1), CastFunc.as_numeric(y)), + Assignment(f(2), CastFunc.as_numeric(z)), + ] + + idx_struct = DEFAULTS.index_struct(index_dtype, 3) + idx_field = Field.create_generic( + "index", 1, idx_struct, field_type=FieldType.INDEXED + ) + + cfg = CreateKernelConfig(index_dtype=index_dtype, index_field=idx_field) + kernel = create_kernel(asms, cfg).compile() + + f_arr = np.zeros((16, 16, 16, 3)) + idx_arr = np.array( + [(1, 4, 3), (5, 1, 6), (9, 5, 1), (3, 13, 7)], dtype=idx_struct.numpy_dtype + ) + + kernel(f=f_arr, index=idx_arr) + + for t in idx_arr: + assert f_arr[t[0], t[1], t[2], 0] == t[0].astype(np.float64) + assert f_arr[t[0], t[1], t[2], 1] == t[1].astype(np.float64) + assert f_arr[t[0], t[1], t[2], 2] == t[2].astype(np.float64)