From fad3752a84dfbdb22857d2066e8b835fb573458f Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 23 Jan 2025 09:03:40 +0100 Subject: [PATCH] WIP update field to permit DynamicType --- .../backend/kernelcreation/context.py | 21 +++++- .../backend/kernelcreation/freeze.py | 9 +-- src/pystencils/field.py | 75 ++++++++++--------- tests/frontend/test_field.py | 50 ++++++++++--- 4 files changed, 96 insertions(+), 59 deletions(-) diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 8f5931c64..68da893ff 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -93,6 +93,16 @@ class KernelCreationContext: def index_dtype(self) -> PsIntegerType: """Data type used by default for index expressions""" return self._index_dtype + + def resolve_dynamic_type(self, dtype: DynamicType | PsType) -> PsType: + """Selects the appropriate data type for `DynamicType` instances, and returns all other types as they are.""" + match dtype: + case DynamicType.NUMERIC_TYPE: + return self._default_dtype + case DynamicType.INDEX_TYPE: + return self._index_dtype + case _: + return dtype @property def metadata(self) -> dict[str, Any]: @@ -339,6 +349,8 @@ class KernelCreationContext: if isinstance(s, TypedSymbol) ) + entry_type = self.resolve_dynamic_type(field.dtype) + if len(idx_types) > 1: raise KernelConstraintsError( f"Multiple incompatible types found in index symbols of field {field}: " @@ -375,10 +387,10 @@ class KernelCreationContext: base_ptr = self.get_symbol( DEFAULTS.field_pointer_name(field.name), - PsPointerType(field.dtype, restrict=True), + PsPointerType(entry_type, restrict=True), ) - return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides) + return PsBuffer(field.name, entry_type, base_ptr, buf_shape, buf_strides) def _create_buffer_field_buffer(self, field: Field) -> PsBuffer: if field.spatial_dimensions != 1: @@ -418,10 +430,11 @@ class KernelCreationContext: ] buf_strides = [PsConstant(num_entries, idx_type), PsConstant(1, idx_type)] + buf_dtype = self.resolve_dynamic_type(field.dtype) base_ptr = self.get_symbol( DEFAULTS.field_pointer_name(field.name), - PsPointerType(field.dtype, restrict=True), + PsPointerType(buf_dtype, restrict=True), ) - return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides) + return PsBuffer(field.name, buf_dtype, base_ptr, buf_shape, buf_strides) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 16710861b..5d18ecf82 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -261,14 +261,7 @@ class FreezeExpressions: return num / denom def map_TypedSymbol(self, expr: TypedSymbol): - dtype = expr.dtype - - match dtype: - case DynamicType.NUMERIC_TYPE: - dtype = self._ctx.default_dtype - case DynamicType.INDEX_TYPE: - dtype = self._ctx.index_dtype - + dtype = self._ctx.resolve_dynamic_type(expr.dtype) symb = self._ctx.get_symbol(expr.name, dtype) return PsSymbolExpr(symb) diff --git a/src/pystencils/field.py b/src/pystencils/field.py index 1a3a13b73..826c551a7 100644 --- a/src/pystencils/field.py +++ b/src/pystencils/field.py @@ -12,13 +12,13 @@ import sympy as sp from sympy.core.cache import cacheit from .defaults import DEFAULTS -from pystencils.alignedarray import aligned_empty -from pystencils.spatial_coordinates import x_staggered_vector, x_vector -from pystencils.stencil import direction_string_to_offset, inverse_direction, offset_to_direction_string -from pystencils.types import PsType, PsStructType, create_type -from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType -from pystencils.sympyextensions import is_integer_sequence -from pystencils.types import UserTypeSpec +from .alignedarray import aligned_empty +from .spatial_coordinates import x_staggered_vector, x_vector +from .stencil import direction_string_to_offset, inverse_direction, offset_to_direction_string +from .types import PsType, PsStructType, create_type +from .sympyextensions.typed_sympy import TypedSymbol, DynamicType +from .sympyextensions import is_integer_sequence +from .types import UserTypeSpec __all__ = ['Field', 'fields', 'FieldType', 'Field'] @@ -123,16 +123,16 @@ class Field: >>> assignments = [Assignment(dst[0,0](i), src[-offset](i)) for i, offset in enumerate(stencil)]; Args: - field_name: something - field_type: something - dtype: something - layout: something - shape: something - strides: something + field_name: The field's name + field_type: The kind of the field + dtype: Data type of the field's entries + layout: Linearization order of the field's spatial dimensions + shape: Total shape (spatial and index) of the field + strides: Linearization strides of the field """ @staticmethod - def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec = np.float64, index_dimensions=0, + def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE, index_dimensions=0, layout='numpy', index_shape=None, field_type=FieldType.GENERIC) -> 'Field': """ Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes @@ -177,15 +177,15 @@ class Field: for i in range(total_dimensions) ]) - dtype = create_type(dtype) - np_data_type = dtype.numpy_dtype - assert np_data_type is not None - - if np_data_type.fields is not None: + if not isinstance(dtype, DynamicType): + dtype = create_type(dtype) + + if isinstance(dtype, PsStructType): if index_dimensions != 0: raise ValueError("Structured arrays/fields are not allowed to have an index dimension") shape += (1,) strides += (1,) + if field_type == FieldType.STAGGERED and index_dimensions == 0: raise ValueError("A staggered field needs at least one index dimension") @@ -228,7 +228,7 @@ class Field: @staticmethod def create_fixed_size(field_name: str, shape: Tuple[int, ...], index_dimensions: int = 0, - dtype: UserTypeSpec = np.float64, layout: str = 'numpy', + dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE, layout: str = 'numpy', strides: Optional[Sequence[int]] = None, field_type=FieldType.GENERIC) -> 'Field': """ @@ -256,11 +256,10 @@ class Field: assert len(strides) == len(shape) strides = tuple([s // np.dtype(dtype).itemsize for s in strides]) - dtype = create_type(dtype) - numpy_dtype = dtype.numpy_dtype - assert numpy_dtype is not None + if not isinstance(dtype, DynamicType): + dtype = create_type(dtype) - if numpy_dtype.fields is not None: + if isinstance(dtype, PsStructType): if index_dimensions != 0: raise ValueError("Structured arrays/fields are not allowed to have an index dimension") shape += (1,) @@ -277,7 +276,7 @@ class Field: self, field_name: str, field_type: FieldType, - dtype: UserTypeSpec, + dtype: UserTypeSpec | DynamicType, layout: tuple[int, ...], shape, strides @@ -287,7 +286,7 @@ class Field: assert isinstance(field_type, FieldType) assert len(shape) == len(strides) self.field_type = field_type - self._dtype = create_type(dtype) + self._dtype: PsType | DynamicType = create_type(dtype) if not isinstance(dtype, DynamicType) else dtype self._layout = normalize_layout(layout) self.shape = shape self.strides = strides @@ -351,12 +350,15 @@ class Field: return self.strides[self.spatial_dimensions:] @property - def dtype(self) -> PsType: + def dtype(self) -> PsType | DynamicType: return self._dtype @property - def itemsize(self): - return self.dtype.itemsize + def itemsize(self) -> int | None: + if isinstance(self.dtype, PsType): + return self.dtype.itemsize + else: + return None def __repr__(self): if any(isinstance(s, sp.Symbol) for s in self.spatial_shape): @@ -1094,17 +1096,20 @@ def _parse_description(description): result = type_description_regex.match(d) if result: data_type_str, size_info = result.group(1), result.group(2).strip().lower() - if data_type_str is None: - data_type_str = 'float64' - data_type_str = data_type_str.lower().strip() + if data_type_str is not None: + data_type_str = data_type_str.lower().strip() + + if data_type_str: + dtype = create_type(data_type_str) + else: + dtype = DynamicType.NUMERIC_TYPE - if not data_type_str: - data_type_str = 'float64' if size_info.endswith('d'): size_info = int(size_info[:-1]) else: size_info = tuple(int(e) for e in size_info.split(",")) - return data_type_str, size_info + + return dtype, size_info else: raise ValueError("Could not parse field description") diff --git a/tests/frontend/test_field.py b/tests/frontend/test_field.py index dc804491b..6521e114f 100644 --- a/tests/frontend/test_field.py +++ b/tests/frontend/test_field.py @@ -3,7 +3,7 @@ import pytest import sympy as sp import pystencils as ps -from pystencils import DEFAULTS +from pystencils import DEFAULTS, DynamicType, create_type, fields from pystencils.field import ( Field, FieldType, @@ -15,6 +15,7 @@ from pystencils.field import ( def test_field_basic(): f = Field.create_generic("f", spatial_dimensions=2) assert FieldType.is_generic(f) + assert f.dtype == DynamicType.NUMERIC_TYPE assert f["E"] == f[1, 0] assert f["N"] == f[0, 1] assert "_" in f.center._latex("dummy") @@ -41,17 +42,16 @@ def test_field_basic(): assert f1.ndim == f.ndim assert f1.values_per_cell() == f.values_per_cell() - fixed = ps.fields("f(5, 5) : double[20, 20]") - assert fixed.neighbor_vector((1, 1)).shape == (5, 5) - f = Field.create_fixed_size("f", (10, 10), strides=(80, 8), dtype=np.float64) assert f.spatial_strides == (10, 1) assert f.index_strides == () assert f.center_vector == sp.Matrix([f.center]) + assert f.dtype == create_type("float64") f1 = f.new_field_with_different_name("f1") assert f1.ndim == f.ndim assert f1.values_per_cell() == f.values_per_cell() + assert f1.dtype == create_type("float64") f = Field.create_fixed_size("f", (8, 8, 2, 2), index_dimensions=2) assert f.center_vector == sp.Matrix([[f(0, 0), f(0, 1)], [f(1, 0), f(1, 1)]]) @@ -61,16 +61,42 @@ def test_field_basic(): neighbor = field_access.neighbor(coord_id=0, offset=-2) assert neighbor.offsets == (-1, 1) assert "_" in neighbor._latex("dummy") + assert f.dtype == DynamicType.NUMERIC_TYPE f = Field.create_fixed_size("f", (8, 8, 2, 2, 2), index_dimensions=3) assert f.center_vector == sp.Array( [[[f(i, j, k) for k in range(2)] for j in range(2)] for i in range(2)] ) + assert f.dtype == DynamicType.NUMERIC_TYPE f = Field.create_generic("f", spatial_dimensions=5, index_dimensions=2) field_access = f[1, -1, 2, -3, 0](1, 0) assert field_access.offsets == (1, -1, 2, -3, 0) assert field_access.index == (1, 0) + assert f.dtype == DynamicType.NUMERIC_TYPE + + +def test_field_description_parsing(): + f, g = fields("f(1), g(3): [2D]") + + assert f.dtype == g.dtype == DynamicType.NUMERIC_TYPE + assert f.spatial_dimensions == g.spatial_dimensions == 2 + assert f.index_shape == (1,) + assert g.index_shape == (3,) + + h = fields("h: float32[3D]") + + assert h.index_shape == () + assert h.spatial_dimensions == 3 + assert h.index_dimensions == 0 + assert h.dtype == create_type("float32") + + f: Field = fields("f(5, 5) : double[20, 20]") + + assert f.dtype == create_type("float64") + assert f.spatial_shape == (20, 20) + assert f.index_shape == (5, 5) + assert f.neighbor_vector((1, 1)).shape == (5, 5) def test_error_handling(): @@ -145,7 +171,7 @@ def test_error_handling(): def test_decorator_scoping(): - dst = ps.fields("dst : double[2D]") + dst = fields("dst : double[2D]") def f1(): a = sp.Symbol("a") @@ -165,7 +191,7 @@ def test_decorator_scoping(): def test_string_creation(): - x, y, z = ps.fields(" x(4), y(3,5) z : double[ 3, 47]") + x, y, z = fields(" x(4), y(3,5) z : double[ 3, 47]") assert x.index_shape == (4,) assert y.index_shape == (3, 5) assert z.spatial_shape == (3, 47) @@ -173,9 +199,9 @@ def test_string_creation(): def test_itemsize(): - x = ps.fields("x: float32[1d]") - y = ps.fields("y: float64[2d]") - i = ps.fields("i: int16[1d]") + x = fields("x: float32[1d]") + y = fields("y: float64[2d]") + i = fields("i: int16[1d]") assert x.itemsize == 4 assert y.itemsize == 8 @@ -249,7 +275,7 @@ def test_memory_layout_descriptors(): def test_staggered(): # D2Q5 - j1, j2, j3 = ps.fields( + j1, j2, j3 = fields( "j1(2), j2(2,2), j3(2,2,2) : double[2D]", field_type=FieldType.STAGGERED ) @@ -296,7 +322,7 @@ def test_staggered(): ) # D2Q9 - k1, k2 = ps.fields("k1(4), k2(2) : double[2D]", field_type=FieldType.STAGGERED) + k1, k2 = fields("k1(4), k2(2) : double[2D]", field_type=FieldType.STAGGERED) assert k1[1, 1](2) == k1.staggered_access("NE") assert k1[0, 0](2) == k1.staggered_access("SW") @@ -319,7 +345,7 @@ def test_staggered(): ] # sign reversed when using as flux field - r = ps.fields("r(2) : double[2D]", field_type=FieldType.STAGGERED_FLUX) + r = fields("r(2) : double[2D]", field_type=FieldType.STAGGERED_FLUX) assert r[0, 0](0) == r.staggered_access("W") assert -r[1, 0](0) == r.staggered_access("E") -- GitLab