Skip to content
Snippets Groups Projects

Add support for sycl accessors

Merged Christoph Alt requested to merge ob28imeq/pystencils-sfg:sycl_accessor into master
All threads resolved!
Viewing commit 218c3334
Show latest version
2 files
+ 83
2
Preferences
Compare changes
Files
2
@@ -6,6 +6,8 @@ import re
@@ -6,6 +6,8 @@ import re
from pystencils.types import PsType, PsCustomType
from pystencils.types import PsType, PsCustomType
from pystencils.enums import Target
from pystencils.enums import Target
 
from pystencilssfg.composer.basic_composer import SequencerArg
 
from ..exceptions import SfgException
from ..exceptions import SfgException
from ..context import SfgContext
from ..context import SfgContext
from ..composer import (
from ..composer import (
@@ -13,6 +15,7 @@ from ..composer import (
@@ -13,6 +15,7 @@ from ..composer import (
SfgClassComposer,
SfgClassComposer,
SfgComposer,
SfgComposer,
SfgComposerMixIn,
SfgComposerMixIn,
 
make_sequence,
)
)
from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude
from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude
from ..ir import (
from ..ir import (
@@ -56,12 +59,13 @@ class SyclHandler(AugExpr):
@@ -56,12 +59,13 @@ class SyclHandler(AugExpr):
self._ctx = ctx
self._ctx = ctx
def parallel_for(self, range: SfgVar | Sequence[int], kernel: SfgKernelHandle):
def parallel_for(self, range: SfgVar | Sequence[int], kernel: SfgKernelHandle, *, extras: Sequence[SequencerArg]=[]):
"""Generate a ``parallel_for`` kernel invocation using this command group handler.
"""Generate a ``parallel_for`` kernel invocation using this command group handler.
Args:
Args:
range: Object, or tuple of integers, indicating the kernel's iteration range
range: Object, or tuple of integers, indicating the kernel's iteration range
kernel: Handle to the pystencils-kernel to be executed
kernel: Handle to the pystencils-kernel to be executed
 
extras: Statements that should be in the parallel_for but before the kernel call
"""
"""
self._ctx.add_include(SfgHeaderInclude("sycl/sycl.hpp", system_header=True))
self._ctx.add_include(SfgHeaderInclude("sycl/sycl.hpp", system_header=True))
@@ -81,7 +85,7 @@ class SyclHandler(AugExpr):
@@ -81,7 +85,7 @@ class SyclHandler(AugExpr):
id_param = list(filter(filter_id, kernel.scalar_parameters))[0]
id_param = list(filter(filter_id, kernel.scalar_parameters))[0]
tree = SfgKernelCallNode(kernel)
tree = make_sequence(*extras, SfgKernelCallNode(kernel))
kernel_lambda = SfgLambda(("=",), (id_param,), tree, None)
kernel_lambda = SfgLambda(("=",), (id_param,), tree, None)
return SyclKernelInvoke(self, SyclInvokeType.ParallelFor, range, kernel_lambda)
return SyclKernelInvoke(self, SyclInvokeType.ParallelFor, range, kernel_lambda)