diff --git a/docs/source/api/symbolic/sympyextensions.rst b/docs/source/api/symbolic/sympyextensions.rst index e3d10fbdf67a1fc26fe1e339b0e642d86f1be51e..4190569d2ab838356a4f501e6811bb7e9202e666 100644 --- a/docs/source/api/symbolic/sympyextensions.rst +++ b/docs/source/api/symbolic/sympyextensions.rst @@ -71,7 +71,10 @@ Typed Expressions .. autoclass:: pystencils.DynamicType :members: -.. autoclass:: pystencils.sympyextensions.CastFunc +.. autoclass:: pystencils.sympyextensions.typed_sympy.TypeCast + :members: + +.. autoclass:: pystencils.sympyextensions.tcast Integer Operations diff --git a/docs/source/contributing/dev-workflow.md b/docs/source/contributing/dev-workflow.md index 2aee09ba2e78bd0041ec6d2d2860385514240ecf..fe8b70e7703385d45f7fd2d53822424b193c2592 100644 --- a/docs/source/contributing/dev-workflow.md +++ b/docs/source/contributing/dev-workflow.md @@ -118,10 +118,10 @@ mypy src/pystencils :::: :::{note} -Type checking is currently restricted to the `codegen`, `jit`, `backend`, and `types` modules, -since most code in the remaining modules is significantly older and is not comprehensively -type-annotated. As more modules are updated with type annotations, this list will expand in the future. -If you think a new module is ready to be type-checked, add an exception clause for it in the `mypy.ini` file. +Type checking is currently restricted only to a few modules, which are listed in the `mypy.ini` file. +Most code in the remaining modules is significantly older and is not comprehensively type-annotated. +As more modules are updated with type annotations, this list will expand in the future. +If you think a new module is ready to be type-checked, add an exception clause to `mypy.ini`. ::: ## Running the Test Suite diff --git a/docs/source/index.rst b/docs/source/index.rst index cb455c8b4d1589353a7538c0e98b5eab864b4392..6dba50af184ee95c7378a2e923bd76a6d97883a2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -82,6 +82,7 @@ Topics user_manual/symbolic_language user_manual/kernelcreation user_manual/gpu_kernels + user_manual/WorkingWithTypes .. toctree:: :maxdepth: 1 diff --git a/docs/source/user_manual/WorkingWithTypes.md b/docs/source/user_manual/WorkingWithTypes.md new file mode 100644 index 0000000000000000000000000000000000000000..e0f9283773cea55453ecdfe2377dc82b096e0741 --- /dev/null +++ b/docs/source/user_manual/WorkingWithTypes.md @@ -0,0 +1,164 @@ +--- +file_format: mystnb +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +mystnb: + execution_mode: cache +--- + +# Working with Data Types + +This guide will demonstrate the various options that exist to customize the data types +in generated kernels. +Data types can be modified on different levels of granularity: +Individual fields and symbols, +single subexpressions, +or the entire kernel. + +```{code-cell} ipython3 +:tags: [remove-cell] +import pystencils as ps +import sympy as sp +``` + +## Changing the Default Data Types + +The pystencils code generator defines two default data types: + - The default *numeric type*, which is applied to all numerical computations that are not + otherwise explicitly typed; the default is `float64`. + - The default *index type*, which is used for all loop and field index calculations; the default is `int64`. + +These can be modified by setting the +{any}`default_dtype <CreateKernelConfig.default_dtype>` and +{any}`index_type <CreateKernelConfig.index_dtype>` +options of the code generator configuration: + +```{code-cell} ipython3 +cfg = ps.CreateKernelConfig() +cfg.default_dtype = "float32" +cfg.index_dtype = "int32" +``` + +Modifying these will change the way types for [untyped symbols](#untyped-symbols) +and [dynamically typed expressions](#dynamic-typing) are computed. + +## Setting the Types of Fields and Symbols + +(untyped-symbols)= +### Untyped Symbols + +Symbols used inside a kernel are most commonly created using +{any}`sp.symbols <sympy.core.symbol.symbols>` or +{any}`sp.Symbol <sympy.core.symbol.Symbol>`. +These symbols are *untyped*; they will receive a type during code generation +according to these rules: + - Free untyped symbols (i.e. symbols not defined by an assignment inside the kernel) receive the + {any}`default data type <CreateKernelConfig.default_dtype>` specified in the code generator configuration. + - Bound untyped symbols (i.e. symbols that *are* defined in an assignment) + receive the data type that was computed for the right-hand side expression of their defining assignment. + +If you are working on kernels with homogenous data types, using untyped symbols will mostly be enough. + +### Explicitly Typed Symbols and Fields + +If you need more control over the data types in (parts of) your kernel, +you will have to explicitly specify them. +To set an explicit data type for a symbol, use the {any}`TypedSymbol` class of pystencils: + +```{code-cell} ipython3 +x_typed = ps.TypedSymbol("x", "uint32") +x_typed, str(x_typed.dtype) +``` + +You can set a `TypedSymbol` to any data type provided by [the type system](#page_type_system), +which will then be enforced by the code generator. + +The same holds for fields: +When creating fields through the {any}`fields <pystencils.field.fields>` function, +add the type to the descriptor string; for instance: + +```{code-cell} ipython3 +f, g = ps.fields("f(1), g(3): float32[3D]") +str(f.dtype), str(g.dtype) +``` + +When using `Field.create_generic` or `Field.create_fixed_size`, on the other hand, +you can set the data type via the `dtype` keyword argument. + +(dynamic-typing)= +### Dynamically Typed Symbols and Fields + +Apart from explicitly setting data types, +`TypedSymbol`s and fields can also receive a *dynamic data type* (see {any}`DynamicType`). +There are two options: + - Symbols or fields annotated with {any}`DynamicType.NUMERIC_TYPE` will always receive + the {any}`default numeric type <CreateKernelConfig.default_dtype>` configured for the + code generator. + This is the default setting for fields + created through `fields`, `Field.create_generic` or `Field.create_fixed_size`. + - When annotated with {any}`DynamicType.INDEX_TYPE`, on the other hand, they will receive + the {any}`index data type <CreateKernelConfig.index_dtype>` configured for the kernel. + +Using dynamic typing, you can enforce symbols to receive either the standard numeric or +index type without explicitly stating it, such that your kernel definition becomes +independent from the code generator configuration. + +## Mixing Types Inside Expressions + +Pystencils enforces that all symbols, constants, and fields occuring inside an expression +have the same data type. +The code generator will never introduce implicit casts--if any type conflicts arise, it will terminate with an error. + +Still, there are cases where you want to combine subexpressions of different types; +maybe you need to compute geometric information from loop counters or other integers, +or you are doing mixed-precision numerical computations. +In these cases, you might have to introduce explicit type casts when values move from one type context to another. + + <!-- 2. Annotate expressions with a specific data type to ensure computations are performed in that type. + TODO: See #97 (https://i10git.cs.fau.de/pycodegen/pystencils/-/issues/97) + --> + +(type_casts)= +### Type Casts + +Type casts can be introduced into kernels using the {any}`tcast` symbolic function. +It takes an expression and a data type, which is either an explicit type (see [the type system](#page_type_system)) +or a dynamic type ({any}`DynamicType`): + +```{code-cell} ipython3 +x, y = sp.symbols("x, y") +expr1 = ps.tcast(x, "float32") +expr2 = ps.tcast(3 + y, ps.DynamicType.INDEX_TYPE) + +str(expr1.dtype), str(expr2.dtype) +``` + +When a type cast occurs, pystencils will compute the type of its argument independently +and then introduce a runtime cast to the target type. +That target type must comply with the type computed for the outer expression, +which the cast is embedded in. + +## Understanding the pystencils Type Inference System + +To correctly apply varying data types to pystencils kernels, it is important to understand +how pystencils computes and propagates the data types of symbols and expressions. + +Type inference happens on the level of assignments. +For each assignment $x := \mathrm{calc}(y_1, \dots, y_n)$, +the system first attempts to compute a *unique* type for the right-hand side (RHS) $\mathrm{calc}(y_1, \dots, y_n)$. +It searches for any subexpression inside the RHS for which a type is already known -- +these might be typed symbols +(whose types are either set explicitly by the user, +or have been determined from their defining assignment), +field accesses, +or explicitly typed expressions. +It then attempts to apply that data type to the entire expression. +If type conflicts occur, the process fails and the code generator raises an error. +Otherwise, the resulting type is assigned to the left-hand side symbol $x$. + +:::{admonition} Developer's To Do +It would be great to illustrate this using a GraphViz-plot of an AST, +with nodes colored according to their data types +::: diff --git a/docs/source/user_manual/kernelcreation.md b/docs/source/user_manual/kernelcreation.md index c85c8f99d3490602321c57f881b32b0127051c70..ad346473cc05dc746794ffc8f56f5bf21ffdec90 100644 --- a/docs/source/user_manual/kernelcreation.md +++ b/docs/source/user_manual/kernelcreation.md @@ -138,7 +138,7 @@ This happens roughly according to the following rules: We can observe this behavior by setting up a kernel including several fields with different data types: ```{code-cell} ipython3 -from pystencils.sympyextensions import CastFunc +from pystencils.sympyextensions import tcast f = ps.fields("f: float32[2D]") g = ps.fields("g: float16[2D]") diff --git a/mypy.ini b/mypy.ini index cc23a503a2da6c9849d3a41e82fe8ceb8de13b43..c8a7195e2e28bffbeb79e1e552822cea4e8dd041 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,6 +17,12 @@ ignore_errors = False [mypy-pystencils.jit.*] ignore_errors = False +[mypy-pystencils.field] +ignore_errors = False + +[mypy-pystencils.sympyextensions.typed_sympy] +ignore_errors = False + [mypy-setuptools.*] ignore_missing_imports=true diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index a23ce185d1a4f6c9cd9a17fccf315462eddf287f..07283d5294bc08c8e68e6a40af0f956b36a0129a 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -32,7 +32,7 @@ from .spatial_coordinates import ( from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil from .simp import AssignmentCollection from .sympyextensions.typed_sympy import TypedSymbol, DynamicType -from .sympyextensions import SymbolCreator +from .sympyextensions import SymbolCreator, tcast from .datahandling import create_data_handling __all__ = [ @@ -77,6 +77,7 @@ __all__ = [ "x_staggered_vector", "fd", "stencil", + "tcast", ] from . import _version diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 8f5931c6494bbf1eca950e38df0053e79af3e81b..68da893ff6204c73d61dcebddb1da602f37520c7 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -93,6 +93,16 @@ class KernelCreationContext: def index_dtype(self) -> PsIntegerType: """Data type used by default for index expressions""" return self._index_dtype + + def resolve_dynamic_type(self, dtype: DynamicType | PsType) -> PsType: + """Selects the appropriate data type for `DynamicType` instances, and returns all other types as they are.""" + match dtype: + case DynamicType.NUMERIC_TYPE: + return self._default_dtype + case DynamicType.INDEX_TYPE: + return self._index_dtype + case _: + return dtype @property def metadata(self) -> dict[str, Any]: @@ -339,6 +349,8 @@ class KernelCreationContext: if isinstance(s, TypedSymbol) ) + entry_type = self.resolve_dynamic_type(field.dtype) + if len(idx_types) > 1: raise KernelConstraintsError( f"Multiple incompatible types found in index symbols of field {field}: " @@ -375,10 +387,10 @@ class KernelCreationContext: base_ptr = self.get_symbol( DEFAULTS.field_pointer_name(field.name), - PsPointerType(field.dtype, restrict=True), + PsPointerType(entry_type, restrict=True), ) - return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides) + return PsBuffer(field.name, entry_type, base_ptr, buf_shape, buf_strides) def _create_buffer_field_buffer(self, field: Field) -> PsBuffer: if field.spatial_dimensions != 1: @@ -418,10 +430,11 @@ class KernelCreationContext: ] buf_strides = [PsConstant(num_entries, idx_type), PsConstant(1, idx_type)] + buf_dtype = self.resolve_dynamic_type(field.dtype) base_ptr = self.get_symbol( DEFAULTS.field_pointer_name(field.name), - PsPointerType(field.dtype, restrict=True), + PsPointerType(buf_dtype, restrict=True), ) - return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides) + return PsBuffer(field.name, buf_dtype, base_ptr, buf_shape, buf_strides) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index c38fcbc9730e046b2d51762c5846c1542f0cbe74..4fd09f879dd8d98903753c8709543e0bcc3fd3e1 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -13,7 +13,7 @@ from ...sympyextensions import ( integer_functions, ConditionalFieldAccess, ) -from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType +from ...sympyextensions.typed_sympy import TypedSymbol, TypeCast, DynamicType from ...sympyextensions.pointers import AddressOf, mem_acc from ...field import Field, FieldType @@ -270,14 +270,7 @@ class FreezeExpressions: return num / denom def map_TypedSymbol(self, expr: TypedSymbol): - dtype = expr.dtype - - match dtype: - case DynamicType.NUMERIC_TYPE: - dtype = self._ctx.default_dtype - case DynamicType.INDEX_TYPE: - dtype = self._ctx.index_dtype - + dtype = self._ctx.resolve_dynamic_type(expr.dtype) symb = self._ctx.get_symbol(expr.name, dtype) return PsSymbolExpr(symb) @@ -490,7 +483,7 @@ class FreezeExpressions: ] return cast(PsCall, args[0]) - def map_CastFunc(self, cast_expr: CastFunc) -> PsCast | PsConstantExpr: + def map_TypeCast(self, cast_expr: TypeCast) -> PsCast | PsConstantExpr: dtype: PsType match cast_expr.dtype: case DynamicType.NUMERIC_TYPE: diff --git a/src/pystencils/field.py b/src/pystencils/field.py index 1a3a13b73125bd07e6161a1694a6bf03dc2ba506..246232efde7a6b432598f614492725e2ea063cff 100644 --- a/src/pystencils/field.py +++ b/src/pystencils/field.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools import hashlib import operator @@ -5,23 +7,28 @@ import pickle import re from enum import Enum from itertools import chain -from typing import List, Optional, Sequence, Set, Tuple, Union +from typing import List, Optional, Sequence, Set, Tuple +from warnings import warn import numpy as np import sympy as sp from sympy.core.cache import cacheit from .defaults import DEFAULTS -from pystencils.alignedarray import aligned_empty -from pystencils.spatial_coordinates import x_staggered_vector, x_vector -from pystencils.stencil import direction_string_to_offset, inverse_direction, offset_to_direction_string -from pystencils.types import PsType, PsStructType, create_type -from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType -from pystencils.sympyextensions import is_integer_sequence -from pystencils.types import UserTypeSpec +from .alignedarray import aligned_empty +from .spatial_coordinates import x_staggered_vector, x_vector +from .stencil import ( + direction_string_to_offset, + inverse_direction, + offset_to_direction_string, +) +from .types import PsType, PsStructType, create_type +from .sympyextensions.typed_sympy import TypedSymbol, DynamicType +from .sympyextensions import is_integer_sequence +from .types import UserTypeSpec -__all__ = ['Field', 'fields', 'FieldType', 'Field'] +__all__ = ["Field", "fields", "FieldType", "Field"] class FieldType(Enum): @@ -63,7 +70,10 @@ class FieldType(Enum): @staticmethod def is_staggered(field): assert isinstance(field, Field) - return field.field_type == FieldType.STAGGERED or field.field_type == FieldType.STAGGERED_FLUX + return ( + field.field_type == FieldType.STAGGERED + or field.field_type == FieldType.STAGGERED_FLUX + ) @staticmethod def is_staggered_flux(field): @@ -123,23 +133,34 @@ class Field: >>> assignments = [Assignment(dst[0,0](i), src[-offset](i)) for i, offset in enumerate(stencil)]; Args: - field_name: something - field_type: something - dtype: something - layout: something - shape: something - strides: something + field_name: The field's name + field_type: The kind of the field + dtype: Data type of the field's entries + layout: Linearization order of the field's spatial dimensions + shape: Total shape (spatial and index) of the field + strides: Linearization strides of the field """ @staticmethod - def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec = np.float64, index_dimensions=0, - layout='numpy', index_shape=None, field_type=FieldType.GENERIC) -> 'Field': + def create_generic( + field_name, + spatial_dimensions, + dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE, + index_dimensions=0, + layout="numpy", + index_shape=None, + field_type=FieldType.GENERIC, + ) -> "Field": """ - Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes + Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes. + + **Field Element Type** By default, the data type of the field entries is left undetermined until + code generation, at which point it is set to the default numerical type of the kernel. + You can specify a concrete type using the `dtype` parameter. Args: field_name: symbolic name for the field - dtype: numpy data type of the array the kernel is called with later + dtype: Data type of the field entries spatial_dimensions: see documentation of Field index_dimensions: see documentation of Field layout: tuple specifying the loop ordering of the spatial dimensions e.g. (2, 1, 0 ) means that @@ -159,41 +180,60 @@ class Field: layout = spatial_layout_string_to_tuple(layout, dim=spatial_dimensions) total_dimensions = spatial_dimensions + index_dimensions + shape: tuple[TypedSymbol | int, ...] + if index_shape is None or len(index_shape) == 0: - shape = tuple([ - TypedSymbol(DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE) - for i in range(total_dimensions) - ]) + shape = tuple( + [ + TypedSymbol( + DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE + ) + for i in range(total_dimensions) + ] + ) else: shape = tuple( [ - TypedSymbol(DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE) + TypedSymbol( + DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE + ) for i in range(spatial_dimensions) - ] + list(index_shape) + ] + + list(index_shape) ) - strides = tuple([ - TypedSymbol(DEFAULTS.field_stride_name(field_name, i), DynamicType.INDEX_TYPE) - for i in range(total_dimensions) - ]) + strides: tuple[TypedSymbol | int, ...] = tuple( + [ + TypedSymbol( + DEFAULTS.field_stride_name(field_name, i), DynamicType.INDEX_TYPE + ) + for i in range(total_dimensions) + ] + ) - dtype = create_type(dtype) - np_data_type = dtype.numpy_dtype - assert np_data_type is not None - - if np_data_type.fields is not None: + if not isinstance(dtype, DynamicType): + dtype = create_type(dtype) + + if isinstance(dtype, PsStructType): if index_dimensions != 0: - raise ValueError("Structured arrays/fields are not allowed to have an index dimension") + raise ValueError( + "Structured arrays/fields are not allowed to have an index dimension" + ) shape += (1,) strides += (1,) + if field_type == FieldType.STAGGERED and index_dimensions == 0: raise ValueError("A staggered field needs at least one index dimension") return Field(field_name, field_type, dtype, layout, shape, strides) @staticmethod - def create_from_numpy_array(field_name: str, array: np.ndarray, index_dimensions: int = 0, - field_type=FieldType.GENERIC) -> 'Field': + def create_from_numpy_array( + field_name: str, + array: np.ndarray, + index_dimensions: int = 0, + field_type=FieldType.GENERIC, + ) -> Field: """Creates a field based on the layout, data type, and shape of a given numpy array. Kernels created for these kind of fields can only be called with arrays of the same layout, shape and type. @@ -206,7 +246,9 @@ class Field: """ spatial_dimensions = len(array.shape) - index_dimensions if spatial_dimensions < 1: - raise ValueError("Too many index dimensions. At least one spatial dimension required") + raise ValueError( + "Too many index dimensions. At least one spatial dimension required" + ) full_layout = get_layout_of_array(array) spatial_layout = tuple([i for i in full_layout if i < spatial_dimensions]) @@ -218,21 +260,31 @@ class Field: numpy_dtype = np.dtype(array.dtype) if numpy_dtype.fields is not None: if index_dimensions != 0: - raise ValueError("Structured arrays/fields are not allowed to have an index dimension") + raise ValueError( + "Structured arrays/fields are not allowed to have an index dimension" + ) shape += (1,) strides += (1,) if field_type == FieldType.STAGGERED and index_dimensions == 0: raise ValueError("A staggered field needs at least one index dimension") - return Field(field_name, field_type, array.dtype, spatial_layout, shape, strides) + return Field( + field_name, field_type, array.dtype, spatial_layout, shape, strides + ) @staticmethod - def create_fixed_size(field_name: str, shape: Tuple[int, ...], index_dimensions: int = 0, - dtype: UserTypeSpec = np.float64, layout: str = 'numpy', - strides: Optional[Sequence[int]] = None, - field_type=FieldType.GENERIC) -> 'Field': + def create_fixed_size( + field_name: str, + shape: tuple[int, ...], + index_dimensions: int = 0, + dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE, + layout: str | tuple[int, ...] = "numpy", + memory_strides: None | Sequence[int] = None, + strides: Optional[Sequence[int]] = None, + field_type=FieldType.GENERIC, + ) -> Field: """ - Creates a field with fixed sizes i.e. can be called only with arrays of the same size and layout + Creates a field with fixed sizes i.e. can be called only with arrays of the same size and layout. Args: field_name: symbolic name for the field @@ -240,54 +292,90 @@ class Field: index_dimensions: how many of the trailing dimensions are interpreted as index (as opposed to spatial) dtype: numpy data type of the array the kernel is called with later layout: full layout of array, not only spatial dimensions - strides: strides in bytes or None to automatically compute them from shape (assuming no padding) + memory_strides: Linearization strides for each dimension; + i.e. the number of elements to skip to get from one index to the next in the respective dimension. field_type: kind of field """ + if strides is not None: + warn( + "The `strides` parameter to `Field.create_fixed_size` is deprecated " + "and will be removed in pystencils 2.1. " + "Use `memory_strides` instead; " + "beware that `memory_strides` takes the number of *elements* to skip, " + "instead of the number of bytes.", + FutureWarning + ) + + if memory_strides is not None: + raise ValueError("Cannot specify `memory_strides` and deprecated parameter `strides` at the same time.") + + if isinstance(dtype, DynamicType): + raise ValueError("Cannot specify the deprecated parameter `strides` together with a `DynamicType`. " + "Set `memory_strides` instead.") + + np_type = create_type(dtype).numpy_dtype + assert np_type is not None + memory_strides = tuple([s // np_type.itemsize for s in strides]) + spatial_dimensions = len(shape) - index_dimensions assert spatial_dimensions >= 1 if isinstance(layout, str): - layout = layout_string_to_tuple(layout, spatial_dimensions + index_dimensions) + layout = layout_string_to_tuple( + layout, spatial_dimensions + index_dimensions + ) + + if not isinstance(dtype, DynamicType): + dtype = create_type(dtype) + + shape_tuple = tuple(int(s) for s in shape) + strides_tuple: tuple[int, ...] - shape = tuple(int(s) for s in shape) if strides is None: - strides = compute_strides(shape, layout) + strides_tuple = compute_strides(shape_tuple, layout) else: - assert len(strides) == len(shape) - strides = tuple([s // np.dtype(dtype).itemsize for s in strides]) + assert len(strides) == len(shape_tuple) + strides_tuple = tuple(strides) - dtype = create_type(dtype) - numpy_dtype = dtype.numpy_dtype - assert numpy_dtype is not None - - if numpy_dtype.fields is not None: + if isinstance(dtype, PsStructType): if index_dimensions != 0: - raise ValueError("Structured arrays/fields are not allowed to have an index dimension") - shape += (1,) - strides += (1,) + raise ValueError( + "Structured arrays/fields are not allowed to have an index dimension" + ) + shape_tuple += (1,) + strides_tuple += (1,) if field_type == FieldType.STAGGERED and index_dimensions == 0: raise ValueError("A staggered field needs at least one index dimension") spatial_layout = list(layout) for i in range(spatial_dimensions, len(layout)): spatial_layout.remove(i) - return Field(field_name, field_type, dtype, tuple(spatial_layout), shape, strides) + return Field( + field_name, + field_type, + dtype, + tuple(spatial_layout), + shape_tuple, + strides_tuple, + ) def __init__( self, field_name: str, field_type: FieldType, - dtype: UserTypeSpec, + dtype: UserTypeSpec | DynamicType, layout: tuple[int, ...], shape, - strides + strides, ): """Do not use directly. Use static create* methods""" self._field_name = field_name assert isinstance(field_type, FieldType) assert len(shape) == len(strides) self.field_type = field_type - self._dtype = create_type(dtype) + self._dtype: PsType | DynamicType = ( + create_type(dtype) if not isinstance(dtype, DynamicType) else dtype + ) self._layout = normalize_layout(layout) self.shape = shape self.strides = strides @@ -299,9 +387,23 @@ class Field: def new_field_with_different_name(self, new_name): if self.has_fixed_shape: - return Field(new_name, self.field_type, self._dtype, self._layout, self.shape, self.strides) + return Field( + new_name, + self.field_type, + self._dtype, + self._layout, + self.shape, + self.strides, + ) else: - return Field(new_name, self.field_type, self.dtype, self.layout, self.shape, self.strides) + return Field( + new_name, + self.field_type, + self.dtype, + self.layout, + self.shape, + self.strides, + ) @property def spatial_dimensions(self) -> int: @@ -328,7 +430,7 @@ class Field: @property def spatial_shape(self) -> Tuple[int, ...]: - return self.shape[:self.spatial_dimensions] + return self.shape[: self.spatial_dimensions] @property def has_fixed_shape(self): @@ -344,31 +446,34 @@ class Field: @property def spatial_strides(self): - return self.strides[:self.spatial_dimensions] + return self.strides[: self.spatial_dimensions] @property def index_strides(self): return self.strides[self.spatial_dimensions:] @property - def dtype(self) -> PsType: + def dtype(self) -> PsType | DynamicType: return self._dtype @property - def itemsize(self): - return self.dtype.itemsize + def itemsize(self) -> int | None: + if isinstance(self.dtype, PsType): + return self.dtype.itemsize + else: + return None def __repr__(self): if any(isinstance(s, sp.Symbol) for s in self.spatial_shape): - spatial_shape_str = f'{self.spatial_dimensions}d' + spatial_shape_str = f"{self.spatial_dimensions}d" else: - spatial_shape_str = ','.join(str(i) for i in self.spatial_shape) - index_shape_str = ','.join(str(i) for i in self.index_shape) + spatial_shape_str = ",".join(str(i) for i in self.spatial_shape) + index_shape_str = ",".join(str(i) for i in self.index_shape) if self.index_shape: - return f'{self._field_name}({index_shape_str}): {self.dtype}[{spatial_shape_str}]' + return f"{self._field_name}({index_shape_str}): {self.dtype}[{spatial_shape_str}]" else: - return f'{self._field_name}: {self.dtype}[{spatial_shape_str}]' + return f"{self._field_name}: {self.dtype}[{spatial_shape_str}]" def __str__(self): return self.name @@ -389,12 +494,26 @@ class Field: elif len(index_shape) == 1: return sp.Matrix([self(i) for i in range(index_shape[0])]) elif len(index_shape) == 2: - return sp.Matrix([[self(i, j) for j in range(index_shape[1])] for i in range(index_shape[0])]) + return sp.Matrix( + [ + [self(i, j) for j in range(index_shape[1])] + for i in range(index_shape[0]) + ] + ) elif len(index_shape) == 3: - return sp.Array([[[self(i, j, k) for k in range(index_shape[2])] - for j in range(index_shape[1])] for i in range(index_shape[0])]) + return sp.Array( + [ + [ + [self(i, j, k) for k in range(index_shape[2])] + for j in range(index_shape[1]) + ] + for i in range(index_shape[0]) + ] + ) else: - raise NotImplementedError("center_vector is not implemented for more than 3 index dimensions") + raise NotImplementedError( + "center_vector is not implemented for more than 3 index dimensions" + ) @property def center(self): @@ -410,12 +529,20 @@ class Field: if self.index_dimensions == 0: return sp.Matrix([self.__getitem__(offset)]) elif self.index_dimensions == 1: - return sp.Matrix([self.__getitem__(offset)(i) for i in range(self.index_shape[0])]) + return sp.Matrix( + [self.__getitem__(offset)(i) for i in range(self.index_shape[0])] + ) elif self.index_dimensions == 2: - return sp.Matrix([[self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])] - for i in range(self.index_shape[0])]) + return sp.Matrix( + [ + [self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])] + for i in range(self.index_shape[0]) + ] + ) else: - raise NotImplementedError("neighbor_vector is not implemented for more than 2 index dimensions") + raise NotImplementedError( + "neighbor_vector is not implemented for more than 2 index dimensions" + ) def __getitem__(self, offset): if type(offset) is np.ndarray: @@ -425,7 +552,9 @@ class Field: if type(offset) is not tuple: offset = (offset,) if len(offset) != self.spatial_dimensions: - raise ValueError(f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}") + raise ValueError( + f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}" + ) return Field.Access(self, offset) def absolute_access(self, offset, index): @@ -448,7 +577,9 @@ class Field: offset = tuple(direction_string_to_offset(offset, self.spatial_dimensions)) offset = tuple([o * sp.Rational(1, 2) for o in offset]) if len(offset) != self.spatial_dimensions: - raise ValueError(f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}") + raise ValueError( + f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}" + ) prefactor = 1 neighbor_vec = [0] * len(offset) @@ -462,25 +593,33 @@ class Field: if FieldType.is_staggered_flux(self): prefactor = -1 if neighbor not in self.staggered_stencil: - raise ValueError(f"{offset_orig} is not a valid neighbor for the {self.staggered_stencil_name} stencil") + raise ValueError( + f"{offset_orig} is not a valid neighbor for the {self.staggered_stencil_name} stencil" + ) offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec)) idx = self.staggered_stencil.index(neighbor) - if self.index_dimensions == 1: # this field stores a scalar value at each staggered position + if ( + self.index_dimensions == 1 + ): # this field stores a scalar value at each staggered position if index is not None: raise ValueError("Cannot specify an index for a scalar staggered field") return prefactor * Field.Access(self, offset, (idx,)) else: # this field stores a vector or tensor at each staggered position if index is None: - raise ValueError(f"Wrong number of indices: Got 0, expected {self.index_dimensions - 1}") + raise ValueError( + f"Wrong number of indices: Got 0, expected {self.index_dimensions - 1}" + ) if type(index) is np.ndarray: index = tuple(index) if type(index) is not tuple: index = (index,) if self.index_dimensions != len(index) + 1: - raise ValueError(f"Wrong number of indices: Got {len(index)}, expected {self.index_dimensions - 1}") + raise ValueError( + f"Wrong number of indices: Got {len(index)}, expected {self.index_dimensions - 1}" + ) return prefactor * Field.Access(self, offset, (idx, *index)) @@ -491,30 +630,54 @@ class Field: if self.index_dimensions == 1: return sp.Matrix([self.staggered_access(offset)]) elif self.index_dimensions == 2: - return sp.Matrix([self.staggered_access(offset, i) for i in range(self.index_shape[1])]) + return sp.Matrix( + [self.staggered_access(offset, i) for i in range(self.index_shape[1])] + ) elif self.index_dimensions == 3: - return sp.Matrix([[self.staggered_access(offset, (i, k)) for k in range(self.index_shape[2])] - for i in range(self.index_shape[1])]) + return sp.Matrix( + [ + [ + self.staggered_access(offset, (i, k)) + for k in range(self.index_shape[2]) + ] + for i in range(self.index_shape[1]) + ] + ) else: - raise NotImplementedError("staggered_vector_access is not implemented for more than 3 index dimensions") + raise NotImplementedError( + "staggered_vector_access is not implemented for more than 3 index dimensions" + ) @property def staggered_stencil(self): assert FieldType.is_staggered(self) stencils = { - 2: { - 2: ["W", "S"], # D2Q5 - 4: ["W", "S", "SW", "NW"] # D2Q9 - }, + 2: {2: ["W", "S"], 4: ["W", "S", "SW", "NW"]}, # D2Q5 # D2Q9 3: { 3: ["W", "S", "B"], # D3Q7 7: ["W", "S", "B", "BSW", "TSW", "BNW", "TNW"], # D3Q15 9: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS"], # D3Q19 - 13: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS", "BSW", "TSW", "BNW", "TNW"] # D3Q27 - } + 13: [ + "W", + "S", + "B", + "SW", + "NW", + "BW", + "TW", + "BS", + "TS", + "BSW", + "TSW", + "BNW", + "TNW", + ], # D3Q27 + }, } if not self.index_shape[0] in stencils[self.spatial_dimensions]: - raise ValueError(f"No known stencil has {self.index_shape[0]} staggered points") + raise ValueError( + f"No known stencil has {self.index_shape[0]} staggered points" + ) return stencils[self.spatial_dimensions][self.index_shape[0]] @property @@ -527,13 +690,15 @@ class Field: return Field.Access(self, center)(*args, **kwargs) def hashable_contents(self): - return (self._layout, - self.shape, - self.strides, - self.field_type, - self._field_name, - self.latex_name, - self._dtype) + return ( + self._layout, + self.shape, + self.strides, + self.field_type, + self._field_name, + self.latex_name, + self._dtype, + ) def __hash__(self): return hash(self.hashable_contents()) @@ -545,36 +710,53 @@ class Field: @property def physical_coordinates(self): - if hasattr(self.coordinate_transform, '__call__'): - return self.coordinate_transform(self.coordinate_origin + x_vector(self.spatial_dimensions)) + if hasattr(self.coordinate_transform, "__call__"): + return self.coordinate_transform( + self.coordinate_origin + x_vector(self.spatial_dimensions) + ) else: - return self.coordinate_transform @ (self.coordinate_origin + x_vector(self.spatial_dimensions)) + return self.coordinate_transform @ ( + self.coordinate_origin + x_vector(self.spatial_dimensions) + ) @property def physical_coordinates_staggered(self): - return self.coordinate_transform @ \ - (self.coordinate_origin + x_staggered_vector(self.spatial_dimensions)) + return self.coordinate_transform @ ( + self.coordinate_origin + x_staggered_vector(self.spatial_dimensions) + ) def index_to_physical(self, index_coordinates: sp.Matrix, staggered=False): if staggered: - index_coordinates = sp.Matrix([0.5] * len(self.coordinate_origin)) + index_coordinates - if hasattr(self.coordinate_transform, '__call__'): + index_coordinates = ( + sp.Matrix([0.5] * len(self.coordinate_origin)) + index_coordinates + ) + if hasattr(self.coordinate_transform, "__call__"): return self.coordinate_transform(self.coordinate_origin + index_coordinates) else: - return self.coordinate_transform @ (self.coordinate_origin + index_coordinates) + return self.coordinate_transform @ ( + self.coordinate_origin + index_coordinates + ) def physical_to_index(self, physical_coordinates: sp.Matrix, staggered=False): - if hasattr(self.coordinate_transform, '__call__'): - if hasattr(self.coordinate_transform, 'inv'): - return self.coordinate_transform.inv()(physical_coordinates) - self.coordinate_origin + if hasattr(self.coordinate_transform, "__call__"): + if hasattr(self.coordinate_transform, "inv"): + return ( + self.coordinate_transform.inv()(physical_coordinates) + - self.coordinate_origin + ) else: - idx = sp.Matrix(sp.symbols(f'index_coordinates:{self.ndim}', real=True)) + idx = sp.Matrix(sp.symbols(f"index_coordinates:{self.ndim}", real=True)) rtn = sp.solve(self.index_to_physical(idx) - physical_coordinates, idx) - assert rtn, f'Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}' + assert ( + rtn + ), f"Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}" return rtn else: - rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin + rtn = ( + self.coordinate_transform.inv() @ physical_coordinates + - self.coordinate_origin + ) if staggered: rtn = sp.Matrix([i - 0.5 for i in rtn]) @@ -603,18 +785,40 @@ class Field: >>> central_y_component.at_index(0) # change component v_C^0 """ + _iterable = False # see https://i10git.cs.fau.de/pycodegen/pystencils/-/merge_requests/166#note_10680 __match_args__ = ("field", "offsets", "index") + # for the type checker + _field: Field + _offsets: tuple[int | sp.Basic, ...] + _offsetName: str + _superscript: None | str + _index: tuple[int | sp.Basic, ...] | str + _indirect_addressing_fields: set[Field] + _is_absolute_access: bool + def __new__(cls, name, *args, **kwargs): obj = Field.Access.__xnew_cached_(cls, name, *args, **kwargs) return obj - def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None, is_absolute_access=False, dtype=None): + def __new_stage2__( # type: ignore + self, + field: Field, + offsets: tuple[int, ...] = (0, 0, 0), + idx: None | tuple[int, ...] | str = None, + is_absolute_access: bool = False, + dtype: PsType | None = None, + ): field_name = field.name offsets_and_index = (*offsets, *idx) if idx is not None else offsets - constant_offsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsets_and_index]) + constant_offsets = not any( + [ + isinstance(o, sp.Basic) and not o.is_Integer + for o in offsets_and_index + ] + ) if not idx: idx = tuple([0] * field.index_dimensions) @@ -628,31 +832,36 @@ class Field: else: idx_str = ",".join([str(e) for e in idx]) superscript = idx_str - if field.has_fixed_index_shape and not isinstance(field.dtype, PsStructType): + if field.has_fixed_index_shape and not isinstance( + field.dtype, PsStructType + ): for i, bound in zip(idx, field.index_shape): if i >= bound: raise ValueError("Field index out of bounds") else: - offset_name = hashlib.md5(pickle.dumps(offsets_and_index)).hexdigest()[:12] + offset_name = hashlib.md5(pickle.dumps(offsets_and_index)).hexdigest()[ + :12 + ] superscript = None symbol_name = f"{field_name}_{offset_name}" if superscript is not None: symbol_name += "^" + superscript + obj: Field.Access if dtype is not None: obj = super(Field.Access, self).__xnew__(self, symbol_name, dtype) else: obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype) obj._field = field - obj._offsets = [] + _offsets: list[sp.Basic | int] = [] for o in offsets: if isinstance(o, sp.Basic): - obj._offsets.append(o) + _offsets.append(o) else: - obj._offsets.append(int(o)) - obj._offsets = tuple(sp.sympify(obj._offsets)) + _offsets.append(int(o)) + obj._offsets = tuple(sp.sympify(_offsets)) obj._offsetName = offset_name obj._superscript = superscript obj._index = idx @@ -660,19 +869,33 @@ class Field: obj._indirect_addressing_fields = set() for e in chain(obj._offsets, obj._index): if isinstance(e, sp.Basic): - obj._indirect_addressing_fields.update(a.field for a in e.atoms(Field.Access)) + obj._indirect_addressing_fields.update( + a.field for a in e.atoms(Field.Access) + ) obj._is_absolute_access = is_absolute_access return obj def __getnewargs__(self): - return self.field, self.offsets, self.index, self.is_absolute_access, self.dtype + return ( + self.field, + self.offsets, + self.index, + self.is_absolute_access, + self.dtype, + ) def __getnewargs_ex__(self): - return (self.field, self.offsets, self.index, self.is_absolute_access, self.dtype), {} + return ( + self.field, + self.offsets, + self.index, + self.is_absolute_access, + self.dtype, + ), {} # noinspection SpellCheckingInspection - __xnew__ = staticmethod(__new_stage2__) + __xnew__ = staticmethod(__new_stage2__) # type: ignore # noinspection SpellCheckingInspection __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) @@ -686,22 +909,34 @@ class Field: idx = () if len(idx) != self.field.index_dimensions: - raise ValueError(f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}") + raise ValueError( + f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}" + ) if len(idx) == 1 and isinstance(idx[0], str): struct_type = self.field.dtype assert isinstance(struct_type, PsStructType) dtype = struct_type.get_member(idx[0]).dtype - return Field.Access(self.field, self._offsets, idx, - is_absolute_access=self.is_absolute_access, dtype=dtype) + return Field.Access( + self.field, + self._offsets, + idx, + is_absolute_access=self.is_absolute_access, + dtype=dtype, + ) else: - return Field.Access(self.field, self._offsets, idx, - is_absolute_access=self.is_absolute_access, dtype=self.dtype) + return Field.Access( + self.field, + self._offsets, + idx, + is_absolute_access=self.is_absolute_access, + dtype=self.dtype, + ) def __getitem__(self, *idx): return self.__call__(*idx) @property - def field(self) -> 'Field': + def field(self) -> "Field": """Field that the Access points to""" return self._field @@ -713,7 +948,7 @@ class Field: @property def required_ghost_layers(self) -> int: """Largest spatial distance that is accessed.""" - return int(np.max(np.abs(self._offsets))) + return int(np.max(np.abs(self._offsets))) # type: ignore @property def nr_of_coordinates(self): @@ -735,7 +970,7 @@ class Field: """Value of index coordinates as tuple.""" return self._index - def neighbor(self, coord_id: int, offset: int) -> 'Field.Access': + def neighbor(self, coord_id: int, offset: int) -> "Field.Access": """Returns a new Access with changed spatial coordinates. Args: @@ -749,10 +984,15 @@ class Field: """ offset_list = list(self.offsets) offset_list[coord_id] += offset - return Field.Access(self.field, tuple(offset_list), self.index, - is_absolute_access=self.is_absolute_access, dtype=self.dtype) + return Field.Access( + self.field, + tuple(offset_list), + self.index, + is_absolute_access=self.is_absolute_access, + dtype=self.dtype, + ) - def get_shifted(self, *shift) -> 'Field.Access': + def get_shifted(self, *shift) -> "Field.Access": """Returns a new Access with changed spatial coordinates Example: @@ -760,13 +1000,15 @@ class Field: >>> f[0,0].get_shifted(1, 1) f_NE """ - return Field.Access(self.field, - tuple(a + b for a, b in zip(shift, self.offsets)), - self.index, - is_absolute_access=self.is_absolute_access, - dtype=self.dtype) + return Field.Access( + self.field, + tuple(a + b for a, b in zip(shift, self.offsets)), + self.index, + is_absolute_access=self.is_absolute_access, + dtype=self.dtype, + ) - def at_index(self, *idx_tuple) -> 'Field.Access': + def at_index(self, *idx_tuple) -> "Field.Access": """Returns new Access with changed index. Example: @@ -774,15 +1016,22 @@ class Field: >>> f(0).at_index(8) f_C^8 """ - return Field.Access(self.field, self.offsets, idx_tuple, - is_absolute_access=self.is_absolute_access, dtype=self.dtype) + return Field.Access( + self.field, + self.offsets, + idx_tuple, + is_absolute_access=self.is_absolute_access, + dtype=self.dtype, + ) def _eval_subs(self, old, new): - return Field.Access(self.field, - tuple(sp.sympify(a).subs(old, new) for a in self.offsets), - tuple(sp.sympify(a).subs(old, new) for a in self.index), - is_absolute_access=self.is_absolute_access, - dtype=self.dtype) + return Field.Access( + self.field, + tuple(sp.sympify(a).subs(old, new) for a in self.offsets), + tuple(sp.sympify(a).subs(old, new) for a in self.index), + is_absolute_access=self.is_absolute_access, + dtype=self.dtype, + ) @property def is_absolute_access(self) -> bool: @@ -790,30 +1039,43 @@ class Field: return self._is_absolute_access @property - def indirect_addressing_fields(self) -> Set['Field']: + def indirect_addressing_fields(self) -> Set["Field"]: """Returns a set of fields that the access depends on. - e.g. f[index_field[1, 0]], the outer access to f depends on index_field - """ + e.g. f[index_field[1, 0]], the outer access to f depends on index_field + """ return self._indirect_addressing_fields def _hashable_content(self): super_class_contents = super(Field.Access, self)._hashable_content() - return (super_class_contents, self._field.hashable_contents(), *self._index, - *self._offsets, self._is_absolute_access) + return ( + super_class_contents, + self._field.hashable_contents(), + *self._index, + *self._offsets, + self._is_absolute_access, + ) def _staggered_offset(self, offsets, index): assert FieldType.is_staggered(self._field) neighbor = self._field.staggered_stencil[index] - neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions) - return [(o + sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)] + neighbor = direction_string_to_offset( + neighbor, self._field.spatial_dimensions + ) + return [ + (o + sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets) + ] def _latex(self, _): n = self._field.latex_name if self._field.latex_name else self._field.name offset_str = ",".join([sp.latex(o) for o in self.offsets]) if FieldType.is_staggered(self._field): - offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i]) - for i in range(len(self.offsets))]) + offset_str = ",".join( + [ + sp.latex(self._staggered_offset(self.offsets, self.index[0])[i]) + for i in range(len(self.offsets)) + ] + ) if self.is_absolute_access: offset_str = f"\\mathbf{offset_str}" elif self.field.spatial_dimensions > 1: @@ -834,8 +1096,12 @@ class Field: n = self._field.latex_name if self._field.latex_name else self._field.name offset_str = ",".join([sp.latex(o) for o in self.offsets]) if FieldType.is_staggered(self._field): - offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i]) - for i in range(len(self.offsets))]) + offset_str = ",".join( + [ + sp.latex(self._staggered_offset(self.offsets, self.index[0])[i]) + for i in range(len(self.offsets)) + ] + ) if self.is_absolute_access: offset_str = f"[abs]{offset_str}" @@ -851,12 +1117,36 @@ class Field: return f"{n}[{offset_str}]" -def fields(description=None, index_dimensions=0, layout=None, - field_type=FieldType.GENERIC, **kwargs) -> Union[Field, List[Field]]: +def fields( + description=None, + index_dimensions=0, + layout=None, + field_type=FieldType.GENERIC, + **kwargs, +) -> Field | list[Field]: """Creates pystencils fields from a string description. + The description must be a string of the form + ``"name(index-shape) [name(index-shape) ...]: <data-type>[<dimension-or-shape>]"``, + where: + + - ``name`` is the name of the respective field + - ``(index-shape)`` is a tuple of integers describing the shape of the tensor on each field node + (can be omitted for scalar fields) + - ``<data-type>`` is the numerical data type of the field's entries; + this can be any type parseable by `create_type`, + as well as ``dyn`` for `DynamicType.NUMERIC_TYPE` + and ``dynidx`` for `DynamicType.INDEX_TYPE`. + - ``<dimension-or-shape>`` can be a dimensionality (e.g. ``1D``, ``2D``, ``3D``) + or a tuple of integers defining the spatial shape of the field. + Examples: - Create a 2D scalar and vector field: + Create a 3D scalar field of default numeric type: + >>> f = fields("f(1): [2D]") + >>> str(f.dtype) + 'DynamicType.NUMERIC_TYPE' + + Create a 2D scalar and vector field of 64-bit float type: >>> s, v = fields("s, v(2): double[2D]") >>> assert s.spatial_dimensions == 2 and s.index_dimensions == 0 >>> assert (v.spatial_dimensions, v.index_dimensions, v.index_shape) == (2, 1, (2,)) @@ -882,35 +1172,70 @@ def fields(description=None, index_dimensions=0, layout=None, >>> f = fields("pdfs(19) : float32[3D]", layout='fzyx') >>> f.layout (2, 1, 0) + + Returns: + Sequence of fields created from the description """ result = [] if description: field_descriptions, dtype, shape = _parse_description(description) - layout = 'numpy' if layout is None else layout + layout = "numpy" if layout is None else layout for field_name, idx_shape in field_descriptions: if field_name in kwargs: arr = kwargs[field_name] - idx_shape_of_arr = () if not len(idx_shape) else arr.shape[-len(idx_shape):] + idx_shape_of_arr = ( + () if not len(idx_shape) else arr.shape[-len(idx_shape):] + ) assert idx_shape_of_arr == idx_shape - f = Field.create_from_numpy_array(field_name, kwargs[field_name], index_dimensions=len(idx_shape), - field_type=field_type) + f = Field.create_from_numpy_array( + field_name, + kwargs[field_name], + index_dimensions=len(idx_shape), + field_type=field_type, + ) elif isinstance(shape, tuple): - f = Field.create_fixed_size(field_name, shape + idx_shape, dtype=dtype, - index_dimensions=len(idx_shape), layout=layout, field_type=field_type) + f = Field.create_fixed_size( + field_name, + shape + idx_shape, + dtype=dtype, + index_dimensions=len(idx_shape), + layout=layout, + field_type=field_type, + ) elif isinstance(shape, int): - f = Field.create_generic(field_name, spatial_dimensions=shape, dtype=dtype, - index_shape=idx_shape, layout=layout, field_type=field_type) + f = Field.create_generic( + field_name, + spatial_dimensions=shape, + dtype=dtype, + index_shape=idx_shape, + layout=layout, + field_type=field_type, + ) elif shape is None: - f = Field.create_generic(field_name, spatial_dimensions=2, dtype=dtype, - index_shape=idx_shape, layout=layout, field_type=field_type) + f = Field.create_generic( + field_name, + spatial_dimensions=2, + dtype=dtype, + index_shape=idx_shape, + layout=layout, + field_type=field_type, + ) else: assert False result.append(f) else: - assert layout is None, "Layout can not be specified when creating Field from numpy array" + assert ( + layout is None + ), "Layout can not be specified when creating Field from numpy array" for field_name, arr in kwargs.items(): - result.append(Field.create_from_numpy_array(field_name, arr, index_dimensions=index_dimensions, - field_type=field_type)) + result.append( + Field.create_from_numpy_array( + field_name, + arr, + index_dimensions=index_dimensions, + field_type=field_type, + ) + ) if len(result) == 0: raise ValueError("Could not parse field description") @@ -920,16 +1245,27 @@ def fields(description=None, index_dimensions=0, layout=None, return result -def get_layout_from_strides(strides: Sequence[int], index_dimension_ids: Optional[List[int]] = None): +def get_layout_from_strides( + strides: Sequence[int], index_dimension_ids: Optional[List[int]] = None +): index_dimension_ids = [] if index_dimension_ids is None else index_dimension_ids coordinates = list(range(len(strides))) - relevant_strides = [stride for i, stride in enumerate(strides) if i not in index_dimension_ids] - result = [x for (y, x) in sorted(zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True)] + relevant_strides = [ + stride for i, stride in enumerate(strides) if i not in index_dimension_ids + ] + result = [ + x + for (y, x) in sorted( + zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True + ) + ] return normalize_layout(result) -def get_layout_of_array(arr: np.ndarray, index_dimension_ids: Optional[List[int]] = None): - """ Returns a list indicating the memory layout (linearization order) of the numpy array. +def get_layout_of_array( + arr: np.ndarray, index_dimension_ids: Optional[List[int]] = None +): + """Returns a list indicating the memory layout (linearization order) of the numpy array. Examples: >>> get_layout_of_array(np.zeros([3,3,3])) @@ -946,7 +1282,9 @@ def get_layout_of_array(arr: np.ndarray, index_dimension_ids: Optional[List[int] return get_layout_from_strides(arr.strides, index_dimension_ids) -def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0, **kwargs): +def create_numpy_array_with_layout( + shape, layout, alignment=False, byte_offset=0, **kwargs +): """Creates numpy array with given memory layout. Args: @@ -970,7 +1308,10 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0 if cur_layout[i] != layout[i]: index_to_swap_with = cur_layout.index(layout[i]) swaps.append((i, index_to_swap_with)) - cur_layout[i], cur_layout[index_to_swap_with] = cur_layout[index_to_swap_with], cur_layout[i] + cur_layout[i], cur_layout[index_to_swap_with] = ( + cur_layout[index_to_swap_with], + cur_layout[i], + ) assert tuple(cur_layout) == tuple(layout) shape = list(shape) @@ -978,7 +1319,7 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0 shape[a], shape[b] = shape[b], shape[a] if not alignment: - res = np.empty(shape, order='c', **kwargs) + res = np.empty(shape, order="c", **kwargs) else: res = aligned_empty(shape, alignment, byte_offset=byte_offset, **kwargs) @@ -990,37 +1331,43 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0 def spatial_layout_string_to_tuple(layout_str: str, dim: int) -> Tuple[int, ...]: if dim <= 0: raise ValueError("Dimensionality must be positive") - + layout_str = layout_str.lower() - if layout_str in ('fzyx', 'zyxf', 'soa', 'aos'): + if layout_str in ("fzyx", "zyxf", "soa", "aos"): if dim > 3: - raise ValueError(f"Invalid spatial dimensionality for layout descriptor {layout_str}: May be at most 3.") + raise ValueError( + f"Invalid spatial dimensionality for layout descriptor {layout_str}: May be at most 3." + ) return tuple(reversed(range(dim))) - - if layout_str in ('f', 'reverse_numpy'): + + if layout_str in ("f", "reverse_numpy"): return tuple(reversed(range(dim))) - elif layout_str in ('c', 'numpy'): + elif layout_str in ("c", "numpy"): return tuple(range(dim)) raise ValueError("Unknown layout descriptor " + layout_str) -def layout_string_to_tuple(layout_str, dim): +def layout_string_to_tuple(layout_str, dim) -> tuple[int, ...]: if dim <= 0: raise ValueError("Dimensionality must be positive") - + layout_str = layout_str.lower() - if layout_str == 'fzyx' or layout_str == 'soa': + if layout_str == "fzyx" or layout_str == "soa": if dim > 4: - raise ValueError(f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4.") + raise ValueError( + f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4." + ) return tuple(reversed(range(dim))) - elif layout_str == 'zyxf' or layout_str == 'aos': + elif layout_str == "zyxf" or layout_str == "aos": if dim > 4: - raise ValueError(f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4.") + raise ValueError( + f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4." + ) return tuple(reversed(range(dim - 1))) + (dim - 1,) - elif layout_str == 'f' or layout_str == 'reverse_numpy': + elif layout_str == "f" or layout_str == "reverse_numpy": return tuple(reversed(range(dim))) - elif layout_str == 'c' or layout_str == 'numpy': + elif layout_str == "c" or layout_str == "numpy": return tuple(range(dim)) raise ValueError("Unknown layout descriptor " + layout_str) @@ -1055,7 +1402,8 @@ def compute_strides(shape, layout): # ---------------------------------------- Parsing of string in fields() function -------------------------------------- -field_description_regex = re.compile(r""" +field_description_regex = re.compile( + r""" \s* # ignore leading white spaces (\w+) # identifier is a sequence of alphanumeric characters, is stored in first group (?: # optional index specification e.g. (1, 4, 2) @@ -1066,9 +1414,12 @@ field_description_regex = re.compile(r""" \s* )? \s*,?\s* # ignore trailing white spaces and comma -""", re.VERBOSE) +""", + re.VERBOSE, +) -type_description_regex = re.compile(r""" +type_description_regex = re.compile( + r""" \s* (\w+)? # optional dtype \s* @@ -1076,7 +1427,9 @@ type_description_regex = re.compile(r""" ([^\]]+) \] \s* -""", re.VERBOSE | re.IGNORECASE) +""", + re.VERBOSE | re.IGNORECASE, +) def _parse_part1(d): @@ -1094,24 +1447,30 @@ def _parse_description(description): result = type_description_regex.match(d) if result: data_type_str, size_info = result.group(1), result.group(2).strip().lower() - if data_type_str is None: - data_type_str = 'float64' - data_type_str = data_type_str.lower().strip() + if data_type_str is not None: + data_type_str = data_type_str.lower().strip() + + if data_type_str: + match data_type_str: + case "dyn": dtype = DynamicType.NUMERIC_TYPE + case "dynidx": dtype = DynamicType.INDEX_TYPE + case _: dtype = create_type(data_type_str) + else: + dtype = DynamicType.NUMERIC_TYPE - if not data_type_str: - data_type_str = 'float64' - if size_info.endswith('d'): + if size_info.endswith("d"): size_info = int(size_info[:-1]) else: size_info = tuple(int(e) for e in size_info.split(",")) - return data_type_str, size_info + + return dtype, size_info else: raise ValueError("Could not parse field description") - if ':' in description: - field_description, field_info = description.split(':') + if ":" in description: + field_description, field_info = description.split(":") else: - field_description, field_info = description, 'float64[2D]' + field_description, field_info = description, "float64[2D]" fields_info = [e for e in _parse_part1(field_description)] if not field_info: diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py index befb033e6f7969a5ffd9bc7742e9e7ab691da47d..55f1961ca5c00963c16912ada738788688a93452 100644 --- a/src/pystencils/jit/cpu_extension_module.py +++ b/src/pystencils/jit/cpu_extension_module.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast from os import path import hashlib @@ -19,6 +19,7 @@ from ..types import ( PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType, + PsPointerType, ) from ..types.quick import Fp, SInt, UInt from ..field import Field @@ -121,9 +122,7 @@ class PsKernelExtensioNModule: def emit_call_wrapper(function_name: str, kernel: Kernel) -> str: builder = CallWrapperBuilder() - - for p in kernel.parameters: - builder.extract_parameter(p) + builder.extract_params(kernel.parameters) # for c in kernel.constraints: # builder.check_constraint(c) @@ -199,7 +198,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ """ def __init__(self) -> None: - self._array_buffers: dict[Field, str] = dict() + self._buffer_types: dict[Field, PsType] = dict() self._array_extractions: dict[Field, str] = dict() self._array_frees: dict[Field, str] = dict() @@ -220,9 +219,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return "PyLong_AsUnsignedLong" case _: - raise ValueError( - f"Don't know how to cast Python objects to {dtype}" - ) + raise ValueError(f"Don't know how to cast Python objects to {dtype}") def _type_char(self, dtype: PsType) -> str | None: if isinstance( @@ -233,37 +230,39 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ else: return None - def extract_field(self, field: Field) -> str: - """Adds an array, and returns the name of the underlying Py_Buffer.""" + def get_field_buffer(self, field: Field) -> str: + """Get the Python buffer object for the given field.""" + return f"buffer_{field.name}" + + def extract_field(self, field: Field) -> None: + """Add the necessary code to extract the NumPy array for a given field""" if field not in self._array_extractions: extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=field.name) + actual_dtype = self._buffer_types[field] # Check array type - type_char = self._type_char(field.dtype) + type_char = self._type_char(actual_dtype) if type_char is not None: dtype_cond = f"buffer_{field.name}.format[0] == '{type_char}'" extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format( cond=dtype_cond, what="data type", name=field.name, - expected=str(field.dtype), + expected=str(actual_dtype), ) # Check item size - itemsize = field.dtype.itemsize + itemsize = actual_dtype.itemsize item_size_cond = f"buffer_{field.name}.itemsize == {itemsize}" extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format( cond=item_size_cond, what="itemsize", name=field.name, expected=itemsize ) - self._array_buffers[field] = f"buffer_{field.name}" self._array_extractions[field] = extraction_code release_code = f"PyBuffer_Release(&buffer_{field.name});" self._array_frees[field] = release_code - return self._array_buffers[field] - def extract_scalar(self, param: Parameter) -> str: if param not in self._scalar_extractions: extract_func = self._scalar_extractor(param.dtype) @@ -279,7 +278,8 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ def extract_array_assoc_var(self, param: Parameter) -> str: if param not in self._array_assoc_var_extractions: field = param.fields[0] - buffer = self.extract_field(field) + buffer = self.get_field_buffer(field) + buffer_dtype = self._buffer_types[field] code: str | None = None for prop in param.properties: @@ -293,7 +293,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ case FieldStride(_, coord): code = ( f"{param.dtype.c_string()} {param.name} = " - f"{buffer}.strides[{coord}] / {field.dtype.itemsize};" + f"{buffer}.strides[{coord}] / {buffer_dtype.itemsize};" ) break assert code is not None @@ -302,29 +302,48 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return param.name - def extract_parameter(self, param: Parameter): - if param.is_field_parameter: - self.extract_array_assoc_var(param) - else: - self.extract_scalar(param) + def extract_params(self, params: tuple[Parameter, ...]) -> None: + for param in params: + if ptr_props := param.get_properties(FieldBasePtr): + prop: FieldBasePtr = cast(FieldBasePtr, ptr_props.pop()) + field = prop.field + actual_field_type: PsType + + from .. import DynamicType + + if isinstance(field.dtype, DynamicType): + ptr_type = param.dtype + assert isinstance(ptr_type, PsPointerType) + actual_field_type = ptr_type.base_type + else: + actual_field_type = field.dtype + + self._buffer_types[prop.field] = actual_field_type + self.extract_field(prop.field) + + for param in params: + if param.is_field_parameter: + self.extract_array_assoc_var(param) + else: + self.extract_scalar(param) -# def check_constraint(self, constraint: KernelParamsConstraint): -# variables = constraint.get_parameters() + # def check_constraint(self, constraint: KernelParamsConstraint): + # variables = constraint.get_parameters() -# for var in variables: -# self.extract_parameter(var) + # for var in variables: + # self.extract_parameter(var) -# cond = constraint.to_code() + # cond = constraint.to_code() -# code = f""" -# if(!({cond})) -# {{ -# PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}"); -# return NULL; -# }} -# """ + # code = f""" + # if(!({cond})) + # {{ + # PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}"); + # return NULL; + # }} + # """ -# self._constraint_checks.append(code) + # self._constraint_checks.append(code) def call(self, kernel: Kernel, params: tuple[Parameter, ...]): param_list = ", ".join(p.name for p in params) diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py index c208ac2196151d079ca5081f1377c55d18a9393c..a407bb75e08bfde9911070aef03b4a1769a6221a 100644 --- a/src/pystencils/jit/gpu_cupy.py +++ b/src/pystencils/jit/gpu_cupy.py @@ -19,7 +19,7 @@ from ..codegen import ( Parameter, ) from ..codegen.properties import FieldShape, FieldStride, FieldBasePtr -from ..types import PsStructType +from ..types import PsStructType, PsPointerType from ..include import get_pystencils_include_path @@ -160,8 +160,18 @@ class CupyKernelWrapper(KernelWrapper): for prop in kparam.properties: match prop: case FieldBasePtr(field): + + elem_dtype: PsType + + from .. import DynamicType + if isinstance(field.dtype, DynamicType): + assert isinstance(kparam.dtype, PsPointerType) + elem_dtype = kparam.dtype.base_type + else: + elem_dtype = field.dtype + arr = kwargs[field.name] - if arr.dtype != field.dtype.numpy_dtype: + if arr.dtype != elem_dtype.numpy_dtype: raise JitError( f"Data type mismatch at array argument {field.name}:" f"Expected {field.dtype}, got {arr.dtype}" diff --git a/src/pystencils/rng.py b/src/pystencils/rng.py index d6c6cd2741ee3e7442bd9fa4a96f4e9983d524e3..4f8316fa75284ed0fa3385744bd9b93f88d5ae65 100644 --- a/src/pystencils/rng.py +++ b/src/pystencils/rng.py @@ -2,7 +2,7 @@ import copy import numpy as np import sympy as sp -from .sympyextensions import TypedSymbol, CastFunc, fast_subs +from .sympyextensions import TypedSymbol, tcast, fast_subs # from pystencils.sympyextensions.astnodes import LoopOverCoordinate # TODO nbackend: replace # from pystencils.backends.cbackend import CustomCodeNode # TODO nbackend: replace @@ -48,7 +48,7 @@ class RNGBase: def get_code(self, dialect, vector_instruction_set, print_arg): code = "\n" for r in self.result_symbols: - if vector_instruction_set and not self.args[1].atoms(CastFunc): + if vector_instruction_set and not self.args[1].atoms(tcast): # this vector RNG has become scalar through substitution code += f"{r.dtype} {r.name};\n" else: diff --git a/src/pystencils/sympyextensions/__init__.py b/src/pystencils/sympyextensions/__init__.py index 7431416c9eb9bcd4433dab76c32fb1b755501105..2d874fdc0778a331aaf61ed938981f533eafbecb 100644 --- a/src/pystencils/sympyextensions/__init__.py +++ b/src/pystencils/sympyextensions/__init__.py @@ -1,5 +1,5 @@ from .astnodes import ConditionalFieldAccess -from .typed_sympy import TypedSymbol, CastFunc +from .typed_sympy import TypedSymbol, tcast from .pointers import mem_acc from .math import ( @@ -34,7 +34,7 @@ from .math import ( __all__ = [ "ConditionalFieldAccess", "TypedSymbol", - "CastFunc", + "tcast", "mem_acc", "remove_higher_order_terms", "prod", diff --git a/src/pystencils/sympyextensions/math.py b/src/pystencils/sympyextensions/math.py index 9841a98bd83162fbb080db370556de70612bc398..33c035499ee80598303c8a26b028e47dfae72cc3 100644 --- a/src/pystencils/sympyextensions/math.py +++ b/src/pystencils/sympyextensions/math.py @@ -11,7 +11,7 @@ from sympy.functions import Abs from sympy.core.numbers import Zero from ..assignment import Assignment -from .typed_sympy import CastFunc +from .typed_sympy import TypeCast from ..types import PsPointerType, PsVectorType T = TypeVar('T') @@ -603,7 +603,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]], visit_children = False elif t.is_integer: pass - elif isinstance(t, CastFunc): + elif isinstance(t, TypeCast): visit_children = False visit(t.args[0]) elif t.func is fast_sqrt: diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index 39202296b477dc17ea6e9564548ef841fd04594d..e2435d6bbe570887e0903c67f6041ed9911c02be 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import cast import sympy as sp from enum import Enum, auto @@ -6,11 +7,14 @@ from enum import Enum, auto from ..types import ( PsType, PsNumericType, - PsBoolType, create_type, UserTypeSpec ) +from sympy.logic.boolalg import Boolean + +from warnings import warn + def is_loop_counter_symbol(symbol): from ..defaults import DEFAULTS @@ -37,11 +41,12 @@ class DynamicType(Enum): class TypeAtom(sp.Atom): """Wrapper around a type to disguise it as a SymPy atom.""" - def __new__(cls, *args, **kwargs): - return sp.Basic.__new__(cls) + _dtype: PsType | DynamicType - def __init__(self, dtype: PsType | DynamicType) -> None: - self._dtype = dtype + def __new__(cls, dtype: PsType | DynamicType): + obj = super().__new__(cls) + obj._dtype = dtype + return obj def _sympystr(self, *args, **kwargs): return str(self._dtype) @@ -52,6 +57,9 @@ class TypeAtom(sp.Atom): def _hashable_content(self): return (self._dtype,) + def __getnewargs__(self): + return (self._dtype,) + def assumptions_from_dtype(dtype: PsType | DynamicType): """Derives SymPy assumptions from :class:`PsAbstractType` @@ -133,144 +141,74 @@ class TypedSymbol(sp.Symbol): return self.dtype.required_headers if isinstance(self.dtype, PsType) else set() -class CastFunc(sp.Function): - """Use this function to introduce a static type cast into the output code. - - Usage: ``CastFunc(expr, target_type)`` becomes, in C code, ``(target_type) expr``. - The ``target_type`` may be a valid pystencils type specification parsable by `create_type`, - or a special value of the `DynamicType` enum. - These dynamic types can be used to select the target type according to the code generation context. - """ +class TypeCast(sp.Function): + """Explicitly cast an expression to a data type.""" @staticmethod def as_numeric(expr): - return CastFunc(expr, DynamicType.NUMERIC_TYPE) + return TypeCast(expr, DynamicType.NUMERIC_TYPE) @staticmethod def as_index(expr): - return CastFunc(expr, DynamicType.INDEX_TYPE) - - is_Atom = True - - def __new__(cls, *args, **kwargs): - if len(args) != 2: - pass - expr, dtype, *other_args = args - - # If we have two consecutive casts, throw the inner one away. - # This optimisation is only available for simple casts. Thus the == is intended here! - if expr.__class__ == CastFunc: - expr = expr.args[0] - - if not isinstance(dtype, (TypeAtom)): - if isinstance(dtype, DynamicType): - dtype = TypeAtom(dtype) - else: - dtype = TypeAtom(create_type(dtype)) - - # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well - # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads - # to problems when for example comparing cast_func's for equality - # - # lhs = bitwise_and(a, cast_func(1, 'int')) - # rhs = cast_func(0, 'int') - # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans - # -> thus a separate class boolean_cast_func is introduced - if isinstance(expr, sp.logic.boolalg.Boolean) and ( - not isinstance(expr, TypedSymbol) or isinstance(expr.dtype, PsBoolType) - ): - cls = BooleanCastFunc - - return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs) - - @property - def canonical(self): - if hasattr(self.args[0], "canonical"): - return self.args[0].canonical - else: - raise NotImplementedError() - - @property - def is_commutative(self): - return self.args[0].is_commutative - - @property - def dtype(self) -> PsType | DynamicType: - assert isinstance(self.args[1], TypeAtom) - return self.args[1].get() - + return TypeCast(expr, DynamicType.INDEX_TYPE) + @property - def expr(self): + def expr(self) -> sp.Basic: return self.args[0] @property - def is_integer(self): + def dtype(self) -> PsType | DynamicType: + return cast(TypeAtom, self._args[1]).get() + + def __new__(cls, expr: sp.Basic, dtype: UserTypeSpec | DynamicType | TypeAtom): + tatom: TypeAtom + match dtype: + case TypeAtom(): + tatom = dtype + case DynamicType(): + tatom = TypeAtom(dtype) + case _: + tatom = TypeAtom(create_type(dtype)) + + return super().__new__(cls, expr, tatom) + + @classmethod + def eval(cls, expr: sp.Basic, tatom: TypeAtom) -> TypeCast | None: + dtype = tatom.get() + if cls is not BoolCast and isinstance(dtype, PsNumericType) and dtype.is_bool(): + return BoolCast(expr, tatom) + + return None + + def _eval_is_integer(self): if self.dtype == DynamicType.INDEX_TYPE: return True - elif isinstance(self.dtype, PsNumericType): - return self.dtype.is_int() or super().is_integer - else: - return super().is_integer - - @property - def is_negative(self): - """ - See :func:`.TypedSymbol.is_integer` - """ - if isinstance(self.dtype, PsNumericType): - if self.dtype.is_uint(): - return False - - return super().is_negative - - @property - def is_nonnegative(self): - """ - See :func:`.TypedSymbol.is_integer` - """ - if self.is_negative is False: + if isinstance(self.dtype, PsNumericType) and self.dtype.is_int(): + return True + + def _eval_is_real(self): + if isinstance(self.dtype, DynamicType): + return True + if isinstance(self.dtype, PsNumericType) and (self.dtype.is_float() or self.dtype.is_int()): + return True + + def _eval_is_nonnegative(self): + if isinstance(self.dtype, PsNumericType) and self.dtype.is_uint(): return True - else: - return super().is_nonnegative - - @property - def is_real(self): - """ - See :func:`.TypedSymbol.is_integer` - """ - if isinstance(self.dtype, PsNumericType): - return self.dtype.is_int() or self.dtype.is_float() or super().is_real - else: - return super().is_real - - -class BooleanCastFunc(CastFunc, sp.logic.boolalg.Boolean): - # TODO: documentation - pass - - -class VectorMemoryAccess(CastFunc): - """ - Special memory access for vectorized kernel. - Arguments: read/write expression, type, aligned, non-temporal, mask (or none), stride - """ - nargs = (6,) +class BoolCast(TypeCast, Boolean): + pass -class ReinterpretCastFunc(CastFunc): - """ - Reinterpret cast is necessary for the StructType - """ - pass +tcast = TypeCast -class PointerArithmeticFunc(sp.Function, sp.logic.boolalg.Boolean): - # TODO: documentation, or deprecate! - @property - def canonical(self): - if hasattr(self.args[0], "canonical"): - return self.args[0].canonical - else: - raise NotImplementedError() +class CastFunc(TypeCast): + def __new__(cls, *args, **kwargs): + warn( + "CastFunc is deprecated and will be removed in pystencils 2.1. " + "Use `pystencils.tcast` instead.", + FutureWarning + ) + return TypeCast.__new__(cls, *args, **kwargs) diff --git a/tests/frontend/test_address_of.py b/tests/frontend/test_address_of.py index 99f33ddbdfa7054bf5f27c08848640ee03f64555..62d7f00d56b288c009c9dc4fcfade95b95acdd41 100644 --- a/tests/frontend/test_address_of.py +++ b/tests/frontend/test_address_of.py @@ -5,7 +5,7 @@ import pytest import pystencils from pystencils.types import PsPointerType, create_type from pystencils.sympyextensions.pointers import AddressOf -from pystencils.sympyextensions.typed_sympy import CastFunc +from pystencils.sympyextensions.typed_sympy import tcast from pystencils.simp import sympy_cse import sympy as sp @@ -23,14 +23,14 @@ def test_address_of(): assignments = pystencils.AssignmentCollection({ s: AddressOf(x[0, 0]), - y[0, 0]: CastFunc(s, create_type('int64')) + y[0, 0]: tcast(s, create_type('int64')) }) _ = pystencils.create_kernel(assignments).compile() # pystencils.show_code(kernel.ast) assignments = pystencils.AssignmentCollection({ - y[0, 0]: CastFunc(AddressOf(x[0, 0]), create_type('int64')) + y[0, 0]: tcast(AddressOf(x[0, 0]), create_type('int64')) }) _ = pystencils.create_kernel(assignments).compile() @@ -41,7 +41,7 @@ def test_address_of_with_cse(): x, y = pystencils.fields('x, y: int64[2d]') assignments = pystencils.AssignmentCollection({ - x[0, 0]: CastFunc(AddressOf(x[0, 0]), create_type('int64')) + 1 + x[0, 0]: tcast(AddressOf(x[0, 0]), create_type('int64')) + 1 }) _ = pystencils.create_kernel(assignments).compile() diff --git a/tests/frontend/test_field.py b/tests/frontend/test_field.py index dc804491bee8023e7b0e1b665d5f9cd252d64c1d..6d2942569704b7ff85b15fd23432667ba109ed7d 100644 --- a/tests/frontend/test_field.py +++ b/tests/frontend/test_field.py @@ -3,7 +3,7 @@ import pytest import sympy as sp import pystencils as ps -from pystencils import DEFAULTS +from pystencils import DEFAULTS, DynamicType, create_type, fields from pystencils.field import ( Field, FieldType, @@ -15,6 +15,7 @@ from pystencils.field import ( def test_field_basic(): f = Field.create_generic("f", spatial_dimensions=2) assert FieldType.is_generic(f) + assert f.dtype == DynamicType.NUMERIC_TYPE assert f["E"] == f[1, 0] assert f["N"] == f[0, 1] assert "_" in f.center._latex("dummy") @@ -41,17 +42,16 @@ def test_field_basic(): assert f1.ndim == f.ndim assert f1.values_per_cell() == f.values_per_cell() - fixed = ps.fields("f(5, 5) : double[20, 20]") - assert fixed.neighbor_vector((1, 1)).shape == (5, 5) - - f = Field.create_fixed_size("f", (10, 10), strides=(80, 8), dtype=np.float64) + f = Field.create_fixed_size("f", (10, 10), strides=(10, 1), dtype=np.float64) assert f.spatial_strides == (10, 1) assert f.index_strides == () assert f.center_vector == sp.Matrix([f.center]) + assert f.dtype == create_type("float64") f1 = f.new_field_with_different_name("f1") assert f1.ndim == f.ndim assert f1.values_per_cell() == f.values_per_cell() + assert f1.dtype == create_type("float64") f = Field.create_fixed_size("f", (8, 8, 2, 2), index_dimensions=2) assert f.center_vector == sp.Matrix([[f(0, 0), f(0, 1)], [f(1, 0), f(1, 1)]]) @@ -61,16 +61,48 @@ def test_field_basic(): neighbor = field_access.neighbor(coord_id=0, offset=-2) assert neighbor.offsets == (-1, 1) assert "_" in neighbor._latex("dummy") + assert f.dtype == DynamicType.NUMERIC_TYPE f = Field.create_fixed_size("f", (8, 8, 2, 2, 2), index_dimensions=3) assert f.center_vector == sp.Array( [[[f(i, j, k) for k in range(2)] for j in range(2)] for i in range(2)] ) + assert f.dtype == DynamicType.NUMERIC_TYPE f = Field.create_generic("f", spatial_dimensions=5, index_dimensions=2) field_access = f[1, -1, 2, -3, 0](1, 0) assert field_access.offsets == (1, -1, 2, -3, 0) assert field_access.index == (1, 0) + assert f.dtype == DynamicType.NUMERIC_TYPE + + +def test_field_description_parsing(): + f, g = fields("f(1), g(3): [2D]") + + assert f.dtype == g.dtype == DynamicType.NUMERIC_TYPE + assert f.spatial_dimensions == g.spatial_dimensions == 2 + assert f.index_shape == (1,) + assert g.index_shape == (3,) + + f = fields("f: dyn[3D]") + assert f.dtype == DynamicType.NUMERIC_TYPE + + idx = fields("idx: dynidx[3D]") + assert idx.dtype == DynamicType.INDEX_TYPE + + h = fields("h: float32[3D]") + + assert h.index_shape == () + assert h.spatial_dimensions == 3 + assert h.index_dimensions == 0 + assert h.dtype == create_type("float32") + + f: Field = fields("f(5, 5) : double[20, 20]") + + assert f.dtype == create_type("float64") + assert f.spatial_shape == (20, 20) + assert f.index_shape == (5, 5) + assert f.neighbor_vector((1, 1)).shape == (5, 5) def test_error_handling(): @@ -145,7 +177,7 @@ def test_error_handling(): def test_decorator_scoping(): - dst = ps.fields("dst : double[2D]") + dst = fields("dst : double[2D]") def f1(): a = sp.Symbol("a") @@ -165,7 +197,7 @@ def test_decorator_scoping(): def test_string_creation(): - x, y, z = ps.fields(" x(4), y(3,5) z : double[ 3, 47]") + x, y, z = fields(" x(4), y(3,5) z : double[ 3, 47]") assert x.index_shape == (4,) assert y.index_shape == (3, 5) assert z.spatial_shape == (3, 47) @@ -173,9 +205,9 @@ def test_string_creation(): def test_itemsize(): - x = ps.fields("x: float32[1d]") - y = ps.fields("y: float64[2d]") - i = ps.fields("i: int16[1d]") + x = fields("x: float32[1d]") + y = fields("y: float64[2d]") + i = fields("i: int16[1d]") assert x.itemsize == 4 assert y.itemsize == 8 @@ -249,7 +281,7 @@ def test_memory_layout_descriptors(): def test_staggered(): # D2Q5 - j1, j2, j3 = ps.fields( + j1, j2, j3 = fields( "j1(2), j2(2,2), j3(2,2,2) : double[2D]", field_type=FieldType.STAGGERED ) @@ -296,7 +328,7 @@ def test_staggered(): ) # D2Q9 - k1, k2 = ps.fields("k1(4), k2(2) : double[2D]", field_type=FieldType.STAGGERED) + k1, k2 = fields("k1(4), k2(2) : double[2D]", field_type=FieldType.STAGGERED) assert k1[1, 1](2) == k1.staggered_access("NE") assert k1[0, 0](2) == k1.staggered_access("SW") @@ -319,7 +351,7 @@ def test_staggered(): ] # sign reversed when using as flux field - r = ps.fields("r(2) : double[2D]", field_type=FieldType.STAGGERED_FLUX) + r = fields("r(2) : double[2D]", field_type=FieldType.STAGGERED_FLUX) assert r[0, 0](0) == r.staggered_access("W") assert -r[1, 0](0) == r.staggered_access("E") diff --git a/tests/frontend/test_typed_sympy.py b/tests/frontend/test_typed_sympy.py index 41015f96bfa6950a57f9ccfa3194c128c2bc0f69..bf6058537a7217851d22987f3b011edea08058c8 100644 --- a/tests/frontend/test_typed_sympy.py +++ b/tests/frontend/test_typed_sympy.py @@ -1,8 +1,11 @@ import numpy as np +import pickle +import sympy as sp +from sympy.logic import boolalg from pystencils.sympyextensions.typed_sympy import ( TypedSymbol, - CastFunc, + tcast, TypeAtom, DynamicType, ) @@ -12,7 +15,7 @@ from pystencils.types.quick import UInt, Ptr def test_type_atoms(): atom1 = TypeAtom(create_type("int32")) - atom2 = TypeAtom(create_type("int32")) + atom2 = TypeAtom(create_type(np.int32)) assert atom1 == atom2 @@ -25,6 +28,11 @@ def test_type_atoms(): assert atom3 != atom4 assert atom4 != atom5 + dump = pickle.dumps(atom1) + atom1_reconst = pickle.loads(dump) + + assert atom1_reconst == atom1 + def test_typed_symbol(): x = TypedSymbol("x", "uint32") @@ -46,12 +54,34 @@ def test_typed_symbol(): assert not z.is_nonnegative -def test_cast_func(): - assert ( - CastFunc(TypedSymbol("s", np.uint), np.int64).canonical - == TypedSymbol("s", np.uint).canonical - ) - - a = CastFunc(5, np.uint) - assert a.is_negative is False - assert a.is_nonnegative +def test_casts(): + x, y = sp.symbols("x, y") + + # Pickling + expr = tcast(x, "int32") + dump = pickle.dumps(expr) + expr_reconst = pickle.loads(dump) + assert expr_reconst == expr + + # Boolean Casts + bool_expr = tcast(x, "bool") + assert isinstance(bool_expr, boolalg.Boolean) + + # Check that we can construct boolean expressions with cast results + _ = boolalg.Or(bool_expr, y) + + # Assumptions + expr = tcast(x, "int32") + assert expr.is_integer + assert expr.is_real + assert expr.is_nonnegative is None + + expr = tcast(x, "uint32") + assert expr.is_integer + assert expr.is_real + assert expr.is_nonnegative + + expr = tcast(x, "float32") + assert expr.is_integer is None + assert expr.is_real + assert expr.is_nonnegative is None diff --git a/tests/kernelcreation/test_spatial_counters.py b/tests/kernelcreation/test_spatial_counters.py index fdb365294c98311943c370cb650694b1a4bd8613..4f865ad97f42f31133cc5d0dc3fbba569f6f577d 100644 --- a/tests/kernelcreation/test_spatial_counters.py +++ b/tests/kernelcreation/test_spatial_counters.py @@ -9,7 +9,7 @@ from pystencils import ( DEFAULTS, FieldType, ) -from pystencils.sympyextensions import CastFunc +from pystencils.sympyextensions import tcast @pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"]) @@ -21,9 +21,9 @@ def test_spatial_counters_dense(index_dtype): f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx") asms = [ - Assignment(f(0), CastFunc.as_numeric(z)), - Assignment(f(1), CastFunc.as_numeric(y)), - Assignment(f(2), CastFunc.as_numeric(x)), + Assignment(f(0), tcast.as_numeric(z)), + Assignment(f(1), tcast.as_numeric(y)), + Assignment(f(2), tcast.as_numeric(x)), ] cfg = CreateKernelConfig(index_dtype=index_dtype) @@ -44,9 +44,9 @@ def test_spatial_counters_sparse(index_dtype): f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx") asms = [ - Assignment(f(0), CastFunc.as_numeric(x)), - Assignment(f(1), CastFunc.as_numeric(y)), - Assignment(f(2), CastFunc.as_numeric(z)), + Assignment(f(0), tcast.as_numeric(x)), + Assignment(f(1), tcast.as_numeric(y)), + Assignment(f(2), tcast.as_numeric(z)), ] idx_struct = DEFAULTS.index_struct(index_dtype, 3) diff --git a/tests/kernelcreation/test_type_cast.py b/tests/kernelcreation/test_type_cast.py index 8ad6d867042ff6e57e5baee8c12ff45bae17e8e4..6b7acbbedbe395f36f306b76eeb09b8edc7444d9 100644 --- a/tests/kernelcreation/test_type_cast.py +++ b/tests/kernelcreation/test_type_cast.py @@ -8,7 +8,7 @@ from pystencils import ( Assignment, Field, ) -from pystencils.sympyextensions.typed_sympy import CastFunc +from pystencils.sympyextensions.typed_sympy import tcast AVAIL_TARGETS_NO_SSE = [t for t in Target.available_targets() if Target._SSE not in t] @@ -55,7 +55,7 @@ def test_type_cast(gen_config, xp, from_type, to_type): inp_field = Field.create_from_numpy_array("inp", inp) outp_field = Field.create_from_numpy_array("outp", outp) - asms = [Assignment(outp_field.center(), CastFunc(inp_field.center(), to_type))] + asms = [Assignment(outp_field.center(), tcast(inp_field.center(), to_type))] kernel = create_kernel(asms, gen_config) kfunc = kernel.compile() diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index fccec711423699c3d413682c3a5e8a99d5e092f1..f6c8f85b2b3df2289e809728b9e7b014d6428976 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -9,7 +9,7 @@ from pystencils import ( TypedSymbol, DynamicType, ) -from pystencils.sympyextensions import CastFunc +from pystencils.sympyextensions import tcast from pystencils.sympyextensions.pointers import mem_acc from pystencils.backend.ast.structural import ( @@ -312,16 +312,16 @@ def test_cast_func(): y2 = PsExpression.make(ctx.get_symbol("y")) z2 = PsExpression.make(ctx.get_symbol("z")) - expr = freeze(CastFunc(x, create_type("int"))) + expr = freeze(tcast(x, create_type("int"))) assert expr.structurally_equal(PsCast(create_type("int"), x2)) - expr = freeze(CastFunc.as_numeric(y)) + expr = freeze(tcast.as_numeric(y)) assert expr.structurally_equal(PsCast(ctx.default_dtype, y2)) - expr = freeze(CastFunc.as_index(z)) + expr = freeze(tcast.as_index(z)) assert expr.structurally_equal(PsCast(ctx.index_dtype, z2)) - expr = freeze(CastFunc(42, create_type("int16"))) + expr = freeze(tcast(42, create_type("int16"))) assert expr.structurally_equal(PsConstantExpr(PsConstant(42, create_type("int16")))) diff --git a/tests/nbackend/transformations/test_ast_vectorizer.py b/tests/nbackend/transformations/test_ast_vectorizer.py index f92f1e768a6e04c1eb5292612d6406365520bb72..3ccb479e5552bcd02954b9ed8518ef3ad0f90bfb 100644 --- a/tests/nbackend/transformations/test_ast_vectorizer.py +++ b/tests/nbackend/transformations/test_ast_vectorizer.py @@ -2,7 +2,7 @@ import sympy as sp import pytest from pystencils import Assignment, TypedSymbol, fields, FieldType, make_slice -from pystencils.sympyextensions import CastFunc, mem_acc +from pystencils.sympyextensions import tcast, mem_acc from pystencils.sympyextensions.pointers import AddressOf from pystencils.backend.constants import PsConstant @@ -109,7 +109,7 @@ def test_vectorize_casts_and_counter(): axis = VectorizationAxis(ctr, vec_ctr) vc = VectorizationContext(ctx, 4, axis) - expr = factory.parse_sympy(CastFunc(sp.Symbol("ctr"), create_type("float32"))) + expr = factory.parse_sympy(tcast(sp.Symbol("ctr"), create_type("float32"))) vec_expr = vectorize.visit(expr, vc) assert isinstance(vec_expr, PsCast) @@ -136,7 +136,7 @@ def test_invalid_vectorization(): axis = VectorizationAxis(ctr) vc = VectorizationContext(ctx, 4, axis) - expr = factory.parse_sympy(CastFunc(sp.Symbol("ctr"), create_type("float32"))) + expr = factory.parse_sympy(tcast(sp.Symbol("ctr"), create_type("float32"))) with pytest.raises(VectorizationError): # Fails since no vectorized counter was specified @@ -177,7 +177,7 @@ def test_vectorize_declarations(): [ factory.parse_sympy(asm) for asm in [ - Assignment(x, CastFunc.as_numeric(ctr)), + Assignment(x, tcast.as_numeric(ctr)), Assignment(y, sp.cos(x)), Assignment(z, x**2 + 2 * y / 4), Assignment(w, -x + y - z), diff --git a/tests/nbackend/transformations/test_canonicalize_symbols.py b/tests/nbackend/transformations/test_canonicalize_symbols.py index 2758d123417eb8e3015ed1d6b4d8cf0ba7c14611..dbc4ba10b71668af43d3a352d9cc49a8c9d61140 100644 --- a/tests/nbackend/transformations/test_canonicalize_symbols.py +++ b/tests/nbackend/transformations/test_canonicalize_symbols.py @@ -17,7 +17,7 @@ def test_deduplication(): factory = AstFactory(ctx) canonicalize = CanonicalizeSymbols(ctx) - f = Field.create_fixed_size("f", (5, 5), strides=(5, 1)) + f = Field.create_fixed_size("f", (5, 5), memory_strides=(5, 1)) x, y, z = sp.symbols("x, y, z") ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], f) diff --git a/tests/nbackend/transformations/test_hoist_invariants.py b/tests/nbackend/transformations/test_hoist_invariants.py index daa2760c0b376dc0bb1f2ca59703a15efc5c2312..1f27a5a4cd17d9b20e54b3c44d1e733f8374f947 100644 --- a/tests/nbackend/transformations/test_hoist_invariants.py +++ b/tests/nbackend/transformations/test_hoist_invariants.py @@ -33,7 +33,7 @@ def test_hoist_multiple_loops(): canonicalize = CanonicalizeSymbols(ctx) hoist = HoistLoopInvariantDeclarations(ctx) - f = Field.create_fixed_size("f", (5, 5), strides=(5, 1)) + f = Field.create_fixed_size("f", (5, 5), memory_strides=(5, 1)) x, y, z = sp.symbols("x, y, z") ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], f) diff --git a/tests/runtime/test_datahandling.py b/tests/runtime/test_datahandling.py index 62ba64056ab6d4062e49b76376d4e3cf3560ccf2..9d7ff924e8d7eba9039f8f0796145bd7de116ef5 100644 --- a/tests/runtime/test_datahandling.py +++ b/tests/runtime/test_datahandling.py @@ -249,7 +249,7 @@ def test_add_arrays(): dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=0, default_layout='numpy') x_, y_ = dh.add_arrays(field_description) - x, y = ps.fields(field_description + ': [3,4,5]') + x, y = ps.fields(field_description + ': float64[3,4,5]') assert x_ == x assert y_ == y