diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py
index 6a9d3e4f45815eb8995e55469e7d13dbff0ed924..a3da9a1de1aa0639c5b04acfa3020bc551497b51 100644
--- a/src/pystencils/backend/functions.py
+++ b/src/pystencils/backend/functions.py
@@ -95,45 +95,31 @@ class PsMathFunction(PsFunction):
         return hash(self._func)
 
 
-class ReductionFunctions(Enum):
-    """Function representing different steps in kernels with reductions supported by the backend.
+class PsReductionWriteBack(PsFunction):
+    """Function representing a reduction kernel's write-back step supported by the backend.
 
-    Each platform has to materialize these functions to a concrete implementation.
+    Each platform has to materialize this function to a concrete implementation.
     """
 
-    WriteBackToPtr = ("WriteBackToPtr", 2)
-
-    def __init__(self, func_name, num_args):
-        self.function_name = func_name
-        self.num_args = num_args
-
-
-class PsReductionFunction(PsFunction):
-
-    def __init__(self, func: ReductionFunctions, reduction_op: ReductionOp) -> None:
-        super().__init__(func.function_name, func.num_args)
-        self._func = func
+    def __init__(self, reduction_op: ReductionOp) -> None:
+        super().__init__("WriteBackToPtr", 2)
         self._reduction_op = reduction_op
 
-    @property
-    def func(self) -> ReductionFunctions:
-        return self._func
-
     @property
     def reduction_op(self) -> ReductionOp:
         return self._reduction_op
 
     def __str__(self) -> str:
-        return f"{self._func.function_name}"
+        return f"{super().name}"
 
     def __eq__(self, other: object) -> bool:
-        if not isinstance(other, PsReductionFunction):
+        if not isinstance(other, PsReductionWriteBack):
             return False
 
-        return self._func == other._func and self._reduction_op == other._reduction_op
+        return self._reduction_op == other._reduction_op
 
     def __hash__(self) -> int:
-        return hash(self._func)
+        return hash(self._reduction_op)
 
 
 class ConstantFunctions(Enum):
diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py
index 1dc2914b9f587a11180e0df426eb2bb38a7377b8..cbde419d4eff508915b607df8a699632c1ba4416 100644
--- a/src/pystencils/backend/platforms/generic_cpu.py
+++ b/src/pystencils/backend/platforms/generic_cpu.py
@@ -8,11 +8,10 @@ from ..ast import PsAstNode
 from ..functions import (
     CFunction,
     MathFunctions,
-    ReductionFunctions,
     PsMathFunction,
-    PsReductionFunction,
     PsConstantFunction,
     ConstantFunctions,
+    PsReductionWriteBack,
 )
 from ..reduction_op_mapping import reduction_op_to_expr
 from ...sympyextensions import ReductionOp
@@ -69,14 +68,9 @@ class GenericCpu(Platform):
         self, call: PsCall
     ) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]:
         call_func = call.function
-        assert isinstance(call_func, (PsReductionFunction | PsMathFunction | PsConstantFunction))
+        assert isinstance(call_func, (PsReductionWriteBack | PsMathFunction | PsConstantFunction))
 
-        func = call_func.func
-
-        if (
-            isinstance(call_func, PsReductionFunction)
-            and func is ReductionFunctions.WriteBackToPtr
-        ):
+        if isinstance(call_func, PsReductionWriteBack):
             ptr_expr, symbol_expr = call.args
             op = call_func.reduction_op
 
@@ -103,6 +97,7 @@ class GenericCpu(Platform):
             return potential_call
 
         dtype = call.get_dtype()
+        func = call_func.func
         arg_types = (dtype,) * call.function.arg_count
 
         expr: PsExpression | None = None
diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py
index 06b230454c289bf4e1bce294e725cb54cae624ad..876da4c912c68fed993b7d864e69748c47345277 100644
--- a/src/pystencils/backend/platforms/generic_gpu.py
+++ b/src/pystencils/backend/platforms/generic_gpu.py
@@ -52,8 +52,7 @@ from ..literals import PsLiteral
 from ..functions import (
     MathFunctions,
     CFunction,
-    ReductionFunctions,
-    PsReductionFunction,
+    PsReductionWriteBack,
     PsMathFunction,
     PsConstantFunction,
     ConstantFunctions,
@@ -296,20 +295,16 @@ class GenericGpu(Platform):
         self, call: PsCall
     ) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]:
         call_func = call.function
-        assert isinstance(call_func, (PsReductionFunction | PsMathFunction | PsConstantFunction))
+        assert isinstance(call_func, (PsReductionWriteBack | PsMathFunction | PsConstantFunction))
 
-        func = call_func.func
-
-        if (
-            isinstance(call_func, PsReductionFunction)
-            and func is ReductionFunctions.WriteBackToPtr
-        ):
+        if isinstance(call_func, PsReductionWriteBack):
             ptr_expr, symbol_expr = call.args
             op = call_func.reduction_op
 
             return self.resolve_reduction(ptr_expr, symbol_expr, op)
 
         dtype = call.get_dtype()
+        func = call_func.func
         arg_types = (dtype,) * call.function.arg_count
         expr: PsExpression | None = None
 
diff --git a/src/pystencils/backend/transformations/select_functions.py b/src/pystencils/backend/transformations/select_functions.py
index d005acb4bcf3042473826383384e16d9fc7dd4fc..5953bd47db53e716fffef9b2cd1be1d2897fcf52 100644
--- a/src/pystencils/backend/transformations/select_functions.py
+++ b/src/pystencils/backend/transformations/select_functions.py
@@ -3,7 +3,7 @@ from ..exceptions import MaterializationError
 from ..platforms import Platform
 from ..ast import PsAstNode
 from ..ast.expressions import PsCall, PsExpression
-from ..functions import PsMathFunction, PsConstantFunction, PsReductionFunction
+from ..functions import PsMathFunction, PsConstantFunction, PsReductionWriteBack
 
 
 class SelectFunctions:
@@ -22,7 +22,7 @@ class SelectFunctions:
         if isinstance(node, PsAssignment):
             rhs = node.rhs
             if isinstance(rhs, PsCall) and isinstance(
-                rhs.function, PsReductionFunction
+                rhs.function, PsReductionWriteBack
             ):
                 resolved_func = self._platform.select_function(rhs)
 
diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py
index 74a07b902a6494d0c528602326286a2aaab01cc5..68ab59bd9265d5cdadac253f45e61f8259cfa2be 100644
--- a/src/pystencils/codegen/driver.py
+++ b/src/pystencils/codegen/driver.py
@@ -26,7 +26,7 @@ from ..types import PsIntegerType, PsScalarType
 
 from ..backend.memory import PsSymbol
 from ..backend.ast import PsAstNode
-from ..backend.functions import PsReductionFunction, ReductionFunctions
+from ..backend.functions import PsReductionWriteBack
 from ..backend.ast.expressions import (
     PsExpression,
     PsSymbolExpr,
@@ -308,9 +308,7 @@ class DefaultKernelCreationDriver:
             ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype))
         )
         write_back_ptr = PsCall(
-            PsReductionFunction(
-                ReductionFunctions.WriteBackToPtr, reduction_info.op
-            ),
+            PsReductionWriteBack(reduction_info.op),
             [ptr_symbol_expr, symbol_expr],
         )