diff --git a/src/pystencilssfg/extensions/gpu.py b/src/pystencilssfg/extensions/gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b4242ac81aa166085dca20b4936cc56e1b829720 --- /dev/null +++ b/src/pystencilssfg/extensions/gpu.py @@ -0,0 +1,38 @@ +from pystencilssfg import lang + + +def dim3class(gpu_runtime_header: str, *, cls_name: str = "dim3"): + """ + >>> dim3 = dim3class("<hip/hip_runtime.h>") + >>> dim3().ctor(64, 1, 1) + dim3{64, 1, 1} + + Args: + gpu_runtime_header: String with the name of the gpu runtime header + cls_name: String with the acutal name (default "dim3") + Returns: + Dim3Class: A `lang.CppClass` that mimics cuda's/hip's `dim3` + """ + @lang.cppclass(cls_name, gpu_runtime_header) + class Dim3Class: + def ctor(self, dim0=1, dim1=1, dim2=1): + return self.ctor_bind(dim0, dim1, dim2) + + @property + def x(self): + return lang.AugExpr.format("{}.x", self) + + @property + def y(self): + return lang.AugExpr.format("{}.y", self) + + @property + def z(self): + return lang.AugExpr.format("{}.z", self) + + @property + def dims(self): + """The dims property.""" + return [self.x, self.y, self.z] + + return Dim3Class diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index a8de86be10ce44c2ac2d49cc3b5fba0e1549de50..980d2bdf9d543ce764e1f739926ee460b5fe9699 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -12,6 +12,7 @@ from .expressions import ( depends, includes, CppClass, + cppclass, IFieldExtraction, SrcField, SrcVector, @@ -36,6 +37,7 @@ __all__ = [ "SrcVector", "cpptype", "CppClass", + "cppclass", "void", "Ref", "strip_ptr_ref", diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 72287eaad73afc08770d451c0847c362ef561519..d0b4978a49c9247eb57edb44c5568468d5a38442 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -11,7 +11,7 @@ from pystencils.types import PsType, PsIntegerType, UserTypeSpec, create_type from ..exceptions import SfgException from .headers import HeaderFile -from .types import strip_ptr_ref, CppType, CppTypeFactory +from .types import strip_ptr_ref, CppType, CppTypeFactory, cpptype __all__ = [ "SfgVar", @@ -369,6 +369,26 @@ class CppClass(AugExpr): return self.bind(fstr, *args, require_headers=dtype.includes) +def cppclass( + template_str: str, include: str | HeaderFile | Iterable[str | HeaderFile] = () +): + """ + Convience class decorator for CppClass. + It adds to the decorate class the variable `template` via `lang.cpptype` + and sets `lang.CppClass` as a base clase. + >>> @cppclass("MyClass", "MyClass.hpp") + ... class MyClass: + ... pass + """ + + def wrapper(cls): + new_cls = type(cls.__name__, (cls, CppClass), {}) + new_cls.template = cpptype(template_str, include) + return new_cls + + return wrapper + + _VarLike = (AugExpr, SfgVar, TypedSymbol) VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol """Things that may act as a variable. diff --git a/tests/extensions/test_gpu.py b/tests/extensions/test_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..2e8d133e80d0b8f9487e39290f4528aa1f11533e --- /dev/null +++ b/tests/extensions/test_gpu.py @@ -0,0 +1,26 @@ +from pystencilssfg.extensions.gpu import dim3class +from pystencilssfg.lang import HeaderFile, AugExpr + + +def test_dim3(): + cuda_runtime = "<cuda_runtime.h>" + dim3 = dim3class(cuda_runtime, cls_name="dim3") + assert HeaderFile.parse(cuda_runtime) in dim3.template.includes + assert str(dim3().ctor(128, 1, 1)) == "dim3{128, 1, 1}" + assert str(dim3().ctor()) == "dim3{1, 1, 1}" + assert str(dim3().ctor(1, 1, 128)) == "dim3{1, 1, 128}" + + block = dim3(ref=True, const=True).var("block") + + dims = [ + AugExpr.format( + "uint32_t(({} + {} - 1)/ {})", + 1024, + block.dims[i], + block.dims[i], + ) + for i in range(3) + ] + + grid = dim3().ctor(*dims) + assert str(grid) == f"dim3{{{', '.join((str(d) for d in dims))}}}" diff --git a/tests/lang/test_expressions.py b/tests/lang/test_expressions.py index 04c1c98ecd40841fb0ee47e883ff2308fa478bc7..94661e51634993dcf0c4b3cfad90cdb9ff243901 100644 --- a/tests/lang/test_expressions.py +++ b/tests/lang/test_expressions.py @@ -1,7 +1,15 @@ import pytest from pystencilssfg import SfgException -from pystencilssfg.lang import asvar, SfgVar, AugExpr, cpptype, HeaderFile, CppClass +from pystencilssfg.lang import ( + asvar, + SfgVar, + AugExpr, + cpptype, + HeaderFile, + CppClass, + cppclass, +) import sympy as sp @@ -118,3 +126,17 @@ def test_cppclass(): ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo")) assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}" + + +def test_cppclass_decorator(): + + @cppclass("mynamespace::MyClass< {T} >", "MyHeader.hpp") + class MyClass(CppClass): + def ctor(self, arg: AugExpr): + return self.ctor_bind(arg) + + unbound = MyClass(T="bogus") + assert unbound.get_dtype() == MyClass.template(T="bogus") + + ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo")) + assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}"