Skip to content
Snippets Groups Projects
Commit 64ac6ebf authored by Frederik Hennig's avatar Frederik Hennig
Browse files

make field extraction protocl runtime-checkable. fix extraction of trivial index dimensions.

parent 51c03215
No related merge requests found
Pipeline #76396 failed with stages
in 4 minutes and 35 seconds
......@@ -232,6 +232,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
self._field.strides
)
index_dims_are_trivial = self._field.index_shape == (1,)
rank = len(self._field.shape)
for param in ppc.live_variables:
if isinstance(param, SfgKernelParamVar):
for prop in param.wrapped.properties:
......@@ -280,9 +283,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
expr = self._extraction._extract_size(coord)
if expr is None:
raise SfgException(
f"Cannot extract shape in coordinate {coord} from {self._extraction}"
)
if index_dims_are_trivial and coord == rank - 1:
expr = AugExpr.format("1")
else:
raise SfgException(
f"Cannot extract shape in coordinate {coord} from {self._extraction}"
)
if isinstance(symb, SfgKernelParamVar) and symb not in done:
done.add(symb)
......@@ -300,9 +306,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
expr = self._extraction._extract_stride(coord)
if expr is None:
raise SfgException(
f"Cannot extract stride in coordinate {coord} from {self._extraction}"
)
if index_dims_are_trivial and coord == rank - 1:
expr = AugExpr.format("1")
else:
raise SfgException(
f"Cannot extract stride in coordinate {coord} from {self._extraction}"
)
if isinstance(symb, SfgKernelParamVar) and symb not in done:
done.add(symb)
......
from __future__ import annotations
from typing import Protocol
from typing import Protocol, runtime_checkable
from abc import abstractmethod
from .expressions import AugExpr
@runtime_checkable
class SupportsFieldExtraction(Protocol):
"""Protocol for field pointer and indexing extraction.
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment