Skip to content
Snippets Groups Projects
Commit dd900d63 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Refactor DEFAULTS to use DynamicType. Do not constify spatial counters in sparse iteration space.

parent 778222bf
No related branches found
No related tags found
No related merge requests found
Pipeline #70099 passed
......@@ -13,7 +13,7 @@ from ..memory import PsSymbol, PsBuffer
from ..constants import PsConstant
from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem
from ..ast.util import failing_cast
from ...types import PsStructType, constify
from ...types import PsStructType
from ..exceptions import PsInputError, KernelConstraintsError
if TYPE_CHECKING:
......@@ -379,7 +379,7 @@ def create_sparse_iteration_space(
)
spatial_counters = [
ctx.get_symbol(name, constify(ctx.index_dtype))
ctx.get_symbol(name, ctx.index_dtype)
for name in DEFAULTS.spatial_counter_names[:dim]
]
......
from typing import TypeVar, Generic, Callable
from .types import PsType, PsIeeeFloatType, PsIntegerType, PsSignedIntegerType, PsStructType
from .types import PsIeeeFloatType, PsIntegerType, PsSignedIntegerType
from pystencils.sympyextensions.typed_sympy import TypedSymbol
from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType
SymbolT = TypeVar("SymbolT")
class GenericDefaults(Generic[SymbolT]):
def __init__(self, symcreate: Callable[[str, PsType], SymbolT]):
class SympyDefaults():
def __init__(self):
self.numeric_dtype = PsIeeeFloatType(64)
"""Default data type for numerical computations"""
......@@ -18,26 +15,19 @@ class GenericDefaults(Generic[SymbolT]):
"""Names of the default spatial counters"""
self.spatial_counters = (
symcreate("ctr_0", self.index_dtype),
symcreate("ctr_1", self.index_dtype),
symcreate("ctr_2", self.index_dtype),
TypedSymbol("ctr_0", DynamicType.INDEX_TYPE),
TypedSymbol("ctr_1", DynamicType.INDEX_TYPE),
TypedSymbol("ctr_2", DynamicType.INDEX_TYPE),
)
"""Default spatial counters"""
self._index_struct_coordinate_names = ("x", "y", "z")
"""Default names of spatial coordinate members in index list structures"""
self.index_struct_coordinates = (
PsStructType.Member("x", self.index_dtype),
PsStructType.Member("y", self.index_dtype),
PsStructType.Member("z", self.index_dtype),
)
"""Default spatial coordinate members in index list structures"""
self.sparse_counter_name = "sparse_idx"
"""Name of the default sparse iteration counter"""
self.sparse_counter = symcreate(self.sparse_counter_name, self.index_dtype)
self.sparse_counter = TypedSymbol(self.sparse_counter_name, DynamicType.INDEX_TYPE)
"""Default sparse iteration counter."""
def field_shape_name(self, field_name: str, coord: int):
......@@ -50,5 +40,5 @@ class GenericDefaults(Generic[SymbolT]):
return f"_data_{field_name}"
DEFAULTS = GenericDefaults[TypedSymbol](TypedSymbol)
DEFAULTS = SympyDefaults()
"""Default names and symbols used throughout code generation"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment