Skip to content
Snippets Groups Projects
Commit b1c47558 authored by Christoph Alt's avatar Christoph Alt
Browse files

Merge branch 'fhennig/postprocessing-fixes' into 'master'

Fixes to postprocessing: Remove unused code, test vector extraction, unify treatment of scalar fields

See merge request !26
parents 020347e7 3e0c00c4
1 merge request!26Fixes to postprocessing: Remove unused code, test vector extraction, unify treatment of scalar fields
Pipeline #77278 passed with stages
in 3 minutes and 21 seconds
...@@ -27,38 +27,6 @@ from ..lang import ( ...@@ -27,38 +27,6 @@ from ..lang import (
) )
class FlattenSequences:
"""Flattens any nested sequences occuring in a kernel call tree."""
def __call__(self, node: SfgCallTreeNode) -> None:
self.visit(node)
def visit(self, node: SfgCallTreeNode):
match node:
case SfgSequence():
self.flatten(node)
case _:
for c in node.children:
self.visit(c)
def flatten(self, sequence: SfgSequence) -> None:
children_flattened: list[SfgCallTreeNode] = []
def flatten(seq: SfgSequence):
for c in seq.children:
if isinstance(c, SfgSequence):
flatten(c)
else:
children_flattened.append(c)
flatten(sequence)
for c in children_flattened:
self.visit(c)
sequence.children = children_flattened
class PostProcessingContext: class PostProcessingContext:
def __init__(self) -> None: def __init__(self) -> None:
self._live_variables: dict[str, SfgVar] = dict() self._live_variables: dict[str, SfgVar] = dict()
...@@ -129,9 +97,6 @@ class PostProcessingResult: ...@@ -129,9 +97,6 @@ class PostProcessingResult:
class CallTreePostProcessing: class CallTreePostProcessing:
def __init__(self):
self._flattener = FlattenSequences()
def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult: def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult:
live_vars = self.get_live_variables(ast) live_vars = self.get_live_variables(ast)
return PostProcessingResult(live_vars) return PostProcessingResult(live_vars)
...@@ -214,6 +179,15 @@ class SfgDeferredParamSetter(SfgDeferredNode): ...@@ -214,6 +179,15 @@ class SfgDeferredParamSetter(SfgDeferredNode):
class SfgDeferredFieldMapping(SfgDeferredNode): class SfgDeferredFieldMapping(SfgDeferredNode):
"""Deferred mapping of a pystencils field to a field data structure.""" """Deferred mapping of a pystencils field to a field data structure."""
# NOTE ON Scalar Fields
#
# pystencils permits explicit (`index_shape = (1,)`) and implicit (`index_shape = ()`)
# scalar fields. In order to handle both equivalently,
# we ignore the trivial explicit scalar dimension in field extraction.
# This makes sure that explicit D-dimensional scalar fields
# can be mapped onto D-dimensional data structures, and do not require that
# D+1st dimension.
def __init__( def __init__(
self, self,
psfield: Field, psfield: Field,
...@@ -227,10 +201,16 @@ class SfgDeferredFieldMapping(SfgDeferredNode): ...@@ -227,10 +201,16 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
# Find field pointer # Find field pointer
ptr: SfgKernelParamVar | None = None ptr: SfgKernelParamVar | None = None
shape: list[SfgKernelParamVar | str | None] = [None] * len(self._field.shape) rank: int
strides: list[SfgKernelParamVar | str | None] = [None] * len(
self._field.strides if self._field.index_shape == (1,):
) # explicit scalar field -> ignore index dimensions
rank = self._field.spatial_dimensions
else:
rank = len(self._field.shape)
shape: list[SfgKernelParamVar | str | None] = [None] * rank
strides: list[SfgKernelParamVar | str | None] = [None] * rank
for param in ppc.live_variables: for param in ppc.live_variables:
if isinstance(param, SfgKernelParamVar): if isinstance(param, SfgKernelParamVar):
...@@ -244,12 +224,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode): ...@@ -244,12 +224,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
strides[coord] = param # type: ignore strides[coord] = param # type: ignore
# Find constant or otherwise determined sizes # Find constant or otherwise determined sizes
for coord, s in enumerate(self._field.shape): for coord, s in enumerate(self._field.shape[:rank]):
if shape[coord] is None: if shape[coord] is None:
shape[coord] = str(s) shape[coord] = str(s)
# Find constant or otherwise determined strides # Find constant or otherwise determined strides
for coord, s in enumerate(self._field.strides): for coord, s in enumerate(self._field.strides[:rank]):
if strides[coord] is None: if strides[coord] is None:
strides[coord] = str(s) strides[coord] = str(s)
......
from __future__ import annotations from __future__ import annotations
from typing import Protocol from typing import Protocol, runtime_checkable
from abc import abstractmethod from abc import abstractmethod
from .expressions import AugExpr from .expressions import AugExpr
@runtime_checkable
class SupportsFieldExtraction(Protocol): class SupportsFieldExtraction(Protocol):
"""Protocol for field pointer and indexing extraction. """Protocol for field pointer and indexing extraction.
...@@ -13,7 +14,7 @@ class SupportsFieldExtraction(Protocol): ...@@ -13,7 +14,7 @@ class SupportsFieldExtraction(Protocol):
They can therefore be passed to `sfg.map_field <SfgBasicComposer.map_field>`. They can therefore be passed to `sfg.map_field <SfgBasicComposer.map_field>`.
""" """
# how-to-guide begin # how-to-guide begin
@abstractmethod @abstractmethod
def _extract_ptr(self) -> AugExpr: def _extract_ptr(self) -> AugExpr:
"""Extract the field base pointer. """Extract the field base pointer.
...@@ -47,9 +48,12 @@ class SupportsFieldExtraction(Protocol): ...@@ -47,9 +48,12 @@ class SupportsFieldExtraction(Protocol):
:meta public: :meta public:
""" """
# how-to-guide end # how-to-guide end
@runtime_checkable
class SupportsVectorExtraction(Protocol): class SupportsVectorExtraction(Protocol):
"""Protocol for component extraction from a vector. """Protocol for component extraction from a vector.
......
...@@ -84,6 +84,7 @@ NestedNamespaces: ...@@ -84,6 +84,7 @@ NestedNamespaces:
ScaleKernel: ScaleKernel:
JacobiMdspan: JacobiMdspan:
StlContainers1D: StlContainers1D:
VectorExtraction:
# std::mdspan # std::mdspan
......
#include "VectorExtraction.hpp"
#include <experimental/mdspan>
#include <memory>
#include <vector>
#undef NDEBUG
#include <cassert>
namespace stdex = std::experimental;
using extents_t = stdex::extents<std::int64_t, std::dynamic_extent, std::dynamic_extent, 3>;
using vector_field_t = stdex::mdspan<double, extents_t, stdex::layout_right>;
constexpr size_t N{41};
int main(void)
{
auto u_data = std::make_unique<double[]>(N * N * 3);
vector_field_t u_field{u_data.get(), extents_t{N, N}};
std::vector<double> v{3.1, 3.2, 3.4};
gen::invoke(u_field, v);
for (size_t j = 0; j < N; ++j)
for (size_t i = 0; i < N; ++i)
{
assert(u_field(j, i, 0) == v[0]);
assert(u_field(j, i, 1) == v[1]);
assert(u_field(j, i, 2) == v[2]);
}
}
\ No newline at end of file
from pystencilssfg import SourceFileGenerator
from pystencilssfg.lang.cpp import std
import pystencils as ps
import sympy as sp
std.mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>")
with SourceFileGenerator() as sfg:
sfg.namespace("gen")
u_field = ps.fields("u(3): double[2D]", layout="c")
u = sp.symbols("u_:3")
asms = [ps.Assignment(u_field(i), u[i]) for i in range(3)]
ker = sfg.kernels.create(asms)
sfg.function("invoke")(
sfg.map_field(u_field, std.mdspan.from_field(u_field, layout_policy="layout_right")),
sfg.map_vector(u, std.vector("double", const=True, ref=True).var("vel")),
sfg.call(ker)
)
import sympy as sp import sympy as sp
from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type from pystencils import (
fields,
kernel,
TypedSymbol,
Field,
FieldType,
create_type,
Assignment,
)
from pystencils.types import PsCustomType from pystencils.types import PsCustomType
from pystencilssfg.composer import make_sequence from pystencilssfg.composer import make_sequence
from pystencilssfg.lang import AugExpr, SupportsFieldExtraction from pystencilssfg.lang import AugExpr, SupportsFieldExtraction
from pystencilssfg.lang.cpp import std
from pystencilssfg.ir import SfgStatements, SfgSequence from pystencilssfg.ir import SfgStatements, SfgSequence
from pystencilssfg.ir.postprocessing import CallTreePostProcessing from pystencilssfg.ir.postprocessing import CallTreePostProcessing
...@@ -100,7 +109,9 @@ def test_field_extraction(sfg): ...@@ -100,7 +109,9 @@ def test_field_extraction(sfg):
khandle = sfg.kernels.create(set_constant) khandle = sfg.kernels.create(set_constant)
extraction = DemoFieldExtraction("f") extraction = DemoFieldExtraction("f")
call_tree = make_sequence(sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle)) call_tree = make_sequence(
sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle)
)
pp = CallTreePostProcessing() pp = CallTreePostProcessing()
free_vars = pp.get_live_variables(call_tree) free_vars = pp.get_live_variables(call_tree)
...@@ -165,3 +176,39 @@ def test_duplicate_field_shapes(sfg): ...@@ -165,3 +176,39 @@ def test_duplicate_field_shapes(sfg):
for line, stmt in zip(lines_f, call_tree.children[1].children, strict=True): for line, stmt in zip(lines_f, call_tree.children[1].children, strict=True):
assert isinstance(stmt, SfgStatements) assert isinstance(stmt, SfgStatements)
assert stmt.code_string == line assert stmt.code_string == line
def test_scalar_fields(sfg):
sc_expl = Field.create_generic("f", 1, "double", index_shape=(1,))
sc_impl = Field.create_generic("f", 1, "double", index_shape=())
asm_expl = Assignment(sc_expl.center(0), 3)
asm_impl = Assignment(sc_impl.center(), 3)
k_expl = sfg.kernels.create(asm_expl, "expl")
k_impl = sfg.kernels.create(asm_impl, "impl")
tree_expl = make_sequence(
sfg.map_field(sc_expl, std.span.from_field(sc_expl)), sfg.call(k_expl)
)
tree_impl = make_sequence(
sfg.map_field(sc_impl, std.span.from_field(sc_impl)), sfg.call(k_impl)
)
pp = CallTreePostProcessing()
_ = pp.get_live_variables(tree_expl)
_ = pp.get_live_variables(tree_impl)
extraction_expl = tree_expl.children[0]
assert isinstance(extraction_expl, SfgSequence)
extraction_impl = tree_impl.children[0]
assert isinstance(extraction_impl, SfgSequence)
for node1, node2 in zip(
extraction_expl.children, extraction_impl.children, strict=True
):
assert isinstance(node1, SfgStatements)
assert isinstance(node2, SfgStatements)
assert node1.code_string == node2.code_string
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment