Coverage for src/pystencilssfg/extensions/sycl.py: 80%
153 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
1from __future__ import annotations
2from typing import Sequence
3from enum import Enum
4import re
6from pystencils.types import UserTypeSpec, PsType, PsCustomType, create_type
7from pystencils import Target
9from pystencilssfg.composer.basic_composer import SequencerArg
11from ..config import CodeStyle
12from ..exceptions import SfgException
13from ..context import SfgContext
14from ..composer import (
15 SfgBasicComposer,
16 SfgClassComposer,
17 SfgComposer,
18 SfgComposerMixIn,
19 make_sequence,
20)
21from ..ir import (
22 SfgKernelHandle,
23 SfgCallTreeNode,
24 SfgCallTreeLeaf,
25 SfgKernelCallNode,
26)
28from ..lang import SfgVar, AugExpr, cpptype, Ref, VarLike, _VarLike, asvar
29from ..lang.cpp.sycl_accessor import SyclAccessor
32accessor = SyclAccessor
35class SyclComposerMixIn(SfgComposerMixIn):
36 """Composer mix-in for SYCL code generation"""
38 def sycl_handler(self, name: str) -> SyclHandler:
39 """Obtain a `SyclHandler`, which represents a ``sycl::handler`` object."""
40 return SyclHandler(self._ctx).var(name)
42 def sycl_group(self, dims: int, name: str) -> SyclGroup:
43 """Obtain a `SyclHandler`, which represents a ``sycl::handler`` object."""
44 return SyclGroup(dims, self._ctx).var(name)
46 def sycl_range(self, dims: int, name: str, ref: bool = False) -> SyclRange:
47 return SyclRange(dims, ref=ref).var(name)
50class SyclComposer(SfgBasicComposer, SfgClassComposer, SyclComposerMixIn):
51 """Composer extension providing SYCL code generation capabilities"""
53 def __init__(self, sfg: SfgContext | SfgComposer):
54 super().__init__(sfg)
57class SyclRange(AugExpr):
58 _template = cpptype("sycl::range< {dims} >", "<sycl/sycl.hpp>")
60 def __init__(self, dims: int, const: bool = False, ref: bool = False):
61 dtype = self._template(dims=dims, const=const, ref=ref)
62 super().__init__(dtype)
65class SyclHandler(AugExpr):
66 """Represents a SYCL command group handler (``sycl::handler``)."""
68 _type = cpptype("sycl::handler", "<sycl/sycl.hpp>")
70 def __init__(self, ctx: SfgContext):
71 dtype = Ref(self._type())
72 super().__init__(dtype)
74 self._ctx = ctx
76 def parallel_for(
77 self,
78 range: VarLike | Sequence[int],
79 ):
80 """Generate a ``parallel_for`` kernel invocation using this command group handler.
81 The syntax of this uses a chain of two calls to mimic C++ syntax:
83 .. code-block:: Python
85 sfg.parallel_for(range)(
86 # Body
87 )
89 The body is constructed via sequencing (see `make_sequence`).
91 Args:
92 range: Object, or tuple of integers, indicating the kernel's iteration range
93 """
94 if isinstance(range, _VarLike):
95 range = asvar(range)
97 def check_kernel(khandle: SfgKernelHandle):
98 kfunc = khandle.kernel
99 if kfunc.target != Target.SYCL:
100 raise SfgException(
101 f"Kernel given to `parallel_for` is no SYCL kernel: {khandle.fqname}"
102 )
104 id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>")
106 def filter_id(param: SfgVar) -> bool:
107 return (
108 isinstance(param.dtype, PsCustomType)
109 and id_regex.search(param.dtype.c_string()) is not None
110 )
112 def sequencer(*args: SequencerArg):
113 id_param = []
114 for arg in args:
115 if isinstance(arg, SfgKernelCallNode):
116 check_kernel(arg._kernel_handle)
117 id_param.append(
118 list(filter(filter_id, arg._kernel_handle.scalar_parameters))[0]
119 )
121 if not all(item == id_param[0] for item in id_param):
122 raise ValueError(
123 "id_param should be the same for all kernels in parallel_for"
124 )
125 tree = make_sequence(*args)
127 kernel_lambda = SfgLambda(("=",), (id_param[0],), tree, None)
128 return SyclKernelInvoke(
129 self, SyclInvokeType.ParallelFor, range, kernel_lambda
130 )
132 return sequencer
135class SyclGroup(AugExpr):
136 """Represents a SYCL group (``sycl::group``)."""
138 _template = cpptype("sycl::group< {dims} >", "<sycl/sycl.hpp>")
140 def __init__(self, dimensions: int, ctx: SfgContext):
141 dtype = Ref(self._template(dims=dimensions))
142 super().__init__(dtype)
144 self._dimensions = dimensions
145 self._ctx = ctx
147 def parallel_for_work_item(
148 self, range: VarLike | Sequence[int], khandle: SfgKernelHandle
149 ):
150 """Generate a ``parallel_for_work_item` kernel invocation on this group.`
152 Args:
153 range: Object, or tuple of integers, indicating the kernel's iteration range
154 kernel: Handle to the pystencils-kernel to be executed
155 """
156 if isinstance(range, _VarLike):
157 range = asvar(range)
159 kfunc = khandle.kernel
160 if kfunc.target != Target.SYCL:
161 raise SfgException(
162 f"Kernel given to `parallel_for` is no SYCL kernel: {khandle.fqname}"
163 )
165 id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>")
167 def filter_id(param: SfgVar) -> bool:
168 return (
169 isinstance(param.dtype, PsCustomType)
170 and id_regex.search(param.dtype.c_string()) is not None
171 )
173 id_param = list(filter(filter_id, khandle.scalar_parameters))[0]
174 h_item = SfgVar("item", PsCustomType("sycl::h_item< 3 >"))
176 comp = SfgComposer(self._ctx)
177 tree = comp.seq(
178 comp.set_param(id_param, AugExpr.format("{}.get_local_id()", h_item)),
179 SfgKernelCallNode(khandle),
180 )
182 kernel_lambda = SfgLambda(("=",), (h_item,), tree, None)
183 invoke = SyclKernelInvoke(
184 self, SyclInvokeType.ParallelForWorkItem, range, kernel_lambda
185 )
186 return invoke
189class SfgLambda:
190 """Models a C++ lambda expression"""
192 def __init__(
193 self,
194 captures: Sequence[str],
195 params: Sequence[SfgVar],
196 tree: SfgCallTreeNode,
197 return_type: UserTypeSpec | None = None,
198 ) -> None:
199 self._captures = tuple(captures)
200 self._params = tuple(params)
201 self._tree = tree
202 self._return_type: PsType | None = (
203 create_type(return_type) if return_type is not None else None
204 )
206 from ..ir.postprocessing import CallTreePostProcessing
208 postprocess = CallTreePostProcessing()
209 self._required_params = postprocess(self._tree).function_params - set(
210 self._params
211 )
213 @property
214 def captures(self) -> tuple[str, ...]:
215 return self._captures
217 @property
218 def parameters(self) -> tuple[SfgVar, ...]:
219 return self._params
221 @property
222 def body(self) -> SfgCallTreeNode:
223 return self._tree
225 @property
226 def return_type(self) -> PsType | None:
227 return self._return_type
229 @property
230 def required_parameters(self) -> set[SfgVar]:
231 return self._required_params
233 def get_code(self, cstyle: CodeStyle):
234 captures = ", ".join(self._captures)
235 params = ", ".join(f"{p.dtype.c_string()} {p.name}" for p in self._params)
236 body = self._tree.get_code(cstyle)
237 body = cstyle.indent(body)
238 rtype = (
239 f"-> {self._return_type.c_string()} "
240 if self._return_type is not None
241 else ""
242 )
244 return f"[{captures}] ({params}) {rtype}{ \n{body}\n} "
247class SyclInvokeType(Enum):
248 ParallelFor = ("parallel_for", SyclHandler)
249 ParallelForWorkItem = ("parallel_for_work_item", SyclGroup)
251 @property
252 def method(self) -> str:
253 return self.value[0]
255 @property
256 def invoker_class(self) -> type:
257 return self.value[1]
260class SyclKernelInvoke(SfgCallTreeLeaf):
261 """A SYCL kernel invocation on a given handler or group"""
263 def __init__(
264 self,
265 invoker: SyclHandler | SyclGroup,
266 invoke_type: SyclInvokeType,
267 range: SfgVar | Sequence[int],
268 lamb: SfgLambda,
269 ):
270 if not isinstance(invoker, invoke_type.invoker_class):
271 raise SfgException(
272 f"Cannot invoke kernel via `{invoke_type.method}` on a {type(invoker)}"
273 )
275 super().__init__()
276 self._invoker = invoker
277 self._invoke_type = invoke_type
278 self._range: SfgVar | tuple[int, ...] = (
279 range if isinstance(range, SfgVar) else tuple(range)
280 )
281 self._lambda = lamb
283 self._required_params = set(invoker.depends | lamb.required_parameters)
285 if isinstance(range, SfgVar):
286 self._required_params.add(range)
288 @property
289 def invoker(self) -> SyclHandler | SyclGroup:
290 return self._invoker
292 @property
293 def range(self) -> SfgVar | tuple[int, ...]:
294 return self._range
296 @property
297 def kernel(self) -> SfgLambda:
298 return self._lambda
300 @property
301 def depends(self) -> set[SfgVar]:
302 return self._required_params
304 def get_code(self, cstyle: CodeStyle) -> str:
305 if isinstance(self._range, SfgVar):
306 range_code = self._range.name
307 else:
308 range_code = "{ " + ", ".join(str(r) for r in self._range) + " }"
310 kernel_code = self._lambda.get_code(cstyle)
311 invoker = str(self._invoker)
312 method = self._invoke_type.method
314 return f"{invoker}.{method}({range_code}, {kernel_code});"