diff --git a/docs/source/api/lang.rst b/docs/source/api/lang.rst index c83317e506549eb9ca9d86bc71765d196306ef09..bd5e4fa7993f30d4c8f6e5c532eb3d234c64061d 100644 --- a/docs/source/api/lang.rst +++ b/docs/source/api/lang.rst @@ -21,6 +21,12 @@ Data Types .. automodule:: pystencilssfg.lang.types :members: +Extraction Protocols +-------------------- + +.. automodule:: pystencilssfg.lang.extractions + :members: + C++ Standard Library (``pystencilssfg.lang.cpp``) ------------------------------------------------- diff --git a/docs/source/usage/api_modelling.md b/docs/source/usage/api_modelling.md index cfb0e10776ae11072c13423c1ef5b78a72fd8459..bba61edbee19ef3eb207eacb3aef4e5f8dd7e369 100644 --- a/docs/source/usage/api_modelling.md +++ b/docs/source/usage/api_modelling.md @@ -7,6 +7,21 @@ kernelspec: (how_to_cpp_api_modelling)= # How To Reflect C++ APIs +```{code-cell} ipython3 +:tags: [remove-cell] + +from __future__ import annotations +import sys +from pathlib import Path + +mockup_path = Path("../_util").resolve() +sys.path.append(str(mockup_path)) + +from sfg_monkeypatch import DocsPatchedGenerator # monkeypatch SFG for docs + +from pystencilssfg import SourceFileGenerator +``` + Pystencils-SFG is designed to help you generate C++ code that interfaces with pystencils on the one side, and with your handwritten code on the other side. This requires that the C++ classes and APIs of your framework or application be represented within the SFG system. @@ -233,7 +248,173 @@ expr, lang.depends(expr), lang.includes(expr) (field_data_structure_reflection)= ## Reflecting Field Data Structures -:::{admonition} To Do +One key feature of pystencils-sfg is its ability to map symbolic fields +onto arbitrary array data structures +using the composer's {any}`map_field <SfgBasicComposer.map_field>` method. +The APIs of a custom field data structure can naturally be injected into pystencils-sfg +using the modelling framework described above. +However, for them to be recognized by `map_field`, +the reflection class also needs to implement the {any}`SupportsFieldExtraction` protocol. +This requires that the following three methods are implemented: + +```{code-block} python +def _extract_ptr(self) -> AugExpr: ... + +def _extract_size(self, coordinate: int) -> AugExpr | None: ... + +def _extract_stride(self, coordinate: int) -> AugExpr | None: ... +``` + +The first, `_extract_ptr`, must return an expression that evaluates +to the base pointer of the field's memory buffer. +This pointer has to point at the field entry which pystencils accesses +at all-zero index and offsets (see [](#note-on-ghost-layers)). +The other two, when called with a coordinate $c \ge 0$, shall return +the size and linearization stride of the field in that direction. +If the coordinate is equal or larger than the field's dimensionality, +return `None` instead. + +### Sample Field API Reflection + +Consider the following class template for a field, which takes its element type +and dimensionality as template parameters +and exposes its data pointer, shape, and strides through public methods: + +```{code-block} C++ +template< std::floating_point ElemType, size_t DIM > +class MyField { +public: + size_t get_shape(size_t coord); + size_t get_stride(size_t coord); + ElemType * data_ptr(); +} +``` + +It could be reflected by the following class. +Note that in this case we define a custom `__init__` method in order to +intercept the template arguments `elem_type` and `dim` +and store them as instance members. +Our `__init__` then forwards all its arguments up to `CppClass.__init__`. +We then define reflection methods for `shape`, `stride` and `data` - +the implementation of the field extraction protocol then simply calls these methods. + +```{code-cell} ipython3 +from pystencilssfg.lang import SupportsFieldExtraction +from pystencils.types import UserTypeSpec + +class MyField(lang.CppClass, SupportsFieldExtraction): + template = lang.cpptype( + "MyField< {ElemType}, {DIM} >", + "MyField.hpp" + ) + + def __init__( + self, + elem_type: UserTypeSpec, + dim: int, + **kwargs, + ) -> None: + self._elem_type = elem_type + self._dim = dim + super().__init__(ElemType=elem_type, DIM=dim, **kwargs) + + # Reflection of Public Methods + def get_shape(self, coord: int | lang.AugExpr) -> lang.AugExpr: + return lang.AugExpr.format("{}.get_shape({})", self, coord) + + def get_stride(self, coord: int | lang.AugExpr) -> lang.AugExpr: + return lang.AugExpr.format("{}.get_stride({})", self, coord) + + def data_ptr(self) -> lang.AugExpr: + return lang.AugExpr.format("{}.data_ptr()", self) + + # Field Extraction Protocol that uses the above interface + def _extract_ptr(self) -> lang.AugExpr: + return self.data_ptr() + + def _extract_size(self, coordinate: int) -> lang.AugExpr | None: + if coordinate > self._dim: + return None + else: + return self.get_shape(coordinate) + + def _extract_stride(self, coordinate: int) -> lang.AugExpr | None: + if coordinate > self._dim: + return None + else: + return self.get_stride(coordinate) +``` + +Our custom field reflection is now ready to be used. +The following generator script demonstrates what code is generated when an instance of `MyField` +is passed to `sfg.map_field`: + + +```{code-cell} ipython3 +import pystencils as ps +from pystencilssfg.lang.cpp import std + +with SourceFileGenerator() as sfg: + # Create symbolic fields + f = ps.fields("f: double[3D]") + f_myfield = MyField(f.dtype, f.ndim, ref=True).var(f.name) + + # Create the kernel + asm = ps.Assignment(f(0), 2 * f(0)) + khandle = sfg.kernels.create(asm) + + # Create the wrapper function + sfg.function("invoke")( + sfg.map_field(f, f_myfield), + sfg.call(khandle) + ) +``` + +### Add a Factory Function + +In the above example, an instance of `MyField` representing the field `f` is created by the +slightly verbose expression `MyField(f.dtype, f.ndim, ref=True).var(f.name)`. +Having to write this sequence every time, for every field, introduces unnecessary +cognitive load and lots of potential sources of error. +Whenever it is possible to create a field reflection using just information contained in a +pystencils {any}`Field <pystencils.field.Field>` object, +the API reflection should therefore implement a factory method `from_field`: + +```{code-cell} ipython3 +class MyField(lang.CppClass, SupportsFieldExtraction): + ... + + @classmethod + def from_field(cls, field: ps.Field, const: bool = False, ref: bool = False) -> MyField: + return cls(f.dtype, f.ndim, const=const, ref=ref).var(f.name) + +``` + +The above signature is idiomatic for `from_field`, and you should stick to it as far as possible. +We can now use it inside the generator script: + +```{code-block} python +f = ps.fields("f: double[3D]") +f_myfield = MyField.from_field(f) +``` -Write guide on field data structure reflection -::: +(note-on-ghost-layers)= +### A Note on Ghost Layers + +Some care has to be taken when reflecting data structures that model the notion +of ghost layers. +Consider an array with the index space $[0, N_x) \times [0, N_y)$, +its base pointer identifying the entry $(0, 0)$. +When a pystencils kernel is generated with a shell of $k$ ghost layers +(see {any}`CreateKernelConfig.ghost_layers <pystencils.codegen.config.CreateKernelConfig.ghost_layers>`), +it will process only the subspace $[k, N_x - k) \times [k, N_x - k)$. + +If your data structure is implemented such that ghost layer nodes have coordinates +$< 0$ and $\ge N_{x, y}$, +you must hence take care that + - either, `_extract_ptr` returns a pointer identifying the array entry at `(-k, -k)`; + - or, ensure that kernels operating on your data structure are always generated + with `ghost_layers = 0`. + +In either case, you must make sure that the number of ghost layers in your data structure +matches the expected number of ghost layers of the kernel. diff --git a/docs/source/usage/how_to_composer.md b/docs/source/usage/how_to_composer.md index 7f08829605abad35f5603b502e2cc454df90124a..966a9a661b8f7c5d4d863b07c2a9549a95032591 100644 --- a/docs/source/usage/how_to_composer.md +++ b/docs/source/usage/how_to_composer.md @@ -352,8 +352,8 @@ computing landscape, including [Kokkos Views][kokkos_view], [C++ std::mdspan][md [SYCL buffers][sycl_buffer], and many framework-specific custom-built classes. Using the protocols behind {any}`sfg.map_field <SfgBasicComposer.map_field>`, it is possible to automatically emit code -that extracts the indexing information required by a kernel from any of these classes -- provided a suitable API reflection is available. +that extracts the indexing information required by a kernel from any of these classes, +as long as a suitable API reflection is available. :::{seealso} [](#field_data_structure_reflection) for instructions on how to set up field API diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 49b6c73e8bd06b0d5fb44402b8af285362072b29..05ebc71ffdcc1879e1ccaa4e5636b064776a44b2 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -54,9 +54,8 @@ from ..lang import ( includes, SfgVar, AugExpr, - SrcField, - IFieldExtraction, - SrcVector, + SupportsFieldExtraction, + SupportsVectorExtraction, void, ) from ..exceptions import SfgException @@ -511,7 +510,7 @@ class SfgBasicComposer(SfgIComposer): def map_field( self, field: Field, - index_provider: IFieldExtraction | SrcField, + index_provider: SupportsFieldExtraction, cast_indexing_symbols: bool = True, ) -> SfgDeferredFieldMapping: """Map a pystencils field to a field data structure, from which pointers, sizes @@ -519,7 +518,7 @@ class SfgBasicComposer(SfgIComposer): Args: field: The pystencils field to be mapped - index_provider: An expression representing a field, or a field extraction provider instance + index_provider: An object that provides the field indexing information cast_indexing_symbols: Whether to always introduce explicit casts for indexing symbols """ return SfgDeferredFieldMapping( @@ -536,12 +535,12 @@ class SfgBasicComposer(SfgIComposer): var: SfgVar | sp.Symbol = asvar(param) if isinstance(param, _VarLike) else param return SfgDeferredParamSetter(var, expr) - def map_vector(self, lhs_components: Sequence[VarLike | sp.Symbol], rhs: SrcVector): + def map_vector(self, lhs_components: Sequence[VarLike | sp.Symbol], rhs: SupportsVectorExtraction): """Extracts scalar numerical values from a vector data type. Args: lhs_components: Vector components as a list of symbols. - rhs: A `SrcVector` object representing a vector data structure. + rhs: An object providing access to vector components """ components: list[SfgVar | sp.Symbol] = [ (asvar(c) if isinstance(c, _VarLike) else c) for c in lhs_components diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index 469383c123d7403ca448efb40dad5692d46d0a3c..1e692b0aa9f37368da1688ec2d3bac6892c5ac60 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -19,9 +19,8 @@ from .call_tree import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStateme from ..lang.expressions import SfgKernelParamVar from ..lang import ( SfgVar, - IFieldExtraction, - SrcField, - SrcVector, + SupportsFieldExtraction, + SupportsVectorExtraction, ExprLike, AugExpr, depends, @@ -211,7 +210,9 @@ class SfgDeferredParamSetter(SfgDeferredNode): live_var = ppc.get_live_variable(self._lhs.name) if live_var is not None: code = f"{live_var.dtype.c_string()} {live_var.name} = {self._rhs};" - return SfgStatements(code, (live_var,), depends(self._rhs), includes(self._rhs)) + return SfgStatements( + code, (live_var,), depends(self._rhs), includes(self._rhs) + ) else: return SfgSequence([]) @@ -222,15 +223,11 @@ class SfgDeferredFieldMapping(SfgDeferredNode): def __init__( self, psfield: Field, - extraction: IFieldExtraction | SrcField, + extraction: SupportsFieldExtraction, cast_indexing_symbols: bool = True, ): self._field = psfield - self._extraction: IFieldExtraction = ( - extraction - if isinstance(extraction, IFieldExtraction) - else extraction.get_extraction() - ) + self._extraction = extraction self._cast_indexing_symbols = cast_indexing_symbols def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: @@ -267,7 +264,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): done: set[SfgKernelParamVar] = set() if ptr is not None: - expr = self._extraction.ptr() + expr = self._extraction._extract_ptr() nodes.append( SfgStatements( f"{ptr.dtype.c_string()} {ptr.name} {{ {expr} }};", @@ -286,7 +283,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): return expr def get_shape(coord, symb: SfgKernelParamVar | str): - expr = self._extraction.size(coord) + expr = self._extraction._extract_size(coord) if expr is None: raise SfgException( @@ -306,7 +303,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): return SfgStatements(f"/* {expr} == {symb} */", (), ()) def get_stride(coord, symb: SfgKernelParamVar | str): - expr = self._extraction.stride(coord) + expr = self._extraction._extract_stride(coord) if expr is None: raise SfgException( @@ -332,7 +329,11 @@ class SfgDeferredFieldMapping(SfgDeferredNode): class SfgDeferredVectorMapping(SfgDeferredNode): - def __init__(self, scalars: Sequence[sp.Symbol | SfgVar], vector: SrcVector): + def __init__( + self, + scalars: Sequence[sp.Symbol | SfgVar], + vector: SupportsVectorExtraction, + ): self._scalars = {sc.name: (i, sc) for i, sc in enumerate(scalars)} self._vector = vector @@ -342,7 +343,7 @@ class SfgDeferredVectorMapping(SfgDeferredNode): for param in ppc.live_variables: if param.name in self._scalars: idx, _ = self._scalars[param.name] - expr = self._vector.extract_component(idx) + expr = self._vector._extract_component(idx) nodes.append( SfgStatements( f"{param.dtype.c_string()} {param.name} {{ {expr} }};", diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index 980d2bdf9d543ce764e1f739926ee460b5fe9699..ff8f3ced07deed67e1be6e1f4760c6991e7a724f 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -13,11 +13,10 @@ from .expressions import ( includes, CppClass, cppclass, - IFieldExtraction, - SrcField, - SrcVector, ) +from .extractions import SupportsFieldExtraction, SupportsVectorExtraction + from .types import cpptype, void, Ref, strip_ptr_ref __all__ = [ @@ -32,13 +31,12 @@ __all__ = [ "asvar", "depends", "includes", - "IFieldExtraction", - "SrcField", - "SrcVector", "cpptype", "CppClass", "cppclass", "void", "Ref", "strip_ptr_ref", + "SupportsFieldExtraction", + "SupportsVectorExtraction", ] diff --git a/src/pystencilssfg/lang/cpp/std_mdspan.py b/src/pystencilssfg/lang/cpp/std_mdspan.py index 5e552e957c51db36a91ce17d3cfe96506343cdba..ad5feedaf6ca43a6ef1b88a5b53e2b6bf454167d 100644 --- a/src/pystencilssfg/lang/cpp/std_mdspan.py +++ b/src/pystencilssfg/lang/cpp/std_mdspan.py @@ -11,10 +11,10 @@ from pystencils.types import ( from pystencilssfg.lang.expressions import AugExpr -from ...lang import SrcField, IFieldExtraction, cpptype, HeaderFile, ExprLike +from ...lang import SupportsFieldExtraction, cpptype, HeaderFile, ExprLike -class StdMdspan(SrcField): +class StdMdspan(AugExpr, SupportsFieldExtraction): """Represents an `std::mdspan` instance. The `std::mdspan <https://en.cppreference.com/w/cpp/container/mdspan>`_ @@ -141,26 +141,22 @@ class StdMdspan(SrcField): def data_handle(self) -> AugExpr: return AugExpr.format("{}.data_handle()", self) - def get_extraction(self) -> IFieldExtraction: - mdspan = self + # SupportsFieldExtraction protocol - class Extraction(IFieldExtraction): - def ptr(self) -> AugExpr: - return mdspan.data_handle() + def _extract_ptr(self) -> AugExpr: + return self.data_handle() - def size(self, coordinate: int) -> AugExpr | None: - if coordinate > mdspan._dim: - return None - else: - return mdspan.extent(coordinate) + def _extract_size(self, coordinate: int) -> AugExpr | None: + if coordinate > self._dim: + return None + else: + return self.extent(coordinate) - def stride(self, coordinate: int) -> AugExpr | None: - if coordinate > mdspan._dim: - return None - else: - return mdspan.stride(coordinate) - - return Extraction() + def _extract_stride(self, coordinate: int) -> AugExpr | None: + if coordinate > self._dim: + return None + else: + return self.stride(coordinate) @staticmethod def from_field( diff --git a/src/pystencilssfg/lang/cpp/std_span.py b/src/pystencilssfg/lang/cpp/std_span.py index ea4b520ee23717701bc70f615ace3ed8103a2937..7661ea91151e15f171eb63e290f1353d0047ee00 100644 --- a/src/pystencilssfg/lang/cpp/std_span.py +++ b/src/pystencilssfg/lang/cpp/std_span.py @@ -1,10 +1,10 @@ from pystencils import Field, DynamicType from pystencils.types import UserTypeSpec, create_type, PsType -from ...lang import SrcField, IFieldExtraction, AugExpr, cpptype +from ...lang import SupportsFieldExtraction, AugExpr, cpptype -class StdSpan(SrcField): +class StdSpan(AugExpr, SupportsFieldExtraction): _template = cpptype("std::span< {T} >", "<span>") def __init__(self, T: UserTypeSpec, ref=False, const=False): @@ -17,26 +17,20 @@ class StdSpan(SrcField): def element_type(self) -> PsType: return self._element_type - def get_extraction(self) -> IFieldExtraction: - span = self + def _extract_ptr(self) -> AugExpr: + return AugExpr.format("{}.data()", self) - class Extraction(IFieldExtraction): - def ptr(self) -> AugExpr: - return AugExpr.format("{}.data()", span) + def _extract_size(self, coordinate: int) -> AugExpr | None: + if coordinate > 0: + return None + else: + return AugExpr.format("{}.size()", self) - def size(self, coordinate: int) -> AugExpr | None: - if coordinate > 0: - return None - else: - return AugExpr.format("{}.size()", span) - - def stride(self, coordinate: int) -> AugExpr | None: - if coordinate > 0: - return None - else: - return AugExpr.format("1") - - return Extraction() + def _extract_stride(self, coordinate: int) -> AugExpr | None: + if coordinate > 0: + return None + else: + return AugExpr.format("1") @staticmethod def from_field(field: Field, ref: bool = False, const: bool = False): diff --git a/src/pystencilssfg/lang/cpp/std_tuple.py b/src/pystencilssfg/lang/cpp/std_tuple.py index 7ea0e2416b2280042521b871b1b40a623971b187..645b6b56fbeb515d6324feafdb8588f7ca22e992 100644 --- a/src/pystencilssfg/lang/cpp/std_tuple.py +++ b/src/pystencilssfg/lang/cpp/std_tuple.py @@ -1,9 +1,9 @@ from pystencils.types import UserTypeSpec, create_type -from ...lang import SrcVector, AugExpr, cpptype +from ...lang import SupportsVectorExtraction, AugExpr, cpptype -class StdTuple(SrcVector): +class StdTuple(AugExpr, SupportsVectorExtraction): _template = cpptype("std::tuple< {ts} >", "<tuple>") def __init__( @@ -19,7 +19,7 @@ class StdTuple(SrcVector): dtype = self._template(ts=", ".join(elt_type_strings), const=const, ref=ref) super().__init__(dtype) - def extract_component(self, coordinate: int) -> AugExpr: + def _extract_component(self, coordinate: int) -> AugExpr: if coordinate < 0 or coordinate >= self._length: raise ValueError( f"Index {coordinate} out-of-bounds for std::tuple with {self._length} entries." diff --git a/src/pystencilssfg/lang/cpp/std_vector.py b/src/pystencilssfg/lang/cpp/std_vector.py index 7356f942047941b429381d8b2bc111c17b7e6f9d..5ce626dea2b29be50e168f8493612625a82ff002 100644 --- a/src/pystencilssfg/lang/cpp/std_vector.py +++ b/src/pystencilssfg/lang/cpp/std_vector.py @@ -1,10 +1,10 @@ from pystencils import Field, DynamicType from pystencils.types import UserTypeSpec, create_type, PsType -from ...lang import SrcField, SrcVector, AugExpr, IFieldExtraction, cpptype +from ...lang import SupportsFieldExtraction, SupportsVectorExtraction, AugExpr, cpptype -class StdVector(SrcVector, SrcField): +class StdVector(AugExpr, SupportsFieldExtraction, SupportsVectorExtraction): _template = cpptype("std::vector< {T} >", "<vector>") def __init__( @@ -25,28 +25,22 @@ class StdVector(SrcVector, SrcField): def element_type(self) -> PsType: return self._element_type - def get_extraction(self) -> IFieldExtraction: - vec = self + def _extract_ptr(self) -> AugExpr: + return AugExpr.format("{}.data()", self) - class Extraction(IFieldExtraction): - def ptr(self) -> AugExpr: - return AugExpr.format("{}.data()", vec) - - def size(self, coordinate: int) -> AugExpr | None: - if coordinate > 0: - return None - else: - return AugExpr.format("{}.size()", vec) - - def stride(self, coordinate: int) -> AugExpr | None: - if coordinate > 0: - return None - else: - return AugExpr.format("1") + def _extract_size(self, coordinate: int) -> AugExpr | None: + if coordinate > 0: + return None + else: + return AugExpr.format("{}.size()", self) - return Extraction() + def _extract_stride(self, coordinate: int) -> AugExpr | None: + if coordinate > 0: + return None + else: + return AugExpr.format("1") - def extract_component(self, coordinate: int) -> AugExpr: + def _extract_component(self, coordinate: int) -> AugExpr: if self._unsafe: return AugExpr.format("{}[{}]", self, coordinate) else: diff --git a/src/pystencilssfg/lang/cpp/sycl_accessor.py b/src/pystencilssfg/lang/cpp/sycl_accessor.py index 0052302adef8e8d98914073e124da24891c301e9..08a92ce6610d558db85794455ba64414087b2051 100644 --- a/src/pystencilssfg/lang/cpp/sycl_accessor.py +++ b/src/pystencilssfg/lang/cpp/sycl_accessor.py @@ -1,12 +1,10 @@ -from ...lang import SrcField, IFieldExtraction - from pystencils import Field, DynamicType from pystencils.types import UserTypeSpec, create_type -from ...lang import AugExpr, cpptype +from ...lang import AugExpr, cpptype, SupportsFieldExtraction -class SyclAccessor(SrcField): +class SyclAccessor(AugExpr, SupportsFieldExtraction): """Represent a `SYCL Accessor <https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#subsec:accessors>`_. @@ -36,40 +34,32 @@ class SyclAccessor(SrcField): self._dim = dimensions self._inner_stride = 1 - def get_extraction(self) -> IFieldExtraction: - accessor = self - - class Extraction(IFieldExtraction): - def ptr(self) -> AugExpr: - return AugExpr.format( - "{}.get_multi_ptr<sycl::access::decorated::no>().get()", - accessor, - ) - - def size(self, coordinate: int) -> AugExpr | None: - if coordinate > accessor._dim: - return None - else: - return AugExpr.format( - "{}.get_range().get({})", accessor, coordinate - ) - - def stride(self, coordinate: int) -> AugExpr | None: - if coordinate > accessor._dim: - return None - elif coordinate == accessor._dim - 1: - return AugExpr.format("{}", accessor._inner_stride) - else: - exprs = [] - args = [] - for d in range(coordinate + 1, accessor._dim): - args.extend([accessor, d]) - exprs.append("{}.get_range().get({})") - expr = " * ".join(exprs) - expr += " * {}" - return AugExpr.format(expr, *args, accessor._inner_stride) - - return Extraction() + def _extract_ptr(self) -> AugExpr: + return AugExpr.format( + "{}.get_multi_ptr<sycl::access::decorated::no>().get()", + self, + ) + + def _extract_size(self, coordinate: int) -> AugExpr | None: + if coordinate > self._dim: + return None + else: + return AugExpr.format("{}.get_range().get({})", self, coordinate) + + def _extract_stride(self, coordinate: int) -> AugExpr | None: + if coordinate > self._dim: + return None + elif coordinate == self._dim - 1: + return AugExpr.format("{}", self._inner_stride) + else: + exprs = [] + args = [] + for d in range(coordinate + 1, self._dim): + args.extend([self, d]) + exprs.append("{}.get_range().get({})") + expr = " * ".join(exprs) + expr += " * {}" + return AugExpr.format(expr, *args, self._inner_stride) @staticmethod def from_field(field: Field, ref: bool = True): diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 67bd3eb977e7b786c2dcff7b97e8b8a729ab49a2..135a54eed92e4ba214244c8f46323ea81f6610db 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -2,7 +2,6 @@ from __future__ import annotations from typing import Iterable, TypeAlias, Any, cast from itertools import chain -from abc import ABC, abstractmethod import sympy as sp @@ -501,39 +500,3 @@ def includes(obj: ExprLike | PsType) -> set[HeaderFile]: case _: raise ValueError(f"Invalid expression: {obj}") - - -class IFieldExtraction(ABC): - """Interface for objects defining how to extract low-level field parameters - from high-level data structures.""" - - @abstractmethod - def ptr(self) -> AugExpr: ... - - @abstractmethod - def size(self, coordinate: int) -> AugExpr | None: ... - - @abstractmethod - def stride(self, coordinate: int) -> AugExpr | None: ... - - -class SrcField(AugExpr): - """Represents a C++ data structure that can be mapped to a *pystencils* field. - - Args: - dtype: Data type of the field data structure - """ - - @abstractmethod - def get_extraction(self) -> IFieldExtraction: ... - - -class SrcVector(AugExpr, ABC): - """Represents a C++ data structure that represents a mathematical vector. - - Args: - dtype: Data type of the vector data structure - """ - - @abstractmethod - def extract_component(self, coordinate: int) -> AugExpr: ... diff --git a/src/pystencilssfg/lang/extractions.py b/src/pystencilssfg/lang/extractions.py new file mode 100644 index 0000000000000000000000000000000000000000..e920fcbfc453d53c22f0486ab7c051ed6c5a7c7f --- /dev/null +++ b/src/pystencilssfg/lang/extractions.py @@ -0,0 +1,62 @@ +from __future__ import annotations +from typing import Protocol +from abc import abstractmethod + +from .expressions import AugExpr + + +class SupportsFieldExtraction(Protocol): + """Protocol for field pointer and indexing extraction. + + Objects adhering to this protocol are understood to provide expressions + for the base pointer, shape, and stride properties of a field. + They can therefore be passed to `sfg.map_field <SfgBasicComposer.map_field>`. + """ + +# how-to-guide begin + @abstractmethod + def _extract_ptr(self) -> AugExpr: + """Extract the field base pointer. + + Return an expression which represents the base pointer + of this field data structure. + + :meta public: + """ + + @abstractmethod + def _extract_size(self, coordinate: int) -> AugExpr | None: + """Extract field size in a given coordinate. + + If ``coordinate`` is valid for this field (i.e. smaller than its dimensionality), + return an expression representing the logical size of this field + in the given dimension. + Otherwise, return `None`. + + :meta public: + """ + + @abstractmethod + def _extract_stride(self, coordinate: int) -> AugExpr | None: + """Extract field stride in a given coordinate. + + If ``coordinate`` is valid for this field (i.e. smaller than its dimensionality), + return an expression representing the memory linearization stride of this field + in the given dimension. + Otherwise, return `None`. + + :meta public: + """ +# how-to-guide end + + +class SupportsVectorExtraction(Protocol): + """Protocol for component extraction from a vector. + + Objects adhering to this protocol are understood to provide + access to the entries of a vector + and can therefore be passed to `sfg.map_vector <SfgBasicComposer.map_vector>`. + """ + + @abstractmethod + def _extract_component(self, coordinate: int) -> AugExpr: ... diff --git a/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py index c89fe2455e3ff117596bdd63d538d56d2afcc3b5..c11fdac8f2fb5c3175ba3d66b1f126702c634de7 100644 --- a/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py +++ b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py @@ -22,7 +22,7 @@ with SourceFileGenerator() as sfg: ), sfg.expr( 'assert({} == {} && "Stride mismatch at coordinate {}");', - mdspan.stride(d), + mdspan._extract_stride(d), field.strides[d], d, ), diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py index 9d51c8fa6ef944a51d9c60219c1460061d6514b1..5a9150b63b2a11dc910b9bbfc1c917487f5d1196 100644 --- a/tests/ir/test_postprocessing.py +++ b/tests/ir/test_postprocessing.py @@ -4,7 +4,7 @@ from pystencils.types import PsCustomType from pystencilssfg.composer import make_sequence -from pystencilssfg.lang import IFieldExtraction, AugExpr +from pystencilssfg.lang import AugExpr, SupportsFieldExtraction from pystencilssfg.ir import SfgStatements, SfgSequence from pystencilssfg.ir.postprocessing import CallTreePostProcessing @@ -73,17 +73,17 @@ def test_find_sympy_symbols(sfg): assert call_tree.children[1].code_string == "const double y = x / a;" -class DemoFieldExtraction(IFieldExtraction): +class DemoFieldExtraction(SupportsFieldExtraction): def __init__(self, name: str): self.obj = AugExpr(PsCustomType("MyField")).var(name) - def ptr(self) -> AugExpr: + def _extract_ptr(self) -> AugExpr: return AugExpr.format("{}.ptr()", self.obj) - def size(self, coordinate: int) -> AugExpr | None: + def _extract_size(self, coordinate: int) -> AugExpr | None: return AugExpr.format("{}.size({})", self.obj, coordinate) - def stride(self, coordinate: int) -> AugExpr | None: + def _extract_stride(self, coordinate: int) -> AugExpr | None: return AugExpr.format("{}.stride({})", self.obj, coordinate)