Skip to content
Snippets Groups Projects
Commit bd3da8b1 authored by Christoph Alt's avatar Christoph Alt Committed by Frederik Hennig
Browse files

Adding Dim3 class

parent 09d4e4c3
No related branches found
No related tags found
1 merge request!19Adding Dim3 class
from pystencilssfg import lang
def dim3class(gpu_runtime_header: str, *, cls_name: str = "dim3"):
"""
>>> dim3 = dim3class("<hip/hip_runtime.h>")
>>> dim3().ctor(64, 1, 1)
dim3{64, 1, 1}
Args:
gpu_runtime_header: String with the name of the gpu runtime header
cls_name: String with the acutal name (default "dim3")
Returns:
Dim3Class: A `lang.CppClass` that mimics cuda's/hip's `dim3`
"""
@lang.cppclass(cls_name, gpu_runtime_header)
class Dim3Class:
def ctor(self, dim0=1, dim1=1, dim2=1):
return self.ctor_bind(dim0, dim1, dim2)
@property
def x(self):
return lang.AugExpr.format("{}.x", self)
@property
def y(self):
return lang.AugExpr.format("{}.y", self)
@property
def z(self):
return lang.AugExpr.format("{}.z", self)
@property
def dims(self):
"""The dims property."""
return [self.x, self.y, self.z]
return Dim3Class
......@@ -12,6 +12,7 @@ from .expressions import (
depends,
includes,
CppClass,
cppclass,
IFieldExtraction,
SrcField,
SrcVector,
......@@ -36,6 +37,7 @@ __all__ = [
"SrcVector",
"cpptype",
"CppClass",
"cppclass",
"void",
"Ref",
"strip_ptr_ref",
......
......@@ -11,7 +11,7 @@ from pystencils.types import PsType, PsIntegerType, UserTypeSpec, create_type
from ..exceptions import SfgException
from .headers import HeaderFile
from .types import strip_ptr_ref, CppType, CppTypeFactory
from .types import strip_ptr_ref, CppType, CppTypeFactory, cpptype
__all__ = [
"SfgVar",
......@@ -369,6 +369,26 @@ class CppClass(AugExpr):
return self.bind(fstr, *args, require_headers=dtype.includes)
def cppclass(
template_str: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()
):
"""
Convience class decorator for CppClass.
It adds to the decorate class the variable `template` via `lang.cpptype`
and sets `lang.CppClass` as a base clase.
>>> @cppclass("MyClass", "MyClass.hpp")
... class MyClass:
... pass
"""
def wrapper(cls):
new_cls = type(cls.__name__, (cls, CppClass), {})
new_cls.template = cpptype(template_str, include)
return new_cls
return wrapper
_VarLike = (AugExpr, SfgVar, TypedSymbol)
VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol
"""Things that may act as a variable.
......
from pystencilssfg.extensions.gpu import dim3class
from pystencilssfg.lang import HeaderFile, AugExpr
def test_dim3():
cuda_runtime = "<cuda_runtime.h>"
dim3 = dim3class(cuda_runtime, cls_name="dim3")
assert HeaderFile.parse(cuda_runtime) in dim3.template.includes
assert str(dim3().ctor(128, 1, 1)) == "dim3{128, 1, 1}"
assert str(dim3().ctor()) == "dim3{1, 1, 1}"
assert str(dim3().ctor(1, 1, 128)) == "dim3{1, 1, 128}"
block = dim3(ref=True, const=True).var("block")
dims = [
AugExpr.format(
"uint32_t(({} + {} - 1)/ {})",
1024,
block.dims[i],
block.dims[i],
)
for i in range(3)
]
grid = dim3().ctor(*dims)
assert str(grid) == f"dim3{{{', '.join((str(d) for d in dims))}}}"
import pytest
from pystencilssfg import SfgException
from pystencilssfg.lang import asvar, SfgVar, AugExpr, cpptype, HeaderFile, CppClass
from pystencilssfg.lang import (
asvar,
SfgVar,
AugExpr,
cpptype,
HeaderFile,
CppClass,
cppclass,
)
import sympy as sp
......@@ -118,3 +126,17 @@ def test_cppclass():
ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo"))
assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}"
def test_cppclass_decorator():
@cppclass("mynamespace::MyClass< {T} >", "MyHeader.hpp")
class MyClass(CppClass):
def ctor(self, arg: AugExpr):
return self.ctor_bind(arg)
unbound = MyClass(T="bogus")
assert unbound.get_dtype() == MyClass.template(T="bogus")
ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo"))
assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment