diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index eff88df7e2dcb317403ff417bd996f5e8acb09c0..6f32102deafac5a40c37fe13bcb62bc5426107c0 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -1,9 +1,16 @@ from __future__ import annotations + +import math +import operator from abc import ABC, abstractmethod +from functools import reduce from ..ast import PsAstNode +from ..constants import PsConstant +from ...compound_op_mapping import compound_op_to_expr from ...sympyextensions.reduction import ReductionOp from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType +from ...types.quick import UInt, SInt from ..exceptions import MaterializationError from .generic_gpu import GenericGpu @@ -17,14 +24,14 @@ from ..kernelcreation import ( ) from ..kernelcreation.context import KernelCreationContext -from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement +from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement, PsAssignment from ..ast.expressions import ( PsExpression, PsLiteralExpr, PsCast, PsCall, PsLookup, - PsBufferAcc, PsSymbolExpr + PsBufferAcc, PsSymbolExpr, PsConstantExpr, PsAdd, PsRem, PsEq ) from ..ast.expressions import PsLt, PsAnd from ...types import PsSignedIntegerType, PsIeeeFloatType @@ -292,26 +299,58 @@ class CudaPlatform(GenericGpu): case ReductionFunctions.WriteBackToPtr: ptr_expr, symbol_expr = call.args op = call.function.reduction_op + stype = symbol_expr.dtype + ptrtype = ptr_expr.dtype + + warp_size = 32 # TODO: get from platform/user config + + assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType) + assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType) - assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType) - assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType) + if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64): + NotImplementedError("atomic operations are only available for float32/64 datatypes") + def gen_shuffle_instr(offset: int): + return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype), + [PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))), + symbol_expr, + PsConstantExpr(PsConstant(offset, SInt(32)))]) + + # workaround for subtractions -> use additions for reducing intermediate results + # similar to OpenMP reductions: local copies (negative sign) are added at the end match op: case ReductionOp.Sub: - # workaround for unsupported atomicSub: use atomic add - # similar to OpenMP reductions: local copies (negative sign) are added at the end - call.function = CFunction("atomicAdd", [ptr_expr.dtype, symbol_expr.dtype], - PsCustomType("void")) - call.args = (ptr_expr, symbol_expr) + actual_op = ReductionOp.Add case _: - call.function = CFunction(f"atomic{op.name}", [ptr_expr.dtype, symbol_expr.dtype], - PsCustomType("void")) - call.args = (ptr_expr, symbol_expr) - - if not isinstance(symbol_expr.dtype, PsIeeeFloatType) or symbol_expr.dtype.width not in (32, 64): - NotImplementedError("atomicMul is only available for float32/64 datatypes") - - return PsStatement(call) + actual_op = op + + # perform local warp reductions + num_shuffles = math.frexp(warp_size)[1] - 1 + shuffles = [PsAssignment(symbol_expr, compound_op_to_expr(actual_op, symbol_expr, gen_shuffle_instr(i))) + for i in reversed(range(1, num_shuffles))] + + # find first thread in warp + ispace = self._ctx.get_iteration_space() # TODO: receive as argument in unfold_function? + is_valid_thread = self._get_condition_for_translation(ispace) + thread_indices_per_dim = [ + idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32))) + for i, idx in enumerate(THREAD_IDX[:ispace.rank]) + ] + tid: PsExpression = thread_indices_per_dim[0] + for t in thread_indices_per_dim[1:]: + tid = PsAdd(tid, t) + first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))), + PsConstantExpr(PsConstant(0, SInt(32)))) + cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp + + # use atomic operation on first thread of warp to sync + call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void")) + call.args = (ptr_expr, symbol_expr) + + # assemble warp reduction + return PsBlock( + shuffles + + [PsConditional(cond, PsBlock([PsStatement(call)]))]) # Internals