From 90837d04eaad5b3ad049c11fa5af48cf0942e812 Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Thu, 20 Mar 2025 15:36:01 +0100 Subject: [PATCH] Merge handling for GPU reductions into generic_gpu.py for the time being --- .../backend/platforms/generic_gpu.py | 147 +++++++++++++++--- src/pystencils/codegen/driver.py | 5 + 2 files changed, 130 insertions(+), 22 deletions(-) diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 11425d923..f16c28e8c 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -1,7 +1,16 @@ from __future__ import annotations -from abc import ABC, abstractmethod -from ...types import constify, deconstify +import math +import operator +from abc import ABC, abstractmethod +from functools import reduce + +from ..ast import PsAstNode +from ..constants import PsConstant +from ...compound_op_mapping import compound_op_to_expr +from ...sympyextensions.reduction import ReductionOp +from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType +from ...types.quick import UInt, SInt from ..exceptions import MaterializationError from .platform import Platform @@ -15,7 +24,7 @@ from ..kernelcreation import ( ) from ..kernelcreation.context import KernelCreationContext -from ..ast.structural import PsBlock, PsConditional, PsDeclaration +from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement, PsAssignment from ..ast.expressions import ( PsExpression, PsLiteralExpr, @@ -23,12 +32,17 @@ from ..ast.expressions import ( PsCall, PsLookup, PsBufferAcc, + PsSymbolExpr, + PsConstantExpr, + PsAdd, + PsRem, + PsEq ) from ..ast.expressions import PsLt, PsAnd from ...types import PsSignedIntegerType, PsIeeeFloatType from ..literals import PsLiteral -from ..functions import PsMathFunction, MathFunctions, CFunction - +from ..functions import MathFunctions, CFunction, ReductionFunctions, NumericLimitsFunctions, PsReductionFunction, \ + PsMathFunction int32 = PsSignedIntegerType(width=32, const=False) @@ -174,10 +188,15 @@ class GenericGpu(Platform): def __init__( self, ctx: KernelCreationContext, + assume_warp_aligned_block_size: bool, + warp_size: int | None, thread_mapping: ThreadMapping | None = None, ) -> None: super().__init__(ctx) + self._assume_warp_aligned_block_size = assume_warp_aligned_block_size + self._warp_size = warp_size + self._thread_mapping = ( thread_mapping if thread_mapping is not None else Linear3DMapping() ) @@ -194,14 +213,107 @@ class GenericGpu(Platform): else: raise MaterializationError(f"Unknown type of iteration space: {ispace}") - def select_function(self, call: PsCall) -> PsExpression: - assert isinstance(call.function, PsMathFunction) + @staticmethod + def _get_condition_for_translation(ispace: IterationSpace): + + if isinstance(ispace, FullIterationSpace): + conds = [] + + dimensions = ispace.dimensions_in_loop_order() + + for dim in dimensions: + ctr_expr = PsExpression.make(dim.counter) + conds.append(PsLt(ctr_expr, dim.stop)) + + condition: PsExpression = conds[0] + for cond in conds[1:]: + condition = PsAnd(condition, cond) + + return condition + elif isinstance(ispace, SparseIterationSpace): + sparse_ctr_expr = PsExpression.make(ispace.sparse_counter) + stop = PsExpression.make(ispace.index_list.shape[0]) + + return PsLt(sparse_ctr_expr.clone(), stop) + else: + raise MaterializationError(f"Unknown type of iteration space: {ispace}") + + def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]: + call_func = call.function + assert isinstance(call_func, PsReductionFunction | PsMathFunction) + + func = call_func.func + + if isinstance(call_func, PsReductionFunction) and func is ReductionFunctions.WriteBackToPtr: + ptr_expr, symbol_expr = call.args + op = call_func.reduction_op + stype = symbol_expr.dtype + ptrtype = ptr_expr.dtype + + assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType) + assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType) + + if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64): + NotImplementedError("atomic operations are only available for float32/64 datatypes") + + # workaround for subtractions -> use additions for reducing intermediate results + # similar to OpenMP reductions: local copies (negative sign) are added at the end + match op: + case ReductionOp.Sub: + actual_op = ReductionOp.Add + case _: + actual_op = op + + # perform local warp reductions + def gen_shuffle_instr(offset: int): + full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))) + return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype), + [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))]) + + num_shuffles = math.frexp(self._warp_size)[1] + shuffles = tuple(PsAssignment(symbol_expr, + compound_op_to_expr(actual_op, + symbol_expr, gen_shuffle_instr(pow(2, i - 1)))) + for i in reversed(range(1, num_shuffles))) + + # find first thread in warp + ispace = self._ctx.get_iteration_space() + is_valid_thread = self._get_condition_for_translation(ispace) + thread_indices_per_dim = [ + idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32))) + for i, idx in enumerate(THREAD_IDX[:ispace.rank]) + ] + tid: PsExpression = thread_indices_per_dim[0] + for t in thread_indices_per_dim[1:]: + tid = PsAdd(tid, t) + first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(self._warp_size, SInt(32)))), + PsConstantExpr(PsConstant(0, SInt(32)))) + cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp + + # use atomic operation on first thread of warp to sync + call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void")) + call.args = (ptr_expr, symbol_expr) + + # assemble warp reduction + return shuffles, PsConditional(cond, PsBlock([PsStatement(call)])) - func = call.function.func dtype = call.get_dtype() arg_types = (dtype,) * func.num_args - if isinstance(dtype, PsIeeeFloatType): + if isinstance(dtype, PsScalarType) and func in NumericLimitsFunctions: + assert isinstance(dtype, PsIeeeFloatType) + + match func: + case NumericLimitsFunctions.Min: + define = "NEG_INFINITY" + case NumericLimitsFunctions.Max: + define = "POS_INFINITY" + case _: + raise MaterializationError(f"Cannot materialize call to function {func}") + + return PsLiteralExpr(PsLiteral(define, dtype)) + + if isinstance(dtype, PsIeeeFloatType) and func in MathFunctions: match func: case ( MathFunctions.Exp @@ -262,7 +374,6 @@ class GenericGpu(Platform): ctr_mapping = self._thread_mapping(ispace) indexing_decls = [] - conds = [] dimensions = ispace.dimensions_in_loop_order() @@ -276,14 +387,9 @@ class GenericGpu(Platform): indexing_decls.append( self._typify(PsDeclaration(ctr_expr, ctr_mapping[dim.counter])) ) - conds.append(PsLt(ctr_expr, dim.stop)) - - condition: PsExpression = conds[0] - for cond in conds[1:]: - condition = PsAnd(condition, cond) - ast = PsBlock(indexing_decls + [PsConditional(condition, body)]) - return ast + cond = self._get_condition_for_translation(ispace) + return PsBlock(indexing_decls + [PsConditional(cond, body)]) def _prepend_sparse_translation( self, body: PsBlock, ispace: SparseIterationSpace @@ -313,8 +419,5 @@ class GenericGpu(Platform): ] body.statements = mappings + body.statements - stop = PsExpression.make(ispace.index_list.shape[0]) - condition = PsLt(sparse_ctr_expr.clone(), stop) - ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)]) - - return ast + cond = self._get_condition_for_translation(ispace) + return PsBlock([sparse_idx_decl, PsConditional(cond, body)]) diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index c2bee0ad2..3962c316b 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -475,6 +475,9 @@ class DefaultKernelCreationDriver: else None ) + assume_warp_aligned_block_size: bool = self._cfg.gpu.get_option("assume_warp_aligned_block_size") + warp_size: int | None = self._cfg.gpu.get_option("warp_size") + GpuPlatform: type match self._target: case Target.CUDA: @@ -486,6 +489,8 @@ class DefaultKernelCreationDriver: return GpuPlatform( self._ctx, + assume_warp_aligned_block_size, + warp_size, thread_mapping=thread_mapping, ) -- GitLab