diff --git a/mypy.ini b/mypy.ini index cc23a503a2da6c9849d3a41e82fe8ceb8de13b43..08f073f7c6f688fe110ee71c1aa0836b0cd90ec5 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,6 +17,9 @@ ignore_errors = False [mypy-pystencils.jit.*] ignore_errors = False +[mypy-pystencils.sympyextensions.typed_sympy] +ignore_errors = False + [mypy-setuptools.*] ignore_missing_imports=true diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 6cb375b61e63904c0cd3c2e6e9d6a3be86be6b29..3d3727e8c60d83c475eaedfc435b11fdbb8dc58b 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -36,7 +36,7 @@ from .spatial_coordinates import ( from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil from .simp import AssignmentCollection from .sympyextensions.typed_sympy import TypedSymbol, DynamicType -from .sympyextensions import SymbolCreator +from .sympyextensions import SymbolCreator, tcast from .datahandling import create_data_handling __all__ = [ @@ -85,6 +85,7 @@ __all__ = [ "x_staggered_vector", "fd", "stencil", + "tcast", ] from . import _version diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 44ee170775982916e34ab6da6656461962763cd4..16710861b9a032a8bee3daae4b2483c9432912ee 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -13,7 +13,7 @@ from ...sympyextensions import ( integer_functions, ConditionalFieldAccess, ) -from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType +from ...sympyextensions.typed_sympy import TypedSymbol, TypeCast, DynamicType from ...sympyextensions.pointers import AddressOf, mem_acc from ...field import Field, FieldType @@ -481,7 +481,7 @@ class FreezeExpressions: ] return cast(PsCall, args[0]) - def map_CastFunc(self, cast_expr: CastFunc) -> PsCast | PsConstantExpr: + def map_TypeCast(self, cast_expr: TypeCast) -> PsCast | PsConstantExpr: dtype: PsType match cast_expr.dtype: case DynamicType.NUMERIC_TYPE: diff --git a/src/pystencils/rng.py b/src/pystencils/rng.py index d6c6cd2741ee3e7442bd9fa4a96f4e9983d524e3..4f8316fa75284ed0fa3385744bd9b93f88d5ae65 100644 --- a/src/pystencils/rng.py +++ b/src/pystencils/rng.py @@ -2,7 +2,7 @@ import copy import numpy as np import sympy as sp -from .sympyextensions import TypedSymbol, CastFunc, fast_subs +from .sympyextensions import TypedSymbol, tcast, fast_subs # from pystencils.sympyextensions.astnodes import LoopOverCoordinate # TODO nbackend: replace # from pystencils.backends.cbackend import CustomCodeNode # TODO nbackend: replace @@ -48,7 +48,7 @@ class RNGBase: def get_code(self, dialect, vector_instruction_set, print_arg): code = "\n" for r in self.result_symbols: - if vector_instruction_set and not self.args[1].atoms(CastFunc): + if vector_instruction_set and not self.args[1].atoms(tcast): # this vector RNG has become scalar through substitution code += f"{r.dtype} {r.name};\n" else: diff --git a/src/pystencils/sympyextensions/__init__.py b/src/pystencils/sympyextensions/__init__.py index 7431416c9eb9bcd4433dab76c32fb1b755501105..2d874fdc0778a331aaf61ed938981f533eafbecb 100644 --- a/src/pystencils/sympyextensions/__init__.py +++ b/src/pystencils/sympyextensions/__init__.py @@ -1,5 +1,5 @@ from .astnodes import ConditionalFieldAccess -from .typed_sympy import TypedSymbol, CastFunc +from .typed_sympy import TypedSymbol, tcast from .pointers import mem_acc from .math import ( @@ -34,7 +34,7 @@ from .math import ( __all__ = [ "ConditionalFieldAccess", "TypedSymbol", - "CastFunc", + "tcast", "mem_acc", "remove_higher_order_terms", "prod", diff --git a/src/pystencils/sympyextensions/math.py b/src/pystencils/sympyextensions/math.py index 9841a98bd83162fbb080db370556de70612bc398..33c035499ee80598303c8a26b028e47dfae72cc3 100644 --- a/src/pystencils/sympyextensions/math.py +++ b/src/pystencils/sympyextensions/math.py @@ -11,7 +11,7 @@ from sympy.functions import Abs from sympy.core.numbers import Zero from ..assignment import Assignment -from .typed_sympy import CastFunc +from .typed_sympy import TypeCast from ..types import PsPointerType, PsVectorType T = TypeVar('T') @@ -603,7 +603,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]], visit_children = False elif t.is_integer: pass - elif isinstance(t, CastFunc): + elif isinstance(t, TypeCast): visit_children = False visit(t.args[0]) elif t.func is fast_sqrt: diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index 39202296b477dc17ea6e9564548ef841fd04594d..509752bdf8fe0a4cca1504b703ad47e5cc2b90c1 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import cast import sympy as sp from enum import Enum, auto @@ -6,11 +7,14 @@ from enum import Enum, auto from ..types import ( PsType, PsNumericType, - PsBoolType, create_type, UserTypeSpec ) +from sympy.logic.boolalg import Boolean + +from warnings import warn + def is_loop_counter_symbol(symbol): from ..defaults import DEFAULTS @@ -37,11 +41,12 @@ class DynamicType(Enum): class TypeAtom(sp.Atom): """Wrapper around a type to disguise it as a SymPy atom.""" - def __new__(cls, *args, **kwargs): - return sp.Basic.__new__(cls) + _dtype: PsType | DynamicType - def __init__(self, dtype: PsType | DynamicType) -> None: - self._dtype = dtype + def __new__(cls, dtype: PsType | DynamicType): + obj = super().__new__(cls) + obj._dtype = dtype + return obj def _sympystr(self, *args, **kwargs): return str(self._dtype) @@ -52,6 +57,9 @@ class TypeAtom(sp.Atom): def _hashable_content(self): return (self._dtype,) + def __getnewargs__(self): + return (self._dtype,) + def assumptions_from_dtype(dtype: PsType | DynamicType): """Derives SymPy assumptions from :class:`PsAbstractType` @@ -133,144 +141,76 @@ class TypedSymbol(sp.Symbol): return self.dtype.required_headers if isinstance(self.dtype, PsType) else set() -class CastFunc(sp.Function): - """Use this function to introduce a static type cast into the output code. - - Usage: ``CastFunc(expr, target_type)`` becomes, in C code, ``(target_type) expr``. - The ``target_type`` may be a valid pystencils type specification parsable by `create_type`, - or a special value of the `DynamicType` enum. - These dynamic types can be used to select the target type according to the code generation context. - """ +class TypeCast(sp.Function): + """Explicitly cast an expression to a data type.""" @staticmethod def as_numeric(expr): - return CastFunc(expr, DynamicType.NUMERIC_TYPE) + return TypeCast(expr, DynamicType.NUMERIC_TYPE) @staticmethod def as_index(expr): - return CastFunc(expr, DynamicType.INDEX_TYPE) - - is_Atom = True - - def __new__(cls, *args, **kwargs): - if len(args) != 2: - pass - expr, dtype, *other_args = args - - # If we have two consecutive casts, throw the inner one away. - # This optimisation is only available for simple casts. Thus the == is intended here! - if expr.__class__ == CastFunc: - expr = expr.args[0] - - if not isinstance(dtype, (TypeAtom)): - if isinstance(dtype, DynamicType): - dtype = TypeAtom(dtype) - else: - dtype = TypeAtom(create_type(dtype)) - - # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well - # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads - # to problems when for example comparing cast_func's for equality - # - # lhs = bitwise_and(a, cast_func(1, 'int')) - # rhs = cast_func(0, 'int') - # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans - # -> thus a separate class boolean_cast_func is introduced - if isinstance(expr, sp.logic.boolalg.Boolean) and ( - not isinstance(expr, TypedSymbol) or isinstance(expr.dtype, PsBoolType) - ): - cls = BooleanCastFunc - - return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs) - - @property - def canonical(self): - if hasattr(self.args[0], "canonical"): - return self.args[0].canonical - else: - raise NotImplementedError() - - @property - def is_commutative(self): - return self.args[0].is_commutative - - @property - def dtype(self) -> PsType | DynamicType: - assert isinstance(self.args[1], TypeAtom) - return self.args[1].get() - + return TypeCast(expr, DynamicType.INDEX_TYPE) + @property - def expr(self): + def expr(self) -> sp.Basic: return self.args[0] @property - def is_integer(self): + def dtype(self) -> PsType | DynamicType: + return cast(TypeAtom, self._args[1]).get() + + def __new__(cls, expr: sp.Basic, dtype: UserTypeSpec | DynamicType | TypeAtom): + tatom: TypeAtom + match dtype: + case TypeAtom(): + tatom = dtype + case DynamicType(): + tatom = TypeAtom(dtype) + case _: + tatom = TypeAtom(create_type(dtype)) + + return super().__new__(cls, expr, tatom) + + @classmethod + def eval(cls, expr: sp.Basic, tatom: TypeAtom) -> TypeCast | None: + if isinstance(expr, TypeCast): + return TypeCast(expr.args[0], tatom) + + dtype = tatom.get() + if cls is not BoolCast and isinstance(dtype, PsNumericType) and dtype.is_bool(): + return BoolCast(expr, tatom) + + return None + + def _eval_is_integer(self): if self.dtype == DynamicType.INDEX_TYPE: return True - elif isinstance(self.dtype, PsNumericType): - return self.dtype.is_int() or super().is_integer - else: - return super().is_integer - - @property - def is_negative(self): - """ - See :func:`.TypedSymbol.is_integer` - """ - if isinstance(self.dtype, PsNumericType): - if self.dtype.is_uint(): - return False - - return super().is_negative - - @property - def is_nonnegative(self): - """ - See :func:`.TypedSymbol.is_integer` - """ - if self.is_negative is False: + if isinstance(self.dtype, PsNumericType) and self.dtype.is_int(): + return True + + def _eval_is_real(self): + if isinstance(self.dtype, DynamicType): + return True + if isinstance(self.dtype, PsNumericType) and (self.dtype.is_float() or self.dtype.is_int()): + return True + + def _eval_is_nonnegative(self): + if isinstance(self.dtype, PsNumericType) and self.dtype.is_uint(): return True - else: - return super().is_nonnegative - - @property - def is_real(self): - """ - See :func:`.TypedSymbol.is_integer` - """ - if isinstance(self.dtype, PsNumericType): - return self.dtype.is_int() or self.dtype.is_float() or super().is_real - else: - return super().is_real - - -class BooleanCastFunc(CastFunc, sp.logic.boolalg.Boolean): - # TODO: documentation - pass - - -class VectorMemoryAccess(CastFunc): - """ - Special memory access for vectorized kernel. - Arguments: read/write expression, type, aligned, non-temporal, mask (or none), stride - """ - nargs = (6,) +class BoolCast(TypeCast, Boolean): + pass -class ReinterpretCastFunc(CastFunc): - """ - Reinterpret cast is necessary for the StructType - """ - pass +tcast = TypeCast -class PointerArithmeticFunc(sp.Function, sp.logic.boolalg.Boolean): - # TODO: documentation, or deprecate! - @property - def canonical(self): - if hasattr(self.args[0], "canonical"): - return self.args[0].canonical - else: - raise NotImplementedError() +class CastFunc(sp.Function): + def __new__(cls, *args, **kwargs): + warn( + "CastFunc is deprecated and will be removed in pystencils 2.1. " + "Use `pystencils.tcast` instead.", + FutureWarning + ) diff --git a/tests/frontend/test_address_of.py b/tests/frontend/test_address_of.py index 99f33ddbdfa7054bf5f27c08848640ee03f64555..62d7f00d56b288c009c9dc4fcfade95b95acdd41 100644 --- a/tests/frontend/test_address_of.py +++ b/tests/frontend/test_address_of.py @@ -5,7 +5,7 @@ import pytest import pystencils from pystencils.types import PsPointerType, create_type from pystencils.sympyextensions.pointers import AddressOf -from pystencils.sympyextensions.typed_sympy import CastFunc +from pystencils.sympyextensions.typed_sympy import tcast from pystencils.simp import sympy_cse import sympy as sp @@ -23,14 +23,14 @@ def test_address_of(): assignments = pystencils.AssignmentCollection({ s: AddressOf(x[0, 0]), - y[0, 0]: CastFunc(s, create_type('int64')) + y[0, 0]: tcast(s, create_type('int64')) }) _ = pystencils.create_kernel(assignments).compile() # pystencils.show_code(kernel.ast) assignments = pystencils.AssignmentCollection({ - y[0, 0]: CastFunc(AddressOf(x[0, 0]), create_type('int64')) + y[0, 0]: tcast(AddressOf(x[0, 0]), create_type('int64')) }) _ = pystencils.create_kernel(assignments).compile() @@ -41,7 +41,7 @@ def test_address_of_with_cse(): x, y = pystencils.fields('x, y: int64[2d]') assignments = pystencils.AssignmentCollection({ - x[0, 0]: CastFunc(AddressOf(x[0, 0]), create_type('int64')) + 1 + x[0, 0]: tcast(AddressOf(x[0, 0]), create_type('int64')) + 1 }) _ = pystencils.create_kernel(assignments).compile() diff --git a/tests/frontend/test_typed_sympy.py b/tests/frontend/test_typed_sympy.py index 41015f96bfa6950a57f9ccfa3194c128c2bc0f69..29b872952f605511ca44501efa761f889838a1a9 100644 --- a/tests/frontend/test_typed_sympy.py +++ b/tests/frontend/test_typed_sympy.py @@ -1,8 +1,11 @@ import numpy as np +import pickle +import sympy as sp +from sympy.logic import boolalg from pystencils.sympyextensions.typed_sympy import ( TypedSymbol, - CastFunc, + tcast, TypeAtom, DynamicType, ) @@ -12,7 +15,7 @@ from pystencils.types.quick import UInt, Ptr def test_type_atoms(): atom1 = TypeAtom(create_type("int32")) - atom2 = TypeAtom(create_type("int32")) + atom2 = TypeAtom(create_type(np.int32)) assert atom1 == atom2 @@ -25,6 +28,11 @@ def test_type_atoms(): assert atom3 != atom4 assert atom4 != atom5 + dump = pickle.dumps(atom1) + atom1_reconst = pickle.loads(dump) + + assert atom1_reconst == atom1 + def test_typed_symbol(): x = TypedSymbol("x", "uint32") @@ -46,12 +54,38 @@ def test_typed_symbol(): assert not z.is_nonnegative -def test_cast_func(): - assert ( - CastFunc(TypedSymbol("s", np.uint), np.int64).canonical - == TypedSymbol("s", np.uint).canonical - ) - - a = CastFunc(5, np.uint) - assert a.is_negative is False - assert a.is_nonnegative +def test_casts(): + x, y = sp.symbols("x, y") + + # Pickling + expr = tcast(x, "int32") + dump = pickle.dumps(expr) + expr_reconst = pickle.loads(dump) + assert expr_reconst == expr + + # Double Cast Elimination + expr = tcast(tcast(x, "int32"), "uint64") + assert expr == tcast(x, "uint64") + + # Boolean Casts + bool_expr = tcast(x, "bool") + assert isinstance(bool_expr, boolalg.Boolean) + + # Check that we can construct boolean expressions with cast results + _ = boolalg.Or(bool_expr, y) + + # Assumptions + expr = tcast(x, "int32") + assert expr.is_integer + assert expr.is_real + assert expr.is_nonnegative is None + + expr = tcast(x, "uint32") + assert expr.is_integer + assert expr.is_real + assert expr.is_nonnegative + + expr = tcast(x, "float32") + assert expr.is_integer is None + assert expr.is_real + assert expr.is_nonnegative is None diff --git a/tests/kernelcreation/test_spatial_counters.py b/tests/kernelcreation/test_spatial_counters.py index fdb365294c98311943c370cb650694b1a4bd8613..4f865ad97f42f31133cc5d0dc3fbba569f6f577d 100644 --- a/tests/kernelcreation/test_spatial_counters.py +++ b/tests/kernelcreation/test_spatial_counters.py @@ -9,7 +9,7 @@ from pystencils import ( DEFAULTS, FieldType, ) -from pystencils.sympyextensions import CastFunc +from pystencils.sympyextensions import tcast @pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"]) @@ -21,9 +21,9 @@ def test_spatial_counters_dense(index_dtype): f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx") asms = [ - Assignment(f(0), CastFunc.as_numeric(z)), - Assignment(f(1), CastFunc.as_numeric(y)), - Assignment(f(2), CastFunc.as_numeric(x)), + Assignment(f(0), tcast.as_numeric(z)), + Assignment(f(1), tcast.as_numeric(y)), + Assignment(f(2), tcast.as_numeric(x)), ] cfg = CreateKernelConfig(index_dtype=index_dtype) @@ -44,9 +44,9 @@ def test_spatial_counters_sparse(index_dtype): f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx") asms = [ - Assignment(f(0), CastFunc.as_numeric(x)), - Assignment(f(1), CastFunc.as_numeric(y)), - Assignment(f(2), CastFunc.as_numeric(z)), + Assignment(f(0), tcast.as_numeric(x)), + Assignment(f(1), tcast.as_numeric(y)), + Assignment(f(2), tcast.as_numeric(z)), ] idx_struct = DEFAULTS.index_struct(index_dtype, 3) diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index ce4f6178511aaad913819eafb59d9ccae42ee992..b7b2ed19e114a3c3d7568fbe6b6ea035848f42e4 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -9,7 +9,7 @@ from pystencils import ( TypedSymbol, DynamicType, ) -from pystencils.sympyextensions import CastFunc +from pystencils.sympyextensions import tcast from pystencils.sympyextensions.pointers import mem_acc from pystencils.backend.ast.structural import ( @@ -318,16 +318,16 @@ def test_cast_func(): y2 = PsExpression.make(ctx.get_symbol("y")) z2 = PsExpression.make(ctx.get_symbol("z")) - expr = freeze(CastFunc(x, create_type("int"))) + expr = freeze(tcast(x, create_type("int"))) assert expr.structurally_equal(PsCast(create_type("int"), x2)) - expr = freeze(CastFunc.as_numeric(y)) + expr = freeze(tcast.as_numeric(y)) assert expr.structurally_equal(PsCast(ctx.default_dtype, y2)) - expr = freeze(CastFunc.as_index(z)) + expr = freeze(tcast.as_index(z)) assert expr.structurally_equal(PsCast(ctx.index_dtype, z2)) - expr = freeze(CastFunc(42, create_type("int16"))) + expr = freeze(tcast(42, create_type("int16"))) assert expr.structurally_equal(PsConstantExpr(PsConstant(42, create_type("int16")))) diff --git a/tests/nbackend/transformations/test_ast_vectorizer.py b/tests/nbackend/transformations/test_ast_vectorizer.py index ea425349529c45b94317a98d2d9f305933c9ba60..81a301278505f8db916044c8a3885de84f1d3a00 100644 --- a/tests/nbackend/transformations/test_ast_vectorizer.py +++ b/tests/nbackend/transformations/test_ast_vectorizer.py @@ -2,7 +2,7 @@ import sympy as sp import pytest from pystencils import Assignment, TypedSymbol, fields, FieldType -from pystencils.sympyextensions import CastFunc, mem_acc +from pystencils.sympyextensions import tcast, mem_acc from pystencils.sympyextensions.pointers import AddressOf from pystencils.backend.constants import PsConstant @@ -101,7 +101,7 @@ def test_vectorize_casts_and_counter(): axis = VectorizationAxis(ctr, vec_ctr) vc = VectorizationContext(ctx, 4, axis) - expr = factory.parse_sympy(CastFunc(sp.Symbol("ctr"), create_type("float32"))) + expr = factory.parse_sympy(tcast(sp.Symbol("ctr"), create_type("float32"))) vec_expr = vectorize.visit(expr, vc) assert isinstance(vec_expr, PsCast) @@ -128,7 +128,7 @@ def test_invalid_vectorization(): axis = VectorizationAxis(ctr) vc = VectorizationContext(ctx, 4, axis) - expr = factory.parse_sympy(CastFunc(sp.Symbol("ctr"), create_type("float32"))) + expr = factory.parse_sympy(tcast(sp.Symbol("ctr"), create_type("float32"))) with pytest.raises(VectorizationError): # Fails since no vectorized counter was specified @@ -169,7 +169,7 @@ def test_vectorize_declarations(): [ factory.parse_sympy(asm) for asm in [ - Assignment(x, CastFunc.as_numeric(ctr)), + Assignment(x, tcast.as_numeric(ctr)), Assignment(y, sp.cos(x)), Assignment(z, x**2 + 2 * y / 4), Assignment(w, -x + y - z), diff --git a/tests/runtime/test_data/datahandling_save_test.npz b/tests/runtime/test_data/datahandling_save_test.npz index 22202358a4fa1d1cea4db89c0889f5bca636598b..486c7ee74d4421d563c3b1c2e3739d8db6308b07 100644 Binary files a/tests/runtime/test_data/datahandling_save_test.npz and b/tests/runtime/test_data/datahandling_save_test.npz differ