diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 7bcf62b973d8ace8e9ad9847ae165c398f1cbb0e..4063f7b539ab387d1f950a75e735f3c6201b5ef2 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -158,7 +158,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): def __repr__(self) -> str: return f"PsConstantExpr({repr(self._constant)})" - + class PsLiteralExpr(PsLeafMixIn, PsExpression): __match_args__ = ("literal",) @@ -177,7 +177,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression): def clone(self) -> PsLiteralExpr: return PsLiteralExpr(self._literal) - + def structurally_equal(self, other: PsAstNode) -> bool: if not isinstance(other, PsLiteralExpr): return False diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index f3d56c6c4c20e5969ee10d08ee42b6803a2e0b1c..b5f8a43549d989c568201ef43873eeca850bd966 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -246,7 +246,7 @@ class CAstPrinter: ) return dtype.create_literal(constant.value) - + case PsLiteralExpr(lit): return lit.text diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index dbec20235f0a37cfab763771e2e2fbed05a3c196..7e628631dba164ab8f1c756961f1f1112437f36f 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -159,7 +159,7 @@ class TypeContext: f" Constant type: {c.dtype}\n" f" Target type: {self._target_type}" ) - + case PsLiteralExpr(lit): if not self._compatible(lit.dtype): raise TypificationError( diff --git a/src/pystencils/backend/literals.py b/src/pystencils/backend/literals.py index dc7504f520f8950b46df76b0359aaad371244b19..dc254da0e340d518929d6eecb483defcdffbe185 100644 --- a/src/pystencils/backend/literals.py +++ b/src/pystencils/backend/literals.py @@ -4,7 +4,7 @@ from ..types import PsType, constify class PsLiteral: """Representation of literal code. - + Instances of this class represent code literals inside the AST. These literals are not to be confused with C literals; the name `Literal` refers to the fact that the code generator takes them "literally", printing them as they are. @@ -22,22 +22,22 @@ class PsLiteral: @property def text(self) -> str: return self._text - + @property def dtype(self) -> PsType: return self._dtype - + def __str__(self) -> str: return f"{self._text}: {self._dtype}" - + def __repr__(self) -> str: return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})" - + def __eq__(self, other: object) -> bool: if not isinstance(other, PsLiteral): return False - + return self._text == other._text and self._dtype == other._dtype - + def __hash__(self) -> int: return hash((PsLiteral, self._text, self._dtype)) diff --git a/src/pystencils/backend/platforms/__init__.py b/src/pystencils/backend/platforms/__init__.py index 84ed6f9a24daecb60bcf8a9eed665462dd5cd567..af4d88e79aa22abcaeeaefb1ab660818bff09bcc 100644 --- a/src/pystencils/backend/platforms/__init__.py +++ b/src/pystencils/backend/platforms/__init__.py @@ -11,5 +11,5 @@ __all__ = [ "X86VectorCpu", "X86VectorArch", "GenericGpu", - "SyclPlatform" + "SyclPlatform", ] diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py index 4cbd186c0536aa210addefdfa72e1fd7d22cb291..f646768d5184f5ab7bb080e67feece881a4e2111 100644 --- a/src/pystencils/backend/platforms/sycl.py +++ b/src/pystencils/backend/platforms/sycl.py @@ -14,10 +14,14 @@ from ..constants import PsConstant from .platform import Platform from ..exceptions import MaterializationError from ...types import PsType, PsCustomType, PsIeeeFloatType, constify +from ...config import SyclIndexingConfig class SyclPlatform(Platform): + def __init__(self, indexing_cfg: SyclIndexingConfig): + self._cfg = indexing_cfg + @property def required_headers(self) -> set[str]: return {"<sycl/sycl.hpp>"} @@ -63,14 +67,14 @@ class SyclPlatform(Platform): unpackings = [] for i, dim in enumerate(ispace.dimensions): - index = PsExpression.make(PsConstant(i, self._ctx.index_dtype)) - subscript = PsSubscript(id_symbol, index) + coord = PsExpression.make(PsConstant(i, self._ctx.index_dtype)) + work_item_idx = PsSubscript(id_symbol, coord) dim.counter.dtype = constify(dim.counter.get_dtype()) - subscript.dtype = dim.counter.get_dtype() + work_item_idx.dtype = dim.counter.get_dtype() ctr = PsExpression.make(dim.counter) - unpackings.append(PsDeclaration(ctr, subscript)) + unpackings.append(PsDeclaration(ctr, dim.start + work_item_idx * dim.step)) body.statements = unpackings + body.statements return body @@ -92,3 +96,13 @@ class SyclPlatform(Platform): body.statements = [unpacking] + body.statements return body + + def _id_type(self, rank: int): + if self._cfg.use_ndrange: + return PsCustomType(f"sycl::nd_item< {rank} >", const=True) + else: + return PsCustomType(f"sycl::item< {rank} >", const=True) + + def _get_id(self): + # TODO: Need AST support for member functions to model `get_id()` / `get_global_id()` + raise NotImplementedError() diff --git a/src/pystencils/config.py b/src/pystencils/config.py index c2e1451f5016f34785090de628db3204e58f9cd1..13df58682189b8ed22b1694f8b0b8a05f10438a7 100644 --- a/src/pystencils/config.py +++ b/src/pystencils/config.py @@ -99,6 +99,26 @@ class VectorizationConfig: """ +@dataclass +class SyclIndexingConfig: + """Configure index translation behaviour inside kernels generated for `Target.SYCL`.""" + + omit_range_check: bool = False + """If set to `True`, omit the iteration counter range check. + + By default, the code generator introduces a check if the iteration counters computed from GPU block and thread + indices are within the prescribed loop range. + This check can be discarded through this option, at your own peril. + """ + + use_ndrange: bool = False + """If set to `True` while generating for `Target.SYCL`, generate the kernel for execution with a ``sycl::ndrange``. + + If `use_ndrange` is set, the kernel will receive an `nd_item` instead of an `item` from which the iteration counters + are derived. + """ + + @dataclass class CreateKernelConfig: """Options for create_kernel.""" @@ -161,6 +181,12 @@ class CreateKernelConfig: If this parameter is set while `target` is a non-CPU target, an error will be raised. """ + sycl_indexing: None | SyclIndexingConfig = None + """Configure index translation for SYCL kernels. + + It this parameter is set while `target` is not `Target.SYCL`, an error will be raised. + """ + def __post_init__(self): # Check iteration space argument consistency if ( @@ -191,6 +217,10 @@ class CreateKernelConfig: if self.cpu_optim.vectorize is not False and not self.target.is_vector_cpu(): raise PsOptionsError(f"Cannot enable auto-vectorization for non-vector CPU target {self.target}") + if self.sycl_indexing is not None: + if self.target != Target.SYCL: + raise PsOptionsError(f"`sycl_indexing` cannot be set for non-SYCL target {self.target}") + # Infer JIT if self.jit is None: if self.target.is_cpu():