Coverage for src/pystencilssfg/lang/gpu.py: 95%

37 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-04 07:16 +0000

1from __future__ import annotations 

2 

3from typing import Protocol 

4 

5from .expressions import CppClass, cpptype, AugExpr 

6 

7 

8class Dim3Interface(CppClass): 

9 """Interface definition for the ``dim3`` struct of Cuda and HIP.""" 

10 

11 def ctor(self, dim0=1, dim1=1, dim2=1): 

12 """Constructor invocation of ``dim3``""" 

13 return self.ctor_bind(dim0, dim1, dim2) 

14 

15 @property 

16 def x(self) -> AugExpr: 

17 """The `x` coordinate member.""" 

18 return AugExpr.format("{}.x", self) 

19 

20 @property 

21 def y(self) -> AugExpr: 

22 """The `y` coordinate member.""" 

23 return AugExpr.format("{}.y", self) 

24 

25 @property 

26 def z(self) -> AugExpr: 

27 """The `z` coordinate member.""" 

28 return AugExpr.format("{}.z", self) 

29 

30 @property 

31 def dims(self) -> tuple[AugExpr, AugExpr, AugExpr]: 

32 """`x`, `y`, and `z` as a tuple.""" 

33 return (self.x, self.y, self.z) 

34 

35 

36class ProvidesGpuRuntimeAPI(Protocol): 

37 """Protocol definition for a GPU runtime API provider.""" 

38 

39 dim3: type[Dim3Interface] 

40 """The ``dim3`` struct type for this GPU runtime""" 

41 

42 stream_t: type[AugExpr] 

43 """The ``stream_t`` type for this GPU runtime""" 

44 

45 

46class CudaAPI(ProvidesGpuRuntimeAPI): 

47 """Reflection of the CUDA runtime API""" 

48 

49 class dim3(Dim3Interface): 

50 """Implements `Dim3Interface` for CUDA""" 

51 

52 template = cpptype("dim3", "<cuda_runtime.h>") 

53 

54 class stream_t(CppClass): 

55 template = cpptype("cudaStream_t", "<cuda_runtime.h>") 

56 

57 

58cuda = CudaAPI 

59"""Alias for `CudaAPI`""" 

60 

61 

62class HipAPI(ProvidesGpuRuntimeAPI): 

63 """Reflection of the HIP runtime API""" 

64 

65 class dim3(Dim3Interface): 

66 """Implements `Dim3Interface` for HIP""" 

67 

68 template = cpptype("dim3", "<hip/hip_runtime.h>") 

69 

70 class stream_t(CppClass): 

71 template = cpptype("hipStream_t", "<hip/hip_runtime.h>") 

72 

73 

74hip = HipAPI 

75"""Alias for `HipAPI`"""