Skip to content
Snippets Groups Projects
Commit 6d02cb47 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Adapt field parameter collection to changes in pystencils.

parent f191fe83
Branches
Tags
No related merge requests found
Pipeline #69771 failed
......@@ -5,7 +5,6 @@ import re
from pystencils.types import PsType, PsCustomType
from pystencils.enums import Target
from pystencils.backend.kernelfunction import KernelParameter
from ..exceptions import SfgException
from ..context import SfgContext
......@@ -15,8 +14,7 @@ from ..composer import (
SfgComposer,
SfgComposerMixIn,
)
from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude
from ..ir.source_components import SfgSymbolLike
from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude, SfgKernelParamVar
from ..ir import (
SfgCallTreeNode,
SfgCallTreeLeaf,
......@@ -75,7 +73,7 @@ class SyclHandler(AugExpr):
id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>")
def filter_id(param: SfgSymbolLike[KernelParameter]) -> bool:
def filter_id(param: SfgKernelParamVar) -> bool:
return (
isinstance(param.dtype, PsCustomType)
and id_regex.search(param.dtype.c_string()) is not None
......@@ -119,7 +117,7 @@ class SyclGroup(AugExpr):
id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>")
def filter_id(param: SfgSymbolLike[KernelParameter]) -> bool:
def filter_id(param: SfgKernelParamVar) -> bool:
return (
isinstance(param.dtype, PsCustomType)
and id_regex.search(param.dtype.c_string()) is not None
......
......@@ -19,7 +19,7 @@ from .source_components import (
SfgEmptyLines,
SfgKernelNamespace,
SfgKernelHandle,
SfgSymbolLike,
SfgKernelParamVar,
SfgFunction,
SfgVisibility,
SfgClassKeyword,
......@@ -50,7 +50,7 @@ __all__ = [
"SfgEmptyLines",
"SfgKernelNamespace",
"SfgKernelHandle",
"SfgSymbolLike",
"SfgKernelParamVar",
"SfgFunction",
"SfgVisibility",
"SfgClassKeyword",
......
......@@ -8,18 +8,14 @@ from abc import ABC, abstractmethod
import sympy as sp
from pystencils import Field, TypedSymbol
from pystencils import Field
from pystencils.types import deconstify
from pystencils.backend.kernelfunction import (
FieldPointerParam,
FieldShapeParam,
FieldStrideParam,
)
from pystencils.backend.properties import FieldBasePtr, FieldShape, FieldStride
from ..exceptions import SfgException
from .call_tree import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements
from ..ir.source_components import SfgSymbolLike
from ..ir.source_components import SfgKernelParamVar
from ..lang import SfgVar, IFieldExtraction, SrcField, SrcVector
if TYPE_CHECKING:
......@@ -255,40 +251,36 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
# type: ignore
def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
# Find field pointer
ptr: SfgSymbolLike[FieldPointerParam] | None = None
shape: list[SfgSymbolLike[FieldShapeParam] | int | None] = [None] * len(
self._field.shape
)
strides: list[SfgSymbolLike[FieldStrideParam] | int | None] = [None] * len(
ptr: SfgKernelParamVar | None = None
shape: list[SfgKernelParamVar | str | None] = [None] * len(self._field.shape)
strides: list[SfgKernelParamVar | str | None] = [None] * len(
self._field.strides
)
for param in ppc.live_variables:
# idk why, but mypy does not understand these pattern matches
match param:
case SfgSymbolLike(FieldPointerParam(_, _, field)) if field == self._field: # type: ignore
ptr = param
case SfgSymbolLike(
FieldShapeParam(_, _, field, coord) # type: ignore
) if field == self._field: # type: ignore
shape[coord] = param # type: ignore
case SfgSymbolLike(
FieldStrideParam(_, _, field, coord) # type: ignore
) if field == self._field: # type: ignore
strides[coord] = param # type: ignore
# Find constant sizes
if isinstance(param, SfgKernelParamVar):
for prop in param.wrapped.properties:
match prop:
case FieldBasePtr(field) if field == self._field:
ptr = param
case FieldShape(field, coord) if field == self._field: # type: ignore
shape[coord] = param # type: ignore
case FieldStride(field, coord) if field == self._field: # type: ignore
strides[coord] = param # type: ignore
# Find constant or otherwise determined sizes
for coord, s in enumerate(self._field.shape):
if not isinstance(s, TypedSymbol):
shape[coord] = s
if shape[coord] is None:
shape[coord] = str(s)
# Find constant strides
# Find constant or otherwise determined strides
for coord, s in enumerate(self._field.strides):
if not isinstance(s, TypedSymbol):
strides[coord] = s
if strides[coord] is None:
strides[coord] = str(s)
# Now we have all the symbols, start extracting them
nodes = []
done: set[SfgKernelParamVar] = set()
if ptr is not None:
expr = self._extraction.ptr()
......@@ -298,7 +290,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
)
)
def get_shape(coord, symb: SfgSymbolLike | int):
def get_shape(coord, symb: SfgKernelParamVar | str):
expr = self._extraction.size(coord)
if expr is None:
......@@ -306,14 +298,15 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
f"Cannot extract shape in coordinate {coord} from {self._extraction}"
)
if isinstance(symb, SfgSymbolLike):
if isinstance(symb, SfgKernelParamVar) and symb not in done:
done.add(symb)
return SfgStatements(
f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends
)
else:
return SfgStatements(f"/* {expr} == {symb} */", (), ())
def get_stride(coord, symb: SfgSymbolLike | int):
def get_stride(coord, symb: SfgKernelParamVar | str):
expr = self._extraction.stride(coord)
if expr is None:
......@@ -321,7 +314,8 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
f"Cannot extract stride in coordinate {coord} from {self._extraction}"
)
if isinstance(symb, SfgSymbolLike):
if isinstance(symb, SfgKernelParamVar) and symb not in done:
done.add(symb)
return SfgStatements(
f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends
)
......
......@@ -2,7 +2,7 @@ from __future__ import annotations
from abc import ABC
from enum import Enum, auto
from typing import TYPE_CHECKING, Sequence, Generator, TypeVar, Generic
from typing import TYPE_CHECKING, Sequence, Generator, TypeVar
from dataclasses import replace
from itertools import chain
......@@ -10,7 +10,6 @@ from pystencils import CreateKernelConfig, create_kernel, Field
from pystencils.backend.kernelfunction import (
KernelFunction,
KernelParameter,
FieldParameter,
)
from pystencils.types import PsType, PsCustomType
......@@ -162,14 +161,14 @@ class SfgKernelHandle:
self._ctx = ctx
self._name = name
self._namespace = namespace
self._parameters = [SfgSymbolLike(p) for p in parameters]
self._parameters = [SfgKernelParamVar(p) for p in parameters]
self._scalar_params: set[SfgSymbolLike] = set()
self._scalar_params: set[SfgKernelParamVar] = set()
self._fields: set[Field] = set()
for param in self._parameters:
if isinstance(param.wrapped, FieldParameter):
self._fields.add(param.wrapped.field)
if param.wrapped.is_field_parameter:
self._fields |= param.wrapped.fields
else:
self._scalar_params.add(param)
......@@ -190,7 +189,7 @@ class SfgKernelHandle:
return f"{fqn}::{self.kernel_namespace.name}::{self.kernel_name}"
@property
def parameters(self) -> Sequence[SfgSymbolLike]:
def parameters(self) -> Sequence[SfgKernelParamVar]:
return self._parameters
@property
......@@ -208,17 +207,17 @@ class SfgKernelHandle:
SymbolLike_T = TypeVar("SymbolLike_T", bound=KernelParameter)
class SfgSymbolLike(SfgVar, Generic[SymbolLike_T]):
class SfgKernelParamVar(SfgVar):
__match_args__ = ("wrapped",)
"""Cast pystencils- or SymPy-native symbol-like objects as a `SfgVar`."""
def __init__(self, param: SymbolLike_T):
def __init__(self, param: KernelParameter):
self._param = param
super().__init__(param.name, param.dtype)
@property
def wrapped(self) -> SymbolLike_T:
def wrapped(self) -> KernelParameter:
return self._param
def _args(self):
......
......@@ -9,5 +9,5 @@ private:
float alpha;
public:
Scale(float alpha) : alpha{ alpha } {}
void operator() (float *const _data_f, float *const _data_g);
void operator() (float *RESTRICT const _data_f, float *RESTRICT const _data_g);
};
import sympy as sp
from pystencils import fields, kernel, TypedSymbol
from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type
from pystencils.types import PsCustomType
from pystencilssfg import SfgContext, SfgComposer
from pystencilssfg.composer import make_sequence
from pystencilssfg.ir import SfgStatements
from pystencilssfg.lang import IFieldExtraction, AugExpr
from pystencilssfg.ir import SfgStatements, SfgSequence
from pystencilssfg.ir.postprocessing import CallTreePostProcessing
......@@ -75,3 +78,101 @@ def test_find_sympy_symbols():
assert isinstance(call_tree.children[1], SfgStatements)
assert call_tree.children[1].code_string == "const double y = x / a;"
class TestFieldExtraction(IFieldExtraction):
def __init__(self, name: str):
self.obj = AugExpr(PsCustomType("MyField")).var(name)
def ptr(self) -> AugExpr:
return AugExpr.format("{}.ptr()", self.obj)
def size(self, coordinate: int) -> AugExpr | None:
return AugExpr.format("{}.size({})", self.obj, coordinate)
def stride(self, coordinate: int) -> AugExpr | None:
return AugExpr.format("{}.stride({})", self.obj, coordinate)
def test_field_extraction():
sx, sy, tx, ty = [
TypedSymbol(n, create_type("int64")) for n in ("sx", "sy", "tx", "ty")
]
f = Field("f", FieldType.GENERIC, "double", (1, 0), (sx, sy), (tx, ty))
@kernel
def set_constant():
f.center @= 13.2
sfg = SfgComposer(SfgContext())
khandle = sfg.kernels.create(set_constant)
extraction = TestFieldExtraction("f")
call_tree = make_sequence(sfg.map_field(f, extraction), sfg.call(khandle))
pp = CallTreePostProcessing()
free_vars = pp.get_live_variables(call_tree)
assert free_vars == {extraction.obj.as_variable()}
lines = [
r"double * RESTRICT const _data_f { f.ptr() };",
r"const int64_t sx { f.size(0) };",
r"const int64_t sy { f.size(1) };",
r"const int64_t tx { f.stride(0) };",
r"const int64_t ty { f.stride(1) };",
]
assert isinstance(call_tree.children[0], SfgSequence)
for line, stmt in zip(lines, call_tree.children[0].children, strict=True):
assert isinstance(stmt, SfgStatements)
assert stmt.code_string == line
def test_duplicate_field_shapes():
N, tx, ty = [TypedSymbol(n, create_type("int64")) for n in ("N", "tx", "ty")]
f = Field("f", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty))
g = Field("g", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty))
@kernel
def set_constant():
f.center @= g.center(0)
sfg = SfgComposer(SfgContext())
khandle = sfg.kernels.create(set_constant)
call_tree = make_sequence(
sfg.map_field(g, TestFieldExtraction("g")),
sfg.map_field(f, TestFieldExtraction("f")),
sfg.call(khandle),
)
pp = CallTreePostProcessing()
_ = pp.get_live_variables(call_tree)
lines_g = [
r"double * RESTRICT const _data_g { g.ptr() };",
r"/* g.size(0) == N */",
r"/* g.size(1) == N */",
r"/* g.stride(0) == tx */",
r"/* g.stride(1) == ty */",
]
assert isinstance(call_tree.children[0], SfgSequence)
for line, stmt in zip(lines_g, call_tree.children[0].children, strict=True):
assert isinstance(stmt, SfgStatements)
assert stmt.code_string == line
lines_f = [
r"double * RESTRICT const _data_f { f.ptr() };",
r"const int64_t N { f.size(0) };",
r"/* f.size(1) == N */",
r"const int64_t tx { f.stride(0) };",
r"const int64_t ty { f.stride(1) };",
]
assert isinstance(call_tree.children[1], SfgSequence)
for line, stmt in zip(lines_f, call_tree.children[1].children, strict=True):
assert isinstance(stmt, SfgStatements)
assert stmt.code_string == line
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment