From ef05762956075fd8d58021e92466f993d2a94f08 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 28 Oct 2019 17:29:15 +0100
Subject: [PATCH] Support dtype in DestructuringBindingsForFieldClass

---
 src/pystencils_autodiff/framework_integration/printer.py | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py
index 486dc38..f66da0d 100644
--- a/src/pystencils_autodiff/framework_integration/printer.py
+++ b/src/pystencils_autodiff/framework_integration/printer.py
@@ -56,18 +56,24 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
     def _print_DestructuringBindingsForFieldClass(self, node):
         # Define all undefined symbols
         undefined_field_symbols = node.symbols_defined
+        fields_dtype = {u.field_name:
+                        u.dtype.base_type for u in undefined_field_symbols if isinstance(u, FieldPointerSymbol)}
         destructuring_bindings = ["%s %s = %s.%s;" %
                                   (u.dtype,
                                    u.name,
                                    u.field_name if hasattr(u, 'field_name') else u.field_names[0],
                                    node.CLASS_TO_MEMBER_DICT[u.__class__].format(
-                                       dtype=(u.dtype.base_type if type(u) == FieldPointerSymbol else ""),
+                                       dtype=(u.dtype.base_type if type(u) == FieldPointerSymbol
+                                              else fields_dtype[u.field_name
+                                                                if hasattr(u, 'field_name')
+                                                                else u.field_names[0]]),
                                        field_name=(u.field_name if hasattr(u, "field_name") else ""),
                                        dim=("" if type(u) == FieldPointerSymbol else u.coordinate)
                                    )
                                    )
                                   for u in undefined_field_symbols
                                   ]
+
         destructuring_bindings.sort()  # only for code aesthetics
         return "{\n" + self._indent + \
             ("\n" + self._indent).join(destructuring_bindings) + \
-- 
GitLab