diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py
index 29677281fe9d9ff0a642445edd35244056fb5ba0..7fb78f8fb16dc1a176d0ea066d683bc4d6f28f2c 100644
--- a/src/pystencils_autodiff/_autodiff.py
+++ b/src/pystencils_autodiff/_autodiff.py
@@ -134,7 +134,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
                                         ] += sp.diff(forward_assignment.rhs, ra) * diff_write_field[inverted_offset]
 
                 for index in range(diff_read_field.index_shape[0]):
-                    if forward_read_field in self._time_constant_fields:
+                    if True:
                         # Accumulate in case of time_constant_fields
                         assignment = ps.Assignment(
                             diff_read_field.center_vector[index],
diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py
index d43cb905e18b405a10b6048a9ee647e9f03fa597..848da232acfef0d54b9a31fad6dd085567c8747f 100644
--- a/src/pystencils_autodiff/backends/_torch_native.py
+++ b/src/pystencils_autodiff/backends/_torch_native.py
@@ -59,11 +59,17 @@ def create_autograd_function(autodiff_obj, use_cuda):
 
         for field in autodiff_obj.forward_output_fields:
             if field.name not in kwargs:
-                kwargs[field.name] = torch.zeros(
-                    field.shape,
-                    dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
-                    device='cuda' if use_cuda else 'cpu')  # use device of tensor
-
+                try:
+                    kwargs[field.name] = torch.zeros(
+                        field.shape,
+                        dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
+                        device=next(chain(args, kwargs.values())).device)
+                except:
+                    shape = next(filter(lambda x: isinstance(x, torch.Tensor), chain(args, kwargs.values()))).shape
+                    kwargs[field.name] = torch.zeros(
+                        shape,
+                        dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
+                        device=next(chain(args, kwargs.values())).device)
         output_tensors = OrderedDict({f.name:
                                       field_to_tensor_dict.get(f, kwargs[f.name])
                                       for f in autodiff_obj.forward_output_fields})
diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py
index 1199e9cd5a44d935de2b79264cb78f49bee4df9c..4ace78066ff1475abcd34cf92ca72f02cb48a6de 100644
--- a/src/pystencils_autodiff/framework_integration/astnodes.py
+++ b/src/pystencils_autodiff/framework_integration/astnodes.py
@@ -25,6 +25,85 @@ from pystencils_autodiff.framework_integration.printer import FrameworkIntegrati
 from pystencils_autodiff.framework_integration.texture_astnodes import NativeTextureBinding
 
 
+class JinjaCppFile(Node):
+    TEMPLATE: jinja2.Template = None
+    NOT_PRINT_TYPES = (pystencils.Field, pystencils.TypedSymbol, bool)
+
+    def __init__(self, ast_dict={}):
+        self.ast_dict = pystencils.utils.DotDict(ast_dict)
+        self.printer = FrameworkIntegrationPrinter()
+        Node.__init__(self)
+
+    @property
+    def args(self):
+        """Returns all arguments/children of this node."""
+        ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, sp.Expr))]
+        iterables_of_ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, Iterable)
+                                  and not isinstance(a, str)]
+        return ast_nodes + list(itertools.chain.from_iterable(iterables_of_ast_nodes))
+
+    @property
+    def symbols_defined(self):
+        """Set of symbols which are defined by this node."""
+        return set(itertools.chain.from_iterable(a.symbols_defined
+                                                 for a in self.args
+                                                 if isinstance(a, Node)))
+
+    @property
+    def undefined_symbols(self):
+        """Symbols which are used but are not defined inside this node."""
+        return set(itertools.chain.from_iterable(a.undefined_symbols if isinstance(a, Node) else a.free_symbols
+                                                 for a in self.args
+                                                 if isinstance(a, (Node, sp.Expr)))) - self.symbols_defined
+
+    def _print(self, node):
+        if isinstance(node, (Node, sp.Expr)):
+            return self.printer(node)
+        else:
+            return str(node)
+
+    def atoms(self, arg_type) -> Set[Any]:
+        """Returns a set of all descendants recursively, which are an instance of the given type."""
+        if isinstance(self, arg_type):
+            result = {self}
+        else:
+            result = set()
+
+        for arg in self.args:
+            if isinstance(arg, arg_type):
+                result.add(arg)
+            if hasattr(arg, 'atoms'):
+                result.update(arg.atoms(arg_type))
+        return result
+
+    @property
+    def is_cuda(self):
+        return any(f.backend == 'gpucuda' for f in self.atoms(KernelFunction))
+
+    def __str__(self):
+        assert self.TEMPLATE, f"Template of {self.__class__} must be set"
+        render_dict = {k: (self._print(v)
+                           if not isinstance(v, self.NOT_PRINT_TYPES) and v is not None
+                           else v)
+                       if not isinstance(v, Iterable) or isinstance(v, str)
+                       else [(self._print(a)
+                              if not isinstance(a, self.NOT_PRINT_TYPES) and a is not None
+                              else a)
+                             for a in v]
+                       for k, v in self.ast_dict.items()}
+
+        render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)})
+        render_dict.update({"globals": sorted({
+            self.printer(g) for g in pystencils.backends.cbackend.get_global_declarations(self)
+        }, key=str)})
+        # self.TEMPLATE.environment = self.ENVIROMENT
+
+        return self.TEMPLATE.render(render_dict)
+
+    def __repr__(self):
+        return f'{str(self.__class__)}:\n {self.TEMPLATE.render(self.ast_dict)}'
+
+
 class DestructuringBindingsForFieldClass(Node):
     """
     Defines all variables needed for describing a field (shape, pointer, strides)
@@ -112,7 +191,15 @@ class FunctionCall(Node):
 
     @property
     def undefined_symbols(self) -> Set[sp.Symbol]:
-        return {p.symbol for p in self.kernel_function.get_parameters()}
+        rtn = {p.symbol for p in self.kernel_function.get_parameters()}
+        function = self.kernel_function
+        if function.backend == "gpucuda":
+            written_fields = function.fields_written
+            shape = list(written_fields)[0].spatial_shape
+            block_and_thread_numbers = function.indexing.call_parameters(shape)
+        rtn = rtn | set(itertools.chain.from_iterable(
+            (i.free_symbols for i in block_and_thread_numbers['block'] + block_and_thread_numbers['block'])))
+        return rtn
 
     def subs(self, subs_dict) -> None:
         for a in self.args:
@@ -167,85 +254,6 @@ def generate_kernel_call(kernel_function):
     return block
 
 
-class JinjaCppFile(Node):
-    TEMPLATE: jinja2.Template = None
-    NOT_PRINT_TYPES = (pystencils.Field, pystencils.TypedSymbol, bool)
-
-    def __init__(self, ast_dict={}):
-        self.ast_dict = pystencils.utils.DotDict(ast_dict)
-        self.printer = FrameworkIntegrationPrinter()
-        Node.__init__(self)
-
-    @property
-    def args(self):
-        """Returns all arguments/children of this node."""
-        ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, sp.Expr))]
-        iterables_of_ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, Iterable)
-                                  and not isinstance(a, str)]
-        return ast_nodes + list(itertools.chain.from_iterable(iterables_of_ast_nodes))
-
-    @property
-    def symbols_defined(self):
-        """Set of symbols which are defined by this node."""
-        return set(itertools.chain.from_iterable(a.symbols_defined
-                                                 for a in self.args
-                                                 if isinstance(a, Node)))
-
-    @property
-    def undefined_symbols(self):
-        """Symbols which are used but are not defined inside this node."""
-        return set(itertools.chain.from_iterable(a.undefined_symbols if isinstance(a, Node) else a.free_symbols
-                                                 for a in self.args
-                                                 if isinstance(a, (Node, sp.Expr)))) - self.symbols_defined
-
-    def _print(self, node):
-        if isinstance(node, (Node, sp.Expr)):
-            return self.printer(node)
-        else:
-            return str(node)
-
-    def atoms(self, arg_type) -> Set[Any]:
-        """Returns a set of all descendants recursively, which are an instance of the given type."""
-        if isinstance(self, arg_type):
-            result = {self}
-        else:
-            result = set()
-
-        for arg in self.args:
-            if isinstance(arg, arg_type):
-                result.add(arg)
-            if hasattr(arg, 'atoms'):
-                result.update(arg.atoms(arg_type))
-        return result
-
-    @property
-    def is_cuda(self):
-        return any(f.backend == 'gpucuda' for f in self.atoms(KernelFunction))
-
-    def __str__(self):
-        assert self.TEMPLATE, f"Template of {self.__class__} must be set"
-        render_dict = {k: (self._print(v)
-                           if not isinstance(v, self.NOT_PRINT_TYPES) and v is not None
-                           else v)
-                       if not isinstance(v, Iterable) or isinstance(v, str)
-                       else [(self._print(a)
-                              if not isinstance(a, self.NOT_PRINT_TYPES) and a is not None
-                              else a)
-                             for a in v]
-                       for k, v in self.ast_dict.items()}
-
-        render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)})
-        render_dict.update({"globals": sorted({
-            self.printer(g) for g in pystencils.backends.cbackend.get_global_declarations(self)
-        }, key=str)})
-        # self.TEMPLATE.environment = self.ENVIROMENT
-
-        return self.TEMPLATE.render(render_dict)
-
-    def __repr__(self):
-        return f'{str(self.__class__)}:\n {self.TEMPLATE.render(self.ast_dict)}'
-
-
 class CudaErrorCheckDefinition(CustomCodeNode):
     def __init__(self):
         super().__init__(self.code, [], [])
@@ -410,7 +418,7 @@ class CustomFunctionCall(JinjaCppFile):
 
     @property
     def undefined_symbols(self):
-        return set(self.ast_dict.args) 
+        return set(self.ast_dict.args)
 
     def subs(self, subs_dict):
         self.ast_dict.args = list(map(lambda x: x.subs(subs_dict), self.ast_dict.args))