From 8bc762ad4397aecee8dde8738cca6a15db011c38 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Wed, 26 Feb 2020 14:11:32 +0100
Subject: [PATCH] Backward gradient now working

---
 src/pyronn_torch/conebeam.py |  4 ++--
 tests/test_projection.py     | 14 +++++++++++---
 2 files changed, 13 insertions(+), 5 deletions(-)

diff --git a/src/pyronn_torch/conebeam.py b/src/pyronn_torch/conebeam.py
index 09e4f40..ece50fe 100644
--- a/src/pyronn_torch/conebeam.py
+++ b/src/pyronn_torch/conebeam.py
@@ -37,7 +37,7 @@ class ConeBeamProjector:
 
         def forward(self, volume):
             volume = volume.cuda().contiguous()
-            projection = torch.zeros(self.projection_shape, device='cuda')
+            projection = torch.zeros(self.projection_shape, device='cuda', requires_grad=volume.requires_grad)
 
             assert pyronn_torch.cpp_extension
             if self.with_texture:
@@ -133,7 +133,7 @@ class ConeBeamProjector:
                                       self._volume_shape,
                                       self._projection_multiplier,
                                       step_size,
-                                      use_texture).forward(volume)
+                                      use_texture).forward(volume)[0]
 
     def project_backward(self, projection_stack):
         return self.BackwardProjection(projection_stack)
diff --git a/tests/test_projection.py b/tests/test_projection.py
index fad1a3a..ba5cb02 100644
--- a/tests/test_projection.py
+++ b/tests/test_projection.py
@@ -16,12 +16,20 @@ def test_init():
 
 
 @pytest.mark.parametrize('with_texture', ('with_texture', False))
-def test_projection(with_texture):
+@pytest.mark.parametrize('with_backward', ('with_backward', False))
+def test_projection(with_texture, with_backward):
     projector = pyronn_torch.ConeBeamProjector.from_conrad_config()
 
-    volume = projector.new_volume_tensor()
+    volume = projector.new_volume_tensor(requires_grad=True if with_backward else False)
 
     volume += 1.
     result = projector.project_forward(volume, use_texture=False)
 
-    print(result)
+    assert result is not None
+    if with_backward:
+        assert volume.requires_grad
+        assert result.requires_grad
+
+        loss = result.mean()
+        loss.backward()
+
-- 
GitLab