From 2e6077b1354cfedb54045b5ea5bc41ac05d40336 Mon Sep 17 00:00:00 2001 From: Christoph Alt <christoph.alt@fau.de> Date: Thu, 13 Feb 2025 16:15:43 +0100 Subject: [PATCH 1/3] Added a class decorator for CppClasses --- src/pystencilssfg/lang/__init__.py | 2 ++ src/pystencilssfg/lang/expressions.py | 22 +++++++++++++++++++++- tests/lang/test_expressions.py | 24 +++++++++++++++++++++++- 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index a8de86b..980d2bd 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -12,6 +12,7 @@ from .expressions import ( depends, includes, CppClass, + cppclass, IFieldExtraction, SrcField, SrcVector, @@ -36,6 +37,7 @@ __all__ = [ "SrcVector", "cpptype", "CppClass", + "cppclass", "void", "Ref", "strip_ptr_ref", diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 72287ea..63bda20 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -11,7 +11,7 @@ from pystencils.types import PsType, PsIntegerType, UserTypeSpec, create_type from ..exceptions import SfgException from .headers import HeaderFile -from .types import strip_ptr_ref, CppType, CppTypeFactory +from .types import strip_ptr_ref, CppType, CppTypeFactory, cpptype __all__ = [ "SfgVar", @@ -369,6 +369,26 @@ class CppClass(AugExpr): return self.bind(fstr, *args, require_headers=dtype.includes) +def cppclass( + template_str: str, include: str | HeaderFile | Iterable[str | HeaderFile] = () +): + """ + Convience class decorator for CppClass. + It adds to the decorate class the variable `template` via `lang.cpptype` + and sets `lang.CppClass` as a base clase. + >>> @cppclass("MyClass", "MyClass.hpp") + ... class MyClass: + ... pass + """ + + def wrapper(cls): + new_cls = type(cls.__name__, (cls, CppClass), {}) + new_cls.template = cpptype(template_str, include) + return new_cls + + return wrapper + + _VarLike = (AugExpr, SfgVar, TypedSymbol) VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol """Things that may act as a variable. diff --git a/tests/lang/test_expressions.py b/tests/lang/test_expressions.py index 04c1c98..94661e5 100644 --- a/tests/lang/test_expressions.py +++ b/tests/lang/test_expressions.py @@ -1,7 +1,15 @@ import pytest from pystencilssfg import SfgException -from pystencilssfg.lang import asvar, SfgVar, AugExpr, cpptype, HeaderFile, CppClass +from pystencilssfg.lang import ( + asvar, + SfgVar, + AugExpr, + cpptype, + HeaderFile, + CppClass, + cppclass, +) import sympy as sp @@ -118,3 +126,17 @@ def test_cppclass(): ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo")) assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}" + + +def test_cppclass_decorator(): + + @cppclass("mynamespace::MyClass< {T} >", "MyHeader.hpp") + class MyClass(CppClass): + def ctor(self, arg: AugExpr): + return self.ctor_bind(arg) + + unbound = MyClass(T="bogus") + assert unbound.get_dtype() == MyClass.template(T="bogus") + + ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo")) + assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}" -- GitLab From e67fb36e92b26d02650e7591b4398b596a548f95 Mon Sep 17 00:00:00 2001 From: Christoph Alt <christoph.alt@fau.de> Date: Thu, 13 Feb 2025 16:16:33 +0100 Subject: [PATCH 2/3] added a dim3 implementation --- src/pystencilssfg/extensions/gpu.py | 37 +++++++++++++++++++++++++++++ tests/extensions/test_gpu.py | 26 ++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 src/pystencilssfg/extensions/gpu.py create mode 100644 tests/extensions/test_gpu.py diff --git a/src/pystencilssfg/extensions/gpu.py b/src/pystencilssfg/extensions/gpu.py new file mode 100644 index 0000000..d2b2864 --- /dev/null +++ b/src/pystencilssfg/extensions/gpu.py @@ -0,0 +1,37 @@ +from pystencilssfg import lang + + +def dim3class(gpu_runtime_header: str, *, cls_name: str = "dim3"): + """ + >>> dim3 = dim3class("<hip/hip_runtime.h>") + >>> dim3.ctor(64, 1, 1) + 'dim3{64, 1, 1}' + Args: + gpu_runtime_header: String with the name of the gpu runtime header + cls_name: String with the acutal name (default "dim3") + Returns: + Dim3Class: A `lang.CppClass` that mimics cuda's/hip's `dim3` + """ + @lang.cppclass(cls_name, gpu_runtime_header) + class Dim3Class: + def ctor(self, dim0=1, dim1=1, dim2=1): + return self.ctor_bind(dim0, dim1, dim2) + + @property + def x(self): + return lang.AugExpr.format("{}.x", self) + + @property + def y(self): + return lang.AugExpr.format("{}.y", self) + + @property + def z(self): + return lang.AugExpr.format("{}.z", self) + + @property + def dims(self): + """The dims property.""" + return [self.x, self.y, self.z] + + return Dim3Class diff --git a/tests/extensions/test_gpu.py b/tests/extensions/test_gpu.py new file mode 100644 index 0000000..2e8d133 --- /dev/null +++ b/tests/extensions/test_gpu.py @@ -0,0 +1,26 @@ +from pystencilssfg.extensions.gpu import dim3class +from pystencilssfg.lang import HeaderFile, AugExpr + + +def test_dim3(): + cuda_runtime = "<cuda_runtime.h>" + dim3 = dim3class(cuda_runtime, cls_name="dim3") + assert HeaderFile.parse(cuda_runtime) in dim3.template.includes + assert str(dim3().ctor(128, 1, 1)) == "dim3{128, 1, 1}" + assert str(dim3().ctor()) == "dim3{1, 1, 1}" + assert str(dim3().ctor(1, 1, 128)) == "dim3{1, 1, 128}" + + block = dim3(ref=True, const=True).var("block") + + dims = [ + AugExpr.format( + "uint32_t(({} + {} - 1)/ {})", + 1024, + block.dims[i], + block.dims[i], + ) + for i in range(3) + ] + + grid = dim3().ctor(*dims) + assert str(grid) == f"dim3{{{', '.join((str(d) for d in dims))}}}" -- GitLab From 80afd8c1ed830b48cb69837812728f7684dce88d Mon Sep 17 00:00:00 2001 From: Christoph Alt <christoph.alt@fau.de> Date: Thu, 13 Feb 2025 16:48:55 +0100 Subject: [PATCH 3/3] fixed the doctests --- src/pystencilssfg/extensions/gpu.py | 5 +++-- src/pystencilssfg/lang/expressions.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/pystencilssfg/extensions/gpu.py b/src/pystencilssfg/extensions/gpu.py index d2b2864..b4242ac 100644 --- a/src/pystencilssfg/extensions/gpu.py +++ b/src/pystencilssfg/extensions/gpu.py @@ -4,8 +4,9 @@ from pystencilssfg import lang def dim3class(gpu_runtime_header: str, *, cls_name: str = "dim3"): """ >>> dim3 = dim3class("<hip/hip_runtime.h>") - >>> dim3.ctor(64, 1, 1) - 'dim3{64, 1, 1}' + >>> dim3().ctor(64, 1, 1) + dim3{64, 1, 1} + Args: gpu_runtime_header: String with the name of the gpu runtime header cls_name: String with the acutal name (default "dim3") diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 63bda20..d0b4978 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -377,8 +377,8 @@ def cppclass( It adds to the decorate class the variable `template` via `lang.cpptype` and sets `lang.CppClass` as a base clase. >>> @cppclass("MyClass", "MyClass.hpp") - ... class MyClass: - ... pass + ... class MyClass: + ... pass """ def wrapper(cls): -- GitLab