diff --git a/docs/source/backend/index.rst b/docs/source/backend/index.rst index 1e3968bc0e4137b7a43796911de646cdddcb9ab6..e0e914b4d423fb5b9e32950185c6aa3474976d39 100644 --- a/docs/source/backend/index.rst +++ b/docs/source/backend/index.rst @@ -9,11 +9,12 @@ who wish to customize or extend the behaviour of the code generator in their app .. toctree:: :maxdepth: 1 - symbols + objects ast iteration_space translation platforms + transformations jit Internal Representation diff --git a/docs/source/backend/symbols.rst b/docs/source/backend/objects.rst similarity index 80% rename from docs/source/backend/symbols.rst rename to docs/source/backend/objects.rst index 66c8c43ba63c7740f033e7409cb5fc6f6be9bc07..b0c3af6db67ff3cfb1e6a3d3603e84e6c4abb6cb 100644 --- a/docs/source/backend/symbols.rst +++ b/docs/source/backend/objects.rst @@ -8,5 +8,8 @@ Symbols, Constants and Arrays .. autoclass:: pystencils.backend.constants.PsConstant :members: +.. autoclass:: pystencils.backend.literals.PsLiteral + :members: + .. automodule:: pystencils.backend.arrays :members: diff --git a/docs/source/backend/transformations.rst b/docs/source/backend/transformations.rst new file mode 100644 index 0000000000000000000000000000000000000000..44bf4da23e160edaac5e2dd9918fbd389aba94d6 --- /dev/null +++ b/docs/source/backend/transformations.rst @@ -0,0 +1,7 @@ +******************* +AST Transformations +******************* + +`pystencils.backend.transformations` + +.. automodule:: pystencils.backend.transformations diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 0666d96873d4bdd3d722a7912b6e704b4aee1cf8..7bcf62b973d8ace8e9ad9847ae165c398f1cbb0e 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -5,6 +5,7 @@ import operator from ..symbols import PsSymbol from ..constants import PsConstant +from ..literals import PsLiteral from ..arrays import PsLinearizedArray, PsArrayBasePointer from ..functions import PsFunction from ...types import ( @@ -76,12 +77,19 @@ class PsExpression(PsAstNode, ABC): def make(obj: PsConstant) -> PsConstantExpr: pass + @overload + @staticmethod + def make(obj: PsLiteral) -> PsLiteralExpr: + pass + @staticmethod - def make(obj: PsSymbol | PsConstant) -> PsSymbolExpr | PsConstantExpr: + def make(obj: PsSymbol | PsConstant | PsLiteral) -> PsExpression: if isinstance(obj, PsSymbol): return PsSymbolExpr(obj) elif isinstance(obj, PsConstant): return PsConstantExpr(obj) + elif isinstance(obj, PsLiteral): + return PsLiteralExpr(obj) else: raise ValueError(f"Cannot make expression out of {obj}") @@ -150,6 +158,34 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): def __repr__(self) -> str: return f"PsConstantExpr({repr(self._constant)})" + + +class PsLiteralExpr(PsLeafMixIn, PsExpression): + __match_args__ = ("literal",) + + def __init__(self, literal: PsLiteral): + super().__init__(literal.dtype) + self._literal = literal + + @property + def literal(self) -> PsLiteral: + return self._literal + + @literal.setter + def literal(self, lit: PsLiteral): + self._literal = lit + + def clone(self) -> PsLiteralExpr: + return PsLiteralExpr(self._literal) + + def structurally_equal(self, other: PsAstNode) -> bool: + if not isinstance(other, PsLiteralExpr): + return False + + return self._literal == other._literal + + def __repr__(self) -> str: + return f"PsLiteralExpr({repr(self._literal)})" class PsSubscript(PsLvalue, PsExpression): diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 588ac410a6118b668eacd08114fdea3c7853ba6f..f3d56c6c4c20e5969ee10d08ee42b6803a2e0b1c 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -34,6 +34,7 @@ from .ast.expressions import ( PsSub, PsSubscript, PsSymbolExpr, + PsLiteralExpr, PsVectorArrayAccess, PsAnd, PsOr, @@ -245,6 +246,9 @@ class CAstPrinter: ) return dtype.create_literal(constant.value) + + case PsLiteralExpr(lit): + return lit.text case PsVectorArrayAccess(): raise EmissionError("Cannot print vectorized array accesses") diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 313b622beaa151d294b8bcf6c66d730830ce2497..e420deaa657e0569996837c64a10187e323cf511 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -15,10 +15,13 @@ TODO: Figure out the best way to describe function signatures and overloads for """ from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import Any, Sequence, TYPE_CHECKING from abc import ABC from enum import Enum +from ..types import PsType +from .exceptions import PsInternalCompilerError + if TYPE_CHECKING: from .ast.expressions import PsExpression @@ -69,19 +72,59 @@ class PsFunction(ABC): class CFunction(PsFunction): - """A concrete C function.""" + """A concrete C function. + + Instances of this class represent a C function by its name, parameter types, and return type. + + Args: + name: Function name + param_types: Types of the function parameters + return_type: The function's return type + """ + + __match_args__ = ("name", "parameter_types", "return_type") + + @staticmethod + def parse(obj) -> CFunction: + """Parse the signature of a Python callable object to obtain a CFunction object. + + The callable must be fully annotated with type-like objects convertible by `create_type`. + """ + import inspect + from pystencils.types import create_type - def __init__(self, qualified_name: str, arg_count: int): - self._qname = qualified_name - self._arg_count = arg_count + if not inspect.isfunction(obj): + raise PsInternalCompilerError(f"Cannot parse object {obj} as a function") + + func_sig = inspect.signature(obj) + func_name = obj.__name__ + arg_types = [ + create_type(param.annotation) for param in func_sig.parameters.values() + ] + ret_type = create_type(func_sig.return_annotation) + + return CFunction(func_name, arg_types, ret_type) + + def __init__(self, name: str, param_types: Sequence[PsType], return_type: PsType): + super().__init__(name, len(param_types)) + + self._param_types = tuple(param_types) + self._return_type = return_type @property - def qualified_name(self) -> str: - return self._qname + def parameter_types(self) -> tuple[PsType, ...]: + return self._param_types @property - def arg_count(self) -> int: - return self._arg_count + def return_type(self) -> PsType: + return self._return_type + + def __str__(self) -> str: + param_types = ", ".join(str(t) for t in self._param_types) + return f"{self._return_type} {self._name}({param_types})" + + def __repr__(self) -> str: + return f"CFunction({self._name}, {self._param_types}, {self._return_type})" class PsMathFunction(PsFunction): diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index c2334f54c34d476207eddc5466b2b13bff0d39d8..83c406b0a99d52cee9599f321d6c32477f6dbf8a 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -1,10 +1,11 @@ from typing import Any, Sequence, cast, overload +import numpy as np import sympy as sp from sympy.codegen.ast import AssignmentBase from ..ast import PsAstNode -from ..ast.expressions import PsExpression, PsSymbolExpr +from ..ast.expressions import PsExpression, PsSymbolExpr, PsConstantExpr from ..ast.structural import PsLoop, PsBlock, PsAssignment from ..symbols import PsSymbol @@ -16,6 +17,10 @@ from .typification import Typifier from .iteration_space import FullIterationSpace +IndexParsable = PsExpression | PsSymbol | PsConstant | sp.Expr | int | np.integer +_IndexParsable = (PsExpression, PsSymbol, PsConstant, sp.Expr, int, np.integer) + + class AstFactory: """Factory providing a convenient interface for building syntax trees. @@ -51,6 +56,45 @@ class AstFactory: """ return self._typify(self._freeze(sp_obj)) + @overload + def parse_index(self, idx: sp.Symbol | PsSymbol | PsSymbolExpr) -> PsSymbolExpr: + pass + + @overload + def parse_index( + self, idx: int | np.integer | PsConstant | PsConstantExpr + ) -> PsConstantExpr: + pass + + @overload + def parse_index(self, idx: sp.Expr | PsExpression) -> PsExpression: + pass + + def parse_index(self, idx: IndexParsable): + """Parse the given object as an expression with data type `ctx.index_dtype`.""" + + if not isinstance(idx, _IndexParsable): + raise TypeError( + f"Cannot parse object of type {type(idx)} as an index expression" + ) + + match idx: + case PsExpression(): + return self._typify.typify_expression(idx, self._ctx.index_dtype)[0] + case PsSymbol() | PsConstant(): + return self._typify.typify_expression( + PsExpression.make(idx), self._ctx.index_dtype + )[0] + case sp.Expr(): + return self._typify.typify_expression( + self._freeze(idx), self._ctx.index_dtype + )[0] + case _: + return PsExpression.make(PsConstant(idx, self._ctx.index_dtype)) + + def _parse_any_index(self, idx: Any) -> PsExpression: + return self.parse_index(cast(IndexParsable, idx)) + def parse_slice( self, slic: slice, upper_limit: Any | None = None ) -> tuple[PsExpression, PsExpression, PsExpression]: @@ -75,27 +119,16 @@ class AstFactory: "Must specify an upper iteration limit if `slice.stop` is `None` or a negative `int`" ) - def make_expr(val: Any) -> PsExpression: - match val: - case PsExpression(): - return self._typify.typify_expression(val, self._ctx.index_dtype)[0] - case PsSymbol() | PsConstant(): - return self._typify.typify_expression( - PsExpression.make(val), self._ctx.index_dtype - )[0] - case sp.Expr(): - return self._typify.typify_expression( - self._freeze(val), self._ctx.index_dtype - )[0] - case _: - return PsExpression.make(PsConstant(val, self._ctx.index_dtype)) - - start = make_expr(slic.start if slic.start is not None else 0) - stop = make_expr(slic.stop) if slic.stop is not None else make_expr(upper_limit) - step = make_expr(slic.step if slic.step is not None else 1) + start = self._parse_any_index(slic.start if slic.start is not None else 0) + stop = ( + self._parse_any_index(slic.stop) + if slic.stop is not None + else self._parse_any_index(upper_limit) + ) + step = self._parse_any_index(slic.step if slic.step is not None else 1) if isinstance(slic.stop, int) and slic.stop < 0: - stop = make_expr(upper_limit) + stop + stop = self._parse_any_index(upper_limit) + stop return start, stop, step diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index d48953a5b2ccaf1325006f09ec63f5f8fea90f26..263c2f48ecfff15dd4c6271f77ca2a7578b86d09 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -1,9 +1,10 @@ from __future__ import annotations from typing import Iterable, Iterator -from itertools import chain +from itertools import chain, count from types import EllipsisType -from collections import namedtuple +from collections import namedtuple, defaultdict +import re from ...defaults import DEFAULTS from ...field import Field, FieldType @@ -67,6 +68,9 @@ class KernelCreationContext: self._symbols: dict[str, PsSymbol] = dict() + self._symbol_ctr_pattern = re.compile(r"__[0-9]+$") + self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0) + self._fields_and_arrays: dict[str, FieldArrayPair] = dict() self._fields_collection = FieldsInKernel() @@ -95,6 +99,21 @@ class KernelCreationContext: # Symbols def get_symbol(self, name: str, dtype: PsType | None = None) -> PsSymbol: + """Retrieve the symbol with the given name and data type from the symbol table. + + If no symbol named ``name`` exists, a new symbol with the given data type is created. + + If a symbol with the given ``name`` already exists and ``dtype`` is not `None`, + the given data type will be applied to it, and it is returned. + If the symbol already has a different data type, an error will be raised. + + If the symbol already exists and ``dtype`` is `None`, the existing symbol is returned + without checking or altering its data type. + + Args: + name: The symbol's name + dtype: The symbol's data type, or `None` + """ if name not in self._symbols: symb = PsSymbol(name, None) self._symbols[name] = symb @@ -115,12 +134,20 @@ class KernelCreationContext: return self._symbols.get(name, None) def add_symbol(self, symbol: PsSymbol): + """Add an existing symbol to the symbol table. + + If a symbol with the same name already exists, an error will be raised. + """ if symbol.name in self._symbols: raise PsInternalCompilerError(f"Duplicate symbol: {symbol.name}") self._symbols[symbol.name] = symbol def replace_symbol(self, old: PsSymbol, new: PsSymbol): + """Replace one symbol by another. + + The two symbols ``old`` and ``new`` must have the same name, but may have different data types. + """ if old.name != new.name: raise PsInternalCompilerError( "replace_symbol: Old and new symbol must have the same name" @@ -131,8 +158,30 @@ class KernelCreationContext: self._symbols[old.name] = new + def duplicate_symbol(self, symb: PsSymbol) -> PsSymbol: + """Canonically duplicates the given symbol. + + A new symbol with the same data type, and new name ``symb.name + "__<counter>"`` is created, + added to the symbol table, and returned. + The ``counter`` reflects the number of previously created duplicates of this symbol. + """ + if (result := self._symbol_ctr_pattern.search(symb.name)) is not None: + span = result.span() + basename = symb.name[: span[0]] + else: + basename = symb.name + + initial_count = self._symbol_dup_table[basename] + for i in count(initial_count): + dup_name = f"{basename}__{i}" + if self.find_symbol(dup_name) is None: + self._symbol_dup_table[basename] = i + 1 + return self.get_symbol(dup_name, symb.dtype) + assert False, "unreachable code" + @property def symbols(self) -> Iterable[PsSymbol]: + """Return an iterable of all symbols listed in the symbol table.""" return self._symbols.values() # Fields and Arrays diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index ba215f822ea7372211bf764425d44e44487cc46b..6adac2a519ffc04505c0e0adac3484d78f30d013 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -121,7 +121,7 @@ class FullIterationSpace(IterationSpace): @staticmethod def create_from_slice( ctx: KernelCreationContext, - iteration_slice: Sequence[slice], + iteration_slice: slice | Sequence[slice], archetype_field: Field | None = None, ): """Create an iteration space from a sequence of slices, optionally over an archetype field. @@ -131,6 +131,9 @@ class FullIterationSpace(IterationSpace): iteration_slice: The iteration slices for each dimension; for valid formats, see `AstFactory.parse_slice` archetype_field: Optionally, an archetype field that dictates the upper slice limits and loop order. """ + if isinstance(iteration_slice, slice): + iteration_slice = (iteration_slice,) + dim = len(iteration_slice) if dim == 0: raise ValueError( diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 1bf3c49807ff52a34bc9ab319f5da67e4fa59ebc..dbec20235f0a37cfab763771e2e2fbed05a3c196 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -39,11 +39,12 @@ from ..ast.expressions import ( PsLookup, PsSubscript, PsSymbolExpr, + PsLiteralExpr, PsRel, PsNeg, PsNot, ) -from ..functions import PsMathFunction +from ..functions import PsMathFunction, CFunction __all__ = ["Typifier"] @@ -158,6 +159,14 @@ class TypeContext: f" Constant type: {c.dtype}\n" f" Target type: {self._target_type}" ) + + case PsLiteralExpr(lit): + if not self._compatible(lit.dtype): + raise TypificationError( + f"Type mismatch at literal {lit}: Literal type did not match the context's target type\n" + f" Literal type: {lit.dtype}\n" + f" Target type: {self._target_type}" + ) case PsSymbolExpr(symb): assert symb.dtype is not None @@ -356,6 +365,9 @@ class Typifier: else: tc.infer_dtype(expr) + case PsLiteralExpr(lit): + tc.apply_dtype(lit.dtype, expr) + case PsArrayAccess(bptr, idx): tc.apply_dtype(bptr.array.element_type, expr) @@ -467,6 +479,14 @@ class Typifier: for arg in args: self.visit_expr(arg, tc) tc.infer_dtype(expr) + + case CFunction(_, arg_types, ret_type): + tc.apply_dtype(ret_type, expr) + + for arg, arg_type in zip(args, arg_types, strict=True): + arg_tc = TypeContext(arg_type) + self.visit_expr(arg, arg_tc) + case _: raise TypificationError( f"Don't know how to typify calls to {function}" @@ -494,6 +514,7 @@ class Typifier: ) else: items_tc.apply_dtype(tc.target_type.base_type) + tc.infer_dtype(expr) else: arr_type = PsArrayType(items_tc.target_type, len(items)) tc.apply_dtype(arr_type, expr) diff --git a/src/pystencils/backend/literals.py b/src/pystencils/backend/literals.py new file mode 100644 index 0000000000000000000000000000000000000000..dc7504f520f8950b46df76b0359aaad371244b19 --- /dev/null +++ b/src/pystencils/backend/literals.py @@ -0,0 +1,43 @@ +from __future__ import annotations +from ..types import PsType, constify + + +class PsLiteral: + """Representation of literal code. + + Instances of this class represent code literals inside the AST. + These literals are not to be confused with C literals; the name `Literal` refers to the fact that + the code generator takes them "literally", printing them as they are. + + Each literal has to be annotated with a type, and is considered constant within the scope of a kernel. + Instances of `PsLiteral` are immutable. + """ + + __match_args__ = ("text", "dtype") + + def __init__(self, text: str, dtype: PsType) -> None: + self._text = text + self._dtype = constify(dtype) + + @property + def text(self) -> str: + return self._text + + @property + def dtype(self) -> PsType: + return self._dtype + + def __str__(self) -> str: + return f"{self._text}: {self._dtype}" + + def __repr__(self) -> str: + return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsLiteral): + return False + + return self._text == other._text and self._dtype == other._dtype + + def __hash__(self) -> int: + return hash((PsLiteral, self._text, self._dtype)) diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 6899ac9474d303623f79e2bdc7c3765c64380a6c..17589bf27d3109a3ce891acfdc248e0120da2f77 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -55,6 +55,7 @@ class GenericCpu(Platform): self, math_function: PsMathFunction, dtype: PsType ) -> CFunction: func = math_function.func + arg_types = (dtype,) * func.num_args if isinstance(dtype, PsIeeeFloatType) and dtype.width in (32, 64): match func: case ( @@ -64,9 +65,9 @@ class GenericCpu(Platform): | MathFunctions.Tan | MathFunctions.Pow ): - return CFunction(func.function_name, func.num_args) + return CFunction(func.function_name, arg_types, dtype) case MathFunctions.Abs | MathFunctions.Min | MathFunctions.Max: - return CFunction("f" + func.function_name, func.num_args) + return CFunction("f" + func.function_name, arg_types, dtype) raise MaterializationError( f"No implementation available for function {math_function} on data type {dtype}" diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 839cd34f4c9fada060a0ac6253b635f7c7812948..6e860d32b07de8d7993ee96b103ce2d0983766d8 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -11,26 +11,26 @@ from ..kernelcreation.iteration_space import ( from ..ast.structural import PsBlock, PsConditional from ..ast.expressions import ( PsExpression, - PsSymbolExpr, + PsLiteralExpr, PsAdd, ) from ..ast.expressions import PsLt, PsAnd from ...types import PsSignedIntegerType -from ..symbols import PsSymbol +from ..literals import PsLiteral int32 = PsSignedIntegerType(width=32, const=False) BLOCK_IDX = [ - PsSymbolExpr(PsSymbol(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z") + PsLiteralExpr(PsLiteral(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z") ] THREAD_IDX = [ - PsSymbolExpr(PsSymbol(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z") + PsLiteralExpr(PsLiteral(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z") ] BLOCK_DIM = [ - PsSymbolExpr(PsSymbol(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z") + PsLiteralExpr(PsLiteral(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z") ] GRID_DIM = [ - PsSymbolExpr(PsSymbol(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z") + PsLiteralExpr(PsLiteral(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z") ] diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index fa5af4655810943c47f503329a8c41ce3baa36c5..ccaf9fbe99f46ce4b0ecbb81c775c9f274678026 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -10,7 +10,7 @@ from ..ast.expressions import ( PsSubscript, ) from ..transformations.select_intrinsics import IntrinsicOps -from ...types import PsCustomType, PsVectorType +from ...types import PsCustomType, PsVectorType, PsPointerType from ..constants import PsConstant from ..exceptions import MaterializationError @@ -124,10 +124,13 @@ class X86VectorCpu(GenericVectorCpu): def constant_vector(self, c: PsConstant) -> PsExpression: vtype = c.dtype assert isinstance(vtype, PsVectorType) + stype = vtype.scalar_type prefix = self._vector_arch.intrin_prefix(vtype) suffix = self._vector_arch.intrin_suffix(vtype) - set_func = CFunction(f"{prefix}_set_{suffix}", vtype.vector_entries) + set_func = CFunction( + f"{prefix}_set_{suffix}", (stype,) * vtype.vector_entries, vtype + ) values = c.value return set_func(*values) @@ -164,7 +167,10 @@ def _x86_packed_load( ) -> CFunction: prefix = varch.intrin_prefix(vtype) suffix = varch.intrin_suffix(vtype) - return CFunction(f"{prefix}_load{'' if aligned else 'u'}_{suffix}", 1) + ptr_type = PsPointerType(vtype.scalar_type, const=True) + return CFunction( + f"{prefix}_load{'' if aligned else 'u'}_{suffix}", (ptr_type,), vtype + ) @cache @@ -173,7 +179,12 @@ def _x86_packed_store( ) -> CFunction: prefix = varch.intrin_prefix(vtype) suffix = varch.intrin_suffix(vtype) - return CFunction(f"{prefix}_store{'' if aligned else 'u'}_{suffix}", 2) + ptr_type = PsPointerType(vtype.scalar_type, const=True) + return CFunction( + f"{prefix}_store{'' if aligned else 'u'}_{suffix}", + (ptr_type, vtype), + PsCustomType("void"), + ) @cache @@ -197,4 +208,5 @@ def _x86_op_intrin( case _: assert False - return CFunction(f"{prefix}_{opstr}_{suffix}", 3 if op == IntrinsicOps.FMA else 2) + num_args = 3 if op == IntrinsicOps.FMA else 2 + return CFunction(f"{prefix}_{opstr}_{suffix}", (vtype,) * num_args, vtype) diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index 01b69509991eaa762a093f50f427f6e4050dc34a..518c402d27e8828cc129c06a20bd10e2ba3d3168 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -1,16 +1,94 @@ +""" +This module contains various transformation and optimization passes that can be +executed on the backend AST. + +Canonical Form +============== + +Many transformations in this module require that their input AST is in *canonical form*. +This means that: + +- Each symbol, constant, and expression node is annotated with a data type; +- Each symbol has at most one declaration; +- Each symbol that is never written to apart from its declaration has a ``const`` type; and +- Each symbol whose type is *not* ``const`` has at least one non-declaring assignment. + +The first requirement can be ensured by running the `Typifier` on each newly constructed subtree. +The other three requirements are ensured by the `CanonicalizeSymbols` pass, +which should be run first before applying any optimizing transformations. +All transformations in this module retain canonicality of the AST. + +Canonicality allows transformations to forego various checks that would otherwise be necessary +to prove their legality. + +Certain transformations, like the auto-vectorizer (TODO), state additional requirements, e.g. +the absence of loop-carried dependencies. + +Transformations +=============== + +Canonicalization +---------------- + +.. autoclass:: CanonicalizeSymbols + :members: __call__ + +AST Cloning +----------- + +.. autoclass:: CanonicalClone + :members: __call__ + +Simplifying Transformations +--------------------------- + +.. autoclass:: EliminateConstants + :members: __call__ + +.. autoclass:: EliminateBranches + :members: __call__ + +Code Motion +----------- + +.. autoclass:: HoistLoopInvariantDeclarations + :members: __call__ + +Loop Reshaping Transformations +------------------------------ + +.. autoclass:: ReshapeLoops + :members: + + +Code Lowering and Materialization +--------------------------------- + +.. autoclass:: EraseAnonymousStructTypes + :members: __call__ + +.. autoclass:: SelectFunctions + :members: __call__ + +""" + +from .canonicalize_symbols import CanonicalizeSymbols +from .canonical_clone import CanonicalClone from .eliminate_constants import EliminateConstants from .eliminate_branches import EliminateBranches -from .canonicalize_symbols import CanonicalizeSymbols from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations +from .reshape_loops import ReshapeLoops from .erase_anonymous_structs import EraseAnonymousStructTypes from .select_functions import SelectFunctions from .select_intrinsics import MaterializeVectorIntrinsics __all__ = [ + "CanonicalizeSymbols", + "CanonicalClone", "EliminateConstants", "EliminateBranches", - "CanonicalizeSymbols", "HoistLoopInvariantDeclarations", + "ReshapeLoops", "EraseAnonymousStructTypes", "SelectFunctions", "MaterializeVectorIntrinsics", diff --git a/src/pystencils/backend/transformations/canonical_clone.py b/src/pystencils/backend/transformations/canonical_clone.py new file mode 100644 index 0000000000000000000000000000000000000000..538bb2779314fc0fe1d7b83810dd6a4b031ca46a --- /dev/null +++ b/src/pystencils/backend/transformations/canonical_clone.py @@ -0,0 +1,112 @@ +from typing import TypeVar, cast + +from ..kernelcreation import KernelCreationContext +from ..symbols import PsSymbol +from ..exceptions import PsInternalCompilerError + +from ..ast import PsAstNode +from ..ast.structural import ( + PsBlock, + PsConditional, + PsLoop, + PsDeclaration, + PsAssignment, + PsComment, +) +from ..ast.expressions import PsExpression, PsSymbolExpr + +__all__ = ["CanonicalClone"] + + +class CloneContext: + def __init__(self, ctx: KernelCreationContext) -> None: + self._ctx = ctx + self._dup_table: dict[PsSymbol, PsSymbol] = dict() + + def symbol_decl(self, declared_symbol: PsSymbol): + self._dup_table[declared_symbol] = self._ctx.duplicate_symbol(declared_symbol) + + def get_replacement(self, symb: PsSymbol): + return self._dup_table.get(symb, symb) + + +Node_T = TypeVar("Node_T", bound=PsAstNode) + + +class CanonicalClone: + """Clone a subtree, and rename all symbols declared inside it to retain canonicality.""" + + def __init__(self, ctx: KernelCreationContext) -> None: + self._ctx = ctx + + def __call__(self, node: Node_T) -> Node_T: + return self.visit(node, CloneContext(self._ctx)) + + def visit(self, node: Node_T, cc: CloneContext) -> Node_T: + match node: + case PsBlock(statements): + return cast( + Node_T, PsBlock([self.visit(stmt, cc) for stmt in statements]) + ) + + case PsLoop(ctr, start, stop, step, body): + cc.symbol_decl(ctr.symbol) + return cast( + Node_T, + PsLoop( + self.visit(ctr, cc), + self.visit(start, cc), + self.visit(stop, cc), + self.visit(step, cc), + self.visit(body, cc), + ), + ) + + case PsConditional(cond, then, els): + return cast( + Node_T, + PsConditional( + self.visit(cond, cc), + self.visit(then, cc), + self.visit(els, cc) if els is not None else None, + ), + ) + + case PsComment(): + return cast(Node_T, node.clone()) + + case PsDeclaration(lhs, rhs): + cc.symbol_decl(node.declared_symbol) + return cast( + Node_T, + PsDeclaration( + cast(PsSymbolExpr, self.visit(lhs, cc)), + self.visit(rhs, cc), + ), + ) + + case PsAssignment(lhs, rhs): + return cast( + Node_T, + PsAssignment( + self.visit(lhs, cc), + self.visit(rhs, cc), + ), + ) + + case PsExpression(): + expr_clone = node.clone() + self._replace_symbols(expr_clone, cc) + return cast(Node_T, expr_clone) + + case _: + raise PsInternalCompilerError( + f"Don't know how to canonically clone {type(node)}" + ) + + def _replace_symbols(self, expr: PsExpression, cc: CloneContext): + if isinstance(expr, PsSymbolExpr): + expr.symbol = cc.get_replacement(expr.symbol) + else: + for c in expr.children: + self._replace_symbols(cast(PsExpression, c), cc) diff --git a/src/pystencils/backend/transformations/canonicalize_symbols.py b/src/pystencils/backend/transformations/canonicalize_symbols.py index 6fe922f28cfbfd0a42c77dbd3d1606c910abf298..3900105b8f64b7cd33e154de02eaec7cf826d0fb 100644 --- a/src/pystencils/backend/transformations/canonicalize_symbols.py +++ b/src/pystencils/backend/transformations/canonicalize_symbols.py @@ -1,5 +1,3 @@ -from itertools import count - from ..kernelcreation import KernelCreationContext from ..symbols import PsSymbol from ..exceptions import PsInternalCompilerError @@ -29,14 +27,10 @@ class CanonContext: self.live_symbols_map[symb] = symb return symb else: - for i in count(): - replacement_name = f"{symb.name}__{i}" - if self._ctx.find_symbol(replacement_name) is None: - replacement = self._ctx.get_symbol(replacement_name, symb.dtype) - self.live_symbols_map[symb] = replacement - self.encountered_symbols.add(replacement) - return replacement - assert False, "unreachable code" + replacement = self._ctx.duplicate_symbol(symb) + self.live_symbols_map[symb] = replacement + self.encountered_symbols.add(replacement) + return replacement def mark_as_updated(self, symb: PsSymbol): self.updated_symbols.add(self.deduplicate(symb)) diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index 7678dbd8c6ce783585fb7095b201e9f92e65e485..ddfa33f08272d59d032e1e657e66baa96fb41d04 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -1,4 +1,4 @@ -from typing import cast, Iterable +from typing import cast, Iterable, overload from collections import defaultdict from ..kernelcreation import KernelCreationContext, Typifier @@ -9,6 +9,7 @@ from ..ast.expressions import ( PsExpression, PsConstantExpr, PsSymbolExpr, + PsLiteralExpr, PsBinOp, PsAdd, PsSub, @@ -116,6 +117,14 @@ class EliminateConstants: self._fold_floats = False self._extract_constant_exprs = extract_constant_exprs + @overload + def __call__(self, node: PsExpression) -> PsExpression: + pass + + @overload + def __call__(self, node: PsAstNode) -> PsAstNode: + pass + def __call__(self, node: PsAstNode) -> PsAstNode: ecc = ECContext(self._ctx) @@ -151,8 +160,8 @@ class EliminateConstants: Returns: (transformed_expr, is_const): The tranformed expression, and a flag indicating whether it is constant """ - # Return constants as they are - if isinstance(expr, PsConstantExpr): + # Return constants and literals as they are + if isinstance(expr, (PsConstantExpr, PsLiteralExpr)): return expr, True # Shortcut symbols @@ -243,7 +252,6 @@ class EliminateConstants: # Detect constant expressions if all(subtree_constness): dtype = expr.get_dtype() - assert isinstance(dtype, PsNumericType) is_int = isinstance(dtype, PsIntegerType) is_float = isinstance(dtype, PsIeeeFloatType) @@ -266,6 +274,7 @@ class EliminateConstants: py_operator = expr.python_operator if do_fold and py_operator is not None: + assert isinstance(dtype, PsNumericType) folded = PsConstant(py_operator(val), dtype) return self._typify(PsConstantExpr(folded)), True @@ -279,6 +288,7 @@ class EliminateConstants: v2 = op2.constant.value if do_fold: + assert isinstance(dtype, PsNumericType) py_operator = expr.python_operator folded = None @@ -308,7 +318,7 @@ class EliminateConstants: # If required, extract constant subexpressions if self._extract_constant_exprs: for i, (child, is_const) in enumerate(subtree_results): - if is_const and not isinstance(child, PsConstantExpr): + if is_const and not isinstance(child, (PsConstantExpr, PsLiteralExpr)): replacement = ecc.extract_expression(child) expr.set_child(i, replacement) diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index 5920038150ee6c5295c3f46f4530639ed02fca25..cb9c9e92064d2061198f66cf5dc893d6491e1b34 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -7,6 +7,7 @@ from ..ast.expressions import ( PsExpression, PsSymbolExpr, PsConstantExpr, + PsLiteralExpr, PsCall, PsDeref, PsSubscript, @@ -40,7 +41,7 @@ class HoistContext: symbol in self.invariant_symbols ) - case PsConstantExpr(): + case PsConstantExpr() | PsLiteralExpr(): return True case PsCall(func): @@ -69,7 +70,7 @@ class HoistLoopInvariantDeclarations: in particular, each symbol may have at most one declaration. To ensure this, a `CanonicalizeSymbols` pass should be run before `HoistLoopInvariantDeclarations`. - `HoistLoopInvariants` assumes that all `PsMathFunction`s are pure (have no side effects), + `HoistLoopInvariantDeclarations` assumes that all `PsMathFunction` s are pure (have no side effects), but makes no such assumption about instances of `CFunction`. """ diff --git a/src/pystencils/backend/transformations/reshape_loops.py b/src/pystencils/backend/transformations/reshape_loops.py new file mode 100644 index 0000000000000000000000000000000000000000..6963bee0b2e43bc6bac58a6c96de5f4a35e57148 --- /dev/null +++ b/src/pystencils/backend/transformations/reshape_loops.py @@ -0,0 +1,138 @@ +from typing import Sequence + +from ..kernelcreation import KernelCreationContext, Typifier +from ..kernelcreation.ast_factory import AstFactory, IndexParsable + +from ..ast.structural import PsLoop, PsBlock, PsConditional, PsDeclaration +from ..ast.expressions import PsExpression, PsConstantExpr, PsLt +from ..constants import PsConstant + +from .canonical_clone import CanonicalClone, CloneContext +from .eliminate_constants import EliminateConstants + + +class ReshapeLoops: + """Various transformations for reshaping loop nests.""" + + def __init__(self, ctx: KernelCreationContext) -> None: + self._ctx = ctx + self._typify = Typifier(ctx) + self._factory = AstFactory(ctx) + self._canon_clone = CanonicalClone(ctx) + self._elim_constants = EliminateConstants(ctx) + + def peel_loop_front( + self, loop: PsLoop, num_iterations: int, omit_range_check: bool = False + ) -> tuple[Sequence[PsBlock], PsLoop]: + """Peel off iterations from the front of a loop. + + Removes ``num_iterations`` from the front of the given loop and returns them as a sequence of + independent blocks. + + Args: + loop: The loop node from which to peel iterations + num_iterations: The number of iterations to peel off + omit_range_check: If set to `True`, assume that the peeled-off iterations will always + be executed, and omit their enclosing conditional. + + Returns: + Tuple containing the peeled-off iterations as a sequence of blocks, + and the remaining loop. + """ + + peeled_iters: list[PsBlock] = [] + + for i in range(num_iterations): + cc = CloneContext(self._ctx) + cc.symbol_decl(loop.counter.symbol) + peeled_ctr = self._factory.parse_index( + cc.get_replacement(loop.counter.symbol) + ) + peeled_idx = self._typify(loop.start + PsExpression.make(PsConstant(i))) + + counter_decl = PsDeclaration(peeled_ctr, peeled_idx) + peeled_block = self._canon_clone.visit(loop.body, cc) + + if omit_range_check: + peeled_block.statements = [counter_decl] + peeled_block.statements + else: + iter_condition = PsLt(peeled_ctr, loop.stop) + peeled_block.statements = [ + counter_decl, + PsConditional(iter_condition, PsBlock(peeled_block.statements)), + ] + + peeled_iters.append(peeled_block) + + loop.start = self._elim_constants( + self._typify(loop.start + PsExpression.make(PsConstant(num_iterations))) + ) + + return peeled_iters, loop + + def cut_loop( + self, loop: PsLoop, cutting_points: Sequence[IndexParsable] + ) -> Sequence[PsLoop | PsBlock]: + """Cut a loop at the given cutting points. + + Cut the given loop at the iterations specified by the given cutting points, + producing ``n`` new subtrees representing the iterations + ``(loop.start:cutting_points[0]), (cutting_points[0]:cutting_points[1]), ..., (cutting_points[-1]:loop.stop)``. + + Resulting subtrees representing zero iterations are dropped; subtrees representing exactly one iteration are + returned without the trivial loop structure. + + Currently, `cut_loop` performs no checks to ensure that the given cutting points are in fact inside + the loop's iteration range. + + Returns: + Sequence of ``n`` subtrees representing the respective iteration ranges + """ + + if not ( + isinstance(loop.step, PsConstantExpr) and loop.step.constant.value == 1 + ): + raise NotImplementedError( + "Loop cutting for loops with step != 1 is not implemented" + ) + + result: list[PsLoop | PsBlock] = [] + new_start = loop.start + cutting_points = [self._factory.parse_index(idx) for idx in cutting_points] + [ + loop.stop + ] + + for new_end in cutting_points: + if new_end.structurally_equal(new_start): + continue + + num_iters = self._elim_constants(self._typify(new_end - new_start)) + skip = False + + if isinstance(num_iters, PsConstantExpr): + if num_iters.constant.value == 0: + skip = True + elif num_iters.constant.value == 1: + skip = True + cc = CloneContext(self._ctx) + cc.symbol_decl(loop.counter.symbol) + local_counter = self._factory.parse_index( + cc.get_replacement(loop.counter.symbol) + ) + ctr_decl = PsDeclaration( + local_counter, + new_start, + ) + cloned_body = self._canon_clone.visit(loop.body, cc) + cloned_body.statements = [ctr_decl] + cloned_body.statements + result.append(cloned_body) + + if not skip: + loop_clone = self._canon_clone(loop) + loop_clone.start = new_start.clone() + loop_clone.stop = new_end.clone() + result.append(loop_clone) + + new_start = new_end + + return result diff --git a/tests/nbackend/kernelcreation/test_iteration_space.py b/tests/nbackend/kernelcreation/test_iteration_space.py index 1dfcfea2a9bfd8e9db70dab7ca61732523855843..7fd6d778ff62f7fb2fcbc24a55af5225fb9f870e 100644 --- a/tests/nbackend/kernelcreation/test_iteration_space.py +++ b/tests/nbackend/kernelcreation/test_iteration_space.py @@ -52,7 +52,7 @@ def test_invalid_slices(): ctx.add_field(archetype_field) islice = (slice(1, -1, 0.5),) - with pytest.raises(PsTypeError): + with pytest.raises(TypeError): FullIterationSpace.create_from_slice(ctx, islice, archetype_field) islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 60d0d6e7424bdfea730cafe18995afdb7dc253df..01f68c0a3e637e3139990f9208710e9861243e9d 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -26,10 +26,12 @@ from pystencils.backend.ast.expressions import ( PsLe, PsGt, PsLt, + PsCall, ) from pystencils.backend.constants import PsConstant +from pystencils.backend.functions import CFunction from pystencils.types import constify, create_type, create_numeric_type -from pystencils.types.quick import Fp, Bool +from pystencils.types.quick import Fp, Int, Bool from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError @@ -354,7 +356,7 @@ def test_invalid_conditions(): x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"] - + cond = PsConditional(x + y, PsBlock([])) with pytest.raises(TypificationError): typify(cond) @@ -362,3 +364,24 @@ def test_invalid_conditions(): cond = PsConditional(PsAnd(p, PsOr(x, q)), PsBlock([])) with pytest.raises(TypificationError): typify(cond) + + +def test_cfunction(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] + p, q = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "pq"] + + def _threeway(x: np.float32, y: np.float32) -> np.int32: + assert False + + threeway = CFunction.parse(_threeway) + + result = typify(PsCall(threeway, [x, y])) + + assert result.get_dtype() == Int(32, const=True) + assert result.args[0].get_dtype() == Fp(32, const=True) + assert result.args[1].get_dtype() == Fp(32, const=True) + + with pytest.raises(TypificationError): + _ = typify(PsCall(threeway, (x, p))) diff --git a/tests/nbackend/test_functions.py b/tests/nbackend/test_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..e88e51e48244f557d19e8c3ebbc5129d335c40a5 --- /dev/null +++ b/tests/nbackend/test_functions.py @@ -0,0 +1,77 @@ +import sympy as sp +import numpy as np +import pytest + +from pystencils import create_kernel, CreateKernelConfig, Target, Assignment, Field + +UNARY_FUNCTIONS = { + "exp": (sp.exp, np.exp), + "sin": (sp.sin, np.sin), + "cos": (sp.cos, np.cos), + "tan": (sp.tan, np.tan), + "abs": (sp.Abs, np.abs), +} + +BINARY_FUNCTIONS = { + "min": (sp.Min, np.fmin), + "max": (sp.Max, np.fmax), + "pow": (sp.Pow, np.power), +} + + +@pytest.mark.parametrize("target", (Target.GenericCPU,)) +@pytest.mark.parametrize("function_name", UNARY_FUNCTIONS.keys()) +@pytest.mark.parametrize("dtype", (np.float32, np.float64)) +def test_unary_functions(target, function_name, dtype): + sp_func, np_func = UNARY_FUNCTIONS[function_name] + resolution: dtype = np.finfo(dtype).resolution + + inp = np.array( + [[0.1, 0.2, 0.3], [-0.8, -1.6, -12.592], [np.pi, np.e, 0.0]], dtype=dtype + ) + outp = np.zeros_like(inp) + + reference = np_func(inp) + + inp_field = Field.create_from_numpy_array("inp", inp) + outp_field = inp_field.new_field_with_different_name("outp") + + asms = [Assignment(outp_field.center(), sp_func(inp_field.center()))] + gen_config = CreateKernelConfig(target=target, default_dtype=dtype) + + kernel = create_kernel(asms, gen_config) + kfunc = kernel.compile() + kfunc(inp=inp, outp=outp) + + np.testing.assert_allclose(outp, reference, rtol=resolution) + + +@pytest.mark.parametrize("target", (Target.GenericCPU,)) +@pytest.mark.parametrize("function_name", BINARY_FUNCTIONS.keys()) +@pytest.mark.parametrize("dtype", (np.float32, np.float64)) +def test_binary_functions(target, function_name, dtype): + sp_func, np_func = BINARY_FUNCTIONS[function_name] + resolution: dtype = np.finfo(dtype).resolution + + inp = np.array( + [[0.1, 0.2, 0.3], [-0.8, -1.6, -12.592], [np.pi, np.e, 0.0]], dtype=dtype + ) + inp2 = np.array( + [[3.1, -0.5, 21.409], [11.0, 1.0, -14e3], [2.0 * np.pi, - np.e, 0.0]], dtype=dtype + ) + outp = np.zeros_like(inp) + + reference = np_func(inp, inp2) + + inp_field = Field.create_from_numpy_array("inp", inp) + inp2_field = Field.create_from_numpy_array("inp2", inp) + outp_field = inp_field.new_field_with_different_name("outp") + + asms = [Assignment(outp_field.center(), sp_func(inp_field.center(), inp2_field.center()))] + gen_config = CreateKernelConfig(target=target, default_dtype=dtype) + + kernel = create_kernel(asms, gen_config) + kfunc = kernel.compile() + kfunc(inp=inp, inp2=inp2, outp=outp) + + np.testing.assert_allclose(outp, reference, rtol=resolution) diff --git a/tests/nbackend/transformations/test_canonical_clone.py b/tests/nbackend/transformations/test_canonical_clone.py new file mode 100644 index 0000000000000000000000000000000000000000..b158b91781b49f8d589a3da3b266e8c2137fceab --- /dev/null +++ b/tests/nbackend/transformations/test_canonical_clone.py @@ -0,0 +1,63 @@ +import sympy as sp +from pystencils import Field, Assignment, make_slice, TypedSymbol +from pystencils.types.quick import Arr + +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) +from pystencils.backend.transformations import CanonicalClone +from pystencils.backend.ast.structural import PsBlock, PsComment +from pystencils.backend.ast.expressions import PsSymbolExpr +from pystencils.backend.ast.iteration import dfs_preorder + + +def test_clone_entire_ast(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + canon_clone = CanonicalClone(ctx) + + f = Field.create_generic("f", 2, index_shape=(5,)) + rho = sp.Symbol("rho") + u = sp.symbols("u_:2") + + cx = TypedSymbol("cx", Arr(ctx.default_dtype)) + cy = TypedSymbol("cy", Arr(ctx.default_dtype)) + cxs = sp.IndexedBase(cx, shape=(5,)) + cys = sp.IndexedBase(cy, shape=(5,)) + + rho_out = Field.create_generic("rho", 2, index_shape=(1,)) + u_out = Field.create_generic("u", 2, index_shape=(2,)) + + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], f) + ctx.set_iteration_space(ispace) + + asms = [ + Assignment(cx, (0, 1, -1, 0, 0)), + Assignment(cy, (0, 0, 0, 1, -1)), + Assignment(rho, sum(f.center(i) for i in range(5))), + Assignment(u[0], 1 / rho * sum((f.center(i) * cxs[i]) for i in range(5))), + Assignment(u[1], 1 / rho * sum((f.center(i) * cys[i]) for i in range(5))), + Assignment(rho_out.center(0), rho), + Assignment(u_out.center(0), u[0]), + Assignment(u_out.center(1), u[1]), + ] + + body = PsBlock( + [PsComment("Compute and export density and velocity")] + + [factory.parse_sympy(asm) for asm in asms] + ) + + ast = factory.loops_from_ispace(ispace, body) + ast_clone = canon_clone(ast) + + for orig, clone in zip(dfs_preorder(ast), dfs_preorder(ast_clone), strict=True): + assert type(orig) is type(clone) + assert orig is not clone + + if isinstance(orig, PsSymbolExpr): + assert isinstance(clone, PsSymbolExpr) + + if orig.symbol.name in ("ctr_0", "ctr_1", "rho", "u_0", "u_1", "cx", "cy"): + assert clone.symbol.name == orig.symbol.name + "__0" diff --git a/tests/nbackend/transformations/test_reshape_loops.py b/tests/nbackend/transformations/test_reshape_loops.py new file mode 100644 index 0000000000000000000000000000000000000000..e68cff1b64acbb4f9bbf30dee9ef3f2abe9e59d3 --- /dev/null +++ b/tests/nbackend/transformations/test_reshape_loops.py @@ -0,0 +1,101 @@ +import sympy as sp + +from pystencils import Field, Assignment, make_slice +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) +from pystencils.backend.transformations import ReshapeLoops + +from pystencils.backend.ast.structural import PsDeclaration, PsBlock, PsLoop, PsConditional +from pystencils.backend.ast.expressions import PsConstantExpr, PsLt + + +def test_loop_cutting(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + reshape = ReshapeLoops(ctx) + + x, y, z = sp.symbols("x, y, z") + + f = Field.create_generic("f", 1, index_shape=(2,)) + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f) + ctx.set_iteration_space(ispace) + + loop_body = PsBlock( + [ + factory.parse_sympy(Assignment(x, 2 * z)), + factory.parse_sympy(Assignment(f.center(0), x + y)), + ] + ) + + loop = factory.loops_from_ispace(ispace, loop_body) + + subloops = reshape.cut_loop(loop, [1, 1, 3]) + assert len(subloops) == 3 + + subloop = subloops[0] + assert isinstance(subloop, PsBlock) + assert isinstance(subloop.statements[0], PsDeclaration) + assert subloop.statements[0].declared_symbol.name == "ctr_0__0" + + x_decl = subloop.statements[1] + assert isinstance(x_decl, PsDeclaration) + assert x_decl.declared_symbol.name == "x__0" + + subloop = subloops[1] + assert isinstance(subloop, PsLoop) + assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1 + assert isinstance(subloop.stop, PsConstantExpr) and subloop.stop.constant.value == 3 + + x_decl = subloop.body.statements[0] + assert isinstance(x_decl, PsDeclaration) + assert x_decl.declared_symbol.name == "x__1" + + subloop = subloops[2] + assert isinstance(subloop, PsLoop) + assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3 + assert subloop.stop.structurally_equal(loop.stop) + + +def test_loop_peeling(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + reshape = ReshapeLoops(ctx) + + x, y, z = sp.symbols("x, y, z") + + f = Field.create_generic("f", 1, index_shape=(2,)) + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f) + ctx.set_iteration_space(ispace) + + loop_body = PsBlock([ + factory.parse_sympy(Assignment(x, 2 * z)), + factory.parse_sympy(Assignment(f.center(0), x + y)), + ]) + + loop = factory.loops_from_ispace(ispace, loop_body) + + num_iters = 3 + peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters) + assert len(peeled_iters) == 3 + + for i, iter in enumerate(peeled_iters): + assert isinstance(iter, PsBlock) + + ctr_decl = iter.statements[0] + assert isinstance(ctr_decl, PsDeclaration) + assert ctr_decl.declared_symbol.name == f"ctr_0__{i}" + + cond = iter.statements[1] + assert isinstance(cond, PsConditional) + assert cond.condition.structurally_equal(PsLt(ctr_decl.lhs, loop.stop)) + + subblock = cond.branch_true + assert isinstance(subblock.statements[0], PsDeclaration) + assert subblock.statements[0].declared_symbol.name == f"x__{i}" + + assert peeled_loop.start.structurally_equal(factory.parse_index(num_iters)) + assert peeled_loop.stop.structurally_equal(loop.stop) + assert peeled_loop.body.structurally_equal(loop.body)