Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (2)
Showing
with 1017 additions and 448 deletions
......@@ -71,7 +71,10 @@ Typed Expressions
.. autoclass:: pystencils.DynamicType
:members:
.. autoclass:: pystencils.sympyextensions.CastFunc
.. autoclass:: pystencils.sympyextensions.typed_sympy.TypeCast
:members:
.. autoclass:: pystencils.sympyextensions.tcast
Integer Operations
......
......@@ -118,10 +118,10 @@ mypy src/pystencils
::::
:::{note}
Type checking is currently restricted to the `codegen`, `jit`, `backend`, and `types` modules,
since most code in the remaining modules is significantly older and is not comprehensively
type-annotated. As more modules are updated with type annotations, this list will expand in the future.
If you think a new module is ready to be type-checked, add an exception clause for it in the `mypy.ini` file.
Type checking is currently restricted only to a few modules, which are listed in the `mypy.ini` file.
Most code in the remaining modules is significantly older and is not comprehensively type-annotated.
As more modules are updated with type annotations, this list will expand in the future.
If you think a new module is ready to be type-checked, add an exception clause to `mypy.ini`.
:::
## Running the Test Suite
......
......@@ -82,6 +82,7 @@ Topics
user_manual/symbolic_language
user_manual/kernelcreation
user_manual/gpu_kernels
user_manual/WorkingWithTypes
.. toctree::
:maxdepth: 1
......
---
file_format: mystnb
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
mystnb:
execution_mode: cache
---
# Working with Data Types
This guide will demonstrate the various options that exist to customize the data types
in generated kernels.
Data types can be modified on different levels of granularity:
Individual fields and symbols,
single subexpressions,
or the entire kernel.
```{code-cell} ipython3
:tags: [remove-cell]
import pystencils as ps
import sympy as sp
```
## Changing the Default Data Types
The pystencils code generator defines two default data types:
- The default *numeric type*, which is applied to all numerical computations that are not
otherwise explicitly typed; the default is `float64`.
- The default *index type*, which is used for all loop and field index calculations; the default is `int64`.
These can be modified by setting the
{any}`default_dtype <CreateKernelConfig.default_dtype>` and
{any}`index_type <CreateKernelConfig.index_dtype>`
options of the code generator configuration:
```{code-cell} ipython3
cfg = ps.CreateKernelConfig()
cfg.default_dtype = "float32"
cfg.index_dtype = "int32"
```
Modifying these will change the way types for [untyped symbols](#untyped-symbols)
and [dynamically typed expressions](#dynamic-typing) are computed.
## Setting the Types of Fields and Symbols
(untyped-symbols)=
### Untyped Symbols
Symbols used inside a kernel are most commonly created using
{any}`sp.symbols <sympy.core.symbol.symbols>` or
{any}`sp.Symbol <sympy.core.symbol.Symbol>`.
These symbols are *untyped*; they will receive a type during code generation
according to these rules:
- Free untyped symbols (i.e. symbols not defined by an assignment inside the kernel) receive the
{any}`default data type <CreateKernelConfig.default_dtype>` specified in the code generator configuration.
- Bound untyped symbols (i.e. symbols that *are* defined in an assignment)
receive the data type that was computed for the right-hand side expression of their defining assignment.
If you are working on kernels with homogenous data types, using untyped symbols will mostly be enough.
### Explicitly Typed Symbols and Fields
If you need more control over the data types in (parts of) your kernel,
you will have to explicitly specify them.
To set an explicit data type for a symbol, use the {any}`TypedSymbol` class of pystencils:
```{code-cell} ipython3
x_typed = ps.TypedSymbol("x", "uint32")
x_typed, str(x_typed.dtype)
```
You can set a `TypedSymbol` to any data type provided by [the type system](#page_type_system),
which will then be enforced by the code generator.
The same holds for fields:
When creating fields through the {any}`fields <pystencils.field.fields>` function,
add the type to the descriptor string; for instance:
```{code-cell} ipython3
f, g = ps.fields("f(1), g(3): float32[3D]")
str(f.dtype), str(g.dtype)
```
When using `Field.create_generic` or `Field.create_fixed_size`, on the other hand,
you can set the data type via the `dtype` keyword argument.
(dynamic-typing)=
### Dynamically Typed Symbols and Fields
Apart from explicitly setting data types,
`TypedSymbol`s and fields can also receive a *dynamic data type* (see {any}`DynamicType`).
There are two options:
- Symbols or fields annotated with {any}`DynamicType.NUMERIC_TYPE` will always receive
the {any}`default numeric type <CreateKernelConfig.default_dtype>` configured for the
code generator.
This is the default setting for fields
created through `fields`, `Field.create_generic` or `Field.create_fixed_size`.
- When annotated with {any}`DynamicType.INDEX_TYPE`, on the other hand, they will receive
the {any}`index data type <CreateKernelConfig.index_dtype>` configured for the kernel.
Using dynamic typing, you can enforce symbols to receive either the standard numeric or
index type without explicitly stating it, such that your kernel definition becomes
independent from the code generator configuration.
## Mixing Types Inside Expressions
Pystencils enforces that all symbols, constants, and fields occuring inside an expression
have the same data type.
The code generator will never introduce implicit casts--if any type conflicts arise, it will terminate with an error.
Still, there are cases where you want to combine subexpressions of different types;
maybe you need to compute geometric information from loop counters or other integers,
or you are doing mixed-precision numerical computations.
In these cases, you might have to introduce explicit type casts when values move from one type context to another.
<!-- 2. Annotate expressions with a specific data type to ensure computations are performed in that type.
TODO: See #97 (https://i10git.cs.fau.de/pycodegen/pystencils/-/issues/97)
-->
(type_casts)=
### Type Casts
Type casts can be introduced into kernels using the {any}`tcast` symbolic function.
It takes an expression and a data type, which is either an explicit type (see [the type system](#page_type_system))
or a dynamic type ({any}`DynamicType`):
```{code-cell} ipython3
x, y = sp.symbols("x, y")
expr1 = ps.tcast(x, "float32")
expr2 = ps.tcast(3 + y, ps.DynamicType.INDEX_TYPE)
str(expr1.dtype), str(expr2.dtype)
```
When a type cast occurs, pystencils will compute the type of its argument independently
and then introduce a runtime cast to the target type.
That target type must comply with the type computed for the outer expression,
which the cast is embedded in.
## Understanding the pystencils Type Inference System
To correctly apply varying data types to pystencils kernels, it is important to understand
how pystencils computes and propagates the data types of symbols and expressions.
Type inference happens on the level of assignments.
For each assignment $x := \mathrm{calc}(y_1, \dots, y_n)$,
the system first attempts to compute a *unique* type for the right-hand side (RHS) $\mathrm{calc}(y_1, \dots, y_n)$.
It searches for any subexpression inside the RHS for which a type is already known --
these might be typed symbols
(whose types are either set explicitly by the user,
or have been determined from their defining assignment),
field accesses,
or explicitly typed expressions.
It then attempts to apply that data type to the entire expression.
If type conflicts occur, the process fails and the code generator raises an error.
Otherwise, the resulting type is assigned to the left-hand side symbol $x$.
:::{admonition} Developer's To Do
It would be great to illustrate this using a GraphViz-plot of an AST,
with nodes colored according to their data types
:::
......@@ -138,7 +138,7 @@ This happens roughly according to the following rules:
We can observe this behavior by setting up a kernel including several fields with different data types:
```{code-cell} ipython3
from pystencils.sympyextensions import CastFunc
from pystencils.sympyextensions import tcast
f = ps.fields("f: float32[2D]")
g = ps.fields("g: float16[2D]")
......
......@@ -17,6 +17,12 @@ ignore_errors = False
[mypy-pystencils.jit.*]
ignore_errors = False
[mypy-pystencils.field]
ignore_errors = False
[mypy-pystencils.sympyextensions.typed_sympy]
ignore_errors = False
[mypy-setuptools.*]
ignore_missing_imports=true
......
......@@ -32,7 +32,7 @@ from .spatial_coordinates import (
from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil
from .simp import AssignmentCollection
from .sympyextensions.typed_sympy import TypedSymbol, DynamicType
from .sympyextensions import SymbolCreator
from .sympyextensions import SymbolCreator, tcast
from .datahandling import create_data_handling
__all__ = [
......@@ -77,6 +77,7 @@ __all__ = [
"x_staggered_vector",
"fd",
"stencil",
"tcast",
]
from . import _version
......
......@@ -93,6 +93,16 @@ class KernelCreationContext:
def index_dtype(self) -> PsIntegerType:
"""Data type used by default for index expressions"""
return self._index_dtype
def resolve_dynamic_type(self, dtype: DynamicType | PsType) -> PsType:
"""Selects the appropriate data type for `DynamicType` instances, and returns all other types as they are."""
match dtype:
case DynamicType.NUMERIC_TYPE:
return self._default_dtype
case DynamicType.INDEX_TYPE:
return self._index_dtype
case _:
return dtype
@property
def metadata(self) -> dict[str, Any]:
......@@ -339,6 +349,8 @@ class KernelCreationContext:
if isinstance(s, TypedSymbol)
)
entry_type = self.resolve_dynamic_type(field.dtype)
if len(idx_types) > 1:
raise KernelConstraintsError(
f"Multiple incompatible types found in index symbols of field {field}: "
......@@ -375,10 +387,10 @@ class KernelCreationContext:
base_ptr = self.get_symbol(
DEFAULTS.field_pointer_name(field.name),
PsPointerType(field.dtype, restrict=True),
PsPointerType(entry_type, restrict=True),
)
return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides)
return PsBuffer(field.name, entry_type, base_ptr, buf_shape, buf_strides)
def _create_buffer_field_buffer(self, field: Field) -> PsBuffer:
if field.spatial_dimensions != 1:
......@@ -418,10 +430,11 @@ class KernelCreationContext:
]
buf_strides = [PsConstant(num_entries, idx_type), PsConstant(1, idx_type)]
buf_dtype = self.resolve_dynamic_type(field.dtype)
base_ptr = self.get_symbol(
DEFAULTS.field_pointer_name(field.name),
PsPointerType(field.dtype, restrict=True),
PsPointerType(buf_dtype, restrict=True),
)
return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides)
return PsBuffer(field.name, buf_dtype, base_ptr, buf_shape, buf_strides)
......@@ -13,7 +13,7 @@ from ...sympyextensions import (
integer_functions,
ConditionalFieldAccess,
)
from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType
from ...sympyextensions.typed_sympy import TypedSymbol, TypeCast, DynamicType
from ...sympyextensions.pointers import AddressOf, mem_acc
from ...field import Field, FieldType
......@@ -261,14 +261,7 @@ class FreezeExpressions:
return num / denom
def map_TypedSymbol(self, expr: TypedSymbol):
dtype = expr.dtype
match dtype:
case DynamicType.NUMERIC_TYPE:
dtype = self._ctx.default_dtype
case DynamicType.INDEX_TYPE:
dtype = self._ctx.index_dtype
dtype = self._ctx.resolve_dynamic_type(expr.dtype)
symb = self._ctx.get_symbol(expr.name, dtype)
return PsSymbolExpr(symb)
......@@ -481,7 +474,7 @@ class FreezeExpressions:
]
return cast(PsCall, args[0])
def map_CastFunc(self, cast_expr: CastFunc) -> PsCast | PsConstantExpr:
def map_TypeCast(self, cast_expr: TypeCast) -> PsCast | PsConstantExpr:
dtype: PsType
match cast_expr.dtype:
case DynamicType.NUMERIC_TYPE:
......
This diff is collapsed.
from __future__ import annotations
from typing import Any
from typing import Any, cast
from os import path
import hashlib
......@@ -19,6 +19,7 @@ from ..types import (
PsUnsignedIntegerType,
PsSignedIntegerType,
PsIeeeFloatType,
PsPointerType,
)
from ..types.quick import Fp, SInt, UInt
from ..field import Field
......@@ -121,9 +122,7 @@ class PsKernelExtensioNModule:
def emit_call_wrapper(function_name: str, kernel: Kernel) -> str:
builder = CallWrapperBuilder()
for p in kernel.parameters:
builder.extract_parameter(p)
builder.extract_params(kernel.parameters)
# for c in kernel.constraints:
# builder.check_constraint(c)
......@@ -199,7 +198,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
"""
def __init__(self) -> None:
self._array_buffers: dict[Field, str] = dict()
self._buffer_types: dict[Field, PsType] = dict()
self._array_extractions: dict[Field, str] = dict()
self._array_frees: dict[Field, str] = dict()
......@@ -220,9 +219,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
return "PyLong_AsUnsignedLong"
case _:
raise ValueError(
f"Don't know how to cast Python objects to {dtype}"
)
raise ValueError(f"Don't know how to cast Python objects to {dtype}")
def _type_char(self, dtype: PsType) -> str | None:
if isinstance(
......@@ -233,37 +230,39 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
else:
return None
def extract_field(self, field: Field) -> str:
"""Adds an array, and returns the name of the underlying Py_Buffer."""
def get_field_buffer(self, field: Field) -> str:
"""Get the Python buffer object for the given field."""
return f"buffer_{field.name}"
def extract_field(self, field: Field) -> None:
"""Add the necessary code to extract the NumPy array for a given field"""
if field not in self._array_extractions:
extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=field.name)
actual_dtype = self._buffer_types[field]
# Check array type
type_char = self._type_char(field.dtype)
type_char = self._type_char(actual_dtype)
if type_char is not None:
dtype_cond = f"buffer_{field.name}.format[0] == '{type_char}'"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=dtype_cond,
what="data type",
name=field.name,
expected=str(field.dtype),
expected=str(actual_dtype),
)
# Check item size
itemsize = field.dtype.itemsize
itemsize = actual_dtype.itemsize
item_size_cond = f"buffer_{field.name}.itemsize == {itemsize}"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=item_size_cond, what="itemsize", name=field.name, expected=itemsize
)
self._array_buffers[field] = f"buffer_{field.name}"
self._array_extractions[field] = extraction_code
release_code = f"PyBuffer_Release(&buffer_{field.name});"
self._array_frees[field] = release_code
return self._array_buffers[field]
def extract_scalar(self, param: Parameter) -> str:
if param not in self._scalar_extractions:
extract_func = self._scalar_extractor(param.dtype)
......@@ -279,7 +278,8 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
def extract_array_assoc_var(self, param: Parameter) -> str:
if param not in self._array_assoc_var_extractions:
field = param.fields[0]
buffer = self.extract_field(field)
buffer = self.get_field_buffer(field)
buffer_dtype = self._buffer_types[field]
code: str | None = None
for prop in param.properties:
......@@ -293,7 +293,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
case FieldStride(_, coord):
code = (
f"{param.dtype.c_string()} {param.name} = "
f"{buffer}.strides[{coord}] / {field.dtype.itemsize};"
f"{buffer}.strides[{coord}] / {buffer_dtype.itemsize};"
)
break
assert code is not None
......@@ -302,29 +302,48 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
return param.name
def extract_parameter(self, param: Parameter):
if param.is_field_parameter:
self.extract_array_assoc_var(param)
else:
self.extract_scalar(param)
def extract_params(self, params: tuple[Parameter, ...]) -> None:
for param in params:
if ptr_props := param.get_properties(FieldBasePtr):
prop: FieldBasePtr = cast(FieldBasePtr, ptr_props.pop())
field = prop.field
actual_field_type: PsType
from .. import DynamicType
if isinstance(field.dtype, DynamicType):
ptr_type = param.dtype
assert isinstance(ptr_type, PsPointerType)
actual_field_type = ptr_type.base_type
else:
actual_field_type = field.dtype
self._buffer_types[prop.field] = actual_field_type
self.extract_field(prop.field)
for param in params:
if param.is_field_parameter:
self.extract_array_assoc_var(param)
else:
self.extract_scalar(param)
# def check_constraint(self, constraint: KernelParamsConstraint):
# variables = constraint.get_parameters()
# def check_constraint(self, constraint: KernelParamsConstraint):
# variables = constraint.get_parameters()
# for var in variables:
# self.extract_parameter(var)
# for var in variables:
# self.extract_parameter(var)
# cond = constraint.to_code()
# cond = constraint.to_code()
# code = f"""
# if(!({cond}))
# {{
# PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}");
# return NULL;
# }}
# """
# code = f"""
# if(!({cond}))
# {{
# PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}");
# return NULL;
# }}
# """
# self._constraint_checks.append(code)
# self._constraint_checks.append(code)
def call(self, kernel: Kernel, params: tuple[Parameter, ...]):
param_list = ", ".join(p.name for p in params)
......
......@@ -19,7 +19,7 @@ from ..codegen import (
Parameter,
)
from ..codegen.properties import FieldShape, FieldStride, FieldBasePtr
from ..types import PsStructType
from ..types import PsStructType, PsPointerType
from ..include import get_pystencils_include_path
......@@ -160,8 +160,18 @@ class CupyKernelWrapper(KernelWrapper):
for prop in kparam.properties:
match prop:
case FieldBasePtr(field):
elem_dtype: PsType
from .. import DynamicType
if isinstance(field.dtype, DynamicType):
assert isinstance(kparam.dtype, PsPointerType)
elem_dtype = kparam.dtype.base_type
else:
elem_dtype = field.dtype
arr = kwargs[field.name]
if arr.dtype != field.dtype.numpy_dtype:
if arr.dtype != elem_dtype.numpy_dtype:
raise JitError(
f"Data type mismatch at array argument {field.name}:"
f"Expected {field.dtype}, got {arr.dtype}"
......
......@@ -2,7 +2,7 @@ import copy
import numpy as np
import sympy as sp
from .sympyextensions import TypedSymbol, CastFunc, fast_subs
from .sympyextensions import TypedSymbol, tcast, fast_subs
# from pystencils.sympyextensions.astnodes import LoopOverCoordinate # TODO nbackend: replace
# from pystencils.backends.cbackend import CustomCodeNode # TODO nbackend: replace
......@@ -48,7 +48,7 @@ class RNGBase:
def get_code(self, dialect, vector_instruction_set, print_arg):
code = "\n"
for r in self.result_symbols:
if vector_instruction_set and not self.args[1].atoms(CastFunc):
if vector_instruction_set and not self.args[1].atoms(tcast):
# this vector RNG has become scalar through substitution
code += f"{r.dtype} {r.name};\n"
else:
......
from .astnodes import ConditionalFieldAccess
from .typed_sympy import TypedSymbol, CastFunc
from .typed_sympy import TypedSymbol, tcast
from .pointers import mem_acc
from .math import (
......@@ -34,7 +34,7 @@ from .math import (
__all__ = [
"ConditionalFieldAccess",
"TypedSymbol",
"CastFunc",
"tcast",
"mem_acc",
"remove_higher_order_terms",
"prod",
......
......@@ -11,7 +11,7 @@ from sympy.functions import Abs
from sympy.core.numbers import Zero
from ..assignment import Assignment
from .typed_sympy import CastFunc
from .typed_sympy import TypeCast
from ..types import PsPointerType, PsVectorType
T = TypeVar('T')
......@@ -603,7 +603,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
visit_children = False
elif t.is_integer:
pass
elif isinstance(t, CastFunc):
elif isinstance(t, TypeCast):
visit_children = False
visit(t.args[0])
elif t.func is fast_sqrt:
......
from __future__ import annotations
from typing import cast
import sympy as sp
from enum import Enum, auto
......@@ -6,11 +7,14 @@ from enum import Enum, auto
from ..types import (
PsType,
PsNumericType,
PsBoolType,
create_type,
UserTypeSpec
)
from sympy.logic.boolalg import Boolean
from warnings import warn
def is_loop_counter_symbol(symbol):
from ..defaults import DEFAULTS
......@@ -37,11 +41,12 @@ class DynamicType(Enum):
class TypeAtom(sp.Atom):
"""Wrapper around a type to disguise it as a SymPy atom."""
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
_dtype: PsType | DynamicType
def __init__(self, dtype: PsType | DynamicType) -> None:
self._dtype = dtype
def __new__(cls, dtype: PsType | DynamicType):
obj = super().__new__(cls)
obj._dtype = dtype
return obj
def _sympystr(self, *args, **kwargs):
return str(self._dtype)
......@@ -52,6 +57,9 @@ class TypeAtom(sp.Atom):
def _hashable_content(self):
return (self._dtype,)
def __getnewargs__(self):
return (self._dtype,)
def assumptions_from_dtype(dtype: PsType | DynamicType):
"""Derives SymPy assumptions from :class:`PsAbstractType`
......@@ -133,144 +141,74 @@ class TypedSymbol(sp.Symbol):
return self.dtype.required_headers if isinstance(self.dtype, PsType) else set()
class CastFunc(sp.Function):
"""Use this function to introduce a static type cast into the output code.
Usage: ``CastFunc(expr, target_type)`` becomes, in C code, ``(target_type) expr``.
The ``target_type`` may be a valid pystencils type specification parsable by `create_type`,
or a special value of the `DynamicType` enum.
These dynamic types can be used to select the target type according to the code generation context.
"""
class TypeCast(sp.Function):
"""Explicitly cast an expression to a data type."""
@staticmethod
def as_numeric(expr):
return CastFunc(expr, DynamicType.NUMERIC_TYPE)
return TypeCast(expr, DynamicType.NUMERIC_TYPE)
@staticmethod
def as_index(expr):
return CastFunc(expr, DynamicType.INDEX_TYPE)
is_Atom = True
def __new__(cls, *args, **kwargs):
if len(args) != 2:
pass
expr, dtype, *other_args = args
# If we have two consecutive casts, throw the inner one away.
# This optimisation is only available for simple casts. Thus the == is intended here!
if expr.__class__ == CastFunc:
expr = expr.args[0]
if not isinstance(dtype, (TypeAtom)):
if isinstance(dtype, DynamicType):
dtype = TypeAtom(dtype)
else:
dtype = TypeAtom(create_type(dtype))
# to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
# however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
# to problems when for example comparing cast_func's for equality
#
# lhs = bitwise_and(a, cast_func(1, 'int'))
# rhs = cast_func(0, 'int')
# print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
# -> thus a separate class boolean_cast_func is introduced
if isinstance(expr, sp.logic.boolalg.Boolean) and (
not isinstance(expr, TypedSymbol) or isinstance(expr.dtype, PsBoolType)
):
cls = BooleanCastFunc
return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
@property
def canonical(self):
if hasattr(self.args[0], "canonical"):
return self.args[0].canonical
else:
raise NotImplementedError()
@property
def is_commutative(self):
return self.args[0].is_commutative
@property
def dtype(self) -> PsType | DynamicType:
assert isinstance(self.args[1], TypeAtom)
return self.args[1].get()
return TypeCast(expr, DynamicType.INDEX_TYPE)
@property
def expr(self):
def expr(self) -> sp.Basic:
return self.args[0]
@property
def is_integer(self):
def dtype(self) -> PsType | DynamicType:
return cast(TypeAtom, self._args[1]).get()
def __new__(cls, expr: sp.Basic, dtype: UserTypeSpec | DynamicType | TypeAtom):
tatom: TypeAtom
match dtype:
case TypeAtom():
tatom = dtype
case DynamicType():
tatom = TypeAtom(dtype)
case _:
tatom = TypeAtom(create_type(dtype))
return super().__new__(cls, expr, tatom)
@classmethod
def eval(cls, expr: sp.Basic, tatom: TypeAtom) -> TypeCast | None:
dtype = tatom.get()
if cls is not BoolCast and isinstance(dtype, PsNumericType) and dtype.is_bool():
return BoolCast(expr, tatom)
return None
def _eval_is_integer(self):
if self.dtype == DynamicType.INDEX_TYPE:
return True
elif isinstance(self.dtype, PsNumericType):
return self.dtype.is_int() or super().is_integer
else:
return super().is_integer
@property
def is_negative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if isinstance(self.dtype, PsNumericType):
if self.dtype.is_uint():
return False
return super().is_negative
@property
def is_nonnegative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if self.is_negative is False:
if isinstance(self.dtype, PsNumericType) and self.dtype.is_int():
return True
def _eval_is_real(self):
if isinstance(self.dtype, DynamicType):
return True
if isinstance(self.dtype, PsNumericType) and (self.dtype.is_float() or self.dtype.is_int()):
return True
def _eval_is_nonnegative(self):
if isinstance(self.dtype, PsNumericType) and self.dtype.is_uint():
return True
else:
return super().is_nonnegative
@property
def is_real(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if isinstance(self.dtype, PsNumericType):
return self.dtype.is_int() or self.dtype.is_float() or super().is_real
else:
return super().is_real
class BooleanCastFunc(CastFunc, sp.logic.boolalg.Boolean):
# TODO: documentation
pass
class VectorMemoryAccess(CastFunc):
"""
Special memory access for vectorized kernel.
Arguments: read/write expression, type, aligned, non-temporal, mask (or none), stride
"""
nargs = (6,)
class BoolCast(TypeCast, Boolean):
pass
class ReinterpretCastFunc(CastFunc):
"""
Reinterpret cast is necessary for the StructType
"""
pass
tcast = TypeCast
class PointerArithmeticFunc(sp.Function, sp.logic.boolalg.Boolean):
# TODO: documentation, or deprecate!
@property
def canonical(self):
if hasattr(self.args[0], "canonical"):
return self.args[0].canonical
else:
raise NotImplementedError()
class CastFunc(TypeCast):
def __new__(cls, *args, **kwargs):
warn(
"CastFunc is deprecated and will be removed in pystencils 2.1. "
"Use `pystencils.tcast` instead.",
FutureWarning
)
return TypeCast.__new__(cls, *args, **kwargs)
......@@ -5,7 +5,7 @@ import pytest
import pystencils
from pystencils.types import PsPointerType, create_type
from pystencils.sympyextensions.pointers import AddressOf
from pystencils.sympyextensions.typed_sympy import CastFunc
from pystencils.sympyextensions.typed_sympy import tcast
from pystencils.simp import sympy_cse
import sympy as sp
......@@ -23,14 +23,14 @@ def test_address_of():
assignments = pystencils.AssignmentCollection({
s: AddressOf(x[0, 0]),
y[0, 0]: CastFunc(s, create_type('int64'))
y[0, 0]: tcast(s, create_type('int64'))
})
_ = pystencils.create_kernel(assignments).compile()
# pystencils.show_code(kernel.ast)
assignments = pystencils.AssignmentCollection({
y[0, 0]: CastFunc(AddressOf(x[0, 0]), create_type('int64'))
y[0, 0]: tcast(AddressOf(x[0, 0]), create_type('int64'))
})
_ = pystencils.create_kernel(assignments).compile()
......@@ -41,7 +41,7 @@ def test_address_of_with_cse():
x, y = pystencils.fields('x, y: int64[2d]')
assignments = pystencils.AssignmentCollection({
x[0, 0]: CastFunc(AddressOf(x[0, 0]), create_type('int64')) + 1
x[0, 0]: tcast(AddressOf(x[0, 0]), create_type('int64')) + 1
})
_ = pystencils.create_kernel(assignments).compile()
......
......@@ -3,7 +3,7 @@ import pytest
import sympy as sp
import pystencils as ps
from pystencils import DEFAULTS
from pystencils import DEFAULTS, DynamicType, create_type, fields
from pystencils.field import (
Field,
FieldType,
......@@ -15,6 +15,7 @@ from pystencils.field import (
def test_field_basic():
f = Field.create_generic("f", spatial_dimensions=2)
assert FieldType.is_generic(f)
assert f.dtype == DynamicType.NUMERIC_TYPE
assert f["E"] == f[1, 0]
assert f["N"] == f[0, 1]
assert "_" in f.center._latex("dummy")
......@@ -41,17 +42,16 @@ def test_field_basic():
assert f1.ndim == f.ndim
assert f1.values_per_cell() == f.values_per_cell()
fixed = ps.fields("f(5, 5) : double[20, 20]")
assert fixed.neighbor_vector((1, 1)).shape == (5, 5)
f = Field.create_fixed_size("f", (10, 10), strides=(80, 8), dtype=np.float64)
f = Field.create_fixed_size("f", (10, 10), strides=(10, 1), dtype=np.float64)
assert f.spatial_strides == (10, 1)
assert f.index_strides == ()
assert f.center_vector == sp.Matrix([f.center])
assert f.dtype == create_type("float64")
f1 = f.new_field_with_different_name("f1")
assert f1.ndim == f.ndim
assert f1.values_per_cell() == f.values_per_cell()
assert f1.dtype == create_type("float64")
f = Field.create_fixed_size("f", (8, 8, 2, 2), index_dimensions=2)
assert f.center_vector == sp.Matrix([[f(0, 0), f(0, 1)], [f(1, 0), f(1, 1)]])
......@@ -61,16 +61,48 @@ def test_field_basic():
neighbor = field_access.neighbor(coord_id=0, offset=-2)
assert neighbor.offsets == (-1, 1)
assert "_" in neighbor._latex("dummy")
assert f.dtype == DynamicType.NUMERIC_TYPE
f = Field.create_fixed_size("f", (8, 8, 2, 2, 2), index_dimensions=3)
assert f.center_vector == sp.Array(
[[[f(i, j, k) for k in range(2)] for j in range(2)] for i in range(2)]
)
assert f.dtype == DynamicType.NUMERIC_TYPE
f = Field.create_generic("f", spatial_dimensions=5, index_dimensions=2)
field_access = f[1, -1, 2, -3, 0](1, 0)
assert field_access.offsets == (1, -1, 2, -3, 0)
assert field_access.index == (1, 0)
assert f.dtype == DynamicType.NUMERIC_TYPE
def test_field_description_parsing():
f, g = fields("f(1), g(3): [2D]")
assert f.dtype == g.dtype == DynamicType.NUMERIC_TYPE
assert f.spatial_dimensions == g.spatial_dimensions == 2
assert f.index_shape == (1,)
assert g.index_shape == (3,)
f = fields("f: dyn[3D]")
assert f.dtype == DynamicType.NUMERIC_TYPE
idx = fields("idx: dynidx[3D]")
assert idx.dtype == DynamicType.INDEX_TYPE
h = fields("h: float32[3D]")
assert h.index_shape == ()
assert h.spatial_dimensions == 3
assert h.index_dimensions == 0
assert h.dtype == create_type("float32")
f: Field = fields("f(5, 5) : double[20, 20]")
assert f.dtype == create_type("float64")
assert f.spatial_shape == (20, 20)
assert f.index_shape == (5, 5)
assert f.neighbor_vector((1, 1)).shape == (5, 5)
def test_error_handling():
......@@ -145,7 +177,7 @@ def test_error_handling():
def test_decorator_scoping():
dst = ps.fields("dst : double[2D]")
dst = fields("dst : double[2D]")
def f1():
a = sp.Symbol("a")
......@@ -165,7 +197,7 @@ def test_decorator_scoping():
def test_string_creation():
x, y, z = ps.fields(" x(4), y(3,5) z : double[ 3, 47]")
x, y, z = fields(" x(4), y(3,5) z : double[ 3, 47]")
assert x.index_shape == (4,)
assert y.index_shape == (3, 5)
assert z.spatial_shape == (3, 47)
......@@ -173,9 +205,9 @@ def test_string_creation():
def test_itemsize():
x = ps.fields("x: float32[1d]")
y = ps.fields("y: float64[2d]")
i = ps.fields("i: int16[1d]")
x = fields("x: float32[1d]")
y = fields("y: float64[2d]")
i = fields("i: int16[1d]")
assert x.itemsize == 4
assert y.itemsize == 8
......@@ -249,7 +281,7 @@ def test_memory_layout_descriptors():
def test_staggered():
# D2Q5
j1, j2, j3 = ps.fields(
j1, j2, j3 = fields(
"j1(2), j2(2,2), j3(2,2,2) : double[2D]", field_type=FieldType.STAGGERED
)
......@@ -296,7 +328,7 @@ def test_staggered():
)
# D2Q9
k1, k2 = ps.fields("k1(4), k2(2) : double[2D]", field_type=FieldType.STAGGERED)
k1, k2 = fields("k1(4), k2(2) : double[2D]", field_type=FieldType.STAGGERED)
assert k1[1, 1](2) == k1.staggered_access("NE")
assert k1[0, 0](2) == k1.staggered_access("SW")
......@@ -319,7 +351,7 @@ def test_staggered():
]
# sign reversed when using as flux field
r = ps.fields("r(2) : double[2D]", field_type=FieldType.STAGGERED_FLUX)
r = fields("r(2) : double[2D]", field_type=FieldType.STAGGERED_FLUX)
assert r[0, 0](0) == r.staggered_access("W")
assert -r[1, 0](0) == r.staggered_access("E")
......
import numpy as np
import pickle
import sympy as sp
from sympy.logic import boolalg
from pystencils.sympyextensions.typed_sympy import (
TypedSymbol,
CastFunc,
tcast,
TypeAtom,
DynamicType,
)
......@@ -12,7 +15,7 @@ from pystencils.types.quick import UInt, Ptr
def test_type_atoms():
atom1 = TypeAtom(create_type("int32"))
atom2 = TypeAtom(create_type("int32"))
atom2 = TypeAtom(create_type(np.int32))
assert atom1 == atom2
......@@ -25,6 +28,11 @@ def test_type_atoms():
assert atom3 != atom4
assert atom4 != atom5
dump = pickle.dumps(atom1)
atom1_reconst = pickle.loads(dump)
assert atom1_reconst == atom1
def test_typed_symbol():
x = TypedSymbol("x", "uint32")
......@@ -46,12 +54,34 @@ def test_typed_symbol():
assert not z.is_nonnegative
def test_cast_func():
assert (
CastFunc(TypedSymbol("s", np.uint), np.int64).canonical
== TypedSymbol("s", np.uint).canonical
)
a = CastFunc(5, np.uint)
assert a.is_negative is False
assert a.is_nonnegative
def test_casts():
x, y = sp.symbols("x, y")
# Pickling
expr = tcast(x, "int32")
dump = pickle.dumps(expr)
expr_reconst = pickle.loads(dump)
assert expr_reconst == expr
# Boolean Casts
bool_expr = tcast(x, "bool")
assert isinstance(bool_expr, boolalg.Boolean)
# Check that we can construct boolean expressions with cast results
_ = boolalg.Or(bool_expr, y)
# Assumptions
expr = tcast(x, "int32")
assert expr.is_integer
assert expr.is_real
assert expr.is_nonnegative is None
expr = tcast(x, "uint32")
assert expr.is_integer
assert expr.is_real
assert expr.is_nonnegative
expr = tcast(x, "float32")
assert expr.is_integer is None
assert expr.is_real
assert expr.is_nonnegative is None
......@@ -9,7 +9,7 @@ from pystencils import (
DEFAULTS,
FieldType,
)
from pystencils.sympyextensions import CastFunc
from pystencils.sympyextensions import tcast
@pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"])
......@@ -21,9 +21,9 @@ def test_spatial_counters_dense(index_dtype):
f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
asms = [
Assignment(f(0), CastFunc.as_numeric(z)),
Assignment(f(1), CastFunc.as_numeric(y)),
Assignment(f(2), CastFunc.as_numeric(x)),
Assignment(f(0), tcast.as_numeric(z)),
Assignment(f(1), tcast.as_numeric(y)),
Assignment(f(2), tcast.as_numeric(x)),
]
cfg = CreateKernelConfig(index_dtype=index_dtype)
......@@ -44,9 +44,9 @@ def test_spatial_counters_sparse(index_dtype):
f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
asms = [
Assignment(f(0), CastFunc.as_numeric(x)),
Assignment(f(1), CastFunc.as_numeric(y)),
Assignment(f(2), CastFunc.as_numeric(z)),
Assignment(f(0), tcast.as_numeric(x)),
Assignment(f(1), tcast.as_numeric(y)),
Assignment(f(2), tcast.as_numeric(z)),
]
idx_struct = DEFAULTS.index_struct(index_dtype, 3)
......