From 1934248a45fb6c4cbaddc25c9c7f2b17ab2c55c9 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Sat, 16 Dec 2023 13:30:45 +0100 Subject: [PATCH] std::tuple mapping --- .../source_concepts/cpp/__init__.py | 9 +++- .../source_concepts/cpp/std_tuple.py | 53 +++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 src/pystencilssfg/source_concepts/cpp/std_tuple.py diff --git a/src/pystencilssfg/source_concepts/cpp/__init__.py b/src/pystencilssfg/source_concepts/cpp/__init__.py index b6e5878..7a2a910 100644 --- a/src/pystencilssfg/source_concepts/cpp/__init__.py +++ b/src/pystencilssfg/source_concepts/cpp/__init__.py @@ -1,7 +1,12 @@ from .std_mdspan import StdMdspan, mdspan_ref from .std_vector import StdVector, std_vector_ref +from .std_tuple import StdTuple, std_tuple_ref __all__ = [ - "StdMdspan", "StdVector", "std_vector_ref", - "mdspan_ref" + "StdMdspan", + "mdspan_ref", + "StdVector", + "std_vector_ref", + "StdTuple", + "std_tuple_ref", ] diff --git a/src/pystencilssfg/source_concepts/cpp/std_tuple.py b/src/pystencilssfg/source_concepts/cpp/std_tuple.py new file mode 100644 index 0000000..717a83c --- /dev/null +++ b/src/pystencilssfg/source_concepts/cpp/std_tuple.py @@ -0,0 +1,53 @@ +from typing import Sequence + +from pystencils.typing import BasicType, TypedSymbol + +from ...tree import SfgStatements +from ..source_objects import SrcVector +from ..source_objects import TypedSymbolOrObject +from ...types import SrcType, cpp_typename +from ...source_components import SfgHeaderInclude + + +class StdTuple(SrcVector): + def __init__( + self, + identifier: str, + element_types: Sequence[BasicType], + const: bool = False, + ref: bool = False, + ): + self._element_types = element_types + self._length = len(element_types) + elt_type_strings = tuple(cpp_typename(t) for t in self._element_types) + src_type = f"{'const' if const else ''} std::tuple< {', '.join(elt_type_strings)} > {'&' if ref else ''}" + super().__init__(identifier, SrcType(src_type)) + + @property + def required_includes(self) -> set[SfgHeaderInclude]: + return {SfgHeaderInclude("tuple", system_header=True)} + + def extract_component(self, destination: TypedSymbolOrObject, coordinate: int): + if coordinate < 0 or coordinate >= self._length: + raise ValueError( + f"Index {coordinate} out-of-bounds for std::tuple with {self._length} entries." + ) + + if destination.dtype != self._element_types[coordinate]: + raise ValueError( + f"Cannot extract type {destination.dtype} from std::tuple entry " + "of type {self._element_types[coordinate]}" + ) + + return SfgStatements( + f"{destination.dtype} {destination.name} = std::get< {coordinate} >({self.identifier});", + (destination,), + (self,), + ) + + +def std_tuple_ref( + identifier: str, components: Sequence[TypedSymbol], const: bool = True +): + elt_types = tuple(c.dtype for c in components) + return StdTuple(identifier, elt_types, const=const, ref=True) -- GitLab