Skip to content
Snippets Groups Projects
Commit d447eb1b authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fix work item indices. Add sycl indexing config.

parent 03b9e02a
No related branches found
No related tags found
1 merge request!384Fundamental GPU Support
Pipeline #66666 passed
...@@ -158,7 +158,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): ...@@ -158,7 +158,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"PsConstantExpr({repr(self._constant)})" return f"PsConstantExpr({repr(self._constant)})"
class PsLiteralExpr(PsLeafMixIn, PsExpression): class PsLiteralExpr(PsLeafMixIn, PsExpression):
__match_args__ = ("literal",) __match_args__ = ("literal",)
...@@ -177,7 +177,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression): ...@@ -177,7 +177,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression):
def clone(self) -> PsLiteralExpr: def clone(self) -> PsLiteralExpr:
return PsLiteralExpr(self._literal) return PsLiteralExpr(self._literal)
def structurally_equal(self, other: PsAstNode) -> bool: def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsLiteralExpr): if not isinstance(other, PsLiteralExpr):
return False return False
......
...@@ -246,7 +246,7 @@ class CAstPrinter: ...@@ -246,7 +246,7 @@ class CAstPrinter:
) )
return dtype.create_literal(constant.value) return dtype.create_literal(constant.value)
case PsLiteralExpr(lit): case PsLiteralExpr(lit):
return lit.text return lit.text
......
...@@ -159,7 +159,7 @@ class TypeContext: ...@@ -159,7 +159,7 @@ class TypeContext:
f" Constant type: {c.dtype}\n" f" Constant type: {c.dtype}\n"
f" Target type: {self._target_type}" f" Target type: {self._target_type}"
) )
case PsLiteralExpr(lit): case PsLiteralExpr(lit):
if not self._compatible(lit.dtype): if not self._compatible(lit.dtype):
raise TypificationError( raise TypificationError(
......
...@@ -4,7 +4,7 @@ from ..types import PsType, constify ...@@ -4,7 +4,7 @@ from ..types import PsType, constify
class PsLiteral: class PsLiteral:
"""Representation of literal code. """Representation of literal code.
Instances of this class represent code literals inside the AST. Instances of this class represent code literals inside the AST.
These literals are not to be confused with C literals; the name `Literal` refers to the fact that These literals are not to be confused with C literals; the name `Literal` refers to the fact that
the code generator takes them "literally", printing them as they are. the code generator takes them "literally", printing them as they are.
...@@ -22,22 +22,22 @@ class PsLiteral: ...@@ -22,22 +22,22 @@ class PsLiteral:
@property @property
def text(self) -> str: def text(self) -> str:
return self._text return self._text
@property @property
def dtype(self) -> PsType: def dtype(self) -> PsType:
return self._dtype return self._dtype
def __str__(self) -> str: def __str__(self) -> str:
return f"{self._text}: {self._dtype}" return f"{self._text}: {self._dtype}"
def __repr__(self) -> str: def __repr__(self) -> str:
return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})" return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})"
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, PsLiteral): if not isinstance(other, PsLiteral):
return False return False
return self._text == other._text and self._dtype == other._dtype return self._text == other._text and self._dtype == other._dtype
def __hash__(self) -> int: def __hash__(self) -> int:
return hash((PsLiteral, self._text, self._dtype)) return hash((PsLiteral, self._text, self._dtype))
...@@ -11,5 +11,5 @@ __all__ = [ ...@@ -11,5 +11,5 @@ __all__ = [
"X86VectorCpu", "X86VectorCpu",
"X86VectorArch", "X86VectorArch",
"GenericGpu", "GenericGpu",
"SyclPlatform" "SyclPlatform",
] ]
...@@ -14,10 +14,14 @@ from ..constants import PsConstant ...@@ -14,10 +14,14 @@ from ..constants import PsConstant
from .platform import Platform from .platform import Platform
from ..exceptions import MaterializationError from ..exceptions import MaterializationError
from ...types import PsType, PsCustomType, PsIeeeFloatType, constify from ...types import PsType, PsCustomType, PsIeeeFloatType, constify
from ...config import SyclIndexingConfig
class SyclPlatform(Platform): class SyclPlatform(Platform):
def __init__(self, indexing_cfg: SyclIndexingConfig):
self._cfg = indexing_cfg
@property @property
def required_headers(self) -> set[str]: def required_headers(self) -> set[str]:
return {"<sycl/sycl.hpp>"} return {"<sycl/sycl.hpp>"}
...@@ -63,14 +67,14 @@ class SyclPlatform(Platform): ...@@ -63,14 +67,14 @@ class SyclPlatform(Platform):
unpackings = [] unpackings = []
for i, dim in enumerate(ispace.dimensions): for i, dim in enumerate(ispace.dimensions):
index = PsExpression.make(PsConstant(i, self._ctx.index_dtype)) coord = PsExpression.make(PsConstant(i, self._ctx.index_dtype))
subscript = PsSubscript(id_symbol, index) work_item_idx = PsSubscript(id_symbol, coord)
dim.counter.dtype = constify(dim.counter.get_dtype()) dim.counter.dtype = constify(dim.counter.get_dtype())
subscript.dtype = dim.counter.get_dtype() work_item_idx.dtype = dim.counter.get_dtype()
ctr = PsExpression.make(dim.counter) ctr = PsExpression.make(dim.counter)
unpackings.append(PsDeclaration(ctr, subscript)) unpackings.append(PsDeclaration(ctr, dim.start + work_item_idx * dim.step))
body.statements = unpackings + body.statements body.statements = unpackings + body.statements
return body return body
...@@ -92,3 +96,13 @@ class SyclPlatform(Platform): ...@@ -92,3 +96,13 @@ class SyclPlatform(Platform):
body.statements = [unpacking] + body.statements body.statements = [unpacking] + body.statements
return body return body
def _id_type(self, rank: int):
if self._cfg.use_ndrange:
return PsCustomType(f"sycl::nd_item< {rank} >", const=True)
else:
return PsCustomType(f"sycl::item< {rank} >", const=True)
def _get_id(self):
# TODO: Need AST support for member functions to model `get_id()` / `get_global_id()`
raise NotImplementedError()
...@@ -99,6 +99,26 @@ class VectorizationConfig: ...@@ -99,6 +99,26 @@ class VectorizationConfig:
""" """
@dataclass
class SyclIndexingConfig:
"""Configure index translation behaviour inside kernels generated for `Target.SYCL`."""
omit_range_check: bool = False
"""If set to `True`, omit the iteration counter range check.
By default, the code generator introduces a check if the iteration counters computed from GPU block and thread
indices are within the prescribed loop range.
This check can be discarded through this option, at your own peril.
"""
use_ndrange: bool = False
"""If set to `True` while generating for `Target.SYCL`, generate the kernel for execution with a ``sycl::ndrange``.
If `use_ndrange` is set, the kernel will receive an `nd_item` instead of an `item` from which the iteration counters
are derived.
"""
@dataclass @dataclass
class CreateKernelConfig: class CreateKernelConfig:
"""Options for create_kernel.""" """Options for create_kernel."""
...@@ -161,6 +181,12 @@ class CreateKernelConfig: ...@@ -161,6 +181,12 @@ class CreateKernelConfig:
If this parameter is set while `target` is a non-CPU target, an error will be raised. If this parameter is set while `target` is a non-CPU target, an error will be raised.
""" """
sycl_indexing: None | SyclIndexingConfig = None
"""Configure index translation for SYCL kernels.
It this parameter is set while `target` is not `Target.SYCL`, an error will be raised.
"""
def __post_init__(self): def __post_init__(self):
# Check iteration space argument consistency # Check iteration space argument consistency
if ( if (
...@@ -191,6 +217,10 @@ class CreateKernelConfig: ...@@ -191,6 +217,10 @@ class CreateKernelConfig:
if self.cpu_optim.vectorize is not False and not self.target.is_vector_cpu(): if self.cpu_optim.vectorize is not False and not self.target.is_vector_cpu():
raise PsOptionsError(f"Cannot enable auto-vectorization for non-vector CPU target {self.target}") raise PsOptionsError(f"Cannot enable auto-vectorization for non-vector CPU target {self.target}")
if self.sycl_indexing is not None:
if self.target != Target.SYCL:
raise PsOptionsError(f"`sycl_indexing` cannot be set for non-SYCL target {self.target}")
# Infer JIT # Infer JIT
if self.jit is None: if self.jit is None:
if self.target.is_cpu(): if self.target.is_cpu():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment