Skip to content
Snippets Groups Projects
Commit 64ac6ebf authored by Frederik Hennig's avatar Frederik Hennig
Browse files

make field extraction protocl runtime-checkable. fix extraction of trivial index dimensions.

parent 51c03215
No related branches found
No related tags found
No related merge requests found
Pipeline #76396 failed
...@@ -232,6 +232,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode): ...@@ -232,6 +232,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
self._field.strides self._field.strides
) )
index_dims_are_trivial = self._field.index_shape == (1,)
rank = len(self._field.shape)
for param in ppc.live_variables: for param in ppc.live_variables:
if isinstance(param, SfgKernelParamVar): if isinstance(param, SfgKernelParamVar):
for prop in param.wrapped.properties: for prop in param.wrapped.properties:
...@@ -280,6 +283,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode): ...@@ -280,6 +283,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
expr = self._extraction._extract_size(coord) expr = self._extraction._extract_size(coord)
if expr is None: if expr is None:
if index_dims_are_trivial and coord == rank - 1:
expr = AugExpr.format("1")
else:
raise SfgException( raise SfgException(
f"Cannot extract shape in coordinate {coord} from {self._extraction}" f"Cannot extract shape in coordinate {coord} from {self._extraction}"
) )
...@@ -300,6 +306,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode): ...@@ -300,6 +306,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
expr = self._extraction._extract_stride(coord) expr = self._extraction._extract_stride(coord)
if expr is None: if expr is None:
if index_dims_are_trivial and coord == rank - 1:
expr = AugExpr.format("1")
else:
raise SfgException( raise SfgException(
f"Cannot extract stride in coordinate {coord} from {self._extraction}" f"Cannot extract stride in coordinate {coord} from {self._extraction}"
) )
......
from __future__ import annotations from __future__ import annotations
from typing import Protocol from typing import Protocol, runtime_checkable
from abc import abstractmethod from abc import abstractmethod
from .expressions import AugExpr from .expressions import AugExpr
@runtime_checkable
class SupportsFieldExtraction(Protocol): class SupportsFieldExtraction(Protocol):
"""Protocol for field pointer and indexing extraction. """Protocol for field pointer and indexing extraction.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment