Skip to content
Snippets Groups Projects
Commit 218c3334 authored by Christoph Alt's avatar Christoph Alt
Browse files

First draft to add support for sycl accessors:

Added a function to add sycl accessors from a ps.field.
Extended the parrallel_for to make it possible add more statements
before the actual kernel call
parent a47eeeeb
No related branches found
No related tags found
1 merge request!3Add support for sycl accessors
...@@ -6,6 +6,8 @@ import re ...@@ -6,6 +6,8 @@ import re
from pystencils.types import PsType, PsCustomType from pystencils.types import PsType, PsCustomType
from pystencils.enums import Target from pystencils.enums import Target
from pystencilssfg.composer.basic_composer import SequencerArg
from ..exceptions import SfgException from ..exceptions import SfgException
from ..context import SfgContext from ..context import SfgContext
from ..composer import ( from ..composer import (
...@@ -13,6 +15,7 @@ from ..composer import ( ...@@ -13,6 +15,7 @@ from ..composer import (
SfgClassComposer, SfgClassComposer,
SfgComposer, SfgComposer,
SfgComposerMixIn, SfgComposerMixIn,
make_sequence,
) )
from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude
from ..ir import ( from ..ir import (
...@@ -56,12 +59,13 @@ class SyclHandler(AugExpr): ...@@ -56,12 +59,13 @@ class SyclHandler(AugExpr):
self._ctx = ctx self._ctx = ctx
def parallel_for(self, range: SfgVar | Sequence[int], kernel: SfgKernelHandle): def parallel_for(self, range: SfgVar | Sequence[int], kernel: SfgKernelHandle, *, extras: Sequence[SequencerArg]=[]):
"""Generate a ``parallel_for`` kernel invocation using this command group handler. """Generate a ``parallel_for`` kernel invocation using this command group handler.
Args: Args:
range: Object, or tuple of integers, indicating the kernel's iteration range range: Object, or tuple of integers, indicating the kernel's iteration range
kernel: Handle to the pystencils-kernel to be executed kernel: Handle to the pystencils-kernel to be executed
extras: Statements that should be in the parallel_for but before the kernel call
""" """
self._ctx.add_include(SfgHeaderInclude("sycl/sycl.hpp", system_header=True)) self._ctx.add_include(SfgHeaderInclude("sycl/sycl.hpp", system_header=True))
...@@ -81,7 +85,7 @@ class SyclHandler(AugExpr): ...@@ -81,7 +85,7 @@ class SyclHandler(AugExpr):
id_param = list(filter(filter_id, kernel.scalar_parameters))[0] id_param = list(filter(filter_id, kernel.scalar_parameters))[0]
tree = SfgKernelCallNode(kernel) tree = make_sequence(*extras, SfgKernelCallNode(kernel))
kernel_lambda = SfgLambda(("=",), (id_param,), tree, None) kernel_lambda = SfgLambda(("=",), (id_param,), tree, None)
return SyclKernelInvoke(self, SyclInvokeType.ParallelFor, range, kernel_lambda) return SyclKernelInvoke(self, SyclInvokeType.ParallelFor, range, kernel_lambda)
......
from ...lang import SrcField, IFieldExtraction
from ...ir.source_components import SfgHeaderInclude
from pystencils import Field
from pystencils.types import (
PsType,
PsCustomType,
)
from pystencilssfg.lang.expressions import AugExpr
class SyclAccessor(SrcField):
def __init__(
self,
T: PsType,
dimensions: int,
reference: bool = False,
):
cpp_typestr = T.c_string()
if dimensions not in [1, 2, 3]:
raise ValueError("sycl accessors can only have dims 1, 2 or 3")
typestring = (
f"sycl::accessor< {cpp_typestr}, {dimensions} > {'&' if reference else ''}"
)
super().__init__(PsCustomType(typestring))
self._dim = dimensions
@property
def required_includes(self) -> set[SfgHeaderInclude]:
return {SfgHeaderInclude("sycl/sycl.hpp", system_header=True)}
def get_extraction(self) -> IFieldExtraction:
accessor = self
class Extraction(IFieldExtraction):
def ptr(self) -> AugExpr:
return AugExpr.format(
"{}.get_multi_ptr<sycl::access::decorated::no>().get()",
accessor,
)
def size(self, coordinate: int) -> AugExpr | None:
if coordinate > accessor._dim:
return None
else:
return AugExpr.format(
"{}.get_range().get({})", accessor, coordinate
)
def stride(self, coordinate: int) -> AugExpr | None:
if coordinate > accessor._dim:
return None
else:
if coordinate == accessor._dim - 1:
return AugExpr.format("1")
else:
exprs = []
args = []
for d in range(coordinate + 1, accessor._dim):
args.extend([accessor, d])
exprs.append("{}.get_range().get({})")
expr = " * ".join(exprs)
return AugExpr.format(expr, *args)
return Extraction()
def sycl_accessor_ref(field: Field):
"""Creates a `sycl::accessor &` for a given pystencils field."""
return SyclAccessor(
field.dtype,
field.spatial_dimensions,
reference=True,
).var(field.name)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment