From 2e6077b1354cfedb54045b5ea5bc41ac05d40336 Mon Sep 17 00:00:00 2001
From: Christoph Alt <christoph.alt@fau.de>
Date: Thu, 13 Feb 2025 16:15:43 +0100
Subject: [PATCH 1/3] Added a class decorator for CppClasses

---
 src/pystencilssfg/lang/__init__.py    |  2 ++
 src/pystencilssfg/lang/expressions.py | 22 +++++++++++++++++++++-
 tests/lang/test_expressions.py        | 24 +++++++++++++++++++++++-
 3 files changed, 46 insertions(+), 2 deletions(-)

diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py
index a8de86b..980d2bd 100644
--- a/src/pystencilssfg/lang/__init__.py
+++ b/src/pystencilssfg/lang/__init__.py
@@ -12,6 +12,7 @@ from .expressions import (
     depends,
     includes,
     CppClass,
+    cppclass,
     IFieldExtraction,
     SrcField,
     SrcVector,
@@ -36,6 +37,7 @@ __all__ = [
     "SrcVector",
     "cpptype",
     "CppClass",
+    "cppclass",
     "void",
     "Ref",
     "strip_ptr_ref",
diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py
index 72287ea..63bda20 100644
--- a/src/pystencilssfg/lang/expressions.py
+++ b/src/pystencilssfg/lang/expressions.py
@@ -11,7 +11,7 @@ from pystencils.types import PsType, PsIntegerType, UserTypeSpec, create_type
 
 from ..exceptions import SfgException
 from .headers import HeaderFile
-from .types import strip_ptr_ref, CppType, CppTypeFactory
+from .types import strip_ptr_ref, CppType, CppTypeFactory, cpptype
 
 __all__ = [
     "SfgVar",
@@ -369,6 +369,26 @@ class CppClass(AugExpr):
         return self.bind(fstr, *args, require_headers=dtype.includes)
 
 
+def cppclass(
+    template_str: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()
+):
+    """
+    Convience class decorator for CppClass.
+    It adds to the decorate class the variable `template` via `lang.cpptype`
+    and sets `lang.CppClass` as a base clase.
+    >>> @cppclass("MyClass", "MyClass.hpp")
+    ...     class MyClass:
+    ...         pass
+    """
+
+    def wrapper(cls):
+        new_cls = type(cls.__name__, (cls, CppClass), {})
+        new_cls.template = cpptype(template_str, include)
+        return new_cls
+
+    return wrapper
+
+
 _VarLike = (AugExpr, SfgVar, TypedSymbol)
 VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol
 """Things that may act as a variable.
diff --git a/tests/lang/test_expressions.py b/tests/lang/test_expressions.py
index 04c1c98..94661e5 100644
--- a/tests/lang/test_expressions.py
+++ b/tests/lang/test_expressions.py
@@ -1,7 +1,15 @@
 import pytest
 
 from pystencilssfg import SfgException
-from pystencilssfg.lang import asvar, SfgVar, AugExpr, cpptype, HeaderFile, CppClass
+from pystencilssfg.lang import (
+    asvar,
+    SfgVar,
+    AugExpr,
+    cpptype,
+    HeaderFile,
+    CppClass,
+    cppclass,
+)
 
 import sympy as sp
 
@@ -118,3 +126,17 @@ def test_cppclass():
 
     ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo"))
     assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}"
+
+
+def test_cppclass_decorator():
+
+    @cppclass("mynamespace::MyClass< {T} >", "MyHeader.hpp")
+    class MyClass(CppClass):
+        def ctor(self, arg: AugExpr):
+            return self.ctor_bind(arg)
+
+    unbound = MyClass(T="bogus")
+    assert unbound.get_dtype() == MyClass.template(T="bogus")
+
+    ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo"))
+    assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}"
-- 
GitLab


From e67fb36e92b26d02650e7591b4398b596a548f95 Mon Sep 17 00:00:00 2001
From: Christoph Alt <christoph.alt@fau.de>
Date: Thu, 13 Feb 2025 16:16:33 +0100
Subject: [PATCH 2/3] added a dim3 implementation

---
 src/pystencilssfg/extensions/gpu.py | 37 +++++++++++++++++++++++++++++
 tests/extensions/test_gpu.py        | 26 ++++++++++++++++++++
 2 files changed, 63 insertions(+)
 create mode 100644 src/pystencilssfg/extensions/gpu.py
 create mode 100644 tests/extensions/test_gpu.py

diff --git a/src/pystencilssfg/extensions/gpu.py b/src/pystencilssfg/extensions/gpu.py
new file mode 100644
index 0000000..d2b2864
--- /dev/null
+++ b/src/pystencilssfg/extensions/gpu.py
@@ -0,0 +1,37 @@
+from pystencilssfg import lang
+
+
+def dim3class(gpu_runtime_header: str, *, cls_name: str = "dim3"):
+    """
+    >>> dim3 = dim3class("<hip/hip_runtime.h>")
+    >>> dim3.ctor(64, 1, 1)
+    'dim3{64, 1, 1}'
+    Args:
+        gpu_runtime_header: String with the name of the gpu runtime header
+        cls_name: String with the acutal name (default "dim3")
+    Returns:
+        Dim3Class: A `lang.CppClass` that mimics cuda's/hip's `dim3`
+    """
+    @lang.cppclass(cls_name, gpu_runtime_header)
+    class Dim3Class:
+        def ctor(self, dim0=1, dim1=1, dim2=1):
+            return self.ctor_bind(dim0, dim1, dim2)
+
+        @property
+        def x(self):
+            return lang.AugExpr.format("{}.x", self)
+
+        @property
+        def y(self):
+            return lang.AugExpr.format("{}.y", self)
+
+        @property
+        def z(self):
+            return lang.AugExpr.format("{}.z", self)
+
+        @property
+        def dims(self):
+            """The dims property."""
+            return [self.x, self.y, self.z]
+
+    return Dim3Class
diff --git a/tests/extensions/test_gpu.py b/tests/extensions/test_gpu.py
new file mode 100644
index 0000000..2e8d133
--- /dev/null
+++ b/tests/extensions/test_gpu.py
@@ -0,0 +1,26 @@
+from pystencilssfg.extensions.gpu import dim3class
+from pystencilssfg.lang import HeaderFile, AugExpr
+
+
+def test_dim3():
+    cuda_runtime = "<cuda_runtime.h>"
+    dim3 = dim3class(cuda_runtime, cls_name="dim3")
+    assert HeaderFile.parse(cuda_runtime) in dim3.template.includes
+    assert str(dim3().ctor(128, 1, 1)) == "dim3{128, 1, 1}"
+    assert str(dim3().ctor()) == "dim3{1, 1, 1}"
+    assert str(dim3().ctor(1, 1, 128)) == "dim3{1, 1, 128}"
+
+    block = dim3(ref=True, const=True).var("block")
+
+    dims = [
+        AugExpr.format(
+            "uint32_t(({} + {} - 1)/ {})",
+            1024,
+            block.dims[i],
+            block.dims[i],
+        )
+        for i in range(3)
+    ]
+
+    grid = dim3().ctor(*dims)
+    assert str(grid) == f"dim3{{{', '.join((str(d) for d in dims))}}}"
-- 
GitLab


From 80afd8c1ed830b48cb69837812728f7684dce88d Mon Sep 17 00:00:00 2001
From: Christoph Alt <christoph.alt@fau.de>
Date: Thu, 13 Feb 2025 16:48:55 +0100
Subject: [PATCH 3/3] fixed the doctests

---
 src/pystencilssfg/extensions/gpu.py   | 5 +++--
 src/pystencilssfg/lang/expressions.py | 4 ++--
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/src/pystencilssfg/extensions/gpu.py b/src/pystencilssfg/extensions/gpu.py
index d2b2864..b4242ac 100644
--- a/src/pystencilssfg/extensions/gpu.py
+++ b/src/pystencilssfg/extensions/gpu.py
@@ -4,8 +4,9 @@ from pystencilssfg import lang
 def dim3class(gpu_runtime_header: str, *, cls_name: str = "dim3"):
     """
     >>> dim3 = dim3class("<hip/hip_runtime.h>")
-    >>> dim3.ctor(64, 1, 1)
-    'dim3{64, 1, 1}'
+    >>> dim3().ctor(64, 1, 1)
+    dim3{64, 1, 1}
+
     Args:
         gpu_runtime_header: String with the name of the gpu runtime header
         cls_name: String with the acutal name (default "dim3")
diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py
index 63bda20..d0b4978 100644
--- a/src/pystencilssfg/lang/expressions.py
+++ b/src/pystencilssfg/lang/expressions.py
@@ -377,8 +377,8 @@ def cppclass(
     It adds to the decorate class the variable `template` via `lang.cpptype`
     and sets `lang.CppClass` as a base clase.
     >>> @cppclass("MyClass", "MyClass.hpp")
-    ...     class MyClass:
-    ...         pass
+    ... class MyClass:
+    ...    pass
     """
 
     def wrapper(cls):
-- 
GitLab