diff --git a/src/pyronn_torch/conebeam.py b/src/pyronn_torch/conebeam.py index ece50fe6031e33688250117311eeb450767d7a49..5343f8c7edce53e5f8c476f96c3bc40513330463 100644 --- a/src/pyronn_torch/conebeam.py +++ b/src/pyronn_torch/conebeam.py @@ -13,70 +13,75 @@ import torch import pyronn_torch +class _ForwardProjection(torch.autograd.Function): + + def __init__(self, projection_shape, + source_points, + inverse_matrices, + projection_matrices, + volume_origin, + volume_spacing, + projection_multiplier, + step_size=1., + with_texture=True): + self.projection_shape = projection_shape + self.source_points = source_points + self.inverse_matrices = inverse_matrices + self.projection_matrices = projection_matrices + self.volume_origin = volume_origin + self.volume_spacing = volume_spacing + self.projection_multiplier = projection_multiplier + self.with_texture = with_texture + self.step_size = step_size + + def forward(self, volume): + volume = volume.cuda().contiguous() + projection = torch.zeros(self.projection_shape, device='cuda', requires_grad=volume.requires_grad) + + assert pyronn_torch.cpp_extension + if self.with_texture: + pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Tex_Interp_Launcher( + self.inverse_matrices, + projection, + self.source_points, + self.step_size, + volume, + *self.volume_spacing) + else: + pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Launcher( + self.inverse_matrices, + projection, + self.source_points, + self.step_size, + volume, + *self.volume_spacing) + + return projection, + + def backward(self, *projection_grad): + projection_grad = projection_grad[0] + self.projection_matrices + volume_grad = torch.zeros(self.volume_shape, device='cuda') + + assert pyronn_torch.cpp_extension + pyronn_torch.cpp_extension.call_Cone_Backprojection3D_Kernel_Launcher( + self.projection_matrices, + projection_grad, + self.projection_multiplier, + volume_grad, + self.volume_origin, + self.volume_spacing) + + return volume_grad, + + +class _BackwardProjection(torch.autograd.Function): + __init__ = _ForwardProjection.__init__ + backward = _ForwardProjection.forward + forward = _ForwardProjection.backward + + class ConeBeamProjector: - class ForwardProjection(torch.autograd.Function): - - def __init__(self, projection_shape, - source_points, - inverse_matrices, - projection_matrices, - volume_origin, - volume_spacing, - projection_multiplier, - step_size=1., - with_texture=True): - self.projection_shape = projection_shape - self.source_points = source_points - self.inverse_matrices = inverse_matrices - self.projection_matrices = projection_matrices - self.volume_origin = volume_origin - self.volume_spacing = volume_spacing - self.projection_multiplier = projection_multiplier - self.with_texture = with_texture - self.step_size = step_size - - def forward(self, volume): - volume = volume.cuda().contiguous() - projection = torch.zeros(self.projection_shape, device='cuda', requires_grad=volume.requires_grad) - - assert pyronn_torch.cpp_extension - if self.with_texture: - pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Tex_Interp_Launcher( - self.inverse_matrices, - projection, - self.source_points, - self.step_size, - volume, - *self.volume_spacing) - else: - pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Launcher( - self.inverse_matrices, - projection, - self.source_points, - self.step_size, - volume, - *self.volume_spacing) - - return projection, - - def backward(self, *projection_grad): - projection_grad = projection_grad[0] - self.projection_matrices - volume_grad = torch.zeros(self.volume_shape, device='cuda') - - assert pyronn_torch.cpp_extension - pyronn_torch.cpp_extension.call_Cone_Backprojection3D_Kernel_Launcher( - self.projection_matrices, - projection_grad, - self.projection_multiplier, - volume_grad, - self.volume_origin, - self.volume_spacing) - - return volume_grad, - - class BackwardProjection(torch.autograd.Function): - pass def __init__(self, volume_shape, @@ -125,18 +130,26 @@ class ConeBeamProjector: return torch.zeros(self._projection_shape, requires_grad=requires_grad).cuda() def project_forward(self, volume, step_size=1., use_texture=True): - return self.ForwardProjection(self._projection_shape, - self._source_points, - self._inverse_matrices, - self._projection_matrices, - self._volume_origin, - self._volume_shape, - self._projection_multiplier, - step_size, - use_texture).forward(volume)[0] - - def project_backward(self, projection_stack): - return self.BackwardProjection(projection_stack) + return _ForwardProjection(self._projection_shape, + self._source_points, + self._inverse_matrices, + self._projection_matrices, + self._volume_origin, + self._volume_shape, + self._projection_multiplier, + step_size, + use_texture).forward(volume)[0] + + def project_backward(self, projection_stack, step_size=1., use_texture=True): + return _BackwardProjection(self._projection_shape, + self._source_points, + self._inverse_matrices, + self._projection_matrices, + self._volume_origin, + self._volume_shape, + self._projection_multiplier, + step_size, + use_texture).backward(projection_stack)[0] def _calc_inverse_matrices(self): if self._projection_matrices_numpy is None: diff --git a/tests/test_projection.py b/tests/test_projection.py index f72f3cc20ff02e8226db1b36de3ad1d8ac22e6df..9562fa8b1488c42131da51d027aaeac4f395b80c 100644 --- a/tests/test_projection.py +++ b/tests/test_projection.py @@ -53,6 +53,41 @@ def test_projection(with_texture, with_backward): loss.backward() +@pytest.mark.parametrize('with_texture', ('with_texture', False)) +@pytest.mark.parametrize('with_backward', ('with_backward', False)) +def test_projection_backward(with_texture, with_backward): + projector = pyronn_torch.ConeBeamProjector( + (128, 128, 128), + (2.0, 2.0, 2.0), + (-127.5, -127.5, -127.5), + (2, 480, 620), + [1.0, 1.0], + (0, 0), + np.array([[[-3.10e+2, -1.20e+03, 0.00e+00, 1.86e+5], + [-2.40e+2, 0.00e+00, 1.20e+03, 1.44e+5], + [-1.00e+00, 0.00e+00, 0.00e+00, 6.00e+2]], + [[-2.89009888e+2, -1.20522754e+3, -1.02473585e-13, + 1.86000000e+5], + [-2.39963440e+2, -4.18857765e+0, 1.20000000e+3, + 1.44000000e+5], + [-9.99847710e-01, -1.74524058e-2, 0.00000000e+0, + 6.00000000e+2]]]) + ) + + projection = projector.new_projection_tensor(requires_grad=True if with_backward else False) + + projection += 1. + result = projector.project_backward(projection, use_texture=with_texture) + + assert result is not None + if with_backward: + assert projection.requires_grad + assert result.requires_grad + + loss = result.mean() + loss.backward() + + @pytest.mark.skipif('CI' in os.environ, reason="No conrad config on CI") @pytest.mark.parametrize('with_backward', ('with_backward', False)) def test_conrad_config(with_backward, with_texture=True):