Skip to content
Snippets Groups Projects
Commit d16ae8a6 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

ignore trivial index dimension of explicit scalar fields

parent 020347e7
No related branches found
No related tags found
1 merge request!26Fixes to postprocessing: Remove unused code, test vector extraction, unify treatment of scalar fields
...@@ -214,6 +214,15 @@ class SfgDeferredParamSetter(SfgDeferredNode): ...@@ -214,6 +214,15 @@ class SfgDeferredParamSetter(SfgDeferredNode):
class SfgDeferredFieldMapping(SfgDeferredNode): class SfgDeferredFieldMapping(SfgDeferredNode):
"""Deferred mapping of a pystencils field to a field data structure.""" """Deferred mapping of a pystencils field to a field data structure."""
# NOTE ON Scalar Fields
#
# pystencils permits explicit (`index_shape = (1,)`) and implicit (`index_shape = ()`)
# scalar fields. In order to handle both equivalently,
# we ignore the trivial explicit scalar dimension in field extraction.
# This makes sure that explicit D-dimensional scalar fields
# can be mapped onto D-dimensional data structures, and do not require that
# D+1st dimension.
def __init__( def __init__(
self, self,
psfield: Field, psfield: Field,
...@@ -227,10 +236,16 @@ class SfgDeferredFieldMapping(SfgDeferredNode): ...@@ -227,10 +236,16 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
# Find field pointer # Find field pointer
ptr: SfgKernelParamVar | None = None ptr: SfgKernelParamVar | None = None
shape: list[SfgKernelParamVar | str | None] = [None] * len(self._field.shape) rank: int
strides: list[SfgKernelParamVar | str | None] = [None] * len(
self._field.strides if self._field.index_shape == (1,):
) # explicit scalar field -> ignore index dimensions
rank = self._field.spatial_dimensions
else:
rank = len(self._field.shape)
shape: list[SfgKernelParamVar | str | None] = [None] * rank
strides: list[SfgKernelParamVar | str | None] = [None] * rank
for param in ppc.live_variables: for param in ppc.live_variables:
if isinstance(param, SfgKernelParamVar): if isinstance(param, SfgKernelParamVar):
...@@ -244,12 +259,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode): ...@@ -244,12 +259,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
strides[coord] = param # type: ignore strides[coord] = param # type: ignore
# Find constant or otherwise determined sizes # Find constant or otherwise determined sizes
for coord, s in enumerate(self._field.shape): for coord, s in enumerate(self._field.shape[:rank]):
if shape[coord] is None: if shape[coord] is None:
shape[coord] = str(s) shape[coord] = str(s)
# Find constant or otherwise determined strides # Find constant or otherwise determined strides
for coord, s in enumerate(self._field.strides): for coord, s in enumerate(self._field.strides[:rank]):
if strides[coord] is None: if strides[coord] is None:
strides[coord] = str(s) strides[coord] = str(s)
......
import sympy as sp import sympy as sp
from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type from pystencils import (
fields,
kernel,
TypedSymbol,
Field,
FieldType,
create_type,
Assignment,
)
from pystencils.types import PsCustomType from pystencils.types import PsCustomType
from pystencilssfg.composer import make_sequence from pystencilssfg.composer import make_sequence
from pystencilssfg.lang import AugExpr, SupportsFieldExtraction from pystencilssfg.lang import AugExpr, SupportsFieldExtraction
from pystencilssfg.lang.cpp import std
from pystencilssfg.ir import SfgStatements, SfgSequence from pystencilssfg.ir import SfgStatements, SfgSequence
from pystencilssfg.ir.postprocessing import CallTreePostProcessing from pystencilssfg.ir.postprocessing import CallTreePostProcessing
...@@ -100,7 +109,9 @@ def test_field_extraction(sfg): ...@@ -100,7 +109,9 @@ def test_field_extraction(sfg):
khandle = sfg.kernels.create(set_constant) khandle = sfg.kernels.create(set_constant)
extraction = DemoFieldExtraction("f") extraction = DemoFieldExtraction("f")
call_tree = make_sequence(sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle)) call_tree = make_sequence(
sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle)
)
pp = CallTreePostProcessing() pp = CallTreePostProcessing()
free_vars = pp.get_live_variables(call_tree) free_vars = pp.get_live_variables(call_tree)
...@@ -165,3 +176,39 @@ def test_duplicate_field_shapes(sfg): ...@@ -165,3 +176,39 @@ def test_duplicate_field_shapes(sfg):
for line, stmt in zip(lines_f, call_tree.children[1].children, strict=True): for line, stmt in zip(lines_f, call_tree.children[1].children, strict=True):
assert isinstance(stmt, SfgStatements) assert isinstance(stmt, SfgStatements)
assert stmt.code_string == line assert stmt.code_string == line
def test_scalar_fields(sfg):
sc_expl = Field.create_generic("f", 1, "double", index_shape=(1,))
sc_impl = Field.create_generic("f", 1, "double", index_shape=())
asm_expl = Assignment(sc_expl.center(0), 3)
asm_impl = Assignment(sc_impl.center(), 3)
k_expl = sfg.kernels.create(asm_expl, "expl")
k_impl = sfg.kernels.create(asm_impl, "impl")
tree_expl = make_sequence(
sfg.map_field(sc_expl, std.span.from_field(sc_expl)), sfg.call(k_expl)
)
tree_impl = make_sequence(
sfg.map_field(sc_impl, std.span.from_field(sc_impl)), sfg.call(k_impl)
)
pp = CallTreePostProcessing()
_ = pp.get_live_variables(tree_expl)
_ = pp.get_live_variables(tree_impl)
extraction_expl = tree_expl.children[0]
assert isinstance(extraction_expl, SfgSequence)
extraction_impl = tree_impl.children[0]
assert isinstance(extraction_impl, SfgSequence)
for node1, node2 in zip(
extraction_expl.children, extraction_impl.children, strict=True
):
assert isinstance(node1, SfgStatements)
assert isinstance(node2, SfgStatements)
assert node1.code_string == node2.code_string
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment