Skip to content
Snippets Groups Projects
Commit d3e347f2 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Adapt to field-related API changes in pystencils

 - Replace `SfgSymbolLike` by `SfgKernelParam`
 - Update postprocessing to work with parameter properties
 - Add tests

Squashed commit of the following:

commit d017185f
Author: Frederik Hennig <frederik.hennig@fau.de>
Date:   Wed Oct 23 10:13:06 2024 +0200

    adapt to KernelParameter API changes

commit b2857481
Author: Frederik Hennig <frederik.hennig@fau.de>
Date:   Tue Oct 22 15:14:01 2024 +0200

    don't ignore the type

commit 6d02cb47
Author: Frederik Hennig <frederik.hennig@fau.de>
Date:   Tue Oct 22 15:12:49 2024 +0200

    Adapt field parameter collection to changes in pystencils.
parent f191fe83
No related branches found
No related tags found
No related merge requests found
Pipeline #69902 passed
......@@ -5,7 +5,6 @@ import re
from pystencils.types import PsType, PsCustomType
from pystencils.enums import Target
from pystencils.backend.kernelfunction import KernelParameter
from ..exceptions import SfgException
from ..context import SfgContext
......@@ -15,8 +14,7 @@ from ..composer import (
SfgComposer,
SfgComposerMixIn,
)
from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude
from ..ir.source_components import SfgSymbolLike
from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude, SfgKernelParamVar
from ..ir import (
SfgCallTreeNode,
SfgCallTreeLeaf,
......@@ -75,7 +73,7 @@ class SyclHandler(AugExpr):
id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>")
def filter_id(param: SfgSymbolLike[KernelParameter]) -> bool:
def filter_id(param: SfgKernelParamVar) -> bool:
return (
isinstance(param.dtype, PsCustomType)
and id_regex.search(param.dtype.c_string()) is not None
......@@ -119,7 +117,7 @@ class SyclGroup(AugExpr):
id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>")
def filter_id(param: SfgSymbolLike[KernelParameter]) -> bool:
def filter_id(param: SfgKernelParamVar) -> bool:
return (
isinstance(param.dtype, PsCustomType)
and id_regex.search(param.dtype.c_string()) is not None
......
......@@ -19,7 +19,7 @@ from .source_components import (
SfgEmptyLines,
SfgKernelNamespace,
SfgKernelHandle,
SfgSymbolLike,
SfgKernelParamVar,
SfgFunction,
SfgVisibility,
SfgClassKeyword,
......@@ -50,7 +50,7 @@ __all__ = [
"SfgEmptyLines",
"SfgKernelNamespace",
"SfgKernelHandle",
"SfgSymbolLike",
"SfgKernelParamVar",
"SfgFunction",
"SfgVisibility",
"SfgClassKeyword",
......
......@@ -8,18 +8,14 @@ from abc import ABC, abstractmethod
import sympy as sp
from pystencils import Field, TypedSymbol
from pystencils import Field
from pystencils.types import deconstify
from pystencils.backend.kernelfunction import (
FieldPointerParam,
FieldShapeParam,
FieldStrideParam,
)
from pystencils.backend.properties import FieldBasePtr, FieldShape, FieldStride
from ..exceptions import SfgException
from .call_tree import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements
from ..ir.source_components import SfgSymbolLike
from ..ir.source_components import SfgKernelParamVar
from ..lang import SfgVar, IFieldExtraction, SrcField, SrcVector
if TYPE_CHECKING:
......@@ -252,43 +248,38 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
else extraction.get_extraction()
)
# type: ignore
def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
# Find field pointer
ptr: SfgSymbolLike[FieldPointerParam] | None = None
shape: list[SfgSymbolLike[FieldShapeParam] | int | None] = [None] * len(
self._field.shape
)
strides: list[SfgSymbolLike[FieldStrideParam] | int | None] = [None] * len(
ptr: SfgKernelParamVar | None = None
shape: list[SfgKernelParamVar | str | None] = [None] * len(self._field.shape)
strides: list[SfgKernelParamVar | str | None] = [None] * len(
self._field.strides
)
for param in ppc.live_variables:
# idk why, but mypy does not understand these pattern matches
match param:
case SfgSymbolLike(FieldPointerParam(_, _, field)) if field == self._field: # type: ignore
ptr = param
case SfgSymbolLike(
FieldShapeParam(_, _, field, coord) # type: ignore
) if field == self._field: # type: ignore
shape[coord] = param # type: ignore
case SfgSymbolLike(
FieldStrideParam(_, _, field, coord) # type: ignore
) if field == self._field: # type: ignore
strides[coord] = param # type: ignore
# Find constant sizes
if isinstance(param, SfgKernelParamVar):
for prop in param.wrapped.properties:
match prop:
case FieldBasePtr(field) if field == self._field:
ptr = param
case FieldShape(field, coord) if field == self._field: # type: ignore
shape[coord] = param # type: ignore
case FieldStride(field, coord) if field == self._field: # type: ignore
strides[coord] = param # type: ignore
# Find constant or otherwise determined sizes
for coord, s in enumerate(self._field.shape):
if not isinstance(s, TypedSymbol):
shape[coord] = s
if shape[coord] is None:
shape[coord] = str(s)
# Find constant strides
# Find constant or otherwise determined strides
for coord, s in enumerate(self._field.strides):
if not isinstance(s, TypedSymbol):
strides[coord] = s
if strides[coord] is None:
strides[coord] = str(s)
# Now we have all the symbols, start extracting them
nodes = []
done: set[SfgKernelParamVar] = set()
if ptr is not None:
expr = self._extraction.ptr()
......@@ -298,7 +289,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
)
)
def get_shape(coord, symb: SfgSymbolLike | int):
def get_shape(coord, symb: SfgKernelParamVar | str):
expr = self._extraction.size(coord)
if expr is None:
......@@ -306,14 +297,15 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
f"Cannot extract shape in coordinate {coord} from {self._extraction}"
)
if isinstance(symb, SfgSymbolLike):
if isinstance(symb, SfgKernelParamVar) and symb not in done:
done.add(symb)
return SfgStatements(
f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends
)
else:
return SfgStatements(f"/* {expr} == {symb} */", (), ())
def get_stride(coord, symb: SfgSymbolLike | int):
def get_stride(coord, symb: SfgKernelParamVar | str):
expr = self._extraction.stride(coord)
if expr is None:
......@@ -321,7 +313,8 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
f"Cannot extract stride in coordinate {coord} from {self._extraction}"
)
if isinstance(symb, SfgSymbolLike):
if isinstance(symb, SfgKernelParamVar) and symb not in done:
done.add(symb)
return SfgStatements(
f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends
)
......
......@@ -2,7 +2,7 @@ from __future__ import annotations
from abc import ABC
from enum import Enum, auto
from typing import TYPE_CHECKING, Sequence, Generator, TypeVar, Generic
from typing import TYPE_CHECKING, Sequence, Generator, TypeVar
from dataclasses import replace
from itertools import chain
......@@ -10,7 +10,6 @@ from pystencils import CreateKernelConfig, create_kernel, Field
from pystencils.backend.kernelfunction import (
KernelFunction,
KernelParameter,
FieldParameter,
)
from pystencils.types import PsType, PsCustomType
......@@ -162,14 +161,14 @@ class SfgKernelHandle:
self._ctx = ctx
self._name = name
self._namespace = namespace
self._parameters = [SfgSymbolLike(p) for p in parameters]
self._parameters = [SfgKernelParamVar(p) for p in parameters]
self._scalar_params: set[SfgSymbolLike] = set()
self._scalar_params: set[SfgKernelParamVar] = set()
self._fields: set[Field] = set()
for param in self._parameters:
if isinstance(param.wrapped, FieldParameter):
self._fields.add(param.wrapped.field)
if param.wrapped.is_field_parameter:
self._fields |= set(param.wrapped.fields)
else:
self._scalar_params.add(param)
......@@ -190,7 +189,7 @@ class SfgKernelHandle:
return f"{fqn}::{self.kernel_namespace.name}::{self.kernel_name}"
@property
def parameters(self) -> Sequence[SfgSymbolLike]:
def parameters(self) -> Sequence[SfgKernelParamVar]:
return self._parameters
@property
......@@ -208,17 +207,17 @@ class SfgKernelHandle:
SymbolLike_T = TypeVar("SymbolLike_T", bound=KernelParameter)
class SfgSymbolLike(SfgVar, Generic[SymbolLike_T]):
class SfgKernelParamVar(SfgVar):
__match_args__ = ("wrapped",)
"""Cast pystencils- or SymPy-native symbol-like objects as a `SfgVar`."""
def __init__(self, param: SymbolLike_T):
def __init__(self, param: KernelParameter):
self._param = param
super().__init__(param.name, param.dtype)
@property
def wrapped(self) -> SymbolLike_T:
def wrapped(self) -> KernelParameter:
return self._param
def _args(self):
......
......@@ -9,5 +9,5 @@ private:
float alpha;
public:
Scale(float alpha) : alpha{ alpha } {}
void operator() (float *const _data_f, float *const _data_g);
void operator() (float *RESTRICT const _data_f, float *RESTRICT const _data_g);
};
import sympy as sp
from pystencils import fields, kernel, TypedSymbol
from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type
from pystencils.types import PsCustomType
from pystencilssfg import SfgContext, SfgComposer
from pystencilssfg.composer import make_sequence
from pystencilssfg.ir import SfgStatements
from pystencilssfg.lang import IFieldExtraction, AugExpr
from pystencilssfg.ir import SfgStatements, SfgSequence
from pystencilssfg.ir.postprocessing import CallTreePostProcessing
......@@ -75,3 +78,101 @@ def test_find_sympy_symbols():
assert isinstance(call_tree.children[1], SfgStatements)
assert call_tree.children[1].code_string == "const double y = x / a;"
class TestFieldExtraction(IFieldExtraction):
def __init__(self, name: str):
self.obj = AugExpr(PsCustomType("MyField")).var(name)
def ptr(self) -> AugExpr:
return AugExpr.format("{}.ptr()", self.obj)
def size(self, coordinate: int) -> AugExpr | None:
return AugExpr.format("{}.size({})", self.obj, coordinate)
def stride(self, coordinate: int) -> AugExpr | None:
return AugExpr.format("{}.stride({})", self.obj, coordinate)
def test_field_extraction():
sx, sy, tx, ty = [
TypedSymbol(n, create_type("int64")) for n in ("sx", "sy", "tx", "ty")
]
f = Field("f", FieldType.GENERIC, "double", (1, 0), (sx, sy), (tx, ty))
@kernel
def set_constant():
f.center @= 13.2
sfg = SfgComposer(SfgContext())
khandle = sfg.kernels.create(set_constant)
extraction = TestFieldExtraction("f")
call_tree = make_sequence(sfg.map_field(f, extraction), sfg.call(khandle))
pp = CallTreePostProcessing()
free_vars = pp.get_live_variables(call_tree)
assert free_vars == {extraction.obj.as_variable()}
lines = [
r"double * RESTRICT const _data_f { f.ptr() };",
r"const int64_t sx { f.size(0) };",
r"const int64_t sy { f.size(1) };",
r"const int64_t tx { f.stride(0) };",
r"const int64_t ty { f.stride(1) };",
]
assert isinstance(call_tree.children[0], SfgSequence)
for line, stmt in zip(lines, call_tree.children[0].children, strict=True):
assert isinstance(stmt, SfgStatements)
assert stmt.code_string == line
def test_duplicate_field_shapes():
N, tx, ty = [TypedSymbol(n, create_type("int64")) for n in ("N", "tx", "ty")]
f = Field("f", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty))
g = Field("g", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty))
@kernel
def set_constant():
f.center @= g.center(0)
sfg = SfgComposer(SfgContext())
khandle = sfg.kernels.create(set_constant)
call_tree = make_sequence(
sfg.map_field(g, TestFieldExtraction("g")),
sfg.map_field(f, TestFieldExtraction("f")),
sfg.call(khandle),
)
pp = CallTreePostProcessing()
_ = pp.get_live_variables(call_tree)
lines_g = [
r"double * RESTRICT const _data_g { g.ptr() };",
r"/* g.size(0) == N */",
r"/* g.size(1) == N */",
r"/* g.stride(0) == tx */",
r"/* g.stride(1) == ty */",
]
assert isinstance(call_tree.children[0], SfgSequence)
for line, stmt in zip(lines_g, call_tree.children[0].children, strict=True):
assert isinstance(stmt, SfgStatements)
assert stmt.code_string == line
lines_f = [
r"double * RESTRICT const _data_f { f.ptr() };",
r"const int64_t N { f.size(0) };",
r"/* f.size(1) == N */",
r"const int64_t tx { f.stride(0) };",
r"const int64_t ty { f.stride(1) };",
]
assert isinstance(call_tree.children[1], SfgSequence)
for line, stmt in zip(lines_f, call_tree.children[1].children, strict=True):
assert isinstance(stmt, SfgStatements)
assert stmt.code_string == line
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment