Skip to content
Snippets Groups Projects
Commit 95907f7b authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Basic CUDA support.

 - Implement GPU index translation for CUDA target
 - Expose CUDA code generation through `create_kernel`
 - enable CUDA kernel printing
parent ff396a6a
No related branches found
No related tags found
1 merge request!384Fundamental GPU Support
Pipeline #67159 failed
...@@ -18,6 +18,7 @@ from .config import ( ...@@ -18,6 +18,7 @@ from .config import (
from .kernel_decorator import kernel, kernel_config from .kernel_decorator import kernel, kernel_config
from .kernelcreation import create_kernel from .kernelcreation import create_kernel
from .backend.kernelfunction import KernelFunction from .backend.kernelfunction import KernelFunction
from .backend.jit import no_jit
from .slicing import make_slice from .slicing import make_slice
from .spatial_coordinates import ( from .spatial_coordinates import (
x_, x_,
...@@ -51,6 +52,7 @@ __all__ = [ ...@@ -51,6 +52,7 @@ __all__ = [
"create_kernel", "create_kernel",
"KernelFunction", "KernelFunction",
"Target", "Target",
"no_jit",
"show_code", "show_code",
"to_dot", "to_dot",
"get_code_obj", "get_code_obj",
......
...@@ -5,6 +5,7 @@ from .kernelfunction import ( ...@@ -5,6 +5,7 @@ from .kernelfunction import (
FieldStrideParam, FieldStrideParam,
FieldPointerParam, FieldPointerParam,
KernelFunction, KernelFunction,
GpuKernelFunction,
) )
from .constraints import KernelParamsConstraint from .constraints import KernelParamsConstraint
...@@ -16,5 +17,6 @@ __all__ = [ ...@@ -16,5 +17,6 @@ __all__ = [
"FieldStrideParam", "FieldStrideParam",
"FieldPointerParam", "FieldPointerParam",
"KernelFunction", "KernelFunction",
"GpuKernelFunction",
"KernelParamsConstraint", "KernelParamsConstraint",
] ]
from __future__ import annotations from __future__ import annotations
from enum import Enum from enum import Enum
from ..enums import Target
from .ast.structural import ( from .ast.structural import (
PsAstNode, PsAstNode,
PsBlock, PsBlock,
...@@ -53,7 +55,7 @@ from .extensions.foreign_ast import PsForeignExpression ...@@ -53,7 +55,7 @@ from .extensions.foreign_ast import PsForeignExpression
from .symbols import PsSymbol from .symbols import PsSymbol
from ..types import PsScalarType, PsArrayType from ..types import PsScalarType, PsArrayType
from .kernelfunction import KernelFunction from .kernelfunction import KernelFunction, GpuKernelFunction
__all__ = ["emit_code", "CAstPrinter"] __all__ = ["emit_code", "CAstPrinter"]
...@@ -167,10 +169,13 @@ class CAstPrinter: ...@@ -167,10 +169,13 @@ class CAstPrinter:
def __call__(self, obj: PsAstNode | KernelFunction) -> str: def __call__(self, obj: PsAstNode | KernelFunction) -> str:
if isinstance(obj, KernelFunction): if isinstance(obj, KernelFunction):
prefix = self._func_prefix(obj)
params_str = ", ".join( params_str = ", ".join(
f"{p.dtype.c_string()} {p.name}" for p in obj.parameters f"{p.dtype.c_string()} {p.name}" for p in obj.parameters
) )
decl = f"FUNC_PREFIX void {obj.name} ({params_str})"
decl = " ".join([prefix, "void", obj.name, f"({params_str})"])
body_code = self.visit(obj.body, PrinterCtx()) body_code = self.visit(obj.body, PrinterCtx())
return f"{decl}\n{body_code}" return f"{decl}\n{body_code}"
else: else:
...@@ -336,7 +341,7 @@ class CAstPrinter: ...@@ -336,7 +341,7 @@ class CAstPrinter:
items_str = ", ".join(self.visit(item, pc) for item in items) items_str = ", ".join(self.visit(item, pc) for item in items)
pc.pop_op() pc.pop_op()
return "{ " + items_str + " }" return "{ " + items_str + " }"
case PsForeignExpression(children): case PsForeignExpression(children):
pc.push_op(Ops.Weakest, LR.Middle) pc.push_op(Ops.Weakest, LR.Middle)
foreign_code = node.get_code(self.visit(c, pc) for c in children) foreign_code = node.get_code(self.visit(c, pc) for c in children)
...@@ -346,6 +351,12 @@ class CAstPrinter: ...@@ -346,6 +351,12 @@ class CAstPrinter:
case _: case _:
raise NotImplementedError(f"Don't know how to print {node}") raise NotImplementedError(f"Don't know how to print {node}")
def _func_prefix(self, func: KernelFunction):
if isinstance(func, GpuKernelFunction) and func.target == Target.CUDA:
return "__global__"
else:
return "FUNC_PREFIX"
def _symbol_decl(self, symb: PsSymbol): def _symbol_decl(self, symb: PsSymbol):
dtype = symb.get_dtype() dtype = symb.get_dtype()
......
...@@ -11,7 +11,7 @@ from ...types import PsType ...@@ -11,7 +11,7 @@ from ...types import PsType
class PsForeignExpression(PsExpression, ABC): class PsForeignExpression(PsExpression, ABC):
"""Base class for foreign expressions. """Base class for foreign expressions.
Foreign expressions are expressions whose properties are not modelled by the pystencils AST, Foreign expressions are expressions whose properties are not modelled by the pystencils AST,
and which pystencils therefore does not understand. and which pystencils therefore does not understand.
...@@ -24,9 +24,7 @@ class PsForeignExpression(PsExpression, ABC): ...@@ -24,9 +24,7 @@ class PsForeignExpression(PsExpression, ABC):
__match_args__ = ("children",) __match_args__ = ("children",)
def __init__( def __init__(self, children: Iterable[PsExpression], dtype: PsType | None = None):
self, children: Iterable[PsExpression], dtype: PsType | None = None
):
self._children = list(children) self._children = list(children)
super().__init__(dtype) super().__init__(dtype)
......
...@@ -5,6 +5,7 @@ from .typification import Typifier ...@@ -5,6 +5,7 @@ from .typification import Typifier
from .ast_factory import AstFactory from .ast_factory import AstFactory
from .iteration_space import ( from .iteration_space import (
IterationSpace,
FullIterationSpace, FullIterationSpace,
SparseIterationSpace, SparseIterationSpace,
create_full_iteration_space, create_full_iteration_space,
...@@ -19,6 +20,7 @@ __all__ = [ ...@@ -19,6 +20,7 @@ __all__ = [
"FreezeExpressions", "FreezeExpressions",
"Typifier", "Typifier",
"AstFactory", "AstFactory",
"IterationSpace",
"FullIterationSpace", "FullIterationSpace",
"SparseIterationSpace", "SparseIterationSpace",
"create_full_iteration_space", "create_full_iteration_space",
......
...@@ -208,7 +208,7 @@ class FullIterationSpace(IterationSpace): ...@@ -208,7 +208,7 @@ class FullIterationSpace(IterationSpace):
@property @property
def archetype_field(self) -> Field | None: def archetype_field(self) -> Field | None:
return self._archetype_field return self._archetype_field
def dimensions_in_loop_order(self) -> Sequence[FullIterationSpace.Dimension]: def dimensions_in_loop_order(self) -> Sequence[FullIterationSpace.Dimension]:
"""Return the dimensions of this iteration space ordered from the fastest to the slowest coordinate. """Return the dimensions of this iteration space ordered from the fastest to the slowest coordinate.
...@@ -220,7 +220,9 @@ class FullIterationSpace(IterationSpace): ...@@ -220,7 +220,9 @@ class FullIterationSpace(IterationSpace):
else: else:
return self._dimensions return self._dimensions
def actual_iterations(self, dimension: int | FullIterationSpace.Dimension | None = None) -> PsExpression: def actual_iterations(
self, dimension: int | FullIterationSpace.Dimension | None = None
) -> PsExpression:
if dimension is None: if dimension is None:
return reduce( return reduce(
mul, (self.actual_iterations(d) for d in range(len(self.dimensions))) mul, (self.actual_iterations(d) for d in range(len(self.dimensions)))
......
from pystencils.backend.functions import CFunction, PsMathFunction from pystencils.backend.functions import CFunction, PsMathFunction
from pystencils.types.types import PsType from pystencils.backend.kernelcreation.context import KernelCreationContext
from .platform import Platform from pystencils.types import PsType, constify
from ..exceptions import MaterializationError
from .generic_gpu import GenericGpu, GpuThreadsRange
from ..kernelcreation.iteration_space import ( from ..kernelcreation import (
Typifier,
IterationSpace, IterationSpace,
FullIterationSpace, FullIterationSpace,
# SparseIterationSpace, SparseIterationSpace,
) )
from ..ast.structural import PsBlock, PsConditional from ..ast.structural import PsBlock, PsConditional, PsDeclaration
from ..ast.expressions import ( from ..ast.expressions import PsExpression, PsLiteralExpr, PsCast
PsExpression,
PsLiteralExpr,
PsAdd,
)
from ..ast.expressions import PsLt, PsAnd from ..ast.expressions import PsLt, PsAnd
from ...types import PsSignedIntegerType from ...types import PsSignedIntegerType
from ..literals import PsLiteral from ..literals import PsLiteral
from ...config import GpuIndexingConfig
int32 = PsSignedIntegerType(width=32, const=False) int32 = PsSignedIntegerType(width=32, const=False)
...@@ -34,7 +34,14 @@ GRID_DIM = [ ...@@ -34,7 +34,14 @@ GRID_DIM = [
] ]
class CudaPlatform(Platform): class CudaPlatform(GenericGpu):
def __init__(
self, ctx: KernelCreationContext, indexing_cfg: GpuIndexingConfig | None
) -> None:
super().__init__(ctx)
self._cfg = indexing_cfg if indexing_cfg is not None else GpuIndexingConfig()
self._typify = Typifier(ctx)
@property @property
def required_headers(self) -> set[str]: def required_headers(self) -> set[str]:
...@@ -42,20 +49,13 @@ class CudaPlatform(Platform): ...@@ -42,20 +49,13 @@ class CudaPlatform(Platform):
def materialize_iteration_space( def materialize_iteration_space(
self, body: PsBlock, ispace: IterationSpace self, body: PsBlock, ispace: IterationSpace
) -> PsBlock: ) -> tuple[PsBlock, GpuThreadsRange]:
if isinstance(ispace, FullIterationSpace): if isinstance(ispace, FullIterationSpace):
return self._guard_full_iteration_space(body, ispace) return self._prepend_dense_translation(body, ispace)
elif isinstance(ispace, SparseIterationSpace):
return self._prepend_sparse_translation(body, ispace)
else: else:
assert False, "unreachable code" raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def cuda_indices(self, dim):
block_size = BLOCK_DIM
indices = [
block_index * bs + thread_idx
for block_index, bs, thread_idx in zip(BLOCK_IDX, block_size, THREAD_IDX)
]
return indices[:dim]
def select_function( def select_function(
self, math_function: PsMathFunction, dtype: PsType self, math_function: PsMathFunction, dtype: PsType
...@@ -63,26 +63,57 @@ class CudaPlatform(Platform): ...@@ -63,26 +63,57 @@ class CudaPlatform(Platform):
raise NotImplementedError() raise NotImplementedError()
# Internals # Internals
def _guard_full_iteration_space(
def _prepend_dense_translation(
self, body: PsBlock, ispace: FullIterationSpace self, body: PsBlock, ispace: FullIterationSpace
) -> PsBlock: ) -> tuple[PsBlock, GpuThreadsRange]:
dimensions = ispace.dimensions_in_loop_order()
launch_config = GpuThreadsRange.from_ispace(ispace)
indexing_decls = []
conds = []
for i, dim in enumerate(dimensions[::-1]):
dim.counter.dtype = constify(dim.counter.get_dtype())
ctr = PsExpression.make(dim.counter)
indexing_decls.append(
self._typify(
PsDeclaration(
ctr,
dim.start
+ dim.step
* PsCast(ctr.get_dtype(), self._linear_thread_idx(i)),
)
)
)
if not self._cfg.omit_range_check:
conds.append(PsLt(ctr, dim.stop))
if conds:
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
ast = PsBlock(indexing_decls + [PsConditional(condition, body)])
else:
body.statements = indexing_decls + body.statements
ast = body
dimensions = ispace.dimensions return ast, launch_config
# Determine loop order by permuting dimensions def _prepend_sparse_translation(
archetype_field = ispace.archetype_field self, body: PsBlock, ispace: SparseIterationSpace
if archetype_field is not None: ) -> tuple[PsBlock, GpuThreadsRange]:
loop_order = archetype_field.layout ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
dimensions = [dimensions[coordinate] for coordinate in loop_order]
start = [ ctr = PsExpression.make(ispace.sparse_counter)
PsAdd(c, d.start) thread_idx = self._linear_thread_idx(0)
for c, d in zip(self.cuda_indices(len(dimensions)), dimensions[::-1]) idx_decl = self._typify(PsDeclaration(ctr, PsCast(ctr.get_dtype(), thread_idx)))
] body.statements = [idx_decl] + body.statements
conditions = [PsLt(c, d.stop) for c, d in zip(start, dimensions[::-1])]
condition: PsExpression = conditions[0] return body, GpuThreadsRange.from_ispace(ispace)
for c in conditions[1:]:
condition = PsAnd(condition, c)
return PsBlock([PsConditional(condition, body)]) def _linear_thread_idx(self, coord: int):
block_size = BLOCK_DIM[coord]
block_idx = BLOCK_IDX[coord]
thread_idx = THREAD_IDX[coord]
return block_idx * block_size + thread_idx
...@@ -89,7 +89,7 @@ class SyclPlatform(GenericGpu): ...@@ -89,7 +89,7 @@ class SyclPlatform(GenericGpu):
if conds: if conds:
condition: PsExpression = conds[0] condition: PsExpression = conds[0]
for cond in conds: for cond in conds[1:]:
condition = PsAnd(condition, cond) condition = PsAnd(condition, cond)
ast = PsBlock(indexing_decls + [PsConditional(condition, body)]) ast = PsBlock(indexing_decls + [PsConditional(condition, body)])
else: else:
......
...@@ -71,13 +71,13 @@ class Target(Flag): ...@@ -71,13 +71,13 @@ class Target(Flag):
found on the current machine and runtime environment. found on the current machine and runtime environment.
""" """
GenericCUDA = _GPU | _CUDA CUDA = _GPU | _CUDA
"""Generic CUDA GPU target. """Generic CUDA GPU target.
Generate a CUDA kernel for a generic Nvidia GPU. Generate a CUDA kernel for a generic Nvidia GPU.
""" """
GPU = GenericCUDA GPU = CUDA
"""Alias for backward compatibility.""" """Alias for backward compatibility."""
SYCL = _GPU | _SYCL SYCL = _GPU | _SYCL
......
...@@ -86,16 +86,30 @@ def create_kernel( ...@@ -86,16 +86,30 @@ def create_kernel(
platform = GenericCpu(ctx) platform = GenericCpu(ctx)
kernel_ast = platform.materialize_iteration_space(kernel_body, ispace) kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
case Target.SYCL:
from .backend.platforms import SyclPlatform
platform = SyclPlatform(ctx, config.gpu_indexing) case target if target.is_gpu():
match target:
case Target.SYCL:
from .backend.platforms import SyclPlatform
platform = SyclPlatform(ctx, config.gpu_indexing)
case Target.CUDA:
from .backend.platforms import CudaPlatform
platform = CudaPlatform(ctx, config.gpu_indexing)
case _:
raise NotImplementedError(
f"Code generation for target {target} not implemented"
)
kernel_ast, gpu_threads = platform.materialize_iteration_space( kernel_ast, gpu_threads = platform.materialize_iteration_space(
kernel_body, ispace kernel_body, ispace
) )
case _: case _:
# TODO: CUDA/HIP platform raise NotImplementedError(
raise NotImplementedError("Target platform not implemented") f"Code generation for target {target} not implemented"
)
# Simplifying transformations # Simplifying transformations
elim_constants = EliminateConstants(ctx, extract_constant_exprs=True) elim_constants = EliminateConstants(ctx, extract_constant_exprs=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment