Skip to content
Snippets Groups Projects
Commit da71823f authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Add `CppClass` convenience class. Add user guide on using it.

parent fcfb59f1
No related branches found
No related tags found
1 merge request!12Improve versatility and robustness of `cpptype`, and document it in the user guide
Pipeline #71977 failed
...@@ -90,4 +90,141 @@ In particular, this means that we can also use unnamed, implicit positional para ...@@ -90,4 +90,141 @@ In particular, this means that we can also use unnamed, implicit positional para
MyClassTemplate = lang.cpptype("my_namespace::MyClass< {}, {} >", "MyClass.hpp") MyClassTemplate = lang.cpptype("my_namespace::MyClass< {}, {} >", "MyClass.hpp")
MyClassIntDouble = MyClassTemplate("int", "double") MyClassIntDouble = MyClassTemplate("int", "double")
str(MyClassIntDouble) str(MyClassIntDouble)
``` ```
\ No newline at end of file
## Creating Variables and Expressions
Type templates and types will not get us far on their own.
To use them in APIs, as function or constructor parameters,
or as class members and local objects,
we need to create *variables* with certain types.
To do so, we need to inject our defined types into the expression framework of pystencils-sfg.
We wrap the type in an interface that allows us to create variables and, later, more complex expressions,
using {any}`lang.CppClass <pystencilssfg.lang.expressions.CppClass>`:
```{code-cell} ipython3
class MyClass(lang.CppClass):
template = lang.cpptype("my_namespace::MyClass< {T1}, {T2} >", "MyClass.hpp")
```
Instances of `MyClass` can now be created via constructor call, in the same way as above.
This gives us an unbound `MyClass` object, which we can bind to a variable name by calling `var` on it:
```{code-cell} ipython3
my_obj = MyClass(T1="int", T2="void").var("my_obj")
my_obj, str(my_obj.dtype)
```
## Reflecting C++ Class APIs
In the previous section, we showed how to reflect a C++ class in pystencils-sfg in order to create
a variable representing an object of that class.
We can now extend this to reflect the public API of the class, in order to create complex expressions
involving objects of `MyClass` during code generation.
### Public Methods
Assume `MyClass` has the following public interface:
```C++
template< typename T1, typename T2 >
class MyClass {
public:
T1 & getA();
std::tuple< T1, T2 > getBoth();
void replace(T1 a_new, T2 b_new);
}
```
We mirror this in our Python reflection of `CppClass` using methods that create `AugExpr` objects,
which represent C++ expressions annotated with variables they depend on.
A possible implementation might look like this:
```{code-cell} ipython3
---
tags: [remove-cell]
---
class MyClass(lang.CppClass):
template = lang.cpptype("my_namespace::MyClass< {T1}, {T2} >", "MyClass.hpp")
def ctor(self, a: lang.AugExpr, b: lang.AugExpr) -> MyClass:
return self.ctor_bind(a, b)
def getA(self) -> lang.AugExpr:
return lang.AugExpr.format("{}.getA()", self)
def getBoth(self) -> lang.AugExpr:
return lang.AugExpr.format("{}.getBoth()", self)
def replace(self, a_new: lang.AugExpr, b_new: lang.AugExpr) -> lang.AugExpr:
return lang.AugExpr.format("{}.replace({}, {})", self, a_new, b_new)
```
```{code-block} python
class MyClass(lang.CppClass):
template = lang.cpptype("my_namespace::MyClass< {T1}, {T2} >", "MyClass.hpp")
def getA(self) -> lang.AugExpr:
return lang.AugExpr.format("{}.getA()", self)
def getBoth(self) -> lang.AugExpr:
return lang.AugExpr.format("{}.getBoth()", self)
def replace(self, a_new: lang.AugExpr, b_new: lang.AugExpr) -> lang.AugExpr:
return lang.AugExpr.format("{}.replace({}, {})", self, a_new, b_new)
```
Each method of `MyClass` reflects a method of the same name in its public C++ API.
These methods do not return values, but *expressions*;
here, we use the generic `AugExpr` class to model expressions that we don't know anything
about except how they should be constructed.
We create these expressions using `AugExpr.format`, which takes a format string
and interpolation arguments in the same way as `cpptype`.
Internally, it will analyze the format arguments (e.g. `self`, `a_new` and `b_new` in `replace`),
and combine information from any `AugExpr`s found among them.
These are:
- **Variables**: If any of the input expression depend on variables, the resulting expression will
depend on the union of all these variable sets
- **Headers**: If any of the input expression requires certain header files to be evaluated,
the resulting expression will require the same header files.
We can see this in action by calling one of the methods on a variable of type `MyClass`:
```{code-cell} ipython3
my_obj = MyClass(T1="int", T2="void").var("my_obj")
expr = my_obj.getBoth()
expr, lang.depends(expr), lang.includes(expr)
```
We can see: the newly created expression `my_obj.getBoth()` depends on the variable `my_obj` and
requires the header `MyClass.hpp` to be included; this header it has inherited from `my_obj`.
### Constructors
Using the `AugExpr` system, we can also model constructors of `MyClass`.
Assume `MyClass` has the constructor `MyClass(T1 a, T2 b)`.
We implement this by adding a `ctor` method to our Python interface:
```{code-block} python
class MyClass(lang.CppClass):
...
def ctor(self, a: lang.AugExpr, b: lang.AugExpr) -> MyClass:
return self.ctor_bind(a, b)
```
Here, we don't use `AugExpr.format`; instead, we use `ctor_bind`, which is exposed by `CppClass`.
This will generate the correct constructor invocation from the type of our `MyClass` object
and also ensure the headers required by `MyClass` are correctly attached to the resulting
expression:
```{code-cell} ipython3
a = lang.AugExpr("int").var("a")
b = lang.AugExpr("double").var("b")
expr = MyClass(T1="int", T2="double").ctor(a, b)
expr, lang.depends(expr), lang.includes(expr)
```
...@@ -10,6 +10,7 @@ from .expressions import ( ...@@ -10,6 +10,7 @@ from .expressions import (
asvar, asvar,
depends, depends,
includes, includes,
CppClass,
IFieldExtraction, IFieldExtraction,
SrcField, SrcField,
SrcVector, SrcVector,
...@@ -32,7 +33,8 @@ __all__ = [ ...@@ -32,7 +33,8 @@ __all__ = [
"SrcField", "SrcField",
"SrcVector", "SrcVector",
"cpptype", "cpptype",
"CppClass",
"void", "void",
"Ref", "Ref",
"strip_ptr_ref" "strip_ptr_ref",
] ]
from __future__ import annotations from __future__ import annotations
from typing import Iterable, TypeAlias, Any from typing import Iterable, TypeAlias, Any, cast
from itertools import chain from itertools import chain
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -10,11 +10,12 @@ from pystencils.types import PsType, UserTypeSpec, create_type ...@@ -10,11 +10,12 @@ from pystencils.types import PsType, UserTypeSpec, create_type
from ..exceptions import SfgException from ..exceptions import SfgException
from .headers import HeaderFile from .headers import HeaderFile
from .types import strip_ptr_ref, CppType from .types import strip_ptr_ref, CppType, CppTypeFactory
__all__ = [ __all__ = [
"SfgVar", "SfgVar",
"AugExpr", "AugExpr",
"CppClass",
"VarLike", "VarLike",
"ExprLike", "ExprLike",
"asvar", "asvar",
...@@ -220,7 +221,13 @@ class AugExpr: ...@@ -220,7 +221,13 @@ class AugExpr:
"""Create a new `AugExpr` by combining existing expressions.""" """Create a new `AugExpr` by combining existing expressions."""
return AugExpr().bind(fmt, *deps, **kwdeps) return AugExpr().bind(fmt, *deps, **kwdeps)
def bind(self, fmt: str | AugExpr, *deps, **kwdeps): def bind(
self,
fmt: str | AugExpr,
*deps,
require_headers: Iterable[str | HeaderFile] = (),
**kwdeps,
):
"""Bind an unbound `AugExpr` instance to an expression.""" """Bind an unbound `AugExpr` instance to an expression."""
if isinstance(fmt, AugExpr): if isinstance(fmt, AugExpr):
if bool(deps) or bool(kwdeps): if bool(deps) or bool(kwdeps):
...@@ -232,7 +239,7 @@ class AugExpr: ...@@ -232,7 +239,7 @@ class AugExpr:
self._bind(fmt._bound) self._bind(fmt._bound)
else: else:
dependencies: set[SfgVar] = set() dependencies: set[SfgVar] = set()
incls: set[HeaderFile] = set() incls: set[HeaderFile] = set(HeaderFile.parse(h) for h in require_headers)
from pystencils.sympyextensions import is_constant from pystencils.sympyextensions import is_constant
...@@ -310,6 +317,40 @@ class AugExpr: ...@@ -310,6 +317,40 @@ class AugExpr:
return self._bound is not None return self._bound is not None
class CppClass(AugExpr):
"""Convenience base class for C++ API mirroring.
Example:
To reflect a C++ class (template) in pystencils-sfg, you may create a subclass
of `CppClass` like this:
>>> class MyClassTemplate(CppClass):
... template = lang.cpptype("mynamespace::MyClassTemplate< {T} >", "MyHeader.hpp")
Then use `AugExpr` initialization and binding to create variables or expressions with
this class:
>>> var = MyClassTemplate(T="float").var("myObj")
>>> var
myObj
>>> str(var.dtype).strip()
'mynamespace::MyClassTemplate< float >'
"""
template: CppTypeFactory
def __init__(self, *args, const: bool = False, ref: bool = False, **kwargs):
dtype = self.template(*args, **kwargs, const=const, ref=ref)
super().__init__(dtype)
def ctor_bind(self, *args):
fstr = self.get_dtype().c_string() + "{{" + ", ".join(["{}"] * len(args)) + "}}"
dtype = cast(CppType, self.get_dtype())
return self.bind(fstr, *args, require_headers=dtype.includes)
_VarLike = (AugExpr, SfgVar, TypedSymbol) _VarLike = (AugExpr, SfgVar, TypedSymbol)
VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol
"""Things that may act as a variable. """Things that may act as a variable.
...@@ -428,16 +469,13 @@ class IFieldExtraction(ABC): ...@@ -428,16 +469,13 @@ class IFieldExtraction(ABC):
from high-level data structures.""" from high-level data structures."""
@abstractmethod @abstractmethod
def ptr(self) -> AugExpr: def ptr(self) -> AugExpr: ...
...
@abstractmethod @abstractmethod
def size(self, coordinate: int) -> AugExpr | None: def size(self, coordinate: int) -> AugExpr | None: ...
...
@abstractmethod @abstractmethod
def stride(self, coordinate: int) -> AugExpr | None: def stride(self, coordinate: int) -> AugExpr | None: ...
...
class SrcField(AugExpr): class SrcField(AugExpr):
...@@ -448,8 +486,7 @@ class SrcField(AugExpr): ...@@ -448,8 +486,7 @@ class SrcField(AugExpr):
""" """
@abstractmethod @abstractmethod
def get_extraction(self) -> IFieldExtraction: def get_extraction(self) -> IFieldExtraction: ...
...
class SrcVector(AugExpr, ABC): class SrcVector(AugExpr, ABC):
...@@ -460,5 +497,4 @@ class SrcVector(AugExpr, ABC): ...@@ -460,5 +497,4 @@ class SrcVector(AugExpr, ABC):
""" """
@abstractmethod @abstractmethod
def extract_component(self, coordinate: int) -> AugExpr: def extract_component(self, coordinate: int) -> AugExpr: ...
...
import pytest import pytest
from pystencilssfg import SfgException from pystencilssfg import SfgException
from pystencilssfg.lang import asvar, SfgVar, AugExpr, cpptype, HeaderFile from pystencilssfg.lang import asvar, SfgVar, AugExpr, cpptype, HeaderFile, CppClass
import sympy as sp import sympy as sp
from pystencils import TypedSymbol, DynamicType from pystencils import TypedSymbol, DynamicType
from pystencils.types import PsCustomType
def test_asvar(): def test_asvar():
...@@ -103,3 +104,17 @@ def test_headers(): ...@@ -103,3 +104,17 @@ def test_headers():
expr = AugExpr().bind("std::get< int >({})", var) expr = AugExpr().bind("std::get< int >({})", var)
assert expr.includes == {HeaderFile("tuple", system_header=True)} assert expr.includes == {HeaderFile("tuple", system_header=True)}
def test_cppclass():
class MyClass(CppClass):
template = cpptype("mynamespace::MyClass< {T} >", "MyHeader.hpp")
def ctor(self, arg: AugExpr):
return self.ctor_bind(arg)
unbound = MyClass(T="bogus")
assert unbound.get_dtype() == MyClass.template(T="bogus")
ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo"))
assert str(ctor_expr) == r"mynamespace::MyClass< bogus >{foo}"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment