Skip to content
Snippets Groups Projects
Commit af52d21f authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fix iteration space iteration counts.

parent 9717fd66
No related branches found
No related tags found
1 merge request!393Ternary Expressions, Improved Integer Divisions, and Iteration Space Fix
Pipeline #67189 passed
......@@ -11,7 +11,7 @@ from ...field import Field, FieldType
from ..symbols import PsSymbol
from ..constants import PsConstant
from ..ast.expressions import PsExpression, PsConstantExpr
from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem
from ..arrays import PsLinearizedArray
from ..ast.util import failing_cast
from ...types import PsStructType, constify
......@@ -210,14 +210,37 @@ class FullIterationSpace(IterationSpace):
return self._archetype_field
def actual_iterations(self, dimension: int | None = None) -> PsExpression:
from .typification import Typifier
from ..transformations import EliminateConstants
typify = Typifier(self._ctx)
fold = EliminateConstants(self._ctx)
if dimension is None:
return reduce(
mul, (self.actual_iterations(d) for d in range(len(self.dimensions)))
return fold(
typify(
reduce(
mul,
(
self.actual_iterations(d)
for d in range(len(self.dimensions))
),
)
)
)
else:
dim = self.dimensions[dimension]
one = PsConstantExpr(PsConstant(1, self._ctx.index_dtype))
return one + (dim.stop - dim.start - one) / dim.step
zero = PsConstantExpr(PsConstant(0, self._ctx.index_dtype))
return fold(
typify(
PsTernary(
PsEq(PsRem((dim.stop - dim.start), dim.step), zero),
(dim.stop - dim.start) / dim.step,
(dim.stop - dim.start) / dim.step + one,
)
)
)
def compressed_counter(self) -> PsExpression:
"""Expression counting the actual number of items processed at the iteration defined by the counter tuple.
......
import pytest
from pystencils.field import Field
from pystencils.sympyextensions.typed_sympy import TypedSymbol, create_type
from pystencils import make_slice, Field, create_type
from pystencils.sympyextensions.typed_sympy import TypedSymbol
from pystencils.backend.constants import PsConstant
from pystencils.backend.kernelcreation import KernelCreationContext, FullIterationSpace
from pystencils.backend.ast.expressions import PsAdd, PsConstantExpr, PsExpression
from pystencils.backend.kernelcreation.typification import TypificationError
from pystencils.types import PsTypeError
from pystencils.types.quick import Int
def test_slices():
......@@ -36,12 +36,12 @@ def test_slices():
op.structurally_equal(PsExpression.make(archetype_arr.shape[0]))
for op in dims[0].stop.children
)
assert isinstance(dims[1].stop, PsAdd) and any(
op.structurally_equal(PsExpression.make(archetype_arr.shape[1]))
for op in dims[1].stop.children
)
assert dims[2].stop.structurally_equal(PsExpression.make(archetype_arr.shape[2]))
......@@ -58,3 +58,28 @@ def test_invalid_slices():
islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),)
with pytest.raises(TypificationError):
FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
def test_iteration_count():
ctx = KernelCreationContext()
i, j, k = [PsExpression.make(ctx.get_symbol(x, ctx.index_dtype)) for x in "ijk"]
zero = PsExpression.make(PsConstant(0, ctx.index_dtype))
two = PsExpression.make(PsConstant(2, ctx.index_dtype))
three = PsExpression.make(PsConstant(3, ctx.index_dtype))
ispace = FullIterationSpace.create_from_slice(
ctx, make_slice[three : i-two, 1:8:3]
)
iters = [ispace.actual_iterations(coord) for coord in range(2)]
assert iters[0].structurally_equal((i - two) - three)
assert iters[1].structurally_equal(three)
empty_ispace = FullIterationSpace.create_from_slice(
ctx, make_slice[4:4:1, 4:4:7]
)
iters = [empty_ispace.actual_iterations(coord) for coord in range(2)]
assert iters[0].structurally_equal(zero)
assert iters[1].structurally_equal(zero)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment