diff --git a/conftest.py b/conftest.py index 1c85902fdb05c4799fd18c2b4ac1710fca3ae933..661e722446f25fca1230c55989f00372bf66802c 100644 --- a/conftest.py +++ b/conftest.py @@ -3,13 +3,15 @@ from os import path @pytest.fixture(autouse=True) -def prepare_composer(doctest_namespace): +def prepare_doctest_namespace(doctest_namespace): from pystencilssfg import SfgContext, SfgComposer + from pystencilssfg import lang # Place a composer object in the environment for doctests sfg = SfgComposer(SfgContext()) doctest_namespace["sfg"] = sfg + doctest_namespace["lang"] = lang DATA_DIR = path.join(path.split(__file__)[0], "tests/data") diff --git a/docs/source/conf.py b/docs/source/conf.py index 4bdf700da8e949023ddc1a56fd44d915a9f842db..da6f4d729898cb0215658f99bf1ecea2b018edb5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -24,7 +24,7 @@ html_title = f"pystencils-sfg v{version} Documentation" # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration extensions = [ - "myst_parser", + "myst_nb", "sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx.ext.autosummary", @@ -37,16 +37,8 @@ extensions = [ templates_path = ["_templates"] exclude_patterns = [] -source_suffix = { - ".rst": "restructuredtext", - ".md": "markdown", -} master_doc = "index" nitpicky = True -myst_enable_extensions = [ - "colon_fence", - "dollarmath" -] # -- Options for HTML output ------------------------------------------------- @@ -89,6 +81,15 @@ sfg = SfgComposer(SfgContext()) ''' +# -- Options for MyST / MyST-NB ---------------------------------------------- + +nb_execution_mode = "cache" # do not execute notebooks by default + +myst_enable_extensions = [ + "dollarmath", + "colon_fence", +] + # Prepare code generation examples def build_examples(): diff --git a/docs/source/index.md b/docs/source/index.md index ca35b36efb434430b4155a292784a092620db8d5..0cab08335824b9c6247052f243127b8d72287dfa 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -6,6 +6,7 @@ :caption: User Guide usage/generator_scripts +C++ API Modelling <usage/api_modelling> usage/project_integration usage/tips_n_tricks ``` diff --git a/docs/source/usage/api_modelling.md b/docs/source/usage/api_modelling.md new file mode 100644 index 0000000000000000000000000000000000000000..60d37aee23b1b92a720e1c1d2991ff0b072f01de --- /dev/null +++ b/docs/source/usage/api_modelling.md @@ -0,0 +1,230 @@ +--- +file_format: mystnb +kernelspec: + name: python3 +--- + +# Modelling C++ APIs in pystencils-sfg + +Pystencils-SFG is designed to help you generate C++ code that interfaces with pystencils on the one side, +and with your handwritten code on the other side. +This requires that the C++ classes and APIs of your framework or application be represented within the SFG system. +This guide shows how you can use the facilities of the {any}`pystencilssfg.lang` module to model your C++ interfaces +for use with the code generator. + +To begin, import the `lang` module: + +```{code-cell} ipython3 +from pystencilssfg import lang +``` + +## Defining C++ Types and Type Templates + +The first C++ entities that need to be mirrored for the SFGs are the types and type templates a library +or application uses or exposes. + +### Non-Templated Types + +To define a C++ type, we use {any}`pystencilssfg.lang.cpptype <pystencilssfg.lang.types.cpptype>`: + +```{code-cell} ipython3 +MyClassTypeFactory = lang.cpptype("my_namespace::MyClass", "MyClass.hpp") +MyClassTypeFactory +``` + +This defines two properties of the type: its fully qualified name, and the set of headers +that need to be included when working with the type. +Now, whenever this type occurs as the type of a variable given to pystencils-sfg, +the code generator will make sure that `MyClass.hpp` is included into the respective +generated code file. + +The object returned by `cpptype` is not the type itself, but a factory for instances of the type. +Even as `MyClass` does not have any template parameters, we can create different instances of it: +`const` and non-`const`, as well as references and non-references. +We do this by calling the factory: + +```{code-cell} ipython3 +MyClass = MyClassTypeFactory() +str(MyClass) +``` + +To produce a `const`-qualified version of the type: + +```{code-cell} ipython3 +MyClassConst = MyClassTypeFactory(const=True) +str(MyClassConst) +``` + +And finally, to produce a reference instead: + +```{code-cell} ipython3 +MyClassRef = MyClassTypeFactory(ref=True) +str(MyClassRef) +``` + +Of course, `const` and `ref` can also be combined to create a reference-to-const. + +### Types with Template Parameters + +We can add template parameters to our type by the use of +[Python format strings](https://docs.python.org/3/library/string.html#formatstrings): + +```{code-cell} ipython3 +MyClassTemplate = lang.cpptype("my_namespace::MyClass< {T1}, {T2} >", "MyClass.hpp") +MyClassTemplate +``` + +Here, the type parameters `T1` and `T2` are specified in braces. +For them, values must be provided when calling the factory to instantiate the type: + +```{code-cell} ipython3 +MyClassIntDouble = MyClassTemplate(T1="int", T2="double") +str(MyClassIntDouble) +``` + +The way type parameters are passed to the factory is identical to the behavior of {any}`str.format`, +except that it does not support attribute or element accesses. +In particular, this means that we can also use unnamed, implicit positional parameters: + +```{code-cell} ipython3 +MyClassTemplate = lang.cpptype("my_namespace::MyClass< {}, {} >", "MyClass.hpp") +MyClassIntDouble = MyClassTemplate("int", "double") +str(MyClassIntDouble) +``` + +## Creating Variables and Expressions + +Type templates and types will not get us far on their own. +To use them in APIs, as function or constructor parameters, +or as class members and local objects, +we need to create *variables* with certain types. + +To do so, we need to inject our defined types into the expression framework of pystencils-sfg. +We wrap the type in an interface that allows us to create variables and, later, more complex expressions, +using {any}`lang.CppClass <pystencilssfg.lang.expressions.CppClass>`: + +```{code-cell} ipython3 +class MyClass(lang.CppClass): + template = lang.cpptype("my_namespace::MyClass< {T1}, {T2} >", "MyClass.hpp") +``` + +Instances of `MyClass` can now be created via constructor call, in the same way as above. +This gives us an unbound `MyClass` object, which we can bind to a variable name by calling `var` on it: + +```{code-cell} ipython3 +my_obj = MyClass(T1="int", T2="void").var("my_obj") +my_obj, str(my_obj.dtype) +``` + +## Reflecting C++ Class APIs + +In the previous section, we showed how to reflect a C++ class in pystencils-sfg in order to create +a variable representing an object of that class. +We can now extend this to reflect the public API of the class, in order to create complex expressions +involving objects of `MyClass` during code generation. + +### Public Methods + +Assume `MyClass` has the following public interface: + +```C++ +template< typename T1, typename T2 > +class MyClass { +public: + T1 & getA(); + std::tuple< T1, T2 > getBoth(); + + void replace(T1 a_new, T2 b_new); +} +``` + +We mirror this in our Python reflection of `CppClass` using methods that create `AugExpr` objects, +which represent C++ expressions annotated with variables they depend on. +A possible implementation might look like this: + +```{code-cell} ipython3 +--- +tags: [remove-cell] +--- + +class MyClass(lang.CppClass): + template = lang.cpptype("my_namespace::MyClass< {T1}, {T2} >", "MyClass.hpp") + + def ctor(self, a: lang.AugExpr, b: lang.AugExpr) -> MyClass: + return self.ctor_bind(a, b) + + def getA(self) -> lang.AugExpr: + return lang.AugExpr.format("{}.getA()", self) + + def getBoth(self) -> lang.AugExpr: + return lang.AugExpr.format("{}.getBoth()", self) + + def replace(self, a_new: lang.AugExpr, b_new: lang.AugExpr) -> lang.AugExpr: + return lang.AugExpr.format("{}.replace({}, {})", self, a_new, b_new) +``` + +```{code-block} python +class MyClass(lang.CppClass): + template = lang.cpptype("my_namespace::MyClass< {T1}, {T2} >", "MyClass.hpp") + + def getA(self) -> lang.AugExpr: + return lang.AugExpr.format("{}.getA()", self) + + def getBoth(self) -> lang.AugExpr: + return lang.AugExpr.format("{}.getBoth()", self) + + def replace(self, a_new: lang.AugExpr, b_new: lang.AugExpr) -> lang.AugExpr: + return lang.AugExpr.format("{}.replace({}, {})", self, a_new, b_new) +``` + +Each method of `MyClass` reflects a method of the same name in its public C++ API. +These methods do not return values, but *expressions*; +here, we use the generic `AugExpr` class to model expressions that we don't know anything +about except how they should be constructed. + +We create these expressions using `AugExpr.format`, which takes a format string +and interpolation arguments in the same way as `cpptype`. +Internally, it will analyze the format arguments (e.g. `self`, `a_new` and `b_new` in `replace`), +and combine information from any `AugExpr`s found among them. +These are: + - **Variables**: If any of the input expression depend on variables, the resulting expression will + depend on the union of all these variable sets + - **Headers**: If any of the input expression requires certain header files to be evaluated, + the resulting expression will require the same header files. + +We can see this in action by calling one of the methods on a variable of type `MyClass`: + +```{code-cell} ipython3 +my_obj = MyClass(T1="int", T2="void").var("my_obj") +expr = my_obj.getBoth() +expr, lang.depends(expr), lang.includes(expr) +``` + +We can see: the newly created expression `my_obj.getBoth()` depends on the variable `my_obj` and +requires the header `MyClass.hpp` to be included; this header it has inherited from `my_obj`. + +### Constructors + +Using the `AugExpr` system, we can also model constructors of `MyClass`. +Assume `MyClass` has the constructor `MyClass(T1 a, T2 b)`. +We implement this by adding a `ctor` method to our Python interface: + +```{code-block} python +class MyClass(lang.CppClass): + ... + + def ctor(self, a: lang.AugExpr, b: lang.AugExpr) -> MyClass: + return self.ctor_bind(a, b) +``` + +Here, we don't use `AugExpr.format`; instead, we use `ctor_bind`, which is exposed by `CppClass`. +This will generate the correct constructor invocation from the type of our `MyClass` object +and also ensure the headers required by `MyClass` are correctly attached to the resulting +expression: + +```{code-cell} ipython3 +a = lang.AugExpr("int").var("a") +b = lang.AugExpr("double").var("b") +expr = MyClass(T1="int", T2="double").ctor(a, b) +expr, lang.depends(expr), lang.includes(expr) +``` diff --git a/pyproject.toml b/pyproject.toml index da36a11c4a19d510faadc491045485594790c535..6ac0327d728b8732d86186e213892ae60134a2ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ docs = [ "sphinx", "pydata-sphinx-theme==0.15.4", "sphinx-book-theme==1.1.3", # workaround for https://github.com/executablebooks/sphinx-book-theme/issues/865 - "myst-parser", + "myst-nb", "sphinx_design", "sphinx_autodoc_typehints", "sphinx-copybutton", diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 7f1b58fde5078725eb93ad5c8cfc5f15308be4e3..b96d559a1733ec69b6f40db04f41437cc80e3ad7 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -7,13 +7,7 @@ from functools import reduce from pystencils import Field from pystencils.codegen import Kernel -from pystencils.types import ( - create_type, - UserTypeSpec, - PsCustomType, - PsPointerType, - PsType, -) +from pystencils.types import create_type, UserTypeSpec from ..context import SfgContext from .custom import CustomGenerator @@ -325,30 +319,6 @@ class SfgBasicComposer(SfgIComposer): """Use inside a function body to require the inclusion of headers.""" return SfgRequireIncludes((HeaderFile.parse(incl) for incl in incls)) - def cpptype( - self, - typename: UserTypeSpec, - ptr: bool = False, - ref: bool = False, - const: bool = False, - ) -> PsType: - if ptr and ref: - raise SfgException("Create either a pointer, or a ref type, not both!") - - ref_qual = "&" if ref else "" - try: - base_type = create_type(typename) - except ValueError: - if not isinstance(typename, str): - raise ValueError(f"Could not parse type: {typename}") - - base_type = PsCustomType(typename + ref_qual, const=const) - - if ptr: - return PsPointerType(base_type) - else: - return base_type - def var(self, name: str, dtype: UserTypeSpec) -> AugExpr: """Create a variable with given name and data type.""" return AugExpr(create_type(dtype)).var(name) diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index 1f4c4865987c1f10ec133f95d11e4bccc1ef8b76..489823b9ce619be88e3220ce4b941cf49c62b298 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Sequence -from pystencils.types import PsCustomType, UserTypeSpec +from pystencils.types import PsCustomType, UserTypeSpec, create_type from ..lang import ( _VarLike, @@ -177,7 +177,7 @@ class SfgClassComposer(SfgComposerMixIn): return SfgMethod( name, tree, - return_type=self._composer.cpptype(returns), + return_type=create_type(returns), inline=inline, const=const, ) diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py index adf7508e705fc9d909e362433bbc13d5048d6475..9d7c97e7ce732066c91eda3e3cbf887dcb552f77 100644 --- a/src/pystencilssfg/emission/printers.py +++ b/src/pystencilssfg/emission/printers.py @@ -166,7 +166,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgMethod) def sfg_method(self, method: SfgMethod): - code = f"{method.return_type} {method.name} ({self.param_list(method)})" + code = f"{method.return_type.c_string()} {method.name} ({self.param_list(method)})" code += "const" if method.const else "" if method.inline: code += ( diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index 349e030c6c9057472b9f513258c8a979bba16b8a..88dbc9be2e215b1fdce5833ef18eac6eab336d74 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -57,9 +57,7 @@ class SyclRange(AugExpr): _template = cpptype("sycl::range< {dims} >", "<sycl/sycl.hpp>") def __init__(self, dims: int, const: bool = False, ref: bool = False): - dtype = self._template(dims=dims, const=const) - if ref: - dtype = Ref(dtype) + dtype = self._template(dims=dims, const=const, ref=ref) super().__init__(dtype) diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index 2ad8f9384be3e5f93e78b2504e2e90b730d345da..9218ec2b7d7f94517e35a2c9a8e4e4ddaa7c3a2a 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -10,6 +10,7 @@ from .expressions import ( asvar, depends, includes, + CppClass, IFieldExtraction, SrcField, SrcVector, @@ -32,7 +33,8 @@ __all__ = [ "SrcField", "SrcVector", "cpptype", + "CppClass", "void", "Ref", - "strip_ptr_ref" + "strip_ptr_ref", ] diff --git a/src/pystencilssfg/lang/cpp/std_mdspan.py b/src/pystencilssfg/lang/cpp/std_mdspan.py index 5f80ad6f15d71ecc12d6f5b55b8671e1dd1124d2..68308850f9dd7d5c3486265fb6c9b634d97c6fb7 100644 --- a/src/pystencilssfg/lang/cpp/std_mdspan.py +++ b/src/pystencilssfg/lang/cpp/std_mdspan.py @@ -11,7 +11,7 @@ from pystencils.types import ( from pystencilssfg.lang.expressions import AugExpr -from ...lang import SrcField, IFieldExtraction, cpptype, Ref, HeaderFile, ExprLike +from ...lang import SrcField, IFieldExtraction, cpptype, HeaderFile, ExprLike class StdMdspan(SrcField): @@ -111,11 +111,8 @@ class StdMdspan(SrcField): layout_policy = f"{self._namespace}::{layout_policy}" dtype = self._template( - T=T, extents=extents_str, layout_policy=layout_policy, const=const + T=T, extents=extents_str, layout_policy=layout_policy, const=const, ref=ref ) - - if ref: - dtype = Ref(dtype) super().__init__(dtype) self._extents_type = extents_str diff --git a/src/pystencilssfg/lang/cpp/std_span.py b/src/pystencilssfg/lang/cpp/std_span.py index 861a4c4bb1ea81b0cbaaef4cb683316274ab2edd..f161f4874f627fa8943f4e24c2a1082780259572 100644 --- a/src/pystencilssfg/lang/cpp/std_span.py +++ b/src/pystencilssfg/lang/cpp/std_span.py @@ -1,7 +1,7 @@ from pystencils.field import Field from pystencils.types import UserTypeSpec, create_type, PsType -from ...lang import SrcField, IFieldExtraction, AugExpr, cpptype, Ref +from ...lang import SrcField, IFieldExtraction, AugExpr, cpptype class StdSpan(SrcField): @@ -9,10 +9,7 @@ class StdSpan(SrcField): def __init__(self, T: UserTypeSpec, ref=False, const=False): T = create_type(T) - dtype = self._template(T=T, const=const) - if ref: - dtype = Ref(dtype) - + dtype = self._template(T=T, const=const, ref=ref) self._element_type = T super().__init__(dtype) diff --git a/src/pystencilssfg/lang/cpp/std_tuple.py b/src/pystencilssfg/lang/cpp/std_tuple.py index bbf2ba33b8f1a19081593501885d0dc935fc3055..58a3530b9e98e2c39e205fd7dac9845b4ff35bda 100644 --- a/src/pystencilssfg/lang/cpp/std_tuple.py +++ b/src/pystencilssfg/lang/cpp/std_tuple.py @@ -2,7 +2,7 @@ from typing import Sequence from pystencils.types import UserTypeSpec, create_type -from ...lang import SrcVector, AugExpr, cpptype, Ref +from ...lang import SrcVector, AugExpr, cpptype class StdTuple(SrcVector): @@ -18,10 +18,7 @@ class StdTuple(SrcVector): self._length = len(element_types) elt_type_strings = tuple(t.c_string() for t in self._element_types) - dtype = self._template(ts=", ".join(elt_type_strings), const=const) - if ref: - dtype = Ref(dtype) - + dtype = self._template(ts=", ".join(elt_type_strings), const=const, ref=ref) super().__init__(dtype) def extract_component(self, coordinate: int) -> AugExpr: diff --git a/src/pystencilssfg/lang/cpp/std_vector.py b/src/pystencilssfg/lang/cpp/std_vector.py index 5696b32dd55c6092db0ff298fdcbd79fa7df69f5..7e9291eab670a1f4f45996b60ea8b8b3e8f49ff4 100644 --- a/src/pystencilssfg/lang/cpp/std_vector.py +++ b/src/pystencilssfg/lang/cpp/std_vector.py @@ -1,7 +1,7 @@ from pystencils.field import Field from pystencils.types import UserTypeSpec, create_type, PsType -from ...lang import SrcField, SrcVector, AugExpr, IFieldExtraction, cpptype, Ref +from ...lang import SrcField, SrcVector, AugExpr, IFieldExtraction, cpptype class StdVector(SrcVector, SrcField): @@ -15,9 +15,7 @@ class StdVector(SrcVector, SrcField): const: bool = False, ): T = create_type(T) - dtype = self._template(T=T, const=const) - if ref: - dtype = Ref(dtype) + dtype = self._template(T=T, const=const, ref=ref) super().__init__(dtype) self._element_type = T diff --git a/src/pystencilssfg/lang/cpp/sycl_accessor.py b/src/pystencilssfg/lang/cpp/sycl_accessor.py index f01c53d24750dc3b4f0350134e9bccd6f8ea4c26..4bcad56cd4ef109faa66757075eeefb6a5b416d3 100644 --- a/src/pystencilssfg/lang/cpp/sycl_accessor.py +++ b/src/pystencilssfg/lang/cpp/sycl_accessor.py @@ -3,7 +3,7 @@ from ...lang import SrcField, IFieldExtraction from pystencils import Field from pystencils.types import UserTypeSpec, create_type -from ...lang import AugExpr, cpptype, Ref +from ...lang import AugExpr, cpptype class SyclAccessor(SrcField): @@ -29,9 +29,7 @@ class SyclAccessor(SrcField): T = create_type(T) if dimensions > 3: raise ValueError("sycl accessors can only have dims 1, 2 or 3") - dtype = self._template(T=T, dims=dimensions, const=const) - if ref: - dtype = Ref(dtype) + dtype = self._template(T=T, dims=dimensions, const=const, ref=ref) super().__init__(dtype) diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 53064bd13a201b0716545bf1e84677baa8df3be9..f86140ee7a3775caab19f69c34ef97822975b95e 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Iterable, TypeAlias, Any +from typing import Iterable, TypeAlias, Any, cast from itertools import chain from abc import ABC, abstractmethod @@ -10,11 +10,12 @@ from pystencils.types import PsType, UserTypeSpec, create_type from ..exceptions import SfgException from .headers import HeaderFile -from .types import strip_ptr_ref, CppType +from .types import strip_ptr_ref, CppType, CppTypeFactory __all__ = [ "SfgVar", "AugExpr", + "CppClass", "VarLike", "ExprLike", "asvar", @@ -147,7 +148,7 @@ class VarExpr(DependentExpression): incls: Iterable[HeaderFile] match base_type: case CppType(): - incls = base_type.includes + incls = base_type.class_includes case _: incls = ( HeaderFile.parse(header) for header in var.dtype.required_headers @@ -220,7 +221,13 @@ class AugExpr: """Create a new `AugExpr` by combining existing expressions.""" return AugExpr().bind(fmt, *deps, **kwdeps) - def bind(self, fmt: str | AugExpr, *deps, **kwdeps): + def bind( + self, + fmt: str | AugExpr, + *deps, + require_headers: Iterable[str | HeaderFile] = (), + **kwdeps, + ): """Bind an unbound `AugExpr` instance to an expression.""" if isinstance(fmt, AugExpr): if bool(deps) or bool(kwdeps): @@ -232,7 +239,7 @@ class AugExpr: self._bind(fmt._bound) else: dependencies: set[SfgVar] = set() - incls: set[HeaderFile] = set() + incls: set[HeaderFile] = set(HeaderFile.parse(h) for h in require_headers) from pystencils.sympyextensions import is_constant @@ -310,6 +317,40 @@ class AugExpr: return self._bound is not None +class CppClass(AugExpr): + """Convenience base class for C++ API mirroring. + + Example: + To reflect a C++ class (template) in pystencils-sfg, you may create a subclass + of `CppClass` like this: + + >>> class MyClassTemplate(CppClass): + ... template = lang.cpptype("mynamespace::MyClassTemplate< {T} >", "MyHeader.hpp") + + + Then use `AugExpr` initialization and binding to create variables or expressions with + this class: + + >>> var = MyClassTemplate(T="float").var("myObj") + >>> var + myObj + + >>> str(var.dtype).strip() + 'mynamespace::MyClassTemplate< float >' + """ + + template: CppTypeFactory + + def __init__(self, *args, const: bool = False, ref: bool = False, **kwargs): + dtype = self.template(*args, **kwargs, const=const, ref=ref) + super().__init__(dtype) + + def ctor_bind(self, *args): + fstr = self.get_dtype().c_string() + "{{" + ", ".join(["{}"] * len(args)) + "}}" + dtype = cast(CppType, self.get_dtype()) + return self.bind(fstr, *args, require_headers=dtype.includes) + + _VarLike = (AugExpr, SfgVar, TypedSymbol) VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol """Things that may act as a variable. @@ -408,7 +449,11 @@ def includes(expr: ExprLike) -> set[HeaderFile]: match expr: case SfgVar(_, dtype): - return set(HeaderFile.parse(h) for h in dtype.required_headers) + match dtype: + case CppType(): + return set(dtype.includes) + case _: + return set(HeaderFile.parse(h) for h in dtype.required_headers) case TypedSymbol(): return includes(asvar(expr)) case str(): @@ -424,16 +469,13 @@ class IFieldExtraction(ABC): from high-level data structures.""" @abstractmethod - def ptr(self) -> AugExpr: - ... + def ptr(self) -> AugExpr: ... @abstractmethod - def size(self, coordinate: int) -> AugExpr | None: - ... + def size(self, coordinate: int) -> AugExpr | None: ... @abstractmethod - def stride(self, coordinate: int) -> AugExpr | None: - ... + def stride(self, coordinate: int) -> AugExpr | None: ... class SrcField(AugExpr): @@ -444,8 +486,7 @@ class SrcField(AugExpr): """ @abstractmethod - def get_extraction(self) -> IFieldExtraction: - ... + def get_extraction(self) -> IFieldExtraction: ... class SrcVector(AugExpr, ABC): @@ -456,5 +497,4 @@ class SrcVector(AugExpr, ABC): """ @abstractmethod - def extract_component(self, coordinate: int) -> AugExpr: - ... + def extract_component(self, coordinate: int) -> AugExpr: ... diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py index b3a634fbb63bb93d6c8e504030590df477cd200d..71c198a3003308e713a31d87a155b99d01e92a07 100644 --- a/src/pystencilssfg/lang/types.py +++ b/src/pystencilssfg/lang/types.py @@ -1,5 +1,11 @@ -from typing import Any, Iterable +from __future__ import annotations +from typing import Any, Iterable, Sequence, Mapping, TypeVar, Generic from abc import ABC +from dataclasses import dataclass +from itertools import chain + +import string + from pystencils.types import PsType, PsPointerType, PsCustomType from .headers import HeaderFile @@ -23,15 +29,163 @@ class VoidType(PsType): void = VoidType() +class _TemplateArgFormatter(string.Formatter): + + def format_field(self, arg, format_spec): + if isinstance(arg, PsType): + arg = arg.c_string() + return super().format_field(arg, format_spec) + + def check_unused_args( + self, used_args: set[int | str], args: Sequence, kwargs: Mapping[str, Any] + ) -> None: + max_args_len: int = ( + max((k for k in used_args if isinstance(k, int)), default=-1) + 1 + ) + if len(args) > max_args_len: + raise ValueError( + f"Too many positional arguments: Expected {max_args_len}, but got {len(args)}" + ) + + extra_keys = set(kwargs.keys()) - used_args # type: ignore + if extra_keys: + raise ValueError(f"Extraneous keyword arguments: {extra_keys}") + + +@dataclass(frozen=True) +class _TemplateArgs: + pargs: tuple[Any, ...] + kwargs: tuple[tuple[str, Any], ...] + + class CppType(PsCustomType, ABC): - includes: frozenset[HeaderFile] + class_includes: frozenset[HeaderFile] + template_string: str + + def __init__(self, *template_args, const: bool = False, **template_kwargs): + # Support for cloning CppTypes + if template_args and isinstance(template_args[0], _TemplateArgs): + assert not template_kwargs + targs = template_args[0] + pargs = targs.pargs + kwargs = dict(targs.kwargs) + else: + pargs = template_args + kwargs = template_kwargs + targs = _TemplateArgs( + pargs, tuple(sorted(kwargs.items(), key=lambda t: t[0])) + ) + + formatter = _TemplateArgFormatter() + name = formatter.format(self.template_string, *pargs, **kwargs) + + self._targs = targs + self._includes = self.class_includes + + for arg in chain(pargs, kwargs.values()): + match arg: + case CppType(): + self._includes |= arg.includes + case PsType(): + self._includes |= { + HeaderFile.parse(h) for h in arg.required_headers + } + + super().__init__(name, const=const) + + def __args__(self) -> tuple[Any, ...]: + return (self._targs,) + + @property + def includes(self) -> frozenset[HeaderFile]: + return self._includes @property def required_headers(self) -> set[str]: - return set(str(h) for h in self.includes) + return set(str(h) for h in self.class_includes) -def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()): +TypeClass_T = TypeVar("TypeClass_T", bound=CppType) +"""Python type variable bound to `CppType`.""" + + +class CppTypeFactory(Generic[TypeClass_T]): + """Type Factory returned by `cpptype`.""" + + def __init__(self, tclass: type[TypeClass_T]) -> None: + self._type_class = tclass + + @property + def includes(self) -> frozenset[HeaderFile]: + """Set of headers required by this factory's type""" + return self._type_class.class_includes + + @property + def template_string(self) -> str: + """Template string of this factory's type""" + return self._type_class.template_string + + def __str__(self) -> str: + return f"Factory for {self.template_string}` defined in {self.includes}" + + def __repr__(self) -> str: + return f"CppTypeFactory({self.template_string}, includes={{ {', '.join(str(i) for i in self.includes)} }})" + + def __call__(self, *args, ref: bool = False, **kwargs) -> TypeClass_T | Ref: + """Create a type object of this factory's C++ type template. + + Args: + args, kwargs: Positional and keyword arguments are forwarded to the template string formatter + ref: If ``True``, return a reference type + + Returns: + An instantiated type object + """ + + obj = self._type_class(*args, **kwargs) + if ref: + return Ref(obj) + else: + return obj + + +def cpptype( + template_str: str, include: str | HeaderFile | Iterable[str | HeaderFile] = () +) -> CppTypeFactory: + """Describe a C++ type template, associated with a set of required header files. + + This function allows users to define C++ type templates using + `Python format string syntax <https://docs.python.org/3/library/string.html#formatstrings>`_. + The types may furthermore be annotated with a set of header files that must be included + in order to use the type. + + >>> opt_template = lang.cpptype("std::optional< {T} >", "<optional>") + >>> opt_template.template_string + 'std::optional< {T} >' + + This function returns a `CppTypeFactory` object, which in turn can be called to create + an instance of the C++ type template. + Therein, the ``template_str`` argument is treated as a Python format string: + The positional and keyword arguments passed to the returned type factory are passed + through machinery that is based on `str.format` to produce the actual type name. + + >>> int_option = opt_template(T="int") + >>> int_option.c_string().strip() + 'std::optional< int >' + + The factory may also create reference types when the ``ref=True`` is specified. + + >>> int_option_ref = opt_template(T="int", ref=True) + >>> int_option_ref.c_string().strip() + 'std::optional< int >&' + + Args: + template_str: Format string defining the type template + include: Either the name of a header file, or a sequence of names of header files + + Returns: + CppTypeFactory: A factory used to instantiate the type template + """ headers: list[str | HeaderFile] if isinstance(include, (str, HeaderFile)): @@ -41,25 +195,11 @@ def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] else: headers = list(include) - def _fixarg(template_arg): - if isinstance(template_arg, PsType): - return template_arg.c_string() - else: - return str(template_arg) - class TypeClass(CppType): - includes = frozenset(HeaderFile.parse(h) for h in headers) - - def __init__(self, *template_args, const: bool = False, **template_kwargs): - template_args = tuple(_fixarg(arg) for arg in template_args) - template_kwargs = { - key: _fixarg(value) for key, value in template_kwargs.items() - } - - name = typestr.format(*template_args, **template_kwargs) - super().__init__(name, const) + template_string = template_str + class_includes = frozenset(HeaderFile.parse(h) for h in headers) - return TypeClass + return CppTypeFactory[TypeClass](TypeClass) class Ref(PsType): diff --git a/tests/generator_scripts/source/Conditionals.py b/tests/generator_scripts/source/Conditionals.py index d1088a96925bf52ab3634cf439538ce1de2b58a7..9016b73744f78fef504bb4f09b6742a630a6b12d 100644 --- a/tests/generator_scripts/source/Conditionals.py +++ b/tests/generator_scripts/source/Conditionals.py @@ -1,4 +1,5 @@ from pystencilssfg import SourceFileGenerator +from pystencils.types import PsCustomType with SourceFileGenerator() as sfg: sfg.namespace("gen") @@ -6,7 +7,7 @@ with SourceFileGenerator() as sfg: sfg.include("<iostream>") sfg.code(r"enum class Noodles { RIGATONI, RAMEN, SPAETZLE, SPAGHETTI };") - noodle = sfg.var("noodle", sfg.cpptype("Noodles")) + noodle = sfg.var("noodle", PsCustomType("Noodles")) sfg.function("printOpinion")( sfg.switch(noodle) diff --git a/tests/generator_scripts/source/SimpleClasses.py b/tests/generator_scripts/source/SimpleClasses.py index 64093f5744c61918bf505518e2c39dffbb525fad..454f1a26f8103f7c8b330f1a5b70b1b79f96ebbc 100644 --- a/tests/generator_scripts/source/SimpleClasses.py +++ b/tests/generator_scripts/source/SimpleClasses.py @@ -16,7 +16,7 @@ with SourceFileGenerator() as sfg: .init(y_)(y) .init(z_)(z), - sfg.method("getX", returns="const int64_t &", const=True, inline=True)( + sfg.method("getX", returns="const int64_t", const=True, inline=True)( "return this->x_;" ) ), diff --git a/tests/lang/test_expressions.py b/tests/lang/test_expressions.py index f7cbebdf2d07a6a87edaaf0774ba393cf2264e06..04c1c98ecd40841fb0ee47e883ff2308fa478bc7 100644 --- a/tests/lang/test_expressions.py +++ b/tests/lang/test_expressions.py @@ -1,11 +1,12 @@ import pytest from pystencilssfg import SfgException -from pystencilssfg.lang import asvar, SfgVar, AugExpr, cpptype, HeaderFile +from pystencilssfg.lang import asvar, SfgVar, AugExpr, cpptype, HeaderFile, CppClass import sympy as sp from pystencils import TypedSymbol, DynamicType +from pystencils.types import PsCustomType def test_asvar(): @@ -103,3 +104,17 @@ def test_headers(): expr = AugExpr().bind("std::get< int >({})", var) assert expr.includes == {HeaderFile("tuple", system_header=True)} + + +def test_cppclass(): + class MyClass(CppClass): + template = cpptype("mynamespace::MyClass< {T} >", "MyHeader.hpp") + + def ctor(self, arg: AugExpr): + return self.ctor_bind(arg) + + unbound = MyClass(T="bogus") + assert unbound.get_dtype() == MyClass.template(T="bogus") + + ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo")) + assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}" diff --git a/tests/lang/test_types.py b/tests/lang/test_types.py index 44b9a7c23a273b703a5c4d49386258f72bb2f726..62c4bb0bd8ecafc8cc5b2613122838dc56a80637 100644 --- a/tests/lang/test_types.py +++ b/tests/lang/test_types.py @@ -1,14 +1,77 @@ -from pystencilssfg.lang import cpptype, HeaderFile +import pytest + +from pystencilssfg.lang import cpptype, HeaderFile, Ref, strip_ptr_ref from pystencils import create_type +from pystencils.types import constify, deconstify def test_cpptypes(): - tclass = cpptype("std::vector< {T}, {Allocator} >", "<vector>") + tfactory = cpptype("std::vector< {}, {} >", "<vector>") - vec_type = tclass(T=create_type("float32"), Allocator="std::allocator< float >") + vec_type = tfactory(create_type("float32"), "std::allocator< float >") assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >" assert ( - tclass.includes - == vec_type.includes + vec_type.includes == {HeaderFile("vector", system_header=True)} ) + + # Cloning + assert deconstify(constify(vec_type)) == vec_type + + # Duplicate Equality + assert tfactory(create_type("float32"), "std::allocator< float >") == vec_type + # Not equal with different argument even though it produces the same string + assert tfactory("float", "std::allocator< float >") != vec_type + + # The same with keyword arguments + tfactory = cpptype("std::vector< {T}, {Allocator} >", "<vector>") + + vec_type = tfactory(T=create_type("float32"), Allocator="std::allocator< float >") + assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >" + + assert deconstify(constify(vec_type)) == vec_type + + +def test_cpptype_invalid_construction(): + tfactory = cpptype("std::vector< {}, {Allocator} >", "<vector>") + + with pytest.raises(IndexError): + _ = tfactory(Allocator="SomeAlloc") + + with pytest.raises(KeyError): + _ = tfactory("int") + + with pytest.raises(ValueError, match="Too many positional arguments"): + _ = tfactory("int", "bogus", Allocator="SomeAlloc") + + with pytest.raises(ValueError, match="Extraneous keyword arguments"): + _ = tfactory("int", Allocator="SomeAlloc", bogus=2) + + +def test_cpptype_const(): + tfactory = cpptype("std::vector< {T} >", "<vector>") + + vec_type = tfactory(T=create_type("uint32")) + assert constify(vec_type) == tfactory(T=create_type("uint32"), const=True) + + vec_type = tfactory(T=create_type("uint32"), const=True) + assert deconstify(vec_type) == tfactory(T=create_type("uint32"), const=False) + + +def test_cpptype_ref(): + tfactory = cpptype("std::vector< {T} >", "<vector>") + + vec_type = tfactory(T=create_type("uint32"), ref=True) + assert isinstance(vec_type, Ref) + assert strip_ptr_ref(vec_type) == tfactory(T=create_type("uint32")) + + +def test_cpptype_inherits_headers(): + optional_tfactory = cpptype("std::optional< {T} >", "<optional>") + vec_tfactory = cpptype("std::vector< {T} >", "<vector>") + + vec_type = vec_tfactory(T=optional_tfactory(T="int")) + assert vec_type.includes == { + HeaderFile.parse("<optional>"), + HeaderFile.parse("<vector>"), + }