Skip to content
Snippets Groups Projects
Commit 770c3ea2 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Apply crazy state hack

parent e3458864
No related branches found
No related tags found
Loading
...@@ -12,7 +12,7 @@ import pyronn_torch ...@@ -12,7 +12,7 @@ import pyronn_torch
import sympy as sp import sympy as sp
class _ForwardProjection(torch.autograd.Function): class State:
def __init__(self, def __init__(self,
projection_shape, projection_shape,
volume_shape, volume_shape,
...@@ -35,40 +35,63 @@ class _ForwardProjection(torch.autograd.Function): ...@@ -35,40 +35,63 @@ class _ForwardProjection(torch.autograd.Function):
self.with_texture = with_texture self.with_texture = with_texture
self.step_size = step_size self.step_size = step_size
def forward(self, volume):
class _ForwardProjection(torch.autograd.Function):
@staticmethod
def forward(self, volume, state=None):
if state is None:
state = self.state
return_none = True
else:
return_none = False
volume = volume.float().cuda().contiguous() volume = volume.float().cuda().contiguous()
projection = torch.zeros(self.projection_shape, projection = torch.zeros(state.projection_shape,
device='cuda', device='cuda',
requires_grad=volume.requires_grad) requires_grad=volume.requires_grad)
assert pyronn_torch.cpp_extension assert pyronn_torch.cpp_extension
if self.with_texture: if state.with_texture:
pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Tex_Interp_Launcher( pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Tex_Interp_Launcher(
self.inverse_matrices, projection, self.source_points, state.inverse_matrices, projection, state.source_points,
self.step_size, volume, *self.volume_spacing) state.step_size, volume, *state.volume_spacing)
else: else:
pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Launcher( pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Launcher(
self.inverse_matrices, projection, self.source_points, state.inverse_matrices, projection, state.source_points,
self.step_size, volume, *self.volume_spacing) state.step_size, volume, *state.volume_spacing)
return projection, self.state = state
if return_none:
return projection, None
else:
return projection,
@staticmethod
def backward(self, projection_grad, state=None, *args):
if state is None:
state = self.state
return_none = True
else:
return_none = False
def backward(self, *projection_grad):
projection_grad = projection_grad[0]
projection_grad = projection_grad.float().cuda().contiguous() projection_grad = projection_grad.float().cuda().contiguous()
volume_grad = torch.zeros(self.volume_shape, device='cuda', requires_grad=projection_grad.requires_grad) volume_grad = torch.zeros(state.volume_shape,
device='cuda',
requires_grad=projection_grad.requires_grad)
assert pyronn_torch.cpp_extension assert pyronn_torch.cpp_extension
pyronn_torch.cpp_extension.call_Cone_Backprojection3D_Kernel_Launcher( pyronn_torch.cpp_extension.call_Cone_Backprojection3D_Kernel_Launcher(
self.projection_matrices, projection_grad, state.projection_matrices, projection_grad,
self.projection_multiplier, volume_grad, *self.volume_origin, state.projection_multiplier, volume_grad, *state.volume_origin,
*self.volume_spacing) *state.volume_spacing)
return volume_grad, if return_none:
return volume_grad, None
else:
return volume_grad,
class _BackwardProjection(torch.autograd.Function): class _BackwardProjection(torch.autograd.Function):
__init__ = _ForwardProjection.__init__
backward = _ForwardProjection.forward backward = _ForwardProjection.forward
forward = _ForwardProjection.backward forward = _ForwardProjection.backward
...@@ -118,29 +141,33 @@ class ConeBeamProjector: ...@@ -118,29 +141,33 @@ class ConeBeamProjector:
requires_grad=requires_grad).cuda() requires_grad=requires_grad).cuda()
def project_forward(self, volume, step_size=1., use_texture=True): def project_forward(self, volume, step_size=1., use_texture=True):
return _ForwardProjection(self._projection_shape, self._volume_shape, return _ForwardProjection().apply(
self._source_points, self._inverse_matrices, volume,
self._projection_matrices, State(self._projection_shape, self._volume_shape,
self._volume_origin, self._volume_spacing, self._source_points, self._inverse_matrices,
self._projection_multiplier, step_size, self._projection_matrices, self._volume_origin,
use_texture).forward(volume)[0] self._volume_spacing, self._projection_multiplier, step_size,
use_texture))[0]
def project_backward(self, def project_backward(self,
projection_stack, projection_stack,
step_size=1., step_size=1.,
use_texture=True): use_texture=True):
return _BackwardProjection(self._projection_shape, self._volume_shape, return _BackwardProjection().apply(
self._source_points, self._inverse_matrices, projection_stack,
self._projection_matrices, State(self._projection_shape, self._volume_shape,
self._volume_origin, self._volume_spacing, self._source_points, self._inverse_matrices,
self._projection_multiplier, step_size, self._projection_matrices, self._volume_origin,
use_texture).forward(projection_stack)[0] self._volume_spacing, self._projection_multiplier, step_size,
use_texture))[0]
def _calc_inverse_matrices(self): def _calc_inverse_matrices(self):
if self._projection_matrices_numpy is None: if self._projection_matrices_numpy is None:
return return
self._projection_matrices = torch.stack( self._projection_matrices = torch.stack(
tuple(torch.from_numpy(p.astype(np.float32)) for p in self._projection_matrices_numpy)).cuda().contiguous() tuple(
torch.from_numpy(p.astype(np.float32))
for p in self._projection_matrices_numpy)).cuda().contiguous()
inv_spacing = np.array([1 / s for s in reversed(self._volume_spacing)], inv_spacing = np.array([1 / s for s in reversed(self._volume_spacing)],
np.float32) np.float32)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment