Skip to content
Snippets Groups Projects
Commit 72bd2049 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

add basic SYCL support

parent d0654625
No related branches found
No related tags found
1 merge request!384Fundamental GPU Support
......@@ -46,7 +46,7 @@ class IterationSpace(ABC):
return self._spatial_indices
@property
def dim(self) -> int:
def rank(self) -> int:
return len(self._spatial_indices)
......@@ -223,7 +223,7 @@ class FullIterationSpace(IterationSpace):
"""Expression counting the actual number of items processed at the iteration defined by the counter tuple.
Used primarily for indexing buffers."""
actual_iters = [self.actual_iterations(d) for d in range(self.dim)]
actual_iters = [self.actual_iterations(d) for d in range(self.rank)]
compressed_counters = [
(PsExpression.make(dim.counter) - dim.start) / dim.step
for dim in self.dimensions
......
......@@ -2,6 +2,7 @@ from .platform import Platform
from .generic_cpu import GenericCpu, GenericVectorCpu
from .generic_gpu import GenericGpu
from .x86 import X86VectorCpu, X86VectorArch
from .sycl import SyclPlatform
__all__ = [
"Platform",
......@@ -10,4 +11,5 @@ __all__ = [
"X86VectorCpu",
"X86VectorArch",
"GenericGpu",
"SyclPlatform"
]
......@@ -49,7 +49,7 @@ class GenericCpu(Platform):
elif isinstance(ispace, SparseIterationSpace):
return self._create_sparse_loop(body, ispace)
else:
assert False, "unreachable code"
raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def select_function(
self, math_function: PsMathFunction, dtype: PsType
......
from ..functions import CFunction, PsMathFunction, MathFunctions
from ..ast.structural import PsBlock
from ..kernelcreation.iteration_space import (
IterationSpace,
FullIterationSpace,
SparseIterationSpace,
)
from ..ast.structural import PsDeclaration
from ..ast.expressions import (
PsExpression,
PsSubscript,
)
from ..constants import PsConstant
from .platform import Platform
from ..exceptions import MaterializationError
from ...types import PsType, PsCustomType, PsIeeeFloatType, constify
class SyclPlatform(Platform):
@property
def required_headers(self) -> set[str]:
return {"<sycl/sycl.hpp>"}
def materialize_iteration_space(
self, body: PsBlock, ispace: IterationSpace
) -> PsBlock:
if isinstance(ispace, FullIterationSpace):
return self._prepend_dense_translation(body, ispace)
elif isinstance(ispace, SparseIterationSpace):
return self._prepend_sparse_translation(body, ispace)
else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def select_function(
self, math_function: PsMathFunction, dtype: PsType
) -> CFunction:
func = math_function.func
if isinstance(dtype, PsIeeeFloatType) and dtype.width in (16, 32, 64):
match func:
case (
MathFunctions.Exp
| MathFunctions.Sin
| MathFunctions.Cos
| MathFunctions.Tan
| MathFunctions.Pow
):
return CFunction(f"sycl::{func.function_name}", func.num_args)
case MathFunctions.Abs | MathFunctions.Min | MathFunctions.Max:
return CFunction(f"sycl::f{func.function_name}", func.num_args)
raise MaterializationError(
f"No implementation available for function {math_function} on data type {dtype}"
)
def _prepend_dense_translation(
self, body: PsBlock, ispace: FullIterationSpace
) -> PsBlock:
rank = ispace.rank
id_type = PsCustomType(f"sycl::id< {rank} >", const=True)
id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type))
unpackings = []
for i, dim in enumerate(ispace.dimensions):
index = PsExpression.make(PsConstant(i, self._ctx.index_dtype))
subscript = PsSubscript(id_symbol, index)
dim.counter.dtype = constify(dim.counter.get_dtype())
subscript.dtype = dim.counter.get_dtype()
ctr = PsExpression.make(dim.counter)
unpackings.append(PsDeclaration(ctr, subscript))
body.statements = unpackings + body.statements
return body
def _prepend_sparse_translation(
self, body: PsBlock, ispace: SparseIterationSpace
) -> PsBlock:
id_type = PsCustomType("sycl::id< 1 >", const=True)
id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type))
zero = PsExpression.make(PsConstant(0, self._ctx.index_dtype))
subscript = PsSubscript(id_symbol, zero)
ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
subscript.dtype = ispace.sparse_counter.get_dtype()
ctr = PsExpression.make(ispace.sparse_counter)
unpacking = PsDeclaration(ctr, subscript)
body.statements = [unpacking] + body.statements
return body
......@@ -25,6 +25,8 @@ class Target(Flag):
_CUDA = auto()
_SYCL = auto()
_AUTOMATIC = auto()
# ------------------ Actual Targets -------------------------------------------------------------------
......@@ -63,16 +65,14 @@ class Target(Flag):
"""ARM architecture with SVE vector extensions"""
CurrentGPU = _GPU | _AUTOMATIC
"""
Auto-best GPU target.
"""Auto-best GPU target.
`CurrentGPU` causes the code generator to automatically select a GPU target according to GPU devices
found on the current machine and runtime environment.
"""
GenericCUDA = _GPU | _CUDA
"""
Generic CUDA GPU target.
"""Generic CUDA GPU target.
Generate a CUDA kernel for a generic Nvidia GPU.
"""
......@@ -80,6 +80,12 @@ class Target(Flag):
GPU = GenericCUDA
"""Alias for backward compatibility."""
SYCL = _SYCL
"""SYCL kernel target.
Generate a function to be called within a SYCL parallel command.
"""
def is_automatic(self) -> bool:
return Target._AUTOMATIC in self
......
......@@ -19,6 +19,7 @@ from .backend.kernelcreation import (
FreezeExpressions,
Typifier,
)
from .backend.platforms import Platform
from .backend.kernelcreation.iteration_space import (
create_sparse_iteration_space,
create_full_iteration_space,
......@@ -90,9 +91,11 @@ def create_kernel(
from .backend.platforms import GenericCpu
platform = GenericCpu(ctx)
case Target.SYCL:
from .backend.platforms import SyclPlatform
platform = SyclPlatform(ctx)
case _:
# TODO: CUDA/HIP platform
# TODO: SYCL platform (?)
raise NotImplementedError("Target platform not implemented")
kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
......@@ -114,12 +117,13 @@ def create_kernel(
assert config.jit is not None
return create_kernel_function(
ctx, kernel_ast, config.function_name, config.target, config.jit
ctx, platform, kernel_ast, config.function_name, config.target, config.jit
)
def create_kernel_function(
ctx: KernelCreationContext,
platform: Platform,
body: PsBlock,
function_name: str,
target_spec: Target,
......@@ -145,6 +149,7 @@ def create_kernel_function(
params.sort(key=lambda p: p.name)
req_headers = collect_required_headers(body)
req_headers |= platform.required_headers
req_headers |= ctx.required_headers
return KernelFunction(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment