diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py
index 18c2277cf76102f9265114853f97b8e2eb50cc67..201321693395ca3716e70170fbbe511af3a4fca8 100644
--- a/src/pystencils/backend/functions.py
+++ b/src/pystencils/backend/functions.py
@@ -30,6 +30,7 @@ from typing import Any, Sequence, TYPE_CHECKING
 from abc import ABC
 from enum import Enum
 
+from ..sympyextensions import ReductionOp
 from ..types import PsType
 from .exceptions import PsInternalCompilerError
 
@@ -134,6 +135,48 @@ class PsMathFunction(PsFunction):
         return hash(self._func)
 
 
+class ReductionFunctions(Enum):
+    """Function representing different steps in kernels with reductions supported by the backend.
+
+    Each platform has to materialize these functions to a concrete implementation.
+    """
+
+    InitLocalCopy = ("InitLocalCopy", 2)
+    WriteBackToPtr = ("WriteBackToPtr", 2)
+
+    def __init__(self, func_name, num_args):
+        self.function_name = func_name
+        self.num_args = num_args
+
+
+class PsReductionFunction(PsFunction):
+
+    def __init__(self, func: ReductionFunctions, op: ReductionOp) -> None:
+        super().__init__(func.function_name, func.num_args)
+        self._func = func
+        self._op = op
+
+    @property
+    def func(self) -> ReductionFunctions:
+        return self._func
+
+    @property
+    def op(self) -> ReductionOp:
+        return self._op
+
+    def __str__(self) -> str:
+        return f"{self._func.function_name}"
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, PsReductionFunction):
+            return False
+
+        return self._func == other._func
+
+    def __hash__(self) -> int:
+        return hash(self._func)
+
+
 class CFunction(PsFunction):
     """A concrete C function.
 
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 62feca26504b753a9898153b4e6fb85e3fc5b2e7..059817bfda92d4714896a86a110cb257ca4cb823 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -50,7 +50,7 @@ from ..ast.expressions import (
     PsNot,
 )
 from ..ast.vector import PsVecBroadcast, PsVecMemAcc
-from ..functions import PsMathFunction, CFunction
+from ..functions import PsMathFunction, CFunction, PsReductionFunction
 from ..ast.util import determine_memory_object
 from ..exceptions import TypificationError
 
@@ -590,7 +590,7 @@ class Typifier:
 
             case PsCall(function, args):
                 match function:
-                    case PsMathFunction():
+                    case PsMathFunction() | PsReductionFunction():
                         for arg in args:
                             self.visit_expr(arg, tc)
                         tc.infer_dtype(expr)
diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py
index b145b6f76389fb75c5f94eafd3462cb084247b35..33cb28711e59b173ce3a120571653b69db40f122 100644
--- a/src/pystencils/backend/platforms/generic_cpu.py
+++ b/src/pystencils/backend/platforms/generic_cpu.py
@@ -1,11 +1,14 @@
 from abc import ABC, abstractmethod
 from typing import Sequence
 
-from pystencils.backend.ast.expressions import PsCall
+from ..ast.expressions import PsCall, PsMemAcc, PsConstantExpr
 
-from ..functions import CFunction, PsMathFunction, MathFunctions, NumericLimitsFunctions
+from ..ast import PsAstNode
+from ..functions import CFunction, PsMathFunction, MathFunctions, NumericLimitsFunctions, ReductionFunctions, \
+    PsReductionFunction
 from ..literals import PsLiteral
-from ...types import PsIntegerType, PsIeeeFloatType, PsScalarType
+from ...compound_op_mapping import compound_op_to_expr
+from ...types import PsIntegerType, PsIeeeFloatType, PsScalarType, PsPointerType
 
 from .platform import Platform
 from ..exceptions import MaterializationError
@@ -18,7 +21,7 @@ from ..kernelcreation.iteration_space import (
 )
 
 from ..constants import PsConstant
-from ..ast.structural import PsDeclaration, PsLoop, PsBlock
+from ..ast.structural import PsDeclaration, PsLoop, PsBlock, PsAssignment
 from ..ast.expressions import (
     PsSymbolExpr,
     PsExpression,
@@ -56,6 +59,36 @@ class GenericCpu(Platform):
         else:
             raise MaterializationError(f"Unknown type of iteration space: {ispace}")
 
+    def unfold_function(
+        self, call: PsCall
+    ) -> PsAstNode:
+        assert isinstance(call.function, PsReductionFunction)
+
+        func = call.function.func
+
+        match func:
+            case ReductionFunctions.InitLocalCopy:
+                symbol_expr, init_val = call.args
+                assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(init_val, PsExpression)
+
+                return PsDeclaration(symbol_expr, init_val)
+            case ReductionFunctions.WriteBackToPtr:
+                ptr_expr, symbol_expr = call.args
+                op = call.function.op
+
+                assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
+                assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
+
+                ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
+
+                # TODO: can this be avoided somehow?
+                potential_call = compound_op_to_expr(op, ptr_access, symbol_expr)
+                if isinstance(potential_call, PsCall):
+                    potential_call.dtype = symbol_expr.dtype
+                    potential_call = self.select_function(potential_call)
+
+                return PsAssignment(ptr_access, potential_call)
+
     def select_function(self, call: PsCall) -> PsExpression:
         assert isinstance(call.function, PsMathFunction)
 
diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py
index 2c7ee1c5f4750eac0375bc31a3f44b9eea50642b..732f37bbcd75e6f839728a16add34949dfc95c05 100644
--- a/src/pystencils/backend/platforms/platform.py
+++ b/src/pystencils/backend/platforms/platform.py
@@ -1,6 +1,7 @@
 from abc import ABC, abstractmethod
 from typing import Any
 
+from ..ast import PsAstNode
 from ..ast.structural import PsBlock
 from ..ast.expressions import PsCall, PsExpression
 
@@ -40,3 +41,13 @@ class Platform(ABC):
         If no viable implementation exists, raise a `MaterializationError`.
         """
         pass
+
+    @abstractmethod
+    def unfold_function(
+        self, call: PsCall
+    ) -> PsAstNode:
+        """Unfolds an implementation for the given function on the given data type.
+
+        If no viable implementation exists, raise a `MaterializationError`.
+        """
+        pass
diff --git a/src/pystencils/backend/transformations/select_functions.py b/src/pystencils/backend/transformations/select_functions.py
index e41c345ae4ed71101d07fcaa5b9df88b1e0f54e2..0045de87b7b5a4203430fd74a3641ca831826419 100644
--- a/src/pystencils/backend/transformations/select_functions.py
+++ b/src/pystencils/backend/transformations/select_functions.py
@@ -1,7 +1,7 @@
 from ..platforms import Platform
 from ..ast import PsAstNode
 from ..ast.expressions import PsCall
-from ..functions import PsMathFunction
+from ..functions import PsMathFunction, PsReductionFunction
 
 
 class SelectFunctions:
@@ -19,5 +19,7 @@ class SelectFunctions:
 
         if isinstance(node, PsCall) and isinstance(node.function, PsMathFunction):
             return self._platform.select_function(node)
+        elif isinstance(node, PsCall) and isinstance(node.function, PsReductionFunction):
+            return self._platform.unfold_function(node)
         else:
             return node
diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py
index ba7df317acab94e30c26025d246bd06d2bea5490..9a80439e7227e4f35b47556087a4281f2922de00 100644
--- a/src/pystencils/codegen/driver.py
+++ b/src/pystencils/codegen/driver.py
@@ -7,14 +7,14 @@ from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO
 from .kernel import Kernel, GpuKernel, GpuThreadsRange
 from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr
 from .parameters import Parameter
-from ..compound_op_mapping import compound_op_to_expr
-from ..backend.ast.expressions import PsSymbolExpr, PsMemAcc, PsConstantExpr
+from ..backend.functions import PsReductionFunction, ReductionFunctions
+from ..backend.ast.expressions import PsSymbolExpr, PsCall
 
 from ..types import create_numeric_type, PsIntegerType, PsScalarType
 
 from ..backend.memory import PsSymbol
 from ..backend.ast import PsAstNode
-from ..backend.ast.structural import PsBlock, PsLoop, PsAssignment, PsDeclaration
+from ..backend.ast.structural import PsBlock, PsLoop
 from ..backend.ast.analysis import collect_undefined_symbols, collect_required_headers
 from ..backend.kernelcreation import (
     KernelCreationContext,
@@ -156,19 +156,20 @@ class DefaultKernelCreationDriver:
 
         #   Extensions for reductions
         for symbol, reduction_info in self._ctx.symbols_reduction_info.items():
-            # Init local reduction variable copy
-            kernel_ast.statements = [PsDeclaration(PsSymbolExpr(symbol),
-                                                   reduction_info.init_val)] + kernel_ast.statements
+            typify = Typifier(self._ctx)
+            symbol_expr = typify(PsSymbolExpr(symbol))
+            ptr_symbol_expr = typify(PsSymbolExpr(reduction_info.ptr_symbol))
+            init_val = typify(reduction_info.init_val)
 
-            # Write back result to reduction target variable
-            ptr_access = PsMemAcc(PsSymbolExpr(reduction_info.ptr_symbol),
-                                  PsConstantExpr(PsConstant(0)))
-            kernel_ast.statements += [PsAssignment(
-                ptr_access, compound_op_to_expr(reduction_info.op, ptr_access, PsSymbolExpr(symbol)))]
+            init_local_copy = PsCall(PsReductionFunction(ReductionFunctions.InitLocalCopy, reduction_info.op),
+                                     [symbol_expr, init_val])
+            write_back_ptr = PsCall(PsReductionFunction(ReductionFunctions.WriteBackToPtr, reduction_info.op),
+                                    [ptr_symbol_expr, symbol_expr])
 
-            # TODO: only newly introduced nodes
-            typify = Typifier(self._ctx)
-            kernel_ast = typify(kernel_ast)
+            # Init local reduction variable copy
+            kernel_ast.statements = [init_local_copy] + kernel_ast.statements
+            # Write back result to reduction target variable
+            kernel_ast.statements += [write_back_ptr]
 
         #   Target-Specific optimizations
         if self._cfg.target.is_cpu():
diff --git a/tests/kernelcreation/test_reduction.py b/tests/kernelcreation/test_reduction.py
index 69b75e711c1e8c306486f74bc0a43b8daea924b1..b24058571bfd87f2d50b02e1726f23077753b922 100644
--- a/tests/kernelcreation/test_reduction.py
+++ b/tests/kernelcreation/test_reduction.py
@@ -32,11 +32,11 @@ def test_reduction(dtype, op):
     config = ps.CreateKernelConfig(target=ps.Target.GPU) if gpu_avail else ps.CreateKernelConfig(cpu_openmp=True)
 
     ast_reduction = ps.create_kernel([red_assign], config, default_dtype=dtype)
+    ps.show_code(ast_reduction)
+
     # code_reduction = ps.get_code_str(ast_reduction)
     kernel_reduction = ast_reduction.compile()
 
-    ps.show_code(ast_reduction)
-
     array = np.full((SIZE,), INIT_ARR, dtype=dtype)
     reduction_array = np.full((1,), INIT_W, dtype=dtype)