diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py
index b7c78dd7d0ee2167021381c0ab53130a0fefde8e..f530c975fb064053403867de7758bab59a64db92 100644
--- a/src/pystencils_autodiff/framework_integration/astnodes.py
+++ b/src/pystencils_autodiff/framework_integration/astnodes.py
@@ -206,7 +206,11 @@ class JinjaCppFile(Node):
 
     def atoms(self, arg_type) -> Set[Any]:
         """Returns a set of all descendants recursively, which are an instance of the given type."""
-        result = set()
+        if isinstance(self, arg_type):
+            result = {self}
+        else:
+            result = set()
+
         for arg in self.args:
             if isinstance(arg, arg_type):
                 result.add(arg)
@@ -374,12 +378,13 @@ class CustomFunctionDeclaration(JinjaCppFile):
 class CustomFunctionCall(JinjaCppFile):
     TEMPLATE = jinja2.Template("""{{function_name}}({{ args | join(', ') }});""", undefined=jinja2.StrictUndefined)
 
-    def __init__(self, function_name, *args, fields_accessed=[], custom_signature=None):
+    def __init__(self, function_name, *args, fields_accessed=[], custom_signature=None, backend='c'):
         ast_dict = {
             'function_name': function_name,
             'args': args,
             'fields_accessed': [f.center for f in fields_accessed]
         }
+        self._backend = backend
         super().__init__(ast_dict)
         if custom_signature:
             self.required_global_declarations = [CustomCodeNode(custom_signature, (), ())]
@@ -387,6 +392,10 @@ class CustomFunctionCall(JinjaCppFile):
             self.required_global_declarations = [CustomFunctionDeclaration(
                 self.ast_dict.function_name, self.ast_dict.args)]
 
+    @property
+    def backend(self):
+        return self._backend
+
     @property
     def symbols_defined(self):
         return set(self.ast_dict.fields_accessed)