Skip to content
Snippets Groups Projects
Commit a86d9e09 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Fix: Parsing of negative integer slices

parent 63f396e7
No related branches found
No related tags found
3 merge requests!433Consolidate codegen and JIT modules.,!430Jupyter Inspection Framework, Book Theme, and Initial Drafts for Codegen Reference Guides,!429Iteration Slices: Extended GPU support + bugfixes
Pipeline #70376 passed
......@@ -138,6 +138,13 @@ class AstFactory:
self._typify(self.parse_index(iter_slice) + self.parse_index(1))
)
step = self.parse_index(1)
if normalize_to is not None:
upper_limit = self.parse_index(normalize_to)
if isinstance(start, PsConstantExpr) and start.constant.value < 0:
start = fold(self._typify(upper_limit.clone() + start))
stop = fold(self._typify(upper_limit.clone() + stop))
else:
start = self._parse_any_index(
iter_slice.start if iter_slice.start is not None else 0
......@@ -156,21 +163,21 @@ class AstFactory:
f"Invalid value for `slice.step`: {step.constant.value}"
)
if normalize_to is not None:
upper_limit = self.parse_index(normalize_to)
if isinstance(start, PsConstantExpr) and start.constant.value < 0:
start = fold(self._typify(upper_limit.clone() + start))
if normalize_to is not None:
upper_limit = self.parse_index(normalize_to)
if isinstance(start, PsConstantExpr) and start.constant.value < 0:
start = fold(self._typify(upper_limit.clone() + start))
if stop is None:
stop = upper_limit
elif isinstance(stop, PsConstantExpr) and stop.constant.value < 0:
stop = fold(self._typify(upper_limit.clone() + stop))
if stop is None:
stop = upper_limit
elif isinstance(stop, PsConstantExpr) and stop.constant.value < 0:
stop = fold(self._typify(upper_limit.clone() + stop))
elif stop is None:
raise ValueError(
"Cannot parse a slice with `stop == None` if no normalization limit is given"
)
elif stop is None:
raise ValueError(
"Cannot parse a slice with `stop == None` if no normalization limit is given"
)
assert stop is not None # for mypy
return start, stop, step
......
......@@ -129,6 +129,36 @@ def test_slices_with_negative_start():
)
def test_negative_singular_slices():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
archetype_field = Field.create_generic("f", spatial_dimensions=2, layout="fzyx")
ctx.add_field(archetype_field)
archetype_arr = ctx.get_buffer(archetype_field)
islice = (-2, -1)
ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
dims = ispace.dimensions
assert dims[0].start.structurally_equal(
PsExpression.make(archetype_arr.shape[0]) + factory.parse_index(-2)
)
assert dims[0].stop.structurally_equal(
PsExpression.make(archetype_arr.shape[0]) + factory.parse_index(-1)
)
assert dims[1].start.structurally_equal(
PsExpression.make(archetype_arr.shape[1]) + factory.parse_index(-1)
)
assert dims[1].stop.structurally_equal(
PsExpression.make(archetype_arr.shape[1])
)
def test_field_independent_slices():
ctx = KernelCreationContext()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment