diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index 896693317c02dee67302221f810d64c01b5eb233..7222550196ccbfbc0f8b3916b9c91d7c82cb2a81 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -232,6 +232,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode): self._field.strides ) + index_dims_are_trivial = self._field.index_shape == (1,) + rank = len(self._field.shape) + for param in ppc.live_variables: if isinstance(param, SfgKernelParamVar): for prop in param.wrapped.properties: @@ -280,9 +283,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode): expr = self._extraction._extract_size(coord) if expr is None: - raise SfgException( - f"Cannot extract shape in coordinate {coord} from {self._extraction}" - ) + if index_dims_are_trivial and coord == rank - 1: + expr = AugExpr.format("1") + else: + raise SfgException( + f"Cannot extract shape in coordinate {coord} from {self._extraction}" + ) if isinstance(symb, SfgKernelParamVar) and symb not in done: done.add(symb) @@ -300,9 +306,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode): expr = self._extraction._extract_stride(coord) if expr is None: - raise SfgException( - f"Cannot extract stride in coordinate {coord} from {self._extraction}" - ) + if index_dims_are_trivial and coord == rank - 1: + expr = AugExpr.format("1") + else: + raise SfgException( + f"Cannot extract stride in coordinate {coord} from {self._extraction}" + ) if isinstance(symb, SfgKernelParamVar) and symb not in done: done.add(symb) diff --git a/src/pystencilssfg/lang/extractions.py b/src/pystencilssfg/lang/extractions.py index e920fcbfc453d53c22f0486ab7c051ed6c5a7c7f..a9da563ce69358b9b94a3ab4e36fca4c967dede4 100644 --- a/src/pystencilssfg/lang/extractions.py +++ b/src/pystencilssfg/lang/extractions.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import Protocol +from typing import Protocol, runtime_checkable from abc import abstractmethod from .expressions import AugExpr +@runtime_checkable class SupportsFieldExtraction(Protocol): """Protocol for field pointer and indexing extraction.