diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index a5f68433732ebdf8e300b899fe39e7e37ccc5653..a0328a123893a9f103c0ef66aa98028cc5437708 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -97,10 +97,14 @@ class AstFactory: return PsExpression.make(PsConstant(idx, self._ctx.index_dtype)) def _parse_any_index(self, idx: Any) -> PsExpression: - return self.parse_index(cast(IndexParsable, idx)) + if not isinstance(idx, _IndexParsable): + raise TypeError(f"Cannot parse {idx} as an index expression") + return self.parse_index(idx) def parse_slice( - self, iter_slice: int | slice, upper_limit: Any | None = None + self, + iter_slice: IndexParsable | slice, + normalize_to: IndexParsable | None = None, ) -> tuple[PsExpression, PsExpression, PsExpression]: """Parse a slice to obtain start, stop and step expressions for a loop or iteration space dimension. @@ -109,45 +113,62 @@ class AstFactory: They may also be sympy expressions or integer constants, in which case they are parsed to AST objects and must also typify with the kernel creation context's ``index_dtype``. - If the slice's ``stop`` member is `None` or a negative `int`, `upper_limit` must be specified, which is then - used as the upper iteration limit as either ``upper_limit`` or ``upper_limit - stop``. - The `step` member of the slice, if it is constant, must be positive. + The slice may optionally be normalized with respect to an upper iteration limit. + If `normalize_to` is specified, negative integers in `iter_slice.start` and `iter_slice.stop` will + be added to that normalization limit. + Args: iter_slice: The iteration slice - upper_limit: Optionally, the upper iteration limit + normalize_to: The upper iteration limit with respect to which the slice should be normalized """ - if isinstance(iter_slice, int): - iter_slice = slice(iter_slice, iter_slice + 1) + from pystencils.backend.transformations import EliminateConstants - start = self._parse_any_index(iter_slice.start if iter_slice.start is not None else 0) - stop = ( - self._parse_any_index(iter_slice.stop) - if iter_slice.stop is not None - else self._parse_any_index(upper_limit) - ) - step = self._parse_any_index(iter_slice.step if iter_slice.step is not None else 1) + fold = EliminateConstants(self._ctx) - if isinstance(start, PsConstantExpr) and start.constant.value < 0: - if upper_limit is None: - raise ValueError( - "Must specify an upper iteration limit if `slice.start` is negative" - ) + start: PsExpression + stop: PsExpression | None + step: PsExpression - start = self._parse_any_index(upper_limit) + start + if not isinstance(iter_slice, slice): + start = self.parse_index(iter_slice) + stop = fold( + self._typify(self.parse_index(iter_slice) + self.parse_index(1)) + ) + step = self.parse_index(1) + else: + start = self._parse_any_index( + iter_slice.start if iter_slice.start is not None else 0 + ) + stop = ( + self._parse_any_index(iter_slice.stop) + if iter_slice.stop is not None + else None + ) + step = self._parse_any_index( + iter_slice.step if iter_slice.step is not None else 1 + ) - if isinstance(stop, PsConstantExpr) and stop.constant.value < 0: - if upper_limit is None: + if isinstance(step, PsConstantExpr) and step.constant.value <= 0: raise ValueError( - "Must specify an upper iteration limit if `slice.stop` is negative" + f"Invalid value for `slice.step`: {step.constant.value}" ) - stop = self._parse_any_index(upper_limit) + stop - if isinstance(step, PsConstantExpr) and step.constant.value <= 0: + if normalize_to is not None: + upper_limit = self.parse_index(normalize_to) + if isinstance(start, PsConstantExpr) and start.constant.value < 0: + start = fold(self._typify(upper_limit.clone() + start)) + + if stop is None: + stop = upper_limit + elif isinstance(stop, PsConstantExpr) and stop.constant.value < 0: + stop = fold(self._typify(upper_limit.clone() + stop)) + + elif stop is None: raise ValueError( - f"Invalid value for `slice.step`: {step.constant.value}" + "Cannot parse a slice with `stop == None` if no normalization limit is given" ) return start, stop, step diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index bfebd5af6925f560eae9bbddc7a75f41d2e5a876..8175fffed9782d9271262b7bc220e4dcdc208705 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -121,7 +121,7 @@ class FullIterationSpace(IterationSpace): @staticmethod def create_from_slice( ctx: KernelCreationContext, - iteration_slice: slice | Sequence[slice], + iteration_slice: int | slice | tuple[int | slice, ...], archetype_field: Field | None = None, ): """Create an iteration space from a sequence of slices, optionally over an archetype field. @@ -131,7 +131,7 @@ class FullIterationSpace(IterationSpace): iteration_slice: The iteration slices for each dimension; for valid formats, see `AstFactory.parse_slice` archetype_field: Optionally, an archetype field that dictates the upper slice limits and loop order. """ - if isinstance(iteration_slice, slice): + if not isinstance(iteration_slice, tuple): iteration_slice = (iteration_slice,) dim = len(iteration_slice) @@ -163,7 +163,9 @@ class FullIterationSpace(IterationSpace): factory = AstFactory(ctx) - def to_dim(slic: slice, size: PsSymbol | PsConstant | None, ctr: PsSymbol): + def to_dim( + slic: int | slice, size: PsSymbol | PsConstant | None, ctr: PsSymbol + ): start, stop, step = factory.parse_slice(slic, size) return FullIterationSpace.Dimension(start, stop, step, ctr) @@ -393,7 +395,7 @@ def create_full_iteration_space( ctx: KernelCreationContext, assignments: AssignmentCollection, ghost_layers: None | int | Sequence[int | tuple[int, int]] = None, - iteration_slice: None | Sequence[slice] = None, + iteration_slice: None | int | slice | tuple[int | slice, ...] = None, ) -> IterationSpace: assert not ctx.fields.index_fields @@ -443,7 +445,9 @@ def create_full_iteration_space( ) else: if len(domain_field_accesses) > 0: - inferred_gls = max([fa.required_ghost_layers for fa in domain_field_accesses]) + inferred_gls = max( + [fa.required_ghost_layers for fa in domain_field_accesses] + ) else: inferred_gls = 0 diff --git a/tests/nbackend/kernelcreation/test_iteration_space.py b/tests/nbackend/kernelcreation/test_iteration_space.py index d816397a2bfa052922927853b7dd71458f648db1..8ff678fbad398af18b71f2f30145ff34ab269685 100644 --- a/tests/nbackend/kernelcreation/test_iteration_space.py +++ b/tests/nbackend/kernelcreation/test_iteration_space.py @@ -1,16 +1,20 @@ import pytest +import numpy as np from pystencils import make_slice, Field, create_type from pystencils.sympyextensions.typed_sympy import TypedSymbol from pystencils.backend.constants import PsConstant -from pystencils.backend.kernelcreation import KernelCreationContext, FullIterationSpace, AstFactory +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + FullIterationSpace, + AstFactory, +) from pystencils.backend.ast.expressions import PsAdd, PsConstantExpr, PsExpression from pystencils.backend.kernelcreation.typification import TypificationError -from pystencils.types.quick import Int -def test_slices(): +def test_slices_over_field(): ctx = KernelCreationContext() archetype_field = Field.create_generic("f", spatial_dimensions=3, layout="fzyx") @@ -23,7 +27,7 @@ def test_slices(): dims = ispace.dimensions - for sl, size, dim in zip(islice, archetype_arr.shape, dims): + for sl, dim in zip(islice, dims): assert ( isinstance(dim.start, PsConstantExpr) and dim.start.constant.value == sl.start @@ -45,7 +49,39 @@ def test_slices(): assert dims[2].stop.structurally_equal(PsExpression.make(archetype_arr.shape[2])) -def test_singular_slice(): +def test_slices_with_fixed_size_field(): + ctx = KernelCreationContext() + + archetype_field = Field.create_fixed_size("f", (4, 5, 6), layout="fzyx") + ctx.add_field(archetype_field) + + islice = (slice(1, -1, 1), slice(3, -3, 3), slice(0, None, 1)) + ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) + + archetype_arr = ctx.get_array(archetype_field) + + dims = ispace.dimensions + + for sl, size, dim in zip(islice, archetype_arr.shape, dims): + assert ( + isinstance(dim.start, PsConstantExpr) + and dim.start.constant.value == sl.start + ) + + assert isinstance(size, PsConstant) + + assert isinstance( + dim.stop, PsConstantExpr + ) and dim.stop.constant.value == np.int64( + size.value + sl.stop if sl.stop is not None else size.value + ) + + assert ( + isinstance(dim.step, PsConstantExpr) and dim.step.constant.value == sl.step + ) + + +def test_singular_slice_over_field(): ctx = KernelCreationContext() factory = AstFactory(ctx) @@ -93,6 +129,25 @@ def test_slices_with_negative_start(): ) +def test_field_independent_slices(): + ctx = KernelCreationContext() + + islice = (slice(-3, -1, 1), slice(-4, 7, 2)) + ispace = FullIterationSpace.create_from_slice(ctx, islice) + + dims = ispace.dimensions + + for sl, dim in zip(islice, dims): + assert isinstance(dim.start, PsConstantExpr) + assert dim.start.constant.value == np.int64(sl.start) + + assert isinstance(dim.stop, PsConstantExpr) + assert dim.stop.constant.value == np.int64(sl.stop) + + assert isinstance(dim.step, PsConstantExpr) + assert dim.step.constant.value == np.int64(sl.step) + + def test_invalid_slices(): ctx = KernelCreationContext() @@ -125,16 +180,14 @@ def test_iteration_count(): three = PsExpression.make(PsConstant(3, ctx.index_dtype)) ispace = FullIterationSpace.create_from_slice( - ctx, make_slice[three : i-two, 1:8:3] + ctx, make_slice[three : i - two, 1:8:3] ) iters = [ispace.actual_iterations(coord) for coord in range(2)] assert iters[0].structurally_equal((i - two) - three) assert iters[1].structurally_equal(three) - empty_ispace = FullIterationSpace.create_from_slice( - ctx, make_slice[4:4:1, 4:4:7] - ) + empty_ispace = FullIterationSpace.create_from_slice(ctx, make_slice[4:4:1, 4:4:7]) iters = [empty_ispace.actual_iterations(coord) for coord in range(2)] assert iters[0].structurally_equal(zero)