Skip to content
Snippets Groups Projects
Commit af52d21f authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fix iteration space iteration counts.

parent 9717fd66
No related branches found
No related tags found
1 merge request!393Ternary Expressions, Improved Integer Divisions, and Iteration Space Fix
Pipeline #67189 passed
...@@ -11,7 +11,7 @@ from ...field import Field, FieldType ...@@ -11,7 +11,7 @@ from ...field import Field, FieldType
from ..symbols import PsSymbol from ..symbols import PsSymbol
from ..constants import PsConstant from ..constants import PsConstant
from ..ast.expressions import PsExpression, PsConstantExpr from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem
from ..arrays import PsLinearizedArray from ..arrays import PsLinearizedArray
from ..ast.util import failing_cast from ..ast.util import failing_cast
from ...types import PsStructType, constify from ...types import PsStructType, constify
...@@ -210,14 +210,37 @@ class FullIterationSpace(IterationSpace): ...@@ -210,14 +210,37 @@ class FullIterationSpace(IterationSpace):
return self._archetype_field return self._archetype_field
def actual_iterations(self, dimension: int | None = None) -> PsExpression: def actual_iterations(self, dimension: int | None = None) -> PsExpression:
from .typification import Typifier
from ..transformations import EliminateConstants
typify = Typifier(self._ctx)
fold = EliminateConstants(self._ctx)
if dimension is None: if dimension is None:
return reduce( return fold(
mul, (self.actual_iterations(d) for d in range(len(self.dimensions))) typify(
reduce(
mul,
(
self.actual_iterations(d)
for d in range(len(self.dimensions))
),
)
)
) )
else: else:
dim = self.dimensions[dimension] dim = self.dimensions[dimension]
one = PsConstantExpr(PsConstant(1, self._ctx.index_dtype)) one = PsConstantExpr(PsConstant(1, self._ctx.index_dtype))
return one + (dim.stop - dim.start - one) / dim.step zero = PsConstantExpr(PsConstant(0, self._ctx.index_dtype))
return fold(
typify(
PsTernary(
PsEq(PsRem((dim.stop - dim.start), dim.step), zero),
(dim.stop - dim.start) / dim.step,
(dim.stop - dim.start) / dim.step + one,
)
)
)
def compressed_counter(self) -> PsExpression: def compressed_counter(self) -> PsExpression:
"""Expression counting the actual number of items processed at the iteration defined by the counter tuple. """Expression counting the actual number of items processed at the iteration defined by the counter tuple.
......
import pytest import pytest
from pystencils.field import Field from pystencils import make_slice, Field, create_type
from pystencils.sympyextensions.typed_sympy import TypedSymbol, create_type from pystencils.sympyextensions.typed_sympy import TypedSymbol
from pystencils.backend.constants import PsConstant
from pystencils.backend.kernelcreation import KernelCreationContext, FullIterationSpace from pystencils.backend.kernelcreation import KernelCreationContext, FullIterationSpace
from pystencils.backend.ast.expressions import PsAdd, PsConstantExpr, PsExpression from pystencils.backend.ast.expressions import PsAdd, PsConstantExpr, PsExpression
from pystencils.backend.kernelcreation.typification import TypificationError from pystencils.backend.kernelcreation.typification import TypificationError
from pystencils.types import PsTypeError from pystencils.types.quick import Int
def test_slices(): def test_slices():
...@@ -36,12 +36,12 @@ def test_slices(): ...@@ -36,12 +36,12 @@ def test_slices():
op.structurally_equal(PsExpression.make(archetype_arr.shape[0])) op.structurally_equal(PsExpression.make(archetype_arr.shape[0]))
for op in dims[0].stop.children for op in dims[0].stop.children
) )
assert isinstance(dims[1].stop, PsAdd) and any( assert isinstance(dims[1].stop, PsAdd) and any(
op.structurally_equal(PsExpression.make(archetype_arr.shape[1])) op.structurally_equal(PsExpression.make(archetype_arr.shape[1]))
for op in dims[1].stop.children for op in dims[1].stop.children
) )
assert dims[2].stop.structurally_equal(PsExpression.make(archetype_arr.shape[2])) assert dims[2].stop.structurally_equal(PsExpression.make(archetype_arr.shape[2]))
...@@ -58,3 +58,28 @@ def test_invalid_slices(): ...@@ -58,3 +58,28 @@ def test_invalid_slices():
islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),) islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),)
with pytest.raises(TypificationError): with pytest.raises(TypificationError):
FullIterationSpace.create_from_slice(ctx, islice, archetype_field) FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
def test_iteration_count():
ctx = KernelCreationContext()
i, j, k = [PsExpression.make(ctx.get_symbol(x, ctx.index_dtype)) for x in "ijk"]
zero = PsExpression.make(PsConstant(0, ctx.index_dtype))
two = PsExpression.make(PsConstant(2, ctx.index_dtype))
three = PsExpression.make(PsConstant(3, ctx.index_dtype))
ispace = FullIterationSpace.create_from_slice(
ctx, make_slice[three : i-two, 1:8:3]
)
iters = [ispace.actual_iterations(coord) for coord in range(2)]
assert iters[0].structurally_equal((i - two) - three)
assert iters[1].structurally_equal(three)
empty_ispace = FullIterationSpace.create_from_slice(
ctx, make_slice[4:4:1, 4:4:7]
)
iters = [empty_ispace.actual_iterations(coord) for coord in range(2)]
assert iters[0].structurally_equal(zero)
assert iters[1].structurally_equal(zero)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment