From e4c257df0c0087103d5d638baaa75519d16462ce Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 11 Mar 2025 13:51:30 +0100 Subject: [PATCH] add lang.gpu to api docs --- docs/source/api/lang.rst | 6 ++++++ src/pystencilssfg/lang/gpu.py | 35 +++++++++++++++++++++++++---------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/docs/source/api/lang.rst b/docs/source/api/lang.rst index bd5e4fa..cdf8f24 100644 --- a/docs/source/api/lang.rst +++ b/docs/source/api/lang.rst @@ -41,3 +41,9 @@ Implementation .. automodule:: pystencilssfg.lang.cpp :members: + +GPU Runtime APIs +---------------- + +.. automodule:: pystencilssfg.lang.gpu + :members: diff --git a/src/pystencilssfg/lang/gpu.py b/src/pystencilssfg/lang/gpu.py index ccf86d9..0ca2a6d 100644 --- a/src/pystencilssfg/lang/gpu.py +++ b/src/pystencilssfg/lang/gpu.py @@ -5,38 +5,50 @@ from typing import Protocol from .expressions import CppClass, cpptype, AugExpr -class _Dim3Base(CppClass): +class Dim3Interface(CppClass): + """Interface definition for the ``dim3`` struct of Cuda and HIP.""" + def ctor(self, dim0=1, dim1=1, dim2=1): + """Constructor invocation of ``dim3``""" return self.ctor_bind(dim0, dim1, dim2) @property - def x(self): + def x(self) -> AugExpr: + """The `x` coordinate member.""" return AugExpr.format("{}.x", self) @property - def y(self): + def y(self) -> AugExpr: + """The `y` coordinate member.""" return AugExpr.format("{}.y", self) @property - def z(self): + def z(self) -> AugExpr: + """The `z` coordinate member.""" return AugExpr.format("{}.z", self) @property - def dims(self): - """The dims property.""" - return [self.x, self.y, self.z] + def dims(self) -> tuple[AugExpr, AugExpr, AugExpr]: + """`x`, `y`, and `z` as a tuple.""" + return (self.x, self.y, self.z) class ProvidesGpuRuntimeAPI(Protocol): + """Protocol definition for a GPU runtime API provider.""" - dim3: type[_Dim3Base] + dim3: type[Dim3Interface] + """The ``dim3`` struct type for this GPU runtime""" stream_t: type[AugExpr] + """The ``stream_t`` type for this GPU runtime""" class CudaAPI(ProvidesGpuRuntimeAPI): + """Reflection of the CUDA runtime API""" + + class dim3(Dim3Interface): + """Implements `Dim3Interface` for CUDA""" - class dim3(_Dim3Base): template = cpptype("dim3", "<cuda_runtime.h>") class stream_t(CppClass): @@ -44,8 +56,11 @@ class CudaAPI(ProvidesGpuRuntimeAPI): class HipAPI(ProvidesGpuRuntimeAPI): + """Reflection of the HIP runtime API""" + + class dim3(Dim3Interface): + """Implements `Dim3Interface` for HIP""" - class dim3(_Dim3Base): template = cpptype("dim3", "<hip/hip_runtime.h>") class stream_t(CppClass): -- GitLab