From 86c30d14bba61f3fa32467a8401c701029992478 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 20 Jan 2025 16:12:30 +0100
Subject: [PATCH] Improve versatility and robustness of `cpptype`, and document
 it in the user guide

---
 conftest.py                                   |   4 +-
 docs/source/conf.py                           |  19 +-
 docs/source/index.md                          |   1 +
 docs/source/usage/api_modelling.md            | 230 ++++++++++++++++++
 pyproject.toml                                |   2 +-
 src/pystencilssfg/composer/basic_composer.py  |  32 +--
 src/pystencilssfg/composer/class_composer.py  |   4 +-
 src/pystencilssfg/emission/printers.py        |   2 +-
 src/pystencilssfg/extensions/sycl.py          |   4 +-
 src/pystencilssfg/lang/__init__.py            |   4 +-
 src/pystencilssfg/lang/cpp/std_mdspan.py      |   7 +-
 src/pystencilssfg/lang/cpp/std_span.py        |   7 +-
 src/pystencilssfg/lang/cpp/std_tuple.py       |   7 +-
 src/pystencilssfg/lang/cpp/std_vector.py      |   6 +-
 src/pystencilssfg/lang/cpp/sycl_accessor.py   |   6 +-
 src/pystencilssfg/lang/expressions.py         |  72 ++++--
 src/pystencilssfg/lang/types.py               | 182 ++++++++++++--
 .../generator_scripts/source/Conditionals.py  |   3 +-
 .../generator_scripts/source/SimpleClasses.py |   2 +-
 tests/lang/test_expressions.py                |  17 +-
 tests/lang/test_types.py                      |  73 +++++-
 21 files changed, 567 insertions(+), 117 deletions(-)
 create mode 100644 docs/source/usage/api_modelling.md

diff --git a/conftest.py b/conftest.py
index 1c85902..661e722 100644
--- a/conftest.py
+++ b/conftest.py
@@ -3,13 +3,15 @@ from os import path
 
 
 @pytest.fixture(autouse=True)
-def prepare_composer(doctest_namespace):
+def prepare_doctest_namespace(doctest_namespace):
     from pystencilssfg import SfgContext, SfgComposer
+    from pystencilssfg import lang
 
     #   Place a composer object in the environment for doctests
 
     sfg = SfgComposer(SfgContext())
     doctest_namespace["sfg"] = sfg
