Skip to content
Snippets Groups Projects
Commit a86d9e09 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Fix: Parsing of negative integer slices

parent 63f396e7
No related branches found
No related tags found
3 merge requests!433Consolidate codegen and JIT modules.,!430Jupyter Inspection Framework, Book Theme, and Initial Drafts for Codegen Reference Guides,!429Iteration Slices: Extended GPU support + bugfixes
Pipeline #70376 passed
...@@ -138,6 +138,13 @@ class AstFactory: ...@@ -138,6 +138,13 @@ class AstFactory:
self._typify(self.parse_index(iter_slice) + self.parse_index(1)) self._typify(self.parse_index(iter_slice) + self.parse_index(1))
) )
step = self.parse_index(1) step = self.parse_index(1)
if normalize_to is not None:
upper_limit = self.parse_index(normalize_to)
if isinstance(start, PsConstantExpr) and start.constant.value < 0:
start = fold(self._typify(upper_limit.clone() + start))
stop = fold(self._typify(upper_limit.clone() + stop))
else: else:
start = self._parse_any_index( start = self._parse_any_index(
iter_slice.start if iter_slice.start is not None else 0 iter_slice.start if iter_slice.start is not None else 0
...@@ -156,21 +163,21 @@ class AstFactory: ...@@ -156,21 +163,21 @@ class AstFactory:
f"Invalid value for `slice.step`: {step.constant.value}" f"Invalid value for `slice.step`: {step.constant.value}"
) )
if normalize_to is not None: if normalize_to is not None:
upper_limit = self.parse_index(normalize_to) upper_limit = self.parse_index(normalize_to)
if isinstance(start, PsConstantExpr) and start.constant.value < 0: if isinstance(start, PsConstantExpr) and start.constant.value < 0:
start = fold(self._typify(upper_limit.clone() + start)) start = fold(self._typify(upper_limit.clone() + start))
if stop is None: if stop is None:
stop = upper_limit stop = upper_limit
elif isinstance(stop, PsConstantExpr) and stop.constant.value < 0: elif isinstance(stop, PsConstantExpr) and stop.constant.value < 0:
stop = fold(self._typify(upper_limit.clone() + stop)) stop = fold(self._typify(upper_limit.clone() + stop))
elif stop is None:
raise ValueError(
"Cannot parse a slice with `stop == None` if no normalization limit is given"
)
elif stop is None:
raise ValueError(
"Cannot parse a slice with `stop == None` if no normalization limit is given"
)
assert stop is not None # for mypy assert stop is not None # for mypy
return start, stop, step return start, stop, step
......
...@@ -129,6 +129,36 @@ def test_slices_with_negative_start(): ...@@ -129,6 +129,36 @@ def test_slices_with_negative_start():
) )
def test_negative_singular_slices():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
archetype_field = Field.create_generic("f", spatial_dimensions=2, layout="fzyx")
ctx.add_field(archetype_field)
archetype_arr = ctx.get_buffer(archetype_field)
islice = (-2, -1)
ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
dims = ispace.dimensions
assert dims[0].start.structurally_equal(
PsExpression.make(archetype_arr.shape[0]) + factory.parse_index(-2)
)
assert dims[0].stop.structurally_equal(
PsExpression.make(archetype_arr.shape[0]) + factory.parse_index(-1)
)
assert dims[1].start.structurally_equal(
PsExpression.make(archetype_arr.shape[1]) + factory.parse_index(-1)
)
assert dims[1].stop.structurally_equal(
PsExpression.make(archetype_arr.shape[1])
)
def test_field_independent_slices(): def test_field_independent_slices():
ctx = KernelCreationContext() ctx = KernelCreationContext()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment