Coverage for src/pystencilssfg/extensions/sycl.py: 80%

153 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-04 07:16 +0000

1from __future__ import annotations 

2from typing import Sequence 

3from enum import Enum 

4import re 

5 

6from pystencils.types import UserTypeSpec, PsType, PsCustomType, create_type 

7from pystencils import Target 

8 

9from pystencilssfg.composer.basic_composer import SequencerArg 

10 

11from ..config import CodeStyle 

12from ..exceptions import SfgException 

13from ..context import SfgContext 

14from ..composer import ( 

15 SfgBasicComposer, 

16 SfgClassComposer, 

17 SfgComposer, 

18 SfgComposerMixIn, 

19 make_sequence, 

20) 

21from ..ir import ( 

22 SfgKernelHandle, 

23 SfgCallTreeNode, 

24 SfgCallTreeLeaf, 

25 SfgKernelCallNode, 

26) 

27 

28from ..lang import SfgVar, AugExpr, cpptype, Ref, VarLike, _VarLike, asvar 

29from ..lang.cpp.sycl_accessor import SyclAccessor 

30 

31 

32accessor = SyclAccessor 

33 

34 

35class SyclComposerMixIn(SfgComposerMixIn): 

36 """Composer mix-in for SYCL code generation""" 

37 

38 def sycl_handler(self, name: str) -> SyclHandler: 

39 """Obtain a `SyclHandler`, which represents a ``sycl::handler`` object.""" 

40 return SyclHandler(self._ctx).var(name) 

41 

42 def sycl_group(self, dims: int, name: str) -> SyclGroup: 

43 """Obtain a `SyclHandler`, which represents a ``sycl::handler`` object.""" 

44 return SyclGroup(dims, self._ctx).var(name) 

45 

46 def sycl_range(self, dims: int, name: str, ref: bool = False) -> SyclRange: 

47 return SyclRange(dims, ref=ref).var(name) 

48 

49 

50class SyclComposer(SfgBasicComposer, SfgClassComposer, SyclComposerMixIn): 

51 """Composer extension providing SYCL code generation capabilities""" 

52 

53 def __init__(self, sfg: SfgContext | SfgComposer): 

54 super().__init__(sfg) 

55 

56 

57class SyclRange(AugExpr): 

58 _template = cpptype("sycl::range< {dims} >", "<sycl/sycl.hpp>") 

59 

60 def __init__(self, dims: int, const: bool = False, ref: bool = False): 

61 dtype = self._template(dims=dims, const=const, ref=ref) 

62 super().__init__(dtype) 

63 

64 

65class SyclHandler(AugExpr): 

66 """Represents a SYCL command group handler (``sycl::handler``).""" 

67 

68 _type = cpptype("sycl::handler", "<sycl/sycl.hpp>") 

69 

70 def __init__(self, ctx: SfgContext): 

71 dtype = Ref(self._type()) 

72 super().__init__(dtype) 

73 

74 self._ctx = ctx 

75 

76 def parallel_for( 

77 self, 

78 range: VarLike | Sequence[int], 

79 ): 

80 """Generate a ``parallel_for`` kernel invocation using this command group handler. 

81 The syntax of this uses a chain of two calls to mimic C++ syntax: 

82 

83 .. code-block:: Python 

84 

85 sfg.parallel_for(range)( 

86 # Body 

87 ) 

88 

89 The body is constructed via sequencing (see `make_sequence`). 

90 

91 Args: 

92 range: Object, or tuple of integers, indicating the kernel's iteration range 

93 """ 

94 if isinstance(range, _VarLike): 

95 range = asvar(range) 

96 

97 def check_kernel(khandle: SfgKernelHandle): 

98 kfunc = khandle.kernel 

99 if kfunc.target != Target.SYCL: 

100 raise SfgException( 

101 f"Kernel given to `parallel_for` is no SYCL kernel: {khandle.fqname}" 

102 ) 

103 

104 id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>") 

105 

106 def filter_id(param: SfgVar) -> bool: 

107 return ( 

108 isinstance(param.dtype, PsCustomType) 

109 and id_regex.search(param.dtype.c_string()) is not None 

110 ) 

111 

112 def sequencer(*args: SequencerArg): 

113 id_param = [] 

114 for arg in args: 

115 if isinstance(arg, SfgKernelCallNode): 

116 check_kernel(arg._kernel_handle) 

117 id_param.append( 

118 list(filter(filter_id, arg._kernel_handle.scalar_parameters))[0] 

119 ) 

120 

121 if not all(item == id_param[0] for item in id_param): 

122 raise ValueError( 

123 "id_param should be the same for all kernels in parallel_for" 

124 ) 

125 tree = make_sequence(*args) 

126 

127 kernel_lambda = SfgLambda(("=",), (id_param[0],), tree, None) 

128 return SyclKernelInvoke( 

129 self, SyclInvokeType.ParallelFor, range, kernel_lambda 

130 ) 

131 

132 return sequencer 

133 

134 

135class SyclGroup(AugExpr): 

136 """Represents a SYCL group (``sycl::group``).""" 

137 

138 _template = cpptype("sycl::group< {dims} >", "<sycl/sycl.hpp>") 

139 

140 def __init__(self, dimensions: int, ctx: SfgContext): 

141 dtype = Ref(self._template(dims=dimensions)) 

142 super().__init__(dtype) 

143 

144 self._dimensions = dimensions 

145 self._ctx = ctx 

146 

147 def parallel_for_work_item( 

148 self, range: VarLike | Sequence[int], khandle: SfgKernelHandle 

149 ): 

