diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index 2462e5e66ea1a55cd638df07f645b213dd37d68f..5a7084457c4e251c83b588948a283fa3123773f9 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -138,6 +138,13 @@ class AstFactory: self._typify(self.parse_index(iter_slice) + self.parse_index(1)) ) step = self.parse_index(1) + + if normalize_to is not None: + upper_limit = self.parse_index(normalize_to) + if isinstance(start, PsConstantExpr) and start.constant.value < 0: + start = fold(self._typify(upper_limit.clone() + start)) + stop = fold(self._typify(upper_limit.clone() + stop)) + else: start = self._parse_any_index( iter_slice.start if iter_slice.start is not None else 0 @@ -156,21 +163,21 @@ class AstFactory: f"Invalid value for `slice.step`: {step.constant.value}" ) - if normalize_to is not None: - upper_limit = self.parse_index(normalize_to) - if isinstance(start, PsConstantExpr) and start.constant.value < 0: - start = fold(self._typify(upper_limit.clone() + start)) + if normalize_to is not None: + upper_limit = self.parse_index(normalize_to) + if isinstance(start, PsConstantExpr) and start.constant.value < 0: + start = fold(self._typify(upper_limit.clone() + start)) - if stop is None: - stop = upper_limit - elif isinstance(stop, PsConstantExpr) and stop.constant.value < 0: - stop = fold(self._typify(upper_limit.clone() + stop)) + if stop is None: + stop = upper_limit + elif isinstance(stop, PsConstantExpr) and stop.constant.value < 0: + stop = fold(self._typify(upper_limit.clone() + stop)) + + elif stop is None: + raise ValueError( + "Cannot parse a slice with `stop == None` if no normalization limit is given" + ) - elif stop is None: - raise ValueError( - "Cannot parse a slice with `stop == None` if no normalization limit is given" - ) - assert stop is not None # for mypy return start, stop, step diff --git a/tests/nbackend/kernelcreation/test_iteration_space.py b/tests/nbackend/kernelcreation/test_iteration_space.py index 5d56abd2b818fa74fbd48aac0216d472112f8c64..abc1c9820002eb08454ef5bac7c0e1ba2bfca3ba 100644 --- a/tests/nbackend/kernelcreation/test_iteration_space.py +++ b/tests/nbackend/kernelcreation/test_iteration_space.py @@ -129,6 +129,36 @@ def test_slices_with_negative_start(): ) +def test_negative_singular_slices(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + archetype_field = Field.create_generic("f", spatial_dimensions=2, layout="fzyx") + ctx.add_field(archetype_field) + archetype_arr = ctx.get_buffer(archetype_field) + + islice = (-2, -1) + ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) + + dims = ispace.dimensions + + assert dims[0].start.structurally_equal( + PsExpression.make(archetype_arr.shape[0]) + factory.parse_index(-2) + ) + + assert dims[0].stop.structurally_equal( + PsExpression.make(archetype_arr.shape[0]) + factory.parse_index(-1) + ) + + assert dims[1].start.structurally_equal( + PsExpression.make(archetype_arr.shape[1]) + factory.parse_index(-1) + ) + + assert dims[1].stop.structurally_equal( + PsExpression.make(archetype_arr.shape[1]) + ) + + def test_field_independent_slices(): ctx = KernelCreationContext()