From 64ac6ebf82f8753958d8f487710183e9b1cb280d Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 17 Mar 2025 16:20:52 +0100
Subject: [PATCH] make field extraction protocl runtime-checkable. fix
 extraction of trivial index dimensions.

---
 src/pystencilssfg/ir/postprocessing.py | 21 +++++++++++++++------
 src/pystencilssfg/lang/extractions.py  |  3 ++-
 2 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py
index 8966933..7222550 100644
--- a/src/pystencilssfg/ir/postprocessing.py
+++ b/src/pystencilssfg/ir/postprocessing.py
@@ -232,6 +232,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
             self._field.strides
         )
 
+        index_dims_are_trivial = self._field.index_shape == (1,)
+        rank = len(self._field.shape)
+
         for param in ppc.live_variables:
             if isinstance(param, SfgKernelParamVar):
                 for prop in param.wrapped.properties:
@@ -280,9 +283,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
             expr = self._extraction._extract_size(coord)
 
             if expr is None:
-                raise SfgException(
-                    f"Cannot extract shape in coordinate {coord} from {self._extraction}"
-                )
+                if index_dims_are_trivial and coord == rank - 1:
+                    expr = AugExpr.format("1")
+                else:
+                    raise SfgException(
+                        f"Cannot extract shape in coordinate {coord} from {self._extraction}"
+                    )
 
             if isinstance(symb, SfgKernelParamVar) and symb not in done:
                 done.add(symb)
@@ -300,9 +306,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
             expr = self._extraction._extract_stride(coord)
 
             if expr is None:
-                raise SfgException(
-                    f"Cannot extract stride in coordinate {coord} from {self._extraction}"
-                )
+                if index_dims_are_trivial and coord == rank - 1:
+                    expr = AugExpr.format("1")
+                else:
+                    raise SfgException(
+                        f"Cannot extract stride in coordinate {coord} from {self._extraction}"
+                    )
 
             if isinstance(symb, SfgKernelParamVar) and symb not in done:
                 done.add(symb)
diff --git a/src/pystencilssfg/lang/extractions.py b/src/pystencilssfg/lang/extractions.py
index e920fcb..a9da563 100644
--- a/src/pystencilssfg/lang/extractions.py
+++ b/src/pystencilssfg/lang/extractions.py
@@ -1,10 +1,11 @@
 from __future__ import annotations
-from typing import Protocol
+from typing import Protocol, runtime_checkable
 from abc import abstractmethod
 
 from .expressions import AugExpr
 
 
+@runtime_checkable
 class SupportsFieldExtraction(Protocol):
     """Protocol for field pointer and indexing extraction.
 
-- 
GitLab