diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py
index 4822425f1165aec4cd45cb35a83c986287c5555a..4bef57312a4ca5c8e3e4d0f52330e3a3b60cf530 100644
--- a/src/pystencils_autodiff/backends/_torch_native.py
+++ b/src/pystencils_autodiff/backends/_torch_native.py
@@ -1,6 +1,5 @@
 from collections import OrderedDict
 
-import sympy as sp
 from pystencils_autodiff.backends._pytorch import numpy_dtype_to_torch
 from pystencils_autodiff.backends.astnodes import TorchModule
 from pystencils_autodiff.tensorflow_jit import _hash
@@ -33,31 +32,47 @@ def create_autograd_function(autodiff_obj, use_cuda):
     # backward_wrapper_ast = [w for w in wrapper if w.function_name == "call_" + backward_ast.function_name][0]
     # forward_parameters = [str(p.symbol) for p in forward_wrapper_ast.get_parameters()]
     # backward_parameters = [str(p.symbol) for p in backward_wrapper_ast.get_parameters()]
+    class_kwargs = dict()
 
-    def forward(self, *args):
+    def forward(self, *args, **kwargs):
 
+        kwargs.update(class_kwargs)
+        # TODO: drop contiguous requirement
         if use_cuda:
             args = [a.contiguous().cuda() for a in args]
+            kwargs = {k: v.contiguous().cuda() for k, v in kwargs.items()}
         else:
             args = [a.contiguous().cpu() for a in args]
+            kwargs = {k: v.contiguous().cpu() for k, v in kwargs.items()}
+
+        # assert all(f.shape == args[i].shape for i, f in enumerate(autodiff_obj.forward_input_fields)
+        # if not any(isinstance(s, sp.Symbol) for s in args[i].shape))
+        # assert all(f.strides == tuple(args[i].stride(j) for j in range(args[i].ndim))
+        # for i, f in enumerate(autodiff_obj.forward_input_fields))
+        # for field in autodiff_obj.forward_output_fields:
+        # field_to_tensor_dict[field] = torch.zeros(
+        # field.shape,
+        # dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
+        # device=args[0].device)
+
+        kwargs.update({f.name: args[i] for i, f in enumerate(
+            autodiff_obj.forward_input_fields) if f in forward_ast.fields_accessed if i < len(args)})
 
-        input_tensors = dict()
-        input_tensors.update({f.name: args[i] for i, f in enumerate(
-            autodiff_obj.forward_input_fields) if f in forward_ast.fields_accessed})
-        assert all(f.shape == args[i].shape for i, f in enumerate(autodiff_obj.forward_input_fields)
-                   if not any(isinstance(s, sp.Symbol) for s in args[i].shape))
-        assert all(f.strides == tuple(args[i].stride(j) for j in range(args[i].ndim))
-                   for i, f in enumerate(autodiff_obj.forward_input_fields))
         for field in autodiff_obj.forward_output_fields:
-            field_to_tensor_dict[field] = torch.zeros(
-                field.shape,
-                dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
-                device=args[0].device)
-        output_tensors = OrderedDict({f.name: field_to_tensor_dict[f] for f in autodiff_obj.forward_output_fields})
+            if field.name not in kwargs:
+                kwargs[field.name] = torch.zeros(
+                    field.shape,
+                    dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
+                    device=args[0].device)
+        output_tensors = OrderedDict({f.name:
+                                      field_to_tensor_dict.get(f, kwargs[f.name])
+                                      for f in autodiff_obj.forward_output_fields})
+        field_to_tensor_dict.update(kwargs)
+        kwargs.update(output_tensors)
 
-        self.save_for_backward(*args)
+        self.saved_for_backward = kwargs
 
-        getattr(compiled_op, "call_" + forward_ast.function_name)(**input_tensors, **output_tensors)
+        getattr(compiled_op, "call_" + forward_ast.function_name)(**kwargs)
 
         return tuple(output_tensors.values())
 
@@ -70,23 +85,24 @@ def create_autograd_function(autodiff_obj, use_cuda):
         assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(autodiff_obj.backward_input_fields))
         assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim))
                    for i, f in enumerate(autodiff_obj.backward_input_fields))
-        assert all(a.is_cuda == use_cuda for a in grad_outputs), f"Some of the tensors where on the wrong device. " \
-            f"Op was compiled for CUDA: {str(use_cuda)}"
-        saved = {f.name: self.saved_tensors[i] for i, f in enumerate(
-            autodiff_obj.forward_input_fields) if f in backward_ast.fields_accessed}
-        for field in autodiff_obj.backward_output_fields:
-            field_to_tensor_dict[field] = torch.zeros(
-                field.shape,
-                dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
-                device=grad_outputs[0].device)
+        assert all(a.is_cuda == use_cuda for a in grad_outputs), "Some of the tensors where on the wrong device. "
+        f"Op was compiled for CUDA: {str(use_cuda)}"
 
-        backward_output_tensors = OrderedDict(
-            {f.name: field_to_tensor_dict[f] for f in autodiff_obj.backward_output_fields})
-        getattr(compiled_op, "call_" + backward_ast.function_name)(**gradients, **saved, **backward_output_tensors)
+        for field in autodiff_obj.backward_output_fields:
+            backward_output_tensors = OrderedDict({f.name: torch.zeros(field.shape,
+                                                                       dtype=numpy_dtype_to_torch(
+                                                                           field.dtype.numpy_dtype),
+                                                                       device=grad_outputs[0].device)
+                                                   for f in autodiff_obj.backward_output_fields})
+        field_names = [f.name for f in backward_ast.fields_accessed]
+        kwargs = {**gradients, **self.saved_for_backward, **backward_output_tensors}
+        kwargs = {k: v for k, v in kwargs.items() if k in field_names}
+        getattr(compiled_op, "call_" + backward_ast.function_name)(**kwargs)
 
         return tuple(backward_output_tensors.values())
 
     cls = type(op_name, (torch.autograd.Function,), {})
+    cls.class_kwargs = class_kwargs
     cls.forward = forward
     cls.backward = backward
     cls.kernel = forward