diff --git a/src/pyronn_torch/__init__.py b/src/pyronn_torch/__init__.py
index e76fcf215159029a8c08d722e9abeef09382e57d..57356e93074f55144ea6a3521f35ade374ee3dbe 100644
--- a/src/pyronn_torch/__init__.py
+++ b/src/pyronn_torch/__init__.py
@@ -1,10 +1,12 @@
 # -*- coding: utf-8 -*-
 
-from os.path import dirname, join
+import sys
+from os.path import dirname
 
-import torch
 from pkg_resources import DistributionNotFound, get_distribution
 
+from pyronn_torch.conebeam import ConeBeamProjector
+
 try:
     # Change here if project is renamed and does not equal the package name
     dist_name = 'pyronn-torch'
@@ -16,8 +18,12 @@ finally:
 
 
 try:
-    cpp_extension = torch.ops.load_library(join(dirname(__file__), 'pyronn_torch.so'))
-except Exception:
+    sys.path.append(dirname(__file__))
+    cpp_extension = __import__('pyronn_torch_cpp')
+except Exception as e:
+    import warnings
+    warnings.warn(str(e))
     import pyronn_torch.codegen
-    cpp_extension = pyronn_torch.codegen.compile_shared_object()
+    cpp_extension = pyronn_torch.codegen.generate_shared_object()
 
+__all__ = ['ConeBeamProjector', 'cpp_extension']
diff --git a/src/pyronn_torch/codegen.py b/src/pyronn_torch/codegen.py
index f4a3cfa841957edb0e8cb92a527c5db2756180eb..0021721409ecd6dbfaf0c7eccdb4653ecee79c04 100644
--- a/src/pyronn_torch/codegen.py
+++ b/src/pyronn_torch/codegen.py
@@ -151,7 +151,7 @@ def generate_shared_object(output_folder=None, source_files=None, show_code=Fals
 
     object_cache = get_cache_config()['object_cache']
 
-    module_name = 'PYRO_NN'
+    module_name = 'pyronn_torch_cpp'
 
     if not output_folder:
         output_folder = dirname(__file__)
@@ -176,10 +176,13 @@ def generate_shared_object(output_folder=None, source_files=None, show_code=Fals
     if show_code:
         pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter())
 
-    extension = module.compile(extra_source_files=cuda_sources, extra_cuda_flags=['-arch=sm_35'], with_cuda=True)
+    extension = module.compile(extra_source_files=cuda_sources,
+                               extra_cuda_flags=['-arch=sm_35'],
+                               with_cuda=True,
+                               compile_module_name=module_name)
 
-    shared_object_file = module.compiled_file.replace('.cpp', '.so')
-    copyfile(shared_object_file, join(output_folder, 'pyronn_torch.so'))
+    shared_object_file = module.compiled_file
+    copyfile(shared_object_file, join(output_folder, module_name + '.so'))
     copyfile(module.compiled_file, join(output_folder, 'pyronn_torch.cpp'))
 
     return extension
diff --git a/src/pyronn_torch/conebeam.py b/src/pyronn_torch/conebeam.py
new file mode 100644
index 0000000000000000000000000000000000000000..0eca8ea4d341e0ae4d524830cb0968c619661059
--- /dev/null
+++ b/src/pyronn_torch/conebeam.py
@@ -0,0 +1,149 @@
+#
+# Copyright © 2020 Stephan Seitz <stephan.seitz@fau.de>
+#
+# Distributed under terms of the GPLv3 license.
+
+"""
+
+"""
+import numpy as np
+import sympy as sp
+import torch
+
+import pyronn_torch
+
+
+class ConeBeamProjector:
+    class ForwardProjection(torch.autograd.Function):
+
+        def __init__(self, projection_shape,
+                     source_points,
+                     inverse_matrices,
+                     projection_matrices,
+                     volume_origin,
+                     volume_spacing,
+                     projection_multiplier,
+                     step_size=1.,
+                     with_texture=True):
+            self.projection_shape = projection_shape
+            self.source_points = source_points
+            self.inverse_matrices = inverse_matrices
+            self.projection_matrices = projection_matrices
+            self.volume_origin = volume_origin
+            self.volume_spacing = volume_spacing
+            self.projection_multiplier = projection_multiplier
+            self.with_texture = with_texture
+            self.step_size = step_size
+
+        def forward(self, volume):
+            volume = volume.cuda().contiguous()
+            projection = torch.zeros(self.projection_shape, device='cuda')
+
+            assert pyronn_torch.cpp_extension
+            if self.with_texture:
+                pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Tex_Interp_Launcher(
+                    self.inverse_matrices,
+                    projection,
+                    self.source_points,
+                    self.step_size,
+                    volume,
+                    *reversed(self.volume_spacing))
+            else:
+                pyronn_torch.cpp_extension.call_Cone_Projection_Kernel_Launcher(
+                    self.inverse_matrices,
+                    projection,
+                    self.source_points,
+                    self.step_size,
+                    volume,
+                    *reversed(self.volume_spacing))
+
+            return projection,
+
+        def backward(self, *projection_grad):
+            projection_grad = projection_grad[0]
+            self.projection_matrices
+            volume_grad = torch.zeros(self.volume_shape, device='cuda')
+
+            assert pyronn_torch.cpp_extension
+            pyronn_torch.cpp_extension.call_Cone_Backprojection3D_Kernel_Launcher(
+                self.projection_matrices,
+                projection_grad,
+                self.projection_multiplier,
+                volume_grad,
+                *reversed(self.volume_origin),
+                *reversed(self.volume_spacing))
+
+            return volume_grad,
+
+    class BackwardProjection(torch.autograd.Function):
+        pass
+
+    # def __init__(self,
+        # volume_shape,
+        # volume_spacing,
+        # volume_origin,
+        # projection_shape,
+        # projection_spacing,
+        # projection_origin,
+        # projection_matrices):
+        # self._volume_shape = volume_shape
+        # self._volume_origin = volume_origin
+        # self._volume_spacing = volume_spacing
+        # self._projection_shape = projection_shape
+        # self._projection_matrices = projection_matrices
+        # self._projection_spacing = projection_spacing
+        # self._projection_origin = projection_origin
+        # self._calc_inverse_matrices()
+
+    def __init__(self):
+        import pyconrad.autoinit
+        import pyconrad.config
+        self._volume_shape = pyconrad.config.get_reco_shape()
+        self._volume_spacing = pyconrad.config.get_reco_spacing()
+        self._volume_origin = pyconrad.config.get_reco_origin()
+        self._projection_shape = pyconrad.config.get_sino_shape()
+        self._projection_spacing = [pyconrad.config.get_geometry().getPixelDimensionY(),
+                                    pyconrad.config.get_geometry().getPixelDimensionX()]
+        self._projection_origin = [pyconrad.config.get_geometry().getDetectorOffsetV(),
+                                   pyconrad.config.get_geometry().getDetectorOffsetU()]
+        self._projection_matrices_numpy = pyconrad.config.get_projection_matrices()
+
+        self._calc_inverse_matrices()
+
+    def new_volume_tensor(self, requires_grad=False):
+        return torch.zeros(self._volume_shape, requires_grad=requires_grad).cuda()
+
+    def new_projection_tensor(self, requires_grad=False):
+        return torch.zeros(self._projection_shape, requires_grad=requires_grad).cuda()
+
+    def project_forward(self, volume, step_size=1., use_texture=True):
+        return self.ForwardProjection(self._projection_shape,
+                                      self._source_points,
+                                      self._inverse_matrices,
+                                      self._projection_matrices,
+                                      self._volume_origin,
+                                      self._volume_shape,
+                                      self._projection_multiplier,
+                                      step_size,
+                                      use_texture).forward(volume)
+
+    def project_backward(self, projection_stack):
+        return self.BackwardProjection(projection_stack)
+
+    def _calc_inverse_matrices(self):
+        self._projection_matrices = torch.stack(tuple(
+            map(torch.from_numpy, self._projection_matrices_numpy))).cuda().contiguous()
+
+        inv_spacing = np.array([1/s for s in reversed(self._volume_spacing)], np.float32)
+
+        camera_centers = map(lambda x: np.array(sp.Matrix(x).nullspace(), np.float32), self._projection_matrices_numpy)
+
+        source_points = map(lambda x: (x[0, :3] / x[0, 3] * inv_spacing - np.array(list(reversed(self._volume_origin)))
+                                       * inv_spacing).astype(np.float32), camera_centers)
+
+        inv_matrices = map(lambda x: (np.linalg.inv(x[:3, :3]) *
+                                      inv_spacing).astype(np.float32), self._projection_matrices_numpy)
+
+        self._inverse_matrices = torch.stack(tuple(map(torch.from_numpy, inv_matrices))).cuda().contiguous()
+        self._source_points = torch.stack(tuple(map(torch.from_numpy, source_points))).cuda().contiguous()
+        self._projection_multiplier = 1.
diff --git a/src/pyronn_torch/pyronn_torch.cpp b/src/pyronn_torch/pyronn_torch.cpp
index ecad67e8dd2aad521726c36ef04c52fea7f46a7e..e5defe1b03e85dd5e509bafb5dc48d231a72a526 100644
Binary files a/src/pyronn_torch/pyronn_torch.cpp and b/src/pyronn_torch/pyronn_torch.cpp differ
diff --git a/tests/test_projection.py b/tests/test_projection.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b733f126ea851b5059dede00df54b7a00a516b4
--- /dev/null
+++ b/tests/test_projection.py
@@ -0,0 +1,25 @@
+#
+# Copyright © 2020 Stephan Seitz <stephan.seitz@fau.de>
+#
+# Distributed under terms of the GPLv3 license.
+
+"""
+
+"""
+import pyronn_torch
+
+
+def init():
+    assert pyronn_torch.cpp_extension
+
+
+def test_projection():
+    breakpoint()
+    projector = pyronn_torch.ConeBeamProjector()
+
+    volume = projector.new_volume_tensor()
+
+    volume += 1.
+    result = projector.project_forward(volume, use_texture=False)
+
+    print(result)