diff --git a/src/pyronn_torch/conebeam.py b/src/pyronn_torch/conebeam.py index 95791748653e23fdce21546c223f78dcaa353277..32f126f431f325ab5076621db0561da8c9b84020 100644 --- a/src/pyronn_torch/conebeam.py +++ b/src/pyronn_torch/conebeam.py @@ -36,7 +36,7 @@ class _ForwardProjection(torch.autograd.Function): self.step_size = step_size def forward(self, volume): - volume = volume.cuda().contiguous() + volume = volume.float().cuda().contiguous() projection = torch.zeros(self.projection_shape, device='cuda', requires_grad=volume.requires_grad) @@ -55,6 +55,7 @@ class _ForwardProjection(torch.autograd.Function): def backward(self, *projection_grad): projection_grad = projection_grad[0] + projection_grad = projection_grad.float().cuda().contiguous() volume_grad = torch.zeros(self.volume_shape, device='cuda', requires_grad=projection_grad.requires_grad) assert pyronn_torch.cpp_extension