diff --git a/docs/source/api/lang.rst b/docs/source/api/lang.rst index bd5e4fa7993f30d4c8f6e5c532eb3d234c64061d..cdf8f242cf58b3f0d9889ad94cef6c7a8b7a8155 100644 --- a/docs/source/api/lang.rst +++ b/docs/source/api/lang.rst @@ -41,3 +41,9 @@ Implementation .. automodule:: pystencilssfg.lang.cpp :members: + +GPU Runtime APIs +---------------- + +.. automodule:: pystencilssfg.lang.gpu + :members: diff --git a/src/pystencilssfg/lang/gpu.py b/src/pystencilssfg/lang/gpu.py index ccf86d96bd74e242a5e11991b53233ac5580bd5f..0ca2a6dea3846800b7f719268ff537474ee045f7 100644 --- a/src/pystencilssfg/lang/gpu.py +++ b/src/pystencilssfg/lang/gpu.py @@ -5,38 +5,50 @@ from typing import Protocol from .expressions import CppClass, cpptype, AugExpr -class _Dim3Base(CppClass): +class Dim3Interface(CppClass): + """Interface definition for the ``dim3`` struct of Cuda and HIP.""" + def ctor(self, dim0=1, dim1=1, dim2=1): + """Constructor invocation of ``dim3``""" return self.ctor_bind(dim0, dim1, dim2) @property - def x(self): + def x(self) -> AugExpr: + """The `x` coordinate member.""" return AugExpr.format("{}.x", self) @property - def y(self): + def y(self) -> AugExpr: + """The `y` coordinate member.""" return AugExpr.format("{}.y", self) @property - def z(self): + def z(self) -> AugExpr: + """The `z` coordinate member.""" return AugExpr.format("{}.z", self) @property - def dims(self): - """The dims property.""" - return [self.x, self.y, self.z] + def dims(self) -> tuple[AugExpr, AugExpr, AugExpr]: + """`x`, `y`, and `z` as a tuple.""" + return (self.x, self.y, self.z) class ProvidesGpuRuntimeAPI(Protocol): + """Protocol definition for a GPU runtime API provider.""" - dim3: type[_Dim3Base] + dim3: type[Dim3Interface] + """The ``dim3`` struct type for this GPU runtime""" stream_t: type[AugExpr] + """The ``stream_t`` type for this GPU runtime""" class CudaAPI(ProvidesGpuRuntimeAPI): + """Reflection of the CUDA runtime API""" + + class dim3(Dim3Interface): + """Implements `Dim3Interface` for CUDA""" - class dim3(_Dim3Base): template = cpptype("dim3", "<cuda_runtime.h>") class stream_t(CppClass): @@ -44,8 +56,11 @@ class CudaAPI(ProvidesGpuRuntimeAPI): class HipAPI(ProvidesGpuRuntimeAPI): + """Reflection of the HIP runtime API""" + + class dim3(Dim3Interface): + """Implements `Dim3Interface` for HIP""" - class dim3(_Dim3Base): template = cpptype("dim3", "<hip/hip_runtime.h>") class stream_t(CppClass):