diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index 896693317c02dee67302221f810d64c01b5eb233..d86dedc140a7cde8d18c78781aae1dc039757dad 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -214,6 +214,15 @@ class SfgDeferredParamSetter(SfgDeferredNode): class SfgDeferredFieldMapping(SfgDeferredNode): """Deferred mapping of a pystencils field to a field data structure.""" + # NOTE ON Scalar Fields + # + # pystencils permits explicit (`index_shape = (1,)`) and implicit (`index_shape = ()`) + # scalar fields. In order to handle both equivalently, + # we ignore the trivial explicit scalar dimension in field extraction. + # This makes sure that explicit D-dimensional scalar fields + # can be mapped onto D-dimensional data structures, and do not require that + # D+1st dimension. + def __init__( self, psfield: Field, @@ -227,10 +236,16 @@ class SfgDeferredFieldMapping(SfgDeferredNode): def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: # Find field pointer ptr: SfgKernelParamVar | None = None - shape: list[SfgKernelParamVar | str | None] = [None] * len(self._field.shape) - strides: list[SfgKernelParamVar | str | None] = [None] * len( - self._field.strides - ) + rank: int + + if self._field.index_shape == (1,): + # explicit scalar field -> ignore index dimensions + rank = self._field.spatial_dimensions + else: + rank = len(self._field.shape) + + shape: list[SfgKernelParamVar | str | None] = [None] * rank + strides: list[SfgKernelParamVar | str | None] = [None] * rank for param in ppc.live_variables: if isinstance(param, SfgKernelParamVar): @@ -244,12 +259,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode): strides[coord] = param # type: ignore # Find constant or otherwise determined sizes - for coord, s in enumerate(self._field.shape): + for coord, s in enumerate(self._field.shape[:rank]): if shape[coord] is None: shape[coord] = str(s) # Find constant or otherwise determined strides - for coord, s in enumerate(self._field.strides): + for coord, s in enumerate(self._field.strides[:rank]): if strides[coord] is None: strides[coord] = str(s) diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py index 5a9150b63b2a11dc910b9bbfc1c917487f5d1196..1b057bc56ca1ec80cebb9edef2e12e6d6a5872d1 100644 --- a/tests/ir/test_postprocessing.py +++ b/tests/ir/test_postprocessing.py @@ -1,10 +1,19 @@ import sympy as sp -from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type +from pystencils import ( + fields, + kernel, + TypedSymbol, + Field, + FieldType, + create_type, + Assignment, +) from pystencils.types import PsCustomType from pystencilssfg.composer import make_sequence from pystencilssfg.lang import AugExpr, SupportsFieldExtraction +from pystencilssfg.lang.cpp import std from pystencilssfg.ir import SfgStatements, SfgSequence from pystencilssfg.ir.postprocessing import CallTreePostProcessing @@ -100,7 +109,9 @@ def test_field_extraction(sfg): khandle = sfg.kernels.create(set_constant) extraction = DemoFieldExtraction("f") - call_tree = make_sequence(sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle)) + call_tree = make_sequence( + sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle) + ) pp = CallTreePostProcessing() free_vars = pp.get_live_variables(call_tree) @@ -165,3 +176,39 @@ def test_duplicate_field_shapes(sfg): for line, stmt in zip(lines_f, call_tree.children[1].children, strict=True): assert isinstance(stmt, SfgStatements) assert stmt.code_string == line + + +def test_scalar_fields(sfg): + sc_expl = Field.create_generic("f", 1, "double", index_shape=(1,)) + sc_impl = Field.create_generic("f", 1, "double", index_shape=()) + + asm_expl = Assignment(sc_expl.center(0), 3) + asm_impl = Assignment(sc_impl.center(), 3) + + k_expl = sfg.kernels.create(asm_expl, "expl") + k_impl = sfg.kernels.create(asm_impl, "impl") + + tree_expl = make_sequence( + sfg.map_field(sc_expl, std.span.from_field(sc_expl)), sfg.call(k_expl) + ) + + tree_impl = make_sequence( + sfg.map_field(sc_impl, std.span.from_field(sc_impl)), sfg.call(k_impl) + ) + + pp = CallTreePostProcessing() + _ = pp.get_live_variables(tree_expl) + _ = pp.get_live_variables(tree_impl) + + extraction_expl = tree_expl.children[0] + assert isinstance(extraction_expl, SfgSequence) + + extraction_impl = tree_impl.children[0] + assert isinstance(extraction_impl, SfgSequence) + + for node1, node2 in zip( + extraction_expl.children, extraction_impl.children, strict=True + ): + assert isinstance(node1, SfgStatements) + assert isinstance(node2, SfgStatements) + assert node1.code_string == node2.code_string