Skip to content
Snippets Groups Projects
Commit 87224ea3 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Merge branch 'master' into fhennig/composer-improvements

parents dde18985 dcb77258
No related branches found
No related tags found
1 merge request!21Composer API Extensions and How-To Guide
...@@ -154,7 +154,10 @@ class SfgFilePrinter: ...@@ -154,7 +154,10 @@ class SfgFilePrinter:
return code return code
case SfgClassBody(cls, vblocks): case SfgClassBody(cls, vblocks):
code = f"{cls.class_keyword} {cls.name} {{\n" code = f"{cls.class_keyword} {cls.name}"
if cls.base_classes:
code += " : " + ", ".join(cls.base_classes)
code += " {\n"
vblocks_str = [self._visibility_block(b) for b in vblocks] vblocks_str = [self._visibility_block(b) for b in vblocks]
code += "\n\n".join(vblocks_str) code += "\n\n".join(vblocks_str)
code += "\n};\n" code += "\n};\n"
......
from pystencilssfg import lang
def dim3class(gpu_runtime_header: str, *, cls_name: str = "dim3"):
"""
>>> dim3 = dim3class("<hip/hip_runtime.h>")
>>> dim3().ctor(64, 1, 1)
dim3{64, 1, 1}
Args:
gpu_runtime_header: String with the name of the gpu runtime header
cls_name: String with the acutal name (default "dim3")
Returns:
Dim3Class: A `lang.CppClass` that mimics cuda's/hip's `dim3`
"""
@lang.cppclass(cls_name, gpu_runtime_header)
class Dim3Class:
def ctor(self, dim0=1, dim1=1, dim2=1):
return self.ctor_bind(dim0, dim1, dim2)
@property
def x(self):
return lang.AugExpr.format("{}.x", self)
@property
def y(self):
return lang.AugExpr.format("{}.y", self)
@property
def z(self):
return lang.AugExpr.format("{}.z", self)
@property
def dims(self):
"""The dims property."""
return [self.x, self.y, self.z]
return Dim3Class
...@@ -12,6 +12,7 @@ from .expressions import ( ...@@ -12,6 +12,7 @@ from .expressions import (
depends, depends,
includes, includes,
CppClass, CppClass,
cppclass,
IFieldExtraction, IFieldExtraction,
SrcField, SrcField,
SrcVector, SrcVector,
...@@ -36,6 +37,7 @@ __all__ = [ ...@@ -36,6 +37,7 @@ __all__ = [
"SrcVector", "SrcVector",
"cpptype", "cpptype",
"CppClass", "CppClass",
"cppclass",
"void", "void",
"Ref", "Ref",
"strip_ptr_ref", "strip_ptr_ref",
......
...@@ -11,7 +11,7 @@ from pystencils.types import PsType, PsIntegerType, UserTypeSpec, create_type ...@@ -11,7 +11,7 @@ from pystencils.types import PsType, PsIntegerType, UserTypeSpec, create_type
from ..exceptions import SfgException from ..exceptions import SfgException
from .headers import HeaderFile from .headers import HeaderFile
from .types import strip_ptr_ref, CppType, CppTypeFactory from .types import strip_ptr_ref, CppType, CppTypeFactory, cpptype
__all__ = [ __all__ = [
"SfgVar", "SfgVar",
...@@ -369,6 +369,26 @@ class CppClass(AugExpr): ...@@ -369,6 +369,26 @@ class CppClass(AugExpr):
return self.bind(fstr, *args, require_headers=dtype.includes) return self.bind(fstr, *args, require_headers=dtype.includes)
def cppclass(
template_str: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()
):
"""
Convience class decorator for CppClass.
It adds to the decorate class the variable `template` via `lang.cpptype`
and sets `lang.CppClass` as a base clase.
>>> @cppclass("MyClass", "MyClass.hpp")
... class MyClass:
... pass
"""
def wrapper(cls):
new_cls = type(cls.__name__, (cls, CppClass), {})
new_cls.template = cpptype(template_str, include)
return new_cls
return wrapper
_VarLike = (AugExpr, SfgVar, TypedSymbol) _VarLike = (AugExpr, SfgVar, TypedSymbol)
VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol
"""Things that may act as a variable. """Things that may act as a variable.
......
from pystencilssfg.extensions.gpu import dim3class
from pystencilssfg.lang import HeaderFile, AugExpr
def test_dim3():
cuda_runtime = "<cuda_runtime.h>"
dim3 = dim3class(cuda_runtime, cls_name="dim3")
assert HeaderFile.parse(cuda_runtime) in dim3.template.includes
assert str(dim3().ctor(128, 1, 1)) == "dim3{128, 1, 1}"
assert str(dim3().ctor()) == "dim3{1, 1, 1}"
assert str(dim3().ctor(1, 1, 128)) == "dim3{1, 1, 128}"
block = dim3(ref=True, const=True).var("block")
dims = [
AugExpr.format(
"uint32_t(({} + {} - 1)/ {})",
1024,
block.dims[i],
block.dims[i],
)
for i in range(3)
]
grid = dim3().ctor(*dims)
assert str(grid) == f"dim3{{{', '.join((str(d) for d in dims))}}}"
...@@ -7,4 +7,7 @@ int main(void){ ...@@ -7,4 +7,7 @@ int main(void){
Point p { 3, 1, -4 }; Point p { 3, 1, -4 };
assert(p.getX() == 3); assert(p.getX() == 3);
SpecialPoint q { 0, 1, 2 };
assert(q.getY() == 1);
} }
...@@ -11,18 +11,19 @@ with SourceFileGenerator() as sfg: ...@@ -11,18 +11,19 @@ with SourceFileGenerator() as sfg:
sfg.klass("Point")( sfg.klass("Point")(
sfg.public( sfg.public(
sfg.constructor(x, y, z) sfg.constructor(x, y, z).init(x_)(x).init(y_)(y).init(z_)(z),
.init(x_)(x)
.init(y_)(y)
.init(z_)(z),
sfg.method("getX", returns="const int64_t", const=True, inline=True)( sfg.method("getX", returns="const int64_t", const=True, inline=True)(
"return this->x_;" "return this->x_;"
) ),
), ),
sfg.private( sfg.protected(x_, y_, z_),
x_, )
y_,
z_ sfg.klass("SpecialPoint", bases=["public Point"])(
sfg.public(
"using Point::Point;",
sfg.method("getY", returns="const int64_t", const=True, inline=True)(
"return this->y_;"
),
) )
) )
import pytest import pytest
from pystencilssfg import SfgException from pystencilssfg import SfgException
from pystencilssfg.lang import asvar, SfgVar, AugExpr, cpptype, HeaderFile, CppClass from pystencilssfg.lang import (
asvar,
SfgVar,
AugExpr,
cpptype,
HeaderFile,
CppClass,
cppclass,
)
import sympy as sp import sympy as sp
...@@ -118,3 +126,17 @@ def test_cppclass(): ...@@ -118,3 +126,17 @@ def test_cppclass():
ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo")) ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo"))
assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}" assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}"
def test_cppclass_decorator():
@cppclass("mynamespace::MyClass< {T} >", "MyHeader.hpp")
class MyClass(CppClass):
def ctor(self, arg: AugExpr):
return self.ctor_bind(arg)
unbound = MyClass(T="bogus")
assert unbound.get_dtype() == MyClass.template(T="bogus")
ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo"))
assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment