diff --git a/docs/source/api/symbolic/index.md b/docs/source/api/symbolic/index.md index fad3df20b68fcac0aba0fcec3ef4564138df5f01..86c25e10ea0eba76ec96eaa5c886d654647bbfde 100644 --- a/docs/source/api/symbolic/index.md +++ b/docs/source/api/symbolic/index.md @@ -3,6 +3,7 @@ :::{toctree} :maxdepth: 1 +sympy_features field assignments sympyextensions diff --git a/docs/source/api/symbolic/sympy_features.md b/docs/source/api/symbolic/sympy_features.md new file mode 100644 index 0000000000000000000000000000000000000000..d7d0b8211658fb2d35575d825503dc85c98f9e27 --- /dev/null +++ b/docs/source/api/symbolic/sympy_features.md @@ -0,0 +1,107 @@ +# Supported Subset of SymPy + +This page lists the parts of [SymPy](https://sympy.org)'s symbolic algebra toolbox +that pystencils is able to parse and translate into code. +This includes a list of restrictions and caveats. + +## sympy.core + +:::{list-table} + +* - Symbols + - {any}`sp.Symbol <sympy.core.symbol.Symbol>` + - Represent untyped variables +* - Integers + - {any}`sp.Integer <sympy.core.numbers.Integer>` + - +* - Rational Numbers + - {any}`sp.Rational <sympy.core.numbers.Rational>` + - Non-integer rationals are interpreted using the division operation + of their context's data type. +* - Arbitrary-Precision Floating Point + - {any}`sp.Float <sympy.core.numbers.Float>` + - Will initially be narrowed to double-precision (aka. the Python {any}`float` type), +* - Transcendentals: $\pi$ and $e$ + - {any}`sp.pi <sympy.core.numbers.Pi>`, {any}`sp.E <sympy.core.numbers.Exp1>` + - Only valid in floating-point contexts. + Will be rounded to the nearest number representable in the + respective data type (e.g. `pi = 3.1415927` for `float32`). +* - Infinities ($\pm \infty$) + - {any}`sp.Infinity <sympy.core.numbers.Infinity>`, + {any}`sp.NegativeInfinity <sympy.core.numbers.NegativeInfinity>` + - Only valid in floating point contexts. +* - Arithmetic + - {any}`sp.Add <sympy.core.add.Add>`, + {any}`sp.Mul <sympy.core.mul.Mul>`, + {any}`sp.Pow <sympy.core.power.Pow>`, + - Integer powers up to $8$ will be expanded by pairwise multiplication. + Negative integer powers will be replaced by divisions. + Square root powers with a numerator $\le 8$ will be replaced + by (products of) the `sqrt` function. +* - Relations (`==`, `<=`, `>`, ...) + - {any}`sp.Relational <sympy.core.relational.Relational>` + - Result has type `boolean` and can only be used in boolean contexts. +* - (Nested) Tuples + - {any}`sp.Tuple <sympy.core.containers.Tuple>` + - Tuples of expressions are interpreted as array literals. + Tuples that contain further nested tuples must have a uniform, cuboid structure, + i.e. represent a proper n-dimensional array, + to be parsed as multidimensional array literals; otherwise an error is raised. +::: + +## sympy.functions + +:::{list-table} +* - [Trigonometry](https://docs.sympy.org/latest/modules/functions/elementary.html#trigonometric) + - {any}`sp.sin <sympy.functions.elementary.trigonometric.sin>`, + {any}`sp.asin <sympy.functions.elementary.trigonometric.asin>`, + ... + - Only valid in floating-point contexts +* - Hyperbolic Functions + - {any}`sp.sinh <sympy.functions.elementary.hyperbolic.sinh>`, + {any}`sp.cosh <sympy.functions.elementary.hyperbolic.cosh>`, + - Only valid in floating-point contexts +* - Exponentials + - {any}`sp.exp <sympy.functions.elementary.exponential.exp>`, + {any}`sp.log <sympy.functions.elementary.exponential.log>`, + - Only valid in floating-point contexts +* - Absolute + - {any}`sp.Abs <sympy.functions.elementary.complexes.Abs>` + - +* - Rounding + - {any}`sp.floor <sympy.functions.elementary.integers.floor>`, + {any}`sp.ceiling <sympy.functions.elementary.integers.ceiling>` + - Result will have the same data type as the arguments, so in order to + get an integer, a type cast is additionally required (see {any}`tcast <pystencils.sympyextensions.tcast>`) +* - Min/Max + - {any}`sp.Min <sympy.functions.elementary.miscellaneous.Min>`, + {any}`sp.Max <sympy.functions.elementary.miscellaneous.Max>` + - +* - Piecewise Functions + - {any}`sp.Piecewise <sympy.functions.elementary.piecewise.Piecewise>` + - Cases of the piecewise function must be exhaustive; i.e. end with a default case. +::: + +## sympy.logic + +:::{list-table} +* - Boolean atoms + - {any}`sp.true <sympy.logic.boolalg.BooleanTrue>`, + {any}`sp.false <sympy.logic.boolalg.BooleanFalse>` + - +* - Basic boolean connectives + - {any}`sp.And <sympy.logic.boolalg.And>`, + {any}`sp.Or <sympy.logic.boolalg.Or>`, + {any}`sp.Not <sympy.logic.boolalg.Not>` + - +::: + +## sympy.tensor + +:::{list-table} +* - Indexed Objects + - {any}`sp.Indexed <sympy.tensor.indexed.Indexed>` + - Base of the indexed object must have a {any}`PsArrayType` of the correct dimensionality. + Currently, only symbols ({any}`sp.Symbol <sympy.core.symbol.Symbol>` or {any}`TypedSymbol`) + can be used as the base of an `Indexed`. +::: diff --git a/docs/source/api/symbolic/sympyextensions.rst b/docs/source/api/symbolic/sympyextensions.rst index 4190569d2ab838356a4f501e6811bb7e9202e666..99b0774c9e6a5a8d9770b2425ee4dbea66aa1916 100644 --- a/docs/source/api/symbolic/sympyextensions.rst +++ b/docs/source/api/symbolic/sympyextensions.rst @@ -95,3 +95,13 @@ Integer Operations integer_functions.round_to_multiple_towards_zero integer_functions.ceil_to_multiple integer_functions.div_ceil + +Bit Masks and Bit-Set Conditionals +---------------------------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: autosummary/sympy_class.rst + + bit_masks.bit_conditional diff --git a/docs/source/backend/objects.md b/docs/source/backend/objects.md new file mode 100644 index 0000000000000000000000000000000000000000..d2b2a2cd5278530930347fc78517c38c668c4f8f --- /dev/null +++ b/docs/source/backend/objects.md @@ -0,0 +1,181 @@ +# Constants, Memory Objects, and Functions + +## Memory Objects: Symbols and Buffers + +### The Memory Model + +In order to reason about memory accesses, mutability, invariance, and aliasing, the *pystencils* backend uses +a very simple memory model. There are three types of memory objects: + +- Symbols ({any}`PsSymbol`), which act as registers for data storage within the scope of a kernel +- Field buffers ({any}`PsBuffer`), which represent a contiguous block of memory the kernel has access to, and +- the *unmanaged heap*, which is a global catch-all memory object which all pointers not belonging to a field + array point into. + +All of these objects are disjoint, and cannot alias each other. +Each symbol exists in isolation, +field buffers do not overlap, +and raw pointers are assumed not to point into memory owned by a symbol or field array. +Instead, all raw pointers point into unmanaged heap memory, and are assumed to *always* alias one another: +Each change brought to unmanaged memory by one raw pointer is assumed to affect the memory pointed to by +another raw pointer. + +### Symbols + +In the pystencils IR, instances of {any}`PsSymbol` represent what is generally known as "virtual registers". +These are memory locations that are private to a function, cannot be aliased or pointed to, and will finally reside +either in physical registers or on the stack. +Each symbol has a name and a data type. The data type may initially be {any}`None`, in which case it should soon after be +determined by the {any}`Typifier`. + +Other than their front-end counterpart {any}`sympy.Symbol <sympy.core.symbol.Symbol>`, +{any}`PsSymbol` instances are mutable; +their properties can and often will change over time. +As a consequence, they are not comparable by value: +two {any}`PsSymbol` instances with the same name and data type will in general *not* be equal. +In fact, most of the time, it is an error to have two identical symbol instances active. + +#### Creating Symbols + +During kernel translation, symbols never exist in isolation, but should always be managed by a {any}`KernelCreationContext`. +Symbols can be created and retrieved using {any}`add_symbol <KernelCreationContext.add_symbol>` +and {any}`find_symbol <KernelCreationContext.find_symbol>`. +A symbol can also be duplicated using {any}`duplicate_symbol <KernelCreationContext.duplicate_symbol>`, +which assigns a new name to the symbol's copy. +The {any}`KernelCreationContext` keeps track of all existing symbols during a kernel translation run +and makes sure that no name and data type conflicts may arise. + +Never call the constructor of {any}`PsSymbol` directly unless you really know what you are doing. + +#### Symbol Properties + +Symbols can be annotated with arbitrary information using *symbol properties*. +Each symbol property type must be a subclass of {any}`PsSymbolProperty`. +It is strongly recommended to implement property types using frozen +[dataclasses](https://docs.python.org/3/library/dataclasses.html). +For example, this snippet defines a property type that models pointer alignment requirements: + +```{code-block} python + +@dataclass(frozen=True) +class AlignmentProperty(UniqueSymbolProperty) + """Require this pointer symbol to be aligned at a particular byte boundary.""" + + byte_boundary: int + +``` + +Inheriting from {any}`UniqueSymbolProperty` ensures that at most one property of this type can be attached to +a symbol at any time. +Properties can be added, queried, and removed using the {any}`PsSymbol` properties API listed below. + +Many symbol properties are more relevant to consumers of generated kernels than to the code generator itself. +The above alignment property, for instance, may be added to a pointer symbol by a vectorization pass +to document its assumption that the pointer be properly aligned, in order to emit aligned load and store instructions. +It then becomes the responsibility of the runtime system embedding the kernel to check this prequesite before calling the kernel. +To make sure this information becomes visible, any properties attached to symbols exposed as kernel parameters will also +be added to their respective {any}`Parameter` instance. + +### Buffers + +Buffers, as represented by the {any}`PsBuffer` class, represent contiguous, n-dimensional, linearized cuboid blocks of memory. +Each buffer has a fixed name and element data type, +and will be represented in the IR via three sets of symbols: + +- The *base pointer* is a symbol of pointer type which points into the buffer's underlying memory area. + Each buffer has at least one, its primary base pointer, whose pointed-to type must be the same as the + buffer's element type. There may be additional base pointers pointing into subsections of that memory. + These additional base pointers may also have deviating data types, as is for instance required for + type erasure in certain cases. + To communicate its role to the code generation system, + each base pointer needs to be marked as such using the {any}`BufferBasePtr` property, + . +- The buffer *shape* defines the size of the buffer in each dimension. Each shape entry is either a `symbol <PsSymbol>` + or a {any}`constant <PsConstant>`. +- The buffer *strides* define the step size to go from one entry to the next in each dimension. + Like the shape, each stride entry is also either a symbol or a constant. + +The shape and stride symbols must all have the same data type, which will be stored as the buffer's index data type. + +#### Creating and Managing Buffers + +Similarily to symbols, buffers are typically managed by the {any}`KernelCreationContext`, which associates each buffer +to a front-end {any}`Field`. Buffers for fields can be obtained using {any}`get_buffer <KernelCreationContext.get_buffer>`. +The context makes sure to avoid name conflicts between buffers. + +## Constants + +In the pystencils IR, numerical constants are represented by the {any}`PsConstant` class. +It interacts with the type system (in particular, with {any}`PsNumericType` and its subclasses) +to facilitate bit-exact storage, arithmetic, and type conversion of constants. + +Each constant has a *value* and a *data type*. As long as the data type is `None`, +the constant is untyped and its value may be any Python object. +To add a data type, an instance of {any}`PsNumericType` must either be set in the constructor, +or be applied by converting an existing constant using {any}`interpret_as <PsConstant.interpret_as>`. +Once a data type is set, the set of legal values is constrained by that type. + +To facilitate the correctness of the internal representation, `PsConstant` calls {any}`PsNumericType.create_constant`. +This method must be overridden by subclasses of `PsNumericType`; it either returns an object +that represents the numerical constant according to the rules of the data type, +or raises an exception if that is not possible. +The fixed-width integers, the IEEE-754 floating point types, and the corresponding vector variants that are +implemented in pystencils use NumPy for this purpose. + +The same protocol is used for type conversion of constants, using {any}`PsConstant.reinterpret_as`. + +## Functions + +The pystencils IR models several kinds of functions: + + - Pure mathematical functions (such as `sqrt`, `sin`, `exp`, ...), through {any}`PsMathFunction`; + - Special numerical constants (such as $\pi$, $e$, $\pm \infty$) as 0-ary functions through {any}`PsConstantFunction` + - External functions with a fixed C-like signature using {any}`CFunction`. + +All of these inherit from the common base class {any}`PsFunction`. +Functions of the former two categories are purely internal to pystencils +and must be lowered to a platform-specific implementation at some point +during the kernel creation process. +The latter category, `CFunction`, represents these concrete functions. +It is used to inject platform-specific runtime APIs, vector intrinsics, and user-defined +external functions into the IR. + +### Side Effects + +`PsMathFunction` and `PsConstantFunction` represent *pure* functions. +Their occurences may be moved, optimized, or eliminated by the code generator. +For `CFunction`, on the other hand, side effects are conservatively assumed, +such that these cannot be freely manipulated. + +## Literals + +In the pystencils IR, a *literal* is an expression string, with an associated data type, +that is taken literally and printed out verbatim by the code generator. +They are represented by the {any}`PsLiteral` class, +and are used to represent compiler-builtins +(like the CUDA variables `threadIdx`, `blockIdx`, ...), +preprocessor macros (like `INFINITY`), +and other pieces of code that could not otherwise be modelled. +Literals are assumed to be *constant* with respect to the kernel, +and their evaluation is assumed to be free of side effects. + +## API Documentation + +```{eval-rst} + +.. automodule:: pystencils.codegen.properties + :members: + +.. automodule:: pystencils.backend.memory + :members: + +.. automodule:: pystencils.backend.constants + :members: + +.. automodule:: pystencils.backend.literals + :members: + +.. automodule:: pystencils.backend.functions + :members: + +``` \ No newline at end of file diff --git a/docs/source/backend/objects.rst b/docs/source/backend/objects.rst deleted file mode 100644 index 942e6070f2c997c7bf3e59d67e7c44bd53806e12..0000000000000000000000000000000000000000 --- a/docs/source/backend/objects.rst +++ /dev/null @@ -1,125 +0,0 @@ -**************************************** -Constants, Memory Objects, and Functions -**************************************** - -Memory Objects: Symbols and Buffers -=================================== - -The Memory Model ----------------- - -In order to reason about memory accesses, mutability, invariance, and aliasing, the *pystencils* backend uses -a very simple memory model. There are three types of memory objects: - -- Symbols (`PsSymbol`), which act as registers for data storage within the scope of a kernel -- Field buffers (`PsBuffer`), which represent a contiguous block of memory the kernel has access to, and -- the *unmanaged heap*, which is a global catch-all memory object which all pointers not belonging to a field - array point into. - -All of these objects are disjoint, and cannot alias each other. -Each symbol exists in isolation, -field buffers do not overlap, -and raw pointers are assumed not to point into memory owned by a symbol or field array. -Instead, all raw pointers point into unmanaged heap memory, and are assumed to *always* alias one another: -Each change brought to unmanaged memory by one raw pointer is assumed to affect the memory pointed to by -another raw pointer. - -Symbols -------- - -In the pystencils IR, instances of `PsSymbol` represent what is generally known as "virtual registers". -These are memory locations that are private to a function, cannot be aliased or pointed to, and will finally reside -either in physical registers or on the stack. -Each symbol has a name and a data type. The data type may initially be `None`, in which case it should soon after be -determined by the `Typifier`. - -Other than their front-end counterpart `sympy.Symbol <sympy.core.symbol.Symbol>`, `PsSymbol` instances are mutable; -their properties can and often will change over time. -As a consequence, they are not comparable by value: -two `PsSymbol` instances with the same name and data type will in general *not* be equal. -In fact, most of the time, it is an error to have two identical symbol instances active. - -Creating Symbols -^^^^^^^^^^^^^^^^ - -During kernel translation, symbols never exist in isolation, but should always be managed by a `KernelCreationContext`. -Symbols can be created and retrieved using `add_symbol <KernelCreationContext.add_symbol>` and `find_symbol <KernelCreationContext.find_symbol>`. -A symbol can also be duplicated using `duplicate_symbol <KernelCreationContext.duplicate_symbol>`, which assigns a new name to the symbol's copy. -The `KernelCreationContext` keeps track of all existing symbols during a kernel translation run -and makes sure that no name and data type conflicts may arise. - -Never call the constructor of `PsSymbol` directly unless you really know what you are doing. - -Symbol Properties -^^^^^^^^^^^^^^^^^ - -Symbols can be annotated with arbitrary information using *symbol properties*. -Each symbol property type must be a subclass of `PsSymbolProperty`. -It is strongly recommended to implement property types using frozen -`dataclasses <https://docs.python.org/3/library/dataclasses.html>`_. -For example, this snippet defines a property type that models pointer alignment requirements: - -.. code-block:: python - - @dataclass(frozen=True) - class AlignmentProperty(UniqueSymbolProperty) - """Require this pointer symbol to be aligned at a particular byte boundary.""" - - byte_boundary: int - -Inheriting from `UniqueSymbolProperty` ensures that at most one property of this type can be attached to -a symbol at any time. -Properties can be added, queried, and removed using the `PsSymbol` properties API listed below. - -Many symbol properties are more relevant to consumers of generated kernels than to the code generator itself. -The above alignment property, for instance, may be added to a pointer symbol by a vectorization pass -to document its assumption that the pointer be properly aligned, in order to emit aligned load and store instructions. -It then becomes the responsibility of the runtime system embedding the kernel to check this prequesite before calling the kernel. -To make sure this information becomes visible, any properties attached to symbols exposed as kernel parameters will also -be added to their respective `Parameter` instance. - -Buffers -------- - -Buffers, as represented by the `PsBuffer` class, represent contiguous, n-dimensional, linearized cuboid blocks of memory. -Each buffer has a fixed name and element data type, -and will be represented in the IR via three sets of symbols: - -- The *base pointer* is a symbol of pointer type which points into the buffer's underlying memory area. - Each buffer has at least one, its primary base pointer, whose pointed-to type must be the same as the - buffer's element type. There may be additional base pointers pointing into subsections of that memory. - These additional base pointers may also have deviating data types, as is for instance required for - type erasure in certain cases. - To communicate its role to the code generation system, - each base pointer needs to be marked as such using the `BufferBasePtr` property, - . -- The buffer *shape* defines the size of the buffer in each dimension. Each shape entry is either a `symbol <PsSymbol>` - or a `constant <PsConstant>`. -- The buffer *strides* define the step size to go from one entry to the next in each dimension. - Like the shape, each stride entry is also either a symbol or a constant. - -The shape and stride symbols must all have the same data type, which will be stored as the buffer's index data type. - -Creating and Managing Buffers -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Similarily to symbols, buffers are typically managed by the `KernelCreationContext`, which associates each buffer -to a front-end `Field`. Buffers for fields can be obtained using `get_buffer <KernelCreationContext.get_buffer>`. -The context makes sure to avoid name conflicts between buffers. - -API Documentation -================= - -.. automodule:: pystencils.codegen.properties - :members: - -.. automodule:: pystencils.backend.memory - :members: - -.. automodule:: pystencils.backend.constants - :members: - -.. autoclass:: pystencils.backend.literals.PsLiteral - :members: - -.. automodule:: pystencils.backend.functions diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 7cb39bced8b01fa9a45f264f7e586f223f3edf97..6dc69793887c278d706367d4cb50e066909582a8 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -11,7 +11,7 @@ from ..memory import PsSymbol, PsBuffer, BufferBasePtr from ..constants import PsConstant from ..literals import PsLiteral from ..functions import PsFunction -from ...types import PsType +from ...types import PsType, deconstify from .util import failing_cast from ..exceptions import PsInternalCompilerError @@ -27,6 +27,10 @@ class PsExpression(PsAstNode, ABC): only constant expressions, symbol expressions, and array accesses immediately inherit their type from their constant, symbol, or array, respectively. + The type assigned to an expression must never be ``const``. + Constness and mutability are properties of memory objects, not of expressions; + they are verified by inspecting the memory objects referred to by an expression. + The canonical way to add types to newly constructed expressions is through the `Typifier`. It should be run at least once on any expression constructed by the backend. @@ -144,7 +148,7 @@ class PsSymbolExpr(PsLeafMixIn, PsLvalue, PsExpression): __match_args__ = ("symbol",) def __init__(self, symbol: PsSymbol): - super().__init__(symbol.dtype) + super().__init__(deconstify(symbol.dtype) if symbol.dtype is not None else None) self._symbol = symbol @property @@ -172,7 +176,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): __match_args__ = ("constant",) def __init__(self, constant: PsConstant): - super().__init__(constant.dtype) + super().__init__(deconstify(constant.dtype) if constant.dtype is not None else None) self._constant = constant @property @@ -200,7 +204,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression): __match_args__ = ("literal",) def __init__(self, literal: PsLiteral): - super().__init__(literal.dtype) + super().__init__(deconstify(literal.dtype) if literal.dtype is not None else None) self._literal = literal @property @@ -238,7 +242,7 @@ class PsBufferAcc(PsLvalue, PsExpression): self._base_ptr = PsExpression.make(base_ptr) self._index = list(index) - self._dtype = bptr_prop.buffer.element_type + self._dtype = deconstify(bptr_prop.buffer.element_type) @property def base_pointer(self) -> PsSymbolExpr: @@ -582,14 +586,26 @@ class PsAddressOf(PsUnOp): class PsCast(PsUnOp): + """C-style type cast. + + Convert values to another type according to C casting rules. + The target type may be `None`, in which case it will be inferred by the `Typifier` + according to the surrounding type context. + + Args: + target_type: Target type of the cast, + or `None` if the target type should be inferred from the surrounding context + operand: Expression whose value will be cast + """ + __match_args__ = ("target_type", "operand") - def __init__(self, target_type: PsType, operand: PsExpression): + def __init__(self, target_type: PsType | None, operand: PsExpression): super().__init__(operand) - self._target_type = target_type + self._target_type = deconstify(target_type) if target_type is not None else None @property - def target_type(self) -> PsType: + def target_type(self) -> PsType | None: return self._target_type @target_type.setter diff --git a/src/pystencils/backend/emission/base_printer.py b/src/pystencils/backend/emission/base_printer.py index c4ac0640c44a222f3fbecef9086fdbd36e68f8bc..ff53acf102c574a2498485c91b08afc48626b98c 100644 --- a/src/pystencils/backend/emission/base_printer.py +++ b/src/pystencils/backend/emission/base_printer.py @@ -397,7 +397,7 @@ class BasePrinter(ABC): pass @abstractmethod - def _type_str(self, dtype: PsType) -> str: + def _type_str(self, dtype: PsType | None) -> str: """Return a valid string representation of the given type""" def _char_and_op(self, node: PsBinOp) -> tuple[str, Ops]: diff --git a/src/pystencils/backend/emission/c_printer.py b/src/pystencils/backend/emission/c_printer.py index 40cd692836117d48a0ab6f955681085c90fa0b86..317af587cce3b3f522a7cf496a9d87fdb663557e 100644 --- a/src/pystencils/backend/emission/c_printer.py +++ b/src/pystencils/backend/emission/c_printer.py @@ -60,7 +60,10 @@ class CAstPrinter(BasePrinter): return dtype.create_literal(constant.value) - def _type_str(self, dtype: PsType): + def _type_str(self, dtype: PsType | None): + if dtype is None: + raise EmissionError("Cannot emit untyped object as C code.") + try: return dtype.c_string() except PsTypeError: diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 4e38de5e9f11ca1d971ae6659f04e6df7b47f64a..3ff61e039c7d858f0775f2140b69531b67fb48cc 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -1,37 +1,10 @@ -""" -Functions supported by pystencils. - -Every supported function might require handling logic in the following modules: - -- In `freeze.FreezeExpressions`, a case in `map_Function` or a separate mapper method to catch its frontend variant -- In each backend platform, a case in `Platform.select_function` to map the function onto a concrete - C/C++ implementation -- If very special typing rules apply, a case in `typification.Typifier`. - -In most cases, typification of function applications will require no special handling. - -.. autoclass:: PsFunction - :members: - -.. autoclass:: MathFunctions - :members: - :undoc-members: - -.. autoclass:: PsMathFunction - :members: - -.. autoclass:: CFunction - :members: - -""" - from __future__ import annotations from typing import Any, Sequence, TYPE_CHECKING from abc import ABC from enum import Enum from ..sympyextensions import ReductionOp -from ..types import PsType +from ..types import PsType, PsNumericType, PsTypeError from .exceptions import PsInternalCompilerError if TYPE_CHECKING: @@ -96,31 +69,17 @@ class MathFunctions(Enum): self.num_args = num_args -class NumericLimitsFunctions(Enum): - """Numerical limits functions supported by the backend. - - Each platform has to materialize these functions to a concrete implementation. - """ - - Min = ("min", 0) - Max = ("max", 0) - - def __init__(self, func_name, num_args): - self.function_name = func_name - self.num_args = num_args - - class PsMathFunction(PsFunction): - """Homogenously typed mathematical functions.""" + """Homogeneously typed mathematical functions.""" __match_args__ = ("func",) - def __init__(self, func: MathFunctions | NumericLimitsFunctions) -> None: + def __init__(self, func: MathFunctions) -> None: super().__init__(func.function_name, func.num_args) self._func = func @property - def func(self) -> MathFunctions | NumericLimitsFunctions: + def func(self) -> MathFunctions: return self._func def __str__(self) -> str: @@ -177,6 +136,91 @@ class PsReductionFunction(PsFunction): return hash(self._func) +class ConstantFunctions(Enum): + """Numerical constant functions. + + Each platform has to materialize these functions to a concrete implementation. + """ + + Pi = "pi" + E = "e" + PosInfinity = "pos_infinity" + NegInfinity = "neg_infinity" + + def __init__(self, func_name): + self.function_name = func_name + + +class PsConstantFunction(PsFunction): + """Data-type-specific numerical constants. + + Represents numerical constants which need to be exactly represented, + e.g. transcendental numbers and non-finite constants. + + Functions of this class are treated the same as `PsConstant` instances + by most transforms. + In particular, they are subject to the same contextual typing rules, + and will be broadcast by the vectorizer. + """ + + __match_args__ = ("func,") + + def __init__( + self, func: ConstantFunctions, dtype: PsNumericType | None = None + ) -> None: + super().__init__(func.function_name, 0) + self._func = func + self._set_dtype(dtype) + + @property + def func(self) -> ConstantFunctions: + return self._func + + @property + def dtype(self) -> PsNumericType | None: + """This constant function's data type, or ``None`` if it is untyped.""" + return self._dtype + + @dtype.setter + def dtype(self, t: PsNumericType): + self._set_dtype(t) + + def get_dtype(self) -> PsNumericType: + """Retrieve this constant function's data type, throwing an exception if it is untyped.""" + if self._dtype is None: + raise PsInternalCompilerError( + "Data type of constant function was not set." + ) + return self._dtype + + def __str__(self) -> str: + return f"{self._func.function_name}" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsConstantFunction): + return False + + return (self._func, self._dtype) == (other._func, other._dtype) + + def __hash__(self) -> int: + return hash((self._func, self._dtype)) + + def _set_dtype(self, dtype: PsNumericType | None): + if dtype is not None: + match self._func: + case ( + ConstantFunctions.Pi + | ConstantFunctions.E + | ConstantFunctions.PosInfinity + | ConstantFunctions.NegInfinity + ) if not dtype.is_float(): + raise PsTypeError( + f"Invalid type for {self.func.function_name}: {dtype}" + ) + + self._dtype = dtype + + class CFunction(PsFunction): """A concrete C function. diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 48e2f4a3a7f349c407472acfef60ec94ecea8f4c..319d6061cfc9ff1baf6068b62b7cc0699a75ff82 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -7,7 +7,7 @@ from collections import namedtuple, defaultdict import re from ..ast.expressions import PsExpression, PsConstantExpr, PsCall -from ..functions import NumericLimitsFunctions, PsMathFunction +from ..functions import PsConstantFunction, ConstantFunctions from ...defaults import DEFAULTS from ...field import Field, FieldType from ...sympyextensions import ReductionOp @@ -235,9 +235,9 @@ class KernelCreationContext: case ReductionOp.Mul: init_val = PsConstantExpr(PsConstant(1)) case ReductionOp.Min: - init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), []) + init_val = PsCall(PsConstantFunction(ConstantFunctions.PosInfinity), []) case ReductionOp.Max: - init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), []) + init_val = PsCall(PsConstantFunction(ConstantFunctions.NegInfinity), []) case _: raise PsInternalCompilerError( f"Unsupported kind of reduction assignment: {reduction_op}." diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 5987165678327eb0e49ca10f4ee66bb7f7fffd39..0fb1bd9865bca1718336b9cf3c6cde4d91397af2 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -16,6 +16,7 @@ from ..reduction_op_mapping import reduction_op_to_expr from ...sympyextensions.typed_sympy import TypedSymbol, TypeCast, DynamicType from ...sympyextensions.pointers import AddressOf, mem_acc from ...sympyextensions.reduction import ReductionAssignment, ReductionOp +from ...sympyextensions.bit_masks import bit_conditional from ...field import Field, FieldType from .context import KernelCreationContext @@ -63,7 +64,12 @@ from ..ast.vector import PsVecMemAcc from ..constants import PsConstant from ...types import PsNumericType, PsStructType, PsType from ..exceptions import PsInputError -from ..functions import PsMathFunction, MathFunctions +from ..functions import ( + PsMathFunction, + MathFunctions, + PsConstantFunction, + ConstantFunctions, +) from ..exceptions import FreezeError @@ -300,6 +306,24 @@ class FreezeExpressions: denom = PsConstantExpr(PsConstant(expr.denominator)) return num / denom + def map_NumberSymbol(self, expr: sp.Number): + func: ConstantFunctions + match expr: + case sp.core.numbers.Pi(): + func = ConstantFunctions.Pi + case sp.core.numbers.Exp1(): + func = ConstantFunctions.E + case _: + raise FreezeError(f"Cannot translate number symbol {expr}") + + return PsCall(PsConstantFunction(func), []) + + def map_Infinity(self, _: sp.core.numbers.Infinity): + return PsCall(PsConstantFunction(ConstantFunctions.PosInfinity), []) + + def map_NegativeInfinity(self, _: sp.core.numbers.NegativeInfinity): + return PsCall(PsConstantFunction(ConstantFunctions.NegInfinity), []) + def map_TypedSymbol(self, expr: TypedSymbol): dtype = self._ctx.resolve_dynamic_type(expr.dtype) symb = self._ctx.get_symbol(expr.name, dtype) @@ -567,3 +591,19 @@ class FreezeExpressions: def map_Not(self, neg: sympy.logic.Not) -> PsNot: arg = self.visit_expr(neg.args[0]) return PsNot(arg) + + def map_bit_conditional(self, conditional: bit_conditional): + args = [self.visit_expr(arg) for arg in conditional.args] + bitpos, mask, then_expr = args[:3] + + one = PsExpression.make(PsConstant(1)) + extract_bit = PsBitwiseAnd(PsRightShift(mask, bitpos), one) + masked_then_expr = PsCast(None, extract_bit) * then_expr + + if len(args) == 4: + else_expr = args[3] + invert_bit = PsBitwiseXor(extract_bit.clone(), one.clone()) + masked_else_expr = PsCast(None, invert_bit) * else_expr + return masked_then_expr + masked_else_expr + else: + return masked_then_expr diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 1c34fac6fdbb4b78830c1c4b88f0485740475a2a..c966262eaa5026682cfc98cd9a4d6955403668c8 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -50,7 +50,7 @@ from ..ast.expressions import ( PsNot, ) from ..ast.vector import PsVecBroadcast, PsVecMemAcc, PsVecHorizontal -from ..functions import PsMathFunction, CFunction +from ..functions import PsMathFunction, CFunction, PsConstantFunction from ..ast.util import determine_memory_object from ..exceptions import TypificationError @@ -66,7 +66,10 @@ class TypeContext: """Typing context, with support for type inference and checking. Instances of this class are used to propagate and check data types across expression subtrees - of the AST. Each type context has a target type `target_type`, which shall be applied to all expressions it covers + of the AST. Each type context has a target type `target_type`, which shall be applied to all expressions it covers. + + Just like the types of expressions, the context's target type must never be ``const``. + This is ensured by this class, which removes const-qualification from the target type when it is set. """ def __init__( @@ -140,6 +143,7 @@ class TypeContext: def _propagate_target_type(self): assert self._target_type is not None + assert not self._target_type.const for hook in self._hooks: hook(self._target_type) @@ -151,9 +155,10 @@ class TypeContext: def _apply_target_type(self, expr: PsExpression): assert self._target_type is not None + assert not self._target_type.const if expr.dtype is not None: - if not self._compatible(expr.dtype): + if expr.dtype != self._target_type: raise TypificationError( f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n" f" Expression type: {expr.dtype}\n" @@ -164,7 +169,8 @@ class TypeContext: case PsConstantExpr(c): if not isinstance(self._target_type, PsNumericType): raise TypificationError( - f"Can't typify constant with non-numeric type {self._target_type}" + f"Can't typify constant with non-numeric type {self._target_type}\n" + f" at: {expr}" ) if c.dtype is None: expr.constant = c.interpret_as(self._target_type) @@ -186,7 +192,7 @@ class TypeContext: case PsSymbolExpr(symb): if symb.dtype is None: # Symbols are not forced to constness - symb.dtype = deconstify(self._target_type) + symb.dtype = self._target_type elif not self._compatible(symb.dtype): raise TypificationError( f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n" @@ -194,6 +200,21 @@ class TypeContext: f" Target type: {self._target_type}" ) + case PsCall(func) if isinstance(func, PsConstantFunction): + if not isinstance(self._target_type, PsNumericType): + raise TypificationError( + f"Can't typify constant function with non-numeric type {self._target_type}\n" + f" at: {expr}" + ) + if func.dtype is None: + func.dtype = constify(self._target_type) + elif not self._compatible(func.dtype): + raise TypificationError( + f"Type mismatch at constant function {func}: Type did not match the context's target type\n" + f" Function type: {func.dtype}\n" + f" Target type: {self._target_type}" + ) + case PsNumericOpTrait() if ( not isinstance(self._target_type, PsNumericType) or self._target_type.is_bool() @@ -224,6 +245,9 @@ class TypeContext: f" Expression: {expr}" f" Type Context: {self._target_type}" ) + + case PsCast(cast_target, _) if cast_target is None: + expr.target_type = self._target_type # endif expr.dtype = self._target_type @@ -628,6 +652,12 @@ class Typifier: self.visit_expr(arg, tc) tc.infer_dtype(expr) + case PsConstantFunction(): + if function.dtype is not None: + tc.apply_dtype(function.dtype, expr) + else: + tc.infer_dtype(expr) + case CFunction(_, arg_types, ret_type): tc.apply_dtype(ret_type, expr) @@ -646,7 +676,7 @@ class Typifier: f" Array: {expr}" ) - case PsCast(dtype, arg): + case PsCast(cast_target_type, arg): arg_tc = TypeContext() self.visit_expr(arg, arg_tc) @@ -655,7 +685,10 @@ class Typifier: f"Unable to determine type of argument to Cast: {arg}" ) - tc.apply_dtype(dtype, expr) + if cast_target_type is None: + tc.infer_dtype(expr) + else: + tc.apply_dtype(cast_target_type, expr) case PsVecBroadcast(lanes, arg): op_tc = TypeContext() diff --git a/src/pystencils/backend/literals.py b/src/pystencils/backend/literals.py index 976e6b2030d2350a0c7105f8a3a17cfcccf393fd..82ee81373d02052b918389c077309fd9d4e8b8c8 100644 --- a/src/pystencils/backend/literals.py +++ b/src/pystencils/backend/literals.py @@ -11,6 +11,9 @@ class PsLiteral: Each literal has to be annotated with a type, and is considered constant within the scope of a kernel. Instances of `PsLiteral` are immutable. + + The code generator assumes literals to be *constant* and *pure* with respect to the kernel: + their evaluation at kernel runtime must not include any side effects. """ __match_args__ = ("text", "dtype") diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index 05e95011dabb9e475d608e292b971dbfa9625c7b..6e6488ee1b461aac3694573835488d88760af361 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -20,12 +20,11 @@ from ..ast.structural import ( PsStructuralNode, ) from ..constants import PsConstant -from ..exceptions import MaterializationError -from ..functions import NumericLimitsFunctions, CFunction +from ..functions import CFunction from ..literals import PsLiteral from ..reduction_op_mapping import reduction_op_to_expr from ...sympyextensions import ReductionOp -from ...types import PsType, PsIeeeFloatType, PsCustomType, PsPointerType, PsScalarType +from ...types import PsIeeeFloatType, PsCustomType, PsPointerType, PsScalarType from ...types.quick import SInt, UInt @@ -34,9 +33,7 @@ class CudaPlatform(GenericGpu): @property def required_headers(self) -> set[str]: - return super().required_headers | { - '"npp.h"', - } + return super().required_headers | {'"pystencils_runtime/cuda.cuh"'} def resolve_reduction( self, @@ -120,20 +117,3 @@ class CudaPlatform(GenericGpu): return shuffles, PsConditional( cond, PsBlock([PsStatement(PsCall(func, func_args))]) ) - - def resolve_numeric_limits( - self, func: NumericLimitsFunctions, dtype: PsType - ) -> PsExpression: - assert isinstance(dtype, PsIeeeFloatType) - - match func: - case NumericLimitsFunctions.Min: - define = f"NPP_MINABS_{dtype.width}F" - case NumericLimitsFunctions.Max: - define = f"NPP_MAXABS_{dtype.width}F" - case _: - raise MaterializationError( - f"Cannot materialize call to function {func}" - ) - - return PsLiteralExpr(PsLiteral(define, dtype)) diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 4f8b562fa6fc3a69c95f5a0a7cffec3a78686bec..1dc2914b9f587a11180e0df426eb2bb38a7377b8 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Sequence +import numpy as np from ..ast.expressions import PsCall, PsMemAcc, PsConstantExpr @@ -7,12 +8,12 @@ from ..ast import PsAstNode from ..functions import ( CFunction, MathFunctions, - NumericLimitsFunctions, ReductionFunctions, PsMathFunction, PsReductionFunction, + PsConstantFunction, + ConstantFunctions, ) -from ..literals import PsLiteral from ..reduction_op_mapping import reduction_op_to_expr from ...sympyextensions import ReductionOp from ...types import PsIntegerType, PsIeeeFloatType, PsScalarType, PsPointerType @@ -34,12 +35,9 @@ from ..ast.expressions import ( PsExpression, PsBufferAcc, PsLookup, - PsGe, - PsLe, - PsTernary, - PsLiteralExpr, ) from ..ast.vector import PsVecMemAcc +from ..kernelcreation import Typifier from ...types import PsVectorType, PsCustomType @@ -71,7 +69,7 @@ class GenericCpu(Platform): self, call: PsCall ) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]: call_func = call.function - assert isinstance(call_func, PsReductionFunction | PsMathFunction) + assert isinstance(call_func, (PsReductionFunction | PsMathFunction | PsConstantFunction)) func = call_func.func @@ -105,62 +103,74 @@ class GenericCpu(Platform): return potential_call dtype = call.get_dtype() - arg_types = (dtype,) * func.num_args + arg_types = (dtype,) * call.function.arg_count + + expr: PsExpression | None = None + + if isinstance(dtype, PsIeeeFloatType): + if dtype.width in (32, 64): + match func: + case ( + MathFunctions.Exp + | MathFunctions.Log + | MathFunctions.Sin + | MathFunctions.Cos + | MathFunctions.Tan + | MathFunctions.Sinh + | MathFunctions.Cosh + | MathFunctions.ASin + | MathFunctions.ACos + | MathFunctions.ATan + | MathFunctions.ATan2 + | MathFunctions.Pow + | MathFunctions.Sqrt + | MathFunctions.Floor + | MathFunctions.Ceil + ): + call.function = CFunction(func.function_name, arg_types, dtype) + expr = call + case MathFunctions.Abs | MathFunctions.Min | MathFunctions.Max: + call.function = CFunction( + "f" + func.function_name, arg_types, dtype + ) + expr = call - if isinstance(dtype, PsScalarType) and func in ( - NumericLimitsFunctions.Min, - NumericLimitsFunctions.Max, - ): - return PsLiteralExpr( - PsLiteral( - f"std::numeric_limits<{dtype.c_string()}>::{func.function_name}()", - dtype, - ) - ) - - if isinstance(dtype, PsIeeeFloatType) and dtype.width in (32, 64): - cfunc: CFunction - match func: - case ( - MathFunctions.Exp - | MathFunctions.Log - | MathFunctions.Sin - | MathFunctions.Cos - | MathFunctions.Tan - | MathFunctions.Sinh - | MathFunctions.Cosh - | MathFunctions.ASin - | MathFunctions.ACos - | MathFunctions.ATan - | MathFunctions.ATan2 - | MathFunctions.Pow - | MathFunctions.Sqrt - | MathFunctions.Floor - | MathFunctions.Ceil - ): - cfunc = CFunction(func.function_name, arg_types, dtype) - case MathFunctions.Abs | MathFunctions.Min | MathFunctions.Max: - cfunc = CFunction("f" + func.function_name, arg_types, dtype) - - call.function = cfunc - return call - - if isinstance(dtype, PsIntegerType): match func: - case MathFunctions.Abs: - zero = PsExpression.make(PsConstant(0, dtype)) - arg = call.args[0] - return PsTernary(PsGe(arg, zero), arg, -arg) - case MathFunctions.Min: - arg1, arg2 = call.args - return PsTernary(PsLe(arg1, arg2), arg1, arg2) - case MathFunctions.Max: - arg1, arg2 = call.args - return PsTernary(PsGe(arg1, arg2), arg1, arg2) - - raise MaterializationError( - f"No implementation available for function {func} on data type {dtype}" - ) + case ConstantFunctions.Pi: + assert dtype.numpy_dtype is not None + expr = PsExpression.make( + PsConstant(dtype.numpy_dtype.type(np.pi), dtype) + ) + + case ConstantFunctions.E: + assert dtype.numpy_dtype is not None + expr = PsExpression.make( + PsConstant(dtype.numpy_dtype.type(np.e), dtype) + ) + + case ConstantFunctions.PosInfinity | ConstantFunctions.NegInfinity: + call.function = CFunction( + f"std::numeric_limits< {dtype.c_string()} >::infinity", + [], + dtype, + ) + if func == ConstantFunctions.NegInfinity: + expr = -call + else: + expr = call + + elif isinstance(dtype, PsIntegerType): + expr = self._select_integer_function(call) + + if expr is not None: + if expr.dtype is None: + typify = Typifier(self._ctx) + expr = typify(expr) + return expr + else: + raise MaterializationError( + f"No implementation available for function {func} on data type {dtype}" + ) # Internals diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 8a4dd11a29c3243bc1cbad2a857646a1a7be038e..2cfe11d51c80888089039905d31b4dc63cf487a1 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -3,15 +3,14 @@ from __future__ import annotations import operator from abc import ABC, abstractmethod from functools import reduce +import numpy as np from ..ast import PsAstNode -from ..constants import PsConstant from ...sympyextensions.reduction import ReductionOp from ...types import ( constify, deconstify, - PsScalarType, - PsType, + PsIntegerType, ) from ...types.quick import SInt from ..exceptions import MaterializationError @@ -26,6 +25,7 @@ from ..kernelcreation import ( AstFactory, ) +from ..constants import PsConstant from ..kernelcreation.context import KernelCreationContext from ..ast.structural import ( PsBlock, @@ -48,13 +48,15 @@ from ..ast.expressions import ( from ..ast.expressions import PsLt, PsAnd from ...types import PsSignedIntegerType, PsIeeeFloatType from ..literals import PsLiteral + from ..functions import ( MathFunctions, CFunction, ReductionFunctions, - NumericLimitsFunctions, PsReductionFunction, PsMathFunction, + PsConstantFunction, + ConstantFunctions, ) int32 = PsSignedIntegerType(width=32, const=False) @@ -205,12 +207,6 @@ class GenericGpu(Platform): '"gpu_atomics.h"', } - @abstractmethod - def resolve_numeric_limits( - self, func: NumericLimitsFunctions, dtype: PsType - ) -> PsExpression: - pass - @abstractmethod def resolve_reduction( self, @@ -302,7 +298,7 @@ class GenericGpu(Platform): self, call: PsCall ) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]: call_func = call.function - assert isinstance(call_func, PsReductionFunction | PsMathFunction) + assert isinstance(call_func, (PsReductionFunction | PsMathFunction | PsConstantFunction)) func = call_func.func @@ -316,10 +312,8 @@ class GenericGpu(Platform): return self.resolve_reduction(ptr_expr, symbol_expr, op) dtype = call.get_dtype() - arg_types = (dtype,) * func.num_args - - if isinstance(dtype, PsScalarType) and isinstance(func, NumericLimitsFunctions): - return self.resolve_numeric_limits(func, dtype) + arg_types = (dtype,) * call.function.arg_count + expr: PsExpression | None = None if isinstance(dtype, PsIeeeFloatType) and func in MathFunctions: match func: @@ -335,7 +329,8 @@ class GenericGpu(Platform): prefix = "h" if dtype.width == 16 else "" suffix = "f" if dtype.width == 32 else "" name = f"{prefix}{func.function_name}{suffix}" - cfunc = CFunction(name, arg_types, dtype) + call.function = CFunction(name, arg_types, dtype) + expr = call case ( MathFunctions.Pow @@ -350,25 +345,53 @@ class GenericGpu(Platform): # These are unavailable for fp16 suffix = "f" if dtype.width == 32 else "" name = f"{func.function_name}{suffix}" - cfunc = CFunction(name, arg_types, dtype) + call.function = CFunction(name, arg_types, dtype) + expr = call case ( MathFunctions.Min | MathFunctions.Max | MathFunctions.Abs ) if dtype.width in (32, 64): suffix = "f" if dtype.width == 32 else "" name = f"f{func.function_name}{suffix}" - cfunc = CFunction(name, arg_types, dtype) + call.function = CFunction(name, arg_types, dtype) + expr = call case MathFunctions.Abs if dtype.width == 16: - cfunc = CFunction(" __habs", arg_types, dtype) + call.function = CFunction(" __habs", arg_types, dtype) + expr = call + + case ConstantFunctions.Pi: + assert dtype.numpy_dtype is not None + expr = PsExpression.make( + PsConstant(dtype.numpy_dtype.type(np.pi), dtype) + ) + + case ConstantFunctions.E: + assert dtype.numpy_dtype is not None + expr = PsExpression.make( + PsConstant(dtype.numpy_dtype.type(np.e), dtype) + ) + + case ConstantFunctions.PosInfinity: + expr = PsExpression.make(PsLiteral(f"PS_FP{dtype.width}_INFINITY", dtype)) + + case ConstantFunctions.NegInfinity: + expr = PsExpression.make(PsLiteral(f"PS_FP{dtype.width}_NEG_INFINITY", dtype)) case _: raise MaterializationError( f"Cannot materialize call to function {func}" ) - call.function = cfunc - return call + if isinstance(dtype, PsIntegerType): + expr = self._select_integer_function(call) + + if expr is not None: + if expr.dtype is None: + typify = Typifier(self._ctx) + typify(expr) + + return expr raise MaterializationError( f"No implementation available for function {func} on data type {dtype}" diff --git a/src/pystencils/backend/platforms/hip.py b/src/pystencils/backend/platforms/hip.py index 404d9bb27bf7eded15b9e2f7dc60fb70ceedb1fe..b712ac69951639957f2d946549d4dafc53700c99 100644 --- a/src/pystencils/backend/platforms/hip.py +++ b/src/pystencils/backend/platforms/hip.py @@ -2,13 +2,10 @@ from __future__ import annotations from .generic_gpu import GenericGpu from ..ast import PsAstNode -from ..ast.expressions import PsExpression, PsLiteralExpr +from ..ast.expressions import PsExpression from ..ast.structural import PsStructuralNode from ..exceptions import MaterializationError -from ..functions import NumericLimitsFunctions -from ..literals import PsLiteral from ...sympyextensions import ReductionOp -from ...types import PsType, PsIeeeFloatType class HipPlatform(GenericGpu): @@ -16,19 +13,7 @@ class HipPlatform(GenericGpu): @property def required_headers(self) -> set[str]: - return super().required_headers | {'"pystencils_runtime/hip.h"', "<limits>"} - - def resolve_numeric_limits( - self, func: NumericLimitsFunctions, dtype: PsType - ) -> PsExpression: - assert isinstance(dtype, PsIeeeFloatType) - - return PsLiteralExpr( - PsLiteral( - f"std::numeric_limits<{dtype.c_string()}>::{func.function_name}()", - dtype, - ) - ) + return super().required_headers | {'"pystencils_runtime/hip.h"'} def resolve_reduction( self, diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py index 7b81865aef50d63e9913161933f90ee86b24533e..f66b15741407ccbc6d14ee168d1ea0bc92334b39 100644 --- a/src/pystencils/backend/platforms/platform.py +++ b/src/pystencils/backend/platforms/platform.py @@ -1,8 +1,11 @@ from abc import ABC, abstractmethod from ..ast import PsAstNode +from ...types import PsIntegerType from ..ast.structural import PsBlock, PsStructuralNode -from ..ast.expressions import PsCall, PsExpression +from ..ast.expressions import PsCall, PsExpression, PsTernary, PsGe, PsLe +from ..functions import PsMathFunction, MathFunctions +from ..constants import PsConstant from ..kernelcreation.context import KernelCreationContext from ..kernelcreation.iteration_space import IterationSpace @@ -44,3 +47,26 @@ class Platform(ABC): If no viable implementation exists, raise a `MaterializationError`. """ pass + + # Some common lowerings + + def _select_integer_function(self, call: PsCall) -> PsExpression | None: + assert isinstance(call.function, PsMathFunction) + + func = call.function.func + dtype = call.get_dtype() + assert isinstance(dtype, PsIntegerType) + + match func: + case MathFunctions.Abs: + zero = PsExpression.make(PsConstant(0, dtype)) + arg = call.args[0] + return PsTernary(PsGe(arg, zero), arg, -arg) + case MathFunctions.Min: + arg1, arg2 = call.args + return PsTernary(PsLe(arg1, arg2), arg1, arg2) + case MathFunctions.Max: + arg1, arg2 = call.args + return PsTernary(PsGe(arg1, arg2), arg1, arg2) + case _: + return None diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py index 22d60f9b0e9f8b18995338973ade8ebe8caab8bf..dd043f06be79a645cc2d80ec37233f659b846273 100644 --- a/src/pystencils/backend/platforms/sycl.py +++ b/src/pystencils/backend/platforms/sycl.py @@ -15,9 +15,6 @@ from ..ast.expressions import ( PsLt, PsAnd, PsCall, - PsGe, - PsLe, - PsTernary, PsLookup, PsBufferAcc, ) @@ -93,17 +90,8 @@ class SyclPlatform(Platform): return call if isinstance(dtype, PsIntegerType): - match func: - case MathFunctions.Abs: - zero = PsExpression.make(PsConstant(0, dtype)) - arg = call.args[0] - return PsTernary(PsGe(arg, zero), arg, -arg) - case MathFunctions.Min: - arg1, arg2 = call.args - return PsTernary(PsLe(arg1, arg2), arg1, arg2) - case MathFunctions.Max: - arg1, arg2 = call.args - return PsTernary(PsGe(arg1, arg2), arg1, arg2) + if (expr := self._select_integer_function(call)) is not None: + return expr raise MaterializationError( f"No implementation available for function {func} on data type {dtype}" diff --git a/src/pystencils/backend/transformations/ast_vectorizer.py b/src/pystencils/backend/transformations/ast_vectorizer.py index c793c424d2417cbbdcc0cf3782e696c7c9226bb6..d63a3522d67f0c13f7a3022b295bf0f301dbaa84 100644 --- a/src/pystencils/backend/transformations/ast_vectorizer.py +++ b/src/pystencils/backend/transformations/ast_vectorizer.py @@ -9,7 +9,7 @@ from ...types import PsType, PsVectorType, PsBoolType, PsScalarType from ..kernelcreation import KernelCreationContext, AstFactory from ..memory import PsSymbol from ..constants import PsConstant -from ..functions import PsMathFunction +from ..functions import PsMathFunction, PsConstantFunction from ..ast import PsAstNode from ..ast.structural import ( @@ -270,7 +270,9 @@ class AstVectorizer: return self.visit(node, vc) @overload - def visit(self, node: PsStructuralNode, vc: VectorizationContext) -> PsStructuralNode: + def visit( + self, node: PsStructuralNode, vc: VectorizationContext + ) -> PsStructuralNode: pass @overload @@ -368,8 +370,13 @@ class AstVectorizer: "since no vectorized version of the counter was present in the context." ) - # Symbols, constants, and literals that can be broadcast - case PsSymbolExpr() | PsConstantExpr() | PsLiteral(): + # Symbols, constants, constant functions, and literals that can be broadcast + case ( + PsSymbolExpr() + | PsConstantExpr() + | PsLiteral() + | PsCall(PsConstantFunction()) + ): if isinstance(expr.dtype, PsScalarType): # Broadcast constant or non-vectorized scalar symbol vec_expr = PsVecBroadcast(vc.lanes, expr.clone()) @@ -381,6 +388,11 @@ class AstVectorizer: # Unary Ops case PsCast(target_type, operand): + if target_type is None: + raise VectorizationError( + f"Unable to vectorize type cast with unknown target type: {expr}" + ) + vec_expr = PsCast( vc.vector_type(target_type), self.visit_expr(operand, vc) ) diff --git a/src/pystencils/backend/transformations/select_functions.py b/src/pystencils/backend/transformations/select_functions.py index 9ce4046931036a5e16eda2c9cd16faa272752acc..d005acb4bcf3042473826383384e16d9fc7dd4fc 100644 --- a/src/pystencils/backend/transformations/select_functions.py +++ b/src/pystencils/backend/transformations/select_functions.py @@ -3,7 +3,7 @@ from ..exceptions import MaterializationError from ..platforms import Platform from ..ast import PsAstNode from ..ast.expressions import PsCall, PsExpression -from ..functions import PsMathFunction, PsReductionFunction +from ..functions import PsMathFunction, PsConstantFunction, PsReductionFunction class SelectFunctions: @@ -48,7 +48,9 @@ class SelectFunctions: ) else: return node - elif isinstance(node, PsCall) and isinstance(node.function, PsMathFunction): + elif isinstance(node, PsCall) and isinstance( + node.function, (PsMathFunction | PsConstantFunction) + ): resolved_func = self._platform.select_function(node) assert isinstance(resolved_func, PsExpression) diff --git a/src/pystencils/bit_masks.py b/src/pystencils/bit_masks.py index 95a59e1ff712e3a6a3f449c92e2c8d8bc0e12a04..67d8b4eed3d4a938b97d4926f5dd8a28b105febc 100644 --- a/src/pystencils/bit_masks.py +++ b/src/pystencils/bit_masks.py @@ -1,9 +1,12 @@ -from .sympyextensions.bit_masks import flag_cond as _flag_cond - +from .sympyextensions.bit_masks import bit_conditional from warnings import warn -warn( - "Importing the `pystencils.bit_masks` module is deprecated. " - "Import `flag_cond` from `pystencils.sympyextensions` instead." -) -flag_cond = _flag_cond + +class flag_cond(bit_conditional): + def __new__(cls, *args, **kwargs): + warn( + "flag_cond is deprecated and will be removed in pystencils 2.1. " + "Use `pystencils.sympyextensions.bit_conditional` instead.", + FutureWarning + ) + return bit_conditional.__new__(cls, *args, **kwargs) diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py index 58340c3e0fbb16b98af2cf08c3d1894ca34a2309..f0a66ac840b42afe66eb37fc17ab4ec87ae16556 100644 --- a/src/pystencils/boundaries/boundaryhandling.py +++ b/src/pystencils/boundaries/boundaryhandling.py @@ -4,6 +4,7 @@ import numpy as np import sympy as sp from pystencils import create_kernel, CreateKernelConfig, Target +from pystencils.types import UserTypeSpec, create_numeric_type from pystencils.assignment import Assignment from pystencils.boundaries.createindexlist import ( create_boundary_index_array, numpy_data_type_for_boundary_object) @@ -84,13 +85,14 @@ class FlagInterface: class BoundaryHandling: def __init__(self, data_handling, field_name, stencil, name="boundary_handling", flag_interface=None, - target: Target = Target.CPU, openmp=True): + target: Target = Target.CPU, default_dtype: UserTypeSpec = "float64", openmp=True): assert data_handling.has_data(field_name) assert data_handling.dim == len(stencil[0]), "Dimension of stencil and data handling do not match" self._data_handling = data_handling self._field_name = field_name self._index_array_name = name + "IndexArrays" self._target = target + self._default_dtype = create_numeric_type(default_dtype) self._openmp = openmp self._boundary_object_to_boundary_info = {} self.stencil = stencil @@ -313,8 +315,11 @@ class BoundaryHandling: return self._boundary_object_to_boundary_info[boundary_obj].flag def _create_boundary_kernel(self, symbolic_field, symbolic_index_field, boundary_obj): - return create_boundary_kernel(symbolic_field, symbolic_index_field, self.stencil, boundary_obj, - target=self._target, cpu_openmp=self._openmp) + cfg = CreateKernelConfig() + cfg.target = self._target + cfg.default_dtype = self._default_dtype + cfg.cpu.openmp.enable = self._openmp + return create_boundary_kernel(symbolic_field, symbolic_index_field, self.stencil, boundary_obj, cfg) def _create_index_fields(self): dh = self._data_handling @@ -452,11 +457,14 @@ class BoundaryOffsetInfo: return sp.Symbol("invdir") -def create_boundary_kernel(field, index_field, stencil, boundary_functor, target=Target.CPU, **kernel_creation_args): +def create_boundary_kernel(field, index_field, stencil, boundary_functor, cfg: CreateKernelConfig): # TODO: reconsider how to control the index_dtype in boundary kernels - config = CreateKernelConfig(index_field=index_field, target=target, index_dtype=SInt(32), **kernel_creation_args) + config = cfg.copy() + config.index_field = index_field + idx_dtype = SInt(32) + config.index_dtype = idx_dtype - offset_info = BoundaryOffsetInfo(stencil, config.index_dtype) + offset_info = BoundaryOffsetInfo(stencil, idx_dtype) elements = offset_info.get_array_declarations() dir_symbol = TypedSymbol("dir", config.index_dtype) elements += [Assignment(dir_symbol, index_field[0]('dir'))] diff --git a/src/pystencils/codegen/config.py b/src/pystencils/codegen/config.py index 8e7e54ff1125a8bba2ba35c223277ee2867c28b7..295289ac01dd5d204101f8734fd27caf43ddebba 100644 --- a/src/pystencils/codegen/config.py +++ b/src/pystencils/codegen/config.py @@ -682,7 +682,7 @@ class CreateKernelConfig(ConfigBase): if cpu_vectorize_info is not None: _deprecated_option("cpu_vectorize_info", "cpu_optim.vectorize") if "instruction_set" in cpu_vectorize_info: - if self.target != Target.GenericCPU: + if self.target is not None and self.target != Target.GenericCPU: raise ValueError( "Setting 'instruction_set' in the deprecated 'cpu_vectorize_info' option is only " "valid if `target == Target.CPU`." diff --git a/src/pystencils/include/pystencils_runtime/bits/gpu_infinities.h b/src/pystencils/include/pystencils_runtime/bits/gpu_infinities.h new file mode 100644 index 0000000000000000000000000000000000000000..0cfc05171699691e8b854ae4d7f2d2fd02e39334 --- /dev/null +++ b/src/pystencils/include/pystencils_runtime/bits/gpu_infinities.h @@ -0,0 +1,10 @@ +#pragma once + +#define PS_FP16_INFINITY __short_as_half(0x7c00) +#define PS_FP16_NEG_INFINITY __short_as_half(0xfc00) + +#define PS_FP32_INFINITY __int_as_float(0x7f800000) +#define PS_FP32_NEG_INFINITY __int_as_float(0xff800000) + +#define PS_FP64_INFINITY __longlong_as_double(0x7ff0000000000000) +#define PS_FP64_NEG_INFINITY __longlong_as_double(0xfff0000000000000) diff --git a/src/pystencils/include/pystencils_runtime/cuda.cuh b/src/pystencils/include/pystencils_runtime/cuda.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6a22e0b9034d224a4fda52233faab73cabd8a01d --- /dev/null +++ b/src/pystencils/include/pystencils_runtime/cuda.cuh @@ -0,0 +1,5 @@ +#pragma once + +#include <cuda_fp16.h> + +#include "./bits/gpu_infinities.h" diff --git a/src/pystencils/include/pystencils_runtime/hip.h b/src/pystencils/include/pystencils_runtime/hip.h index 4bf4917f8aef1054813eb62a5596a908defeb30c..10084103af058d80aa68cc8e2bc58e7c31246181 100644 --- a/src/pystencils/include/pystencils_runtime/hip.h +++ b/src/pystencils/include/pystencils_runtime/hip.h @@ -1,5 +1,9 @@ #pragma once +#include <hip/hip_fp16.h> + +#include "./bits/gpu_infinities.h" + #ifdef __HIPCC_RTC__ typedef __hip_uint8_t uint8_t; typedef __hip_int8_t int8_t; diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py index 760dccf41883c26b21c0fd8fde5820832bead3d1..893589a011e30fb184ee3e845c32ba2aea68eeb9 100644 --- a/src/pystencils/jit/gpu_cupy.py +++ b/src/pystencils/jit/gpu_cupy.py @@ -257,10 +257,6 @@ class CupyJit(JitBase): if '"pystencils_runtime/half.h"' in headers: headers.remove('"pystencils_runtime/half.h"') - if cp.cuda.runtime.is_hip: - headers.add("<hip/hip_fp16.h>") - else: - headers.add("<cuda_fp16.h>") code = "\n".join(f"#include {header}" for header in headers) diff --git a/src/pystencils/simp/simplifications.py b/src/pystencils/simp/simplifications.py index 9368c8f51a4aabd03c15a0741db5930eb8865884..baecf6cb4118770d64a582310c3962facf95b99a 100644 --- a/src/pystencils/simp/simplifications.py +++ b/src/pystencils/simp/simplifications.py @@ -1,13 +1,20 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + from itertools import chain from typing import Callable, List, Sequence, Union from collections import defaultdict import sympy as sp +from ..types import UserTypeSpec from ..assignment import Assignment -from ..sympyextensions import subs_additive, is_constant, recursive_collect +from ..sympyextensions import subs_additive, is_constant, recursive_collect, tcast from ..sympyextensions.typed_sympy import TypedSymbol +if TYPE_CHECKING: + from .assignment_collection import AssignmentCollection + # TODO rewrite with SymPy AST # def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: @@ -170,14 +177,19 @@ def add_subexpressions_for_sums(ac): return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False) -def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None): +def add_subexpressions_for_field_reads( + ac: AssignmentCollection, + subexpressions=True, + main_assignments=True, + data_type: UserTypeSpec | None = None +): r"""Substitutes field accesses on rhs of assignments with subexpressions Can change semantics of the update rule (which is the goal of this transformation) This is useful if a field should be update in place - all values are loaded before into subexpression variables, then the new values are computed and written to the same field in-place. Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have - this data type. This is useful for mixed precision kernels + this data type, and an explicit cast is inserted. This is useful for mixed precision kernels """ field_reads = set() to_iterate = [] @@ -201,8 +213,23 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments substitutions.update({fa: TypedSymbol(lhs.name, data_type)}) else: substitutions.update({fa: lhs}) - return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, - substitute_on_lhs=False, sort_topologically=False) + + ac = ac.new_with_substitutions( + substitutions, + add_substitutions_as_subexpressions=False, + substitute_on_lhs=False, + sort_topologically=False + ) + + loads: list[Assignment] = [] + for fa in field_reads: + rhs = fa if data_type is None else tcast(fa, data_type) + loads.append( + Assignment(substitutions[fa], rhs) + ) + + ac.subexpressions = loads + ac.subexpressions + return ac def transform_rhs(assignment_list, transformation, *args, **kwargs): diff --git a/src/pystencils/sympyextensions/__init__.py b/src/pystencils/sympyextensions/__init__.py index c575feeb3d973ba77d3ecaeb461730dd2b427fef..8c7d10a806a06a42bed5f6510304c517da6fae28 100644 --- a/src/pystencils/sympyextensions/__init__.py +++ b/src/pystencils/sympyextensions/__init__.py @@ -2,6 +2,7 @@ from .astnodes import ConditionalFieldAccess from .typed_sympy import TypedSymbol, CastFunc, tcast, DynamicType from .pointers import mem_acc from .reduction import reduction_assignment, ReductionOp +from .bit_masks import bit_conditional from .math import ( prod, @@ -67,4 +68,5 @@ __all__ = [ "get_symmetric_part", "SymbolCreator", "DynamicType", + "bit_conditional" ] diff --git a/src/pystencils/sympyextensions/bit_masks.py b/src/pystencils/sympyextensions/bit_masks.py index 57f2ab5fb517dae70a36dfeae0a8640b8eced312..36108f4ac323fa59e782efd6d071061ea6838ba6 100644 --- a/src/pystencils/sympyextensions/bit_masks.py +++ b/src/pystencils/sympyextensions/bit_masks.py @@ -2,51 +2,22 @@ import sympy as sp # noinspection PyPep8Naming -class flag_cond(sp.Function): - """Evaluates a flag condition on a bit mask, and returns the value of one of two expressions, - depending on whether the flag is set. - - Three argument version: - ``` - flag_cond(flag_bit, mask, expr) = expr if (flag_bit is set in mask) else 0 - ``` - - Four argument version: - ``` - flag_cond(flag_bit, mask, expr_then, expr_else) = expr_then if (flag_bit is set in mask) else expr_else - ``` - """ - - nargs = (3, 4) - - def __new__(cls, flag_bit, mask_expression, *expressions): +class bit_conditional(sp.Function): + """Evaluates a bit condition on an integer mask, and returns the value of one of two expressions, + depending on whether the bit is set. - # TODO Jan reintroduce checking - # flag_dtype = get_type_of_expression(flag_bit) - # if not flag_dtype.is_int(): - # raise ValueError('Argument flag_bit must be of integer type.') - # - # mask_dtype = get_type_of_expression(mask_expression) - # if not mask_dtype.is_int(): - # raise ValueError('Argument mask_expression must be of integer type.') + Semantics: - return super().__new__(cls, flag_bit, mask_expression, *expressions) + .. code-block:: none - def to_c(self, print_func): - flag_bit = self.args[0] - mask = self.args[1] + # Three-argument version + flag_cond(bitpos, mask, expr) = expr if (bitpos is set in mask) else 0 - then_expression = self.args[2] + # Four-argument version + flag_cond(bitpos, mask, expr_then, expr_else) = expr_then if (bitpos is set in mask) else expr_else - flag_bit_code = print_func(flag_bit) - mask_code = print_func(mask) - then_code = print_func(then_expression) - - code = f"(({mask_code}) >> ({flag_bit_code}) & 1) * ({then_code})" - - if len(self.args) > 3: - else_expression = self.args[3] - else_code = print_func(else_expression) - code += f" + (({mask_code}) >> ({flag_bit_code}) ^ 1) * ({else_code})" + The ``bitpos`` and ``mask`` arguments must both be of the same integer type. + When in doubt, fix the type using `tcast`. + """ - return code + nargs = (3, 4) diff --git a/tests/frontend/test_abs.py b/tests/frontend/test_abs.py deleted file mode 100644 index daab354a3804332c5d36358a0b850f13b8d4f0a6..0000000000000000000000000000000000000000 --- a/tests/frontend/test_abs.py +++ /dev/null @@ -1,21 +0,0 @@ -import pytest - -import pystencils as ps -import sympy - - -@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU)) -def test_abs(target): - if target == ps.Target.GPU: - # FIXME - pytest.xfail("GPU target not ready yet") - - x, y, z = ps.fields('x, y, z: int64[2d]') - - assignments = ps.AssignmentCollection({x[0, 0]: sympy.Abs(y[0, 0])}) - - config = ps.CreateKernelConfig(target=target) - ast = ps.create_kernel(assignments, config=config) - code = ps.get_code_str(ast) - print(code) - assert 'fabs(' not in code diff --git a/tests/frontend/test_bit_masks.py b/tests/frontend/test_bit_masks.py index 6c8d6e85e96a4bf9f98301928be4eacead3bc841..664374c3ea87a92a93a13fa48aa84dcc16416beb 100644 --- a/tests/frontend/test_bit_masks.py +++ b/tests/frontend/test_bit_masks.py @@ -3,36 +3,37 @@ import numpy as np import pystencils as ps from pystencils import Field, Assignment, create_kernel -from pystencils.sympyextensions.bit_masks import flag_cond +from pystencils.sympyextensions import bit_conditional +from pystencils.backend.exceptions import TypificationError -@pytest.mark.parametrize('mask_type', [np.uint8, np.uint16, np.uint32, np.uint64]) -@pytest.mark.xfail(reason="Bit masks not yet supported by the new backend") -def test_flag_condition(mask_type): + +@pytest.mark.parametrize("mask_type", [np.uint8, np.uint16, np.uint32, np.uint64]) +def test_bit_conditional(mask_type): f_arr = np.zeros((2, 2, 2), dtype=np.float64) mask_arr = np.zeros((2, 2), dtype=mask_type) - mask_arr[0, 1] = (1 << 3) - mask_arr[1, 0] = (1 << 5) + mask_arr[0, 1] = 1 << 3 + mask_arr[1, 0] = (1 << 5) + (1 << 7) mask_arr[1, 1] = (1 << 3) + (1 << 5) - f = Field.create_from_numpy_array('f', f_arr, index_dimensions=1) - mask = Field.create_from_numpy_array('mask', mask_arr) + f = Field.create_from_numpy_array("f", f_arr, index_dimensions=1) + mask = Field.create_from_numpy_array("mask", mask_arr) v1 = 42.3 v2 = 39.7 v3 = 119 assignments = [ - Assignment(f(0), flag_cond(3, mask(0), v1)), - Assignment(f(1), flag_cond(5, mask(0), v2, v3)) + Assignment(f(0), bit_conditional(3, mask(0), v1)), + Assignment(f(1), bit_conditional(5, mask(0), v2, v3)), ] kernel = create_kernel(assignments).compile() kernel(f=f_arr, mask=mask_arr) code = ps.get_code_str(kernel) - assert '119.0' in code + assert "119.0" in code reference = np.zeros((2, 2, 2), dtype=np.float64) reference[0, 1, 0] = v1 @@ -45,3 +46,19 @@ def test_flag_condition(mask_type): reference[1, 1, 1] = v2 np.testing.assert_array_equal(f_arr, reference) + + +def test_invalid_mask_type(): + f, invalid_mask = ps.fields("f(1), mask: double[2D]") + asm = Assignment(f(0), bit_conditional(2, invalid_mask(0), 3, 5)) + + with pytest.raises(TypificationError): + _ = create_kernel(asm) + + asm = Assignment( + f(0), + bit_conditional(ps.TypedSymbol("x", "float32"), ps.tcast(0xFE, "uint32"), 3, 5), + ) + + with pytest.raises(TypificationError): + _ = create_kernel(asm) diff --git a/tests/frontend/test_simplifications.py b/tests/frontend/test_simplifications.py index 45cde724108fe7578d8ff2dc9b8a2509a9add728..c39b2fac743a623e0ddddf7132812732321cb298 100644 --- a/tests/frontend/test_simplifications.py +++ b/tests/frontend/test_simplifications.py @@ -147,47 +147,5 @@ def test_add_subexpressions_for_field_reads(): assert len(ac3.subexpressions) == 2 assert isinstance(ac3.subexpressions[0].lhs, TypedSymbol) assert ac3.subexpressions[0].lhs.dtype == create_type("float32") - - -# TODO: What does this test mean to accomplish? -# @pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU)) -# @pytest.mark.parametrize('dtype', ('float32', 'float64')) -# @pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason") -# def test_sympy_optimizations(target, dtype): -# if target == ps.Target.GPU: -# pytest.importorskip("cupy") -# src, dst = ps.fields(f'src, dst: {dtype}[2d]') - -# assignments = ps.AssignmentCollection({ -# src[0, 0]: 1.0 * (sp.exp(dst[0, 0]) - 1) -# }) - -# config = pystencils.config.CreateKernelConfig(target=target, default_dtype=dtype) -# ast = ps.create_kernel(assignments, config=config) - -# ps.show_code(ast) - -# code = ps.get_code_str(ast) -# if dtype == 'float32': -# assert 'expf(' in code -# elif dtype == 'float64': -# assert 'exp(' in code - - -@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU)) -@pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason") -@pytest.mark.xfail(reason="The new backend does not (yet) evaluate transcendental functions") -def test_evaluate_constant_terms(target): - if target == ps.Target.GPU: - pytest.importorskip("cupy") - src, dst = ps.fields('src, dst: float32[2d]') - - # cos of a number will always be simplified - assignments = ps.AssignmentCollection({ - src[0, 0]: -sp.cos(1) + dst[0, 0] - }) - - config = ps.CreateKernelConfig(target=target) - ast = ps.create_kernel(assignments, config=config) - code = ps.get_code_str(ast) - assert 'cos(' not in code and 'cosf(' not in code + assert isinstance(ac3.subexpressions[0].rhs, ps.tcast) + assert ac3.subexpressions[0].rhs.dtype == create_type("float32") diff --git a/tests/kernelcreation/test_buffer_gpu.py b/tests/kernelcreation/test_buffer_gpu.py index bd9d2156b451e10f7a91ed21b76a64e60af9bd03..2f148e4693e3859ef32e80f377957f45c20824ea 100644 --- a/tests/kernelcreation/test_buffer_gpu.py +++ b/tests/kernelcreation/test_buffer_gpu.py @@ -6,7 +6,7 @@ import pytest import pystencils from pystencils import Assignment, Field, FieldType, Target, CreateKernelConfig, create_kernel, fields -from pystencils.sympyextensions.bit_masks import flag_cond +from pystencils.sympyextensions.bit_masks import bit_conditional from pystencils.field import create_numpy_array_with_layout, layout_string_to_tuple from pystencils.slicing import ( add_ghost_layers, get_ghost_region_slice, get_slice_before_ghost_layer) @@ -16,14 +16,13 @@ try: # noinspection PyUnresolvedReferences import cupy as cp except ImportError: - pass - + pytest.skip("Cupy not available", allow_module_level=True) + FIELD_SIZES = [(4, 3), (9, 3, 7)] def _generate_fields(dt=np.uint8, stencil_directions=1, layout='numpy'): - pytest.importorskip('cupy') field_sizes = FIELD_SIZES if stencil_directions > 1: field_sizes = [s + (stencil_directions,) for s in field_sizes] @@ -235,7 +234,6 @@ def test_field_layouts(): unpack_kernel(buffer=gpu_buffer_arr, dst_field=gpu_dst_arr) -@pytest.mark.xfail(reason="flag_cond is not available yet") def test_buffer_indexing(): src_field, dst_field = fields(f'pdfs_src(19), pdfs_dst(19) :double[3D]') mask_field = fields(f'mask : uint32 [3D]') @@ -246,9 +244,9 @@ def test_buffer_indexing(): src_field_size = src_field.spatial_shape mask_field_size = mask_field.spatial_shape - up = Assignment(buffer(0), flag_cond(1, mask_field.center, src_field[0, 1, 0](1))) + up = Assignment(buffer(0), bit_conditional(1, mask_field.center, src_field[0, 1, 0](1))) iteration_slice = tuple(slice(None, None, 2) for _ in range(3)) - config = CreateKernelConfig(target=Target.GPU) + config = CreateKernelConfig(target=Target.CurrentGPU) config = replace(config, iteration_slice=iteration_slice) ast = create_kernel(up, config=config) @@ -268,11 +266,8 @@ def test_buffer_indexing(): assert len(spatial_shape_symbols) <= 3 -@pytest.mark.parametrize('gpu_indexing', ("block", "line")) -def test_iteration_slices(gpu_indexing): - if gpu_indexing == "line": - pytest.xfail("Line indexing not available yet") - +@pytest.mark.parametrize('indexing_scheme', ("linear3d", "blockwise4d")) +def test_iteration_slices(indexing_scheme): num_cell_values = 19 dt = np.uint64 fields = _generate_fields(dt=dt, stencil_directions=num_cell_values) @@ -300,6 +295,7 @@ def test_iteration_slices(gpu_indexing): gpu_dst_arr.fill(0) config = CreateKernelConfig(target=Target.CurrentGPU, iteration_slice=pack_slice) + config.gpu.indexing_scheme = indexing_scheme pack_code = create_kernel(pack_eqs, config=config) pack_kernel = pack_code.compile() @@ -312,6 +308,7 @@ def test_iteration_slices(gpu_indexing): unpack_eqs.append(eq) config = CreateKernelConfig(target=Target.CurrentGPU, iteration_slice=pack_slice) + config.gpu.indexing_scheme = indexing_scheme unpack_code = create_kernel(unpack_eqs, config=config) unpack_kernel = unpack_code.compile() diff --git a/tests/kernelcreation/test_functions.py b/tests/kernelcreation/test_functions.py index 182a590056d68a9677a657877416574db9f81e25..ac684928467fe8e8cebf1edd0f82dc5ed5d115af 100644 --- a/tests/kernelcreation/test_functions.py +++ b/tests/kernelcreation/test_functions.py @@ -16,6 +16,15 @@ from pystencils.backend.ast import dfs_preorder from pystencils.backend.ast.expressions import PsCall +def constant(name, dtype): + return { + "pi": (sp.pi, dtype(np.pi)), + "e": (sp.E, dtype(np.e)), + "infinity": (sp.core.numbers.Infinity(), dtype(np.inf)), + "neg_infinity": (sp.core.numbers.NegativeInfinity(), -dtype(np.inf)), + }[name] + + def unary_function(name, xp): return { "exp": (sp.exp, xp.exp), @@ -211,14 +220,14 @@ def test_binary_functions(gen_config, xp, function_name, dtype, function_domain) dtype_and_target_for_integer_funcs = pytest.mark.parametrize( "dtype, target", - list(product([np.int32], [t for t in AVAIL_TARGETS if not t.is_gpu()])) + list(product([np.int32], AVAIL_TARGETS)) + list( product( [np.int64], [ t for t in AVAIL_TARGETS - if t not in (Target.X86_SSE, Target.X86_AVX) and not t.is_gpu() + if t not in (Target.X86_SSE, Target.X86_AVX) ], ) ), @@ -282,6 +291,37 @@ def test_integer_binary_functions(gen_config, xp, function_name, dtype): xp.testing.assert_array_equal(outp, reference) +@pytest.mark.parametrize("c_name", ["pi", "e", "infinity", "neg_infinity"]) +@pytest.mark.parametrize( + "target, dtype", + list(product(AVAIL_TARGETS, [np.float32, np.float64])) + + [ + (t, np.float16) + for t in AVAIL_TARGETS + if t.is_gpu() or t in (Target.X86_AVX512_FP16,) + ], +) +def test_constants(c_name, dtype, gen_config, xp): + c_sp, c_np = constant(c_name, dtype) + + outp = xp.zeros( + (17,), dtype=dtype + ) # 17 entries to run both vectorized and remainder loops + reference = xp.zeros_like(outp) + reference[:] = c_np + + outp_field = Field.create_from_numpy_array("outp", outp) + asm = Assignment(outp_field(0), c_sp) + + gen_config = replace(gen_config, default_dtype=dtype) + + kernel = create_kernel(asm, gen_config) + kfunc = kernel.compile() + kfunc(outp=outp) + + xp.testing.assert_array_equal(outp, reference) + + @pytest.mark.parametrize("a", [sp.Symbol("a"), fields("a: float64[2d]").center]) def test_avoid_pow(a): x = fields("x: float64[2d]") diff --git a/tests/kernelcreation/test_struct_types.py b/tests/kernelcreation/test_struct_types.py index de50527a753abb73001fc69496dd8eb15b06c363..ae2418220882bcaf447506c588eafcad57797522 100644 --- a/tests/kernelcreation/test_struct_types.py +++ b/tests/kernelcreation/test_struct_types.py @@ -5,11 +5,8 @@ from pystencils import Assignment, Field, create_kernel @pytest.mark.parametrize("order", ['c', 'f']) -@pytest.mark.parametrize("align", [True, False]) -def test_fixed_sized_field(order, align): - if not align: - pytest.xfail("Non-Aligned structs not supported") - dt = np.dtype([('e1', np.float32), ('e2', np.double), ('e3', np.double)], align=align) +def test_fixed_sized_field(order): + dt = np.dtype([('e1', np.float32), ('e2', np.double), ('e3', np.double)], align=True) arr = np.zeros((3, 2), dtype=dt, order=order) f = Field.create_from_numpy_array("f", arr) @@ -27,11 +24,8 @@ def test_fixed_sized_field(order, align): @pytest.mark.parametrize("order", ['c', 'f']) -@pytest.mark.parametrize("align", [True, False]) -def test_variable_sized_field(order, align): - if not align: - pytest.xfail("Non-Aligned structs not supported") - dt = np.dtype([('e1', np.float32), ('e2', np.double), ('e3', np.double)], align=align) +def test_variable_sized_field(order): + dt = np.dtype([('e1', np.float32), ('e2', np.double), ('e3', np.double)], align=True) f = Field.create_generic("f", 2, dt, layout=order) d = Field.create_generic("d", 2, dt, layout=order) diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index fe31cf94cdfe5b07780baab166c812551f866701..679a7cebcd74e794946e78efecd2e0bdedb97862 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -10,7 +10,7 @@ from pystencils import ( DynamicType, KernelConstraintsError, ) -from pystencils.sympyextensions import tcast +from pystencils.sympyextensions import tcast, bit_conditional from pystencils.sympyextensions.pointers import mem_acc from pystencils.backend.ast.structural import ( @@ -48,7 +48,7 @@ from pystencils.backend.ast.expressions import ( PsSymbolExpr, ) from pystencils.backend.constants import PsConstant -from pystencils.backend.functions import PsMathFunction, MathFunctions +from pystencils.backend.functions import PsMathFunction, MathFunctions, PsConstantFunction, ConstantFunctions from pystencils.backend.kernelcreation import ( KernelCreationContext, FreezeExpressions, @@ -131,6 +131,31 @@ def test_freeze_fields(): assert fasm.structurally_equal(should) +def test_freeze_constants(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + expr = freeze(sp.pi) + assert isinstance(expr, PsCall) + assert isinstance(expr.function, PsConstantFunction) + assert expr.function.func == ConstantFunctions.Pi + + expr = freeze(sp.E) + assert isinstance(expr, PsCall) + assert isinstance(expr.function, PsConstantFunction) + assert expr.function.func == ConstantFunctions.E + + expr = freeze(sp.oo) + assert isinstance(expr, PsCall) + assert isinstance(expr.function, PsConstantFunction) + assert expr.function.func == ConstantFunctions.PosInfinity + + expr = freeze(- sp.oo) + assert isinstance(expr, PsCall) + assert isinstance(expr.function, PsConstantFunction) + assert expr.function.func == ConstantFunctions.NegInfinity + + def test_freeze_integer_binops(): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) @@ -603,3 +628,63 @@ def test_indexed(): expr = freeze(a[x, y, z]) assert expr.structurally_equal(PsSubscript(a2, (x2, y2, z2))) + + +def test_freeze_bit_conditional(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + x, y, z = sp.symbols("x, y, z") + expr = freeze(bit_conditional(x, y, z)) + one = freeze(sp.Integer(1)) + + assert expr.structurally_equal( + PsMul( + PsCast( + None, + PsBitwiseAnd( + PsRightShift( + PsExpression.make(ctx.get_symbol("y")), + PsExpression.make(ctx.get_symbol("x")), + ), + one, + ), + ), + PsExpression.make(ctx.get_symbol("z")), + ) + ) + + expr = freeze(bit_conditional(x, y, z, z)) + assert expr.structurally_equal( + PsAdd( + PsMul( + PsCast( + None, + PsBitwiseAnd( + PsRightShift( + PsExpression.make(ctx.get_symbol("y")), + PsExpression.make(ctx.get_symbol("x")), + ), + one, + ), + ), + PsExpression.make(ctx.get_symbol("z")), + ), + PsMul( + PsCast( + None, + PsBitwiseXor( + PsBitwiseAnd( + PsRightShift( + PsExpression.make(ctx.get_symbol("y")), + PsExpression.make(ctx.get_symbol("x")), + ), + one, + ), + one, + ), + ), + PsExpression.make(ctx.get_symbol("z")), + ), + ) + ) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 31df7090d45e9326e9e625d6afdb4c170229580e..62b0106b9ab7fafc9febb25fef00ba404a3f1282 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -36,9 +36,11 @@ from pystencils.backend.ast.expressions import ( PsMemAcc ) from pystencils.backend.ast.vector import PsVecBroadcast, PsVecHorizontal +from pystencils.backend.ast import dfs_preorder +from pystencils.backend.ast.expressions import PsAdd from pystencils.backend.constants import PsConstant -from pystencils.backend.functions import CFunction -from pystencils.types import constify, create_type, create_numeric_type, PsVectorType +from pystencils.backend.functions import CFunction, PsConstantFunction, ConstantFunctions +from pystencils.types import constify, create_type, create_numeric_type, PsVectorType, PsTypeError from pystencils.types.quick import Fp, Int, Bool, Arr, Ptr from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions @@ -85,6 +87,60 @@ def test_typify_simple(): check(fasm.rhs) +def test_typify_constants(): + ctx = KernelCreationContext(default_dtype=Fp(32)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + for constant in [sp.sympify(0), sp.sympify(1), sp.Rational(1, 2), sp.pi, sp.E, sp.oo, - sp.oo]: + # Constant on its own + expr, _ = typify.typify_expression(freeze(constant), ctx.default_dtype) + + for node in dfs_preorder(expr): + assert isinstance(node, PsExpression) + assert node.dtype == ctx.default_dtype + match node: + case PsConstantExpr(c): + assert c.dtype == constify(ctx.default_dtype) + case PsCall(func) if isinstance(func, PsConstantFunction): + assert func.dtype == constify(ctx.default_dtype) + + +def test_constants_contextual_typing(): + ctx = KernelCreationContext(default_dtype=Fp(32)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + fp16 = Fp(16) + x = TypedSymbol("x", fp16) + + for constant in [sp.sympify(0), sp.sympify(1), sp.Rational(1, 2), sp.pi, sp.E, sp.oo, - sp.oo]: + expr = freeze(constant) + freeze(x) # Freeze separately such that SymPy does not simplify + expr = typify(expr) + + assert isinstance(expr, PsAdd) + + for node in dfs_preorder(expr): + assert isinstance(node, PsExpression) + assert node.dtype == fp16 + match node: + case PsConstantExpr(c): + assert c.dtype == constify(fp16) + case PsCall(func) if isinstance(func, PsConstantFunction): + assert func.dtype == constify(fp16) + + +def test_no_integer_infinities_and_transcendentals(): + ctx = KernelCreationContext(default_dtype=Fp(32)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + for sp_expr in [sp.oo, - sp.oo, sp.pi, sp.E]: + expr = freeze(sp_expr) + with pytest.raises(PsTypeError): + typify.typify_expression(expr, Int(32)) + + def test_lhs_constness(): default_type = Fp(32) ctx = KernelCreationContext(default_dtype=default_type) @@ -622,6 +678,29 @@ def test_cfunction(): _ = typify(PsCall(threeway, (x, p))) +def test_typify_typecast(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + + x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] + p, q = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "pq"] + + # Explicit target type + expr = typify(PsCast(Int(64), x)) + assert expr.dtype == expr.target_type == Int(64) + + # Infer target type from context + cast_expr = PsCast(None, p) + expr = typify(y + cast_expr) + assert expr.dtype == Fp(32) + assert cast_expr.dtype == cast_expr.target_type == Fp(32) + + # Invalid target type + expr = p + PsCast(Fp(64), q) + with pytest.raises(TypificationError): + typify(expr) + + def test_typify_integer_vectors(): ctx = KernelCreationContext() typify = Typifier(ctx) @@ -650,6 +729,22 @@ def test_typify_bool_vectors(): assert result.get_dtype() == PsVectorType(Bool(), 4) +def test_propagate_constant_type_in_broadcast(): + fp16 = Fp(16) + + for constant in [ + PsConstantFunction(ConstantFunctions.E, fp16)(), + PsConstantFunction(ConstantFunctions.PosInfinity, fp16)(), + PsConstantExpr(PsConstant(3.5, fp16)) + ]: + ctx = KernelCreationContext(default_dtype=Fp(32)) + typify = Typifier(ctx) + + expr = PsVecBroadcast(4, constant) + expr = typify(expr) + assert expr.dtype == PsVectorType(fp16, 4) + + def test_typify_horizontal_vector_reductions(): ctx = KernelCreationContext() typify = Typifier(ctx) diff --git a/tests/nbackend/transformations/test_ast_vectorizer.py b/tests/nbackend/transformations/test_ast_vectorizer.py index 3ccb479e5552bcd02954b9ed8518ef3ad0f90bfb..248f7275a9793dfe584cdb280b343279a0de1a3a 100644 --- a/tests/nbackend/transformations/test_ast_vectorizer.py +++ b/tests/nbackend/transformations/test_ast_vectorizer.py @@ -97,6 +97,22 @@ def test_vectorize_expressions(): assert subexpr.dtype == vector_type +def test_broadcast_constants(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + ctr = ctx.get_symbol("ctr", ctx.index_dtype) + axis = VectorizationAxis(ctr) + vc = VectorizationContext(ctx, 4, axis) + vectorize = AstVectorizer(ctx) + + for constant in (sp.sympify(14), sp.pi, - sp.oo): + expr = factory.parse_sympy(constant) + vec_expr = vectorize(expr, vc) + assert isinstance(vec_expr, PsVecBroadcast) + assert vec_expr.dtype == PsVectorType(ctx.default_dtype, 4) + + def test_vectorize_casts_and_counter(): ctx = KernelCreationContext() factory = AstFactory(ctx) diff --git a/tests/runtime/test_boundary.py b/tests/runtime/test_boundary.py index 226510b83d8832a5a189552df5c8760235f0d598..422553bcafb0ca1278f70f63a725d6f1cba8f496 100644 --- a/tests/runtime/test_boundary.py +++ b/tests/runtime/test_boundary.py @@ -222,15 +222,17 @@ def test_boundary_data_setter(): assert np.all(data_setter.link_positions(1) == 6.) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize('with_indices', ('with_indices', False)) -def test_dirichlet(with_indices): +def test_dirichlet(dtype, with_indices): value = (1, 20, 3) if with_indices else 1 dh = SerialDataHandling(domain_size=(7, 7)) - src = dh.add_array('src', values_per_cell=3 if with_indices else 1) - dh.cpu_arrays.src[...] = np.random.rand(*src.shape) + src = dh.add_array('src', values_per_cell=3 if with_indices else 1, dtype=dtype) + rng = np.random.default_rng() + dh.cpu_arrays.src[...] = rng.random(src.shape, dtype=dtype) boundary_stencil = [(1, 0), (-1, 0), (0, 1), (0, -1)] - boundary_handling = BoundaryHandling(dh, src.name, boundary_stencil) + boundary_handling = BoundaryHandling(dh, src.name, boundary_stencil, default_dtype=dtype) dirichlet = Dirichlet(value) assert dirichlet.name == 'Dirichlet' dirichlet.name = "wall"