Coverage for src/pystencilssfg/lang/gpu.py: 95%
37 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
1from __future__ import annotations
3from typing import Protocol
5from .expressions import CppClass, cpptype, AugExpr
8class Dim3Interface(CppClass):
9 """Interface definition for the ``dim3`` struct of Cuda and HIP."""
11 def ctor(self, dim0=1, dim1=1, dim2=1):
12 """Constructor invocation of ``dim3``"""
13 return self.ctor_bind(dim0, dim1, dim2)
15 @property
16 def x(self) -> AugExpr:
17 """The `x` coordinate member."""
18 return AugExpr.format("{}.x", self)
20 @property
21 def y(self) -> AugExpr:
22 """The `y` coordinate member."""
23 return AugExpr.format("{}.y", self)
25 @property
26 def z(self) -> AugExpr:
27 """The `z` coordinate member."""
28 return AugExpr.format("{}.z", self)
30 @property
31 def dims(self) -> tuple[AugExpr, AugExpr, AugExpr]:
32 """`x`, `y`, and `z` as a tuple."""
33 return (self.x, self.y, self.z)
36class ProvidesGpuRuntimeAPI(Protocol):
37 """Protocol definition for a GPU runtime API provider."""
39 dim3: type[Dim3Interface]
40 """The ``dim3`` struct type for this GPU runtime"""
42 stream_t: type[AugExpr]
43 """The ``stream_t`` type for this GPU runtime"""
46class CudaAPI(ProvidesGpuRuntimeAPI):
47 """Reflection of the CUDA runtime API"""
49 class dim3(Dim3Interface):
50 """Implements `Dim3Interface` for CUDA"""
52 template = cpptype("dim3", "<cuda_runtime.h>")
54 class stream_t(CppClass):
55 template = cpptype("cudaStream_t", "<cuda_runtime.h>")
58cuda = CudaAPI
59"""Alias for `CudaAPI`"""
62class HipAPI(ProvidesGpuRuntimeAPI):
63 """Reflection of the HIP runtime API"""
65 class dim3(Dim3Interface):
66 """Implements `Dim3Interface` for HIP"""
68 template = cpptype("dim3", "<hip/hip_runtime.h>")
70 class stream_t(CppClass):
71 template = cpptype("hipStream_t", "<hip/hip_runtime.h>")
74hip = HipAPI
75"""Alias for `HipAPI`"""