diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index b7c78dd7d0ee2167021381c0ab53130a0fefde8e..f530c975fb064053403867de7758bab59a64db92 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -206,7 +206,11 @@ class JinjaCppFile(Node): def atoms(self, arg_type) -> Set[Any]: """Returns a set of all descendants recursively, which are an instance of the given type.""" - result = set() + if isinstance(self, arg_type): + result = {self} + else: + result = set() + for arg in self.args: if isinstance(arg, arg_type): result.add(arg) @@ -374,12 +378,13 @@ class CustomFunctionDeclaration(JinjaCppFile): class CustomFunctionCall(JinjaCppFile): TEMPLATE = jinja2.Template("""{{function_name}}({{ args | join(', ') }});""", undefined=jinja2.StrictUndefined) - def __init__(self, function_name, *args, fields_accessed=[], custom_signature=None): + def __init__(self, function_name, *args, fields_accessed=[], custom_signature=None, backend='c'): ast_dict = { 'function_name': function_name, 'args': args, 'fields_accessed': [f.center for f in fields_accessed] } + self._backend = backend super().__init__(ast_dict) if custom_signature: self.required_global_declarations = [CustomCodeNode(custom_signature, (), ())] @@ -387,6 +392,10 @@ class CustomFunctionCall(JinjaCppFile): self.required_global_declarations = [CustomFunctionDeclaration( self.ast_dict.function_name, self.ast_dict.args)] + @property + def backend(self): + return self._backend + @property def symbols_defined(self): return set(self.ast_dict.fields_accessed)