diff --git a/src/pyronn_torch/conebeam.py b/src/pyronn_torch/conebeam.py index 09e4f4073a4b64806251d2c10d1573de71403d42..ece50fe6031e33688250117311eeb450767d7a49 100644 --- a/src/pyronn_torch/conebeam.py +++ b/src/pyronn_torch/conebeam.py @@ -37,7 +37,7 @@ class ConeBeamProjector: def forward(self, volume): volume = volume.cuda().contiguous() - projection = torch.zeros(self.projection_shape, device='cuda') + projection = torch.zeros(self.projection_shape, device='cuda', requires_grad=volume.requires_grad) assert pyronn_torch.cpp_extension if self.with_texture: @@ -133,7 +133,7 @@ class ConeBeamProjector: self._volume_shape, self._projection_multiplier, step_size, - use_texture).forward(volume) + use_texture).forward(volume)[0] def project_backward(self, projection_stack): return self.BackwardProjection(projection_stack) diff --git a/tests/test_projection.py b/tests/test_projection.py index fad1a3ad1e1ff3c0fdc7b4fd3f77e72dfe72834d..ba5cb02621320e1796b9b9fd851e558cb33c4941 100644 --- a/tests/test_projection.py +++ b/tests/test_projection.py @@ -16,12 +16,20 @@ def test_init(): @pytest.mark.parametrize('with_texture', ('with_texture', False)) -def test_projection(with_texture): +@pytest.mark.parametrize('with_backward', ('with_backward', False)) +def test_projection(with_texture, with_backward): projector = pyronn_torch.ConeBeamProjector.from_conrad_config() - volume = projector.new_volume_tensor() + volume = projector.new_volume_tensor(requires_grad=True if with_backward else False) volume += 1. result = projector.project_forward(volume, use_texture=False) - print(result) + assert result is not None + if with_backward: + assert volume.requires_grad + assert result.requires_grad + + loss = result.mean() + loss.backward() +