diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index f507fa6449c94e80172177c11e080d412a5afb53..442a3a5eefb5fabb865a9a676eb24f6786dd80d8 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -19,8 +19,6 @@ from pystencils_autodiff.backends.python_bindings import (
 from pystencils_autodiff.framework_integration.astnodes import (
     DestructuringBindingsForFieldClass, JinjaCppFile, WrapperFunction, generate_kernel_call)
 
-# Torch
-
 
 class TorchTensorDestructuring(DestructuringBindingsForFieldClass):
     CLASS_TO_MEMBER_DICT = {
diff --git a/src/pystencils_autodiff/backends/python_bindings.py b/src/pystencils_autodiff/backends/python_bindings.py
index 80a5a2e68f965ccf00fa256f46ed58192f1c726c..3cae1a83d7f6b09203018aab4452a807b5a074ef 100644
--- a/src/pystencils_autodiff/backends/python_bindings.py
+++ b/src/pystencils_autodiff/backends/python_bindings.py
@@ -101,8 +101,6 @@ REGISTER_KERNEL_BUILDER(Name("{{ python_name }}").Device({{ device }}), {{ pytho
         parameters = function_node.get_parameters()
         output_shape = str(output_fields[0].shape).replace('(', '{').replace(')', '}')  # noqa,  TODO make work for flexible sizes
 
-        print([f for f in function_node.atoms(Node)])
-
         docstring = "TODO"  # TODO
 
         # this looks almost like lisp 😕
diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py
index b68abd1a8b79345a857ddc49777d7a0099eadba6..76e887c83264c36f9cbd8ec17cf568fe406c50d6 100644
--- a/src/pystencils_autodiff/tensorflow_jit.py
+++ b/src/pystencils_autodiff/tensorflow_jit.py
@@ -7,6 +7,7 @@
 """
 
 """
+import json
 import subprocess
 import sysconfig
 from itertools import chain
@@ -24,11 +25,13 @@ if get_compiler_config()['os'] != 'windows':
     _output_flag = '-o'
     _include_flags = ['-I' + sysconfig.get_paths()['include'], '-I' + get_pystencils_include_path()]
     _do_not_link_flag = "-c"
+    _position_independent_flag = "-fPIC"
 else:
     _do_not_link_flag = "/c"
     _output_flag = '/OUT:'
     _shared_object_flag = '/DLL'
     _include_flags = ['/I' + sysconfig.get_paths()['include'], '/I' + get_pystencils_include_path()]
+    _position_independent_flag = "/DTHIS_FLAG_DOES_NOTHING"
 
 
 try:
@@ -40,7 +43,7 @@ except ImportError:
     pass
 
 
-def link_and_load(object_files, destination_file=None, link_cudart=False, overwrite_destination_file=True):
+def link_and_load(object_files, destination_file=None, overwrite_destination_file=True):
     """Compiles given :param:`source_file` to a Tensorflow shared Library.
 
     .. warning::
@@ -58,15 +61,13 @@ def link_and_load(object_files, destination_file=None, link_cudart=False, overwr
 
     if not exists(destination_file) or overwrite_destination_file:
         command = [get_compiler_config()['command'],
-                   *(get_compiler_config()['flags']).split(' '),
+                   _position_independent_flag,
                    *object_files,
                    *_tf_link_flags,
-                   *_tf_compile_flags,
                    *_include_flags,
                    _shared_object_flag,
-                   _output_flag + destination_file]  # /out: for msvc???
-        if link_cudart:
-            command.append('-lcudart')
+                   _output_flag,
+                   destination_file]  # /out: for msvc???
 
         subprocess.check_call(command)
 
@@ -81,34 +82,46 @@ def try_get_cuda_arch_flag():
     except Exception:
         arch = None
     if arch:
-        return "-arch " + arch
+        return "-arch=" + arch
     else:
         return None
 
 
 _cuda_arch_flag = try_get_cuda_arch_flag()
 
+_nvcc_flags = ["-w", "-std=c++14", "-Wno-deprecated-gpu-targets"]
+
+
+if _cuda_arch_flag:
+    _nvcc_flags.append(_cuda_arch_flag)
+if pystencils.gpucuda.cudajit.USE_FAST_MATH:
+    _nvcc_flags.append('-use_fast_math')
+
 
 def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=True):
+    if 'tensorflow_host_compiler' not in get_compiler_config():
+        get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command']
+        write_file(pystencils.cpu.cpujit.get_configuration_file_path(), json.dumps(pystencils.cpu.cpujit._config))
 
     destination_file = file + '.o'
     if use_nvcc:
         command = [nvcc,
                    '--expt-relaxed-constexpr',
                    '-ccbin',
-                   get_compiler_config()['command'],
-                   *(get_compiler_config()['flags']).split(' '),
+                   get_compiler_config()['tensorflow_host_compiler'],
+                   '-Xcompiler',
+                   get_compiler_config()['flags'].replace('c++11', 'c++14'),
+                   *_nvcc_flags,
                    file,
                    '-x',
                    'cu',
                    '-Xcompiler',
-                   '-fPIC',  # TODO: msvc!
+                   _position_independent_flag,
                    _do_not_link_flag,
                    *_tf_compile_flags,
                    *_include_flags,
-                   _output_flag + destination_file]
-        if _cuda_arch_flag:
-            command.append(_cuda_arch_flag)
+                   _output_flag,
+                   destination_file]
     else:
         command = [get_compiler_config()['command'],
                    *(get_compiler_config()['flags']).split(' '),
@@ -116,7 +129,9 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
                    _do_not_link_flag,
                    *_tf_compile_flags,
                    *_include_flags,
-                   _output_flag + destination_file]
+                   _output_flag,
+                   destination_file]
+
     if not exists(destination_file) or overwrite_destination_file:
         subprocess.check_call(command)
     return destination_file
@@ -126,7 +141,7 @@ def compile_sources_and_load(host_sources, cuda_sources=[]):
 
     object_files = []
 
-    for source in tqdm(chain(host_sources, cuda_sources), desc='Compiling Tensorflow module...'):
+    for source in tqdm(host_sources + cuda_sources, desc='Compiling Tensorflow module'):
         is_cuda = source in cuda_sources
 
         if exists(source):
@@ -139,10 +154,10 @@ def compile_sources_and_load(host_sources, cuda_sources=[]):
         write_file(file_name, source_code)
 
         compile_file(file_name, use_nvcc=is_cuda, overwrite_destination_file=False)
-        object_files.append(file_name)
+        object_files.append(file_name + '.o')
 
     print('Linking Tensorflow module...')
-    module = link_and_load(object_files, overwrite_destination_file=False, link_cudart=cuda_sources or False)
+    module = link_and_load(object_files, overwrite_destination_file=False)
     if module:
-        print('Loaded Tensorflow module')
+        print('Loaded Tensorflow module.')
     return module
diff --git a/tests/backends/test_torch_native_compilation.py b/tests/backends/test_torch_native_compilation.py
index 9a1b9060901fa6959ba968c378d0d4d35ac07e3d..4e88075a585bae7c619c3407f4068d665a30ac92 100644
--- a/tests/backends/test_torch_native_compilation.py
+++ b/tests/backends/test_torch_native_compilation.py
@@ -5,11 +5,11 @@
 
 import os
 import subprocess
+import tempfile
 from os.path import dirname, isfile, join
 
 import pytest
 import sympy
-import tempfile
 
 import pystencils
 from pystencils_autodiff import create_backward_assignments
@@ -22,6 +22,7 @@ pytestmark = pytest.mark.skipif(subprocess.call(['ninja', '--v']) != 0,
 
 PROJECT_ROOT = dirname
 
+
 @pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS")
 def test_torch_jit():
     """
diff --git a/tests/test_tensorflow_jit.py b/tests/test_tensorflow_jit.py
index 6bc183e8c729ccc3b7afcc6c9c99ba363fd62671..251cbb2b71312ef905e7b37a964a004e00a28d74 100644
--- a/tests/test_tensorflow_jit.py
+++ b/tests/test_tensorflow_jit.py
@@ -39,8 +39,36 @@ def test_tensorflow_jit_cpu():
     backward_ast = pystencils.create_kernel(backward_assignments, target)
     backward_ast.function_name = 'backward'
     module = TensorflowModule(module_name, [forward_ast, backward_ast])
-    print(module)
 
     lib = pystencils_autodiff.tensorflow_jit.compile_sources_and_load([str(module)])
     assert 'call_forward' in dir(lib)
     assert 'call_backward' in dir(lib)
+
+
+def test_tensorflow_jit_gpu():
+
+    pytest.importorskip('tensorflow')
+
+    module_name = "Ololol"
+
+    target = 'gpu'
+
+    z, y, x = pystencils.fields("z, y, x: [20,40]")
+    a = sympy.Symbol('a')
+
+    forward_assignments = pystencils.AssignmentCollection({
+        z[0, 0]: x[0, 0] * sympy.log(a * x[0, 0] * y[0, 0])
+    })
+
+    backward_assignments = create_backward_assignments(forward_assignments)
+
+    forward_ast = pystencils.create_kernel(forward_assignments, target)
+    forward_ast.function_name = 'forward'
+    backward_ast = pystencils.create_kernel(backward_assignments, target)
+    backward_ast.function_name = 'backward'
+    module = TensorflowModule(module_name, [forward_ast, backward_ast])
+
+    lib = pystencils_autodiff.tensorflow_jit.compile_sources_and_load([], [str(module)])
+    assert 'call_forward' in dir(lib)
+    assert 'call_backward' in dir(lib)
+