150 """Generate a ``parallel_for_work_item` kernel invocation on this group.` 

151 

152 Args: 

153 range: Object, or tuple of integers, indicating the kernel's iteration range 

154 kernel: Handle to the pystencils-kernel to be executed 

155 """ 

156 if isinstance(range, _VarLike): 

157 range = asvar(range) 

158 

159 kfunc = khandle.kernel 

160 if kfunc.target != Target.SYCL: 

161 raise SfgException( 

162 f"Kernel given to `parallel_for` is no SYCL kernel: {khandle.fqname}" 

163 ) 

164 

165 id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>") 

166 

167 def filter_id(param: SfgVar) -> bool: 

168 return ( 

169 isinstance(param.dtype, PsCustomType) 

170 and id_regex.search(param.dtype.c_string()) is not None 

171 ) 

172 

173 id_param = list(filter(filter_id, khandle.scalar_parameters))[0] 

174 h_item = SfgVar("item", PsCustomType("sycl::h_item< 3 >")) 

175 

176 comp = SfgComposer(self._ctx) 

177 tree = comp.seq( 

178 comp.set_param(id_param, AugExpr.format("{}.get_local_id()", h_item)), 

179 SfgKernelCallNode(khandle), 

180 ) 

181 

182 kernel_lambda = SfgLambda(("=",), (h_item,), tree, None) 

183 invoke = SyclKernelInvoke( 

184 self, SyclInvokeType.ParallelForWorkItem, range, kernel_lambda 

185 ) 

186 return invoke 

187 

188 

189class SfgLambda: 

190 """Models a C++ lambda expression""" 

191 

192 def __init__( 

193 self, 

194 captures: Sequence[str], 

195 params: Sequence[SfgVar], 

196 tree: SfgCallTreeNode, 

197 return_type: UserTypeSpec | None = None, 

198 ) -> None: 

199 self._captures = tuple(captures) 

200 self._params = tuple(params) 

201 self._tree = tree 

202 self._return_type: PsType | None = ( 

203 create_type(return_type) if return_type is not None else None 

204 ) 

205 

206 from ..ir.postprocessing import CallTreePostProcessing 

207 

208 postprocess = CallTreePostProcessing() 

209 self._required_params = postprocess(self._tree).function_params - set( 

210 self._params 

211 ) 

212 

213 @property 

214 def captures(self) -> tuple[str, ...]: 

215 return self._captures 

216 

217 @property 

218 def parameters(self) -> tuple[SfgVar, ...]: 

219 return self._params 

220 

221 @property 

222 def body(self) -> SfgCallTreeNode: 

223 return self._tree 

224 

225 @property 

226 def return_type(self) -> PsType | None: 

227 return self._return_type 

228 

229 @property 

230 def required_parameters(self) -> set[SfgVar]: 

231 return self._required_params 

232 

233 def get_code(self, cstyle: CodeStyle): 

234 captures = ", ".join(self._captures) 

235 params = ", ".join(f"{p.dtype.c_string()} {p.name}" for p in self._params) 

236 body = self._tree.get_code(cstyle) 

237 body = cstyle.indent(body) 

238 rtype = ( 

239 f"-> {self._return_type.c_string()} " 

240 if self._return_type is not None 

241 else "" 

242 ) 

243 

244 return f"[{captures}] ({params}) {rtype}{ \n{body}\n} " 

245 

246 

247class SyclInvokeType(Enum): 

248 ParallelFor = ("parallel_for", SyclHandler) 

249 ParallelForWorkItem = ("parallel_for_work_item", SyclGroup) 

250 

251 @property 

252 def method(self) -> str: 

253 return self.value[0] 

254 

255 @property 

256 def invoker_class(self) -> type: 

257 return self.value[1] 

258 

259 

260class SyclKernelInvoke(SfgCallTreeLeaf): 

261 """A SYCL kernel invocation on a given handler or group""" 

262 

263 def __init__( 

264 self, 

265 invoker: SyclHandler | SyclGroup, 

266 invoke_type: SyclInvokeType, 

267 range: SfgVar | Sequence[int], 

268 lamb: SfgLambda, 

269 ): 

270 if not isinstance(invoker, invoke_type.invoker_class): 

271 raise SfgException( 

272 f"Cannot invoke kernel via `{invoke_type.method}` on a {type(invoker)}" 

273 ) 

274 

275 super().__init__() 

276 self._invoker = invoker 

277 self._invoke_type = invoke_type 

278 self._range: SfgVar | tuple[int, ...] = ( 

279 range if isinstance(range, SfgVar) else tuple(range) 

280 ) 

281 self._lambda = lamb 

282 

283 self._required_params = set(invoker.depends | lamb.required_parameters) 

284 

285 if isinstance(range, SfgVar): 

286 self._required_params.add(range) 

287 

288 @property 

289 def invoker(self) -> SyclHandler | SyclGroup: 

290 return self._invoker 

291 

292 @property 

293 def range(self) -> SfgVar | tuple[int, ...]: 

294 return self._range 

295 

296 @property 

297 def kernel(self) -> SfgLambda: 

298 return self._lambda 

299 

300 @property 

301 def depends(self) -> set[SfgVar]: 

302 return self._required_params 

303 

304 def get_code(self, cstyle: CodeStyle) -> str: 

305 if isinstance(self._range, SfgVar): 

306 range_code = self._range.name 

307 else: 

308 range_code = "{ " + ", ".join(str(r) for r in self._range) + " }" 

309 

310 kernel_code = self._lambda.get_code(cstyle) 

311 invoker = str(self._invoker) 

312 method = self._invoke_type.method 

313 

314 return f"{invoker}.{method}({range_code}, {kernel_code});"