Skip to content
Snippets Groups Projects
Commit fad3752a authored by Frederik Hennig's avatar Frederik Hennig
Browse files

WIP update field to permit DynamicType

parent 1bbe0b92
No related branches found
No related tags found
1 merge request!443Extended Support for Typing in the Symbolic Toolbox
Pipeline #72494 failed
......@@ -94,6 +94,16 @@ class KernelCreationContext:
"""Data type used by default for index expressions"""
return self._index_dtype
def resolve_dynamic_type(self, dtype: DynamicType | PsType) -> PsType:
"""Selects the appropriate data type for `DynamicType` instances, and returns all other types as they are."""
match dtype:
case DynamicType.NUMERIC_TYPE:
return self._default_dtype
case DynamicType.INDEX_TYPE:
return self._index_dtype
case _:
return dtype
@property
def metadata(self) -> dict[str, Any]:
return self._metadata
......@@ -339,6 +349,8 @@ class KernelCreationContext:
if isinstance(s, TypedSymbol)
)
entry_type = self.resolve_dynamic_type(field.dtype)
if len(idx_types) > 1:
raise KernelConstraintsError(
f"Multiple incompatible types found in index symbols of field {field}: "
......@@ -375,10 +387,10 @@ class KernelCreationContext:
base_ptr = self.get_symbol(
DEFAULTS.field_pointer_name(field.name),
PsPointerType(field.dtype, restrict=True),
PsPointerType(entry_type, restrict=True),
)
return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides)
return PsBuffer(field.name, entry_type, base_ptr, buf_shape, buf_strides)
def _create_buffer_field_buffer(self, field: Field) -> PsBuffer:
if field.spatial_dimensions != 1:
......@@ -418,10 +430,11 @@ class KernelCreationContext:
]
buf_strides = [PsConstant(num_entries, idx_type), PsConstant(1, idx_type)]
buf_dtype = self.resolve_dynamic_type(field.dtype)
base_ptr = self.get_symbol(
DEFAULTS.field_pointer_name(field.name),
PsPointerType(field.dtype, restrict=True),
PsPointerType(buf_dtype, restrict=True),
)
return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides)
return PsBuffer(field.name, buf_dtype, base_ptr, buf_shape, buf_strides)
......@@ -261,14 +261,7 @@ class FreezeExpressions:
return num / denom
def map_TypedSymbol(self, expr: TypedSymbol):
dtype = expr.dtype
match dtype:
case DynamicType.NUMERIC_TYPE:
dtype = self._ctx.default_dtype
case DynamicType.INDEX_TYPE:
dtype = self._ctx.index_dtype
dtype = self._ctx.resolve_dynamic_type(expr.dtype)
symb = self._ctx.get_symbol(expr.name, dtype)
return PsSymbolExpr(symb)
......
......@@ -12,13 +12,13 @@ import sympy as sp
from sympy.core.cache import cacheit
from .defaults import DEFAULTS
from pystencils.alignedarray import aligned_empty
from pystencils.spatial_coordinates import x_staggered_vector, x_vector
from pystencils.stencil import direction_string_to_offset, inverse_direction, offset_to_direction_string
from pystencils.types import PsType, PsStructType, create_type
from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType
from pystencils.sympyextensions import is_integer_sequence
from pystencils.types import UserTypeSpec
from .alignedarray import aligned_empty
from .spatial_coordinates import x_staggered_vector, x_vector
from .stencil import direction_string_to_offset, inverse_direction, offset_to_direction_string
from .types import PsType, PsStructType, create_type
from .sympyextensions.typed_sympy import TypedSymbol, DynamicType
from .sympyextensions import is_integer_sequence
from .types import UserTypeSpec
__all__ = ['Field', 'fields', 'FieldType', 'Field']
......@@ -123,16 +123,16 @@ class Field:
>>> assignments = [Assignment(dst[0,0](i), src[-offset](i)) for i, offset in enumerate(stencil)];
Args:
field_name: something
field_type: something
dtype: something
layout: something
shape: something
strides: something
field_name: The field's name
field_type: The kind of the field
dtype: Data type of the field's entries
layout: Linearization order of the field's spatial dimensions
shape: Total shape (spatial and index) of the field
strides: Linearization strides of the field
"""
@staticmethod
def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec = np.float64, index_dimensions=0,
def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE, index_dimensions=0,
layout='numpy', index_shape=None, field_type=FieldType.GENERIC) -> 'Field':
"""
Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes
......@@ -177,15 +177,15 @@ class Field:
for i in range(total_dimensions)
])
if not isinstance(dtype, DynamicType):
dtype = create_type(dtype)
np_data_type = dtype.numpy_dtype
assert np_data_type is not None
if np_data_type.fields is not None:
if isinstance(dtype, PsStructType):
if index_dimensions != 0:
raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
shape += (1,)
strides += (1,)
if field_type == FieldType.STAGGERED and index_dimensions == 0:
raise ValueError("A staggered field needs at least one index dimension")
......@@ -228,7 +228,7 @@ class Field:
@staticmethod
def create_fixed_size(field_name: str, shape: Tuple[int, ...], index_dimensions: int = 0,
dtype: UserTypeSpec = np.float64, layout: str = 'numpy',
dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE, layout: str = 'numpy',
strides: Optional[Sequence[int]] = None,
field_type=FieldType.GENERIC) -> 'Field':
"""
......@@ -256,11 +256,10 @@ class Field:
assert len(strides) == len(shape)
strides = tuple([s // np.dtype(dtype).itemsize for s in strides])
if not isinstance(dtype, DynamicType):
dtype = create_type(dtype)
numpy_dtype = dtype.numpy_dtype
assert numpy_dtype is not None
if numpy_dtype.fields is not None:
if isinstance(dtype, PsStructType):
if index_dimensions != 0:
raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
shape += (1,)
......@@ -277,7 +276,7 @@ class Field:
self,
field_name: str,
field_type: FieldType,
dtype: UserTypeSpec,
dtype: UserTypeSpec | DynamicType,
layout: tuple[int, ...],
shape,
strides
......@@ -287,7 +286,7 @@ class Field:
assert isinstance(field_type, FieldType)
assert len(shape) == len(strides)
self.field_type = field_type
self._dtype = create_type(dtype)
self._dtype: PsType | DynamicType = create_type(dtype) if not isinstance(dtype, DynamicType) else dtype
self._layout = normalize_layout(layout)
self.shape = shape
self.strides = strides
......@@ -351,12 +350,15 @@ class Field:
return self.strides[self.spatial_dimensions:]
@property
def dtype(self) -> PsType:
def dtype(self) -> PsType | DynamicType:
return self._dtype
@property
def itemsize(self):
def itemsize(self) -> int | None:
if isinstance(self.dtype, PsType):
return self.dtype.itemsize
else:
return None
def __repr__(self):
if any(isinstance(s, sp.Symbol) for s in self.spatial_shape):
......@@ -1094,17 +1096,20 @@ def _parse_description(description):
result = type_description_regex.match(d)
if result:
data_type_str, size_info = result.group(1), result.group(2).strip().lower()
if data_type_str is None:
data_type_str = 'float64'
if data_type_str is not None:
data_type_str = data_type_str.lower().strip()
if not data_type_str:
data_type_str = 'float64'
if data_type_str:
dtype = create_type(data_type_str)
else:
dtype = DynamicType.NUMERIC_TYPE
if size_info.endswith('d'):
size_info = int(size_info[:-1])
else:
size_info = tuple(int(e) for e in size_info.split(","))
return data_type_str, size_info
return dtype, size_info
else:
raise ValueError("Could not parse field description")
......
......@@ -3,7 +3,7 @@ import pytest
import sympy as sp
import pystencils as ps
from pystencils import DEFAULTS
from pystencils import DEFAULTS, DynamicType, create_type, fields
from pystencils.field import (
Field,
FieldType,
......@@ -15,6 +15,7 @@ from pystencils.field import (
def test_field_basic():
f = Field.create_generic("f", spatial_dimensions=2)
assert FieldType.is_generic(f)
assert f.dtype == DynamicType.NUMERIC_TYPE
assert f["E"] == f[1, 0]
assert f["N"] == f[0, 1]
assert "_" in f.center._latex("dummy")
......@@ -41,17 +42,16 @@ def test_field_basic():
assert f1.ndim == f.ndim
assert f1.values_per_cell() == f.values_per_cell()
fixed = ps.fields("f(5, 5) : double[20, 20]")
assert fixed.neighbor_vector((1, 1)).shape == (5, 5)
f = Field.create_fixed_size("f", (10, 10), strides=(80, 8), dtype=np.float64)
assert f.spatial_strides == (10, 1)
assert f.index_strides == ()
assert f.center_vector == sp.Matrix([f.center])
assert f.dtype == create_type("float64")
f1 = f.new_field_with_different_name("f1")
assert f1.ndim == f.ndim
assert f1.values_per_cell() == f.values_per_cell()
assert f1.dtype == create_type("float64")
f = Field.create_fixed_size("f", (8, 8, 2, 2), index_dimensions=2)
assert f.center_vector == sp.Matrix([[f(0, 0), f(0, 1)], [f(1, 0), f(1, 1)]])
......@@ -61,16 +61,42 @@ def test_field_basic():
neighbor = field_access.neighbor(coord_id=0, offset=-2)
assert neighbor.offsets == (-1, 1)
assert "_" in neighbor._latex("dummy")
assert f.dtype == DynamicType.NUMERIC_TYPE
f = Field.create_fixed_size("f", (8, 8, 2, 2, 2), index_dimensions=3)
assert f.center_vector == sp.Array(
[[[f(i, j, k) for k in range(2)] for j in range(2)] for i in range(2)]
)
assert f.dtype == DynamicType.NUMERIC_TYPE
f = Field.create_generic("f", spatial_dimensions=5, index_dimensions=2)
field_access = f[1, -1, 2, -3, 0](1, 0)
assert field_access.offsets == (1, -1, 2, -3, 0)
assert field_access.index == (1, 0)
assert f.dtype == DynamicType.NUMERIC_TYPE
def test_field_description_parsing():
f, g = fields("f(1), g(3): [2D]")
assert f.dtype == g.dtype == DynamicType.NUMERIC_TYPE
assert f.spatial_dimensions == g.spatial_dimensions == 2
assert f.index_shape == (1,)
assert g.index_shape == (3,)
h = fields("h: float32[3D]")
assert h.index_shape == ()
assert h.spatial_dimensions == 3
assert h.index_dimensions == 0
assert h.dtype == create_type("float32")
f: Field = fields("f(5, 5) : double[20, 20]")
assert f.dtype == create_type("float64")
assert f.spatial_shape == (20, 20)
assert f.index_shape == (5, 5)
assert f.neighbor_vector((1, 1)).shape == (5, 5)
def test_error_handling():
......@@ -145,7 +171,7 @@ def test_error_handling():
def test_decorator_scoping():
dst = ps.fields("dst : double[2D]")
dst = fields("dst : double[2D]")
def f1():
a = sp.Symbol("a")
......@@ -165,7 +191,7 @@ def test_decorator_scoping():
def test_string_creation():
x, y, z = ps.fields(" x(4), y(3,5) z : double[ 3, 47]")
x, y, z = fields(" x(4), y(3,5) z : double[ 3, 47]")
assert x.index_shape == (4,)
assert y.index_shape == (3, 5)
assert z.spatial_shape == (3, 47)
......@@ -173,9 +199,9 @@ def test_string_creation():
def test_itemsize():
x = ps.fields("x: float32[1d]")
y = ps.fields("y: float64[2d]")
i = ps.fields("i: int16[1d]")
x = fields("x: float32[1d]")
y = fields("y: float64[2d]")
i = fields("i: int16[1d]")
assert x.itemsize == 4
assert y.itemsize == 8
......@@ -249,7 +275,7 @@ def test_memory_layout_descriptors():
def test_staggered():
# D2Q5
j1, j2, j3 = ps.fields(
j1, j2, j3 = fields(
"j1(2), j2(2,2), j3(2,2,2) : double[2D]", field_type=FieldType.STAGGERED
)
......@@ -296,7 +322,7 @@ def test_staggered():
)
# D2Q9
k1, k2 = ps.fields("k1(4), k2(2) : double[2D]", field_type=FieldType.STAGGERED)
k1, k2 = fields("k1(4), k2(2) : double[2D]", field_type=FieldType.STAGGERED)
assert k1[1, 1](2) == k1.staggered_access("NE")
assert k1[0, 0](2) == k1.staggered_access("SW")
......@@ -319,7 +345,7 @@ def test_staggered():
]
# sign reversed when using as flux field
r = ps.fields("r(2) : double[2D]", field_type=FieldType.STAGGERED_FLUX)
r = fields("r(2) : double[2D]", field_type=FieldType.STAGGERED_FLUX)
assert r[0, 0](0) == r.staggered_access("W")
assert -r[1, 0](0) == r.staggered_access("E")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment