diff --git a/src/pyronn_torch/conebeam.py b/src/pyronn_torch/conebeam.py
index ece50fe6031e33688250117311eeb450767d7a49..5343f8c7edce53e5f8c476f96c3bc40513330463 100644
--- a/src/pyronn_torch/conebeam.py
+++ b/src/pyronn_torch/conebeam.py
@@ -13,70 +13,75 @@ import torch
 import pyronn_torch
 
 
+class _ForwardProjection(torch.autograd.Function):
+
+    def __init__(self, projection_shape,
+                 source_points,
+                 inverse_matrices,
+                 projection_matrices,
+                 volume_origin,
+                 volume_spacing,
+                 projection_multiplier,
+                 step_size=1.,
+                 with_texture=True):
+        self.projection_shape = projection_shape
+        self.source_points = source_points
+        self.inverse_matrices = inverse_matrices
+        self.projection_matrices = projection_matrices
+        self.volume_origin = volume_origin
+        self.volume_spacing = volume_spacing
+        self.projection_multiplier = projection_multiplier
+        self.with_texture = with_texture
+        self.step_size = step_size
+
+    def forward(self, volume):
+        volume = volume.cuda().contiguous()
+        projection = torch.zeros(self.projection_shape, device='cuda', requires_grad=volume.requires_grad)
+
+        assert pyronn_torch.cpp_extension
+        if self.with_texture:
+            pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Tex_Interp_Launcher(
+                self.inverse_matrices,
+                projection,
+                self.source_points,
+                self.step_size,
+                volume,
+                *self.volume_spacing)
+        else:
+            pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Launcher(
+                self.inverse_matrices,
+                projection,
+                self.source_points,
+                self.step_size,
+                volume,
+                *self.volume_spacing)
+
+        return projection,
+
+    def backward(self, *projection_grad):
+        projection_grad = projection_grad[0]
+        self.projection_matrices
+        volume_grad = torch.zeros(self.volume_shape, device='cuda')
+
+        assert pyronn_torch.cpp_extension
+        pyronn_torch.cpp_extension.call_Cone_Backprojection3D_Kernel_Launcher(
+            self.projection_matrices,
+            projection_grad,
+            self.projection_multiplier,
+            volume_grad,
+            self.volume_origin,
+            self.volume_spacing)
+
+        return volume_grad,
+
+
+class _BackwardProjection(torch.autograd.Function):
+    __init__ = _ForwardProjection.__init__
+    backward = _ForwardProjection.forward
+    forward = _ForwardProjection.backward
+
+
 class ConeBeamProjector:
-    class ForwardProjection(torch.autograd.Function):
-
-        def __init__(self, projection_shape,
-                     source_points,
-                     inverse_matrices,
-                     projection_matrices,
-                     volume_origin,
-                     volume_spacing,
-                     projection_multiplier,
-                     step_size=1.,
-                     with_texture=True):
-            self.projection_shape = projection_shape
-            self.source_points = source_points
-            self.inverse_matrices = inverse_matrices
-            self.projection_matrices = projection_matrices
-            self.volume_origin = volume_origin
-            self.volume_spacing = volume_spacing
-            self.projection_multiplier = projection_multiplier
-            self.with_texture = with_texture
-            self.step_size = step_size
-
-        def forward(self, volume):
-            volume = volume.cuda().contiguous()
-            projection = torch.zeros(self.projection_shape, device='cuda', requires_grad=volume.requires_grad)
-
-            assert pyronn_torch.cpp_extension
-            if self.with_texture:
-                pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Tex_Interp_Launcher(
-                    self.inverse_matrices,
-                    projection,
-                    self.source_points,
-                    self.step_size,
-                    volume,
-                    *self.volume_spacing)
-            else:
-                pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Launcher(
-                    self.inverse_matrices,
-                    projection,
-                    self.source_points,
-                    self.step_size,
-                    volume,
-                    *self.volume_spacing)
-
-            return projection,
-
-        def backward(self, *projection_grad):
-            projection_grad = projection_grad[0]
-            self.projection_matrices
-            volume_grad = torch.zeros(self.volume_shape, device='cuda')
-
-            assert pyronn_torch.cpp_extension
-            pyronn_torch.cpp_extension.call_Cone_Backprojection3D_Kernel_Launcher(
-                self.projection_matrices,
-                projection_grad,
-                self.projection_multiplier,
-                volume_grad,
-                self.volume_origin,
-                self.volume_spacing)
-
-            return volume_grad,
-
-    class BackwardProjection(torch.autograd.Function):
-        pass
 
     def __init__(self,
                  volume_shape,
@@ -125,18 +130,26 @@ class ConeBeamProjector:
         return torch.zeros(self._projection_shape, requires_grad=requires_grad).cuda()
 
     def project_forward(self, volume, step_size=1., use_texture=True):
-        return self.ForwardProjection(self._projection_shape,
-                                      self._source_points,
-                                      self._inverse_matrices,
-                                      self._projection_matrices,
-                                      self._volume_origin,
-                                      self._volume_shape,
-                                      self._projection_multiplier,
-                                      step_size,
-                                      use_texture).forward(volume)[0]
-
-    def project_backward(self, projection_stack):
-        return self.BackwardProjection(projection_stack)
+        return _ForwardProjection(self._projection_shape,
+                                  self._source_points,
+                                  self._inverse_matrices,
+                                  self._projection_matrices,
+                                  self._volume_origin,
+                                  self._volume_shape,
+                                  self._projection_multiplier,
+                                  step_size,
+                                  use_texture).forward(volume)[0]
+
+    def project_backward(self, projection_stack, step_size=1., use_texture=True):
+        return _BackwardProjection(self._projection_shape,
+                                   self._source_points,
+                                   self._inverse_matrices,
+                                   self._projection_matrices,
+                                   self._volume_origin,
+                                   self._volume_shape,
+                                   self._projection_multiplier,
+                                   step_size,
+                                   use_texture).backward(projection_stack)[0]
 
     def _calc_inverse_matrices(self):
         if self._projection_matrices_numpy is None:
diff --git a/tests/test_projection.py b/tests/test_projection.py
index f72f3cc20ff02e8226db1b36de3ad1d8ac22e6df..9562fa8b1488c42131da51d027aaeac4f395b80c 100644
--- a/tests/test_projection.py
+++ b/tests/test_projection.py
@@ -53,6 +53,41 @@ def test_projection(with_texture, with_backward):
         loss.backward()
 
 
+@pytest.mark.parametrize('with_texture', ('with_texture', False))
+@pytest.mark.parametrize('with_backward', ('with_backward', False))
+def test_projection_backward(with_texture, with_backward):
+    projector = pyronn_torch.ConeBeamProjector(
+        (128, 128, 128),
+        (2.0, 2.0, 2.0),
+        (-127.5, -127.5, -127.5),
+        (2, 480, 620),
+        [1.0, 1.0],
+        (0, 0),
+        np.array([[[-3.10e+2, -1.20e+03,  0.00e+00,  1.86e+5],
+                   [-2.40e+2,  0.00e+00,  1.20e+03,  1.44e+5],
+                   [-1.00e+00,  0.00e+00,  0.00e+00,  6.00e+2]],
+                  [[-2.89009888e+2, -1.20522754e+3, -1.02473585e-13,
+                    1.86000000e+5],
+                   [-2.39963440e+2, -4.18857765e+0,  1.20000000e+3,
+                    1.44000000e+5],
+                   [-9.99847710e-01, -1.74524058e-2,  0.00000000e+0,
+                    6.00000000e+2]]])
+    )
+
+    projection = projector.new_projection_tensor(requires_grad=True if with_backward else False)
+
+    projection += 1.
+    result = projector.project_backward(projection, use_texture=with_texture)
+
+    assert result is not None
+    if with_backward:
+        assert projection.requires_grad
+        assert result.requires_grad
+
+        loss = result.mean()
+        loss.backward()
+
+
 @pytest.mark.skipif('CI' in os.environ, reason="No conrad config on CI")
 @pytest.mark.parametrize('with_backward', ('with_backward', False))
 def test_conrad_config(with_backward, with_texture=True):