Skip to content
Snippets Groups Projects
Commit a792c556 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Implement ConeBeamProjector.project_backward

parent 5cdcc048
No related branches found
No related tags found
No related merge requests found
Pipeline #22275 failed
......@@ -13,70 +13,75 @@ import torch
import pyronn_torch
class _ForwardProjection(torch.autograd.Function):
def __init__(self, projection_shape,
source_points,
inverse_matrices,
projection_matrices,
volume_origin,
volume_spacing,
projection_multiplier,
step_size=1.,
with_texture=True):
self.projection_shape = projection_shape
self.source_points = source_points
self.inverse_matrices = inverse_matrices
self.projection_matrices = projection_matrices
self.volume_origin = volume_origin
self.volume_spacing = volume_spacing
self.projection_multiplier = projection_multiplier
self.with_texture = with_texture
self.step_size = step_size
def forward(self, volume):
volume = volume.cuda().contiguous()
projection = torch.zeros(self.projection_shape, device='cuda', requires_grad=volume.requires_grad)
assert pyronn_torch.cpp_extension
if self.with_texture:
pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Tex_Interp_Launcher(
self.inverse_matrices,
projection,
self.source_points,
self.step_size,
volume,
*self.volume_spacing)
else:
pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Launcher(
self.inverse_matrices,
projection,
self.source_points,
self.step_size,
volume,
*self.volume_spacing)
return projection,
def backward(self, *projection_grad):
projection_grad = projection_grad[0]
self.projection_matrices
volume_grad = torch.zeros(self.volume_shape, device='cuda')
assert pyronn_torch.cpp_extension
pyronn_torch.cpp_extension.call_Cone_Backprojection3D_Kernel_Launcher(
self.projection_matrices,
projection_grad,
self.projection_multiplier,
volume_grad,
self.volume_origin,
self.volume_spacing)
return volume_grad,
class _BackwardProjection(torch.autograd.Function):
__init__ = _ForwardProjection.__init__
backward = _ForwardProjection.forward
forward = _ForwardProjection.backward
class ConeBeamProjector:
class ForwardProjection(torch.autograd.Function):
def __init__(self, projection_shape,
source_points,
inverse_matrices,
projection_matrices,
volume_origin,
volume_spacing,
projection_multiplier,
step_size=1.,
with_texture=True):
self.projection_shape = projection_shape
self.source_points = source_points
self.inverse_matrices = inverse_matrices
self.projection_matrices = projection_matrices
self.volume_origin = volume_origin
self.volume_spacing = volume_spacing
self.projection_multiplier = projection_multiplier
self.with_texture = with_texture
self.step_size = step_size
def forward(self, volume):
volume = volume.cuda().contiguous()
projection = torch.zeros(self.projection_shape, device='cuda', requires_grad=volume.requires_grad)
assert pyronn_torch.cpp_extension
if self.with_texture:
pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Tex_Interp_Launcher(
self.inverse_matrices,
projection,
self.source_points,
self.step_size,
volume,
*self.volume_spacing)
else:
pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Launcher(
self.inverse_matrices,
projection,
self.source_points,
self.step_size,
volume,
*self.volume_spacing)
return projection,
def backward(self, *projection_grad):
projection_grad = projection_grad[0]
self.projection_matrices
volume_grad = torch.zeros(self.volume_shape, device='cuda')
assert pyronn_torch.cpp_extension
pyronn_torch.cpp_extension.call_Cone_Backprojection3D_Kernel_Launcher(
self.projection_matrices,
projection_grad,
self.projection_multiplier,
volume_grad,
self.volume_origin,
self.volume_spacing)
return volume_grad,
class BackwardProjection(torch.autograd.Function):
pass
def __init__(self,
volume_shape,
......@@ -125,18 +130,26 @@ class ConeBeamProjector:
return torch.zeros(self._projection_shape, requires_grad=requires_grad).cuda()
def project_forward(self, volume, step_size=1., use_texture=True):
return self.ForwardProjection(self._projection_shape,
self._source_points,
self._inverse_matrices,
self._projection_matrices,
self._volume_origin,
self._volume_shape,
self._projection_multiplier,
step_size,
use_texture).forward(volume)[0]
def project_backward(self, projection_stack):
return self.BackwardProjection(projection_stack)
return _ForwardProjection(self._projection_shape,
self._source_points,
self._inverse_matrices,
self._projection_matrices,
self._volume_origin,
self._volume_shape,
self._projection_multiplier,
step_size,
use_texture).forward(volume)[0]
def project_backward(self, projection_stack, step_size=1., use_texture=True):
return _BackwardProjection(self._projection_shape,
self._source_points,
self._inverse_matrices,
self._projection_matrices,
self._volume_origin,
self._volume_shape,
self._projection_multiplier,
step_size,
use_texture).backward(projection_stack)[0]
def _calc_inverse_matrices(self):
if self._projection_matrices_numpy is None:
......
......@@ -53,6 +53,41 @@ def test_projection(with_texture, with_backward):
loss.backward()
@pytest.mark.parametrize('with_texture', ('with_texture', False))
@pytest.mark.parametrize('with_backward', ('with_backward', False))
def test_projection_backward(with_texture, with_backward):
projector = pyronn_torch.ConeBeamProjector(
(128, 128, 128),
(2.0, 2.0, 2.0),
(-127.5, -127.5, -127.5),
(2, 480, 620),
[1.0, 1.0],
(0, 0),
np.array([[[-3.10e+2, -1.20e+03, 0.00e+00, 1.86e+5],
[-2.40e+2, 0.00e+00, 1.20e+03, 1.44e+5],
[-1.00e+00, 0.00e+00, 0.00e+00, 6.00e+2]],
[[-2.89009888e+2, -1.20522754e+3, -1.02473585e-13,
1.86000000e+5],
[-2.39963440e+2, -4.18857765e+0, 1.20000000e+3,
1.44000000e+5],
[-9.99847710e-01, -1.74524058e-2, 0.00000000e+0,
6.00000000e+2]]])
)
projection = projector.new_projection_tensor(requires_grad=True if with_backward else False)
projection += 1.
result = projector.project_backward(projection, use_texture=with_texture)
assert result is not None
if with_backward:
assert projection.requires_grad
assert result.requires_grad
loss = result.mean()
loss.backward()
@pytest.mark.skipif('CI' in os.environ, reason="No conrad config on CI")
@pytest.mark.parametrize('with_backward', ('with_backward', False))
def test_conrad_config(with_backward, with_texture=True):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment