Skip to content
Snippets Groups Projects
Commit b8024740 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Add test cases for coordinate symbols

parent dd900d63
No related branches found
No related tags found
1 merge request!426Refactor DEFAULTS & fix bugs concerning data types of spatial counter symbols
...@@ -359,7 +359,7 @@ def create_sparse_iteration_space( ...@@ -359,7 +359,7 @@ def create_sparse_iteration_space(
dim = archetype_field.spatial_dimensions dim = archetype_field.spatial_dimensions
coord_members = [ coord_members = [
PsStructType.Member(name, ctx.index_dtype) PsStructType.Member(name, ctx.index_dtype)
for name in DEFAULTS._index_struct_coordinate_names[:dim] for name in DEFAULTS.index_struct_coordinate_names[:dim]
] ]
# Determine index field # Determine index field
......
from .types import PsIeeeFloatType, PsIntegerType, PsSignedIntegerType from .types import (
PsIeeeFloatType,
PsIntegerType,
PsSignedIntegerType,
PsStructType,
UserTypeSpec,
create_type,
)
from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType
class SympyDefaults(): class SympyDefaults:
def __init__(self): def __init__(self):
self.numeric_dtype = PsIeeeFloatType(64) self.numeric_dtype = PsIeeeFloatType(64)
"""Default data type for numerical computations""" """Default data type for numerical computations"""
...@@ -21,24 +28,32 @@ class SympyDefaults(): ...@@ -21,24 +28,32 @@ class SympyDefaults():
) )
"""Default spatial counters""" """Default spatial counters"""
self._index_struct_coordinate_names = ("x", "y", "z") self.index_struct_coordinate_names = ("x", "y", "z")
"""Default names of spatial coordinate members in index list structures""" """Default names of spatial coordinate members in index list structures"""
self.sparse_counter_name = "sparse_idx" self.sparse_counter_name = "sparse_idx"
"""Name of the default sparse iteration counter""" """Name of the default sparse iteration counter"""
self.sparse_counter = TypedSymbol(self.sparse_counter_name, DynamicType.INDEX_TYPE) self.sparse_counter = TypedSymbol(
self.sparse_counter_name, DynamicType.INDEX_TYPE
)
"""Default sparse iteration counter.""" """Default sparse iteration counter."""
def field_shape_name(self, field_name: str, coord: int): def field_shape_name(self, field_name: str, coord: int):
return f"_size_{field_name}_{coord}" return f"_size_{field_name}_{coord}"
def field_stride_name(self, field_name: str, coord: int): def field_stride_name(self, field_name: str, coord: int):
return f"_stride_{field_name}_{coord}" return f"_stride_{field_name}_{coord}"
def field_pointer_name(self, field_name: str): def field_pointer_name(self, field_name: str):
return f"_data_{field_name}" return f"_data_{field_name}"
def index_struct(self, index_dtype: UserTypeSpec, dim: int) -> PsStructType:
idx_type = create_type(index_dtype)
return PsStructType(
[(name, idx_type) for name in self.index_struct_coordinate_names[:dim]]
)
DEFAULTS = SympyDefaults() DEFAULTS = SympyDefaults()
"""Default names and symbols used throughout code generation""" """Default names and symbols used throughout code generation"""
import pytest
import numpy as np
from pystencils import (
Field,
Assignment,
create_kernel,
CreateKernelConfig,
DEFAULTS,
FieldType,
)
from pystencils.sympyextensions import CastFunc
@pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"])
def test_spatial_counters_dense(index_dtype):
# Parametrized over index_dtype to make sure the `DynamicType.INDEX` in the
# DEFAULTS works validly
x, y, z = DEFAULTS.spatial_counters
f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
asms = [
Assignment(f(0), CastFunc.as_numeric(z)),
Assignment(f(1), CastFunc.as_numeric(y)),
Assignment(f(2), CastFunc.as_numeric(x)),
]
cfg = CreateKernelConfig(index_dtype=index_dtype)
kernel = create_kernel(asms, cfg).compile()
f_arr = np.zeros((16, 16, 16, 3))
kernel(f=f_arr)
expected = np.mgrid[0:16, 0:16, 0:16].astype(np.float64).transpose()
np.testing.assert_equal(f_arr, expected)
@pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"])
def test_spatial_counters_sparse(index_dtype):
x, y, z = DEFAULTS.spatial_counters
f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
asms = [
Assignment(f(0), CastFunc.as_numeric(x)),
Assignment(f(1), CastFunc.as_numeric(y)),
Assignment(f(2), CastFunc.as_numeric(z)),
]
idx_struct = DEFAULTS.index_struct(index_dtype, 3)
idx_field = Field.create_generic(
"index", 1, idx_struct, field_type=FieldType.INDEXED
)
cfg = CreateKernelConfig(index_dtype=index_dtype, index_field=idx_field)
kernel = create_kernel(asms, cfg).compile()
f_arr = np.zeros((16, 16, 16, 3))
idx_arr = np.array(
[(1, 4, 3), (5, 1, 6), (9, 5, 1), (3, 13, 7)], dtype=idx_struct.numpy_dtype
)
kernel(f=f_arr, index=idx_arr)
for t in idx_arr:
assert f_arr[t[0], t[1], t[2], 0] == t[0].astype(np.float64)
assert f_arr[t[0], t[1], t[2], 1] == t[1].astype(np.float64)
assert f_arr[t[0], t[1], t[2], 2] == t[2].astype(np.float64)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment