Skip to content
Snippets Groups Projects
Commit 957b2b7e authored by Christoph Alt's avatar Christoph Alt
Browse files

First draft to add support for sycl accessors:

Added a function to add sycl accessors from a ps.field.
Extended the parrallel_for to make it possible add more statements
before the actual kernel call
parent ba2ff160
No related branches found
No related tags found
1 merge request!3Add support for sycl accessors
......@@ -6,6 +6,8 @@ import re
from pystencils.types import PsType, PsCustomType
from pystencils.enums import Target
from pystencilssfg.composer.basic_composer import SequencerArg
from ..exceptions import SfgException
from ..context import SfgContext
from ..composer import (
......@@ -13,6 +15,7 @@ from ..composer import (
SfgClassComposer,
SfgComposer,
SfgComposerMixIn,
make_sequence,
)
from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude
from ..ir import (
......@@ -56,12 +59,13 @@ class SyclHandler(AugExpr):
self._ctx = ctx
def parallel_for(self, range: SfgVar | Sequence[int], kernel: SfgKernelHandle):
def parallel_for(self, range: SfgVar | Sequence[int], kernel: SfgKernelHandle, *, extras: Sequence[SequencerArg]=[]):
"""Generate a ``parallel_for`` kernel invocation using this command group handler.
Args:
range: Object, or tuple of integers, indicating the kernel's iteration range
kernel: Handle to the pystencils-kernel to be executed
extras: Statements that should be in the parallel_for but before the kernel call
"""
self._ctx.add_include(SfgHeaderInclude("sycl/sycl.hpp", system_header=True))
......@@ -81,7 +85,7 @@ class SyclHandler(AugExpr):
id_param = list(filter(filter_id, kernel.scalar_parameters))[0]
tree = SfgKernelCallNode(kernel)
tree = make_sequence(*extras, SfgKernelCallNode(kernel))
kernel_lambda = SfgLambda(("=",), (id_param,), tree, None)
return SyclKernelInvoke(self, SyclInvokeType.ParallelFor, range, kernel_lambda)
......
from ...lang import SrcField, IFieldExtraction
from ...ir.source_components import SfgHeaderInclude
from pystencils import Field
from pystencils.types import (
PsType,
PsCustomType,
)
from pystencilssfg.lang.expressions import AugExpr
class SyclAccessor(SrcField):
def __init__(
self,
T: PsType,
dimensions: int,
reference: bool = False,
):
cpp_typestr = T.c_string()
if dimensions not in [1, 2, 3]:
raise ValueError("sycl accessors can only have dims 1, 2 or 3")
typestring = (
f"sycl::accessor< {cpp_typestr}, {dimensions} > {'&' if reference else ''}"
)
super().__init__(PsCustomType(typestring))
self._dim = dimensions
@property
def required_includes(self) -> set[SfgHeaderInclude]:
return {SfgHeaderInclude("sycl/sycl.hpp", system_header=True)}
def get_extraction(self) -> IFieldExtraction:
accessor = self
class Extraction(IFieldExtraction):
def ptr(self) -> AugExpr:
return AugExpr.format(
"{}.get_multi_ptr<sycl::access::decorated::no>().get()",
accessor,
)
def size(self, coordinate: int) -> AugExpr | None:
if coordinate > accessor._dim:
return None
else:
return AugExpr.format(
"{}.get_range().get({})", accessor, coordinate
)
def stride(self, coordinate: int) -> AugExpr | None:
if coordinate > accessor._dim:
return None
else:
if coordinate == accessor._dim - 1:
return AugExpr.format("1")
else:
exprs = []
args = []
for d in range(coordinate + 1, accessor._dim):
args.extend([accessor, d])
exprs.append("{}.get_range().get({})")
expr = " * ".join(exprs)
return AugExpr.format(expr, *args)
return Extraction()
def sycl_accessor_ref(field: Field):
"""Creates a `sycl::accessor &` for a given pystencils field."""
return SyclAccessor(
field.dtype,
field.spatial_dimensions,
reference=True,
).var(field.name)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment