diff --git a/src/pyronn_torch/conebeam.py b/src/pyronn_torch/conebeam.py index 278d2a5b54c265583a8bee08b2556b882682b2ea..c9409fb2b5e6bff73e8b07694f33ecf375142009 100644 --- a/src/pyronn_torch/conebeam.py +++ b/src/pyronn_torch/conebeam.py @@ -12,7 +12,7 @@ import pyronn_torch import sympy as sp -class _ForwardProjection(torch.autograd.Function): +class State: def __init__(self, projection_shape, volume_shape, @@ -35,40 +35,63 @@ class _ForwardProjection(torch.autograd.Function): self.with_texture = with_texture self.step_size = step_size - def forward(self, volume): + +class _ForwardProjection(torch.autograd.Function): + @staticmethod + def forward(self, volume, state=None): + if state is None: + state = self.state + return_none = True + else: + return_none = False + volume = volume.float().cuda().contiguous() - projection = torch.zeros(self.projection_shape, + projection = torch.zeros(state.projection_shape, device='cuda', requires_grad=volume.requires_grad) assert pyronn_torch.cpp_extension - if self.with_texture: + if state.with_texture: pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Tex_Interp_Launcher( - self.inverse_matrices, projection, self.source_points, - self.step_size, volume, *self.volume_spacing) + state.inverse_matrices, projection, state.source_points, + state.step_size, volume, *state.volume_spacing) else: pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Launcher( - self.inverse_matrices, projection, self.source_points, - self.step_size, volume, *self.volume_spacing) + state.inverse_matrices, projection, state.source_points, + state.step_size, volume, *state.volume_spacing) - return projection, + self.state = state + if return_none: + return projection, None + else: + return projection, + + @staticmethod + def backward(self, projection_grad, state=None, *args): + if state is None: + state = self.state + return_none = True + else: + return_none = False - def backward(self, *projection_grad): - projection_grad = projection_grad[0] projection_grad = projection_grad.float().cuda().contiguous() - volume_grad = torch.zeros(self.volume_shape, device='cuda', requires_grad=projection_grad.requires_grad) + volume_grad = torch.zeros(state.volume_shape, + device='cuda', + requires_grad=projection_grad.requires_grad) assert pyronn_torch.cpp_extension pyronn_torch.cpp_extension.call_Cone_Backprojection3D_Kernel_Launcher( - self.projection_matrices, projection_grad, - self.projection_multiplier, volume_grad, *self.volume_origin, - *self.volume_spacing) + state.projection_matrices, projection_grad, + state.projection_multiplier, volume_grad, *state.volume_origin, + *state.volume_spacing) - return volume_grad, + if return_none: + return volume_grad, None + else: + return volume_grad, class _BackwardProjection(torch.autograd.Function): - __init__ = _ForwardProjection.__init__ backward = _ForwardProjection.forward forward = _ForwardProjection.backward @@ -118,29 +141,33 @@ class ConeBeamProjector: requires_grad=requires_grad).cuda() def project_forward(self, volume, step_size=1., use_texture=True): - return _ForwardProjection(self._projection_shape, self._volume_shape, - self._source_points, self._inverse_matrices, - self._projection_matrices, - self._volume_origin, self._volume_spacing, - self._projection_multiplier, step_size, - use_texture).forward(volume)[0] + return _ForwardProjection().apply( + volume, + State(self._projection_shape, self._volume_shape, + self._source_points, self._inverse_matrices, + self._projection_matrices, self._volume_origin, + self._volume_spacing, self._projection_multiplier, step_size, + use_texture))[0] def project_backward(self, projection_stack, step_size=1., use_texture=True): - return _BackwardProjection(self._projection_shape, self._volume_shape, - self._source_points, self._inverse_matrices, - self._projection_matrices, - self._volume_origin, self._volume_spacing, - self._projection_multiplier, step_size, - use_texture).forward(projection_stack)[0] + return _BackwardProjection().apply( + projection_stack, + State(self._projection_shape, self._volume_shape, + self._source_points, self._inverse_matrices, + self._projection_matrices, self._volume_origin, + self._volume_spacing, self._projection_multiplier, step_size, + use_texture))[0] def _calc_inverse_matrices(self): if self._projection_matrices_numpy is None: return self._projection_matrices = torch.stack( - tuple(torch.from_numpy(p.astype(np.float32)) for p in self._projection_matrices_numpy)).cuda().contiguous() + tuple( + torch.from_numpy(p.astype(np.float32)) + for p in self._projection_matrices_numpy)).cuda().contiguous() inv_spacing = np.array([1 / s for s in reversed(self._volume_spacing)], np.float32)