diff --git a/src/pyronn_torch/conebeam.py b/src/pyronn_torch/conebeam.py
index 09e4f4073a4b64806251d2c10d1573de71403d42..ece50fe6031e33688250117311eeb450767d7a49 100644
--- a/src/pyronn_torch/conebeam.py
+++ b/src/pyronn_torch/conebeam.py
@@ -37,7 +37,7 @@ class ConeBeamProjector:
 
         def forward(self, volume):
             volume = volume.cuda().contiguous()
-            projection = torch.zeros(self.projection_shape, device='cuda')
+            projection = torch.zeros(self.projection_shape, device='cuda', requires_grad=volume.requires_grad)
 
             assert pyronn_torch.cpp_extension
             if self.with_texture:
@@ -133,7 +133,7 @@ class ConeBeamProjector:
                                       self._volume_shape,
                                       self._projection_multiplier,
                                       step_size,
-                                      use_texture).forward(volume)
+                                      use_texture).forward(volume)[0]
 
     def project_backward(self, projection_stack):
         return self.BackwardProjection(projection_stack)
diff --git a/tests/test_projection.py b/tests/test_projection.py
index fad1a3ad1e1ff3c0fdc7b4fd3f77e72dfe72834d..ba5cb02621320e1796b9b9fd851e558cb33c4941 100644
--- a/tests/test_projection.py
+++ b/tests/test_projection.py
@@ -16,12 +16,20 @@ def test_init():
 
 
 @pytest.mark.parametrize('with_texture', ('with_texture', False))
-def test_projection(with_texture):
+@pytest.mark.parametrize('with_backward', ('with_backward', False))
+def test_projection(with_texture, with_backward):
     projector = pyronn_torch.ConeBeamProjector.from_conrad_config()
 
-    volume = projector.new_volume_tensor()
+    volume = projector.new_volume_tensor(requires_grad=True if with_backward else False)
 
     volume += 1.
     result = projector.project_forward(volume, use_texture=False)
 
-    print(result)
+    assert result is not None
+    if with_backward:
+        assert volume.requires_grad
+        assert result.requires_grad
+
+        loss = result.mean()
+        loss.backward()
+