From ef05762956075fd8d58021e92466f993d2a94f08 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 28 Oct 2019 17:29:15 +0100 Subject: [PATCH] Support dtype in DestructuringBindingsForFieldClass --- src/pystencils_autodiff/framework_integration/printer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py index 486dc38..f66da0d 100644 --- a/src/pystencils_autodiff/framework_integration/printer.py +++ b/src/pystencils_autodiff/framework_integration/printer.py @@ -56,18 +56,24 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): def _print_DestructuringBindingsForFieldClass(self, node): # Define all undefined symbols undefined_field_symbols = node.symbols_defined + fields_dtype = {u.field_name: + u.dtype.base_type for u in undefined_field_symbols if isinstance(u, FieldPointerSymbol)} destructuring_bindings = ["%s %s = %s.%s;" % (u.dtype, u.name, u.field_name if hasattr(u, 'field_name') else u.field_names[0], node.CLASS_TO_MEMBER_DICT[u.__class__].format( - dtype=(u.dtype.base_type if type(u) == FieldPointerSymbol else ""), + dtype=(u.dtype.base_type if type(u) == FieldPointerSymbol + else fields_dtype[u.field_name + if hasattr(u, 'field_name') + else u.field_names[0]]), field_name=(u.field_name if hasattr(u, "field_name") else ""), dim=("" if type(u) == FieldPointerSymbol else u.coordinate) ) ) for u in undefined_field_symbols ] + destructuring_bindings.sort() # only for code aesthetics return "{\n" + self._indent + \ ("\n" + self._indent).join(destructuring_bindings) + \ -- GitLab