From 5955f84be8e5fdbe89753e23043b8570c468ba8b Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Tue, 17 Nov 2020 13:55:49 +0100 Subject: [PATCH] Apply dirty hack when adjoint calculation depends on size of forward fields but not on the forward fields them selves (just take size of corresponding adjoint fields) --- .../framework_integration/astnodes.py | 11 +++++------ .../framework_integration/printer.py | 9 ++++++--- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 351709a..c208151 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -151,12 +151,11 @@ class DestructuringBindingsForFieldClass(Node): undefined_field_symbols = self.symbols_defined corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')} corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')} - return {TypedSymbol(f, - self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype, - ndim=field_map[f].ndim) + ('&' - if self.ARGS_AS_REFERENCE - else '')) - for f in corresponding_field_names} | (self.body.undefined_symbols - undefined_field_symbols) + return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.format(dtype=(field_map.get(f) or field_map.get('diff' + f)).dtype, + ndim=(field_map.get(f) or field_map.get('diff' + f)).ndim) + ('&' + if self.ARGS_AS_REFERENCE + else '')) + for f in corresponding_field_names} | (self.body.undefined_symbols - undefined_field_symbols) def subs(self, subs_dict) -> None: """Inplace! substitute, similar to sympy's but modifies the AST inplace.""" diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py index 40e8e0b..e181fda 100644 --- a/src/pystencils_autodiff/framework_integration/printer.py +++ b/src/pystencils_autodiff/framework_integration/printer.py @@ -123,9 +123,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): + (node.field_suffix if hasattr(node, 'field_suffix') else ''), node.CLASS_TO_MEMBER_DICT[u.__class__].format( dtype=(u.dtype.base_type if type(u) == FieldPointerSymbol - else fields_dtype[u.field_name - if hasattr(u, 'field_name') - else u.field_names[0]]), + else ((fields_dtype.get(u.field_name + if hasattr(u, 'field_name') + else u.field_names[0])) + or (fields_dtype.get('diff' + u.field_name + if hasattr(u, 'field_name') + else 'diff' + u.field_names[0])))), field_name=(u.field_name if hasattr(u, "field_name") else ""), dim=("" if type(u) == FieldPointerSymbol else u.coordinate), dim_letter=("" if type(u) == FieldPointerSymbol else 'xyz'[u.coordinate]) -- GitLab