diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index f530c975fb064053403867de7758bab59a64db92..efe5e6aeedbc001768b0db246a79bc72e84ef83e 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -400,10 +400,23 @@ class CustomFunctionCall(JinjaCppFile): def symbols_defined(self): return set(self.ast_dict.fields_accessed) + @property + def fields_accessed(self): + return [f.name for f in self.ast_dict.fields_accessed] + @property def function_name(self): return self.ast_dict.function_name @property def undefined_symbols(self): - return set(self.ast_dict.args) + return set(self.ast_dict.args) + + def subs(self, subs_dict): + self.ast_dict.args = list(map(lambda x: x.subs(subs_dict), self.ast_dict.args)) + + def atoms(self, types=None): + if types: + return set(a for a in self.args if isinstance(a, types)) + else: + return set(self.args)