diff --git a/src/pyronn_torch/conebeam.py b/src/pyronn_torch/conebeam.py
index 278d2a5b54c265583a8bee08b2556b882682b2ea..c9409fb2b5e6bff73e8b07694f33ecf375142009 100644
--- a/src/pyronn_torch/conebeam.py
+++ b/src/pyronn_torch/conebeam.py
@@ -12,7 +12,7 @@ import pyronn_torch
 import sympy as sp
 
 
-class _ForwardProjection(torch.autograd.Function):
+class State:
     def __init__(self,
                  projection_shape,
                  volume_shape,
@@ -35,40 +35,63 @@ class _ForwardProjection(torch.autograd.Function):
         self.with_texture = with_texture
         self.step_size = step_size
 
-    def forward(self, volume):
+
+class _ForwardProjection(torch.autograd.Function):
+    @staticmethod
+    def forward(self, volume, state=None):
+        if state is None:
+            state = self.state
+            return_none = True
+        else:
+            return_none = False
+
         volume = volume.float().cuda().contiguous()
-        projection = torch.zeros(self.projection_shape,
+        projection = torch.zeros(state.projection_shape,
                                  device='cuda',
                                  requires_grad=volume.requires_grad)
 
         assert pyronn_torch.cpp_extension
-        if self.with_texture:
+        if state.with_texture:
             pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Tex_Interp_Launcher(
-                self.inverse_matrices, projection, self.source_points,
-                self.step_size, volume, *self.volume_spacing)
+                state.inverse_matrices, projection, state.source_points,
+                state.step_size, volume, *state.volume_spacing)
         else:
             pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Launcher(
-                self.inverse_matrices, projection, self.source_points,
-                self.step_size, volume, *self.volume_spacing)
+                state.inverse_matrices, projection, state.source_points,
+                state.step_size, volume, *state.volume_spacing)
 
-        return projection,
+        self.state = state
+        if return_none:
+            return projection, None
+        else:
+            return projection,
+
+    @staticmethod
+    def backward(self, projection_grad, state=None, *args):
+        if state is None:
+            state = self.state
+            return_none = True
+        else:
+            return_none = False
 
-    def backward(self, *projection_grad):
-        projection_grad = projection_grad[0]
         projection_grad = projection_grad.float().cuda().contiguous()
-        volume_grad = torch.zeros(self.volume_shape, device='cuda', requires_grad=projection_grad.requires_grad)
+        volume_grad = torch.zeros(state.volume_shape,
+                                  device='cuda',
+                                  requires_grad=projection_grad.requires_grad)
 
         assert pyronn_torch.cpp_extension
         pyronn_torch.cpp_extension.call_Cone_Backprojection3D_Kernel_Launcher(
-            self.projection_matrices, projection_grad,
-            self.projection_multiplier, volume_grad, *self.volume_origin,
-            *self.volume_spacing)
+            state.projection_matrices, projection_grad,
+            state.projection_multiplier, volume_grad, *state.volume_origin,
+            *state.volume_spacing)
 
-        return volume_grad,
+        if return_none:
+            return volume_grad, None
+        else:
+            return volume_grad,
 
 
 class _BackwardProjection(torch.autograd.Function):
-    __init__ = _ForwardProjection.__init__
     backward = _ForwardProjection.forward
     forward = _ForwardProjection.backward
 
@@ -118,29 +141,33 @@ class ConeBeamProjector:
                            requires_grad=requires_grad).cuda()
 
     def project_forward(self, volume, step_size=1., use_texture=True):
-        return _ForwardProjection(self._projection_shape, self._volume_shape,
-                                  self._source_points, self._inverse_matrices,
-                                  self._projection_matrices,
-                                  self._volume_origin, self._volume_spacing,
-                                  self._projection_multiplier, step_size,
-                                  use_texture).forward(volume)[0]
+        return _ForwardProjection().apply(
+            volume,
+            State(self._projection_shape, self._volume_shape,
+                  self._source_points, self._inverse_matrices,
+                  self._projection_matrices, self._volume_origin,
+                  self._volume_spacing, self._projection_multiplier, step_size,
+                  use_texture))[0]
 
     def project_backward(self,
                          projection_stack,
                          step_size=1.,
                          use_texture=True):
-        return _BackwardProjection(self._projection_shape, self._volume_shape,
-                                   self._source_points, self._inverse_matrices,
-                                   self._projection_matrices,
-                                   self._volume_origin, self._volume_spacing,
-                                   self._projection_multiplier, step_size,
-                                   use_texture).forward(projection_stack)[0]
+        return _BackwardProjection().apply(
+            projection_stack,
+            State(self._projection_shape, self._volume_shape,
+                  self._source_points, self._inverse_matrices,
+                  self._projection_matrices, self._volume_origin,
+                  self._volume_spacing, self._projection_multiplier, step_size,
+                  use_texture))[0]
 
     def _calc_inverse_matrices(self):
         if self._projection_matrices_numpy is None:
             return
         self._projection_matrices = torch.stack(
-            tuple(torch.from_numpy(p.astype(np.float32)) for p in self._projection_matrices_numpy)).cuda().contiguous()
+            tuple(
+                torch.from_numpy(p.astype(np.float32))
+                for p in self._projection_matrices_numpy)).cuda().contiguous()
 
         inv_spacing = np.array([1 / s for s in reversed(self._volume_spacing)],
                                np.float32)