+    doctest_namespace["lang"] = lang
 
 
 DATA_DIR = path.join(path.split(__file__)[0], "tests/data")
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 4bdf700..da6f4d7 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -24,7 +24,7 @@ html_title = f"pystencils-sfg v{version} Documentation"
 # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
 
 extensions = [
-    "myst_parser",
+    "myst_nb",
     "sphinx.ext.autodoc",
     "sphinx.ext.napoleon",
     "sphinx.ext.autosummary",
@@ -37,16 +37,8 @@ extensions = [
 
 templates_path = ["_templates"]
 exclude_patterns = []
-source_suffix = {
-    ".rst": "restructuredtext",
-    ".md": "markdown",
-}
 master_doc = "index"
 nitpicky = True
-myst_enable_extensions = [
-    "colon_fence",
-    "dollarmath"
-]
 
 
 # -- Options for HTML output -------------------------------------------------
@@ -89,6 +81,15 @@ sfg = SfgComposer(SfgContext())
 '''
 
 
+# -- Options for MyST / MyST-NB ----------------------------------------------
+
+nb_execution_mode = "cache"  # do not execute notebooks by default
+
+myst_enable_extensions = [
+    "dollarmath",
+    "colon_fence",
+]
+
 #   Prepare code generation examples
 
 def build_examples():
diff --git a/docs/source/index.md b/docs/source/index.md
index ca35b36..0cab083 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -6,6 +6,7 @@
 :caption: User Guide
 
 usage/generator_scripts
+C++ API Modelling <usage/api_modelling>
 usage/project_integration
 usage/tips_n_tricks
 ```
diff --git a/docs/source/usage/api_modelling.md b/docs/source/usage/api_modelling.md
new file mode 100644
index 0000000..60d37ae
--- /dev/null
+++ b/docs/source/usage/api_modelling.md
@@ -0,0 +1,230 @@
+---
+file_format: mystnb
+kernelspec:
+  name: python3
+---
+
+# Modelling C++ APIs in pystencils-sfg
+
+Pystencils-SFG is designed to help you generate C++ code that interfaces with pystencils on the one side,
+and with your handwritten code on the other side.
+This requires that the C++ classes and APIs of your framework or application be represented within the SFG system.
+This guide shows how you can use the facilities of the {any}`pystencilssfg.lang` module to model your C++ interfaces
+for use with the code generator.
+
+To begin, import the `lang` module:
+
+```{code-cell} ipython3
+from pystencilssfg import lang
+```
+
+## Defining C++ Types and Type Templates
+
+The first C++ entities that need to be mirrored for the SFGs are the types and type templates a library
+or application uses or exposes.
+
+### Non-Templated Types
+
+To define a C++ type, we use {any}`pystencilssfg.lang.cpptype <pystencilssfg.lang.types.cpptype>`:
+
+```{code-cell} ipython3
+MyClassTypeFactory = lang.cpptype("my_namespace::MyClass", "MyClass.hpp")
+MyClassTypeFactory
+```
+
+This defines two properties of the type: its fully qualified name, and the set of headers
+that need to be included when working with the type.
+Now, whenever this type occurs as the type of a variable given to pystencils-sfg,
+the code generator will make sure that `MyClass.hpp` is included into the respective
+generated code file.
+
+The object returned by `cpptype` is not the type itself, but a factory for instances of the type.
+Even as `MyClass` does not have any template parameters, we can create different instances of it:
+`const` and non-`const`, as well as references and non-references.
+We do this by calling the factory:
+
+```{code-cell} ipython3
+MyClass = MyClassTypeFactory()
+str(MyClass)
+```
+
+To produce a `const`-qualified version of the type:
+
+```{code-cell} ipython3
+MyClassConst = MyClassTypeFactory(const=True)
+str(MyClassConst)
+```
+
+And finally, to produce a reference instead:
+
+```{code-cell} ipython3
+MyClassRef = MyClassTypeFactory(ref=True)
+str(MyClassRef)
+```
+
+Of course, `const` and `ref` can also be combined to create a reference-to-const.
+
+### Types with Template Parameters
+
+We can add template parameters to our type by the use of
+[Python format strings](https://docs.python.org/3/library/string.html#formatstrings):
+
+```{code-cell} ipython3
+MyClassTemplate = lang.cpptype("my_namespace::MyClass< {T1}, {T2} >", "MyClass.hpp")
+MyClassTemplate
+```
+
+Here, the type parameters `T1` and `T2` are specified in braces.
+For them, values must be provided when calling the factory to instantiate the type:
+
+```{code-cell} ipython3
+MyClassIntDouble = MyClassTemplate(T1="int", T2="double")
+str(MyClassIntDouble)
+```
+
+The way type parameters are passed to the factory is identical to the behavior of {any}`str.format`,
+except that it does not support attribute or element accesses.
+In particular, this means that we can also use unnamed, implicit positional parameters:
+
+```{code-cell} ipython3
+MyClassTemplate = lang.cpptype("my_namespace::MyClass< {}, {} >", "MyClass.hpp")
+MyClassIntDouble = MyClassTemplate("int", "double")
+str(MyClassIntDouble)
+```
+
+## Creating Variables and Expressions
+
+Type templates and types will not get us far on their own.
+To use them in APIs, as function or constructor parameters,
+or as class members and local objects,
+we need to create *variables* with certain types.
+
+To do so, we need to inject our defined types into the expression framework of pystencils-sfg.
+We wrap the type in an interface that allows us to create variables and, later, more complex expressions,
+using {any}`lang.CppClass <pystencilssfg.lang.expressions.CppClass>`:
+
+```{code-cell} ipython3
+class MyClass(lang.CppClass):
+    template = lang.cpptype("my_namespace::MyClass< {T1}, {T2} >", "MyClass.hpp")
+```
+
+Instances of `MyClass` can now be created via constructor call, in the same way as above.
+This gives us an unbound `MyClass` object, which we can bind to a variable name by calling `var` on it:
+
+```{code-cell} ipython3
+my_obj = MyClass(T1="int", T2="void").var("my_obj")
+my_obj, str(my_obj.dtype)
+```
+
+## Reflecting C++ Class APIs
+
+In the previous section, we showed how to reflect a C++ class in pystencils-sfg in order to create
+a variable representing an object of that class.
+We can now extend this to reflect the public API of the class, in order to create complex expressions
+involving objects of `MyClass` during code generation.
+
+### Public Methods
+
+Assume `MyClass` has the following public interface:
+
+```C++
+template< typename T1, typename T2 >
+class MyClass {
+public:
+  T1 & getA();
+  std::tuple< T1, T2 > getBoth();
+
+  void replace(T1 a_new, T2 b_new);
+}
+```
+
+We mirror this in our Python reflection of `CppClass` using methods that create `AugExpr` objects,
+which represent C++ expressions annotated with variables they depend on.
+A possible implementation might look like this:
+
+```{code-cell} ipython3
+---
+tags: [remove-cell]
+---
+
+class MyClass(lang.CppClass):
+    template = lang.cpptype("my_namespace::MyClass< {T1}, {T2} >", "MyClass.hpp")
+
+    def ctor(self, a: lang.AugExpr, b: lang.AugExpr) -> MyClass:
+        return self.ctor_bind(a, b)
+
+    def getA(self) -> lang.AugExpr:
+        return lang.AugExpr.format("{}.getA()", self)
+
+    def getBoth(self) -> lang.AugExpr:
+        return lang.AugExpr.format("{}.getBoth()", self)
+
+    def replace(self, a_new: lang.AugExpr, b_new: lang.AugExpr) -> lang.AugExpr:
+        return lang.AugExpr.format("{}.replace({}, {})", self, a_new, b_new)
+```
+
+```{code-block} python
+class MyClass(lang.CppClass):
+    template = lang.cpptype("my_namespace::MyClass< {T1}, {T2} >", "MyClass.hpp")
+
+    def getA(self) -> lang.AugExpr:
+        return lang.AugExpr.format("{}.getA()", self)
+
+    def getBoth(self) -> lang.AugExpr:
+        return lang.AugExpr.format("{}.getBoth()", self)
+
+    def replace(self, a_new: lang.AugExpr, b_new: lang.AugExpr) -> lang.AugExpr:
+        return lang.AugExpr.format("{}.replace({}, {})", self, a_new, b_new)
+```
+
+Each method of `MyClass` reflects a method of the same name in its public C++ API.
+These methods do not return values, but *expressions*;
+here, we use the generic `AugExpr` class to model expressions that we don't know anything
+about except how they should be constructed.
+
+We create these expressions using `AugExpr.format`, which takes a format string
+and interpolation arguments in the same way as `cpptype`.
+Internally, it will analyze the format arguments (e.g. `self`, `a_new` and `b_new` in `replace`),
+and combine information from any `AugExpr`s found among them.
+These are:
+ - **Variables**: If any of the input expression depend on variables, the resulting expression will
+   depend on the union of all these variable sets
+ - **Headers**: If any of the input expression requires certain header files to be evaluated,
+   the resulting expression will require the same header files.
+
+We can see this in action by calling one of the methods on a variable of type `MyClass`:
+
+```{code-cell} ipython3
+my_obj = MyClass(T1="int", T2="void").var("my_obj")
+expr = my_obj.getBoth()
+expr, lang.depends(expr), lang.includes(expr)
+```
+
+We can see: the newly created expression `my_obj.getBoth()` depends on the variable `my_obj` and
+requires the header `MyClass.hpp` to be included; this header it has inherited from `my_obj`.
+
+### Constructors
+
+Using the `AugExpr` system, we can also model constructors of `MyClass`.
+Assume `MyClass` has the constructor `MyClass(T1 a, T2 b)`.
+We implement this by adding a `ctor` method to our Python interface:
+
+```{code-block} python
+class MyClass(lang.CppClass):
+    ...
+    
+    def ctor(self, a: lang.AugExpr, b: lang.AugExpr) -> MyClass:
+        return self.ctor_bind(a, b)
+```
+
+Here, we don't use `AugExpr.format`; instead, we use `ctor_bind`, which is exposed by `CppClass`.
+This will generate the correct constructor invocation from the type of our `MyClass` object
+and also ensure the headers required by `MyClass` are correctly attached to the resulting
+expression:
+
+```{code-cell} ipython3
+a = lang.AugExpr("int").var("a")
+b = lang.AugExpr("double").var("b")
+expr = MyClass(T1="int", T2="double").ctor(a, b)
+expr, lang.depends(expr), lang.includes(expr)
+```
diff --git a/pyproject.toml b/pyproject.toml
index da36a11..6ac0327 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,7 +40,7 @@ docs = [
     "sphinx",
     "pydata-sphinx-theme==0.15.4",
     "sphinx-book-theme==1.1.3",  # workaround for https://github.com/executablebooks/sphinx-book-theme/issues/865
-    "myst-parser",
+    "myst-nb",
     "sphinx_design",
     "sphinx_autodoc_typehints",
     "sphinx-copybutton",
diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py
index 7f1b58f..b96d559 100644
--- a/src/pystencilssfg/composer/basic_composer.py
+++ b/src/pystencilssfg/composer/basic_composer.py
@@ -7,13 +7,7 @@ from functools import reduce
 
 from pystencils import Field
 from pystencils.codegen import Kernel
-from pystencils.types import (
-    create_type,
-    UserTypeSpec,
-    PsCustomType,
-    PsPointerType,
-    PsType,
-)
+from pystencils.types import create_type, UserTypeSpec
 
 from ..context import SfgContext
 from .custom import CustomGenerator
@@ -325,30 +319,6 @@ class SfgBasicComposer(SfgIComposer):
         """Use inside a function body to require the inclusion of headers."""
         return SfgRequireIncludes((HeaderFile.parse(incl) for incl in incls))
 
-    def cpptype(
-        self,
-        typename: UserTypeSpec,
-        ptr: bool = False,
-        ref: bool = False,
-        const: bool = False,
-    ) -> PsType:
-        if ptr and ref:
-            raise SfgException("Create either a pointer, or a ref type, not both!")
-
-        ref_qual = "&" if ref else ""
-        try:
-            base_type = create_type(typename)
-        except ValueError:
-            if not isinstance(typename, str):
-                raise ValueError(f"Could not parse type: {typename}")
-
-            base_type = PsCustomType(typename + ref_qual, const=const)
-
-        if ptr:
-            return PsPointerType(base_type)
-        else:
-            return base_type
-
     def var(self, name: str, dtype: UserTypeSpec) -> AugExpr:
         """Create a variable with given name and data type."""
         return AugExpr(create_type(dtype)).var(name)
diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py
index 1f4c486..489823b 100644
--- a/src/pystencilssfg/composer/class_composer.py
+++ b/src/pystencilssfg/composer/class_composer.py
@@ -1,7 +1,7 @@
 from __future__ import annotations
 from typing import Sequence
 
-from pystencils.types import PsCustomType, UserTypeSpec
+from pystencils.types import PsCustomType, UserTypeSpec, create_type
 
 from ..lang import (
     _VarLike,
@@ -177,7 +177,7 @@ class SfgClassComposer(SfgComposerMixIn):
             return SfgMethod(
                 name,
                 tree,
-                return_type=self._composer.cpptype(returns),
+                return_type=create_type(returns),
                 inline=inline,
                 const=const,
             )
diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py
index adf7508..9d7c97e 100644
--- a/src/pystencilssfg/emission/printers.py
+++ b/src/pystencilssfg/emission/printers.py
@@ -166,7 +166,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter):
 
     @visit.case(SfgMethod)
     def sfg_method(self, method: SfgMethod):
-        code = f"{method.return_type} {method.name} ({self.param_list(method)})"
+        code = f"{method.return_type.c_string()} {method.name} ({self.param_list(method)})"
         code += "const" if method.const else ""
         if method.inline:
             code += (
diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py
index 349e030..88dbc9b 100644
--- a/src/pystencilssfg/extensions/sycl.py
+++ b/src/pystencilssfg/extensions/sycl.py
@@ -57,9 +57,7 @@ class SyclRange(AugExpr):
     _template = cpptype("sycl::range< {dims} >", "<sycl/sycl.hpp>")
 
     def __init__(self, dims: int, const: bool = False, ref: bool = False):
-        dtype = self._template(dims=dims, const=const)
-        if ref:
-            dtype = Ref(dtype)
+        dtype = self._template(dims=dims, const=const, ref=ref)
         super().__init__(dtype)
 
 
diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py
index 2ad8f93..9218ec2 100644
--- a/src/pystencilssfg/lang/__init__.py
+++ b/src/pystencilssfg/lang/__init__.py
@@ -10,6 +10,7 @@ from .expressions import (
     asvar,
     depends,
     includes,
+    CppClass,
     IFieldExtraction,
     SrcField,
     SrcVector,
@@ -32,7 +33,8 @@ __all__ = [
     "SrcField",
     "SrcVector",
     "cpptype",
+    "CppClass",
     "void",
     "Ref",
-    "strip_ptr_ref"
+    "strip_ptr_ref",
 ]
diff --git a/src/pystencilssfg/lang/cpp/std_mdspan.py b/src/pystencilssfg/lang/cpp/std_mdspan.py
index 5f80ad6..6830885 100644
--- a/src/pystencilssfg/lang/cpp/std_mdspan.py
+++ b/src/pystencilssfg/lang/cpp/std_mdspan.py
@@ -11,7 +11,7 @@ from pystencils.types import (
 
 from pystencilssfg.lang.expressions import AugExpr
 
-from ...lang import SrcField, IFieldExtraction, cpptype, Ref, HeaderFile, ExprLike
+from ...lang import SrcField, IFieldExtraction, cpptype, HeaderFile, ExprLike
 
 
 class StdMdspan(SrcField):
@@ -111,11 +111,8 @@ class StdMdspan(SrcField):
             layout_policy = f"{self._namespace}::{layout_policy}"
 
         dtype = self._template(
-            T=T, extents=extents_str, layout_policy=layout_policy, const=const
+            T=T, extents=extents_str, layout_policy=layout_policy, const=const, ref=ref
         )
-
-        if ref:
-            dtype = Ref(dtype)
         super().__init__(dtype)
 
         self._extents_type = extents_str
diff --git a/src/pystencilssfg/lang/cpp/std_span.py b/src/pystencilssfg/lang/cpp/std_span.py
index 861a4c4..f161f48 100644
--- a/src/pystencilssfg/lang/cpp/std_span.py
+++ b/src/pystencilssfg/lang/cpp/std_span.py
@@ -1,7 +1,7 @@
 from pystencils.field import Field
 from pystencils.types import UserTypeSpec, create_type, PsType
 
-from ...lang import SrcField, IFieldExtraction, AugExpr, cpptype, Ref
+from ...lang import SrcField, IFieldExtraction, AugExpr, cpptype
 
 
 class StdSpan(SrcField):
@@ -9,10 +9,7 @@ class StdSpan(SrcField):
 
     def __init__(self, T: UserTypeSpec, ref=False, const=False):
         T = create_type(T)
-        dtype = self._template(T=T, const=const)
-        if ref:
-            dtype = Ref(dtype)
-
+        dtype = self._template(T=T, const=const, ref=ref)
         self._element_type = T
         super().__init__(dtype)
 
diff --git a/src/pystencilssfg/lang/cpp/std_tuple.py b/src/pystencilssfg/lang/cpp/std_tuple.py
index bbf2ba3..58a3530 100644
--- a/src/pystencilssfg/lang/cpp/std_tuple.py
+++ b/src/pystencilssfg/lang/cpp/std_tuple.py
@@ -2,7 +2,7 @@ from typing import Sequence
 
 from pystencils.types import UserTypeSpec, create_type
 
-from ...lang import SrcVector, AugExpr, cpptype, Ref
+from ...lang import SrcVector, AugExpr, cpptype
 
 
 class StdTuple(SrcVector):
@@ -18,10 +18,7 @@ class StdTuple(SrcVector):
         self._length = len(element_types)
         elt_type_strings = tuple(t.c_string() for t in self._element_types)
 
-        dtype = self._template(ts=", ".join(elt_type_strings), const=const)
-        if ref:
-            dtype = Ref(dtype)
-
+        dtype = self._template(ts=", ".join(elt_type_strings), const=const, ref=ref)
         super().__init__(dtype)
 
     def extract_component(self, coordinate: int) -> AugExpr:
diff --git a/src/pystencilssfg/lang/cpp/std_vector.py b/src/pystencilssfg/lang/cpp/std_vector.py
index 5696b32..7e9291e 100644
--- a/src/pystencilssfg/lang/cpp/std_vector.py
+++ b/src/pystencilssfg/lang/cpp/std_vector.py
@@ -1,7 +1,7 @@
 from pystencils.field import Field
 from pystencils.types import UserTypeSpec, create_type, PsType
 
-from ...lang import SrcField, SrcVector, AugExpr, IFieldExtraction, cpptype, Ref
+from ...lang import SrcField, SrcVector, AugExpr, IFieldExtraction, cpptype
 
 
 class StdVector(SrcVector, SrcField):
@@ -15,9 +15,7 @@ class StdVector(SrcVector, SrcField):
         const: bool = False,
     ):
         T = create_type(T)
-        dtype = self._template(T=T, const=const)
-        if ref:
-            dtype = Ref(dtype)
+        dtype = self._template(T=T, const=const, ref=ref)
         super().__init__(dtype)
 
         self._element_type = T
diff --git a/src/pystencilssfg/lang/cpp/sycl_accessor.py b/src/pystencilssfg/lang/cpp/sycl_accessor.py
index f01c53d..4bcad56 100644
--- a/src/pystencilssfg/lang/cpp/sycl_accessor.py
+++ b/src/pystencilssfg/lang/cpp/sycl_accessor.py
@@ -3,7 +3,7 @@ from ...lang import SrcField, IFieldExtraction
 from pystencils import Field
 from pystencils.types import UserTypeSpec, create_type
 
-from ...lang import AugExpr, cpptype, Ref
+from ...lang import AugExpr, cpptype
 
 
 class SyclAccessor(SrcField):
@@ -29,9 +29,7 @@ class SyclAccessor(SrcField):
         T = create_type(T)
         if dimensions > 3:
             raise ValueError("sycl accessors can only have dims 1, 2 or 3")
-        dtype = self._template(T=T, dims=dimensions, const=const)
-        if ref:
-            dtype = Ref(dtype)
+        dtype = self._template(T=T, dims=dimensions, const=const, ref=ref)
 
         super().__init__(dtype)
 
diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py
index 53064bd..f86140e 100644
--- a/src/pystencilssfg/lang/expressions.py
+++ b/src/pystencilssfg/lang/expressions.py
@@ -1,5 +1,5 @@
 from __future__ import annotations
-from typing import Iterable, TypeAlias, Any
+from typing import Iterable, TypeAlias, Any, cast
 from itertools import chain
 from abc import ABC, abstractmethod
 
@@ -10,11 +10,12 @@ from pystencils.types import PsType, UserTypeSpec, create_type
 
 from ..exceptions import SfgException
 from .headers import HeaderFile
-from .types import strip_ptr_ref, CppType
+from .types import strip_ptr_ref, CppType, CppTypeFactory
 
 __all__ = [
     "SfgVar",
     "AugExpr",
+    "CppClass",
     "VarLike",
     "ExprLike",
     "asvar",
@@ -147,7 +148,7 @@ class VarExpr(DependentExpression):
         incls: Iterable[HeaderFile]
         match base_type:
             case CppType():
-                incls = base_type.includes
+                incls = base_type.class_includes
             case _:
                 incls = (
                     HeaderFile.parse(header) for header in var.dtype.required_headers
@@ -220,7 +221,13 @@ class AugExpr:
         """Create a new `AugExpr` by combining existing expressions."""
         return AugExpr().bind(fmt, *deps, **kwdeps)
 
-    def bind(self, fmt: str | AugExpr, *deps, **kwdeps):
+    def bind(
+        self,
+        fmt: str | AugExpr,
+        *deps,
+        require_headers: Iterable[str | HeaderFile] = (),
+        **kwdeps,
+    ):
         """Bind an unbound `AugExpr` instance to an expression."""
         if isinstance(fmt, AugExpr):
             if bool(deps) or bool(kwdeps):
@@ -232,7 +239,7 @@ class AugExpr:
             self._bind(fmt._bound)
         else:
             dependencies: set[SfgVar] = set()
-            incls: set[HeaderFile] = set()
+            incls: set[HeaderFile] = set(HeaderFile.parse(h) for h in require_headers)
 
             from pystencils.sympyextensions import is_constant
 
@@ -310,6 +317,40 @@ class AugExpr:
         return self._bound is not None
 
 
+class CppClass(AugExpr):
+    """Convenience base class for C++ API mirroring.
+
+    Example:
+        To reflect a C++ class (template) in pystencils-sfg, you may create a subclass
+        of `CppClass` like this:
+
+        >>> class MyClassTemplate(CppClass):
+        ...     template = lang.cpptype("mynamespace::MyClassTemplate< {T} >", "MyHeader.hpp")
+
+
+        Then use `AugExpr` initialization and binding to create variables or expressions with
+        this class:
+
+        >>> var = MyClassTemplate(T="float").var("myObj")
+        >>> var
+        myObj
+
+        >>> str(var.dtype).strip()
+        'mynamespace::MyClassTemplate< float >'
+    """
+
+    template: CppTypeFactory
+
+    def __init__(self, *args, const: bool = False, ref: bool = False, **kwargs):
+        dtype = self.template(*args, **kwargs, const=const, ref=ref)
+        super().__init__(dtype)
+
+    def ctor_bind(self, *args):
+        fstr = self.get_dtype().c_string() + "{{" + ", ".join(["{}"] * len(args)) + "}}"
+        dtype = cast(CppType, self.get_dtype())
+        return self.bind(fstr, *args, require_headers=dtype.includes)
+
+
 _VarLike = (AugExpr, SfgVar, TypedSymbol)
 VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol
 """Things that may act as a variable.
@@ -408,7 +449,11 @@ def includes(expr: ExprLike) -> set[HeaderFile]:
 
     match expr:
         case SfgVar(_, dtype):
-            return set(HeaderFile.parse(h) for h in dtype.required_headers)
+            match dtype:
+                case CppType():
+                    return set(dtype.includes)
+                case _:
+                    return set(HeaderFile.parse(h) for h in dtype.required_headers)
         case TypedSymbol():
             return includes(asvar(expr))
         case str():
@@ -424,16 +469,13 @@ class IFieldExtraction(ABC):
     from high-level data structures."""
 
     @abstractmethod
-    def ptr(self) -> AugExpr:
-        ...
+    def ptr(self) -> AugExpr: ...
 
     @abstractmethod
-    def size(self, coordinate: int) -> AugExpr | None:
-        ...
+    def size(self, coordinate: int) -> AugExpr | None: ...
 
     @abstractmethod
-    def stride(self, coordinate: int) -> AugExpr | None:
-        ...
+    def stride(self, coordinate: int) -> AugExpr | None: ...
 
 
 class SrcField(AugExpr):
@@ -444,8 +486,7 @@ class SrcField(AugExpr):
     """
 
     @abstractmethod
-    def get_extraction(self) -> IFieldExtraction:
-        ...
+    def get_extraction(self) -> IFieldExtraction: ...
 
 
 class SrcVector(AugExpr, ABC):
@@ -456,5 +497,4 @@ class SrcVector(AugExpr, ABC):
     """
 
     @abstractmethod
-    def extract_component(self, coordinate: int) -> AugExpr:
-        ...
+    def extract_component(self, coordinate: int) -> AugExpr: ...
diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py
index b3a634f..71c198a 100644
--- a/src/pystencilssfg/lang/types.py
+++ b/src/pystencilssfg/lang/types.py
@@ -1,5 +1,11 @@
-from typing import Any, Iterable
+from __future__ import annotations
+from typing import Any, Iterable, Sequence, Mapping, TypeVar, Generic
 from abc import ABC
+from dataclasses import dataclass
+from itertools import chain
+
+import string
+
 from pystencils.types import PsType, PsPointerType, PsCustomType
 from .headers import HeaderFile
 
@@ -23,15 +29,163 @@ class VoidType(PsType):
 void = VoidType()
 
 
+class _TemplateArgFormatter(string.Formatter):
+
+    def format_field(self, arg, format_spec):
+        if isinstance(arg, PsType):
+            arg = arg.c_string()
+        return super().format_field(arg, format_spec)
+
+    def check_unused_args(
+        self, used_args: set[int | str], args: Sequence, kwargs: Mapping[str, Any]
+    ) -> None:
+        max_args_len: int = (
+            max((k for k in used_args if isinstance(k, int)), default=-1) + 1
+        )
+        if len(args) > max_args_len:
+            raise ValueError(
+                f"Too many positional arguments: Expected {max_args_len}, but got {len(args)}"
+            )
+
+        extra_keys = set(kwargs.keys()) - used_args  # type: ignore
+        if extra_keys:
+            raise ValueError(f"Extraneous keyword arguments: {extra_keys}")
+
+
+@dataclass(frozen=True)
+class _TemplateArgs:
+    pargs: tuple[Any, ...]
+    kwargs: tuple[tuple[str, Any], ...]
+
+
 class CppType(PsCustomType, ABC):
-    includes: frozenset[HeaderFile]
+    class_includes: frozenset[HeaderFile]
+    template_string: str
+
+    def __init__(self, *template_args, const: bool = False, **template_kwargs):
+        #   Support for cloning CppTypes
+        if template_args and isinstance(template_args[0], _TemplateArgs):
+            assert not template_kwargs
+            targs = template_args[0]
+            pargs = targs.pargs
+            kwargs = dict(targs.kwargs)
+        else:
+            pargs = template_args
+            kwargs = template_kwargs
+            targs = _TemplateArgs(
+                pargs, tuple(sorted(kwargs.items(), key=lambda t: t[0]))
+            )
+
+        formatter = _TemplateArgFormatter()
+        name = formatter.format(self.template_string, *pargs, **kwargs)
+
+        self._targs = targs
+        self._includes = self.class_includes
+
+        for arg in chain(pargs, kwargs.values()):
+            match arg:
+                case CppType():
+                    self._includes |= arg.includes
+                case PsType():
+                    self._includes |= {
+                        HeaderFile.parse(h) for h in arg.required_headers
+                    }
+
+        super().__init__(name, const=const)
+
+    def __args__(self) -> tuple[Any, ...]:
+        return (self._targs,)
+
+    @property
+    def includes(self) -> frozenset[HeaderFile]:
+        return self._includes
 
     @property
     def required_headers(self) -> set[str]:
-        return set(str(h) for h in self.includes)
+        return set(str(h) for h in self.class_includes)
 
 
-def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()):
+TypeClass_T = TypeVar("TypeClass_T", bound=CppType)
+"""Python type variable bound to `CppType`."""
+
+
+class CppTypeFactory(Generic[TypeClass_T]):
+    """Type Factory returned by `cpptype`."""
+
+    def __init__(self, tclass: type[TypeClass_T]) -> None:
+        self._type_class = tclass
+
+    @property
+    def includes(self) -> frozenset[HeaderFile]:
+        """Set of headers required by this factory's type"""
+        return self._type_class.class_includes
+
+    @property
+    def template_string(self) -> str:
+        """Template string of this factory's type"""
+        return self._type_class.template_string
+
+    def __str__(self) -> str:
+        return f"Factory for {self.template_string}` defined in {self.includes}"
+
+    def __repr__(self) -> str:
+        return f"CppTypeFactory({self.template_string}, includes={{ {', '.join(str(i) for i in self.includes)} }})"
+
+    def __call__(self, *args, ref: bool = False, **kwargs) -> TypeClass_T | Ref:
+        """Create a type object of this factory's C++ type template.
+
+        Args:
+            args, kwargs: Positional and keyword arguments are forwarded to the template string formatter
+            ref: If ``True``, return a reference type
+
+        Returns:
+            An instantiated type object
+        """
+
+        obj = self._type_class(*args, **kwargs)
+        if ref:
+            return Ref(obj)
+        else:
+            return obj
+
+
+def cpptype(
+    template_str: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()
+) -> CppTypeFactory:
+    """Describe a C++ type template, associated with a set of required header files.
+
+    This function allows users to define C++ type templates using
+    `Python format string syntax <https://docs.python.org/3/library/string.html#formatstrings>`_.
+    The types may furthermore be annotated with a set of header files that must be included
+    in order to use the type.
+
+    >>> opt_template = lang.cpptype("std::optional< {T} >", "<optional>")
+    >>> opt_template.template_string
+    'std::optional< {T} >'
+
+    This function returns a `CppTypeFactory` object, which in turn can be called to create
+    an instance of the C++ type template.
+    Therein, the ``template_str`` argument is treated as a Python format string:
+    The positional and keyword arguments passed to the returned type factory are passed
+    through machinery that is based on `str.format` to produce the actual type name.
+
+    >>> int_option = opt_template(T="int")
+    >>> int_option.c_string().strip()
+    'std::optional< int >'
+
+    The factory may also create reference types when the ``ref=True`` is specified.
+
+    >>> int_option_ref = opt_template(T="int", ref=True)
+    >>> int_option_ref.c_string().strip()
+    'std::optional< int >&'
+
+    Args:
+        template_str: Format string defining the type template
+        include: Either the name of a header file, or a sequence of names of header files
+
+    Returns:
+        CppTypeFactory: A factory used to instantiate the type template
+    """
     headers: list[str | HeaderFile]
 
     if isinstance(include, (str, HeaderFile)):
@@ -41,25 +195,11 @@ def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile]
     else:
         headers = list(include)
 
-    def _fixarg(template_arg):
-        if isinstance(template_arg, PsType):
-            return template_arg.c_string()
-        else:
-            return str(template_arg)
-
     class TypeClass(CppType):
-        includes = frozenset(HeaderFile.parse(h) for h in headers)
-
-        def __init__(self, *template_args, const: bool = False, **template_kwargs):
-            template_args = tuple(_fixarg(arg) for arg in template_args)
-            template_kwargs = {
-                key: _fixarg(value) for key, value in template_kwargs.items()
-            }
-
-            name = typestr.format(*template_args, **template_kwargs)
-            super().__init__(name, const)
+        template_string = template_str
+        class_includes = frozenset(HeaderFile.parse(h) for h in headers)
 
-    return TypeClass
+    return CppTypeFactory[TypeClass](TypeClass)
 
 
 class Ref(PsType):
diff --git a/tests/generator_scripts/source/Conditionals.py b/tests/generator_scripts/source/Conditionals.py
index d1088a9..9016b73 100644
--- a/tests/generator_scripts/source/Conditionals.py
+++ b/tests/generator_scripts/source/Conditionals.py
@@ -1,4 +1,5 @@
 from pystencilssfg import SourceFileGenerator
+from pystencils.types import PsCustomType
 
 with SourceFileGenerator() as sfg:
     sfg.namespace("gen")
@@ -6,7 +7,7 @@ with SourceFileGenerator() as sfg:
     sfg.include("<iostream>")
     sfg.code(r"enum class Noodles { RIGATONI, RAMEN, SPAETZLE, SPAGHETTI };")
 
-    noodle = sfg.var("noodle", sfg.cpptype("Noodles"))
+    noodle = sfg.var("noodle", PsCustomType("Noodles"))
 
     sfg.function("printOpinion")(
         sfg.switch(noodle)
diff --git a/tests/generator_scripts/source/SimpleClasses.py b/tests/generator_scripts/source/SimpleClasses.py
index 64093f5..454f1a2 100644
--- a/tests/generator_scripts/source/SimpleClasses.py
+++ b/tests/generator_scripts/source/SimpleClasses.py
@@ -16,7 +16,7 @@ with SourceFileGenerator() as sfg:
             .init(y_)(y)
             .init(z_)(z),
 
-            sfg.method("getX", returns="const int64_t &", const=True, inline=True)(
+            sfg.method("getX", returns="const int64_t", const=True, inline=True)(
                 "return this->x_;"
             )
         ),
diff --git a/tests/lang/test_expressions.py b/tests/lang/test_expressions.py
index f7cbebd..04c1c98 100644
--- a/tests/lang/test_expressions.py
+++ b/tests/lang/test_expressions.py
@@ -1,11 +1,12 @@
 import pytest
 
 from pystencilssfg import SfgException
-from pystencilssfg.lang import asvar, SfgVar, AugExpr, cpptype, HeaderFile
+from pystencilssfg.lang import asvar, SfgVar, AugExpr, cpptype, HeaderFile, CppClass
 
 import sympy as sp
 
 from pystencils import TypedSymbol, DynamicType
+from pystencils.types import PsCustomType
 
 
 def test_asvar():
@@ -103,3 +104,17 @@ def test_headers():
 
     expr = AugExpr().bind("std::get< int >({})", var)
     assert expr.includes == {HeaderFile("tuple", system_header=True)}
+
+
+def test_cppclass():
+    class MyClass(CppClass):
+        template = cpptype("mynamespace::MyClass< {T} >", "MyHeader.hpp")
+
+        def ctor(self, arg: AugExpr):
+            return self.ctor_bind(arg)
+
+    unbound = MyClass(T="bogus")
+    assert unbound.get_dtype() == MyClass.template(T="bogus")
+
+    ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo"))
+    assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}"
diff --git a/tests/lang/test_types.py b/tests/lang/test_types.py
index 44b9a7c..62c4bb0 100644
--- a/tests/lang/test_types.py
+++ b/tests/lang/test_types.py
@@ -1,14 +1,77 @@
-from pystencilssfg.lang import cpptype, HeaderFile
+import pytest
+
+from pystencilssfg.lang import cpptype, HeaderFile, Ref, strip_ptr_ref
 from pystencils import create_type
+from pystencils.types import constify, deconstify
 
 
 def test_cpptypes():
-    tclass = cpptype("std::vector< {T}, {Allocator} >", "<vector>")
+    tfactory = cpptype("std::vector< {}, {} >", "<vector>")
 
-    vec_type = tclass(T=create_type("float32"), Allocator="std::allocator< float >")
+    vec_type = tfactory(create_type("float32"), "std::allocator< float >")
     assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >"
     assert (
-        tclass.includes
-        == vec_type.includes
+        vec_type.includes
         == {HeaderFile("vector", system_header=True)}
     )
+
+    #   Cloning
+    assert deconstify(constify(vec_type)) == vec_type
+
+    #   Duplicate Equality
+    assert tfactory(create_type("float32"), "std::allocator< float >") == vec_type
+    #   Not equal with different argument even though it produces the same string
+    assert tfactory("float", "std::allocator< float >") != vec_type
+
+    #   The same with keyword arguments
+    tfactory = cpptype("std::vector< {T}, {Allocator} >", "<vector>")
+
+    vec_type = tfactory(T=create_type("float32"), Allocator="std::allocator< float >")
+    assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >"
+
+    assert deconstify(constify(vec_type)) == vec_type
+
+
+def test_cpptype_invalid_construction():
+    tfactory = cpptype("std::vector< {}, {Allocator} >", "<vector>")
+
+    with pytest.raises(IndexError):
+        _ = tfactory(Allocator="SomeAlloc")
+
+    with pytest.raises(KeyError):
+        _ = tfactory("int")
+
+    with pytest.raises(ValueError, match="Too many positional arguments"):
+        _ = tfactory("int", "bogus", Allocator="SomeAlloc")
+
+    with pytest.raises(ValueError, match="Extraneous keyword arguments"):
+        _ = tfactory("int", Allocator="SomeAlloc", bogus=2)
+
+
+def test_cpptype_const():
+    tfactory = cpptype("std::vector< {T} >", "<vector>")
+
+    vec_type = tfactory(T=create_type("uint32"))
+    assert constify(vec_type) == tfactory(T=create_type("uint32"), const=True)
+
+    vec_type = tfactory(T=create_type("uint32"), const=True)
+    assert deconstify(vec_type) == tfactory(T=create_type("uint32"), const=False)
+
+
+def test_cpptype_ref():
+    tfactory = cpptype("std::vector< {T} >", "<vector>")
+
+    vec_type = tfactory(T=create_type("uint32"), ref=True)
+    assert isinstance(vec_type, Ref)
+    assert strip_ptr_ref(vec_type) == tfactory(T=create_type("uint32"))
+
+
+def test_cpptype_inherits_headers():
+    optional_tfactory = cpptype("std::optional< {T} >", "<optional>")
+    vec_tfactory = cpptype("std::vector< {T} >", "<vector>")
+
+    vec_type = vec_tfactory(T=optional_tfactory(T="int"))
+    assert vec_type.includes == {
+        HeaderFile.parse("<optional>"),
+        HeaderFile.parse("<vector>"),
+    }
-- 
GitLab