Skip to content
Snippets Groups Projects
Commit 9672cf57 authored by Christoph Alt's avatar Christoph Alt
Browse files

reworked the api for parallel_for to look like the sfg.function api

and extended the sycl_accessor mapping to handle index dimensions
parent 670d03c2
No related branches found
No related tags found
1 merge request!3Add support for sycl accessors
......@@ -9,6 +9,7 @@
# dev environment
**/.venv
**/venv
# build artifacts
dist
......@@ -22,4 +23,4 @@ htmlcov
coverage.xml
# mkdocs
site
\ No newline at end of file
site
from pystencils import Target, CreateKernelConfig, no_jit
from lbmpy import create_lb_update_rule, LBMOptimisation
from pystencilssfg import SourceFileGenerator, SfgConfiguration, SfgOutputMode
from pystencilssfg.lang.cpp.sycl_accessor import sycl_accessor_ref
import pystencilssfg.extensions.sycl as sycl
from itertools import chain
sfg_config = SfgConfiguration(
output_directory="out/test_sycl_buffer",
outer_namespace="gen_code",
impl_extension="ipp",
output_mode=SfgOutputMode.INLINE,
)
with SourceFileGenerator(sfg_config) as sfg:
sfg = sycl.SyclComposer(sfg)
gen_config = CreateKernelConfig(target=Target.SYCL, jit=no_jit)
opt = LBMOptimisation(field_layout="fzyx")
update = create_lb_update_rule(lbm_optimisation=opt)
kernel = sfg.kernels.create(update, "lbm_update", gen_config)
cgh = sfg.sycl_handler("handler")
rang = sfg.sycl_range(update.method.dim, "range")
mappings = [
sfg.map_field(field, sycl_accessor_ref(field))
for field in chain(update.free_fields, update.bound_fields)
]
sfg.function("lb_update")(
cgh.parallel_for(rang)(
*mappings,
sfg.call(kernel),
),
)
......@@ -62,24 +62,29 @@ class SyclHandler(AugExpr):
def parallel_for(
self,
range: SfgVar | Sequence[int],
kernel: SfgKernelHandle,
*,
extras: Sequence[SequencerArg] = [],
):
"""Generate a ``parallel_for`` kernel invocation using this command group handler.
The syntax of this uses a chain of two calls to mimic C++ syntax:
.. code-block:: Python
sfg.parallel_for(range)(
# Body
)
The body is constructed via sequencing (see `make_sequence`).
Args:
range: Object, or tuple of integers, indicating the kernel's iteration range
kernel: Handle to the pystencils-kernel to be executed
extras: Statements that should be in the parallel_for but before the kernel call
"""
self._ctx.add_include(SfgHeaderInclude("sycl/sycl.hpp", system_header=True))
kfunc = kernel.get_kernel_function()
if kfunc.target != Target.SYCL:
raise SfgException(
f"Kernel given to `parallel_for` is no SYCL kernel: {kernel.kernel_name}"
)
def check_kernel(kernel: SfgKernelHandle):
kfunc = kernel.get_kernel_function()
if kfunc.target != Target.SYCL:
raise SfgException(
f"Kernel given to `parallel_for` is no SYCL kernel: {kernel.kernel_name}"
)
id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>")
......@@ -89,12 +94,25 @@ class SyclHandler(AugExpr):
and id_regex.search(param.dtype.c_string()) is not None
)
id_param = list(filter(filter_id, kernel.scalar_parameters))[0]
tree = make_sequence(*extras, SfgKernelCallNode(kernel))
def sequencer(*args: SequencerArg):
id_param = []
for arg in args:
if isinstance(arg, SfgKernelCallNode):
check_kernel(arg._kernel_handle)
id_param.append(list(filter(filter_id, arg._kernel_handle.scalar_parameters))[0])
if not all(item == id_param[0] for item in id_param):
raise ValueError(
"id_param should be the same for all kernels in parallel_for"
)
tree = make_sequence(*args)
kernel_lambda = SfgLambda(("=",), (id_param[0],), tree, None)
return SyclKernelInvoke(
self, SyclInvokeType.ParallelFor, range, kernel_lambda
)
kernel_lambda = SfgLambda(("=",), (id_param,), tree, None)
return SyclKernelInvoke(self, SyclInvokeType.ParallelFor, range, kernel_lambda)
return sequencer
class SyclGroup(AugExpr):
......
import math
from ...lang import SrcField, IFieldExtraction
from ...ir.source_components import SfgHeaderInclude
from typing import Sequence
from pystencils import Field
from pystencils.types import (
......@@ -15,6 +17,7 @@ class SyclAccessor(SrcField):
self,
T: PsType,
dimensions: int,
index_shape: Sequence[int],
reference: bool = False,
):
cpp_typestr = T.c_string()
......@@ -25,7 +28,11 @@ class SyclAccessor(SrcField):
)
super().__init__(PsCustomType(typestring))
self._dim = dimensions
self._spatial_dimensions = dimensions
self._index_dimensions = len(index_shape)
self._index_shape = index_shape
self._index_size = math.prod(index_shape)
self._total_dimensions_ = self._spatial_dimensions + self._index_dimensions
@property
def required_includes(self) -> set[SfgHeaderInclude]:
......@@ -42,7 +49,7 @@ class SyclAccessor(SrcField):
)
def size(self, coordinate: int) -> AugExpr | None:
if coordinate > accessor._dim:
if coordinate > accessor._spatial_dimensions:
return None
else:
return AugExpr.format(
......@@ -50,28 +57,36 @@ class SyclAccessor(SrcField):
)
def stride(self, coordinate: int) -> AugExpr | None:
if coordinate > accessor._dim:
if coordinate > accessor._total_dimensions_:
return None
elif coordinate >= accessor._spatial_dimensions - 1:
start = (coordinate - accessor._spatial_dimensions) + 1
return AugExpr.format(
"{}", math.prod(accessor._index_shape[start:])
)
else:
if coordinate == accessor._dim - 1:
return AugExpr.format("1")
else:
exprs = []
args = []
for d in range(coordinate + 1, accessor._dim):
args.extend([accessor, d])
exprs.append("{}.get_range().get({})")
expr = " * ".join(exprs)
return AugExpr.format(expr, *args)
exprs = []
args = []
for d in range(coordinate + 1, accessor._spatial_dimensions):
args.extend([accessor, d])
exprs.append("{}.get_range().get({})")
expr = " * ".join(exprs)
expr += " * {}"
return AugExpr.format(expr, *args, accessor._index_size)
return Extraction()
def sycl_accessor_ref(field: Field):
"""Creates a `sycl::accessor &` for a given pystencils field."""
# Sycl accesors allow only at max 3 dimensions:
# So also mapping the index dimens to the sycl accesor we only can have 2D LBM stuff
# In principle it would be possible to map it to something like sycl::buffer<std::array<double, 19>, 3>
# but then would need to generate kernels that have sycl accessors as arguments
return SyclAccessor(
field.dtype,
field.spatial_dimensions,
field.index_shape,
reference=True,
).var(field.name)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment