diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 24872a82d2f9240990308ae484b0f3b633e54f95..131816316810cf4987b65cb1b3c92c28441a681f 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -166,7 +166,7 @@ class JinjaCppFile(Node): TEMPLATE: jinja2.Template = None def __init__(self, ast_dict): - self.ast_dict = ast_dict + self.ast_dict = pystencils.utils.DotDict(ast_dict) self.printer = FrameworkIntegrationPrinter() Node.__init__(self) @@ -209,15 +209,16 @@ class JinjaCppFile(Node): def __str__(self): assert self.TEMPLATE, f"Template of {self.__class__} must be set" - render_dict = {k: (self._print(v) if not isinstance(v, (pystencils.Field, pystencils.TypedSymbol)) else v) + render_dict = {k: (self._print(v) + if not isinstance(v, (pystencils.Field, pystencils.TypedSymbol)) and v is not None + else v) if not isinstance(v, Iterable) or isinstance(v, str) else [(self._print(a) - if not isinstance(a, (pystencils.Field, pystencils.TypedSymbol)) + if not isinstance(a, (pystencils.Field, pystencils.TypedSymbol) and a is not None) else a) for a in v] for k, v in self.ast_dict.items()} - # TODO: possibly costly tree traversal render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)}) render_dict.update({"globals": sorted({ self.printer(g) for g in pystencils.backends.cbackend.get_global_declarations(self)