diff --git a/docs/source/usage/api_modelling.md b/docs/source/usage/api_modelling.md index 565da7ad07510e01d2fb5a7855a4ee2f69532dd8..60d37aee23b1b92a720e1c1d2991ff0b072f01de 100644 --- a/docs/source/usage/api_modelling.md +++ b/docs/source/usage/api_modelling.md @@ -90,4 +90,141 @@ In particular, this means that we can also use unnamed, implicit positional para MyClassTemplate = lang.cpptype("my_namespace::MyClass< {}, {} >", "MyClass.hpp") MyClassIntDouble = MyClassTemplate("int", "double") str(MyClassIntDouble) -``` \ No newline at end of file +``` + +## Creating Variables and Expressions + +Type templates and types will not get us far on their own. +To use them in APIs, as function or constructor parameters, +or as class members and local objects, +we need to create *variables* with certain types. + +To do so, we need to inject our defined types into the expression framework of pystencils-sfg. +We wrap the type in an interface that allows us to create variables and, later, more complex expressions, +using {any}`lang.CppClass <pystencilssfg.lang.expressions.CppClass>`: + +```{code-cell} ipython3 +class MyClass(lang.CppClass): + template = lang.cpptype("my_namespace::MyClass< {T1}, {T2} >", "MyClass.hpp") +``` + +Instances of `MyClass` can now be created via constructor call, in the same way as above. +This gives us an unbound `MyClass` object, which we can bind to a variable name by calling `var` on it: + +```{code-cell} ipython3 +my_obj = MyClass(T1="int", T2="void").var("my_obj") +my_obj, str(my_obj.dtype) +``` + +## Reflecting C++ Class APIs + +In the previous section, we showed how to reflect a C++ class in pystencils-sfg in order to create +a variable representing an object of that class. +We can now extend this to reflect the public API of the class, in order to create complex expressions +involving objects of `MyClass` during code generation. + +### Public Methods + +Assume `MyClass` has the following public interface: + +```C++ +template< typename T1, typename T2 > +class MyClass { +public: + T1 & getA(); + std::tuple< T1, T2 > getBoth(); + + void replace(T1 a_new, T2 b_new); +} +``` + +We mirror this in our Python reflection of `CppClass` using methods that create `AugExpr` objects, +which represent C++ expressions annotated with variables they depend on. +A possible implementation might look like this: + +```{code-cell} ipython3 +--- +tags: [remove-cell] +--- + +class MyClass(lang.CppClass): + template = lang.cpptype("my_namespace::MyClass< {T1}, {T2} >", "MyClass.hpp") + + def ctor(self, a: lang.AugExpr, b: lang.AugExpr) -> MyClass: + return self.ctor_bind(a, b) + + def getA(self) -> lang.AugExpr: + return lang.AugExpr.format("{}.getA()", self) + + def getBoth(self) -> lang.AugExpr: + return lang.AugExpr.format("{}.getBoth()", self) + + def replace(self, a_new: lang.AugExpr, b_new: lang.AugExpr) -> lang.AugExpr: + return lang.AugExpr.format("{}.replace({}, {})", self, a_new, b_new) +``` + +```{code-block} python +class MyClass(lang.CppClass): + template = lang.cpptype("my_namespace::MyClass< {T1}, {T2} >", "MyClass.hpp") + + def getA(self) -> lang.AugExpr: + return lang.AugExpr.format("{}.getA()", self) + + def getBoth(self) -> lang.AugExpr: + return lang.AugExpr.format("{}.getBoth()", self) + + def replace(self, a_new: lang.AugExpr, b_new: lang.AugExpr) -> lang.AugExpr: + return lang.AugExpr.format("{}.replace({}, {})", self, a_new, b_new) +``` + +Each method of `MyClass` reflects a method of the same name in its public C++ API. +These methods do not return values, but *expressions*; +here, we use the generic `AugExpr` class to model expressions that we don't know anything +about except how they should be constructed. + +We create these expressions using `AugExpr.format`, which takes a format string +and interpolation arguments in the same way as `cpptype`. +Internally, it will analyze the format arguments (e.g. `self`, `a_new` and `b_new` in `replace`), +and combine information from any `AugExpr`s found among them. +These are: + - **Variables**: If any of the input expression depend on variables, the resulting expression will + depend on the union of all these variable sets + - **Headers**: If any of the input expression requires certain header files to be evaluated, + the resulting expression will require the same header files. + +We can see this in action by calling one of the methods on a variable of type `MyClass`: + +```{code-cell} ipython3 +my_obj = MyClass(T1="int", T2="void").var("my_obj") +expr = my_obj.getBoth() +expr, lang.depends(expr), lang.includes(expr) +``` + +We can see: the newly created expression `my_obj.getBoth()` depends on the variable `my_obj` and +requires the header `MyClass.hpp` to be included; this header it has inherited from `my_obj`. + +### Constructors + +Using the `AugExpr` system, we can also model constructors of `MyClass`. +Assume `MyClass` has the constructor `MyClass(T1 a, T2 b)`. +We implement this by adding a `ctor` method to our Python interface: + +```{code-block} python +class MyClass(lang.CppClass): + ... + + def ctor(self, a: lang.AugExpr, b: lang.AugExpr) -> MyClass: + return self.ctor_bind(a, b) +``` + +Here, we don't use `AugExpr.format`; instead, we use `ctor_bind`, which is exposed by `CppClass`. +This will generate the correct constructor invocation from the type of our `MyClass` object +and also ensure the headers required by `MyClass` are correctly attached to the resulting +expression: + +```{code-cell} ipython3 +a = lang.AugExpr("int").var("a") +b = lang.AugExpr("double").var("b") +expr = MyClass(T1="int", T2="double").ctor(a, b) +expr, lang.depends(expr), lang.includes(expr) +``` diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index 2ad8f9384be3e5f93e78b2504e2e90b730d345da..9218ec2b7d7f94517e35a2c9a8e4e4ddaa7c3a2a 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -10,6 +10,7 @@ from .expressions import ( asvar, depends, includes, + CppClass, IFieldExtraction, SrcField, SrcVector, @@ -32,7 +33,8 @@ __all__ = [ "SrcField", "SrcVector", "cpptype", + "CppClass", "void", "Ref", - "strip_ptr_ref" + "strip_ptr_ref", ] diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 03818c6b481a5436deddb1a0266597e927e982eb..f86140ee7a3775caab19f69c34ef97822975b95e 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Iterable, TypeAlias, Any +from typing import Iterable, TypeAlias, Any, cast from itertools import chain from abc import ABC, abstractmethod @@ -10,11 +10,12 @@ from pystencils.types import PsType, UserTypeSpec, create_type from ..exceptions import SfgException from .headers import HeaderFile -from .types import strip_ptr_ref, CppType +from .types import strip_ptr_ref, CppType, CppTypeFactory __all__ = [ "SfgVar", "AugExpr", + "CppClass", "VarLike", "ExprLike", "asvar", @@ -220,7 +221,13 @@ class AugExpr: """Create a new `AugExpr` by combining existing expressions.""" return AugExpr().bind(fmt, *deps, **kwdeps) - def bind(self, fmt: str | AugExpr, *deps, **kwdeps): + def bind( + self, + fmt: str | AugExpr, + *deps, + require_headers: Iterable[str | HeaderFile] = (), + **kwdeps, + ): """Bind an unbound `AugExpr` instance to an expression.""" if isinstance(fmt, AugExpr): if bool(deps) or bool(kwdeps): @@ -232,7 +239,7 @@ class AugExpr: self._bind(fmt._bound) else: dependencies: set[SfgVar] = set() - incls: set[HeaderFile] = set() + incls: set[HeaderFile] = set(HeaderFile.parse(h) for h in require_headers) from pystencils.sympyextensions import is_constant @@ -310,6 +317,40 @@ class AugExpr: return self._bound is not None +class CppClass(AugExpr): + """Convenience base class for C++ API mirroring. + + Example: + To reflect a C++ class (template) in pystencils-sfg, you may create a subclass + of `CppClass` like this: + + >>> class MyClassTemplate(CppClass): + ... template = lang.cpptype("mynamespace::MyClassTemplate< {T} >", "MyHeader.hpp") + + + Then use `AugExpr` initialization and binding to create variables or expressions with + this class: + + >>> var = MyClassTemplate(T="float").var("myObj") + >>> var + myObj + + >>> str(var.dtype).strip() + 'mynamespace::MyClassTemplate< float >' + """ + + template: CppTypeFactory + + def __init__(self, *args, const: bool = False, ref: bool = False, **kwargs): + dtype = self.template(*args, **kwargs, const=const, ref=ref) + super().__init__(dtype) + + def ctor_bind(self, *args): + fstr = self.get_dtype().c_string() + "{{" + ", ".join(["{}"] * len(args)) + "}}" + dtype = cast(CppType, self.get_dtype()) + return self.bind(fstr, *args, require_headers=dtype.includes) + + _VarLike = (AugExpr, SfgVar, TypedSymbol) VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol """Things that may act as a variable. @@ -428,16 +469,13 @@ class IFieldExtraction(ABC): from high-level data structures.""" @abstractmethod - def ptr(self) -> AugExpr: - ... + def ptr(self) -> AugExpr: ... @abstractmethod - def size(self, coordinate: int) -> AugExpr | None: - ... + def size(self, coordinate: int) -> AugExpr | None: ... @abstractmethod - def stride(self, coordinate: int) -> AugExpr | None: - ... + def stride(self, coordinate: int) -> AugExpr | None: ... class SrcField(AugExpr): @@ -448,8 +486,7 @@ class SrcField(AugExpr): """ @abstractmethod - def get_extraction(self) -> IFieldExtraction: - ... + def get_extraction(self) -> IFieldExtraction: ... class SrcVector(AugExpr, ABC): @@ -460,5 +497,4 @@ class SrcVector(AugExpr, ABC): """ @abstractmethod - def extract_component(self, coordinate: int) -> AugExpr: - ... + def extract_component(self, coordinate: int) -> AugExpr: ... diff --git a/tests/lang/test_expressions.py b/tests/lang/test_expressions.py index f7cbebdf2d07a6a87edaaf0774ba393cf2264e06..1a5700e815aba20ef125ec2baae78788acdf8d4f 100644 --- a/tests/lang/test_expressions.py +++ b/tests/lang/test_expressions.py @@ -1,11 +1,12 @@ import pytest from pystencilssfg import SfgException -from pystencilssfg.lang import asvar, SfgVar, AugExpr, cpptype, HeaderFile +from pystencilssfg.lang import asvar, SfgVar, AugExpr, cpptype, HeaderFile, CppClass import sympy as sp from pystencils import TypedSymbol, DynamicType +from pystencils.types import PsCustomType def test_asvar(): @@ -103,3 +104,17 @@ def test_headers(): expr = AugExpr().bind("std::get< int >({})", var) assert expr.includes == {HeaderFile("tuple", system_header=True)} + + +def test_cppclass(): + class MyClass(CppClass): + template = cpptype("mynamespace::MyClass< {T} >", "MyHeader.hpp") + + def ctor(self, arg: AugExpr): + return self.ctor_bind(arg) + + unbound = MyClass(T="bogus") + assert unbound.get_dtype() == MyClass.template(T="bogus") + + ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo")) + assert str(ctor_expr) == r"mynamespace::MyClass< bogus >{foo}"