diff --git a/src/pyronn_torch/conebeam.py b/src/pyronn_torch/conebeam.py index c9409fb2b5e6bff73e8b07694f33ecf375142009..846ef6306b0c71372dfaa890a3dc0a8954cf0b14 100644 --- a/src/pyronn_torch/conebeam.py +++ b/src/pyronn_torch/conebeam.py @@ -85,6 +85,7 @@ class _ForwardProjection(torch.autograd.Function): state.projection_multiplier, volume_grad, *state.volume_origin, *state.volume_spacing) + self.state = state if return_none: return volume_grad, None else: @@ -92,8 +93,8 @@ class _ForwardProjection(torch.autograd.Function): class _BackwardProjection(torch.autograd.Function): - backward = _ForwardProjection.forward - forward = _ForwardProjection.backward + backward = staticmethod(_ForwardProjection.forward) + forward = staticmethod(_ForwardProjection.backward) class ConeBeamProjector: