Skip to content
Snippets Groups Projects

Fixes to postprocessing: Remove unused code, test vector extraction, unify treatment of scalar fields

Merged Frederik Hennig requested to merge fhennig/postprocessing-fixes into master
6 files
+ 128
45
Compare changes
  • Side-by-side
  • Inline
Files
6
@@ -27,38 +27,6 @@ from ..lang import (
)
class FlattenSequences:
"""Flattens any nested sequences occuring in a kernel call tree."""
def __call__(self, node: SfgCallTreeNode) -> None:
self.visit(node)
def visit(self, node: SfgCallTreeNode):
match node:
case SfgSequence():
self.flatten(node)
case _:
for c in node.children:
self.visit(c)
def flatten(self, sequence: SfgSequence) -> None:
children_flattened: list[SfgCallTreeNode] = []
def flatten(seq: SfgSequence):
for c in seq.children:
if isinstance(c, SfgSequence):
flatten(c)
else:
children_flattened.append(c)
flatten(sequence)
for c in children_flattened:
self.visit(c)
sequence.children = children_flattened
class PostProcessingContext:
def __init__(self) -> None:
self._live_variables: dict[str, SfgVar] = dict()
@@ -129,9 +97,6 @@ class PostProcessingResult:
class CallTreePostProcessing:
def __init__(self):
self._flattener = FlattenSequences()
def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult:
live_vars = self.get_live_variables(ast)
return PostProcessingResult(live_vars)
@@ -214,6 +179,15 @@ class SfgDeferredParamSetter(SfgDeferredNode):
class SfgDeferredFieldMapping(SfgDeferredNode):
"""Deferred mapping of a pystencils field to a field data structure."""
# NOTE ON Scalar Fields
#
# pystencils permits explicit (`index_shape = (1,)`) and implicit (`index_shape = ()`)
# scalar fields. In order to handle both equivalently,
# we ignore the trivial explicit scalar dimension in field extraction.
# This makes sure that explicit D-dimensional scalar fields
# can be mapped onto D-dimensional data structures, and do not require that
# D+1st dimension.
def __init__(
self,
psfield: Field,
@@ -227,10 +201,16 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
# Find field pointer
ptr: SfgKernelParamVar | None = None
shape: list[SfgKernelParamVar | str | None] = [None] * len(self._field.shape)
strides: list[SfgKernelParamVar | str | None] = [None] * len(
self._field.strides
)
rank: int
if self._field.index_shape == (1,):
# explicit scalar field -> ignore index dimensions
rank = self._field.spatial_dimensions
else:
rank = len(self._field.shape)
shape: list[SfgKernelParamVar | str | None] = [None] * rank
strides: list[SfgKernelParamVar | str | None] = [None] * rank
for param in ppc.live_variables:
if isinstance(param, SfgKernelParamVar):
@@ -244,12 +224,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
strides[coord] = param # type: ignore
# Find constant or otherwise determined sizes
for coord, s in enumerate(self._field.shape):
for coord, s in enumerate(self._field.shape[:rank]):
if shape[coord] is None:
shape[coord] = str(s)
# Find constant or otherwise determined strides
for coord, s in enumerate(self._field.strides):
for coord, s in enumerate(self._field.strides[:rank]):
if strides[coord] is None:
strides[coord] = str(s)
Loading