Skip to content
Snippets Groups Projects
Commit a792c556 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Implement ConeBeamProjector.project_backward

parent 5cdcc048
No related branches found
No related tags found
No related merge requests found
Pipeline #22275 failed
......@@ -13,8 +13,7 @@ import torch
import pyronn_torch
class ConeBeamProjector:
class ForwardProjection(torch.autograd.Function):
class _ForwardProjection(torch.autograd.Function):
def __init__(self, projection_shape,
source_points,
......@@ -75,8 +74,14 @@ class ConeBeamProjector:
return volume_grad,
class BackwardProjection(torch.autograd.Function):
pass
class _BackwardProjection(torch.autograd.Function):
__init__ = _ForwardProjection.__init__
backward = _ForwardProjection.forward
forward = _ForwardProjection.backward
class ConeBeamProjector:
def __init__(self,
volume_shape,
......@@ -125,7 +130,7 @@ class ConeBeamProjector:
return torch.zeros(self._projection_shape, requires_grad=requires_grad).cuda()
def project_forward(self, volume, step_size=1., use_texture=True):
return self.ForwardProjection(self._projection_shape,
return _ForwardProjection(self._projection_shape,
self._source_points,
self._inverse_matrices,
self._projection_matrices,
......@@ -135,8 +140,16 @@ class ConeBeamProjector:
step_size,
use_texture).forward(volume)[0]
def project_backward(self, projection_stack):
return self.BackwardProjection(projection_stack)
def project_backward(self, projection_stack, step_size=1., use_texture=True):
return _BackwardProjection(self._projection_shape,
self._source_points,
self._inverse_matrices,
self._projection_matrices,
self._volume_origin,
self._volume_shape,
self._projection_multiplier,
step_size,
use_texture).backward(projection_stack)[0]
def _calc_inverse_matrices(self):
if self._projection_matrices_numpy is None:
......
......@@ -53,6 +53,41 @@ def test_projection(with_texture, with_backward):
loss.backward()
@pytest.mark.parametrize('with_texture', ('with_texture', False))
@pytest.mark.parametrize('with_backward', ('with_backward', False))
def test_projection_backward(with_texture, with_backward):
projector = pyronn_torch.ConeBeamProjector(
(128, 128, 128),
(2.0, 2.0, 2.0),
(-127.5, -127.5, -127.5),
(2, 480, 620),
[1.0, 1.0],
(0, 0),
np.array([[[-3.10e+2, -1.20e+03, 0.00e+00, 1.86e+5],
[-2.40e+2, 0.00e+00, 1.20e+03, 1.44e+5],
[-1.00e+00, 0.00e+00, 0.00e+00, 6.00e+2]],
[[-2.89009888e+2, -1.20522754e+3, -1.02473585e-13,
1.86000000e+5],
[-2.39963440e+2, -4.18857765e+0, 1.20000000e+3,
1.44000000e+5],
[-9.99847710e-01, -1.74524058e-2, 0.00000000e+0,
6.00000000e+2]]])
)
projection = projector.new_projection_tensor(requires_grad=True if with_backward else False)
projection += 1.
result = projector.project_backward(projection, use_texture=with_texture)
assert result is not None
if with_backward:
assert projection.requires_grad
assert result.requires_grad
loss = result.mean()
loss.backward()
@pytest.mark.skipif('CI' in os.environ, reason="No conrad config on CI")
@pytest.mark.parametrize('with_backward', ('with_backward', False))
def test_conrad_config(with_backward, with_texture=True):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment