From 64ac6ebf82f8753958d8f487710183e9b1cb280d Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 17 Mar 2025 16:20:52 +0100 Subject: [PATCH] make field extraction protocl runtime-checkable. fix extraction of trivial index dimensions. --- src/pystencilssfg/ir/postprocessing.py | 21 +++++++++++++++------ src/pystencilssfg/lang/extractions.py | 3 ++- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index 8966933..7222550 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -232,6 +232,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode): self._field.strides ) + index_dims_are_trivial = self._field.index_shape == (1,) + rank = len(self._field.shape) + for param in ppc.live_variables: if isinstance(param, SfgKernelParamVar): for prop in param.wrapped.properties: @@ -280,9 +283,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode): expr = self._extraction._extract_size(coord) if expr is None: - raise SfgException( - f"Cannot extract shape in coordinate {coord} from {self._extraction}" - ) + if index_dims_are_trivial and coord == rank - 1: + expr = AugExpr.format("1") + else: + raise SfgException( + f"Cannot extract shape in coordinate {coord} from {self._extraction}" + ) if isinstance(symb, SfgKernelParamVar) and symb not in done: done.add(symb) @@ -300,9 +306,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode): expr = self._extraction._extract_stride(coord) if expr is None: - raise SfgException( - f"Cannot extract stride in coordinate {coord} from {self._extraction}" - ) + if index_dims_are_trivial and coord == rank - 1: + expr = AugExpr.format("1") + else: + raise SfgException( + f"Cannot extract stride in coordinate {coord} from {self._extraction}" + ) if isinstance(symb, SfgKernelParamVar) and symb not in done: done.add(symb) diff --git a/src/pystencilssfg/lang/extractions.py b/src/pystencilssfg/lang/extractions.py index e920fcb..a9da563 100644 --- a/src/pystencilssfg/lang/extractions.py +++ b/src/pystencilssfg/lang/extractions.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import Protocol +from typing import Protocol, runtime_checkable from abc import abstractmethod from .expressions import AugExpr +@runtime_checkable class SupportsFieldExtraction(Protocol): """Protocol for field pointer and indexing extraction. -- GitLab