diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py
index 1dd21e587bc275d84e5d8d86f9d67d7bde80f293..11c0d15394bae88ca39ae5589c7fe1495b6cab9a 100644
--- a/src/pystencils_autodiff/tensorflow_jit.py
+++ b/src/pystencils_autodiff/tensorflow_jit.py
@@ -40,6 +40,7 @@ if get_compiler_config()['os'] != 'windows':
     _compile_env = os.environ.copy()
     _object_file_extension = '.o'
     _link_cudart = '-lcudart'
+    _openmp_flag = '-fopenmp'
 else:
     _do_not_link_flag = '-c'
     _output_flag = '-o'
@@ -54,6 +55,7 @@ else:
     _compile_env.update(config_env)
     _object_file_extension = '.obj'
     _link_cudart = '/link cudart'  # ???
+    _openmp_flag = '/openmp'  # ???
 
 
 def link(object_files,
@@ -153,6 +155,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
                           'cu',
                           '-Xcompiler',
                           _position_independent_flag,
+                          _openmp_flag,
                           _do_not_link_flag,
                           *tf.sysconfig.get_compile_flags(),
                           *_include_flags,
@@ -163,6 +166,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
                           *(get_compiler_config()['flags']).split(' '),
                           file,
                           _do_not_link_flag,
+                          _openmp_flag,
                           *tf.sysconfig.get_compile_flags(),
                           *_include_flags,
                           *additional_compile_flags,