From 5955f84be8e5fdbe89753e23043b8570c468ba8b Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Tue, 17 Nov 2020 13:55:49 +0100
Subject: [PATCH] Apply dirty hack when adjoint calculation depends on size of
 forward fields but not on the forward fields them selves (just take size of
 corresponding adjoint fields)

---
 .../framework_integration/astnodes.py                 | 11 +++++------
 .../framework_integration/printer.py                  |  9 ++++++---
 2 files changed, 11 insertions(+), 9 deletions(-)

diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py
index 351709a..c208151 100644
--- a/src/pystencils_autodiff/framework_integration/astnodes.py
+++ b/src/pystencils_autodiff/framework_integration/astnodes.py
@@ -151,12 +151,11 @@ class DestructuringBindingsForFieldClass(Node):
         undefined_field_symbols = self.symbols_defined
         corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')}
         corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')}
-        return {TypedSymbol(f,
-                            self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype,
-                                                            ndim=field_map[f].ndim) + ('&'
-                                                                                       if self.ARGS_AS_REFERENCE
-                                                                                       else ''))
-                for f in corresponding_field_names} | (self.body.undefined_symbols - undefined_field_symbols)
+        return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.format(dtype=(field_map.get(f) or field_map.get('diff' + f)).dtype,
+            ndim=(field_map.get(f) or field_map.get('diff' + f)).ndim) + ('&'
+                if self.ARGS_AS_REFERENCE
+                else ''))
+            for f in corresponding_field_names} | (self.body.undefined_symbols - undefined_field_symbols)
 
     def subs(self, subs_dict) -> None:
         """Inplace! substitute, similar to sympy's but modifies the AST inplace."""
diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py
index 40e8e0b..e181fda 100644
--- a/src/pystencils_autodiff/framework_integration/printer.py
+++ b/src/pystencils_autodiff/framework_integration/printer.py
@@ -123,9 +123,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
                                    + (node.field_suffix if hasattr(node, 'field_suffix') else ''),
                                    node.CLASS_TO_MEMBER_DICT[u.__class__].format(
                                        dtype=(u.dtype.base_type if type(u) == FieldPointerSymbol
-                                              else fields_dtype[u.field_name
-                                                                if hasattr(u, 'field_name')
-                                                                else u.field_names[0]]),
+                                              else ((fields_dtype.get(u.field_name
+                                                                      if hasattr(u, 'field_name')
+                                                                      else u.field_names[0]))
+                                                    or (fields_dtype.get('diff' + u.field_name
+                                                                         if hasattr(u, 'field_name')
+                                                                         else 'diff' + u.field_names[0])))),
                                        field_name=(u.field_name if hasattr(u, "field_name") else ""),
                                        dim=("" if type(u) == FieldPointerSymbol else u.coordinate),
                                        dim_letter=("" if type(u) == FieldPointerSymbol else 'xyz'[u.coordinate])
-- 
GitLab