diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 49b6c73e8bd06b0d5fb44402b8af285362072b29..0f9a339137b201f24405a6d06a190d20326986f9 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -54,9 +54,8 @@ from ..lang import ( includes, SfgVar, AugExpr, - SrcField, - IFieldExtraction, - SrcVector, + SupportsFieldExtraction, + SupportsVectorExtraction, void, ) from ..exceptions import SfgException @@ -511,7 +510,7 @@ class SfgBasicComposer(SfgIComposer): def map_field( self, field: Field, - index_provider: IFieldExtraction | SrcField, + index_provider: SupportsFieldExtraction, cast_indexing_symbols: bool = True, ) -> SfgDeferredFieldMapping: """Map a pystencils field to a field data structure, from which pointers, sizes @@ -536,7 +535,7 @@ class SfgBasicComposer(SfgIComposer): var: SfgVar | sp.Symbol = asvar(param) if isinstance(param, _VarLike) else param return SfgDeferredParamSetter(var, expr) - def map_vector(self, lhs_components: Sequence[VarLike | sp.Symbol], rhs: SrcVector): + def map_vector(self, lhs_components: Sequence[VarLike | sp.Symbol], rhs: SupportsVectorExtraction): """Extracts scalar numerical values from a vector data type. Args: diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index 469383c123d7403ca448efb40dad5692d46d0a3c..1e692b0aa9f37368da1688ec2d3bac6892c5ac60 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -19,9 +19,8 @@ from .call_tree import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStateme from ..lang.expressions import SfgKernelParamVar from ..lang import ( SfgVar, - IFieldExtraction, - SrcField, - SrcVector, + SupportsFieldExtraction, + SupportsVectorExtraction, ExprLike, AugExpr, depends, @@ -211,7 +210,9 @@ class SfgDeferredParamSetter(SfgDeferredNode): live_var = ppc.get_live_variable(self._lhs.name) if live_var is not None: code = f"{live_var.dtype.c_string()} {live_var.name} = {self._rhs};" - return SfgStatements(code, (live_var,), depends(self._rhs), includes(self._rhs)) + return SfgStatements( + code, (live_var,), depends(self._rhs), includes(self._rhs) + ) else: return SfgSequence([]) @@ -222,15 +223,11 @@ class SfgDeferredFieldMapping(SfgDeferredNode): def __init__( self, psfield: Field, - extraction: IFieldExtraction | SrcField, + extraction: SupportsFieldExtraction, cast_indexing_symbols: bool = True, ): self._field = psfield - self._extraction: IFieldExtraction = ( - extraction - if isinstance(extraction, IFieldExtraction) - else extraction.get_extraction() - ) + self._extraction = extraction self._cast_indexing_symbols = cast_indexing_symbols def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: @@ -267,7 +264,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): done: set[SfgKernelParamVar] = set() if ptr is not None: - expr = self._extraction.ptr() + expr = self._extraction._extract_ptr() nodes.append( SfgStatements( f"{ptr.dtype.c_string()} {ptr.name} {{ {expr} }};", @@ -286,7 +283,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): return expr def get_shape(coord, symb: SfgKernelParamVar | str): - expr = self._extraction.size(coord) + expr = self._extraction._extract_size(coord) if expr is None: raise SfgException( @@ -306,7 +303,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): return SfgStatements(f"/* {expr} == {symb} */", (), ()) def get_stride(coord, symb: SfgKernelParamVar | str): - expr = self._extraction.stride(coord) + expr = self._extraction._extract_stride(coord) if expr is None: raise SfgException( @@ -332,7 +329,11 @@ class SfgDeferredFieldMapping(SfgDeferredNode): class SfgDeferredVectorMapping(SfgDeferredNode): - def __init__(self, scalars: Sequence[sp.Symbol | SfgVar], vector: SrcVector): + def __init__( + self, + scalars: Sequence[sp.Symbol | SfgVar], + vector: SupportsVectorExtraction, + ): self._scalars = {sc.name: (i, sc) for i, sc in enumerate(scalars)} self._vector = vector @@ -342,7 +343,7 @@ class SfgDeferredVectorMapping(SfgDeferredNode): for param in ppc.live_variables: if param.name in self._scalars: idx, _ = self._scalars[param.name] - expr = self._vector.extract_component(idx) + expr = self._vector._extract_component(idx) nodes.append( SfgStatements( f"{param.dtype.c_string()} {param.name} {{ {expr} }};", diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index 980d2bdf9d543ce764e1f739926ee460b5fe9699..ff8f3ced07deed67e1be6e1f4760c6991e7a724f 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -13,11 +13,10 @@ from .expressions import ( includes, CppClass, cppclass, - IFieldExtraction, - SrcField, - SrcVector, ) +from .extractions import SupportsFieldExtraction, SupportsVectorExtraction + from .types import cpptype, void, Ref, strip_ptr_ref __all__ = [ @@ -32,13 +31,12 @@ __all__ = [ "asvar", "depends", "includes", - "IFieldExtraction", - "SrcField", - "SrcVector", "cpptype", "CppClass", "cppclass", "void", "Ref", "strip_ptr_ref", + "SupportsFieldExtraction", + "SupportsVectorExtraction", ] diff --git a/src/pystencilssfg/lang/cpp/std_mdspan.py b/src/pystencilssfg/lang/cpp/std_mdspan.py index 5e552e957c51db36a91ce17d3cfe96506343cdba..ad5feedaf6ca43a6ef1b88a5b53e2b6bf454167d 100644 --- a/src/pystencilssfg/lang/cpp/std_mdspan.py +++ b/src/pystencilssfg/lang/cpp/std_mdspan.py @@ -11,10 +11,10 @@ from pystencils.types import ( from pystencilssfg.lang.expressions import AugExpr -from ...lang import SrcField, IFieldExtraction, cpptype, HeaderFile, ExprLike +from ...lang import SupportsFieldExtraction, cpptype, HeaderFile, ExprLike -class StdMdspan(SrcField): +class StdMdspan(AugExpr, SupportsFieldExtraction): """Represents an `std::mdspan` instance. The `std::mdspan <https://en.cppreference.com/w/cpp/container/mdspan>`_ @@ -141,26 +141,22 @@ class StdMdspan(SrcField): def data_handle(self) -> AugExpr: return AugExpr.format("{}.data_handle()", self) - def get_extraction(self) -> IFieldExtraction: - mdspan = self + # SupportsFieldExtraction protocol - class Extraction(IFieldExtraction): - def ptr(self) -> AugExpr: - return mdspan.data_handle() + def _extract_ptr(self) -> AugExpr: + return self.data_handle() - def size(self, coordinate: int) -> AugExpr | None: - if coordinate > mdspan._dim: - return None - else: - return mdspan.extent(coordinate) + def _extract_size(self, coordinate: int) -> AugExpr | None: + if coordinate > self._dim: + return None + else: + return self.extent(coordinate) - def stride(self, coordinate: int) -> AugExpr | None: - if coordinate > mdspan._dim: - return None - else: - return mdspan.stride(coordinate) - - return Extraction() + def _extract_stride(self, coordinate: int) -> AugExpr | None: + if coordinate > self._dim: + return None + else: + return self.stride(coordinate) @staticmethod def from_field( diff --git a/src/pystencilssfg/lang/cpp/std_span.py b/src/pystencilssfg/lang/cpp/std_span.py index ea4b520ee23717701bc70f615ace3ed8103a2937..7661ea91151e15f171eb63e290f1353d0047ee00 100644 --- a/src/pystencilssfg/lang/cpp/std_span.py +++ b/src/pystencilssfg/lang/cpp/std_span.py @@ -1,10 +1,10 @@ from pystencils import Field, DynamicType from pystencils.types import UserTypeSpec, create_type, PsType -from ...lang import SrcField, IFieldExtraction, AugExpr, cpptype +from ...lang import SupportsFieldExtraction, AugExpr, cpptype -class StdSpan(SrcField): +class StdSpan(AugExpr, SupportsFieldExtraction): _template = cpptype("std::span< {T} >", "<span>") def __init__(self, T: UserTypeSpec, ref=False, const=False): @@ -17,26 +17,20 @@ class StdSpan(SrcField): def element_type(self) -> PsType: return self._element_type - def get_extraction(self) -> IFieldExtraction: - span = self + def _extract_ptr(self) -> AugExpr: + return AugExpr.format("{}.data()", self) - class Extraction(IFieldExtraction): - def ptr(self) -> AugExpr: - return AugExpr.format("{}.data()", span) + def _extract_size(self, coordinate: int) -> AugExpr | None: + if coordinate > 0: + return None + else: + return AugExpr.format("{}.size()", self) - def size(self, coordinate: int) -> AugExpr | None: - if coordinate > 0: - return None - else: - return AugExpr.format("{}.size()", span) - - def stride(self, coordinate: int) -> AugExpr | None: - if coordinate > 0: - return None - else: - return AugExpr.format("1") - - return Extraction() + def _extract_stride(self, coordinate: int) -> AugExpr | None: + if coordinate > 0: + return None + else: + return AugExpr.format("1") @staticmethod def from_field(field: Field, ref: bool = False, const: bool = False): diff --git a/src/pystencilssfg/lang/cpp/std_tuple.py b/src/pystencilssfg/lang/cpp/std_tuple.py index 7ea0e2416b2280042521b871b1b40a623971b187..645b6b56fbeb515d6324feafdb8588f7ca22e992 100644 --- a/src/pystencilssfg/lang/cpp/std_tuple.py +++ b/src/pystencilssfg/lang/cpp/std_tuple.py @@ -1,9 +1,9 @@ from pystencils.types import UserTypeSpec, create_type -from ...lang import SrcVector, AugExpr, cpptype +from ...lang import SupportsVectorExtraction, AugExpr, cpptype -class StdTuple(SrcVector): +class StdTuple(AugExpr, SupportsVectorExtraction): _template = cpptype("std::tuple< {ts} >", "<tuple>") def __init__( @@ -19,7 +19,7 @@ class StdTuple(SrcVector): dtype = self._template(ts=", ".join(elt_type_strings), const=const, ref=ref) super().__init__(dtype) - def extract_component(self, coordinate: int) -> AugExpr: + def _extract_component(self, coordinate: int) -> AugExpr: if coordinate < 0 or coordinate >= self._length: raise ValueError( f"Index {coordinate} out-of-bounds for std::tuple with {self._length} entries." diff --git a/src/pystencilssfg/lang/cpp/std_vector.py b/src/pystencilssfg/lang/cpp/std_vector.py index 7356f942047941b429381d8b2bc111c17b7e6f9d..5ce626dea2b29be50e168f8493612625a82ff002 100644 --- a/src/pystencilssfg/lang/cpp/std_vector.py +++ b/src/pystencilssfg/lang/cpp/std_vector.py @@ -1,10 +1,10 @@ from pystencils import Field, DynamicType from pystencils.types import UserTypeSpec, create_type, PsType -from ...lang import SrcField, SrcVector, AugExpr, IFieldExtraction, cpptype +from ...lang import SupportsFieldExtraction, SupportsVectorExtraction, AugExpr, cpptype -class StdVector(SrcVector, SrcField): +class StdVector(AugExpr, SupportsFieldExtraction, SupportsVectorExtraction): _template = cpptype("std::vector< {T} >", "<vector>") def __init__( @@ -25,28 +25,22 @@ class StdVector(SrcVector, SrcField): def element_type(self) -> PsType: return self._element_type - def get_extraction(self) -> IFieldExtraction: - vec = self + def _extract_ptr(self) -> AugExpr: + return AugExpr.format("{}.data()", self) - class Extraction(IFieldExtraction): - def ptr(self) -> AugExpr: - return AugExpr.format("{}.data()", vec) - - def size(self, coordinate: int) -> AugExpr | None: - if coordinate > 0: - return None - else: - return AugExpr.format("{}.size()", vec) - - def stride(self, coordinate: int) -> AugExpr | None: - if coordinate > 0: - return None - else: - return AugExpr.format("1") + def _extract_size(self, coordinate: int) -> AugExpr | None: + if coordinate > 0: + return None + else: + return AugExpr.format("{}.size()", self) - return Extraction() + def _extract_stride(self, coordinate: int) -> AugExpr | None: + if coordinate > 0: + return None + else: + return AugExpr.format("1") - def extract_component(self, coordinate: int) -> AugExpr: + def _extract_component(self, coordinate: int) -> AugExpr: if self._unsafe: return AugExpr.format("{}[{}]", self, coordinate) else: diff --git a/src/pystencilssfg/lang/cpp/sycl_accessor.py b/src/pystencilssfg/lang/cpp/sycl_accessor.py index 0052302adef8e8d98914073e124da24891c301e9..08a92ce6610d558db85794455ba64414087b2051 100644 --- a/src/pystencilssfg/lang/cpp/sycl_accessor.py +++ b/src/pystencilssfg/lang/cpp/sycl_accessor.py @@ -1,12 +1,10 @@ -from ...lang import SrcField, IFieldExtraction - from pystencils import Field, DynamicType from pystencils.types import UserTypeSpec, create_type -from ...lang import AugExpr, cpptype +from ...lang import AugExpr, cpptype, SupportsFieldExtraction -class SyclAccessor(SrcField): +class SyclAccessor(AugExpr, SupportsFieldExtraction): """Represent a `SYCL Accessor <https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#subsec:accessors>`_. @@ -36,40 +34,32 @@ class SyclAccessor(SrcField): self._dim = dimensions self._inner_stride = 1 - def get_extraction(self) -> IFieldExtraction: - accessor = self - - class Extraction(IFieldExtraction): - def ptr(self) -> AugExpr: - return AugExpr.format( - "{}.get_multi_ptr<sycl::access::decorated::no>().get()", - accessor, - ) - - def size(self, coordinate: int) -> AugExpr | None: - if coordinate > accessor._dim: - return None - else: - return AugExpr.format( - "{}.get_range().get({})", accessor, coordinate - ) - - def stride(self, coordinate: int) -> AugExpr | None: - if coordinate > accessor._dim: - return None - elif coordinate == accessor._dim - 1: - return AugExpr.format("{}", accessor._inner_stride) - else: - exprs = [] - args = [] - for d in range(coordinate + 1, accessor._dim): - args.extend([accessor, d]) - exprs.append("{}.get_range().get({})") - expr = " * ".join(exprs) - expr += " * {}" - return AugExpr.format(expr, *args, accessor._inner_stride) - - return Extraction() + def _extract_ptr(self) -> AugExpr: + return AugExpr.format( + "{}.get_multi_ptr<sycl::access::decorated::no>().get()", + self, + ) + + def _extract_size(self, coordinate: int) -> AugExpr | None: + if coordinate > self._dim: + return None + else: + return AugExpr.format("{}.get_range().get({})", self, coordinate) + + def _extract_stride(self, coordinate: int) -> AugExpr | None: + if coordinate > self._dim: + return None + elif coordinate == self._dim - 1: + return AugExpr.format("{}", self._inner_stride) + else: + exprs = [] + args = [] + for d in range(coordinate + 1, self._dim): + args.extend([self, d]) + exprs.append("{}.get_range().get({})") + expr = " * ".join(exprs) + expr += " * {}" + return AugExpr.format(expr, *args, self._inner_stride) @staticmethod def from_field(field: Field, ref: bool = True): diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 67bd3eb977e7b786c2dcff7b97e8b8a729ab49a2..135a54eed92e4ba214244c8f46323ea81f6610db 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -2,7 +2,6 @@ from __future__ import annotations from typing import Iterable, TypeAlias, Any, cast from itertools import chain -from abc import ABC, abstractmethod import sympy as sp @@ -501,39 +500,3 @@ def includes(obj: ExprLike | PsType) -> set[HeaderFile]: case _: raise ValueError(f"Invalid expression: {obj}") - - -class IFieldExtraction(ABC): - """Interface for objects defining how to extract low-level field parameters - from high-level data structures.""" - - @abstractmethod - def ptr(self) -> AugExpr: ... - - @abstractmethod - def size(self, coordinate: int) -> AugExpr | None: ... - - @abstractmethod - def stride(self, coordinate: int) -> AugExpr | None: ... - - -class SrcField(AugExpr): - """Represents a C++ data structure that can be mapped to a *pystencils* field. - - Args: - dtype: Data type of the field data structure - """ - - @abstractmethod - def get_extraction(self) -> IFieldExtraction: ... - - -class SrcVector(AugExpr, ABC): - """Represents a C++ data structure that represents a mathematical vector. - - Args: - dtype: Data type of the vector data structure - """ - - @abstractmethod - def extract_component(self, coordinate: int) -> AugExpr: ... diff --git a/src/pystencilssfg/lang/extractions.py b/src/pystencilssfg/lang/extractions.py new file mode 100644 index 0000000000000000000000000000000000000000..40c69220fd40f2ddc6358d21feea647f8aa350ba --- /dev/null +++ b/src/pystencilssfg/lang/extractions.py @@ -0,0 +1,35 @@ +from __future__ import annotations +from typing import Protocol +from abc import abstractmethod + +from .expressions import AugExpr + + +class SupportsFieldExtraction(Protocol): + """Protocol for field pointer and indexing extraction. + + Objects adhering to this protocol are understood to provide expressions + for the base pointer, shape, and stride properties of a field. + They can therefore be passed to `sfg.map_field <SfgBasicComposer.map_field>`. + """ + + @abstractmethod + def _extract_ptr(self) -> AugExpr: ... + + @abstractmethod + def _extract_size(self, coordinate: int) -> AugExpr | None: ... + + @abstractmethod + def _extract_stride(self, coordinate: int) -> AugExpr | None: ... + + +class SupportsVectorExtraction(Protocol): + """Protocol for component extraction from a vector. + + Objects adhering to this protocol are understood to provide + access to the entries of a vector + and can therefore be passed to `sfg.map_vector <SfgBasicComposer.map_vector>`. + """ + + @abstractmethod + def _extract_component(self, coordinate: int) -> AugExpr: ... diff --git a/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py index c89fe2455e3ff117596bdd63d538d56d2afcc3b5..c11fdac8f2fb5c3175ba3d66b1f126702c634de7 100644 --- a/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py +++ b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py @@ -22,7 +22,7 @@ with SourceFileGenerator() as sfg: ), sfg.expr( 'assert({} == {} && "Stride mismatch at coordinate {}");', - mdspan.stride(d), + mdspan._extract_stride(d), field.strides[d], d, ), diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py index 9d51c8fa6ef944a51d9c60219c1460061d6514b1..5a9150b63b2a11dc910b9bbfc1c917487f5d1196 100644 --- a/tests/ir/test_postprocessing.py +++ b/tests/ir/test_postprocessing.py @@ -4,7 +4,7 @@ from pystencils.types import PsCustomType from pystencilssfg.composer import make_sequence -from pystencilssfg.lang import IFieldExtraction, AugExpr +from pystencilssfg.lang import AugExpr, SupportsFieldExtraction from pystencilssfg.ir import SfgStatements, SfgSequence from pystencilssfg.ir.postprocessing import CallTreePostProcessing @@ -73,17 +73,17 @@ def test_find_sympy_symbols(sfg): assert call_tree.children[1].code_string == "const double y = x / a;" -class DemoFieldExtraction(IFieldExtraction): +class DemoFieldExtraction(SupportsFieldExtraction): def __init__(self, name: str): self.obj = AugExpr(PsCustomType("MyField")).var(name) - def ptr(self) -> AugExpr: + def _extract_ptr(self) -> AugExpr: return AugExpr.format("{}.ptr()", self.obj) - def size(self, coordinate: int) -> AugExpr | None: + def _extract_size(self, coordinate: int) -> AugExpr | None: return AugExpr.format("{}.size({})", self.obj, coordinate) - def stride(self, coordinate: int) -> AugExpr | None: + def _extract_stride(self, coordinate: int) -> AugExpr | None: return AugExpr.format("{}.stride({})", self.obj, coordinate)