Skip to content
Snippets Groups Projects
Commit 770c3ea2 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Apply crazy state hack

parent e3458864
Branches
No related merge requests found
......@@ -12,7 +12,7 @@ import pyronn_torch
import sympy as sp
class _ForwardProjection(torch.autograd.Function):
class State:
def __init__(self,
projection_shape,
volume_shape,
......@@ -35,40 +35,63 @@ class _ForwardProjection(torch.autograd.Function):
self.with_texture = with_texture
self.step_size = step_size
def forward(self, volume):
class _ForwardProjection(torch.autograd.Function):
@staticmethod
def forward(self, volume, state=None):
if state is None:
state = self.state
return_none = True
else:
return_none = False
volume = volume.float().cuda().contiguous()
projection = torch.zeros(self.projection_shape,
projection = torch.zeros(state.projection_shape,
device='cuda',
requires_grad=volume.requires_grad)
assert pyronn_torch.cpp_extension
if self.with_texture:
if state.with_texture:
pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Tex_Interp_Launcher(
self.inverse_matrices, projection, self.source_points,
self.step_size, volume, *self.volume_spacing)
state.inverse_matrices, projection, state.source_points,
state.step_size, volume, *state.volume_spacing)
else:
pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Launcher(
self.inverse_matrices, projection, self.source_points,
self.step_size, volume, *self.volume_spacing)
state.inverse_matrices, projection, state.source_points,
state.step_size, volume, *state.volume_spacing)
return projection,
self.state = state
if return_none:
return projection, None
else:
return projection,
@staticmethod
def backward(self, projection_grad, state=None, *args):
if state is None:
state = self.state
return_none = True
else:
return_none = False
def backward(self, *projection_grad):
projection_grad = projection_grad[0]
projection_grad = projection_grad.float().cuda().contiguous()
volume_grad = torch.zeros(self.volume_shape, device='cuda', requires_grad=projection_grad.requires_grad)
volume_grad = torch.zeros(state.volume_shape,
device='cuda',
requires_grad=projection_grad.requires_grad)
assert pyronn_torch.cpp_extension
pyronn_torch.cpp_extension.call_Cone_Backprojection3D_Kernel_Launcher(
self.projection_matrices, projection_grad,
self.projection_multiplier, volume_grad, *self.volume_origin,
*self.volume_spacing)
state.projection_matrices, projection_grad,
state.projection_multiplier, volume_grad, *state.volume_origin,
*state.volume_spacing)
return volume_grad,
if return_none:
return volume_grad, None
else:
return volume_grad,
class _BackwardProjection(torch.autograd.Function):
__init__ = _ForwardProjection.__init__
backward = _ForwardProjection.forward
forward = _ForwardProjection.backward
......@@ -118,29 +141,33 @@ class ConeBeamProjector:
requires_grad=requires_grad).cuda()
def project_forward(self, volume, step_size=1., use_texture=True):
return _ForwardProjection(self._projection_shape, self._volume_shape,
self._source_points, self._inverse_matrices,
self._projection_matrices,
self._volume_origin, self._volume_spacing,
self._projection_multiplier, step_size,
use_texture).forward(volume)[0]
return _ForwardProjection().apply(
volume,
State(self._projection_shape, self._volume_shape,
self._source_points, self._inverse_matrices,
self._projection_matrices, self._volume_origin,
self._volume_spacing, self._projection_multiplier, step_size,
use_texture))[0]
def project_backward(self,
projection_stack,
step_size=1.,
use_texture=True):
return _BackwardProjection(self._projection_shape, self._volume_shape,
self._source_points, self._inverse_matrices,
self._projection_matrices,
self._volume_origin, self._volume_spacing,
self._projection_multiplier, step_size,
use_texture).forward(projection_stack)[0]
return _BackwardProjection().apply(
projection_stack,
State(self._projection_shape, self._volume_shape,
self._source_points, self._inverse_matrices,
self._projection_matrices, self._volume_origin,
self._volume_spacing, self._projection_multiplier, step_size,
use_texture))[0]
def _calc_inverse_matrices(self):
if self._projection_matrices_numpy is None:
return
self._projection_matrices = torch.stack(
tuple(torch.from_numpy(p.astype(np.float32)) for p in self._projection_matrices_numpy)).cuda().contiguous()
tuple(
torch.from_numpy(p.astype(np.float32))
for p in self._projection_matrices_numpy)).cuda().contiguous()
inv_spacing = np.array([1 / s for s in reversed(self._volume_spacing)],
np.float32)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment