diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 6adac2a519ffc04505c0e0adac3484d78f30d013..2a3d2774e03160fe2012f68ecb3ded4803353304 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -11,7 +11,7 @@ from ...field import Field, FieldType from ..symbols import PsSymbol from ..constants import PsConstant -from ..ast.expressions import PsExpression, PsConstantExpr +from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem from ..arrays import PsLinearizedArray from ..ast.util import failing_cast from ...types import PsStructType, constify @@ -210,14 +210,37 @@ class FullIterationSpace(IterationSpace): return self._archetype_field def actual_iterations(self, dimension: int | None = None) -> PsExpression: + from .typification import Typifier + from ..transformations import EliminateConstants + + typify = Typifier(self._ctx) + fold = EliminateConstants(self._ctx) + if dimension is None: - return reduce( - mul, (self.actual_iterations(d) for d in range(len(self.dimensions))) + return fold( + typify( + reduce( + mul, + ( + self.actual_iterations(d) + for d in range(len(self.dimensions)) + ), + ) + ) ) else: dim = self.dimensions[dimension] one = PsConstantExpr(PsConstant(1, self._ctx.index_dtype)) - return one + (dim.stop - dim.start - one) / dim.step + zero = PsConstantExpr(PsConstant(0, self._ctx.index_dtype)) + return fold( + typify( + PsTernary( + PsEq(PsRem((dim.stop - dim.start), dim.step), zero), + (dim.stop - dim.start) / dim.step, + (dim.stop - dim.start) / dim.step + one, + ) + ) + ) def compressed_counter(self) -> PsExpression: """Expression counting the actual number of items processed at the iteration defined by the counter tuple. diff --git a/tests/nbackend/kernelcreation/test_iteration_space.py b/tests/nbackend/kernelcreation/test_iteration_space.py index 7fd6d778ff62f7fb2fcbc24a55af5225fb9f870e..f9646afc26d11bddfb49c5a178096f9d2157d5f6 100644 --- a/tests/nbackend/kernelcreation/test_iteration_space.py +++ b/tests/nbackend/kernelcreation/test_iteration_space.py @@ -1,13 +1,13 @@ import pytest -from pystencils.field import Field -from pystencils.sympyextensions.typed_sympy import TypedSymbol, create_type +from pystencils import make_slice, Field, create_type +from pystencils.sympyextensions.typed_sympy import TypedSymbol +from pystencils.backend.constants import PsConstant from pystencils.backend.kernelcreation import KernelCreationContext, FullIterationSpace - from pystencils.backend.ast.expressions import PsAdd, PsConstantExpr, PsExpression from pystencils.backend.kernelcreation.typification import TypificationError -from pystencils.types import PsTypeError +from pystencils.types.quick import Int def test_slices(): @@ -36,12 +36,12 @@ def test_slices(): op.structurally_equal(PsExpression.make(archetype_arr.shape[0])) for op in dims[0].stop.children ) - + assert isinstance(dims[1].stop, PsAdd) and any( op.structurally_equal(PsExpression.make(archetype_arr.shape[1])) for op in dims[1].stop.children ) - + assert dims[2].stop.structurally_equal(PsExpression.make(archetype_arr.shape[2])) @@ -58,3 +58,28 @@ def test_invalid_slices(): islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),) with pytest.raises(TypificationError): FullIterationSpace.create_from_slice(ctx, islice, archetype_field) + + +def test_iteration_count(): + ctx = KernelCreationContext() + + i, j, k = [PsExpression.make(ctx.get_symbol(x, ctx.index_dtype)) for x in "ijk"] + zero = PsExpression.make(PsConstant(0, ctx.index_dtype)) + two = PsExpression.make(PsConstant(2, ctx.index_dtype)) + three = PsExpression.make(PsConstant(3, ctx.index_dtype)) + + ispace = FullIterationSpace.create_from_slice( + ctx, make_slice[three : i-two, 1:8:3] + ) + + iters = [ispace.actual_iterations(coord) for coord in range(2)] + assert iters[0].structurally_equal((i - two) - three) + assert iters[1].structurally_equal(three) + + empty_ispace = FullIterationSpace.create_from_slice( + ctx, make_slice[4:4:1, 4:4:7] + ) + + iters = [empty_ispace.actual_iterations(coord) for coord in range(2)] + assert iters[0].structurally_equal(zero) + assert iters[1].structurally_equal(zero)