Skip to content
Snippets Groups Projects
Commit 8bc762ad authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Backward gradient now working

parent 6a5b3fef
No related branches found
No related tags found
No related merge requests found
......@@ -37,7 +37,7 @@ class ConeBeamProjector:
def forward(self, volume):
volume = volume.cuda().contiguous()
projection = torch.zeros(self.projection_shape, device='cuda')
projection = torch.zeros(self.projection_shape, device='cuda', requires_grad=volume.requires_grad)
assert pyronn_torch.cpp_extension
if self.with_texture:
......@@ -133,7 +133,7 @@ class ConeBeamProjector:
self._volume_shape,
self._projection_multiplier,
step_size,
use_texture).forward(volume)
use_texture).forward(volume)[0]
def project_backward(self, projection_stack):
return self.BackwardProjection(projection_stack)
......
......@@ -16,12 +16,20 @@ def test_init():
@pytest.mark.parametrize('with_texture', ('with_texture', False))
def test_projection(with_texture):
@pytest.mark.parametrize('with_backward', ('with_backward', False))
def test_projection(with_texture, with_backward):
projector = pyronn_torch.ConeBeamProjector.from_conrad_config()
volume = projector.new_volume_tensor()
volume = projector.new_volume_tensor(requires_grad=True if with_backward else False)
volume += 1.
result = projector.project_forward(volume, use_texture=False)
print(result)
assert result is not None
if with_backward:
assert volume.requires_grad
assert result.requires_grad
loss = result.mean()
loss.backward()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment