Skip to content
Snippets Groups Projects
Commit b1c47558 authored by Christoph Alt's avatar Christoph Alt
Browse files

Merge branch 'fhennig/postprocessing-fixes' into 'master'

Fixes to postprocessing: Remove unused code, test vector extraction, unify treatment of scalar fields

See merge request !26
parents 020347e7 3e0c00c4
1 merge request!26Fixes to postprocessing: Remove unused code, test vector extraction, unify treatment of scalar fields
Pipeline #77278 passed with stages
in 3 minutes and 21 seconds
......@@ -27,38 +27,6 @@ from ..lang import (
)
class FlattenSequences:
"""Flattens any nested sequences occuring in a kernel call tree."""
def __call__(self, node: SfgCallTreeNode) -> None:
self.visit(node)
def visit(self, node: SfgCallTreeNode):
match node:
case SfgSequence():
self.flatten(node)
case _:
for c in node.children:
self.visit(c)
def flatten(self, sequence: SfgSequence) -> None:
children_flattened: list[SfgCallTreeNode] = []
def flatten(seq: SfgSequence):
for c in seq.children:
if isinstance(c, SfgSequence):
flatten(c)
else:
children_flattened.append(c)
flatten(sequence)
for c in children_flattened:
self.visit(c)
sequence.children = children_flattened
class PostProcessingContext:
def __init__(self) -> None:
self._live_variables: dict[str, SfgVar] = dict()
......@@ -129,9 +97,6 @@ class PostProcessingResult:
class CallTreePostProcessing:
def __init__(self):
self._flattener = FlattenSequences()
def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult:
live_vars = self.get_live_variables(ast)
return PostProcessingResult(live_vars)
......@@ -214,6 +179,15 @@ class SfgDeferredParamSetter(SfgDeferredNode):
class SfgDeferredFieldMapping(SfgDeferredNode):
"""Deferred mapping of a pystencils field to a field data structure."""
# NOTE ON Scalar Fields
#
# pystencils permits explicit (`index_shape = (1,)`) and implicit (`index_shape = ()`)
# scalar fields. In order to handle both equivalently,
# we ignore the trivial explicit scalar dimension in field extraction.
# This makes sure that explicit D-dimensional scalar fields
# can be mapped onto D-dimensional data structures, and do not require that
# D+1st dimension.
def __init__(
self,
psfield: Field,
......@@ -227,10 +201,16 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
# Find field pointer
ptr: SfgKernelParamVar | None = None
shape: list[SfgKernelParamVar | str | None] = [None] * len(self._field.shape)
strides: list[SfgKernelParamVar | str | None] = [None] * len(
self._field.strides
)
rank: int
if self._field.index_shape == (1,):
# explicit scalar field -> ignore index dimensions
rank = self._field.spatial_dimensions
else:
rank = len(self._field.shape)
shape: list[SfgKernelParamVar | str | None] = [None] * rank
strides: list[SfgKernelParamVar | str | None] = [None] * rank
for param in ppc.live_variables:
if isinstance(param, SfgKernelParamVar):
......@@ -244,12 +224,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
strides[coord] = param # type: ignore
# Find constant or otherwise determined sizes
for coord, s in enumerate(self._field.shape):
for coord, s in enumerate(self._field.shape[:rank]):
if shape[coord] is None:
shape[coord] = str(s)
# Find constant or otherwise determined strides
for coord, s in enumerate(self._field.strides):
for coord, s in enumerate(self._field.strides[:rank]):
if strides[coord] is None:
strides[coord] = str(s)
......
from __future__ import annotations
from typing import Protocol
from typing import Protocol, runtime_checkable
from abc import abstractmethod
from .expressions import AugExpr
@runtime_checkable
class SupportsFieldExtraction(Protocol):
"""Protocol for field pointer and indexing extraction.
......@@ -13,7 +14,7 @@ class SupportsFieldExtraction(Protocol):
They can therefore be passed to `sfg.map_field <SfgBasicComposer.map_field>`.
"""
# how-to-guide begin
# how-to-guide begin
@abstractmethod
def _extract_ptr(self) -> AugExpr:
"""Extract the field base pointer.
......@@ -47,9 +48,12 @@ class SupportsFieldExtraction(Protocol):
:meta public:
"""
# how-to-guide end
@runtime_checkable
class SupportsVectorExtraction(Protocol):
"""Protocol for component extraction from a vector.
......
......@@ -84,6 +84,7 @@ NestedNamespaces:
ScaleKernel:
JacobiMdspan:
StlContainers1D:
VectorExtraction:
# std::mdspan
......
#include "VectorExtraction.hpp"
#include <experimental/mdspan>
#include <memory>
#include <vector>
#undef NDEBUG
#include <cassert>
namespace stdex = std::experimental;
using extents_t = stdex::extents<std::int64_t, std::dynamic_extent, std::dynamic_extent, 3>;
using vector_field_t = stdex::mdspan<double, extents_t, stdex::layout_right>;
constexpr size_t N{41};
int main(void)
{
auto u_data = std::make_unique<double[]>(N * N * 3);
vector_field_t u_field{u_data.get(), extents_t{N, N}};
std::vector<double> v{3.1, 3.2, 3.4};
gen::invoke(u_field, v);
for (size_t j = 0; j < N; ++j)
for (size_t i = 0; i < N; ++i)
{
assert(u_field(j, i, 0) == v[0]);
assert(u_field(j, i, 1) == v[1]);
assert(u_field(j, i, 2) == v[2]);
}
}
\ No newline at end of file
from pystencilssfg import SourceFileGenerator
from pystencilssfg.lang.cpp import std
import pystencils as ps
import sympy as sp
std.mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>")
with SourceFileGenerator() as sfg:
sfg.namespace("gen")
u_field = ps.fields("u(3): double[2D]", layout="c")
u = sp.symbols("u_:3")
asms = [ps.Assignment(u_field(i), u[i]) for i in range(3)]
ker = sfg.kernels.create(asms)
sfg.function("invoke")(
sfg.map_field(u_field, std.mdspan.from_field(u_field, layout_policy="layout_right")),
sfg.map_vector(u, std.vector("double", const=True, ref=True).var("vel")),
sfg.call(ker)
)
import sympy as sp
from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type
from pystencils import (
fields,
kernel,
TypedSymbol,
Field,
FieldType,
create_type,
Assignment,
)
from pystencils.types import PsCustomType
from pystencilssfg.composer import make_sequence
from pystencilssfg.lang import AugExpr, SupportsFieldExtraction
from pystencilssfg.lang.cpp import std
from pystencilssfg.ir import SfgStatements, SfgSequence
from pystencilssfg.ir.postprocessing import CallTreePostProcessing
......@@ -100,7 +109,9 @@ def test_field_extraction(sfg):
khandle = sfg.kernels.create(set_constant)
extraction = DemoFieldExtraction("f")
call_tree = make_sequence(sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle))
call_tree = make_sequence(
sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle)
)
pp = CallTreePostProcessing()
free_vars = pp.get_live_variables(call_tree)
......@@ -165,3 +176,39 @@ def test_duplicate_field_shapes(sfg):
for line, stmt in zip(lines_f, call_tree.children[1].children, strict=True):
assert isinstance(stmt, SfgStatements)
assert stmt.code_string == line
def test_scalar_fields(sfg):
sc_expl = Field.create_generic("f", 1, "double", index_shape=(1,))
sc_impl = Field.create_generic("f", 1, "double", index_shape=())
asm_expl = Assignment(sc_expl.center(0), 3)
asm_impl = Assignment(sc_impl.center(), 3)
k_expl = sfg.kernels.create(asm_expl, "expl")
k_impl = sfg.kernels.create(asm_impl, "impl")
tree_expl = make_sequence(
sfg.map_field(sc_expl, std.span.from_field(sc_expl)), sfg.call(k_expl)
)
tree_impl = make_sequence(
sfg.map_field(sc_impl, std.span.from_field(sc_impl)), sfg.call(k_impl)
)
pp = CallTreePostProcessing()
_ = pp.get_live_variables(tree_expl)
_ = pp.get_live_variables(tree_impl)
extraction_expl = tree_expl.children[0]
assert isinstance(extraction_expl, SfgSequence)
extraction_impl = tree_impl.children[0]
assert isinstance(extraction_impl, SfgSequence)
for node1, node2 in zip(
extraction_expl.children, extraction_impl.children, strict=True
):
assert isinstance(node1, SfgStatements)
assert isinstance(node2, SfgStatements)
assert node1.code_string == node2.code_string
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment