diff --git a/src/pystencilssfg/emission/file_printer.py b/src/pystencilssfg/emission/file_printer.py index 8216a7b1923b8b23e19a4fbece8d507aa6260d27..6ab98eb89aba143dd567fc0c176199ff9d1444fe 100644 --- a/src/pystencilssfg/emission/file_printer.py +++ b/src/pystencilssfg/emission/file_printer.py @@ -154,7 +154,10 @@ class SfgFilePrinter: return code case SfgClassBody(cls, vblocks): - code = f"{cls.class_keyword} {cls.name} {{\n" + code = f"{cls.class_keyword} {cls.name}" + if cls.base_classes: + code += " : " + ", ".join(cls.base_classes) + code += " {\n" vblocks_str = [self._visibility_block(b) for b in vblocks] code += "\n\n".join(vblocks_str) code += "\n};\n" diff --git a/src/pystencilssfg/extensions/gpu.py b/src/pystencilssfg/extensions/gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b4242ac81aa166085dca20b4936cc56e1b829720 --- /dev/null +++ b/src/pystencilssfg/extensions/gpu.py @@ -0,0 +1,38 @@ +from pystencilssfg import lang + + +def dim3class(gpu_runtime_header: str, *, cls_name: str = "dim3"): + """ + >>> dim3 = dim3class("<hip/hip_runtime.h>") + >>> dim3().ctor(64, 1, 1) + dim3{64, 1, 1} + + Args: + gpu_runtime_header: String with the name of the gpu runtime header + cls_name: String with the acutal name (default "dim3") + Returns: + Dim3Class: A `lang.CppClass` that mimics cuda's/hip's `dim3` + """ + @lang.cppclass(cls_name, gpu_runtime_header) + class Dim3Class: + def ctor(self, dim0=1, dim1=1, dim2=1): + return self.ctor_bind(dim0, dim1, dim2) + + @property + def x(self): + return lang.AugExpr.format("{}.x", self) + + @property + def y(self): + return lang.AugExpr.format("{}.y", self) + + @property + def z(self): + return lang.AugExpr.format("{}.z", self) + + @property + def dims(self): + """The dims property.""" + return [self.x, self.y, self.z] + + return Dim3Class diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index a8de86be10ce44c2ac2d49cc3b5fba0e1549de50..980d2bdf9d543ce764e1f739926ee460b5fe9699 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -12,6 +12,7 @@ from .expressions import ( depends, includes, CppClass, + cppclass, IFieldExtraction, SrcField, SrcVector, @@ -36,6 +37,7 @@ __all__ = [ "SrcVector", "cpptype", "CppClass", + "cppclass", "void", "Ref", "strip_ptr_ref", diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 72287eaad73afc08770d451c0847c362ef561519..d0b4978a49c9247eb57edb44c5568468d5a38442 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -11,7 +11,7 @@ from pystencils.types import PsType, PsIntegerType, UserTypeSpec, create_type from ..exceptions import SfgException from .headers import HeaderFile -from .types import strip_ptr_ref, CppType, CppTypeFactory +from .types import strip_ptr_ref, CppType, CppTypeFactory, cpptype __all__ = [ "SfgVar", @@ -369,6 +369,26 @@ class CppClass(AugExpr): return self.bind(fstr, *args, require_headers=dtype.includes) +def cppclass( + template_str: str, include: str | HeaderFile | Iterable[str | HeaderFile] = () +): + """ + Convience class decorator for CppClass. + It adds to the decorate class the variable `template` via `lang.cpptype` + and sets `lang.CppClass` as a base clase. + >>> @cppclass("MyClass", "MyClass.hpp") + ... class MyClass: + ... pass + """ + + def wrapper(cls): + new_cls = type(cls.__name__, (cls, CppClass), {}) + new_cls.template = cpptype(template_str, include) + return new_cls + + return wrapper + + _VarLike = (AugExpr, SfgVar, TypedSymbol) VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol """Things that may act as a variable. diff --git a/tests/extensions/test_gpu.py b/tests/extensions/test_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..2e8d133e80d0b8f9487e39290f4528aa1f11533e --- /dev/null +++ b/tests/extensions/test_gpu.py @@ -0,0 +1,26 @@ +from pystencilssfg.extensions.gpu import dim3class +from pystencilssfg.lang import HeaderFile, AugExpr + + +def test_dim3(): + cuda_runtime = "<cuda_runtime.h>" + dim3 = dim3class(cuda_runtime, cls_name="dim3") + assert HeaderFile.parse(cuda_runtime) in dim3.template.includes + assert str(dim3().ctor(128, 1, 1)) == "dim3{128, 1, 1}" + assert str(dim3().ctor()) == "dim3{1, 1, 1}" + assert str(dim3().ctor(1, 1, 128)) == "dim3{1, 1, 128}" + + block = dim3(ref=True, const=True).var("block") + + dims = [ + AugExpr.format( + "uint32_t(({} + {} - 1)/ {})", + 1024, + block.dims[i], + block.dims[i], + ) + for i in range(3) + ] + + grid = dim3().ctor(*dims) + assert str(grid) == f"dim3{{{', '.join((str(d) for d in dims))}}}" diff --git a/tests/generator_scripts/source/SimpleClasses.harness.cpp b/tests/generator_scripts/source/SimpleClasses.harness.cpp index ecfc065f2fc27423403dab1c0b03200af766a4d7..176f05e55d47f31b6100027e9ad1b87b4e5e63c9 100644 --- a/tests/generator_scripts/source/SimpleClasses.harness.cpp +++ b/tests/generator_scripts/source/SimpleClasses.harness.cpp @@ -7,4 +7,7 @@ int main(void){ Point p { 3, 1, -4 }; assert(p.getX() == 3); + + SpecialPoint q { 0, 1, 2 }; + assert(q.getY() == 1); } diff --git a/tests/generator_scripts/source/SimpleClasses.py b/tests/generator_scripts/source/SimpleClasses.py index 454f1a26f8103f7c8b330f1a5b70b1b79f96ebbc..26502f0e149c11d470e700269f5ff526aff3ce85 100644 --- a/tests/generator_scripts/source/SimpleClasses.py +++ b/tests/generator_scripts/source/SimpleClasses.py @@ -11,18 +11,19 @@ with SourceFileGenerator() as sfg: sfg.klass("Point")( sfg.public( - sfg.constructor(x, y, z) - .init(x_)(x) - .init(y_)(y) - .init(z_)(z), - + sfg.constructor(x, y, z).init(x_)(x).init(y_)(y).init(z_)(z), sfg.method("getX", returns="const int64_t", const=True, inline=True)( "return this->x_;" - ) + ), ), - sfg.private( - x_, - y_, - z_ + sfg.protected(x_, y_, z_), + ) + + sfg.klass("SpecialPoint", bases=["public Point"])( + sfg.public( + "using Point::Point;", + sfg.method("getY", returns="const int64_t", const=True, inline=True)( + "return this->y_;" + ), ) ) diff --git a/tests/lang/test_expressions.py b/tests/lang/test_expressions.py index 04c1c98ecd40841fb0ee47e883ff2308fa478bc7..94661e51634993dcf0c4b3cfad90cdb9ff243901 100644 --- a/tests/lang/test_expressions.py +++ b/tests/lang/test_expressions.py @@ -1,7 +1,15 @@ import pytest from pystencilssfg import SfgException -from pystencilssfg.lang import asvar, SfgVar, AugExpr, cpptype, HeaderFile, CppClass +from pystencilssfg.lang import ( + asvar, + SfgVar, + AugExpr, + cpptype, + HeaderFile, + CppClass, + cppclass, +) import sympy as sp @@ -118,3 +126,17 @@ def test_cppclass(): ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo")) assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}" + + +def test_cppclass_decorator(): + + @cppclass("mynamespace::MyClass< {T} >", "MyHeader.hpp") + class MyClass(CppClass): + def ctor(self, arg: AugExpr): + return self.ctor_bind(arg) + + unbound = MyClass(T="bogus") + assert unbound.get_dtype() == MyClass.template(T="bogus") + + ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo")) + assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}"