Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (2)
......@@ -27,38 +27,6 @@ from ..lang import (
)
class FlattenSequences:
"""Flattens any nested sequences occuring in a kernel call tree."""
def __call__(self, node: SfgCallTreeNode) -> None:
self.visit(node)
def visit(self, node: SfgCallTreeNode):
match node:
case SfgSequence():
self.flatten(node)
case _:
for c in node.children:
self.visit(c)
def flatten(self, sequence: SfgSequence) -> None:
children_flattened: list[SfgCallTreeNode] = []
def flatten(seq: SfgSequence):
for c in seq.children:
if isinstance(c, SfgSequence):
flatten(c)
else:
children_flattened.append(c)
flatten(sequence)
for c in children_flattened:
self.visit(c)
sequence.children = children_flattened
class PostProcessingContext:
def __init__(self) -> None:
self._live_variables: dict[str, SfgVar] = dict()
......@@ -129,9 +97,6 @@ class PostProcessingResult:
class CallTreePostProcessing:
def __init__(self):
self._flattener = FlattenSequences()
def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult:
live_vars = self.get_live_variables(ast)
return PostProcessingResult(live_vars)
......@@ -214,6 +179,15 @@ class SfgDeferredParamSetter(SfgDeferredNode):
class SfgDeferredFieldMapping(SfgDeferredNode):
"""Deferred mapping of a pystencils field to a field data structure."""
# NOTE ON Scalar Fields
#
# pystencils permits explicit (`index_shape = (1,)`) and implicit (`index_shape = ()`)
# scalar fields. In order to handle both equivalently,
# we ignore the trivial explicit scalar dimension in field extraction.
# This makes sure that explicit D-dimensional scalar fields
# can be mapped onto D-dimensional data structures, and do not require that
# D+1st dimension.
def __init__(
self,
psfield: Field,
......@@ -227,10 +201,16 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
# Find field pointer
ptr: SfgKernelParamVar | None = None
shape: list[SfgKernelParamVar | str | None] = [None] * len(self._field.shape)
strides: list[SfgKernelParamVar | str | None] = [None] * len(
self._field.strides
)
rank: int
if self._field.index_shape == (1,):
# explicit scalar field -> ignore index dimensions
rank = self._field.spatial_dimensions
else:
rank = len(self._field.shape)
shape: list[SfgKernelParamVar | str | None] = [None] * rank
strides: list[SfgKernelParamVar | str | None] = [None] * rank
for param in ppc.live_variables:
if isinstance(param, SfgKernelParamVar):
......@@ -244,12 +224,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
strides[coord] = param # type: ignore
# Find constant or otherwise determined sizes
for coord, s in enumerate(self._field.shape):
for coord, s in enumerate(self._field.shape[:rank]):
if shape[coord] is None:
shape[coord] = str(s)
# Find constant or otherwise determined strides
for coord, s in enumerate(self._field.strides):
for coord, s in enumerate(self._field.strides[:rank]):
if strides[coord] is None:
strides[coord] = str(s)
......
from __future__ import annotations
from typing import Protocol
from typing import Protocol, runtime_checkable
from abc import abstractmethod
from .expressions import AugExpr
@runtime_checkable
class SupportsFieldExtraction(Protocol):
"""Protocol for field pointer and indexing extraction.
......@@ -13,7 +14,7 @@ class SupportsFieldExtraction(Protocol):
They can therefore be passed to `sfg.map_field <SfgBasicComposer.map_field>`.
"""
# how-to-guide begin
# how-to-guide begin
@abstractmethod
def _extract_ptr(self) -> AugExpr:
"""Extract the field base pointer.
......@@ -47,9 +48,12 @@ class SupportsFieldExtraction(Protocol):
:meta public:
"""
# how-to-guide end
@runtime_checkable
class SupportsVectorExtraction(Protocol):
"""Protocol for component extraction from a vector.
......
......@@ -84,6 +84,7 @@ NestedNamespaces:
ScaleKernel:
JacobiMdspan:
StlContainers1D:
VectorExtraction:
# std::mdspan
......
#include "VectorExtraction.hpp"
#include <experimental/mdspan>
#include <memory>
#include <vector>
#undef NDEBUG
#include <cassert>
namespace stdex = std::experimental;
using extents_t = stdex::extents<std::int64_t, std::dynamic_extent, std::dynamic_extent, 3>;
using vector_field_t = stdex::mdspan<double, extents_t, stdex::layout_right>;
constexpr size_t N{41};
int main(void)
{
auto u_data = std::make_unique<double[]>(N * N * 3);
vector_field_t u_field{u_data.get(), extents_t{N, N}};
std::vector<double> v{3.1, 3.2, 3.4};
gen::invoke(u_field, v);
for (size_t j = 0; j < N; ++j)
for (size_t i = 0; i < N; ++i)
{
assert(u_field(j, i, 0) == v[0]);
assert(u_field(j, i, 1) == v[1]);
assert(u_field(j, i, 2) == v[2]);
}
}
\ No newline at end of file
from pystencilssfg import SourceFileGenerator
from pystencilssfg.lang.cpp import std
import pystencils as ps
import sympy as sp
std.mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>")
with SourceFileGenerator() as sfg:
sfg.namespace("gen")
u_field = ps.fields("u(3): double[2D]", layout="c")
u = sp.symbols("u_:3")
asms = [ps.Assignment(u_field(i), u[i]) for i in range(3)]
ker = sfg.kernels.create(asms)
sfg.function("invoke")(
sfg.map_field(u_field, std.mdspan.from_field(u_field, layout_policy="layout_right")),
sfg.map_vector(u, std.vector("double", const=True, ref=True).var("vel")),
sfg.call(ker)
)
import sympy as sp
from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type
from pystencils import (
fields,
kernel,
TypedSymbol,
Field,
FieldType,
create_type,
Assignment,
)
from pystencils.types import PsCustomType
from pystencilssfg.composer import make_sequence
from pystencilssfg.lang import AugExpr, SupportsFieldExtraction
from pystencilssfg.lang.cpp import std
from pystencilssfg.ir import SfgStatements, SfgSequence
from pystencilssfg.ir.postprocessing import CallTreePostProcessing
......@@ -100,7 +109,9 @@ def test_field_extraction(sfg):
khandle = sfg.kernels.create(set_constant)
extraction = DemoFieldExtraction("f")
call_tree = make_sequence(sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle))
call_tree = make_sequence(
sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle)
)
pp = CallTreePostProcessing()
free_vars = pp.get_live_variables(call_tree)
......@@ -165,3 +176,39 @@ def test_duplicate_field_shapes(sfg):
for line, stmt in zip(lines_f, call_tree.children[1].children, strict=True):
assert isinstance(stmt, SfgStatements)
assert stmt.code_string == line
def test_scalar_fields(sfg):
sc_expl = Field.create_generic("f", 1, "double", index_shape=(1,))
sc_impl = Field.create_generic("f", 1, "double", index_shape=())
asm_expl = Assignment(sc_expl.center(0), 3)
asm_impl = Assignment(sc_impl.center(), 3)
k_expl = sfg.kernels.create(asm_expl, "expl")
k_impl = sfg.kernels.create(asm_impl, "impl")
tree_expl = make_sequence(
sfg.map_field(sc_expl, std.span.from_field(sc_expl)), sfg.call(k_expl)
)
tree_impl = make_sequence(
sfg.map_field(sc_impl, std.span.from_field(sc_impl)), sfg.call(k_impl)
)
pp = CallTreePostProcessing()
_ = pp.get_live_variables(tree_expl)
_ = pp.get_live_variables(tree_impl)
extraction_expl = tree_expl.children[0]
assert isinstance(extraction_expl, SfgSequence)
extraction_impl = tree_impl.children[0]
assert isinstance(extraction_impl, SfgSequence)
for node1, node2 in zip(
extraction_expl.children, extraction_impl.children, strict=True
):
assert isinstance(node1, SfgStatements)
assert isinstance(node2, SfgStatements)
assert node1.code_string == node2.code_string