Skip to content
Snippets Groups Projects
Commit dcb77258 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Merge branch 'gpu_extensions' into 'master'

Adding Dim3 class

See merge request !19
parents 09d4e4c3 bd3da8b1
No related branches found
No related tags found
1 merge request!19Adding Dim3 class
Pipeline #74030 passed
from pystencilssfg import lang
def dim3class(gpu_runtime_header: str, *, cls_name: str = "dim3"):
"""
>>> dim3 = dim3class("<hip/hip_runtime.h>")
>>> dim3().ctor(64, 1, 1)
dim3{64, 1, 1}
Args:
gpu_runtime_header: String with the name of the gpu runtime header
cls_name: String with the acutal name (default "dim3")
Returns:
Dim3Class: A `lang.CppClass` that mimics cuda's/hip's `dim3`
"""
@lang.cppclass(cls_name, gpu_runtime_header)
class Dim3Class:
def ctor(self, dim0=1, dim1=1, dim2=1):
return self.ctor_bind(dim0, dim1, dim2)
@property
def x(self):
return lang.AugExpr.format("{}.x", self)
@property
def y(self):
return lang.AugExpr.format("{}.y", self)
@property
def z(self):
return lang.AugExpr.format("{}.z", self)
@property
def dims(self):
"""The dims property."""
return [self.x, self.y, self.z]
return Dim3Class
...@@ -12,6 +12,7 @@ from .expressions import ( ...@@ -12,6 +12,7 @@ from .expressions import (
depends, depends,
includes, includes,
CppClass, CppClass,
cppclass,
IFieldExtraction, IFieldExtraction,
SrcField, SrcField,
SrcVector, SrcVector,
...@@ -36,6 +37,7 @@ __all__ = [ ...@@ -36,6 +37,7 @@ __all__ = [
"SrcVector", "SrcVector",
"cpptype", "cpptype",
"CppClass", "CppClass",
"cppclass",
"void", "void",
"Ref", "Ref",
"strip_ptr_ref", "strip_ptr_ref",
......
...@@ -11,7 +11,7 @@ from pystencils.types import PsType, PsIntegerType, UserTypeSpec, create_type ...@@ -11,7 +11,7 @@ from pystencils.types import PsType, PsIntegerType, UserTypeSpec, create_type
from ..exceptions import SfgException from ..exceptions import SfgException
from .headers import HeaderFile from .headers import HeaderFile
from .types import strip_ptr_ref, CppType, CppTypeFactory from .types import strip_ptr_ref, CppType, CppTypeFactory, cpptype
__all__ = [ __all__ = [
"SfgVar", "SfgVar",
...@@ -369,6 +369,26 @@ class CppClass(AugExpr): ...@@ -369,6 +369,26 @@ class CppClass(AugExpr):
return self.bind(fstr, *args, require_headers=dtype.includes) return self.bind(fstr, *args, require_headers=dtype.includes)
def cppclass(
template_str: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()
):
"""
Convience class decorator for CppClass.
It adds to the decorate class the variable `template` via `lang.cpptype`
and sets `lang.CppClass` as a base clase.
>>> @cppclass("MyClass", "MyClass.hpp")
... class MyClass:
... pass
"""
def wrapper(cls):
new_cls = type(cls.__name__, (cls, CppClass), {})
new_cls.template = cpptype(template_str, include)
return new_cls
return wrapper
_VarLike = (AugExpr, SfgVar, TypedSymbol) _VarLike = (AugExpr, SfgVar, TypedSymbol)
VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol
"""Things that may act as a variable. """Things that may act as a variable.
......
from pystencilssfg.extensions.gpu import dim3class
from pystencilssfg.lang import HeaderFile, AugExpr
def test_dim3():
cuda_runtime = "<cuda_runtime.h>"
dim3 = dim3class(cuda_runtime, cls_name="dim3")
assert HeaderFile.parse(cuda_runtime) in dim3.template.includes
assert str(dim3().ctor(128, 1, 1)) == "dim3{128, 1, 1}"
assert str(dim3().ctor()) == "dim3{1, 1, 1}"
assert str(dim3().ctor(1, 1, 128)) == "dim3{1, 1, 128}"
block = dim3(ref=True, const=True).var("block")
dims = [
AugExpr.format(
"uint32_t(({} + {} - 1)/ {})",
1024,
block.dims[i],
block.dims[i],
)
for i in range(3)
]
grid = dim3().ctor(*dims)
assert str(grid) == f"dim3{{{', '.join((str(d) for d in dims))}}}"
import pytest import pytest
from pystencilssfg import SfgException from pystencilssfg import SfgException
from pystencilssfg.lang import asvar, SfgVar, AugExpr, cpptype, HeaderFile, CppClass from pystencilssfg.lang import (
asvar,
SfgVar,
AugExpr,
cpptype,
HeaderFile,
CppClass,
cppclass,
)
import sympy as sp import sympy as sp
...@@ -118,3 +126,17 @@ def test_cppclass(): ...@@ -118,3 +126,17 @@ def test_cppclass():
ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo")) ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo"))
assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}" assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}"
def test_cppclass_decorator():
@cppclass("mynamespace::MyClass< {T} >", "MyHeader.hpp")
class MyClass(CppClass):
def ctor(self, arg: AugExpr):
return self.ctor_bind(arg)
unbound = MyClass(T="bogus")
assert unbound.get_dtype() == MyClass.template(T="bogus")
ctor_expr = unbound.ctor(AugExpr(PsCustomType("bogus")).var("foo"))
assert str(ctor_expr).strip() == r"mynamespace::MyClass< bogus >{foo}"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment