Skip to content
Snippets Groups Projects
Commit 536420be authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Merge branch 'fhennig/loop-splitting' into 'backend-rework'

Loop Transformations: Cutting and Peeling

See merge request pycodegen/pystencils!376
parents 4b24e9b8 f9f416e1
No related branches found
No related tags found
1 merge request!376Loop Transformations: Cutting and Peeling
Pipeline #65004 passed
Showing
with 625 additions and 38 deletions
...@@ -14,6 +14,7 @@ who wish to customize or extend the behaviour of the code generator in their app ...@@ -14,6 +14,7 @@ who wish to customize or extend the behaviour of the code generator in their app
iteration_space iteration_space
translation translation
platforms platforms
transformations
jit jit
Internal Representation Internal Representation
......
*******************
AST Transformations
*******************
`pystencils.backend.transformations`
.. automodule:: pystencils.backend.transformations
from typing import Any, Sequence, cast, overload from typing import Any, Sequence, cast, overload
import numpy as np
import sympy as sp import sympy as sp
from sympy.codegen.ast import AssignmentBase from sympy.codegen.ast import AssignmentBase
from ..ast import PsAstNode from ..ast import PsAstNode
from ..ast.expressions import PsExpression, PsSymbolExpr from ..ast.expressions import PsExpression, PsSymbolExpr, PsConstantExpr
from ..ast.structural import PsLoop, PsBlock, PsAssignment from ..ast.structural import PsLoop, PsBlock, PsAssignment
from ..symbols import PsSymbol from ..symbols import PsSymbol
...@@ -16,6 +17,10 @@ from .typification import Typifier ...@@ -16,6 +17,10 @@ from .typification import Typifier
from .iteration_space import FullIterationSpace from .iteration_space import FullIterationSpace
IndexParsable = PsExpression | PsSymbol | PsConstant | sp.Expr | int | np.integer
_IndexParsable = (PsExpression, PsSymbol, PsConstant, sp.Expr, int, np.integer)
class AstFactory: class AstFactory:
"""Factory providing a convenient interface for building syntax trees. """Factory providing a convenient interface for building syntax trees.
...@@ -51,6 +56,45 @@ class AstFactory: ...@@ -51,6 +56,45 @@ class AstFactory:
""" """
return self._typify(self._freeze(sp_obj)) return self._typify(self._freeze(sp_obj))
@overload
def parse_index(self, idx: sp.Symbol | PsSymbol | PsSymbolExpr) -> PsSymbolExpr:
pass
@overload
def parse_index(
self, idx: int | np.integer | PsConstant | PsConstantExpr
) -> PsConstantExpr:
pass
@overload
def parse_index(self, idx: sp.Expr | PsExpression) -> PsExpression:
pass
def parse_index(self, idx: IndexParsable):
"""Parse the given object as an expression with data type `ctx.index_dtype`."""
if not isinstance(idx, _IndexParsable):
raise TypeError(
f"Cannot parse object of type {type(idx)} as an index expression"
)
match idx:
case PsExpression():
return self._typify.typify_expression(idx, self._ctx.index_dtype)[0]
case PsSymbol() | PsConstant():
return self._typify.typify_expression(
PsExpression.make(idx), self._ctx.index_dtype
)[0]
case sp.Expr():
return self._typify.typify_expression(
self._freeze(idx), self._ctx.index_dtype
)[0]
case _:
return PsExpression.make(PsConstant(idx, self._ctx.index_dtype))
def _parse_any_index(self, idx: Any) -> PsExpression:
return self.parse_index(cast(IndexParsable, idx))
def parse_slice( def parse_slice(
self, slic: slice, upper_limit: Any | None = None self, slic: slice, upper_limit: Any | None = None
) -> tuple[PsExpression, PsExpression, PsExpression]: ) -> tuple[PsExpression, PsExpression, PsExpression]:
...@@ -75,27 +119,16 @@ class AstFactory: ...@@ -75,27 +119,16 @@ class AstFactory:
"Must specify an upper iteration limit if `slice.stop` is `None` or a negative `int`" "Must specify an upper iteration limit if `slice.stop` is `None` or a negative `int`"
) )
def make_expr(val: Any) -> PsExpression: start = self._parse_any_index(slic.start if slic.start is not None else 0)
match val: stop = (
case PsExpression(): self._parse_any_index(slic.stop)
return self._typify.typify_expression(val, self._ctx.index_dtype)[0] if slic.stop is not None
case PsSymbol() | PsConstant(): else self._parse_any_index(upper_limit)
return self._typify.typify_expression( )
PsExpression.make(val), self._ctx.index_dtype step = self._parse_any_index(slic.step if slic.step is not None else 1)
)[0]
case sp.Expr():
return self._typify.typify_expression(
self._freeze(val), self._ctx.index_dtype
)[0]
case _:
return PsExpression.make(PsConstant(val, self._ctx.index_dtype))
start = make_expr(slic.start if slic.start is not None else 0)
stop = make_expr(slic.stop) if slic.stop is not None else make_expr(upper_limit)
step = make_expr(slic.step if slic.step is not None else 1)
if isinstance(slic.stop, int) and slic.stop < 0: if isinstance(slic.stop, int) and slic.stop < 0:
stop = make_expr(upper_limit) + stop stop = self._parse_any_index(upper_limit) + stop
return start, stop, step return start, stop, step
......
from __future__ import annotations from __future__ import annotations
from typing import Iterable, Iterator from typing import Iterable, Iterator
from itertools import chain from itertools import chain, count
from types import EllipsisType from types import EllipsisType
from collections import namedtuple from collections import namedtuple, defaultdict
import re
from ...defaults import DEFAULTS from ...defaults import DEFAULTS
from ...field import Field, FieldType from ...field import Field, FieldType
...@@ -67,6 +68,9 @@ class KernelCreationContext: ...@@ -67,6 +68,9 @@ class KernelCreationContext:
self._symbols: dict[str, PsSymbol] = dict() self._symbols: dict[str, PsSymbol] = dict()
self._symbol_ctr_pattern = re.compile(r"__[0-9]+$")
self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0)
self._fields_and_arrays: dict[str, FieldArrayPair] = dict() self._fields_and_arrays: dict[str, FieldArrayPair] = dict()
self._fields_collection = FieldsInKernel() self._fields_collection = FieldsInKernel()
...@@ -95,6 +99,21 @@ class KernelCreationContext: ...@@ -95,6 +99,21 @@ class KernelCreationContext:
# Symbols # Symbols
def get_symbol(self, name: str, dtype: PsType | None = None) -> PsSymbol: def get_symbol(self, name: str, dtype: PsType | None = None) -> PsSymbol:
"""Retrieve the symbol with the given name and data type from the symbol table.
If no symbol named ``name`` exists, a new symbol with the given data type is created.
If a symbol with the given ``name`` already exists and ``dtype`` is not `None`,
the given data type will be applied to it, and it is returned.
If the symbol already has a different data type, an error will be raised.
If the symbol already exists and ``dtype`` is `None`, the existing symbol is returned
without checking or altering its data type.
Args:
name: The symbol's name
dtype: The symbol's data type, or `None`
"""
if name not in self._symbols: if name not in self._symbols:
symb = PsSymbol(name, None) symb = PsSymbol(name, None)
self._symbols[name] = symb self._symbols[name] = symb
...@@ -115,12 +134,20 @@ class KernelCreationContext: ...@@ -115,12 +134,20 @@ class KernelCreationContext:
return self._symbols.get(name, None) return self._symbols.get(name, None)
def add_symbol(self, symbol: PsSymbol): def add_symbol(self, symbol: PsSymbol):
"""Add an existing symbol to the symbol table.
If a symbol with the same name already exists, an error will be raised.
"""
if symbol.name in self._symbols: if symbol.name in self._symbols:
raise PsInternalCompilerError(f"Duplicate symbol: {symbol.name}") raise PsInternalCompilerError(f"Duplicate symbol: {symbol.name}")
self._symbols[symbol.name] = symbol self._symbols[symbol.name] = symbol
def replace_symbol(self, old: PsSymbol, new: PsSymbol): def replace_symbol(self, old: PsSymbol, new: PsSymbol):
"""Replace one symbol by another.
The two symbols ``old`` and ``new`` must have the same name, but may have different data types.
"""
if old.name != new.name: if old.name != new.name:
raise PsInternalCompilerError( raise PsInternalCompilerError(
"replace_symbol: Old and new symbol must have the same name" "replace_symbol: Old and new symbol must have the same name"
...@@ -131,8 +158,30 @@ class KernelCreationContext: ...@@ -131,8 +158,30 @@ class KernelCreationContext:
self._symbols[old.name] = new self._symbols[old.name] = new
def duplicate_symbol(self, symb: PsSymbol) -> PsSymbol:
"""Canonically duplicates the given symbol.
A new symbol with the same data type, and new name ``symb.name + "__<counter>"`` is created,
added to the symbol table, and returned.
The ``counter`` reflects the number of previously created duplicates of this symbol.
"""
if (result := self._symbol_ctr_pattern.search(symb.name)) is not None:
span = result.span()
basename = symb.name[: span[0]]
else:
basename = symb.name
initial_count = self._symbol_dup_table[basename]
for i in count(initial_count):
dup_name = f"{basename}__{i}"
if self.find_symbol(dup_name) is None:
self._symbol_dup_table[basename] = i + 1
return self.get_symbol(dup_name, symb.dtype)
assert False, "unreachable code"
@property @property
def symbols(self) -> Iterable[PsSymbol]: def symbols(self) -> Iterable[PsSymbol]:
"""Return an iterable of all symbols listed in the symbol table."""
return self._symbols.values() return self._symbols.values()
# Fields and Arrays # Fields and Arrays
......
...@@ -121,7 +121,7 @@ class FullIterationSpace(IterationSpace): ...@@ -121,7 +121,7 @@ class FullIterationSpace(IterationSpace):
@staticmethod @staticmethod
def create_from_slice( def create_from_slice(
ctx: KernelCreationContext, ctx: KernelCreationContext,
iteration_slice: Sequence[slice], iteration_slice: slice | Sequence[slice],
archetype_field: Field | None = None, archetype_field: Field | None = None,
): ):
"""Create an iteration space from a sequence of slices, optionally over an archetype field. """Create an iteration space from a sequence of slices, optionally over an archetype field.
...@@ -131,6 +131,9 @@ class FullIterationSpace(IterationSpace): ...@@ -131,6 +131,9 @@ class FullIterationSpace(IterationSpace):
iteration_slice: The iteration slices for each dimension; for valid formats, see `AstFactory.parse_slice` iteration_slice: The iteration slices for each dimension; for valid formats, see `AstFactory.parse_slice`
archetype_field: Optionally, an archetype field that dictates the upper slice limits and loop order. archetype_field: Optionally, an archetype field that dictates the upper slice limits and loop order.
""" """
if isinstance(iteration_slice, slice):
iteration_slice = (iteration_slice,)
dim = len(iteration_slice) dim = len(iteration_slice)
if dim == 0: if dim == 0:
raise ValueError( raise ValueError(
......
"""
This module contains various transformation and optimization passes that can be
executed on the backend AST.
Canonical Form
==============
Many transformations in this module require that their input AST is in *canonical form*.
This means that:
- Each symbol, constant, and expression node is annotated with a data type;
- Each symbol has at most one declaration;
- Each symbol that is never written to apart from its declaration has a ``const`` type; and
- Each symbol whose type is *not* ``const`` has at least one non-declaring assignment.
The first requirement can be ensured by running the `Typifier` on each newly constructed subtree.
The other three requirements are ensured by the `CanonicalizeSymbols` pass,
which should be run first before applying any optimizing transformations.
All transformations in this module retain canonicality of the AST.
Canonicality allows transformations to forego various checks that would otherwise be necessary
to prove their legality.
Certain transformations, like the auto-vectorizer (TODO), state additional requirements, e.g.
the absence of loop-carried dependencies.
Transformations
===============
Canonicalization
----------------
.. autoclass:: CanonicalizeSymbols
:members: __call__
AST Cloning
-----------
.. autoclass:: CanonicalClone
:members: __call__
Simplifying Transformations
---------------------------
.. autoclass:: EliminateConstants
:members: __call__
.. autoclass:: EliminateBranches
:members: __call__
Code Motion
-----------
.. autoclass:: HoistLoopInvariantDeclarations
:members: __call__
Loop Reshaping Transformations
------------------------------
.. autoclass:: ReshapeLoops
:members:
Code Lowering and Materialization
---------------------------------
.. autoclass:: EraseAnonymousStructTypes
:members: __call__
.. autoclass:: SelectFunctions
:members: __call__
"""
from .canonicalize_symbols import CanonicalizeSymbols
from .canonical_clone import CanonicalClone
from .eliminate_constants import EliminateConstants from .eliminate_constants import EliminateConstants
from .eliminate_branches import EliminateBranches from .eliminate_branches import EliminateBranches
from .canonicalize_symbols import CanonicalizeSymbols
from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations
from .reshape_loops import ReshapeLoops
from .erase_anonymous_structs import EraseAnonymousStructTypes from .erase_anonymous_structs import EraseAnonymousStructTypes
from .select_functions import SelectFunctions from .select_functions import SelectFunctions
from .select_intrinsics import MaterializeVectorIntrinsics from .select_intrinsics import MaterializeVectorIntrinsics
__all__ = [ __all__ = [
"CanonicalizeSymbols",
"CanonicalClone",
"EliminateConstants", "EliminateConstants",
"EliminateBranches", "EliminateBranches",
"CanonicalizeSymbols",
"HoistLoopInvariantDeclarations", "HoistLoopInvariantDeclarations",
"ReshapeLoops",
"EraseAnonymousStructTypes", "EraseAnonymousStructTypes",
"SelectFunctions", "SelectFunctions",
"MaterializeVectorIntrinsics", "MaterializeVectorIntrinsics",
......
from typing import TypeVar, cast
from ..kernelcreation import KernelCreationContext
from ..symbols import PsSymbol
from ..exceptions import PsInternalCompilerError
from ..ast import PsAstNode
from ..ast.structural import (
PsBlock,
PsConditional,
PsLoop,
PsDeclaration,
PsAssignment,
PsComment,
)
from ..ast.expressions import PsExpression, PsSymbolExpr
__all__ = ["CanonicalClone"]
class CloneContext:
def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx
self._dup_table: dict[PsSymbol, PsSymbol] = dict()
def symbol_decl(self, declared_symbol: PsSymbol):
self._dup_table[declared_symbol] = self._ctx.duplicate_symbol(declared_symbol)
def get_replacement(self, symb: PsSymbol):
return self._dup_table.get(symb, symb)
Node_T = TypeVar("Node_T", bound=PsAstNode)
class CanonicalClone:
"""Clone a subtree, and rename all symbols declared inside it to retain canonicality."""
def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx
def __call__(self, node: Node_T) -> Node_T:
return self.visit(node, CloneContext(self._ctx))
def visit(self, node: Node_T, cc: CloneContext) -> Node_T:
match node:
case PsBlock(statements):
return cast(
Node_T, PsBlock([self.visit(stmt, cc) for stmt in statements])
)
case PsLoop(ctr, start, stop, step, body):
cc.symbol_decl(ctr.symbol)
return cast(
Node_T,
PsLoop(
self.visit(ctr, cc),
self.visit(start, cc),
self.visit(stop, cc),
self.visit(step, cc),
self.visit(body, cc),
),
)
case PsConditional(cond, then, els):
return cast(
Node_T,
PsConditional(
self.visit(cond, cc),
self.visit(then, cc),
self.visit(els, cc) if els is not None else None,
),
)
case PsComment():
return cast(Node_T, node.clone())
case PsDeclaration(lhs, rhs):
cc.symbol_decl(node.declared_symbol)
return cast(
Node_T,
PsDeclaration(
cast(PsSymbolExpr, self.visit(lhs, cc)),
self.visit(rhs, cc),
),
)
case PsAssignment(lhs, rhs):
return cast(
Node_T,
PsAssignment(
self.visit(lhs, cc),
self.visit(rhs, cc),
),
)
case PsExpression():
expr_clone = node.clone()
self._replace_symbols(expr_clone, cc)
return cast(Node_T, expr_clone)
case _:
raise PsInternalCompilerError(
f"Don't know how to canonically clone {type(node)}"
)
def _replace_symbols(self, expr: PsExpression, cc: CloneContext):
if isinstance(expr, PsSymbolExpr):
expr.symbol = cc.get_replacement(expr.symbol)
else:
for c in expr.children:
self._replace_symbols(cast(PsExpression, c), cc)
from itertools import count
from ..kernelcreation import KernelCreationContext from ..kernelcreation import KernelCreationContext
from ..symbols import PsSymbol from ..symbols import PsSymbol
from ..exceptions import PsInternalCompilerError from ..exceptions import PsInternalCompilerError
...@@ -29,14 +27,10 @@ class CanonContext: ...@@ -29,14 +27,10 @@ class CanonContext:
self.live_symbols_map[symb] = symb self.live_symbols_map[symb] = symb
return symb return symb
else: else:
for i in count(): replacement = self._ctx.duplicate_symbol(symb)
replacement_name = f"{symb.name}__{i}" self.live_symbols_map[symb] = replacement
if self._ctx.find_symbol(replacement_name) is None: self.encountered_symbols.add(replacement)
replacement = self._ctx.get_symbol(replacement_name, symb.dtype) return replacement
self.live_symbols_map[symb] = replacement
self.encountered_symbols.add(replacement)
return replacement
assert False, "unreachable code"
def mark_as_updated(self, symb: PsSymbol): def mark_as_updated(self, symb: PsSymbol):
self.updated_symbols.add(self.deduplicate(symb)) self.updated_symbols.add(self.deduplicate(symb))
......
from typing import cast, Iterable from typing import cast, Iterable, overload
from collections import defaultdict from collections import defaultdict
from ..kernelcreation import KernelCreationContext, Typifier from ..kernelcreation import KernelCreationContext, Typifier
...@@ -116,6 +116,14 @@ class EliminateConstants: ...@@ -116,6 +116,14 @@ class EliminateConstants:
self._fold_floats = False self._fold_floats = False
self._extract_constant_exprs = extract_constant_exprs self._extract_constant_exprs = extract_constant_exprs
@overload
def __call__(self, node: PsExpression) -> PsExpression:
pass
@overload
def __call__(self, node: PsAstNode) -> PsAstNode:
pass
def __call__(self, node: PsAstNode) -> PsAstNode: def __call__(self, node: PsAstNode) -> PsAstNode:
ecc = ECContext(self._ctx) ecc = ECContext(self._ctx)
......
...@@ -69,7 +69,7 @@ class HoistLoopInvariantDeclarations: ...@@ -69,7 +69,7 @@ class HoistLoopInvariantDeclarations:
in particular, each symbol may have at most one declaration. in particular, each symbol may have at most one declaration.
To ensure this, a `CanonicalizeSymbols` pass should be run before `HoistLoopInvariantDeclarations`. To ensure this, a `CanonicalizeSymbols` pass should be run before `HoistLoopInvariantDeclarations`.
`HoistLoopInvariants` assumes that all `PsMathFunction`s are pure (have no side effects), `HoistLoopInvariantDeclarations` assumes that all `PsMathFunction` s are pure (have no side effects),
but makes no such assumption about instances of `CFunction`. but makes no such assumption about instances of `CFunction`.
""" """
......
from typing import Sequence
from ..kernelcreation import KernelCreationContext, Typifier
from ..kernelcreation.ast_factory import AstFactory, IndexParsable
from ..ast.structural import PsLoop, PsBlock, PsConditional, PsDeclaration
from ..ast.expressions import PsExpression, PsConstantExpr, PsLt
from ..constants import PsConstant
from .canonical_clone import CanonicalClone, CloneContext
from .eliminate_constants import EliminateConstants
class ReshapeLoops:
"""Various transformations for reshaping loop nests."""
def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx
self._typify = Typifier(ctx)
self._factory = AstFactory(ctx)
self._canon_clone = CanonicalClone(ctx)
self._elim_constants = EliminateConstants(ctx)
def peel_loop_front(
self, loop: PsLoop, num_iterations: int, omit_range_check: bool = False
) -> tuple[Sequence[PsBlock], PsLoop]:
"""Peel off iterations from the front of a loop.
Removes ``num_iterations`` from the front of the given loop and returns them as a sequence of
independent blocks.
Args:
loop: The loop node from which to peel iterations
num_iterations: The number of iterations to peel off
omit_range_check: If set to `True`, assume that the peeled-off iterations will always
be executed, and omit their enclosing conditional.
Returns:
Tuple containing the peeled-off iterations as a sequence of blocks,
and the remaining loop.
"""
peeled_iters: list[PsBlock] = []
for i in range(num_iterations):
cc = CloneContext(self._ctx)
cc.symbol_decl(loop.counter.symbol)
peeled_ctr = self._factory.parse_index(
cc.get_replacement(loop.counter.symbol)
)
peeled_idx = self._typify(loop.start + PsExpression.make(PsConstant(i)))
counter_decl = PsDeclaration(peeled_ctr, peeled_idx)
peeled_block = self._canon_clone.visit(loop.body, cc)
if omit_range_check:
peeled_block.statements = [counter_decl] + peeled_block.statements
else:
iter_condition = PsLt(peeled_ctr, loop.stop)
peeled_block.statements = [
counter_decl,
PsConditional(iter_condition, PsBlock(peeled_block.statements)),
]
peeled_iters.append(peeled_block)
loop.start = self._elim_constants(
self._typify(loop.start + PsExpression.make(PsConstant(num_iterations)))
)
return peeled_iters, loop
def cut_loop(
self, loop: PsLoop, cutting_points: Sequence[IndexParsable]
) -> Sequence[PsLoop | PsBlock]:
"""Cut a loop at the given cutting points.
Cut the given loop at the iterations specified by the given cutting points,
producing ``n`` new subtrees representing the iterations
``(loop.start:cutting_points[0]), (cutting_points[0]:cutting_points[1]), ..., (cutting_points[-1]:loop.stop)``.
Resulting subtrees representing zero iterations are dropped; subtrees representing exactly one iteration are
returned without the trivial loop structure.
Currently, `cut_loop` performs no checks to ensure that the given cutting points are in fact inside
the loop's iteration range.
Returns:
Sequence of ``n`` subtrees representing the respective iteration ranges
"""
if not (
isinstance(loop.step, PsConstantExpr) and loop.step.constant.value == 1
):
raise NotImplementedError(
"Loop cutting for loops with step != 1 is not implemented"
)
result: list[PsLoop | PsBlock] = []
new_start = loop.start
cutting_points = [self._factory.parse_index(idx) for idx in cutting_points] + [
loop.stop
]
for new_end in cutting_points:
if new_end.structurally_equal(new_start):
continue
num_iters = self._elim_constants(self._typify(new_end - new_start))
skip = False
if isinstance(num_iters, PsConstantExpr):
if num_iters.constant.value == 0:
skip = True
elif num_iters.constant.value == 1:
skip = True
cc = CloneContext(self._ctx)
cc.symbol_decl(loop.counter.symbol)
local_counter = self._factory.parse_index(
cc.get_replacement(loop.counter.symbol)
)
ctr_decl = PsDeclaration(
local_counter,
new_start,
)
cloned_body = self._canon_clone.visit(loop.body, cc)
cloned_body.statements = [ctr_decl] + cloned_body.statements
result.append(cloned_body)
if not skip:
loop_clone = self._canon_clone(loop)
loop_clone.start = new_start.clone()
loop_clone.stop = new_end.clone()
result.append(loop_clone)
new_start = new_end
return result
...@@ -52,7 +52,7 @@ def test_invalid_slices(): ...@@ -52,7 +52,7 @@ def test_invalid_slices():
ctx.add_field(archetype_field) ctx.add_field(archetype_field)
islice = (slice(1, -1, 0.5),) islice = (slice(1, -1, 0.5),)
with pytest.raises(PsTypeError): with pytest.raises(TypeError):
FullIterationSpace.create_from_slice(ctx, islice, archetype_field) FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),) islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),)
......
import sympy as sp
from pystencils import Field, Assignment, make_slice, TypedSymbol
from pystencils.types.quick import Arr
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
FullIterationSpace,
)
from pystencils.backend.transformations import CanonicalClone
from pystencils.backend.ast.structural import PsBlock, PsComment
from pystencils.backend.ast.expressions import PsSymbolExpr
from pystencils.backend.ast.iteration import dfs_preorder
def test_clone_entire_ast():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
canon_clone = CanonicalClone(ctx)
f = Field.create_generic("f", 2, index_shape=(5,))
rho = sp.Symbol("rho")
u = sp.symbols("u_:2")
cx = TypedSymbol("cx", Arr(ctx.default_dtype))
cy = TypedSymbol("cy", Arr(ctx.default_dtype))
cxs = sp.IndexedBase(cx, shape=(5,))
cys = sp.IndexedBase(cy, shape=(5,))
rho_out = Field.create_generic("rho", 2, index_shape=(1,))
u_out = Field.create_generic("u", 2, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], f)
ctx.set_iteration_space(ispace)
asms = [
Assignment(cx, (0, 1, -1, 0, 0)),
Assignment(cy, (0, 0, 0, 1, -1)),
Assignment(rho, sum(f.center(i) for i in range(5))),
Assignment(u[0], 1 / rho * sum((f.center(i) * cxs[i]) for i in range(5))),
Assignment(u[1], 1 / rho * sum((f.center(i) * cys[i]) for i in range(5))),
Assignment(rho_out.center(0), rho),
Assignment(u_out.center(0), u[0]),
Assignment(u_out.center(1), u[1]),
]
body = PsBlock(
[PsComment("Compute and export density and velocity")]
+ [factory.parse_sympy(asm) for asm in asms]
)
ast = factory.loops_from_ispace(ispace, body)
ast_clone = canon_clone(ast)
for orig, clone in zip(dfs_preorder(ast), dfs_preorder(ast_clone), strict=True):
assert type(orig) is type(clone)
assert orig is not clone
if isinstance(orig, PsSymbolExpr):
assert isinstance(clone, PsSymbolExpr)
if orig.symbol.name in ("ctr_0", "ctr_1", "rho", "u_0", "u_1", "cx", "cy"):
assert clone.symbol.name == orig.symbol.name + "__0"
import sympy as sp
from pystencils import Field, Assignment, make_slice
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
FullIterationSpace,
)
from pystencils.backend.transformations import ReshapeLoops
from pystencils.backend.ast.structural import PsDeclaration, PsBlock, PsLoop, PsConditional
from pystencils.backend.ast.expressions import PsConstantExpr, PsLt
def test_loop_cutting():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
reshape = ReshapeLoops(ctx)
x, y, z = sp.symbols("x, y, z")
f = Field.create_generic("f", 1, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
ctx.set_iteration_space(ispace)
loop_body = PsBlock(
[
factory.parse_sympy(Assignment(x, 2 * z)),
factory.parse_sympy(Assignment(f.center(0), x + y)),
]
)
loop = factory.loops_from_ispace(ispace, loop_body)
subloops = reshape.cut_loop(loop, [1, 1, 3])
assert len(subloops) == 3
subloop = subloops[0]
assert isinstance(subloop, PsBlock)
assert isinstance(subloop.statements[0], PsDeclaration)
assert subloop.statements[0].declared_symbol.name == "ctr_0__0"
x_decl = subloop.statements[1]
assert isinstance(x_decl, PsDeclaration)
assert x_decl.declared_symbol.name == "x__0"
subloop = subloops[1]
assert isinstance(subloop, PsLoop)
assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1
assert isinstance(subloop.stop, PsConstantExpr) and subloop.stop.constant.value == 3
x_decl = subloop.body.statements[0]
assert isinstance(x_decl, PsDeclaration)
assert x_decl.declared_symbol.name == "x__1"
subloop = subloops[2]
assert isinstance(subloop, PsLoop)
assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3
assert subloop.stop.structurally_equal(loop.stop)
def test_loop_peeling():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
reshape = ReshapeLoops(ctx)
x, y, z = sp.symbols("x, y, z")
f = Field.create_generic("f", 1, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
ctx.set_iteration_space(ispace)
loop_body = PsBlock([
factory.parse_sympy(Assignment(x, 2 * z)),
factory.parse_sympy(Assignment(f.center(0), x + y)),
])
loop = factory.loops_from_ispace(ispace, loop_body)
num_iters = 3
peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters)
assert len(peeled_iters) == 3
for i, iter in enumerate(peeled_iters):
assert isinstance(iter, PsBlock)
ctr_decl = iter.statements[0]
assert isinstance(ctr_decl, PsDeclaration)
assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
cond = iter.statements[1]
assert isinstance(cond, PsConditional)
assert cond.condition.structurally_equal(PsLt(ctr_decl.lhs, loop.stop))
subblock = cond.branch_true
assert isinstance(subblock.statements[0], PsDeclaration)
assert subblock.statements[0].declared_symbol.name == f"x__{i}"
assert peeled_loop.start.structurally_equal(factory.parse_index(num_iters))
assert peeled_loop.stop.structurally_equal(loop.stop)
assert peeled_loop.body.structurally_equal(loop.body)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment