diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py
index 61022bdc49717b86fd024de299212b9c3da246c4..b574c301ee4e41bac8b035f8b378f8b19e7a25b3 100644
--- a/src/pystencils_autodiff/backends/_torch_native.py
+++ b/src/pystencils_autodiff/backends/_torch_native.py
@@ -1,8 +1,8 @@
-import uuid
 from collections import OrderedDict
 
 from pystencils_autodiff.backends._pytorch import numpy_dtype_to_torch
 from pystencils_autodiff.backends.astnodes import TorchModule
+from pystencils_autodiff.tensorflow_jit import _hash
 
 try:
     import torch
@@ -30,28 +30,55 @@ def create_autograd_function(autodiff_obj, inputfield_to_tensor_dict):
         forward_ast = autodiff_obj.forward_ast_cpu
         backward_ast = autodiff_obj.backward_ast_cpu
 
-    op_name = autodiff_obj.op_name + str(uuid.uuid4())
-    compiled_op = TorchModule(op_name, [forward_ast, backward_ast])
-
-    output_tensors = OrderedDict({f.name: field_to_tensor_dict[f] for f in autodiff_obj.forward_output_fields})
-    backward_output_tensors = OrderedDict(
-        {f.name: field_to_tensor_dict[f] for f in autodiff_obj.backward_output_fields})
-
-    def forward(self, **input_tensors):
-
-        self.save_for_backward(**input_tensors)
+    op_name = f'{autodiff_obj.op_name}_{_hash(str(autodiff_obj).encode()).hexdigest()}'
+    module = TorchModule(op_name, [forward_ast, backward_ast])
+    compiled_op = module.compile()
+
+    # print(TorchModule(op_name, [forward_ast, backward_ast]))
+    # wrapper = module.atoms(WrapperFunction)
+    # forward_wrapper_ast = [w for w in wrapper if w.function_name == "call_" + forward_ast.function_name][0]
+    # backward_wrapper_ast = [w for w in wrapper if w.function_name == "call_" + backward_ast.function_name][0]
+    # forward_parameters = [str(p.symbol) for p in forward_wrapper_ast.get_parameters()]
+    # backward_parameters = [str(p.symbol) for p in backward_wrapper_ast.get_parameters()]
+
+    def forward(self, *args):
+        input_tensors = dict()
+        input_tensors.update({f.name: args[i] for i, f in enumerate(
+            autodiff_obj.forward_input_fields) if f in forward_ast.fields_accessed})
+        assert all(f.shape == args[i].shape for i, f in enumerate(autodiff_obj.forward_input_fields))
+        assert all(f.strides == tuple(args[i].stride(j) for j in range(args[i].ndim))
+                   for i, f in enumerate(autodiff_obj.forward_input_fields))
+        for field in autodiff_obj.forward_output_fields:
+            field_to_tensor_dict[field] = torch.zeros(
+                field.shape,
+                dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
+                device=list(inputfield_to_tensor_dict.values())[0].device)
+        output_tensors = OrderedDict({f.name: field_to_tensor_dict[f] for f in autodiff_obj.forward_output_fields})
+
+        self.save_for_backward(*args)
 
         getattr(compiled_op, "call_" + forward_ast.function_name)(**input_tensors, **output_tensors)
 
-        return output_tensors.values()
+        return tuple(output_tensors.values())
 
     def backward(self, *grad_outputs):
         gradients = {f.name: grad_outputs[i] for i, f in enumerate(autodiff_obj.backward_input_fields)}
-        saved = self.saved_tensors
-
+        assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(autodiff_obj.backward_input_fields))
+        assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim))
+                   for i, f in enumerate(autodiff_obj.backward_input_fields))
+        saved = {f.name: self.saved_tensors[i] for i, f in enumerate(
+            autodiff_obj.forward_input_fields) if f in backward_ast.fields_accessed}
+        for field in autodiff_obj.backward_output_fields:
+            field_to_tensor_dict[field] = torch.zeros(
+                field.shape,
+                dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
+                device=list(inputfield_to_tensor_dict.values())[0].device)
+
+        backward_output_tensors = OrderedDict(
+            {f.name: field_to_tensor_dict[f] for f in autodiff_obj.backward_output_fields})
         getattr(compiled_op, "call_" + backward_ast.function_name)(**gradients, **saved, **backward_output_tensors)
 
-        return backward_output_tensors.values()
+        return tuple(backward_output_tensors.values())
 
     cls = type(op_name, (torch.autograd.Function,), {})
     cls.forward = forward
diff --git a/tests/backends/test_torch_native_compilation.py b/tests/backends/test_torch_native_compilation.py
index 2d01db730754f57f15013431a342232ae87478d0..5814ca50eb93dd1af056909636360f73fcbfcb02 100644
--- a/tests/backends/test_torch_native_compilation.py
+++ b/tests/backends/test_torch_native_compilation.py
@@ -13,7 +13,7 @@ import sympy
 
 import pystencils
 from pystencils_autodiff import create_backward_assignments
-from pystencils_autodiff._file_io import write_cached_content, write_file
+from pystencils_autodiff._file_io import write_cached_content
 from pystencils_autodiff.backends.astnodes import PybindModule, TorchModule
 
 torch = pytest.importorskip('torch')
diff --git a/tests/test_tfmad.py b/tests/test_tfmad.py
index 7dd10e1d6a642b127744f07593fa02332e3cccb4..7065d9c0c3e65f14f5b02f074f99153566367304 100644
--- a/tests/test_tfmad.py
+++ b/tests/test_tfmad.py
@@ -165,10 +165,10 @@ def test_tfmad_gradient_check_torch():
 
     a, b, out = ps.fields("a, b, out: float[21,13]")
 
-    cont = ps.fd.Diff(a, 0) - ps.fd.Diff(a, 1) - \
-        ps.fd.Diff(b, 0) + ps.fd.Diff(b, 1)
+    cont = 2*ps.fd.Diff(a, 0) - 1.5*ps.fd.Diff(a, 1) - \
+        ps.fd.Diff(b, 0) + 3 * ps.fd.Diff(b, 1)
     discretize = ps.fd.Discretization2ndOrder(dx=1)
-    discretization = discretize(cont)
+    discretization = discretize(cont) + 1.2*a.center
 
     assignment = ps.Assignment(out.center(), discretization)
     assignment_collection = ps.AssignmentCollection([assignment], [])
@@ -189,12 +189,51 @@ def test_tfmad_gradient_check_torch():
     function = auto_diff.create_tensorflow_op({
         a: a_tensor,
         b: b_tensor
-    },
-                                              backend='torch')
+    }, backend='torch')
 
     torch.autograd.gradcheck(function.apply, [a_tensor, b_tensor])
 
 
+@pytest.mark.parametrize('with_offsets', (True, False))
+def test_tfmad_gradient_check_torch_native(with_offsets):
+    torch = pytest.importorskip('torch')
+
+    a, b, out = ps.fields("a, b, out: float64[21,13]")
+
+    if with_offsets:
+        cont = 2*ps.fd.Diff(a, 0) - 1.5*ps.fd.Diff(a, 1) - ps.fd.Diff(b, 0) + 3 * ps.fd.Diff(b, 1)
+        discretize = ps.fd.Discretization2ndOrder(dx=1)
+        discretization = discretize(cont)
+
+        assignment = ps.Assignment(out.center(), discretization + 1.2*a.center())
+    else:
+        assignment = ps.Assignment(out.center(), 1.2*a.center + 0.1*b.center)
+    assignment_collection = ps.AssignmentCollection([assignment], [])
+    print('Forward')
+    print(assignment_collection)
+
+    print('Backward')
+    auto_diff = pystencils_autodiff.AutoDiffOp(assignment_collection,
+                                               diff_mode='transposed-forward')
+    backward = auto_diff.backward_assignments
+    print(backward)
+    print('Forward output fields (to check order)')
+    print(auto_diff.forward_input_fields)
+
+    a_tensor = torch.zeros(*a.shape, dtype=torch.float64, requires_grad=True).contiguous()
+    b_tensor = torch.zeros(*b.shape, dtype=torch.float64, requires_grad=True).contiguous()
+
+    dict = {
+        a: a_tensor,
+        b: b_tensor
+    }
+    function = auto_diff.create_tensorflow_op(dict, backend='torch_native')
+
+    import torch
+    torch.autograd.gradcheck(function.apply, tuple(
+        [dict[f] for f in auto_diff.forward_input_fields]), atol=1e-4, raise_exception=True)
+
+
 def get_curl(input_field: ps.Field, curl_field: ps.Field):
     """Return a ps.AssignmentCollection describing the calculation of
     the curl given a 2d or 3d vector field [z,y,x](f) or [y,x](f)
@@ -256,3 +295,7 @@ def test_tfmad_two_outputs():
 
     print('Backward')
     print(curl_op.backward_assignments)
+
+
+if __name__ == '__main__':
+    test_tfmad_gradient_check_torch_native()