Skip to content
Snippets Groups Projects
Commit 69a63b0b authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Refactor DEFAULTS & fix bugs concerning data types of spatial counter symbols

parent 778222bf
1 merge request!426Refactor DEFAULTS & fix bugs concerning data types of spatial counter symbols
......@@ -13,7 +13,7 @@ from ..memory import PsSymbol, PsBuffer
from ..constants import PsConstant
from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem
from ..ast.util import failing_cast
from ...types import PsStructType, constify
from ...types import PsStructType
from ..exceptions import PsInputError, KernelConstraintsError
if TYPE_CHECKING:
......@@ -359,7 +359,7 @@ def create_sparse_iteration_space(
dim = archetype_field.spatial_dimensions
coord_members = [
PsStructType.Member(name, ctx.index_dtype)
for name in DEFAULTS._index_struct_coordinate_names[:dim]
for name in DEFAULTS.index_struct_coordinate_names[:dim]
]
# Determine index field
......@@ -379,7 +379,7 @@ def create_sparse_iteration_space(
)
spatial_counters = [
ctx.get_symbol(name, constify(ctx.index_dtype))
ctx.get_symbol(name, ctx.index_dtype)
for name in DEFAULTS.spatial_counter_names[:dim]
]
......
from typing import TypeVar, Generic, Callable
from .types import PsType, PsIeeeFloatType, PsIntegerType, PsSignedIntegerType, PsStructType
from .types import (
PsIeeeFloatType,
PsIntegerType,
PsSignedIntegerType,
PsStructType,
UserTypeSpec,
create_type,
)
from pystencils.sympyextensions.typed_sympy import TypedSymbol
from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType
SymbolT = TypeVar("SymbolT")
class GenericDefaults(Generic[SymbolT]):
def __init__(self, symcreate: Callable[[str, PsType], SymbolT]):
class SympyDefaults:
def __init__(self):
self.numeric_dtype = PsIeeeFloatType(64)
"""Default data type for numerical computations"""
......@@ -18,37 +22,38 @@ class GenericDefaults(Generic[SymbolT]):
"""Names of the default spatial counters"""
self.spatial_counters = (
symcreate("ctr_0", self.index_dtype),
symcreate("ctr_1", self.index_dtype),
symcreate("ctr_2", self.index_dtype),
TypedSymbol("ctr_0", DynamicType.INDEX_TYPE),
TypedSymbol("ctr_1", DynamicType.INDEX_TYPE),
TypedSymbol("ctr_2", DynamicType.INDEX_TYPE),
)
"""Default spatial counters"""
self._index_struct_coordinate_names = ("x", "y", "z")
self.index_struct_coordinate_names = ("x", "y", "z")
"""Default names of spatial coordinate members in index list structures"""
self.index_struct_coordinates = (
PsStructType.Member("x", self.index_dtype),
PsStructType.Member("y", self.index_dtype),
PsStructType.Member("z", self.index_dtype),
)
"""Default spatial coordinate members in index list structures"""
self.sparse_counter_name = "sparse_idx"
"""Name of the default sparse iteration counter"""
self.sparse_counter = symcreate(self.sparse_counter_name, self.index_dtype)
self.sparse_counter = TypedSymbol(
self.sparse_counter_name, DynamicType.INDEX_TYPE
)
"""Default sparse iteration counter."""
def field_shape_name(self, field_name: str, coord: int):
return f"_size_{field_name}_{coord}"
def field_stride_name(self, field_name: str, coord: int):
return f"_stride_{field_name}_{coord}"
def field_pointer_name(self, field_name: str):
return f"_data_{field_name}"
def index_struct(self, index_dtype: UserTypeSpec, dim: int) -> PsStructType:
idx_type = create_type(index_dtype)
return PsStructType(
[(name, idx_type) for name in self.index_struct_coordinate_names[:dim]]
)
DEFAULTS = GenericDefaults[TypedSymbol](TypedSymbol)
DEFAULTS = SympyDefaults()
"""Default names and symbols used throughout code generation"""
import pytest
import numpy as np
from pystencils import (
Field,
Assignment,
create_kernel,
CreateKernelConfig,
DEFAULTS,
FieldType,
)
from pystencils.sympyextensions import CastFunc
@pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"])
def test_spatial_counters_dense(index_dtype):
# Parametrized over index_dtype to make sure the `DynamicType.INDEX` in the
# DEFAULTS works validly
x, y, z = DEFAULTS.spatial_counters
f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
asms = [
Assignment(f(0), CastFunc.as_numeric(z)),
Assignment(f(1), CastFunc.as_numeric(y)),
Assignment(f(2), CastFunc.as_numeric(x)),
]
cfg = CreateKernelConfig(index_dtype=index_dtype)
kernel = create_kernel(asms, cfg).compile()
f_arr = np.zeros((16, 16, 16, 3))
kernel(f=f_arr)
expected = np.mgrid[0:16, 0:16, 0:16].astype(np.float64).transpose()
np.testing.assert_equal(f_arr, expected)
@pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"])
def test_spatial_counters_sparse(index_dtype):
x, y, z = DEFAULTS.spatial_counters
f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
asms = [
Assignment(f(0), CastFunc.as_numeric(x)),
Assignment(f(1), CastFunc.as_numeric(y)),
Assignment(f(2), CastFunc.as_numeric(z)),
]
idx_struct = DEFAULTS.index_struct(index_dtype, 3)
idx_field = Field.create_generic(
"index", 1, idx_struct, field_type=FieldType.INDEXED
)
cfg = CreateKernelConfig(index_dtype=index_dtype, index_field=idx_field)
kernel = create_kernel(asms, cfg).compile()
f_arr = np.zeros((16, 16, 16, 3))
idx_arr = np.array(
[(1, 4, 3), (5, 1, 6), (9, 5, 1), (3, 13, 7)], dtype=idx_struct.numpy_dtype
)
kernel(f=f_arr, index=idx_arr)
for t in idx_arr:
assert f_arr[t[0], t[1], t[2], 0] == t[0].astype(np.float64)
assert f_arr[t[0], t[1], t[2], 1] == t[1].astype(np.float64)
assert f_arr[t[0], t[1], t[2], 2] == t[2].astype(np.float64)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment