diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py index 486dc38d005f95274b3142bf2f610291d0bd1c1d..f66da0d73dc2da9a058ea02ac8fbfd05a16185cb 100644 --- a/src/pystencils_autodiff/framework_integration/printer.py +++ b/src/pystencils_autodiff/framework_integration/printer.py @@ -56,18 +56,24 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): def _print_DestructuringBindingsForFieldClass(self, node): # Define all undefined symbols undefined_field_symbols = node.symbols_defined + fields_dtype = {u.field_name: + u.dtype.base_type for u in undefined_field_symbols if isinstance(u, FieldPointerSymbol)} destructuring_bindings = ["%s %s = %s.%s;" % (u.dtype, u.name, u.field_name if hasattr(u, 'field_name') else u.field_names[0], node.CLASS_TO_MEMBER_DICT[u.__class__].format( - dtype=(u.dtype.base_type if type(u) == FieldPointerSymbol else ""), + dtype=(u.dtype.base_type if type(u) == FieldPointerSymbol + else fields_dtype[u.field_name + if hasattr(u, 'field_name') + else u.field_names[0]]), field_name=(u.field_name if hasattr(u, "field_name") else ""), dim=("" if type(u) == FieldPointerSymbol else u.coordinate) ) ) for u in undefined_field_symbols ] + destructuring_bindings.sort() # only for code aesthetics return "{\n" + self._indent + \ ("\n" + self._indent).join(destructuring_bindings) + \