From 5bdabcf6f147241e69431809686c910c5451e58f Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 24 Jan 2025 09:29:27 +0100 Subject: [PATCH] adapt JIT compilers to DynamicType in fields. Enable typechecking for the field module. --- docs/source/contributing/dev-workflow.md | 8 +- mypy.ini | 3 + src/pystencils/field.py | 684 +++++++++++++++------ src/pystencils/jit/cpu_extension_module.py | 89 +-- src/pystencils/jit/gpu_cupy.py | 14 +- tests/frontend/test_field.py | 4 +- tests/runtime/test_datahandling.py | 2 +- 7 files changed, 574 insertions(+), 230 deletions(-) diff --git a/docs/source/contributing/dev-workflow.md b/docs/source/contributing/dev-workflow.md index 2aee09ba2..fe8b70e77 100644 --- a/docs/source/contributing/dev-workflow.md +++ b/docs/source/contributing/dev-workflow.md @@ -118,10 +118,10 @@ mypy src/pystencils :::: :::{note} -Type checking is currently restricted to the `codegen`, `jit`, `backend`, and `types` modules, -since most code in the remaining modules is significantly older and is not comprehensively -type-annotated. As more modules are updated with type annotations, this list will expand in the future. -If you think a new module is ready to be type-checked, add an exception clause for it in the `mypy.ini` file. +Type checking is currently restricted only to a few modules, which are listed in the `mypy.ini` file. +Most code in the remaining modules is significantly older and is not comprehensively type-annotated. +As more modules are updated with type annotations, this list will expand in the future. +If you think a new module is ready to be type-checked, add an exception clause to `mypy.ini`. ::: ## Running the Test Suite diff --git a/mypy.ini b/mypy.ini index 08f073f7c..c8a7195e2 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,6 +17,9 @@ ignore_errors = False [mypy-pystencils.jit.*] ignore_errors = False +[mypy-pystencils.field] +ignore_errors = False + [mypy-pystencils.sympyextensions.typed_sympy] ignore_errors = False diff --git a/src/pystencils/field.py b/src/pystencils/field.py index 826c551a7..2a74824c4 100644 --- a/src/pystencils/field.py +++ b/src/pystencils/field.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools import hashlib import operator @@ -14,14 +16,18 @@ from sympy.core.cache import cacheit from .defaults import DEFAULTS from .alignedarray import aligned_empty from .spatial_coordinates import x_staggered_vector, x_vector -from .stencil import direction_string_to_offset, inverse_direction, offset_to_direction_string +from .stencil import ( + direction_string_to_offset, + inverse_direction, + offset_to_direction_string, +) from .types import PsType, PsStructType, create_type from .sympyextensions.typed_sympy import TypedSymbol, DynamicType from .sympyextensions import is_integer_sequence from .types import UserTypeSpec -__all__ = ['Field', 'fields', 'FieldType', 'Field'] +__all__ = ["Field", "fields", "FieldType", "Field"] class FieldType(Enum): @@ -63,7 +69,10 @@ class FieldType(Enum): @staticmethod def is_staggered(field): assert isinstance(field, Field) - return field.field_type == FieldType.STAGGERED or field.field_type == FieldType.STAGGERED_FLUX + return ( + field.field_type == FieldType.STAGGERED + or field.field_type == FieldType.STAGGERED_FLUX + ) @staticmethod def is_staggered_flux(field): @@ -132,8 +141,15 @@ class Field: """ @staticmethod - def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE, index_dimensions=0, - layout='numpy', index_shape=None, field_type=FieldType.GENERIC) -> 'Field': + def create_generic( + field_name, + spatial_dimensions, + dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE, + index_dimensions=0, + layout="numpy", + index_shape=None, + field_type=FieldType.GENERIC, + ) -> "Field": """ Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes @@ -159,30 +175,45 @@ class Field: layout = spatial_layout_string_to_tuple(layout, dim=spatial_dimensions) total_dimensions = spatial_dimensions + index_dimensions + shape: tuple[TypedSymbol | int, ...] + if index_shape is None or len(index_shape) == 0: - shape = tuple([ - TypedSymbol(DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE) - for i in range(total_dimensions) - ]) + shape = tuple( + [ + TypedSymbol( + DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE + ) + for i in range(total_dimensions) + ] + ) else: shape = tuple( [ - TypedSymbol(DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE) + TypedSymbol( + DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE + ) for i in range(spatial_dimensions) - ] + list(index_shape) + ] + + list(index_shape) ) - strides = tuple([ - TypedSymbol(DEFAULTS.field_stride_name(field_name, i), DynamicType.INDEX_TYPE) - for i in range(total_dimensions) - ]) + strides: tuple[TypedSymbol | int, ...] = tuple( + [ + TypedSymbol( + DEFAULTS.field_stride_name(field_name, i), DynamicType.INDEX_TYPE + ) + for i in range(total_dimensions) + ] + ) if not isinstance(dtype, DynamicType): dtype = create_type(dtype) if isinstance(dtype, PsStructType): if index_dimensions != 0: - raise ValueError("Structured arrays/fields are not allowed to have an index dimension") + raise ValueError( + "Structured arrays/fields are not allowed to have an index dimension" + ) shape += (1,) strides += (1,) @@ -192,8 +223,12 @@ class Field: return Field(field_name, field_type, dtype, layout, shape, strides) @staticmethod - def create_from_numpy_array(field_name: str, array: np.ndarray, index_dimensions: int = 0, - field_type=FieldType.GENERIC) -> 'Field': + def create_from_numpy_array( + field_name: str, + array: np.ndarray, + index_dimensions: int = 0, + field_type=FieldType.GENERIC, + ) -> "Field": """Creates a field based on the layout, data type, and shape of a given numpy array. Kernels created for these kind of fields can only be called with arrays of the same layout, shape and type. @@ -206,7 +241,9 @@ class Field: """ spatial_dimensions = len(array.shape) - index_dimensions if spatial_dimensions < 1: - raise ValueError("Too many index dimensions. At least one spatial dimension required") + raise ValueError( + "Too many index dimensions. At least one spatial dimension required" + ) full_layout = get_layout_of_array(array) spatial_layout = tuple([i for i in full_layout if i < spatial_dimensions]) @@ -218,19 +255,28 @@ class Field: numpy_dtype = np.dtype(array.dtype) if numpy_dtype.fields is not None: if index_dimensions != 0: - raise ValueError("Structured arrays/fields are not allowed to have an index dimension") + raise ValueError( + "Structured arrays/fields are not allowed to have an index dimension" + ) shape += (1,) strides += (1,) if field_type == FieldType.STAGGERED and index_dimensions == 0: raise ValueError("A staggered field needs at least one index dimension") - return Field(field_name, field_type, array.dtype, spatial_layout, shape, strides) + return Field( + field_name, field_type, array.dtype, spatial_layout, shape, strides + ) @staticmethod - def create_fixed_size(field_name: str, shape: Tuple[int, ...], index_dimensions: int = 0, - dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE, layout: str = 'numpy', - strides: Optional[Sequence[int]] = None, - field_type=FieldType.GENERIC) -> 'Field': + def create_fixed_size( + field_name: str, + shape: tuple[int, ...], + index_dimensions: int = 0, + dtype: UserTypeSpec = np.float64, + layout: str | tuple[int, ...] = "numpy", + strides: Optional[Sequence[int]] = None, + field_type=FieldType.GENERIC, + ) -> "Field": """ Creates a field with fixed sizes i.e. can be called only with arrays of the same size and layout @@ -243,34 +289,58 @@ class Field: strides: strides in bytes or None to automatically compute them from shape (assuming no padding) field_type: kind of field """ + if isinstance(dtype, DynamicType): + raise ValueError( + "Parameter `dtype` to `Field.create_fixed_size` does not accept `DynamicType`." + ) + spatial_dimensions = len(shape) - index_dimensions assert spatial_dimensions >= 1 if isinstance(layout, str): - layout = layout_string_to_tuple(layout, spatial_dimensions + index_dimensions) + layout = layout_string_to_tuple( + layout, spatial_dimensions + index_dimensions + ) + + dtype = create_type(dtype) + + shape_tuple = tuple(int(s) for s in shape) + strides_tuple: tuple[int, ...] - shape = tuple(int(s) for s in shape) if strides is None: - strides = compute_strides(shape, layout) + strides_tuple = compute_strides(shape_tuple, layout) else: - assert len(strides) == len(shape) - strides = tuple([s // np.dtype(dtype).itemsize for s in strides]) + assert len(strides) == len(shape_tuple) + np_type = dtype.numpy_dtype - if not isinstance(dtype, DynamicType): - dtype = create_type(dtype) + if np_type is None: + raise ValueError( + f"Cannot create fixed-size field of data type {dtype} which has no NumPy representation." + ) + + strides_tuple = tuple([s // np_type.itemsize for s in strides]) if isinstance(dtype, PsStructType): if index_dimensions != 0: - raise ValueError("Structured arrays/fields are not allowed to have an index dimension") - shape += (1,) - strides += (1,) + raise ValueError( + "Structured arrays/fields are not allowed to have an index dimension" + ) + shape_tuple += (1,) + strides_tuple += (1,) if field_type == FieldType.STAGGERED and index_dimensions == 0: raise ValueError("A staggered field needs at least one index dimension") spatial_layout = list(layout) for i in range(spatial_dimensions, len(layout)): spatial_layout.remove(i) - return Field(field_name, field_type, dtype, tuple(spatial_layout), shape, strides) + return Field( + field_name, + field_type, + dtype, + tuple(spatial_layout), + shape_tuple, + strides_tuple, + ) def __init__( self, @@ -279,14 +349,16 @@ class Field: dtype: UserTypeSpec | DynamicType, layout: tuple[int, ...], shape, - strides + strides, ): """Do not use directly. Use static create* methods""" self._field_name = field_name assert isinstance(field_type, FieldType) assert len(shape) == len(strides) self.field_type = field_type - self._dtype: PsType | DynamicType = create_type(dtype) if not isinstance(dtype, DynamicType) else dtype + self._dtype: PsType | DynamicType = ( + create_type(dtype) if not isinstance(dtype, DynamicType) else dtype + ) self._layout = normalize_layout(layout) self.shape = shape self.strides = strides @@ -298,9 +370,23 @@ class Field: def new_field_with_different_name(self, new_name): if self.has_fixed_shape: - return Field(new_name, self.field_type, self._dtype, self._layout, self.shape, self.strides) + return Field( + new_name, + self.field_type, + self._dtype, + self._layout, + self.shape, + self.strides, + ) else: - return Field(new_name, self.field_type, self.dtype, self.layout, self.shape, self.strides) + return Field( + new_name, + self.field_type, + self.dtype, + self.layout, + self.shape, + self.strides, + ) @property def spatial_dimensions(self) -> int: @@ -327,7 +413,7 @@ class Field: @property def spatial_shape(self) -> Tuple[int, ...]: - return self.shape[:self.spatial_dimensions] + return self.shape[: self.spatial_dimensions] @property def has_fixed_shape(self): @@ -343,7 +429,7 @@ class Field: @property def spatial_strides(self): - return self.strides[:self.spatial_dimensions] + return self.strides[: self.spatial_dimensions] @property def index_strides(self): @@ -362,15 +448,15 @@ class Field: def __repr__(self): if any(isinstance(s, sp.Symbol) for s in self.spatial_shape): - spatial_shape_str = f'{self.spatial_dimensions}d' + spatial_shape_str = f"{self.spatial_dimensions}d" else: - spatial_shape_str = ','.join(str(i) for i in self.spatial_shape) - index_shape_str = ','.join(str(i) for i in self.index_shape) + spatial_shape_str = ",".join(str(i) for i in self.spatial_shape) + index_shape_str = ",".join(str(i) for i in self.index_shape) if self.index_shape: - return f'{self._field_name}({index_shape_str}): {self.dtype}[{spatial_shape_str}]' + return f"{self._field_name}({index_shape_str}): {self.dtype}[{spatial_shape_str}]" else: - return f'{self._field_name}: {self.dtype}[{spatial_shape_str}]' + return f"{self._field_name}: {self.dtype}[{spatial_shape_str}]" def __str__(self): return self.name @@ -391,12 +477,26 @@ class Field: elif len(index_shape) == 1: return sp.Matrix([self(i) for i in range(index_shape[0])]) elif len(index_shape) == 2: - return sp.Matrix([[self(i, j) for j in range(index_shape[1])] for i in range(index_shape[0])]) + return sp.Matrix( + [ + [self(i, j) for j in range(index_shape[1])] + for i in range(index_shape[0]) + ] + ) elif len(index_shape) == 3: - return sp.Array([[[self(i, j, k) for k in range(index_shape[2])] - for j in range(index_shape[1])] for i in range(index_shape[0])]) + return sp.Array( + [ + [ + [self(i, j, k) for k in range(index_shape[2])] + for j in range(index_shape[1]) + ] + for i in range(index_shape[0]) + ] + ) else: - raise NotImplementedError("center_vector is not implemented for more than 3 index dimensions") + raise NotImplementedError( + "center_vector is not implemented for more than 3 index dimensions" + ) @property def center(self): @@ -412,12 +512,20 @@ class Field: if self.index_dimensions == 0: return sp.Matrix([self.__getitem__(offset)]) elif self.index_dimensions == 1: - return sp.Matrix([self.__getitem__(offset)(i) for i in range(self.index_shape[0])]) + return sp.Matrix( + [self.__getitem__(offset)(i) for i in range(self.index_shape[0])] + ) elif self.index_dimensions == 2: - return sp.Matrix([[self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])] - for i in range(self.index_shape[0])]) + return sp.Matrix( + [ + [self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])] + for i in range(self.index_shape[0]) + ] + ) else: - raise NotImplementedError("neighbor_vector is not implemented for more than 2 index dimensions") + raise NotImplementedError( + "neighbor_vector is not implemented for more than 2 index dimensions" + ) def __getitem__(self, offset): if type(offset) is np.ndarray: @@ -427,7 +535,9 @@ class Field: if type(offset) is not tuple: offset = (offset,) if len(offset) != self.spatial_dimensions: - raise ValueError(f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}") + raise ValueError( + f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}" + ) return Field.Access(self, offset) def absolute_access(self, offset, index): @@ -450,7 +560,9 @@ class Field: offset = tuple(direction_string_to_offset(offset, self.spatial_dimensions)) offset = tuple([o * sp.Rational(1, 2) for o in offset]) if len(offset) != self.spatial_dimensions: - raise ValueError(f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}") + raise ValueError( + f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}" + ) prefactor = 1 neighbor_vec = [0] * len(offset) @@ -464,25 +576,33 @@ class Field: if FieldType.is_staggered_flux(self): prefactor = -1 if neighbor not in self.staggered_stencil: - raise ValueError(f"{offset_orig} is not a valid neighbor for the {self.staggered_stencil_name} stencil") + raise ValueError( + f"{offset_orig} is not a valid neighbor for the {self.staggered_stencil_name} stencil" + ) offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec)) idx = self.staggered_stencil.index(neighbor) - if self.index_dimensions == 1: # this field stores a scalar value at each staggered position + if ( + self.index_dimensions == 1 + ): # this field stores a scalar value at each staggered position if index is not None: raise ValueError("Cannot specify an index for a scalar staggered field") return prefactor * Field.Access(self, offset, (idx,)) else: # this field stores a vector or tensor at each staggered position if index is None: - raise ValueError(f"Wrong number of indices: Got 0, expected {self.index_dimensions - 1}") + raise ValueError( + f"Wrong number of indices: Got 0, expected {self.index_dimensions - 1}" + ) if type(index) is np.ndarray: index = tuple(index) if type(index) is not tuple: index = (index,) if self.index_dimensions != len(index) + 1: - raise ValueError(f"Wrong number of indices: Got {len(index)}, expected {self.index_dimensions - 1}") + raise ValueError( + f"Wrong number of indices: Got {len(index)}, expected {self.index_dimensions - 1}" + ) return prefactor * Field.Access(self, offset, (idx, *index)) @@ -493,30 +613,54 @@ class Field: if self.index_dimensions == 1: return sp.Matrix([self.staggered_access(offset)]) elif self.index_dimensions == 2: - return sp.Matrix([self.staggered_access(offset, i) for i in range(self.index_shape[1])]) + return sp.Matrix( + [self.staggered_access(offset, i) for i in range(self.index_shape[1])] + ) elif self.index_dimensions == 3: - return sp.Matrix([[self.staggered_access(offset, (i, k)) for k in range(self.index_shape[2])] - for i in range(self.index_shape[1])]) + return sp.Matrix( + [ + [ + self.staggered_access(offset, (i, k)) + for k in range(self.index_shape[2]) + ] + for i in range(self.index_shape[1]) + ] + ) else: - raise NotImplementedError("staggered_vector_access is not implemented for more than 3 index dimensions") + raise NotImplementedError( + "staggered_vector_access is not implemented for more than 3 index dimensions" + ) @property def staggered_stencil(self): assert FieldType.is_staggered(self) stencils = { - 2: { - 2: ["W", "S"], # D2Q5 - 4: ["W", "S", "SW", "NW"] # D2Q9 - }, + 2: {2: ["W", "S"], 4: ["W", "S", "SW", "NW"]}, # D2Q5 # D2Q9 3: { 3: ["W", "S", "B"], # D3Q7 7: ["W", "S", "B", "BSW", "TSW", "BNW", "TNW"], # D3Q15 9: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS"], # D3Q19 - 13: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS", "BSW", "TSW", "BNW", "TNW"] # D3Q27 - } + 13: [ + "W", + "S", + "B", + "SW", + "NW", + "BW", + "TW", + "BS", + "TS", + "BSW", + "TSW", + "BNW", + "TNW", + ], # D3Q27 + }, } if not self.index_shape[0] in stencils[self.spatial_dimensions]: - raise ValueError(f"No known stencil has {self.index_shape[0]} staggered points") + raise ValueError( + f"No known stencil has {self.index_shape[0]} staggered points" + ) return stencils[self.spatial_dimensions][self.index_shape[0]] @property @@ -529,13 +673,15 @@ class Field: return Field.Access(self, center)(*args, **kwargs) def hashable_contents(self): - return (self._layout, - self.shape, - self.strides, - self.field_type, - self._field_name, - self.latex_name, - self._dtype) + return ( + self._layout, + self.shape, + self.strides, + self.field_type, + self._field_name, + self.latex_name, + self._dtype, + ) def __hash__(self): return hash(self.hashable_contents()) @@ -547,36 +693,53 @@ class Field: @property def physical_coordinates(self): - if hasattr(self.coordinate_transform, '__call__'): - return self.coordinate_transform(self.coordinate_origin + x_vector(self.spatial_dimensions)) + if hasattr(self.coordinate_transform, "__call__"): + return self.coordinate_transform( + self.coordinate_origin + x_vector(self.spatial_dimensions) + ) else: - return self.coordinate_transform @ (self.coordinate_origin + x_vector(self.spatial_dimensions)) + return self.coordinate_transform @ ( + self.coordinate_origin + x_vector(self.spatial_dimensions) + ) @property def physical_coordinates_staggered(self): - return self.coordinate_transform @ \ - (self.coordinate_origin + x_staggered_vector(self.spatial_dimensions)) + return self.coordinate_transform @ ( + self.coordinate_origin + x_staggered_vector(self.spatial_dimensions) + ) def index_to_physical(self, index_coordinates: sp.Matrix, staggered=False): if staggered: - index_coordinates = sp.Matrix([0.5] * len(self.coordinate_origin)) + index_coordinates - if hasattr(self.coordinate_transform, '__call__'): + index_coordinates = ( + sp.Matrix([0.5] * len(self.coordinate_origin)) + index_coordinates + ) + if hasattr(self.coordinate_transform, "__call__"): return self.coordinate_transform(self.coordinate_origin + index_coordinates) else: - return self.coordinate_transform @ (self.coordinate_origin + index_coordinates) + return self.coordinate_transform @ ( + self.coordinate_origin + index_coordinates + ) def physical_to_index(self, physical_coordinates: sp.Matrix, staggered=False): - if hasattr(self.coordinate_transform, '__call__'): - if hasattr(self.coordinate_transform, 'inv'): - return self.coordinate_transform.inv()(physical_coordinates) - self.coordinate_origin + if hasattr(self.coordinate_transform, "__call__"): + if hasattr(self.coordinate_transform, "inv"): + return ( + self.coordinate_transform.inv()(physical_coordinates) + - self.coordinate_origin + ) else: - idx = sp.Matrix(sp.symbols(f'index_coordinates:{self.ndim}', real=True)) + idx = sp.Matrix(sp.symbols(f"index_coordinates:{self.ndim}", real=True)) rtn = sp.solve(self.index_to_physical(idx) - physical_coordinates, idx) - assert rtn, f'Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}' + assert ( + rtn + ), f"Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}" return rtn else: - rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin + rtn = ( + self.coordinate_transform.inv() @ physical_coordinates + - self.coordinate_origin + ) if staggered: rtn = sp.Matrix([i - 0.5 for i in rtn]) @@ -605,18 +768,40 @@ class Field: >>> central_y_component.at_index(0) # change component v_C^0 """ + _iterable = False # see https://i10git.cs.fau.de/pycodegen/pystencils/-/merge_requests/166#note_10680 __match_args__ = ("field", "offsets", "index") + # for the type checker + _field: Field + _offsets: tuple[int | sp.Basic, ...] + _offsetName: str + _superscript: None | str + _index: tuple[int | sp.Basic, ...] | str + _indirect_addressing_fields: set[Field] + _is_absolute_access: bool + def __new__(cls, name, *args, **kwargs): obj = Field.Access.__xnew_cached_(cls, name, *args, **kwargs) return obj - def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None, is_absolute_access=False, dtype=None): + def __new_stage2__( # type: ignore + self, + field: Field, + offsets: tuple[int, ...] = (0, 0, 0), + idx: None | tuple[int, ...] | str = None, + is_absolute_access: bool = False, + dtype: PsType | None = None, + ): field_name = field.name offsets_and_index = (*offsets, *idx) if idx is not None else offsets - constant_offsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsets_and_index]) + constant_offsets = not any( + [ + isinstance(o, sp.Basic) and not o.is_Integer + for o in offsets_and_index + ] + ) if not idx: idx = tuple([0] * field.index_dimensions) @@ -630,31 +815,36 @@ class Field: else: idx_str = ",".join([str(e) for e in idx]) superscript = idx_str - if field.has_fixed_index_shape and not isinstance(field.dtype, PsStructType): + if field.has_fixed_index_shape and not isinstance( + field.dtype, PsStructType + ): for i, bound in zip(idx, field.index_shape): if i >= bound: raise ValueError("Field index out of bounds") else: - offset_name = hashlib.md5(pickle.dumps(offsets_and_index)).hexdigest()[:12] + offset_name = hashlib.md5(pickle.dumps(offsets_and_index)).hexdigest()[ + :12 + ] superscript = None symbol_name = f"{field_name}_{offset_name}" if superscript is not None: symbol_name += "^" + superscript + obj: Field.Access if dtype is not None: obj = super(Field.Access, self).__xnew__(self, symbol_name, dtype) else: obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype) obj._field = field - obj._offsets = [] + _offsets: list[sp.Basic | int] = [] for o in offsets: if isinstance(o, sp.Basic): - obj._offsets.append(o) + _offsets.append(o) else: - obj._offsets.append(int(o)) - obj._offsets = tuple(sp.sympify(obj._offsets)) + _offsets.append(int(o)) + obj._offsets = tuple(sp.sympify(_offsets)) obj._offsetName = offset_name obj._superscript = superscript obj._index = idx @@ -662,19 +852,33 @@ class Field: obj._indirect_addressing_fields = set() for e in chain(obj._offsets, obj._index): if isinstance(e, sp.Basic): - obj._indirect_addressing_fields.update(a.field for a in e.atoms(Field.Access)) + obj._indirect_addressing_fields.update( + a.field for a in e.atoms(Field.Access) + ) obj._is_absolute_access = is_absolute_access return obj def __getnewargs__(self): - return self.field, self.offsets, self.index, self.is_absolute_access, self.dtype + return ( + self.field, + self.offsets, + self.index, + self.is_absolute_access, + self.dtype, + ) def __getnewargs_ex__(self): - return (self.field, self.offsets, self.index, self.is_absolute_access, self.dtype), {} + return ( + self.field, + self.offsets, + self.index, + self.is_absolute_access, + self.dtype, + ), {} # noinspection SpellCheckingInspection - __xnew__ = staticmethod(__new_stage2__) + __xnew__ = staticmethod(__new_stage2__) # type: ignore # noinspection SpellCheckingInspection __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) @@ -688,22 +892,34 @@ class Field: idx = () if len(idx) != self.field.index_dimensions: - raise ValueError(f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}") + raise ValueError( + f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}" + ) if len(idx) == 1 and isinstance(idx[0], str): struct_type = self.field.dtype assert isinstance(struct_type, PsStructType) dtype = struct_type.get_member(idx[0]).dtype - return Field.Access(self.field, self._offsets, idx, - is_absolute_access=self.is_absolute_access, dtype=dtype) + return Field.Access( + self.field, + self._offsets, + idx, + is_absolute_access=self.is_absolute_access, + dtype=dtype, + ) else: - return Field.Access(self.field, self._offsets, idx, - is_absolute_access=self.is_absolute_access, dtype=self.dtype) + return Field.Access( + self.field, + self._offsets, + idx, + is_absolute_access=self.is_absolute_access, + dtype=self.dtype, + ) def __getitem__(self, *idx): return self.__call__(*idx) @property - def field(self) -> 'Field': + def field(self) -> "Field": """Field that the Access points to""" return self._field @@ -715,7 +931,7 @@ class Field: @property def required_ghost_layers(self) -> int: """Largest spatial distance that is accessed.""" - return int(np.max(np.abs(self._offsets))) + return int(np.max(np.abs(self._offsets))) # type: ignore @property def nr_of_coordinates(self): @@ -737,7 +953,7 @@ class Field: """Value of index coordinates as tuple.""" return self._index - def neighbor(self, coord_id: int, offset: int) -> 'Field.Access': + def neighbor(self, coord_id: int, offset: int) -> "Field.Access": """Returns a new Access with changed spatial coordinates. Args: @@ -751,10 +967,15 @@ class Field: """ offset_list = list(self.offsets) offset_list[coord_id] += offset - return Field.Access(self.field, tuple(offset_list), self.index, - is_absolute_access=self.is_absolute_access, dtype=self.dtype) + return Field.Access( + self.field, + tuple(offset_list), + self.index, + is_absolute_access=self.is_absolute_access, + dtype=self.dtype, + ) - def get_shifted(self, *shift) -> 'Field.Access': + def get_shifted(self, *shift) -> "Field.Access": """Returns a new Access with changed spatial coordinates Example: @@ -762,13 +983,15 @@ class Field: >>> f[0,0].get_shifted(1, 1) f_NE """ - return Field.Access(self.field, - tuple(a + b for a, b in zip(shift, self.offsets)), - self.index, - is_absolute_access=self.is_absolute_access, - dtype=self.dtype) + return Field.Access( + self.field, + tuple(a + b for a, b in zip(shift, self.offsets)), + self.index, + is_absolute_access=self.is_absolute_access, + dtype=self.dtype, + ) - def at_index(self, *idx_tuple) -> 'Field.Access': + def at_index(self, *idx_tuple) -> "Field.Access": """Returns new Access with changed index. Example: @@ -776,15 +999,22 @@ class Field: >>> f(0).at_index(8) f_C^8 """ - return Field.Access(self.field, self.offsets, idx_tuple, - is_absolute_access=self.is_absolute_access, dtype=self.dtype) + return Field.Access( + self.field, + self.offsets, + idx_tuple, + is_absolute_access=self.is_absolute_access, + dtype=self.dtype, + ) def _eval_subs(self, old, new): - return Field.Access(self.field, - tuple(sp.sympify(a).subs(old, new) for a in self.offsets), - tuple(sp.sympify(a).subs(old, new) for a in self.index), - is_absolute_access=self.is_absolute_access, - dtype=self.dtype) + return Field.Access( + self.field, + tuple(sp.sympify(a).subs(old, new) for a in self.offsets), + tuple(sp.sympify(a).subs(old, new) for a in self.index), + is_absolute_access=self.is_absolute_access, + dtype=self.dtype, + ) @property def is_absolute_access(self) -> bool: @@ -792,30 +1022,43 @@ class Field: return self._is_absolute_access @property - def indirect_addressing_fields(self) -> Set['Field']: + def indirect_addressing_fields(self) -> Set["Field"]: """Returns a set of fields that the access depends on. - e.g. f[index_field[1, 0]], the outer access to f depends on index_field - """ + e.g. f[index_field[1, 0]], the outer access to f depends on index_field + """ return self._indirect_addressing_fields def _hashable_content(self): super_class_contents = super(Field.Access, self)._hashable_content() - return (super_class_contents, self._field.hashable_contents(), *self._index, - *self._offsets, self._is_absolute_access) + return ( + super_class_contents, + self._field.hashable_contents(), + *self._index, + *self._offsets, + self._is_absolute_access, + ) def _staggered_offset(self, offsets, index): assert FieldType.is_staggered(self._field) neighbor = self._field.staggered_stencil[index] - neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions) - return [(o + sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)] + neighbor = direction_string_to_offset( + neighbor, self._field.spatial_dimensions + ) + return [ + (o + sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets) + ] def _latex(self, _): n = self._field.latex_name if self._field.latex_name else self._field.name offset_str = ",".join([sp.latex(o) for o in self.offsets]) if FieldType.is_staggered(self._field): - offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i]) - for i in range(len(self.offsets))]) + offset_str = ",".join( + [ + sp.latex(self._staggered_offset(self.offsets, self.index[0])[i]) + for i in range(len(self.offsets)) + ] + ) if self.is_absolute_access: offset_str = f"\\mathbf{offset_str}" elif self.field.spatial_dimensions > 1: @@ -836,8 +1079,12 @@ class Field: n = self._field.latex_name if self._field.latex_name else self._field.name offset_str = ",".join([sp.latex(o) for o in self.offsets]) if FieldType.is_staggered(self._field): - offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i]) - for i in range(len(self.offsets))]) + offset_str = ",".join( + [ + sp.latex(self._staggered_offset(self.offsets, self.index[0])[i]) + for i in range(len(self.offsets)) + ] + ) if self.is_absolute_access: offset_str = f"[abs]{offset_str}" @@ -853,8 +1100,13 @@ class Field: return f"{n}[{offset_str}]" -def fields(description=None, index_dimensions=0, layout=None, - field_type=FieldType.GENERIC, **kwargs) -> Union[Field, List[Field]]: +def fields( + description=None, + index_dimensions=0, + layout=None, + field_type=FieldType.GENERIC, + **kwargs, +) -> Union[Field, List[Field]]: """Creates pystencils fields from a string description. Examples: @@ -888,31 +1140,63 @@ def fields(description=None, index_dimensions=0, layout=None, result = [] if description: field_descriptions, dtype, shape = _parse_description(description) - layout = 'numpy' if layout is None else layout + layout = "numpy" if layout is None else layout for field_name, idx_shape in field_descriptions: if field_name in kwargs: arr = kwargs[field_name] - idx_shape_of_arr = () if not len(idx_shape) else arr.shape[-len(idx_shape):] + idx_shape_of_arr = ( + () if not len(idx_shape) else arr.shape[-len(idx_shape):] + ) assert idx_shape_of_arr == idx_shape - f = Field.create_from_numpy_array(field_name, kwargs[field_name], index_dimensions=len(idx_shape), - field_type=field_type) + f = Field.create_from_numpy_array( + field_name, + kwargs[field_name], + index_dimensions=len(idx_shape), + field_type=field_type, + ) elif isinstance(shape, tuple): - f = Field.create_fixed_size(field_name, shape + idx_shape, dtype=dtype, - index_dimensions=len(idx_shape), layout=layout, field_type=field_type) + f = Field.create_fixed_size( + field_name, + shape + idx_shape, + dtype=dtype, + index_dimensions=len(idx_shape), + layout=layout, + field_type=field_type, + ) elif isinstance(shape, int): - f = Field.create_generic(field_name, spatial_dimensions=shape, dtype=dtype, - index_shape=idx_shape, layout=layout, field_type=field_type) + f = Field.create_generic( + field_name, + spatial_dimensions=shape, + dtype=dtype, + index_shape=idx_shape, + layout=layout, + field_type=field_type, + ) elif shape is None: - f = Field.create_generic(field_name, spatial_dimensions=2, dtype=dtype, - index_shape=idx_shape, layout=layout, field_type=field_type) + f = Field.create_generic( + field_name, + spatial_dimensions=2, + dtype=dtype, + index_shape=idx_shape, + layout=layout, + field_type=field_type, + ) else: assert False result.append(f) else: - assert layout is None, "Layout can not be specified when creating Field from numpy array" + assert ( + layout is None + ), "Layout can not be specified when creating Field from numpy array" for field_name, arr in kwargs.items(): - result.append(Field.create_from_numpy_array(field_name, arr, index_dimensions=index_dimensions, - field_type=field_type)) + result.append( + Field.create_from_numpy_array( + field_name, + arr, + index_dimensions=index_dimensions, + field_type=field_type, + ) + ) if len(result) == 0: raise ValueError("Could not parse field description") @@ -922,16 +1206,27 @@ def fields(description=None, index_dimensions=0, layout=None, return result -def get_layout_from_strides(strides: Sequence[int], index_dimension_ids: Optional[List[int]] = None): +def get_layout_from_strides( + strides: Sequence[int], index_dimension_ids: Optional[List[int]] = None +): index_dimension_ids = [] if index_dimension_ids is None else index_dimension_ids coordinates = list(range(len(strides))) - relevant_strides = [stride for i, stride in enumerate(strides) if i not in index_dimension_ids] - result = [x for (y, x) in sorted(zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True)] + relevant_strides = [ + stride for i, stride in enumerate(strides) if i not in index_dimension_ids + ] + result = [ + x + for (y, x) in sorted( + zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True + ) + ] return normalize_layout(result) -def get_layout_of_array(arr: np.ndarray, index_dimension_ids: Optional[List[int]] = None): - """ Returns a list indicating the memory layout (linearization order) of the numpy array. +def get_layout_of_array( + arr: np.ndarray, index_dimension_ids: Optional[List[int]] = None +): + """Returns a list indicating the memory layout (linearization order) of the numpy array. Examples: >>> get_layout_of_array(np.zeros([3,3,3])) @@ -948,7 +1243,9 @@ def get_layout_of_array(arr: np.ndarray, index_dimension_ids: Optional[List[int] return get_layout_from_strides(arr.strides, index_dimension_ids) -def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0, **kwargs): +def create_numpy_array_with_layout( + shape, layout, alignment=False, byte_offset=0, **kwargs +): """Creates numpy array with given memory layout. Args: @@ -972,7 +1269,10 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0 if cur_layout[i] != layout[i]: index_to_swap_with = cur_layout.index(layout[i]) swaps.append((i, index_to_swap_with)) - cur_layout[i], cur_layout[index_to_swap_with] = cur_layout[index_to_swap_with], cur_layout[i] + cur_layout[i], cur_layout[index_to_swap_with] = ( + cur_layout[index_to_swap_with], + cur_layout[i], + ) assert tuple(cur_layout) == tuple(layout) shape = list(shape) @@ -980,7 +1280,7 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0 shape[a], shape[b] = shape[b], shape[a] if not alignment: - res = np.empty(shape, order='c', **kwargs) + res = np.empty(shape, order="c", **kwargs) else: res = aligned_empty(shape, alignment, byte_offset=byte_offset, **kwargs) @@ -992,37 +1292,43 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0 def spatial_layout_string_to_tuple(layout_str: str, dim: int) -> Tuple[int, ...]: if dim <= 0: raise ValueError("Dimensionality must be positive") - + layout_str = layout_str.lower() - if layout_str in ('fzyx', 'zyxf', 'soa', 'aos'): + if layout_str in ("fzyx", "zyxf", "soa", "aos"): if dim > 3: - raise ValueError(f"Invalid spatial dimensionality for layout descriptor {layout_str}: May be at most 3.") + raise ValueError( + f"Invalid spatial dimensionality for layout descriptor {layout_str}: May be at most 3." + ) return tuple(reversed(range(dim))) - - if layout_str in ('f', 'reverse_numpy'): + + if layout_str in ("f", "reverse_numpy"): return tuple(reversed(range(dim))) - elif layout_str in ('c', 'numpy'): + elif layout_str in ("c", "numpy"): return tuple(range(dim)) raise ValueError("Unknown layout descriptor " + layout_str) -def layout_string_to_tuple(layout_str, dim): +def layout_string_to_tuple(layout_str, dim) -> tuple[int, ...]: if dim <= 0: raise ValueError("Dimensionality must be positive") - + layout_str = layout_str.lower() - if layout_str == 'fzyx' or layout_str == 'soa': + if layout_str == "fzyx" or layout_str == "soa": if dim > 4: - raise ValueError(f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4.") + raise ValueError( + f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4." + ) return tuple(reversed(range(dim))) - elif layout_str == 'zyxf' or layout_str == 'aos': + elif layout_str == "zyxf" or layout_str == "aos": if dim > 4: - raise ValueError(f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4.") + raise ValueError( + f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4." + ) return tuple(reversed(range(dim - 1))) + (dim - 1,) - elif layout_str == 'f' or layout_str == 'reverse_numpy': + elif layout_str == "f" or layout_str == "reverse_numpy": return tuple(reversed(range(dim))) - elif layout_str == 'c' or layout_str == 'numpy': + elif layout_str == "c" or layout_str == "numpy": return tuple(range(dim)) raise ValueError("Unknown layout descriptor " + layout_str) @@ -1057,7 +1363,8 @@ def compute_strides(shape, layout): # ---------------------------------------- Parsing of string in fields() function -------------------------------------- -field_description_regex = re.compile(r""" +field_description_regex = re.compile( + r""" \s* # ignore leading white spaces (\w+) # identifier is a sequence of alphanumeric characters, is stored in first group (?: # optional index specification e.g. (1, 4, 2) @@ -1068,9 +1375,12 @@ field_description_regex = re.compile(r""" \s* )? \s*,?\s* # ignore trailing white spaces and comma -""", re.VERBOSE) +""", + re.VERBOSE, +) -type_description_regex = re.compile(r""" +type_description_regex = re.compile( + r""" \s* (\w+)? # optional dtype \s* @@ -1078,7 +1388,9 @@ type_description_regex = re.compile(r""" ([^\]]+) \] \s* -""", re.VERBOSE | re.IGNORECASE) +""", + re.VERBOSE | re.IGNORECASE, +) def _parse_part1(d): @@ -1104,19 +1416,19 @@ def _parse_description(description): else: dtype = DynamicType.NUMERIC_TYPE - if size_info.endswith('d'): + if size_info.endswith("d"): size_info = int(size_info[:-1]) else: size_info = tuple(int(e) for e in size_info.split(",")) - + return dtype, size_info else: raise ValueError("Could not parse field description") - if ':' in description: - field_description, field_info = description.split(':') + if ":" in description: + field_description, field_info = description.split(":") else: - field_description, field_info = description, 'float64[2D]' + field_description, field_info = description, "float64[2D]" fields_info = [e for e in _parse_part1(field_description)] if not field_info: diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py index befb033e6..48a7eebeb 100644 --- a/src/pystencils/jit/cpu_extension_module.py +++ b/src/pystencils/jit/cpu_extension_module.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast from os import path import hashlib @@ -19,6 +19,7 @@ from ..types import ( PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType, + PsPointerType, ) from ..types.quick import Fp, SInt, UInt from ..field import Field @@ -121,9 +122,7 @@ class PsKernelExtensioNModule: def emit_call_wrapper(function_name: str, kernel: Kernel) -> str: builder = CallWrapperBuilder() - - for p in kernel.parameters: - builder.extract_parameter(p) + builder.extract_params(kernel.parameters) # for c in kernel.constraints: # builder.check_constraint(c) @@ -199,7 +198,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ """ def __init__(self) -> None: - self._array_buffers: dict[Field, str] = dict() + self._buffer_types: dict[Field, PsType] = dict() self._array_extractions: dict[Field, str] = dict() self._array_frees: dict[Field, str] = dict() @@ -220,9 +219,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return "PyLong_AsUnsignedLong" case _: - raise ValueError( - f"Don't know how to cast Python objects to {dtype}" - ) + raise ValueError(f"Don't know how to cast Python objects to {dtype}") def _type_char(self, dtype: PsType) -> str | None: if isinstance( @@ -233,37 +230,39 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ else: return None - def extract_field(self, field: Field) -> str: + def get_field_buffer(self, field: Field): + """Get the Python buffer object for the given field.""" + return f"buffer_{field.name}" + + def extract_field(self, field: Field): """Adds an array, and returns the name of the underlying Py_Buffer.""" if field not in self._array_extractions: extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=field.name) + actual_dtype = self._buffer_types[field] # Check array type - type_char = self._type_char(field.dtype) + type_char = self._type_char(actual_dtype) if type_char is not None: dtype_cond = f"buffer_{field.name}.format[0] == '{type_char}'" extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format( cond=dtype_cond, what="data type", name=field.name, - expected=str(field.dtype), + expected=str(actual_dtype), ) # Check item size - itemsize = field.dtype.itemsize + itemsize = actual_dtype.itemsize item_size_cond = f"buffer_{field.name}.itemsize == {itemsize}" extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format( cond=item_size_cond, what="itemsize", name=field.name, expected=itemsize ) - self._array_buffers[field] = f"buffer_{field.name}" self._array_extractions[field] = extraction_code release_code = f"PyBuffer_Release(&buffer_{field.name});" self._array_frees[field] = release_code - return self._array_buffers[field] - def extract_scalar(self, param: Parameter) -> str: if param not in self._scalar_extractions: extract_func = self._scalar_extractor(param.dtype) @@ -279,7 +278,8 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ def extract_array_assoc_var(self, param: Parameter) -> str: if param not in self._array_assoc_var_extractions: field = param.fields[0] - buffer = self.extract_field(field) + buffer = self.get_field_buffer(field) + buffer_dtype = self._buffer_types[field] code: str | None = None for prop in param.properties: @@ -293,7 +293,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ case FieldStride(_, coord): code = ( f"{param.dtype.c_string()} {param.name} = " - f"{buffer}.strides[{coord}] / {field.dtype.itemsize};" + f"{buffer}.strides[{coord}] / {buffer_dtype.itemsize};" ) break assert code is not None @@ -302,29 +302,48 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return param.name - def extract_parameter(self, param: Parameter): - if param.is_field_parameter: - self.extract_array_assoc_var(param) - else: - self.extract_scalar(param) + def extract_params(self, params: tuple[Parameter, ...]): + for param in params: + if ptr_props := param.get_properties(FieldBasePtr): + prop: FieldBasePtr = cast(FieldBasePtr, ptr_props.pop()) + field = prop.field + actual_field_type: PsType + + from .. import DynamicType + + if isinstance(field.dtype, DynamicType): + ptr_type = param.dtype + assert isinstance(ptr_type, PsPointerType) + actual_field_type = ptr_type.base_type + else: + actual_field_type = field.dtype + + self._buffer_types[prop.field] = actual_field_type + self.extract_field(prop.field) + + for param in params: + if param.is_field_parameter: + self.extract_array_assoc_var(param) + else: + self.extract_scalar(param) -# def check_constraint(self, constraint: KernelParamsConstraint): -# variables = constraint.get_parameters() + # def check_constraint(self, constraint: KernelParamsConstraint): + # variables = constraint.get_parameters() -# for var in variables: -# self.extract_parameter(var) + # for var in variables: + # self.extract_parameter(var) -# cond = constraint.to_code() + # cond = constraint.to_code() -# code = f""" -# if(!({cond})) -# {{ -# PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}"); -# return NULL; -# }} -# """ + # code = f""" + # if(!({cond})) + # {{ + # PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}"); + # return NULL; + # }} + # """ -# self._constraint_checks.append(code) + # self._constraint_checks.append(code) def call(self, kernel: Kernel, params: tuple[Parameter, ...]): param_list = ", ".join(p.name for p in params) diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py index c208ac219..a407bb75e 100644 --- a/src/pystencils/jit/gpu_cupy.py +++ b/src/pystencils/jit/gpu_cupy.py @@ -19,7 +19,7 @@ from ..codegen import ( Parameter, ) from ..codegen.properties import FieldShape, FieldStride, FieldBasePtr -from ..types import PsStructType +from ..types import PsStructType, PsPointerType from ..include import get_pystencils_include_path @@ -160,8 +160,18 @@ class CupyKernelWrapper(KernelWrapper): for prop in kparam.properties: match prop: case FieldBasePtr(field): + + elem_dtype: PsType + + from .. import DynamicType + if isinstance(field.dtype, DynamicType): + assert isinstance(kparam.dtype, PsPointerType) + elem_dtype = kparam.dtype.base_type + else: + elem_dtype = field.dtype + arr = kwargs[field.name] - if arr.dtype != field.dtype.numpy_dtype: + if arr.dtype != elem_dtype.numpy_dtype: raise JitError( f"Data type mismatch at array argument {field.name}:" f"Expected {field.dtype}, got {arr.dtype}" diff --git a/tests/frontend/test_field.py b/tests/frontend/test_field.py index 6521e114f..56c3bdabb 100644 --- a/tests/frontend/test_field.py +++ b/tests/frontend/test_field.py @@ -61,13 +61,13 @@ def test_field_basic(): neighbor = field_access.neighbor(coord_id=0, offset=-2) assert neighbor.offsets == (-1, 1) assert "_" in neighbor._latex("dummy") - assert f.dtype == DynamicType.NUMERIC_TYPE + assert f.dtype == create_type("float64") f = Field.create_fixed_size("f", (8, 8, 2, 2, 2), index_dimensions=3) assert f.center_vector == sp.Array( [[[f(i, j, k) for k in range(2)] for j in range(2)] for i in range(2)] ) - assert f.dtype == DynamicType.NUMERIC_TYPE + assert f.dtype == create_type("float64") f = Field.create_generic("f", spatial_dimensions=5, index_dimensions=2) field_access = f[1, -1, 2, -3, 0](1, 0) diff --git a/tests/runtime/test_datahandling.py b/tests/runtime/test_datahandling.py index 29e639c88..c73ec829d 100644 --- a/tests/runtime/test_datahandling.py +++ b/tests/runtime/test_datahandling.py @@ -249,7 +249,7 @@ def test_add_arrays(): dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=0, default_layout='numpy') x_, y_ = dh.add_arrays(field_description) - x, y = ps.fields(field_description + ': [3,4,5]') + x, y = ps.fields(field_description + ': float64[3,4,5]') assert x_ == x assert y_ == y -- GitLab