Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (2)
Showing
with 1017 additions and 448 deletions
......@@ -71,7 +71,10 @@ Typed Expressions
.. autoclass:: pystencils.DynamicType
:members:
.. autoclass:: pystencils.sympyextensions.CastFunc
.. autoclass:: pystencils.sympyextensions.typed_sympy.TypeCast
:members:
.. autoclass:: pystencils.sympyextensions.tcast
Integer Operations
......
......@@ -118,10 +118,10 @@ mypy src/pystencils
::::
:::{note}
Type checking is currently restricted to the `codegen`, `jit`, `backend`, and `types` modules,
since most code in the remaining modules is significantly older and is not comprehensively
type-annotated. As more modules are updated with type annotations, this list will expand in the future.
If you think a new module is ready to be type-checked, add an exception clause for it in the `mypy.ini` file.
Type checking is currently restricted only to a few modules, which are listed in the `mypy.ini` file.
Most code in the remaining modules is significantly older and is not comprehensively type-annotated.
As more modules are updated with type annotations, this list will expand in the future.
If you think a new module is ready to be type-checked, add an exception clause to `mypy.ini`.
:::
## Running the Test Suite
......
......@@ -82,6 +82,7 @@ Topics
user_manual/symbolic_language
user_manual/kernelcreation
user_manual/gpu_kernels
user_manual/WorkingWithTypes
.. toctree::
:maxdepth: 1
......
---
file_format: mystnb
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
mystnb:
execution_mode: cache
---
# Working with Data Types
This guide will demonstrate the various options that exist to customize the data types
in generated kernels.
Data types can be modified on different levels of granularity:
Individual fields and symbols,
single subexpressions,
or the entire kernel.
```{code-cell} ipython3
:tags: [remove-cell]
import pystencils as ps
import sympy as sp
```
## Changing the Default Data Types
The pystencils code generator defines two default data types:
- The default *numeric type*, which is applied to all numerical computations that are not
otherwise explicitly typed; the default is `float64`.
- The default *index type*, which is used for all loop and field index calculations; the default is `int64`.
These can be modified by setting the
{any}`default_dtype <CreateKernelConfig.default_dtype>` and
{any}`index_type <CreateKernelConfig.index_dtype>`
options of the code generator configuration:
```{code-cell} ipython3
cfg = ps.CreateKernelConfig()
cfg.default_dtype = "float32"
cfg.index_dtype = "int32"
```
Modifying these will change the way types for [untyped symbols](#untyped-symbols)
and [dynamically typed expressions](#dynamic-typing) are computed.
## Setting the Types of Fields and Symbols
(untyped-symbols)=
### Untyped Symbols
Symbols used inside a kernel are most commonly created using
{any}`sp.symbols <sympy.core.symbol.symbols>` or
{any}`sp.Symbol <sympy.core.symbol.Symbol>`.
These symbols are *untyped*; they will receive a type during code generation
according to these rules:
- Free untyped symbols (i.e. symbols not defined by an assignment inside the kernel) receive the
{any}`default data type <CreateKernelConfig.default_dtype>` specified in the code generator configuration.
- Bound untyped symbols (i.e. symbols that *are* defined in an assignment)
receive the data type that was computed for the right-hand side expression of their defining assignment.
If you are working on kernels with homogenous data types, using untyped symbols will mostly be enough.
### Explicitly Typed Symbols and Fields
If you need more control over the data types in (parts of) your kernel,
you will have to explicitly specify them.
To set an explicit data type for a symbol, use the {any}`TypedSymbol` class of pystencils:
```{code-cell} ipython3
x_typed = ps.TypedSymbol("x", "uint32")
x_typed, str(x_typed.dtype)
```
You can set a `TypedSymbol` to any data type provided by [the type system](#page_type_system),
which will then be enforced by the code generator.
The same holds for fields:
When creating fields through the {any}`fields <pystencils.field.fields>` function,
add the type to the descriptor string; for instance:
```{code-cell} ipython3
f, g = ps.fields("f(1), g(3): float32[3D]")
str(f.dtype), str(g.dtype)
```
When using `Field.create_generic` or `Field.create_fixed_size`, on the other hand,
you can set the data type via the `dtype` keyword argument.
(dynamic-typing)=
### Dynamically Typed Symbols and Fields
Apart from explicitly setting data types,
`TypedSymbol`s and fields can also receive a *dynamic data type* (see {any}`DynamicType`).
There are two options:
- Symbols or fields annotated with {any}`DynamicType.NUMERIC_TYPE` will always receive
the {any}`default numeric type <CreateKernelConfig.default_dtype>` configured for the
code generator.
This is the default setting for fields
created through `fields`, `Field.create_generic` or `Field.create_fixed_size`.
- When annotated with {any}`DynamicType.INDEX_TYPE`, on the other hand, they will receive
the {any}`index data type <CreateKernelConfig.index_dtype>` configured for the kernel.
Using dynamic typing, you can enforce symbols to receive either the standard numeric or
index type without explicitly stating it, such that your kernel definition becomes
independent from the code generator configuration.
## Mixing Types Inside Expressions
Pystencils enforces that all symbols, constants, and fields occuring inside an expression
have the same data type.
The code generator will never introduce implicit casts--if any type conflicts arise, it will terminate with an error.
Still, there are cases where you want to combine subexpressions of different types;
maybe you need to compute geometric information from loop counters or other integers,
or you are doing mixed-precision numerical computations.
In these cases, you might have to introduce explicit type casts when values move from one type context to another.
<!-- 2. Annotate expressions with a specific data type to ensure computations are performed in that type.
TODO: See #97 (https://i10git.cs.fau.de/pycodegen/pystencils/-/issues/97)
-->
(type_casts)=
### Type Casts
Type casts can be introduced into kernels using the {any}`tcast` symbolic function.
It takes an expression and a data type, which is either an explicit type (see [the type system](#page_type_system))
or a dynamic type ({any}`DynamicType`):
```{code-cell} ipython3
x, y = sp.symbols("x, y")
expr1 = ps.tcast(x, "float32")
expr2 = ps.tcast(3 + y, ps.DynamicType.INDEX_TYPE)
str(expr1.dtype), str(expr2.dtype)
```
When a type cast occurs, pystencils will compute the type of its argument independently
and then introduce a runtime cast to the target type.
That target type must comply with the type computed for the outer expression,
which the cast is embedded in.
## Understanding the pystencils Type Inference System
To correctly apply varying data types to pystencils kernels, it is important to understand
how pystencils computes and propagates the data types of symbols and expressions.
Type inference happens on the level of assignments.
For each assignment $x := \mathrm{calc}(y_1, \dots, y_n)$,
the system first attempts to compute a *unique* type for the right-hand side (RHS) $\mathrm{calc}(y_1, \dots, y_n)$.
It searches for any subexpression inside the RHS for which a type is already known --
these might be typed symbols
(whose types are either set explicitly by the user,
or have been determined from their defining assignment),
field accesses,
or explicitly typed expressions.
It then attempts to apply that data type to the entire expression.
If type conflicts occur, the process fails and the code generator raises an error.
Otherwise, the resulting type is assigned to the left-hand side symbol $x$.
:::{admonition} Developer's To Do
It would be great to illustrate this using a GraphViz-plot of an AST,
with nodes colored according to their data types
:::
......@@ -138,7 +138,7 @@ This happens roughly according to the following rules:
We can observe this behavior by setting up a kernel including several fields with different data types:
```{code-cell} ipython3
from pystencils.sympyextensions import CastFunc
from pystencils.sympyextensions import tcast
f = ps.fields("f: float32[2D]")
g = ps.fields("g: float16[2D]")
......
......@@ -17,6 +17,12 @@ ignore_errors = False
[mypy-pystencils.jit.*]
ignore_errors = False
[mypy-pystencils.field]
ignore_errors = False
[mypy-pystencils.sympyextensions.typed_sympy]
ignore_errors = False
[mypy-setuptools.*]
ignore_missing_imports=true
......
......@@ -32,7 +32,7 @@ from .spatial_coordinates import (
from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil
from .simp import AssignmentCollection
from .sympyextensions.typed_sympy import TypedSymbol, DynamicType
from .sympyextensions import SymbolCreator
from .sympyextensions import SymbolCreator, tcast
from .datahandling import create_data_handling
__all__ = [
......@@ -77,6 +77,7 @@ __all__ = [
"x_staggered_vector",
"fd",
"stencil",
"tcast",
]
from . import _version
......
......@@ -93,6 +93,16 @@ class KernelCreationContext:
def index_dtype(self) -> PsIntegerType:
"""Data type used by default for index expressions"""
return self._index_dtype
def resolve_dynamic_type(self, dtype: DynamicType | PsType) -> PsType:
"""Selects the appropriate data type for `DynamicType` instances, and returns all other types as they are."""
match dtype:
case DynamicType.NUMERIC_TYPE:
return self._default_dtype
case DynamicType.INDEX_TYPE:
return self._index_dtype
case _:
return dtype
@property
def metadata(self) -> dict[str, Any]:
......@@ -339,6 +349,8 @@ class KernelCreationContext:
if isinstance(s, TypedSymbol)
)
entry_type = self.resolve_dynamic_type(field.dtype)
if len(idx_types) > 1:
raise KernelConstraintsError(
f"Multiple incompatible types found in index symbols of field {field}: "
......@@ -375,10 +387,10 @@ class KernelCreationContext:
base_ptr = self.get_symbol(
DEFAULTS.field_pointer_name(field.name),
PsPointerType(field.dtype, restrict=True),
PsPointerType(entry_type, restrict=True),
)
return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides)
return PsBuffer(field.name, entry_type, base_ptr, buf_shape, buf_strides)
def _create_buffer_field_buffer(self, field: Field) -> PsBuffer:
if field.spatial_dimensions != 1:
......@@ -418,10 +430,11 @@ class KernelCreationContext:
]
buf_strides = [PsConstant(num_entries, idx_type), PsConstant(1, idx_type)]
buf_dtype = self.resolve_dynamic_type(field.dtype)
base_ptr = self.get_symbol(
DEFAULTS.field_pointer_name(field.name),
PsPointerType(field.dtype, restrict=True),
PsPointerType(buf_dtype, restrict=True),
)
return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides)
return PsBuffer(field.name, buf_dtype, base_ptr, buf_shape, buf_strides)
......@@ -13,7 +13,7 @@ from ...sympyextensions import (
integer_functions,
ConditionalFieldAccess,
)
from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType
from ...sympyextensions.typed_sympy import TypedSymbol, TypeCast, DynamicType
from ...sympyextensions.pointers import AddressOf, mem_acc
from ...field import Field, FieldType
......@@ -261,14 +261,7 @@ class FreezeExpressions:
return num / denom
def map_TypedSymbol(self, expr: TypedSymbol):
dtype = expr.dtype
match dtype:
case DynamicType.NUMERIC_TYPE:
dtype = self._ctx.default_dtype
case DynamicType.INDEX_TYPE:
dtype = self._ctx.index_dtype
dtype = self._ctx.resolve_dynamic_type(expr.dtype)
symb = self._ctx.get_symbol(expr.name, dtype)
return PsSymbolExpr(symb)
......@@ -481,7 +474,7 @@ class FreezeExpressions:
]
return cast(PsCall, args[0])
def map_CastFunc(self, cast_expr: CastFunc) -> PsCast | PsConstantExpr:
def map_TypeCast(self, cast_expr: TypeCast) -> PsCast | PsConstantExpr:
dtype: PsType
match cast_expr.dtype:
case DynamicType.NUMERIC_TYPE:
......
from __future__ import annotations
import functools
import hashlib
import operator
......@@ -5,23 +7,28 @@ import pickle
import re
from enum import Enum
from itertools import chain
from typing import List, Optional, Sequence, Set, Tuple, Union
from typing import List, Optional, Sequence, Set, Tuple
from warnings import warn
import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
from .defaults import DEFAULTS
from pystencils.alignedarray import aligned_empty
from pystencils.spatial_coordinates import x_staggered_vector, x_vector
from pystencils.stencil import direction_string_to_offset, inverse_direction, offset_to_direction_string
from pystencils.types import PsType, PsStructType, create_type
from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType
from pystencils.sympyextensions import is_integer_sequence
from pystencils.types import UserTypeSpec
from .alignedarray import aligned_empty
from .spatial_coordinates import x_staggered_vector, x_vector
from .stencil import (
direction_string_to_offset,
inverse_direction,
offset_to_direction_string,
)
from .types import PsType, PsStructType, create_type
from .sympyextensions.typed_sympy import TypedSymbol, DynamicType
from .sympyextensions import is_integer_sequence
from .types import UserTypeSpec
__all__ = ['Field', 'fields', 'FieldType', 'Field']
__all__ = ["Field", "fields", "FieldType", "Field"]
class FieldType(Enum):
......@@ -63,7 +70,10 @@ class FieldType(Enum):
@staticmethod
def is_staggered(field):
assert isinstance(field, Field)
return field.field_type == FieldType.STAGGERED or field.field_type == FieldType.STAGGERED_FLUX
return (
field.field_type == FieldType.STAGGERED
or field.field_type == FieldType.STAGGERED_FLUX
)
@staticmethod
def is_staggered_flux(field):
......@@ -123,23 +133,34 @@ class Field:
>>> assignments = [Assignment(dst[0,0](i), src[-offset](i)) for i, offset in enumerate(stencil)];
Args:
field_name: something
field_type: something
dtype: something
layout: something
shape: something
strides: something
field_name: The field's name
field_type: The kind of the field
dtype: Data type of the field's entries
layout: Linearization order of the field's spatial dimensions
shape: Total shape (spatial and index) of the field
strides: Linearization strides of the field
"""
@staticmethod
def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec = np.float64, index_dimensions=0,
layout='numpy', index_shape=None, field_type=FieldType.GENERIC) -> 'Field':
def create_generic(
field_name,
spatial_dimensions,
dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE,
index_dimensions=0,
layout="numpy",
index_shape=None,
field_type=FieldType.GENERIC,
) -> "Field":
"""
Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes
Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes.
**Field Element Type** By default, the data type of the field entries is left undetermined until
code generation, at which point it is set to the default numerical type of the kernel.
You can specify a concrete type using the `dtype` parameter.
Args:
field_name: symbolic name for the field
dtype: numpy data type of the array the kernel is called with later
dtype: Data type of the field entries
spatial_dimensions: see documentation of Field
index_dimensions: see documentation of Field
layout: tuple specifying the loop ordering of the spatial dimensions e.g. (2, 1, 0 ) means that
......@@ -159,41 +180,60 @@ class Field:
layout = spatial_layout_string_to_tuple(layout, dim=spatial_dimensions)
total_dimensions = spatial_dimensions + index_dimensions
shape: tuple[TypedSymbol | int, ...]
if index_shape is None or len(index_shape) == 0:
shape = tuple([
TypedSymbol(DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE)
for i in range(total_dimensions)
])
shape = tuple(
[
TypedSymbol(
DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE
)
for i in range(total_dimensions)
]
)
else:
shape = tuple(
[
TypedSymbol(DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE)
TypedSymbol(
DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE
)
for i in range(spatial_dimensions)
] + list(index_shape)
]
+ list(index_shape)
)
strides = tuple([
TypedSymbol(DEFAULTS.field_stride_name(field_name, i), DynamicType.INDEX_TYPE)
for i in range(total_dimensions)
])
strides: tuple[TypedSymbol | int, ...] = tuple(
[
TypedSymbol(
DEFAULTS.field_stride_name(field_name, i), DynamicType.INDEX_TYPE
)
for i in range(total_dimensions)
]
)
dtype = create_type(dtype)
np_data_type = dtype.numpy_dtype
assert np_data_type is not None
if np_data_type.fields is not None:
if not isinstance(dtype, DynamicType):
dtype = create_type(dtype)
if isinstance(dtype, PsStructType):
if index_dimensions != 0:
raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
raise ValueError(
"Structured arrays/fields are not allowed to have an index dimension"
)
shape += (1,)
strides += (1,)
if field_type == FieldType.STAGGERED and index_dimensions == 0:
raise ValueError("A staggered field needs at least one index dimension")
return Field(field_name, field_type, dtype, layout, shape, strides)
@staticmethod
def create_from_numpy_array(field_name: str, array: np.ndarray, index_dimensions: int = 0,
field_type=FieldType.GENERIC) -> 'Field':
def create_from_numpy_array(
field_name: str,
array: np.ndarray,
index_dimensions: int = 0,
field_type=FieldType.GENERIC,
) -> Field:
"""Creates a field based on the layout, data type, and shape of a given numpy array.
Kernels created for these kind of fields can only be called with arrays of the same layout, shape and type.
......@@ -206,7 +246,9 @@ class Field:
"""
spatial_dimensions = len(array.shape) - index_dimensions
if spatial_dimensions < 1:
raise ValueError("Too many index dimensions. At least one spatial dimension required")
raise ValueError(
"Too many index dimensions. At least one spatial dimension required"
)
full_layout = get_layout_of_array(array)
spatial_layout = tuple([i for i in full_layout if i < spatial_dimensions])
......@@ -218,21 +260,31 @@ class Field:
numpy_dtype = np.dtype(array.dtype)
if numpy_dtype.fields is not None:
if index_dimensions != 0:
raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
raise ValueError(
"Structured arrays/fields are not allowed to have an index dimension"
)
shape += (1,)
strides += (1,)
if field_type == FieldType.STAGGERED and index_dimensions == 0:
raise ValueError("A staggered field needs at least one index dimension")
return Field(field_name, field_type, array.dtype, spatial_layout, shape, strides)
return Field(
field_name, field_type, array.dtype, spatial_layout, shape, strides
)
@staticmethod
def create_fixed_size(field_name: str, shape: Tuple[int, ...], index_dimensions: int = 0,
dtype: UserTypeSpec = np.float64, layout: str = 'numpy',
strides: Optional[Sequence[int]] = None,
field_type=FieldType.GENERIC) -> 'Field':
def create_fixed_size(
field_name: str,
shape: tuple[int, ...],
index_dimensions: int = 0,
dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE,
layout: str | tuple[int, ...] = "numpy",
memory_strides: None | Sequence[int] = None,
strides: Optional[Sequence[int]] = None,
field_type=FieldType.GENERIC,
) -> Field:
"""
Creates a field with fixed sizes i.e. can be called only with arrays of the same size and layout
Creates a field with fixed sizes i.e. can be called only with arrays of the same size and layout.
Args:
field_name: symbolic name for the field
......@@ -240,54 +292,90 @@ class Field:
index_dimensions: how many of the trailing dimensions are interpreted as index (as opposed to spatial)
dtype: numpy data type of the array the kernel is called with later
layout: full layout of array, not only spatial dimensions
strides: strides in bytes or None to automatically compute them from shape (assuming no padding)
memory_strides: Linearization strides for each dimension;
i.e. the number of elements to skip to get from one index to the next in the respective dimension.
field_type: kind of field
"""
if strides is not None:
warn(
"The `strides` parameter to `Field.create_fixed_size` is deprecated "
"and will be removed in pystencils 2.1. "
"Use `memory_strides` instead; "
"beware that `memory_strides` takes the number of *elements* to skip, "
"instead of the number of bytes.",
FutureWarning
)
if memory_strides is not None:
raise ValueError("Cannot specify `memory_strides` and deprecated parameter `strides` at the same time.")
if isinstance(dtype, DynamicType):
raise ValueError("Cannot specify the deprecated parameter `strides` together with a `DynamicType`. "
"Set `memory_strides` instead.")
np_type = create_type(dtype).numpy_dtype
assert np_type is not None
memory_strides = tuple([s // np_type.itemsize for s in strides])
spatial_dimensions = len(shape) - index_dimensions
assert spatial_dimensions >= 1
if isinstance(layout, str):
layout = layout_string_to_tuple(layout, spatial_dimensions + index_dimensions)
layout = layout_string_to_tuple(
layout, spatial_dimensions + index_dimensions
)
if not isinstance(dtype, DynamicType):
dtype = create_type(dtype)
shape_tuple = tuple(int(s) for s in shape)
strides_tuple: tuple[int, ...]
shape = tuple(int(s) for s in shape)
if strides is None:
strides = compute_strides(shape, layout)
strides_tuple = compute_strides(shape_tuple, layout)
else:
assert len(strides) == len(shape)
strides = tuple([s // np.dtype(dtype).itemsize for s in strides])
assert len(strides) == len(shape_tuple)
strides_tuple = tuple(strides)
dtype = create_type(dtype)
numpy_dtype = dtype.numpy_dtype
assert numpy_dtype is not None
if numpy_dtype.fields is not None:
if isinstance(dtype, PsStructType):
if index_dimensions != 0:
raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
shape += (1,)
strides += (1,)
raise ValueError(
"Structured arrays/fields are not allowed to have an index dimension"
)
shape_tuple += (1,)
strides_tuple += (1,)
if field_type == FieldType.STAGGERED and index_dimensions == 0:
raise ValueError("A staggered field needs at least one index dimension")
spatial_layout = list(layout)
for i in range(spatial_dimensions, len(layout)):
spatial_layout.remove(i)
return Field(field_name, field_type, dtype, tuple(spatial_layout), shape, strides)
return Field(
field_name,
field_type,
dtype,
tuple(spatial_layout),
shape_tuple,
strides_tuple,
)
def __init__(
self,
field_name: str,
field_type: FieldType,
dtype: UserTypeSpec,
dtype: UserTypeSpec | DynamicType,
layout: tuple[int, ...],
shape,
strides
strides,
):
"""Do not use directly. Use static create* methods"""
self._field_name = field_name
assert isinstance(field_type, FieldType)
assert len(shape) == len(strides)
self.field_type = field_type
self._dtype = create_type(dtype)
self._dtype: PsType | DynamicType = (
create_type(dtype) if not isinstance(dtype, DynamicType) else dtype
)
self._layout = normalize_layout(layout)
self.shape = shape
self.strides = strides
......@@ -299,9 +387,23 @@ class Field:
def new_field_with_different_name(self, new_name):
if self.has_fixed_shape:
return Field(new_name, self.field_type, self._dtype, self._layout, self.shape, self.strides)
return Field(
new_name,
self.field_type,
self._dtype,
self._layout,
self.shape,
self.strides,
)
else:
return Field(new_name, self.field_type, self.dtype, self.layout, self.shape, self.strides)
return Field(
new_name,
self.field_type,
self.dtype,
self.layout,
self.shape,
self.strides,
)
@property
def spatial_dimensions(self) -> int:
......@@ -328,7 +430,7 @@ class Field:
@property
def spatial_shape(self) -> Tuple[int, ...]:
return self.shape[:self.spatial_dimensions]
return self.shape[: self.spatial_dimensions]
@property
def has_fixed_shape(self):
......@@ -344,31 +446,34 @@ class Field:
@property
def spatial_strides(self):
return self.strides[:self.spatial_dimensions]
return self.strides[: self.spatial_dimensions]
@property
def index_strides(self):
return self.strides[self.spatial_dimensions:]
@property
def dtype(self) -> PsType:
def dtype(self) -> PsType | DynamicType:
return self._dtype
@property
def itemsize(self):
return self.dtype.itemsize
def itemsize(self) -> int | None:
if isinstance(self.dtype, PsType):
return self.dtype.itemsize
else:
return None
def __repr__(self):
if any(isinstance(s, sp.Symbol) for s in self.spatial_shape):
spatial_shape_str = f'{self.spatial_dimensions}d'
spatial_shape_str = f"{self.spatial_dimensions}d"
else:
spatial_shape_str = ','.join(str(i) for i in self.spatial_shape)
index_shape_str = ','.join(str(i) for i in self.index_shape)
spatial_shape_str = ",".join(str(i) for i in self.spatial_shape)
index_shape_str = ",".join(str(i) for i in self.index_shape)
if self.index_shape:
return f'{self._field_name}({index_shape_str}): {self.dtype}[{spatial_shape_str}]'
return f"{self._field_name}({index_shape_str}): {self.dtype}[{spatial_shape_str}]"
else:
return f'{self._field_name}: {self.dtype}[{spatial_shape_str}]'
return f"{self._field_name}: {self.dtype}[{spatial_shape_str}]"
def __str__(self):
return self.name
......@@ -389,12 +494,26 @@ class Field:
elif len(index_shape) == 1:
return sp.Matrix([self(i) for i in range(index_shape[0])])
elif len(index_shape) == 2:
return sp.Matrix([[self(i, j) for j in range(index_shape[1])] for i in range(index_shape[0])])
return sp.Matrix(
[
[self(i, j) for j in range(index_shape[1])]
for i in range(index_shape[0])
]
)
elif len(index_shape) == 3:
return sp.Array([[[self(i, j, k) for k in range(index_shape[2])]
for j in range(index_shape[1])] for i in range(index_shape[0])])
return sp.Array(
[
[
[self(i, j, k) for k in range(index_shape[2])]
for j in range(index_shape[1])
]
for i in range(index_shape[0])
]
)
else:
raise NotImplementedError("center_vector is not implemented for more than 3 index dimensions")
raise NotImplementedError(
"center_vector is not implemented for more than 3 index dimensions"
)
@property
def center(self):
......@@ -410,12 +529,20 @@ class Field:
if self.index_dimensions == 0:
return sp.Matrix([self.__getitem__(offset)])
elif self.index_dimensions == 1:
return sp.Matrix([self.__getitem__(offset)(i) for i in range(self.index_shape[0])])
return sp.Matrix(
[self.__getitem__(offset)(i) for i in range(self.index_shape[0])]
)
elif self.index_dimensions == 2:
return sp.Matrix([[self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])]
for i in range(self.index_shape[0])])
return sp.Matrix(
[
[self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])]
for i in range(self.index_shape[0])
]
)
else:
raise NotImplementedError("neighbor_vector is not implemented for more than 2 index dimensions")
raise NotImplementedError(
"neighbor_vector is not implemented for more than 2 index dimensions"
)
def __getitem__(self, offset):
if type(offset) is np.ndarray:
......@@ -425,7 +552,9 @@ class Field:
if type(offset) is not tuple:
offset = (offset,)
if len(offset) != self.spatial_dimensions:
raise ValueError(f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}")
raise ValueError(
f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}"
)
return Field.Access(self, offset)
def absolute_access(self, offset, index):
......@@ -448,7 +577,9 @@ class Field:
offset = tuple(direction_string_to_offset(offset, self.spatial_dimensions))
offset = tuple([o * sp.Rational(1, 2) for o in offset])
if len(offset) != self.spatial_dimensions:
raise ValueError(f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}")
raise ValueError(
f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}"
)
prefactor = 1
neighbor_vec = [0] * len(offset)
......@@ -462,25 +593,33 @@ class Field:
if FieldType.is_staggered_flux(self):
prefactor = -1
if neighbor not in self.staggered_stencil:
raise ValueError(f"{offset_orig} is not a valid neighbor for the {self.staggered_stencil_name} stencil")
raise ValueError(
f"{offset_orig} is not a valid neighbor for the {self.staggered_stencil_name} stencil"
)
offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec))
idx = self.staggered_stencil.index(neighbor)
if self.index_dimensions == 1: # this field stores a scalar value at each staggered position
if (
self.index_dimensions == 1
): # this field stores a scalar value at each staggered position
if index is not None:
raise ValueError("Cannot specify an index for a scalar staggered field")
return prefactor * Field.Access(self, offset, (idx,))
else: # this field stores a vector or tensor at each staggered position
if index is None:
raise ValueError(f"Wrong number of indices: Got 0, expected {self.index_dimensions - 1}")
raise ValueError(
f"Wrong number of indices: Got 0, expected {self.index_dimensions - 1}"
)
if type(index) is np.ndarray:
index = tuple(index)
if type(index) is not tuple:
index = (index,)
if self.index_dimensions != len(index) + 1:
raise ValueError(f"Wrong number of indices: Got {len(index)}, expected {self.index_dimensions - 1}")
raise ValueError(
f"Wrong number of indices: Got {len(index)}, expected {self.index_dimensions - 1}"
)
return prefactor * Field.Access(self, offset, (idx, *index))
......@@ -491,30 +630,54 @@ class Field:
if self.index_dimensions == 1:
return sp.Matrix([self.staggered_access(offset)])
elif self.index_dimensions == 2:
return sp.Matrix([self.staggered_access(offset, i) for i in range(self.index_shape[1])])
return sp.Matrix(
[self.staggered_access(offset, i) for i in range(self.index_shape[1])]
)
elif self.index_dimensions == 3:
return sp.Matrix([[self.staggered_access(offset, (i, k)) for k in range(self.index_shape[2])]
for i in range(self.index_shape[1])])
return sp.Matrix(
[
[
self.staggered_access(offset, (i, k))
for k in range(self.index_shape[2])
]
for i in range(self.index_shape[1])
]
)
else:
raise NotImplementedError("staggered_vector_access is not implemented for more than 3 index dimensions")
raise NotImplementedError(
"staggered_vector_access is not implemented for more than 3 index dimensions"
)
@property
def staggered_stencil(self):
assert FieldType.is_staggered(self)
stencils = {
2: {
2: ["W", "S"], # D2Q5
4: ["W", "S", "SW", "NW"] # D2Q9
},
2: {2: ["W", "S"], 4: ["W", "S", "SW", "NW"]}, # D2Q5 # D2Q9
3: {
3: ["W", "S", "B"], # D3Q7
7: ["W", "S", "B", "BSW", "TSW", "BNW", "TNW"], # D3Q15
9: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS"], # D3Q19
13: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS", "BSW", "TSW", "BNW", "TNW"] # D3Q27
}
13: [
"W",
"S",
"B",
"SW",
"NW",
"BW",
"TW",
"BS",
"TS",
"BSW",
"TSW",
"BNW",
"TNW",
], # D3Q27
},
}
if not self.index_shape[0] in stencils[self.spatial_dimensions]:
raise ValueError(f"No known stencil has {self.index_shape[0]} staggered points")
raise ValueError(
f"No known stencil has {self.index_shape[0]} staggered points"
)
return stencils[self.spatial_dimensions][self.index_shape[0]]
@property
......@@ -527,13 +690,15 @@ class Field:
return Field.Access(self, center)(*args, **kwargs)
def hashable_contents(self):
return (self._layout,
self.shape,
self.strides,
self.field_type,
self._field_name,
self.latex_name,
self._dtype)
return (
self._layout,
self.shape,
self.strides,
self.field_type,
self._field_name,
self.latex_name,
self._dtype,
)
def __hash__(self):
return hash(self.hashable_contents())
......@@ -545,36 +710,53 @@ class Field:
@property
def physical_coordinates(self):
if hasattr(self.coordinate_transform, '__call__'):
return self.coordinate_transform(self.coordinate_origin + x_vector(self.spatial_dimensions))
if hasattr(self.coordinate_transform, "__call__"):
return self.coordinate_transform(
self.coordinate_origin + x_vector(self.spatial_dimensions)
)
else:
return self.coordinate_transform @ (self.coordinate_origin + x_vector(self.spatial_dimensions))
return self.coordinate_transform @ (
self.coordinate_origin + x_vector(self.spatial_dimensions)
)
@property
def physical_coordinates_staggered(self):
return self.coordinate_transform @ \
(self.coordinate_origin + x_staggered_vector(self.spatial_dimensions))
return self.coordinate_transform @ (
self.coordinate_origin + x_staggered_vector(self.spatial_dimensions)
)
def index_to_physical(self, index_coordinates: sp.Matrix, staggered=False):
if staggered:
index_coordinates = sp.Matrix([0.5] * len(self.coordinate_origin)) + index_coordinates
if hasattr(self.coordinate_transform, '__call__'):
index_coordinates = (
sp.Matrix([0.5] * len(self.coordinate_origin)) + index_coordinates
)
if hasattr(self.coordinate_transform, "__call__"):
return self.coordinate_transform(self.coordinate_origin + index_coordinates)
else:
return self.coordinate_transform @ (self.coordinate_origin + index_coordinates)
return self.coordinate_transform @ (
self.coordinate_origin + index_coordinates
)
def physical_to_index(self, physical_coordinates: sp.Matrix, staggered=False):
if hasattr(self.coordinate_transform, '__call__'):
if hasattr(self.coordinate_transform, 'inv'):
return self.coordinate_transform.inv()(physical_coordinates) - self.coordinate_origin
if hasattr(self.coordinate_transform, "__call__"):
if hasattr(self.coordinate_transform, "inv"):
return (
self.coordinate_transform.inv()(physical_coordinates)
- self.coordinate_origin
)
else:
idx = sp.Matrix(sp.symbols(f'index_coordinates:{self.ndim}', real=True))
idx = sp.Matrix(sp.symbols(f"index_coordinates:{self.ndim}", real=True))
rtn = sp.solve(self.index_to_physical(idx) - physical_coordinates, idx)
assert rtn, f'Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}'
assert (
rtn
), f"Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}"
return rtn
else:
rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin
rtn = (
self.coordinate_transform.inv() @ physical_coordinates
- self.coordinate_origin
)
if staggered:
rtn = sp.Matrix([i - 0.5 for i in rtn])
......@@ -603,18 +785,40 @@ class Field:
>>> central_y_component.at_index(0) # change component
v_C^0
"""
_iterable = False # see https://i10git.cs.fau.de/pycodegen/pystencils/-/merge_requests/166#note_10680
__match_args__ = ("field", "offsets", "index")
# for the type checker
_field: Field
_offsets: tuple[int | sp.Basic, ...]
_offsetName: str
_superscript: None | str
_index: tuple[int | sp.Basic, ...] | str
_indirect_addressing_fields: set[Field]
_is_absolute_access: bool
def __new__(cls, name, *args, **kwargs):
obj = Field.Access.__xnew_cached_(cls, name, *args, **kwargs)
return obj
def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None, is_absolute_access=False, dtype=None):
def __new_stage2__( # type: ignore
self,
field: Field,
offsets: tuple[int, ...] = (0, 0, 0),
idx: None | tuple[int, ...] | str = None,
is_absolute_access: bool = False,
dtype: PsType | None = None,
):
field_name = field.name
offsets_and_index = (*offsets, *idx) if idx is not None else offsets
constant_offsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsets_and_index])
constant_offsets = not any(
[
isinstance(o, sp.Basic) and not o.is_Integer
for o in offsets_and_index
]
)
if not idx:
idx = tuple([0] * field.index_dimensions)
......@@ -628,31 +832,36 @@ class Field:
else:
idx_str = ",".join([str(e) for e in idx])
superscript = idx_str
if field.has_fixed_index_shape and not isinstance(field.dtype, PsStructType):
if field.has_fixed_index_shape and not isinstance(
field.dtype, PsStructType
):
for i, bound in zip(idx, field.index_shape):
if i >= bound:
raise ValueError("Field index out of bounds")
else:
offset_name = hashlib.md5(pickle.dumps(offsets_and_index)).hexdigest()[:12]
offset_name = hashlib.md5(pickle.dumps(offsets_and_index)).hexdigest()[
:12
]
superscript = None
symbol_name = f"{field_name}_{offset_name}"
if superscript is not None:
symbol_name += "^" + superscript
obj: Field.Access
if dtype is not None:
obj = super(Field.Access, self).__xnew__(self, symbol_name, dtype)
else:
obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype)
obj._field = field
obj._offsets = []
_offsets: list[sp.Basic | int] = []
for o in offsets:
if isinstance(o, sp.Basic):
obj._offsets.append(o)
_offsets.append(o)
else:
obj._offsets.append(int(o))
obj._offsets = tuple(sp.sympify(obj._offsets))
_offsets.append(int(o))
obj._offsets = tuple(sp.sympify(_offsets))
obj._offsetName = offset_name
obj._superscript = superscript
obj._index = idx
......@@ -660,19 +869,33 @@ class Field:
obj._indirect_addressing_fields = set()
for e in chain(obj._offsets, obj._index):
if isinstance(e, sp.Basic):
obj._indirect_addressing_fields.update(a.field for a in e.atoms(Field.Access))
obj._indirect_addressing_fields.update(
a.field for a in e.atoms(Field.Access)
)
obj._is_absolute_access = is_absolute_access
return obj
def __getnewargs__(self):
return self.field, self.offsets, self.index, self.is_absolute_access, self.dtype
return (
self.field,
self.offsets,
self.index,
self.is_absolute_access,
self.dtype,
)
def __getnewargs_ex__(self):
return (self.field, self.offsets, self.index, self.is_absolute_access, self.dtype), {}
return (
self.field,
self.offsets,
self.index,
self.is_absolute_access,
self.dtype,
), {}
# noinspection SpellCheckingInspection
__xnew__ = staticmethod(__new_stage2__)
__xnew__ = staticmethod(__new_stage2__) # type: ignore
# noinspection SpellCheckingInspection
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
......@@ -686,22 +909,34 @@ class Field:
idx = ()
if len(idx) != self.field.index_dimensions:
raise ValueError(f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}")
raise ValueError(
f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}"
)
if len(idx) == 1 and isinstance(idx[0], str):
struct_type = self.field.dtype
assert isinstance(struct_type, PsStructType)
dtype = struct_type.get_member(idx[0]).dtype
return Field.Access(self.field, self._offsets, idx,
is_absolute_access=self.is_absolute_access, dtype=dtype)
return Field.Access(
self.field,
self._offsets,
idx,
is_absolute_access=self.is_absolute_access,
dtype=dtype,
)
else:
return Field.Access(self.field, self._offsets, idx,
is_absolute_access=self.is_absolute_access, dtype=self.dtype)
return Field.Access(
self.field,
self._offsets,
idx,
is_absolute_access=self.is_absolute_access,
dtype=self.dtype,
)
def __getitem__(self, *idx):
return self.__call__(*idx)
@property
def field(self) -> 'Field':
def field(self) -> "Field":
"""Field that the Access points to"""
return self._field
......@@ -713,7 +948,7 @@ class Field:
@property
def required_ghost_layers(self) -> int:
"""Largest spatial distance that is accessed."""
return int(np.max(np.abs(self._offsets)))
return int(np.max(np.abs(self._offsets))) # type: ignore
@property
def nr_of_coordinates(self):
......@@ -735,7 +970,7 @@ class Field:
"""Value of index coordinates as tuple."""
return self._index
def neighbor(self, coord_id: int, offset: int) -> 'Field.Access':
def neighbor(self, coord_id: int, offset: int) -> "Field.Access":
"""Returns a new Access with changed spatial coordinates.
Args:
......@@ -749,10 +984,15 @@ class Field:
"""
offset_list = list(self.offsets)
offset_list[coord_id] += offset
return Field.Access(self.field, tuple(offset_list), self.index,
is_absolute_access=self.is_absolute_access, dtype=self.dtype)
return Field.Access(
self.field,
tuple(offset_list),
self.index,
is_absolute_access=self.is_absolute_access,
dtype=self.dtype,
)
def get_shifted(self, *shift) -> 'Field.Access':
def get_shifted(self, *shift) -> "Field.Access":
"""Returns a new Access with changed spatial coordinates
Example:
......@@ -760,13 +1000,15 @@ class Field:
>>> f[0,0].get_shifted(1, 1)
f_NE
"""
return Field.Access(self.field,
tuple(a + b for a, b in zip(shift, self.offsets)),
self.index,
is_absolute_access=self.is_absolute_access,
dtype=self.dtype)
return Field.Access(
self.field,
tuple(a + b for a, b in zip(shift, self.offsets)),
self.index,
is_absolute_access=self.is_absolute_access,
dtype=self.dtype,
)
def at_index(self, *idx_tuple) -> 'Field.Access':
def at_index(self, *idx_tuple) -> "Field.Access":
"""Returns new Access with changed index.
Example:
......@@ -774,15 +1016,22 @@ class Field:
>>> f(0).at_index(8)
f_C^8
"""
return Field.Access(self.field, self.offsets, idx_tuple,
is_absolute_access=self.is_absolute_access, dtype=self.dtype)
return Field.Access(
self.field,
self.offsets,
idx_tuple,
is_absolute_access=self.is_absolute_access,
dtype=self.dtype,
)
def _eval_subs(self, old, new):
return Field.Access(self.field,
tuple(sp.sympify(a).subs(old, new) for a in self.offsets),
tuple(sp.sympify(a).subs(old, new) for a in self.index),
is_absolute_access=self.is_absolute_access,
dtype=self.dtype)
return Field.Access(
self.field,
tuple(sp.sympify(a).subs(old, new) for a in self.offsets),
tuple(sp.sympify(a).subs(old, new) for a in self.index),
is_absolute_access=self.is_absolute_access,
dtype=self.dtype,
)
@property
def is_absolute_access(self) -> bool:
......@@ -790,30 +1039,43 @@ class Field:
return self._is_absolute_access
@property
def indirect_addressing_fields(self) -> Set['Field']:
def indirect_addressing_fields(self) -> Set["Field"]:
"""Returns a set of fields that the access depends on.
e.g. f[index_field[1, 0]], the outer access to f depends on index_field
"""
e.g. f[index_field[1, 0]], the outer access to f depends on index_field
"""
return self._indirect_addressing_fields
def _hashable_content(self):
super_class_contents = super(Field.Access, self)._hashable_content()
return (super_class_contents, self._field.hashable_contents(), *self._index,
*self._offsets, self._is_absolute_access)
return (
super_class_contents,
self._field.hashable_contents(),
*self._index,
*self._offsets,
self._is_absolute_access,
)
def _staggered_offset(self, offsets, index):
assert FieldType.is_staggered(self._field)
neighbor = self._field.staggered_stencil[index]
neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions)
return [(o + sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)]
neighbor = direction_string_to_offset(
neighbor, self._field.spatial_dimensions
)
return [
(o + sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)
]
def _latex(self, _):
n = self._field.latex_name if self._field.latex_name else self._field.name
offset_str = ",".join([sp.latex(o) for o in self.offsets])
if FieldType.is_staggered(self._field):
offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
for i in range(len(self.offsets))])
offset_str = ",".join(
[
sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
for i in range(len(self.offsets))
]
)
if self.is_absolute_access:
offset_str = f"\\mathbf{offset_str}"
elif self.field.spatial_dimensions > 1:
......@@ -834,8 +1096,12 @@ class Field:
n = self._field.latex_name if self._field.latex_name else self._field.name
offset_str = ",".join([sp.latex(o) for o in self.offsets])
if FieldType.is_staggered(self._field):
offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
for i in range(len(self.offsets))])
offset_str = ",".join(
[
sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
for i in range(len(self.offsets))
]
)
if self.is_absolute_access:
offset_str = f"[abs]{offset_str}"
......@@ -851,12 +1117,36 @@ class Field:
return f"{n}[{offset_str}]"
def fields(description=None, index_dimensions=0, layout=None,
field_type=FieldType.GENERIC, **kwargs) -> Union[Field, List[Field]]:
def fields(
description=None,
index_dimensions=0,
layout=None,
field_type=FieldType.GENERIC,
**kwargs,
) -> Field | list[Field]:
"""Creates pystencils fields from a string description.
The description must be a string of the form
``"name(index-shape) [name(index-shape) ...]: <data-type>[<dimension-or-shape>]"``,
where:
- ``name`` is the name of the respective field
- ``(index-shape)`` is a tuple of integers describing the shape of the tensor on each field node
(can be omitted for scalar fields)
- ``<data-type>`` is the numerical data type of the field's entries;
this can be any type parseable by `create_type`,
as well as ``dyn`` for `DynamicType.NUMERIC_TYPE`
and ``dynidx`` for `DynamicType.INDEX_TYPE`.
- ``<dimension-or-shape>`` can be a dimensionality (e.g. ``1D``, ``2D``, ``3D``)
or a tuple of integers defining the spatial shape of the field.
Examples:
Create a 2D scalar and vector field:
Create a 3D scalar field of default numeric type:
>>> f = fields("f(1): [2D]")
>>> str(f.dtype)
'DynamicType.NUMERIC_TYPE'
Create a 2D scalar and vector field of 64-bit float type:
>>> s, v = fields("s, v(2): double[2D]")
>>> assert s.spatial_dimensions == 2 and s.index_dimensions == 0
>>> assert (v.spatial_dimensions, v.index_dimensions, v.index_shape) == (2, 1, (2,))
......@@ -882,35 +1172,70 @@ def fields(description=None, index_dimensions=0, layout=None,
>>> f = fields("pdfs(19) : float32[3D]", layout='fzyx')
>>> f.layout
(2, 1, 0)
Returns:
Sequence of fields created from the description
"""
result = []
if description:
field_descriptions, dtype, shape = _parse_description(description)
layout = 'numpy' if layout is None else layout
layout = "numpy" if layout is None else layout
for field_name, idx_shape in field_descriptions:
if field_name in kwargs:
arr = kwargs[field_name]
idx_shape_of_arr = () if not len(idx_shape) else arr.shape[-len(idx_shape):]
idx_shape_of_arr = (
() if not len(idx_shape) else arr.shape[-len(idx_shape):]
)
assert idx_shape_of_arr == idx_shape
f = Field.create_from_numpy_array(field_name, kwargs[field_name], index_dimensions=len(idx_shape),
field_type=field_type)
f = Field.create_from_numpy_array(
field_name,
kwargs[field_name],
index_dimensions=len(idx_shape),
field_type=field_type,
)
elif isinstance(shape, tuple):
f = Field.create_fixed_size(field_name, shape + idx_shape, dtype=dtype,
index_dimensions=len(idx_shape), layout=layout, field_type=field_type)
f = Field.create_fixed_size(
field_name,
shape + idx_shape,
dtype=dtype,
index_dimensions=len(idx_shape),
layout=layout,
field_type=field_type,
)
elif isinstance(shape, int):
f = Field.create_generic(field_name, spatial_dimensions=shape, dtype=dtype,
index_shape=idx_shape, layout=layout, field_type=field_type)
f = Field.create_generic(
field_name,
spatial_dimensions=shape,
dtype=dtype,
index_shape=idx_shape,
layout=layout,
field_type=field_type,
)
elif shape is None:
f = Field.create_generic(field_name, spatial_dimensions=2, dtype=dtype,
index_shape=idx_shape, layout=layout, field_type=field_type)
f = Field.create_generic(
field_name,
spatial_dimensions=2,
dtype=dtype,
index_shape=idx_shape,
layout=layout,
field_type=field_type,
)
else:
assert False
result.append(f)
else:
assert layout is None, "Layout can not be specified when creating Field from numpy array"
assert (
layout is None
), "Layout can not be specified when creating Field from numpy array"
for field_name, arr in kwargs.items():
result.append(Field.create_from_numpy_array(field_name, arr, index_dimensions=index_dimensions,
field_type=field_type))
result.append(
Field.create_from_numpy_array(
field_name,
arr,
index_dimensions=index_dimensions,
field_type=field_type,
)
)
if len(result) == 0:
raise ValueError("Could not parse field description")
......@@ -920,16 +1245,27 @@ def fields(description=None, index_dimensions=0, layout=None,
return result
def get_layout_from_strides(strides: Sequence[int], index_dimension_ids: Optional[List[int]] = None):
def get_layout_from_strides(
strides: Sequence[int], index_dimension_ids: Optional[List[int]] = None
):
index_dimension_ids = [] if index_dimension_ids is None else index_dimension_ids
coordinates = list(range(len(strides)))
relevant_strides = [stride for i, stride in enumerate(strides) if i not in index_dimension_ids]
result = [x for (y, x) in sorted(zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True)]
relevant_strides = [
stride for i, stride in enumerate(strides) if i not in index_dimension_ids
]
result = [
x
for (y, x) in sorted(
zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True
)
]
return normalize_layout(result)
def get_layout_of_array(arr: np.ndarray, index_dimension_ids: Optional[List[int]] = None):
""" Returns a list indicating the memory layout (linearization order) of the numpy array.
def get_layout_of_array(
arr: np.ndarray, index_dimension_ids: Optional[List[int]] = None
):
"""Returns a list indicating the memory layout (linearization order) of the numpy array.
Examples:
>>> get_layout_of_array(np.zeros([3,3,3]))
......@@ -946,7 +1282,9 @@ def get_layout_of_array(arr: np.ndarray, index_dimension_ids: Optional[List[int]
return get_layout_from_strides(arr.strides, index_dimension_ids)
def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0, **kwargs):
def create_numpy_array_with_layout(
shape, layout, alignment=False, byte_offset=0, **kwargs
):
"""Creates numpy array with given memory layout.
Args:
......@@ -970,7 +1308,10 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0
if cur_layout[i] != layout[i]:
index_to_swap_with = cur_layout.index(layout[i])
swaps.append((i, index_to_swap_with))
cur_layout[i], cur_layout[index_to_swap_with] = cur_layout[index_to_swap_with], cur_layout[i]
cur_layout[i], cur_layout[index_to_swap_with] = (
cur_layout[index_to_swap_with],
cur_layout[i],
)
assert tuple(cur_layout) == tuple(layout)
shape = list(shape)
......@@ -978,7 +1319,7 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0
shape[a], shape[b] = shape[b], shape[a]
if not alignment:
res = np.empty(shape, order='c', **kwargs)
res = np.empty(shape, order="c", **kwargs)
else:
res = aligned_empty(shape, alignment, byte_offset=byte_offset, **kwargs)
......@@ -990,37 +1331,43 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0
def spatial_layout_string_to_tuple(layout_str: str, dim: int) -> Tuple[int, ...]:
if dim <= 0:
raise ValueError("Dimensionality must be positive")
layout_str = layout_str.lower()
if layout_str in ('fzyx', 'zyxf', 'soa', 'aos'):
if layout_str in ("fzyx", "zyxf", "soa", "aos"):
if dim > 3:
raise ValueError(f"Invalid spatial dimensionality for layout descriptor {layout_str}: May be at most 3.")
raise ValueError(
f"Invalid spatial dimensionality for layout descriptor {layout_str}: May be at most 3."
)
return tuple(reversed(range(dim)))
if layout_str in ('f', 'reverse_numpy'):
if layout_str in ("f", "reverse_numpy"):
return tuple(reversed(range(dim)))
elif layout_str in ('c', 'numpy'):
elif layout_str in ("c", "numpy"):
return tuple(range(dim))
raise ValueError("Unknown layout descriptor " + layout_str)
def layout_string_to_tuple(layout_str, dim):
def layout_string_to_tuple(layout_str, dim) -> tuple[int, ...]:
if dim <= 0:
raise ValueError("Dimensionality must be positive")
layout_str = layout_str.lower()
if layout_str == 'fzyx' or layout_str == 'soa':
if layout_str == "fzyx" or layout_str == "soa":
if dim > 4:
raise ValueError(f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4.")
raise ValueError(
f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4."
)
return tuple(reversed(range(dim)))
elif layout_str == 'zyxf' or layout_str == 'aos':
elif layout_str == "zyxf" or layout_str == "aos":
if dim > 4:
raise ValueError(f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4.")
raise ValueError(
f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4."
)
return tuple(reversed(range(dim - 1))) + (dim - 1,)
elif layout_str == 'f' or layout_str == 'reverse_numpy':
elif layout_str == "f" or layout_str == "reverse_numpy":
return tuple(reversed(range(dim)))
elif layout_str == 'c' or layout_str == 'numpy':
elif layout_str == "c" or layout_str == "numpy":
return tuple(range(dim))
raise ValueError("Unknown layout descriptor " + layout_str)
......@@ -1055,7 +1402,8 @@ def compute_strides(shape, layout):
# ---------------------------------------- Parsing of string in fields() function --------------------------------------
field_description_regex = re.compile(r"""
field_description_regex = re.compile(
r"""
\s* # ignore leading white spaces
(\w+) # identifier is a sequence of alphanumeric characters, is stored in first group
(?: # optional index specification e.g. (1, 4, 2)
......@@ -1066,9 +1414,12 @@ field_description_regex = re.compile(r"""
\s*
)?
\s*,?\s* # ignore trailing white spaces and comma
""", re.VERBOSE)
""",
re.VERBOSE,
)
type_description_regex = re.compile(r"""
type_description_regex = re.compile(
r"""
\s*
(\w+)? # optional dtype
\s*
......@@ -1076,7 +1427,9 @@ type_description_regex = re.compile(r"""
([^\]]+)
\]
\s*
""", re.VERBOSE | re.IGNORECASE)
""",
re.VERBOSE | re.IGNORECASE,
)
def _parse_part1(d):
......@@ -1094,24 +1447,30 @@ def _parse_description(description):
result = type_description_regex.match(d)
if result:
data_type_str, size_info = result.group(1), result.group(2).strip().lower()
if data_type_str is None:
data_type_str = 'float64'
data_type_str = data_type_str.lower().strip()
if data_type_str is not None:
data_type_str = data_type_str.lower().strip()
if data_type_str:
match data_type_str:
case "dyn": dtype = DynamicType.NUMERIC_TYPE
case "dynidx": dtype = DynamicType.INDEX_TYPE
case _: dtype = create_type(data_type_str)
else:
dtype = DynamicType.NUMERIC_TYPE
if not data_type_str:
data_type_str = 'float64'
if size_info.endswith('d'):
if size_info.endswith("d"):
size_info = int(size_info[:-1])
else:
size_info = tuple(int(e) for e in size_info.split(","))
return data_type_str, size_info
return dtype, size_info
else:
raise ValueError("Could not parse field description")
if ':' in description:
field_description, field_info = description.split(':')
if ":" in description:
field_description, field_info = description.split(":")
else:
field_description, field_info = description, 'float64[2D]'
field_description, field_info = description, "float64[2D]"
fields_info = [e for e in _parse_part1(field_description)]
if not field_info:
......
from __future__ import annotations
from typing import Any
from typing import Any, cast
from os import path
import hashlib
......@@ -19,6 +19,7 @@ from ..types import (
PsUnsignedIntegerType,
PsSignedIntegerType,
PsIeeeFloatType,
PsPointerType,
)
from ..types.quick import Fp, SInt, UInt
from ..field import Field
......@@ -121,9 +122,7 @@ class PsKernelExtensioNModule:
def emit_call_wrapper(function_name: str, kernel: Kernel) -> str:
builder = CallWrapperBuilder()
for p in kernel.parameters:
builder.extract_parameter(p)
builder.extract_params(kernel.parameters)
# for c in kernel.constraints:
# builder.check_constraint(c)
......@@ -199,7 +198,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
"""
def __init__(self) -> None:
self._array_buffers: dict[Field, str] = dict()
self._buffer_types: dict[Field, PsType] = dict()
self._array_extractions: dict[Field, str] = dict()
self._array_frees: dict[Field, str] = dict()
......@@ -220,9 +219,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
return "PyLong_AsUnsignedLong"
case _:
raise ValueError(
f"Don't know how to cast Python objects to {dtype}"
)
raise ValueError(f"Don't know how to cast Python objects to {dtype}")
def _type_char(self, dtype: PsType) -> str | None:
if isinstance(
......@@ -233,37 +230,39 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
else:
return None
def extract_field(self, field: Field) -> str:
"""Adds an array, and returns the name of the underlying Py_Buffer."""
def get_field_buffer(self, field: Field) -> str:
"""Get the Python buffer object for the given field."""
return f"buffer_{field.name}"
def extract_field(self, field: Field) -> None:
"""Add the necessary code to extract the NumPy array for a given field"""
if field not in self._array_extractions:
extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=field.name)
actual_dtype = self._buffer_types[field]
# Check array type
type_char = self._type_char(field.dtype)
type_char = self._type_char(actual_dtype)
if type_char is not None:
dtype_cond = f"buffer_{field.name}.format[0] == '{type_char}'"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=dtype_cond,
what="data type",
name=field.name,
expected=str(field.dtype),
expected=str(actual_dtype),
)
# Check item size
itemsize = field.dtype.itemsize
itemsize = actual_dtype.itemsize
item_size_cond = f"buffer_{field.name}.itemsize == {itemsize}"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=item_size_cond, what="itemsize", name=field.name, expected=itemsize
)
self._array_buffers[field] = f"buffer_{field.name}"
self._array_extractions[field] = extraction_code
release_code = f"PyBuffer_Release(&buffer_{field.name});"
self._array_frees[field] = release_code
return self._array_buffers[field]
def extract_scalar(self, param: Parameter) -> str:
if param not in self._scalar_extractions:
extract_func = self._scalar_extractor(param.dtype)
......@@ -279,7 +278,8 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
def extract_array_assoc_var(self, param: Parameter) -> str:
if param not in self._array_assoc_var_extractions:
field = param.fields[0]
buffer = self.extract_field(field)
buffer = self.get_field_buffer(field)
buffer_dtype = self._buffer_types[field]
code: str | None = None
for prop in param.properties:
......@@ -293,7 +293,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
case FieldStride(_, coord):
code = (
f"{param.dtype.c_string()} {param.name} = "
f"{buffer}.strides[{coord}] / {field.dtype.itemsize};"
f"{buffer}.strides[{coord}] / {buffer_dtype.itemsize};"
)
break
assert code is not None
......@@ -302,29 +302,48 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
return param.name
def extract_parameter(self, param: Parameter):
if param.is_field_parameter:
self.extract_array_assoc_var(param)
else:
self.extract_scalar(param)
def extract_params(self, params: tuple[Parameter, ...]) -> None:
for param in params:
if ptr_props := param.get_properties(FieldBasePtr):
prop: FieldBasePtr = cast(FieldBasePtr, ptr_props.pop())
field = prop.field
actual_field_type: PsType
from .. import DynamicType
if isinstance(field.dtype, DynamicType):
ptr_type = param.dtype
assert isinstance(ptr_type, PsPointerType)
actual_field_type = ptr_type.base_type
else:
actual_field_type = field.dtype
self._buffer_types[prop.field] = actual_field_type
self.extract_field(prop.field)
for param in params:
if param.is_field_parameter:
self.extract_array_assoc_var(param)
else:
self.extract_scalar(param)
# def check_constraint(self, constraint: KernelParamsConstraint):
# variables = constraint.get_parameters()
# def check_constraint(self, constraint: KernelParamsConstraint):
# variables = constraint.get_parameters()
# for var in variables:
# self.extract_parameter(var)
# for var in variables:
# self.extract_parameter(var)
# cond = constraint.to_code()
# cond = constraint.to_code()
# code = f"""
# if(!({cond}))
# {{
# PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}");
# return NULL;
# }}
# """
# code = f"""
# if(!({cond}))
# {{
# PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}");
# return NULL;
# }}
# """
# self._constraint_checks.append(code)
# self._constraint_checks.append(code)
def call(self, kernel: Kernel, params: tuple[Parameter, ...]):
param_list = ", ".join(p.name for p in params)
......
......@@ -19,7 +19,7 @@ from ..codegen import (
Parameter,
)
from ..codegen.properties import FieldShape, FieldStride, FieldBasePtr
from ..types import PsStructType
from ..types import PsStructType, PsPointerType
from ..include import get_pystencils_include_path
......@@ -160,8 +160,18 @@ class CupyKernelWrapper(KernelWrapper):
for prop in kparam.properties:
match prop:
case FieldBasePtr(field):
elem_dtype: PsType
from .. import DynamicType
if isinstance(field.dtype, DynamicType):
assert isinstance(kparam.dtype, PsPointerType)
elem_dtype = kparam.dtype.base_type
else:
elem_dtype = field.dtype
arr = kwargs[field.name]
if arr.dtype != field.dtype.numpy_dtype:
if arr.dtype != elem_dtype.numpy_dtype:
raise JitError(
f"Data type mismatch at array argument {field.name}:"
f"Expected {field.dtype}, got {arr.dtype}"
......
......@@ -2,7 +2,7 @@ import copy
import numpy as np
import sympy as sp
from .sympyextensions import TypedSymbol, CastFunc, fast_subs
from .sympyextensions import TypedSymbol, tcast, fast_subs
# from pystencils.sympyextensions.astnodes import LoopOverCoordinate # TODO nbackend: replace
# from pystencils.backends.cbackend import CustomCodeNode # TODO nbackend: replace
......@@ -48,7 +48,7 @@ class RNGBase:
def get_code(self, dialect, vector_instruction_set, print_arg):
code = "\n"
for r in self.result_symbols:
if vector_instruction_set and not self.args[1].atoms(CastFunc):
if vector_instruction_set and not self.args[1].atoms(tcast):
# this vector RNG has become scalar through substitution
code += f"{r.dtype} {r.name};\n"
else:
......
from .astnodes import ConditionalFieldAccess
from .typed_sympy import TypedSymbol, CastFunc
from .typed_sympy import TypedSymbol, tcast
from .pointers import mem_acc
from .math import (
......@@ -34,7 +34,7 @@ from .math import (
__all__ = [
"ConditionalFieldAccess",
"TypedSymbol",
"CastFunc",
"tcast",
"mem_acc",
"remove_higher_order_terms",
"prod",
......
......@@ -11,7 +11,7 @@ from sympy.functions import Abs
from sympy.core.numbers import Zero
from ..assignment import Assignment
from .typed_sympy import CastFunc
from .typed_sympy import TypeCast
from ..types import PsPointerType, PsVectorType
T = TypeVar('T')
......@@ -603,7 +603,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
visit_children = False
elif t.is_integer:
pass
elif isinstance(t, CastFunc):
elif isinstance(t, TypeCast):
visit_children = False
visit(t.args[0])
elif t.func is fast_sqrt:
......
from __future__ import annotations
from typing import cast
import sympy as sp
from enum import Enum, auto
......@@ -6,11 +7,14 @@ from enum import Enum, auto
from ..types import (
PsType,
PsNumericType,
PsBoolType,
create_type,
UserTypeSpec
)
from sympy.logic.boolalg import Boolean
from warnings import warn
def is_loop_counter_symbol(symbol):
from ..defaults import DEFAULTS
......@@ -37,11 +41,12 @@ class DynamicType(Enum):
class TypeAtom(sp.Atom):
"""Wrapper around a type to disguise it as a SymPy atom."""
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
_dtype: PsType | DynamicType
def __init__(self, dtype: PsType | DynamicType) -> None:
self._dtype = dtype
def __new__(cls, dtype: PsType | DynamicType):
obj = super().__new__(cls)
obj._dtype = dtype
return obj
def _sympystr(self, *args, **kwargs):
return str(self._dtype)
......@@ -52,6 +57,9 @@ class TypeAtom(sp.Atom):
def _hashable_content(self):
return (self._dtype,)
def __getnewargs__(self):
return (self._dtype,)
def assumptions_from_dtype(dtype: PsType | DynamicType):
"""Derives SymPy assumptions from :class:`PsAbstractType`
......@@ -133,144 +141,74 @@ class TypedSymbol(sp.Symbol):
return self.dtype.required_headers if isinstance(self.dtype, PsType) else set()
class CastFunc(sp.Function):
"""Use this function to introduce a static type cast into the output code.
Usage: ``CastFunc(expr, target_type)`` becomes, in C code, ``(target_type) expr``.
The ``target_type`` may be a valid pystencils type specification parsable by `create_type`,
or a special value of the `DynamicType` enum.
These dynamic types can be used to select the target type according to the code generation context.
"""
class TypeCast(sp.Function):
"""Explicitly cast an expression to a data type."""
@staticmethod
def as_numeric(expr):
return CastFunc(expr, DynamicType.NUMERIC_TYPE)
return TypeCast(expr, DynamicType.NUMERIC_TYPE)
@staticmethod
def as_index(expr):
return CastFunc(expr, DynamicType.INDEX_TYPE)
is_Atom = True
def __new__(cls, *args, **kwargs):
if len(args) != 2:
pass
expr, dtype, *other_args = args
# If we have two consecutive casts, throw the inner one away.
# This optimisation is only available for simple casts. Thus the == is intended here!
if expr.__class__ == CastFunc:
expr = expr.args[0]
if not isinstance(dtype, (TypeAtom)):
if isinstance(dtype, DynamicType):
dtype = TypeAtom(dtype)
else:
dtype = TypeAtom(create_type(dtype))
# to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
# however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
# to problems when for example comparing cast_func's for equality
#
# lhs = bitwise_and(a, cast_func(1, 'int'))
# rhs = cast_func(0, 'int')
# print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
# -> thus a separate class boolean_cast_func is introduced
if isinstance(expr, sp.logic.boolalg.Boolean) and (
not isinstance(expr, TypedSymbol) or isinstance(expr.dtype, PsBoolType)
):
cls = BooleanCastFunc
return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
@property
def canonical(self):
if hasattr(self.args[0], "canonical"):
return self.args[0].canonical
else:
raise NotImplementedError()
@property
def is_commutative(self):
return self.args[0].is_commutative
@property
def dtype(self) -> PsType | DynamicType:
assert isinstance(self.args[1], TypeAtom)
return self.args[1].get()
return TypeCast(expr, DynamicType.INDEX_TYPE)
@property
def expr(self):
def expr(self) -> sp.Basic:
return self.args[0]
@property
def is_integer(self):
def dtype(self) -> PsType | DynamicType:
return cast(TypeAtom, self._args[1]).get()
def __new__(cls, expr: sp.Basic, dtype: UserTypeSpec | DynamicType | TypeAtom):
tatom: TypeAtom
match dtype:
case TypeAtom():
tatom = dtype
case DynamicType():
tatom = TypeAtom(dtype)
case _:
tatom = TypeAtom(create_type(dtype))
return super().__new__(cls, expr, tatom)
@classmethod
def eval(cls, expr: sp.Basic, tatom: TypeAtom) -> TypeCast | None:
dtype = tatom.get()
if cls is not BoolCast and isinstance(dtype, PsNumericType) and dtype.is_bool():
return BoolCast(expr, tatom)
return None
def _eval_is_integer(self):
if self.dtype == DynamicType.INDEX_TYPE:
return True
elif isinstance(self.dtype, PsNumericType):
return self.dtype.is_int() or super().is_integer
else:
return super().is_integer
@property
def is_negative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if isinstance(self.dtype, PsNumericType):
if self.dtype.is_uint():
return False
return super().is_negative
@property
def is_nonnegative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if self.is_negative is False:
if isinstance(self.dtype, PsNumericType) and self.dtype.is_int():
return True
def _eval_is_real(self):
if isinstance(self.dtype, DynamicType):
return True
if isinstance(self.dtype, PsNumericType) and (self.dtype.is_float() or self.dtype.is_int()):
return True
def _eval_is_nonnegative(self):
if isinstance(self.dtype, PsNumericType) and self.dtype.is_uint():
return True
else:
return super().is_nonnegative
@property
def is_real(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if isinstance(self.dtype, PsNumericType):
return self.dtype.is_int() or self.dtype.is_float() or super().is_real
else:
return super().is_real
class BooleanCastFunc(CastFunc, sp.logic.boolalg.Boolean):
# TODO: documentation
pass
class VectorMemoryAccess(CastFunc):
"""
Special memory access for vectorized kernel.
Arguments: read/write expression, type, aligned, non-temporal, mask (or none), stride
"""
nargs = (6,)
class BoolCast(TypeCast, Boolean):
pass
class ReinterpretCastFunc(CastFunc):
"""
Reinterpret cast is necessary for the StructType
"""
pass
tcast = TypeCast
class PointerArithmeticFunc(sp.Function, sp.logic.boolalg.Boolean):
# TODO: documentation, or deprecate!
@property
def canonical(self):
if hasattr(self.args[0], "canonical"):
return self.args[0].canonical
else:
raise NotImplementedError()
class CastFunc(TypeCast):
def __new__(cls, *args, **kwargs):
warn(
"CastFunc is deprecated and will be removed in pystencils 2.1. "
"Use `pystencils.tcast` instead.",
FutureWarning
)
return TypeCast.__new__(cls, *args, **kwargs)
......@@ -5,7 +5,7 @@ import pytest
import pystencils
from pystencils.types import PsPointerType, create_type
from pystencils.sympyextensions.pointers import AddressOf
from pystencils.sympyextensions.typed_sympy import CastFunc
from pystencils.sympyextensions.typed_sympy import tcast
from pystencils.simp import sympy_cse
import sympy as sp
......@@ -23,14 +23,14 @@ def test_address_of():
assignments = pystencils.AssignmentCollection({
s: AddressOf(x[0, 0]),
y[0, 0]: CastFunc(s, create_type('int64'))
y[0, 0]: tcast(s, create_type('int64'))
})
_ = pystencils.create_kernel(assignments).compile()
# pystencils.show_code(kernel.ast)
assignments = pystencils.AssignmentCollection({
y[0, 0]: CastFunc(AddressOf(x[0, 0]), create_type('int64'))
y[0, 0]: tcast(AddressOf(x[0, 0]), create_type('int64'))
})
_ = pystencils.create_kernel(assignments).compile()
......@@ -41,7 +41,7 @@ def test_address_of_with_cse():
x, y = pystencils.fields('x, y: int64[2d]')
assignments = pystencils.AssignmentCollection({
x[0, 0]: CastFunc(AddressOf(x[0, 0]), create_type('int64')) + 1
x[0, 0]: tcast(AddressOf(x[0, 0]), create_type('int64')) + 1
})
_ = pystencils.create_kernel(assignments).compile()
......
......@@ -3,7 +3,7 @@ import pytest
import sympy as sp
import pystencils as ps
from pystencils import DEFAULTS
from pystencils import DEFAULTS, DynamicType, create_type, fields
from pystencils.field import (
Field,
FieldType,
......@@ -15,6 +15,7 @@ from pystencils.field import (
def test_field_basic():
f = Field.create_generic("f", spatial_dimensions=2)
assert FieldType.is_generic(f)
assert f.dtype == DynamicType.NUMERIC_TYPE
assert f["E"] == f[1, 0]
assert f["N"] == f[0, 1]
assert "_" in f.center._latex("dummy")
......@@ -41,17 +42,16 @@ def test_field_basic():
assert f1.ndim == f.ndim
assert f1.values_per_cell() == f.values_per_cell()
fixed = ps.fields("f(5, 5) : double[20, 20]")
assert fixed.neighbor_vector((1, 1)).shape == (5, 5)
f = Field.create_fixed_size("f", (10, 10), strides=(80, 8), dtype=np.float64)
f = Field.create_fixed_size("f", (10, 10), strides=(10, 1), dtype=np.float64)
assert f.spatial_strides == (10, 1)
assert f.index_strides == ()
assert f.center_vector == sp.Matrix([f.center])
assert f.dtype == create_type("float64")
f1 = f.new_field_with_different_name("f1")
assert f1.ndim == f.ndim
assert f1.values_per_cell() == f.values_per_cell()
assert f1.dtype == create_type("float64")
f = Field.create_fixed_size("f", (8, 8, 2, 2), index_dimensions=2)
assert f.center_vector == sp.Matrix([[f(0, 0), f(0, 1)], [f(1, 0), f(1, 1)]])
......@@ -61,16 +61,48 @@ def test_field_basic():
neighbor = field_access.neighbor(coord_id=0, offset=-2)
assert neighbor.offsets == (-1, 1)
assert "_" in neighbor._latex("dummy")
assert f.dtype == DynamicType.NUMERIC_TYPE
f = Field.create_fixed_size("f", (8, 8, 2, 2, 2), index_dimensions=3)
assert f.center_vector == sp.Array(
[[[f(i, j, k) for k in range(2)] for j in range(2)] for i in range(2)]
)
assert f.dtype == DynamicType.NUMERIC_TYPE
f = Field.create_generic("f", spatial_dimensions=5, index_dimensions=2)
field_access = f[1, -1, 2, -3, 0](1, 0)
assert field_access.offsets == (1, -1, 2, -3, 0)
assert field_access.index == (1, 0)
assert f.dtype == DynamicType.NUMERIC_TYPE
def test_field_description_parsing():
f, g = fields("f(1), g(3): [2D]")
assert f.dtype == g.dtype == DynamicType.NUMERIC_TYPE
assert f.spatial_dimensions == g.spatial_dimensions == 2
assert f.index_shape == (1,)
assert g.index_shape == (3,)
f = fields("f: dyn[3D]")
assert f.dtype == DynamicType.NUMERIC_TYPE
idx = fields("idx: dynidx[3D]")
assert idx.dtype == DynamicType.INDEX_TYPE
h = fields("h: float32[3D]")
assert h.index_shape == ()
assert h.spatial_dimensions == 3
assert h.index_dimensions == 0
assert h.dtype == create_type("float32")
f: Field = fields("f(5, 5) : double[20, 20]")
assert f.dtype == create_type("float64")
assert f.spatial_shape == (20, 20)
assert f.index_shape == (5, 5)
assert f.neighbor_vector((1, 1)).shape == (5, 5)
def test_error_handling():
......@@ -145,7 +177,7 @@ def test_error_handling():
def test_decorator_scoping():
dst = ps.fields("dst : double[2D]")
dst = fields("dst : double[2D]")
def f1():
a = sp.Symbol("a")
......@@ -165,7 +197,7 @@ def test_decorator_scoping():
def test_string_creation():
x, y, z = ps.fields(" x(4), y(3,5) z : double[ 3, 47]")
x, y, z = fields(" x(4), y(3,5) z : double[ 3, 47]")
assert x.index_shape == (4,)
assert y.index_shape == (3, 5)
assert z.spatial_shape == (3, 47)
......@@ -173,9 +205,9 @@ def test_string_creation():
def test_itemsize():
x = ps.fields("x: float32[1d]")
y = ps.fields("y: float64[2d]")
i = ps.fields("i: int16[1d]")
x = fields("x: float32[1d]")
y = fields("y: float64[2d]")
i = fields("i: int16[1d]")
assert x.itemsize == 4
assert y.itemsize == 8
......@@ -249,7 +281,7 @@ def test_memory_layout_descriptors():
def test_staggered():
# D2Q5
j1, j2, j3 = ps.fields(
j1, j2, j3 = fields(
"j1(2), j2(2,2), j3(2,2,2) : double[2D]", field_type=FieldType.STAGGERED
)
......@@ -296,7 +328,7 @@ def test_staggered():
)
# D2Q9
k1, k2 = ps.fields("k1(4), k2(2) : double[2D]", field_type=FieldType.STAGGERED)
k1, k2 = fields("k1(4), k2(2) : double[2D]", field_type=FieldType.STAGGERED)
assert k1[1, 1](2) == k1.staggered_access("NE")
assert k1[0, 0](2) == k1.staggered_access("SW")
......@@ -319,7 +351,7 @@ def test_staggered():
]
# sign reversed when using as flux field
r = ps.fields("r(2) : double[2D]", field_type=FieldType.STAGGERED_FLUX)
r = fields("r(2) : double[2D]", field_type=FieldType.STAGGERED_FLUX)
assert r[0, 0](0) == r.staggered_access("W")
assert -r[1, 0](0) == r.staggered_access("E")
......
import numpy as np
import pickle
import sympy as sp
from sympy.logic import boolalg
from pystencils.sympyextensions.typed_sympy import (
TypedSymbol,
CastFunc,
tcast,
TypeAtom,
DynamicType,
)
......@@ -12,7 +15,7 @@ from pystencils.types.quick import UInt, Ptr
def test_type_atoms():
atom1 = TypeAtom(create_type("int32"))
atom2 = TypeAtom(create_type("int32"))
atom2 = TypeAtom(create_type(np.int32))
assert atom1 == atom2
......@@ -25,6 +28,11 @@ def test_type_atoms():
assert atom3 != atom4
assert atom4 != atom5
dump = pickle.dumps(atom1)
atom1_reconst = pickle.loads(dump)
assert atom1_reconst == atom1
def test_typed_symbol():
x = TypedSymbol("x", "uint32")
......@@ -46,12 +54,34 @@ def test_typed_symbol():
assert not z.is_nonnegative
def test_cast_func():
assert (
CastFunc(TypedSymbol("s", np.uint), np.int64).canonical
== TypedSymbol("s", np.uint).canonical
)
a = CastFunc(5, np.uint)
assert a.is_negative is False
assert a.is_nonnegative
def test_casts():
x, y = sp.symbols("x, y")
# Pickling
expr = tcast(x, "int32")
dump = pickle.dumps(expr)
expr_reconst = pickle.loads(dump)
assert expr_reconst == expr
# Boolean Casts
bool_expr = tcast(x, "bool")
assert isinstance(bool_expr, boolalg.Boolean)
# Check that we can construct boolean expressions with cast results
_ = boolalg.Or(bool_expr, y)
# Assumptions
expr = tcast(x, "int32")
assert expr.is_integer
assert expr.is_real
assert expr.is_nonnegative is None
expr = tcast(x, "uint32")
assert expr.is_integer
assert expr.is_real
assert expr.is_nonnegative
expr = tcast(x, "float32")
assert expr.is_integer is None
assert expr.is_real
assert expr.is_nonnegative is None
......@@ -9,7 +9,7 @@ from pystencils import (
DEFAULTS,
FieldType,
)
from pystencils.sympyextensions import CastFunc
from pystencils.sympyextensions import tcast
@pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"])
......@@ -21,9 +21,9 @@ def test_spatial_counters_dense(index_dtype):
f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
asms = [
Assignment(f(0), CastFunc.as_numeric(z)),
Assignment(f(1), CastFunc.as_numeric(y)),
Assignment(f(2), CastFunc.as_numeric(x)),
Assignment(f(0), tcast.as_numeric(z)),
Assignment(f(1), tcast.as_numeric(y)),
Assignment(f(2), tcast.as_numeric(x)),
]
cfg = CreateKernelConfig(index_dtype=index_dtype)
......@@ -44,9 +44,9 @@ def test_spatial_counters_sparse(index_dtype):
f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
asms = [
Assignment(f(0), CastFunc.as_numeric(x)),
Assignment(f(1), CastFunc.as_numeric(y)),
Assignment(f(2), CastFunc.as_numeric(z)),
Assignment(f(0), tcast.as_numeric(x)),
Assignment(f(1), tcast.as_numeric(y)),
Assignment(f(2), tcast.as_numeric(z)),
]
idx_struct = DEFAULTS.index_struct(index_dtype, 3)
......