Skip to content
Snippets Groups Projects
Commit 722fa3ce authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Merge branch 'fhennig/dtype-bugs' into 'v2.0-dev'

Refactor DEFAULTS & fix bugs concerning data types of spatial counter symbols

See merge request !426
parents 778222bf 69a63b0b
No related branches found
No related tags found
1 merge request!426Refactor DEFAULTS & fix bugs concerning data types of spatial counter symbols
Pipeline #70221 passed
...@@ -13,7 +13,7 @@ from ..memory import PsSymbol, PsBuffer ...@@ -13,7 +13,7 @@ from ..memory import PsSymbol, PsBuffer
from ..constants import PsConstant from ..constants import PsConstant
from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem
from ..ast.util import failing_cast from ..ast.util import failing_cast
from ...types import PsStructType, constify from ...types import PsStructType
from ..exceptions import PsInputError, KernelConstraintsError from ..exceptions import PsInputError, KernelConstraintsError
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -359,7 +359,7 @@ def create_sparse_iteration_space( ...@@ -359,7 +359,7 @@ def create_sparse_iteration_space(
dim = archetype_field.spatial_dimensions dim = archetype_field.spatial_dimensions
coord_members = [ coord_members = [
PsStructType.Member(name, ctx.index_dtype) PsStructType.Member(name, ctx.index_dtype)
for name in DEFAULTS._index_struct_coordinate_names[:dim] for name in DEFAULTS.index_struct_coordinate_names[:dim]
] ]
# Determine index field # Determine index field
...@@ -379,7 +379,7 @@ def create_sparse_iteration_space( ...@@ -379,7 +379,7 @@ def create_sparse_iteration_space(
) )
spatial_counters = [ spatial_counters = [
ctx.get_symbol(name, constify(ctx.index_dtype)) ctx.get_symbol(name, ctx.index_dtype)
for name in DEFAULTS.spatial_counter_names[:dim] for name in DEFAULTS.spatial_counter_names[:dim]
] ]
......
from typing import TypeVar, Generic, Callable from .types import (
from .types import PsType, PsIeeeFloatType, PsIntegerType, PsSignedIntegerType, PsStructType PsIeeeFloatType,
PsIntegerType,
PsSignedIntegerType,
PsStructType,
UserTypeSpec,
create_type,
)
from pystencils.sympyextensions.typed_sympy import TypedSymbol from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType
SymbolT = TypeVar("SymbolT")
class SympyDefaults:
class GenericDefaults(Generic[SymbolT]): def __init__(self):
def __init__(self, symcreate: Callable[[str, PsType], SymbolT]):
self.numeric_dtype = PsIeeeFloatType(64) self.numeric_dtype = PsIeeeFloatType(64)
"""Default data type for numerical computations""" """Default data type for numerical computations"""
...@@ -18,37 +22,38 @@ class GenericDefaults(Generic[SymbolT]): ...@@ -18,37 +22,38 @@ class GenericDefaults(Generic[SymbolT]):
"""Names of the default spatial counters""" """Names of the default spatial counters"""
self.spatial_counters = ( self.spatial_counters = (
symcreate("ctr_0", self.index_dtype), TypedSymbol("ctr_0", DynamicType.INDEX_TYPE),
symcreate("ctr_1", self.index_dtype), TypedSymbol("ctr_1", DynamicType.INDEX_TYPE),
symcreate("ctr_2", self.index_dtype), TypedSymbol("ctr_2", DynamicType.INDEX_TYPE),
) )
"""Default spatial counters""" """Default spatial counters"""
self._index_struct_coordinate_names = ("x", "y", "z") self.index_struct_coordinate_names = ("x", "y", "z")
"""Default names of spatial coordinate members in index list structures""" """Default names of spatial coordinate members in index list structures"""
self.index_struct_coordinates = (
PsStructType.Member("x", self.index_dtype),
PsStructType.Member("y", self.index_dtype),
PsStructType.Member("z", self.index_dtype),
)
"""Default spatial coordinate members in index list structures"""
self.sparse_counter_name = "sparse_idx" self.sparse_counter_name = "sparse_idx"
"""Name of the default sparse iteration counter""" """Name of the default sparse iteration counter"""
self.sparse_counter = symcreate(self.sparse_counter_name, self.index_dtype) self.sparse_counter = TypedSymbol(
self.sparse_counter_name, DynamicType.INDEX_TYPE
)
"""Default sparse iteration counter.""" """Default sparse iteration counter."""
def field_shape_name(self, field_name: str, coord: int): def field_shape_name(self, field_name: str, coord: int):
return f"_size_{field_name}_{coord}" return f"_size_{field_name}_{coord}"
def field_stride_name(self, field_name: str, coord: int): def field_stride_name(self, field_name: str, coord: int):
return f"_stride_{field_name}_{coord}" return f"_stride_{field_name}_{coord}"
def field_pointer_name(self, field_name: str): def field_pointer_name(self, field_name: str):
return f"_data_{field_name}" return f"_data_{field_name}"
def index_struct(self, index_dtype: UserTypeSpec, dim: int) -> PsStructType:
idx_type = create_type(index_dtype)
return PsStructType(
[(name, idx_type) for name in self.index_struct_coordinate_names[:dim]]
)
DEFAULTS = GenericDefaults[TypedSymbol](TypedSymbol) DEFAULTS = SympyDefaults()
"""Default names and symbols used throughout code generation""" """Default names and symbols used throughout code generation"""
import pytest
import numpy as np
from pystencils import (
Field,
Assignment,
create_kernel,
CreateKernelConfig,
DEFAULTS,
FieldType,
)
from pystencils.sympyextensions import CastFunc
@pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"])
def test_spatial_counters_dense(index_dtype):
# Parametrized over index_dtype to make sure the `DynamicType.INDEX` in the
# DEFAULTS works validly
x, y, z = DEFAULTS.spatial_counters
f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
asms = [
Assignment(f(0), CastFunc.as_numeric(z)),
Assignment(f(1), CastFunc.as_numeric(y)),
Assignment(f(2), CastFunc.as_numeric(x)),
]
cfg = CreateKernelConfig(index_dtype=index_dtype)
kernel = create_kernel(asms, cfg).compile()
f_arr = np.zeros((16, 16, 16, 3))
kernel(f=f_arr)
expected = np.mgrid[0:16, 0:16, 0:16].astype(np.float64).transpose()
np.testing.assert_equal(f_arr, expected)
@pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"])
def test_spatial_counters_sparse(index_dtype):
x, y, z = DEFAULTS.spatial_counters
f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
asms = [
Assignment(f(0), CastFunc.as_numeric(x)),
Assignment(f(1), CastFunc.as_numeric(y)),
Assignment(f(2), CastFunc.as_numeric(z)),
]
idx_struct = DEFAULTS.index_struct(index_dtype, 3)
idx_field = Field.create_generic(
"index", 1, idx_struct, field_type=FieldType.INDEXED
)
cfg = CreateKernelConfig(index_dtype=index_dtype, index_field=idx_field)
kernel = create_kernel(asms, cfg).compile()
f_arr = np.zeros((16, 16, 16, 3))
idx_arr = np.array(
[(1, 4, 3), (5, 1, 6), (9, 5, 1), (3, 13, 7)], dtype=idx_struct.numpy_dtype
)
kernel(f=f_arr, index=idx_arr)
for t in idx_arr:
assert f_arr[t[0], t[1], t[2], 0] == t[0].astype(np.float64)
assert f_arr[t[0], t[1], t[2], 1] == t[1].astype(np.float64)
assert f_arr[t[0], t[1], t[2], 2] == t[2].astype(np.float64)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment