From d83be5b6caf0b63e98fc537e235182b7a7d7b73c Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Wed, 26 Feb 2020 10:46:15 +0100
Subject: [PATCH] Re-introduce ConeBeamProjector.from_conrad_config()

---
 src/pyronn_torch/conebeam.py | 63 +++++++++++++++++++-----------------
 tests/test_projection.py     |  5 ++-
 2 files changed, 36 insertions(+), 32 deletions(-)

diff --git a/src/pyronn_torch/conebeam.py b/src/pyronn_torch/conebeam.py
index 0eca8ea..7fb1204 100644
--- a/src/pyronn_torch/conebeam.py
+++ b/src/pyronn_torch/conebeam.py
@@ -78,37 +78,40 @@ class ConeBeamProjector:
     class BackwardProjection(torch.autograd.Function):
         pass
 
-    # def __init__(self,
-        # volume_shape,
-        # volume_spacing,
-        # volume_origin,
-        # projection_shape,
-        # projection_spacing,
-        # projection_origin,
-        # projection_matrices):
-        # self._volume_shape = volume_shape
-        # self._volume_origin = volume_origin
-        # self._volume_spacing = volume_spacing
-        # self._projection_shape = projection_shape
-        # self._projection_matrices = projection_matrices
-        # self._projection_spacing = projection_spacing
-        # self._projection_origin = projection_origin
-        # self._calc_inverse_matrices()
-
-    def __init__(self):
+    def __init__(self,
+                 volume_shape,
+                 volume_spacing,
+                 volume_origin,
+                 projection_shape,
+                 projection_spacing,
+                 projection_origin,
+                 projection_matrices):
+        self._volume_shape = volume_shape
+        self._volume_origin = volume_origin
+        self._volume_spacing = volume_spacing
+        self._projection_shape = projection_shape
+        self._projection_matrices_numpy = projection_matrices
+        self._projection_spacing = projection_spacing
+        self._projection_origin = projection_origin
+        self._calc_inverse_matrices()
+
+    @classmethod
+    def from_conrad_config(cls):
+        obj = cls(*([None]*7))
         import pyconrad.autoinit
         import pyconrad.config
-        self._volume_shape = pyconrad.config.get_reco_shape()
-        self._volume_spacing = pyconrad.config.get_reco_spacing()
-        self._volume_origin = pyconrad.config.get_reco_origin()
-        self._projection_shape = pyconrad.config.get_sino_shape()
-        self._projection_spacing = [pyconrad.config.get_geometry().getPixelDimensionY(),
-                                    pyconrad.config.get_geometry().getPixelDimensionX()]
-        self._projection_origin = [pyconrad.config.get_geometry().getDetectorOffsetV(),
-                                   pyconrad.config.get_geometry().getDetectorOffsetU()]
-        self._projection_matrices_numpy = pyconrad.config.get_projection_matrices()
-
-        self._calc_inverse_matrices()
+        obj._volume_shape = pyconrad.config.get_reco_shape()
+        obj._volume_spacing = pyconrad.config.get_reco_spacing()
+        obj._volume_origin = pyconrad.config.get_reco_origin()
+        obj._projection_shape = pyconrad.config.get_sino_shape()
+        obj._projection_spacing = [pyconrad.config.get_geometry().getPixelDimensionY(),
+                                   pyconrad.config.get_geometry().getPixelDimensionX()]
+        obj._projection_origin = [pyconrad.config.get_geometry().getDetectorOffsetV(),
+                                  pyconrad.config.get_geometry().getDetectorOffsetU()]
+        obj._projection_matrices_numpy = pyconrad.config.get_projection_matrices()
+
+        obj._calc_inverse_matrices()
+        return obj
 
     def new_volume_tensor(self, requires_grad=False):
         return torch.zeros(self._volume_shape, requires_grad=requires_grad).cuda()
@@ -131,6 +134,8 @@ class ConeBeamProjector:
         return self.BackwardProjection(projection_stack)
 
     def _calc_inverse_matrices(self):
+        if self._projection_matrices_numpy is None:
+            return
         self._projection_matrices = torch.stack(tuple(
             map(torch.from_numpy, self._projection_matrices_numpy))).cuda().contiguous()
 
diff --git a/tests/test_projection.py b/tests/test_projection.py
index 8b733f1..6fac28f 100644
--- a/tests/test_projection.py
+++ b/tests/test_projection.py
@@ -9,13 +9,12 @@
 import pyronn_torch
 
 
-def init():
+def test_init():
     assert pyronn_torch.cpp_extension
 
 
 def test_projection():
-    breakpoint()
-    projector = pyronn_torch.ConeBeamProjector()
+    projector = pyronn_torch.ConeBeamProjector.from_conrad_config()
 
     volume = projector.new_volume_tensor()
 
-- 
GitLab