diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py
index 9fae28c5f8ebf28bca5f5f95e4c10d04b20edd9b..48417af13d375c9c624d7b783bb207b437a07d5d 100644
--- a/src/pystencils_autodiff/backends/_torch_native.py
+++ b/src/pystencils_autodiff/backends/_torch_native.py
@@ -19,15 +19,16 @@ def create_autograd_function(autodiff_obj, use_cuda):
 
     if use_cuda:
         forward_ast = autodiff_obj.forward_ast_gpu
-        backward_ast = autodiff_obj.backward_ast_gpu
+        backward_ast = autodiff_obj.backward_ast_gpu if autodiff_obj.backward_output_fields else None
     else:
         forward_ast = autodiff_obj.forward_ast_cpu
-        backward_ast = autodiff_obj.backward_ast_cpu
+        backward_ast = autodiff_obj.backward_ast_cpu if autodiff_obj.backward_output_fields else None
 
-    op_name = f'{autodiff_obj.op_name}_{_hash((str(pystencils.show_code(forward_ast))+ str(autodiff_obj)).encode()).hexdigest()}'  # noqa
+    op_name = f'{autodiff_obj.op_name}_{_hash((str(pystencils.show_code(forward_ast)) + str(autodiff_obj)+str(autodiff_obj.constant_fields)).encode()).hexdigest()}'  # noqa
     forward_ast.function_name = f'{op_name}_{forward_ast.function_name}'
-    backward_ast.function_name = f'{op_name}_{backward_ast.function_name}'
-    module = TorchModule(op_name, [forward_ast, backward_ast])
+    if backward_ast:
+        backward_ast.function_name = f'{op_name}_{backward_ast.function_name}'
+    module = TorchModule(op_name, [forward_ast, backward_ast] if backward_ast else [forward_ast])
     compiled_op = module.compile()
 
     # print(TorchModule(op_name, [forward_ast, backward_ast]))
@@ -103,16 +104,26 @@ def create_autograd_function(autodiff_obj, use_cuda):
 
         return tuple(backward_output_tensors.values())
 
+    def call(self, **kwargs):
+        rtn = self.apply(*[kwargs[p.symbol.name] for p in self.forward_parameters])
+        if len(rtn) == 1:
+            rtn = rtn[0]
+        return rtn
+
     cls = type(op_name, (torch.autograd.Function,), {})
     cls.class_kwargs = class_kwargs
     cls.forward = forward
     cls.backward = backward
     cls.kernel = forward
     cls.ast = module
-    cls.parameters = forward_ast.get_parameters()
+    cls.parameters = [f for f in module.kernel_wrappers
+                      if f.function_name == "call_" + forward_ast.function_name][0].get_parameters()
+    cls.forward_parameters = [p for p in cls.parameters if p.symbol.name in [
+        f.name for f in autodiff_obj.forward_input_fields]]
     cls.forward_ast = forward_ast
     cls.backward_ast = backward_ast
     cls.num_regs = None
+    cls.call = call
     cls.code = str(module)
 
     return cls