diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py
index eff88df7e2dcb317403ff417bd996f5e8acb09c0..6f32102deafac5a40c37fe13bcb62bc5426107c0 100644
--- a/src/pystencils/backend/platforms/cuda.py
+++ b/src/pystencils/backend/platforms/cuda.py
@@ -1,9 +1,16 @@
 from __future__ import annotations
+
+import math
+import operator
 from abc import ABC, abstractmethod
+from functools import reduce
 
 from ..ast import PsAstNode
+from ..constants import PsConstant
+from ...compound_op_mapping import compound_op_to_expr
 from ...sympyextensions.reduction import ReductionOp
 from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType
+from ...types.quick import UInt, SInt
 from ..exceptions import MaterializationError
 from .generic_gpu import GenericGpu
 
@@ -17,14 +24,14 @@ from ..kernelcreation import (
 )
 
 from ..kernelcreation.context import KernelCreationContext
-from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement
+from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement, PsAssignment
 from ..ast.expressions import (
     PsExpression,
     PsLiteralExpr,
     PsCast,
     PsCall,
     PsLookup,
-    PsBufferAcc, PsSymbolExpr
+    PsBufferAcc, PsSymbolExpr, PsConstantExpr, PsAdd, PsRem, PsEq
 )
 from ..ast.expressions import PsLt, PsAnd
 from ...types import PsSignedIntegerType, PsIeeeFloatType
@@ -292,26 +299,58 @@ class CudaPlatform(GenericGpu):
             case ReductionFunctions.WriteBackToPtr:
                 ptr_expr, symbol_expr = call.args
                 op = call.function.reduction_op
+                stype = symbol_expr.dtype
+                ptrtype = ptr_expr.dtype
+
+                warp_size = 32   # TODO: get from platform/user config
+
+                assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
+                assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
 
-                assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
-                assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
+                if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
+                    NotImplementedError("atomic operations are only available for float32/64 datatypes")
 
+                def gen_shuffle_instr(offset: int):
+                    return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
+                                  [PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))),
+                                   symbol_expr,
+                                   PsConstantExpr(PsConstant(offset, SInt(32)))])
+
+                # workaround for subtractions -> use additions for reducing intermediate results
+                # similar to OpenMP reductions: local copies (negative sign) are added at the end
                 match op:
                     case ReductionOp.Sub:
-                        # workaround for unsupported atomicSub: use atomic add
-                        # similar to OpenMP reductions: local copies (negative sign) are added at the end
-                        call.function = CFunction("atomicAdd", [ptr_expr.dtype, symbol_expr.dtype],
-                                                  PsCustomType("void"))
-                        call.args = (ptr_expr, symbol_expr)
+                        actual_op = ReductionOp.Add
                     case _:
-                        call.function = CFunction(f"atomic{op.name}", [ptr_expr.dtype, symbol_expr.dtype],
-                                                  PsCustomType("void"))
-                        call.args = (ptr_expr, symbol_expr)
-
-                if not isinstance(symbol_expr.dtype, PsIeeeFloatType) or symbol_expr.dtype.width not in (32, 64):
-                    NotImplementedError("atomicMul is only available for float32/64 datatypes")
-
-                return PsStatement(call)
+                        actual_op = op
+
+                # perform local warp reductions
+                num_shuffles = math.frexp(warp_size)[1] - 1
+                shuffles = [PsAssignment(symbol_expr, compound_op_to_expr(actual_op, symbol_expr, gen_shuffle_instr(i)))
+                            for i in reversed(range(1, num_shuffles))]
+
+                # find first thread in warp
+                ispace = self._ctx.get_iteration_space()  # TODO: receive as argument in unfold_function?
+                is_valid_thread = self._get_condition_for_translation(ispace)
+                thread_indices_per_dim = [
+                    idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
+                    for i, idx in enumerate(THREAD_IDX[:ispace.rank])
+                ]
+                tid: PsExpression = thread_indices_per_dim[0]
+                for t in thread_indices_per_dim[1:]:
+                    tid = PsAdd(tid, t)
+                first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))),
+                                            PsConstantExpr(PsConstant(0, SInt(32))))
+                cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
+
+                # use atomic operation on first thread of warp to sync
+                call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
+                call.args = (ptr_expr, symbol_expr)
+
+                # assemble warp reduction
+                return PsBlock(
+                    shuffles
+                    + [PsConditional(cond, PsBlock([PsStatement(call)]))])
 
     #   Internals