diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 4f057e1fc8c3c68c6dae3d35ace54c8be8f3de21..a8ec9e80c47c0f1fa82d1187de2d00f1606504a5 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -13,7 +13,7 @@ from ..memory import PsSymbol, PsBuffer from ..constants import PsConstant from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem from ..ast.util import failing_cast -from ...types import PsStructType, constify +from ...types import PsStructType from ..exceptions import PsInputError, KernelConstraintsError if TYPE_CHECKING: @@ -379,7 +379,7 @@ def create_sparse_iteration_space( ) spatial_counters = [ - ctx.get_symbol(name, constify(ctx.index_dtype)) + ctx.get_symbol(name, ctx.index_dtype) for name in DEFAULTS.spatial_counter_names[:dim] ] diff --git a/src/pystencils/defaults.py b/src/pystencils/defaults.py index c7ac33347976bc2c5beca596ae78c553e0b3bf19..f1ac0efe59c670e4189565a3f9081249b20858c4 100644 --- a/src/pystencils/defaults.py +++ b/src/pystencils/defaults.py @@ -1,13 +1,10 @@ -from typing import TypeVar, Generic, Callable -from .types import PsType, PsIeeeFloatType, PsIntegerType, PsSignedIntegerType, PsStructType +from .types import PsIeeeFloatType, PsIntegerType, PsSignedIntegerType -from pystencils.sympyextensions.typed_sympy import TypedSymbol +from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType -SymbolT = TypeVar("SymbolT") - -class GenericDefaults(Generic[SymbolT]): - def __init__(self, symcreate: Callable[[str, PsType], SymbolT]): +class SympyDefaults(): + def __init__(self): self.numeric_dtype = PsIeeeFloatType(64) """Default data type for numerical computations""" @@ -18,26 +15,19 @@ class GenericDefaults(Generic[SymbolT]): """Names of the default spatial counters""" self.spatial_counters = ( - symcreate("ctr_0", self.index_dtype), - symcreate("ctr_1", self.index_dtype), - symcreate("ctr_2", self.index_dtype), + TypedSymbol("ctr_0", DynamicType.INDEX_TYPE), + TypedSymbol("ctr_1", DynamicType.INDEX_TYPE), + TypedSymbol("ctr_2", DynamicType.INDEX_TYPE), ) """Default spatial counters""" self._index_struct_coordinate_names = ("x", "y", "z") """Default names of spatial coordinate members in index list structures""" - self.index_struct_coordinates = ( - PsStructType.Member("x", self.index_dtype), - PsStructType.Member("y", self.index_dtype), - PsStructType.Member("z", self.index_dtype), - ) - """Default spatial coordinate members in index list structures""" - self.sparse_counter_name = "sparse_idx" """Name of the default sparse iteration counter""" - self.sparse_counter = symcreate(self.sparse_counter_name, self.index_dtype) + self.sparse_counter = TypedSymbol(self.sparse_counter_name, DynamicType.INDEX_TYPE) """Default sparse iteration counter.""" def field_shape_name(self, field_name: str, coord: int): @@ -50,5 +40,5 @@ class GenericDefaults(Generic[SymbolT]): return f"_data_{field_name}" -DEFAULTS = GenericDefaults[TypedSymbol](TypedSymbol) +DEFAULTS = SympyDefaults() """Default names and symbols used throughout code generation"""