From 8bc762ad4397aecee8dde8738cca6a15db011c38 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Wed, 26 Feb 2020 14:11:32 +0100 Subject: [PATCH] Backward gradient now working --- src/pyronn_torch/conebeam.py | 4 ++-- tests/test_projection.py | 14 +++++++++++--- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/pyronn_torch/conebeam.py b/src/pyronn_torch/conebeam.py index 09e4f40..ece50fe 100644 --- a/src/pyronn_torch/conebeam.py +++ b/src/pyronn_torch/conebeam.py @@ -37,7 +37,7 @@ class ConeBeamProjector: def forward(self, volume): volume = volume.cuda().contiguous() - projection = torch.zeros(self.projection_shape, device='cuda') + projection = torch.zeros(self.projection_shape, device='cuda', requires_grad=volume.requires_grad) assert pyronn_torch.cpp_extension if self.with_texture: @@ -133,7 +133,7 @@ class ConeBeamProjector: self._volume_shape, self._projection_multiplier, step_size, - use_texture).forward(volume) + use_texture).forward(volume)[0] def project_backward(self, projection_stack): return self.BackwardProjection(projection_stack) diff --git a/tests/test_projection.py b/tests/test_projection.py index fad1a3a..ba5cb02 100644 --- a/tests/test_projection.py +++ b/tests/test_projection.py @@ -16,12 +16,20 @@ def test_init(): @pytest.mark.parametrize('with_texture', ('with_texture', False)) -def test_projection(with_texture): +@pytest.mark.parametrize('with_backward', ('with_backward', False)) +def test_projection(with_texture, with_backward): projector = pyronn_torch.ConeBeamProjector.from_conrad_config() - volume = projector.new_volume_tensor() + volume = projector.new_volume_tensor(requires_grad=True if with_backward else False) volume += 1. result = projector.project_forward(volume, use_texture=False) - print(result) + assert result is not None + if with_backward: + assert volume.requires_grad + assert result.requires_grad + + loss = result.mean() + loss.backward() + -- GitLab