diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 351709aeed813468ed2931c327f169c9569cd588..c2081513b59779befd6d42eb6f78a0e4330c0f39 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -151,12 +151,11 @@ class DestructuringBindingsForFieldClass(Node): undefined_field_symbols = self.symbols_defined corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')} corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')} - return {TypedSymbol(f, - self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype, - ndim=field_map[f].ndim) + ('&' - if self.ARGS_AS_REFERENCE - else '')) - for f in corresponding_field_names} | (self.body.undefined_symbols - undefined_field_symbols) + return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.format(dtype=(field_map.get(f) or field_map.get('diff' + f)).dtype, + ndim=(field_map.get(f) or field_map.get('diff' + f)).ndim) + ('&' + if self.ARGS_AS_REFERENCE + else '')) + for f in corresponding_field_names} | (self.body.undefined_symbols - undefined_field_symbols) def subs(self, subs_dict) -> None: """Inplace! substitute, similar to sympy's but modifies the AST inplace.""" diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py index 40e8e0b8e6d2cad2f94a70502a4385d6e85a4f3b..e181fdaa189be82f74320266867439bfa53b1781 100644 --- a/src/pystencils_autodiff/framework_integration/printer.py +++ b/src/pystencils_autodiff/framework_integration/printer.py @@ -123,9 +123,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): + (node.field_suffix if hasattr(node, 'field_suffix') else ''), node.CLASS_TO_MEMBER_DICT[u.__class__].format( dtype=(u.dtype.base_type if type(u) == FieldPointerSymbol - else fields_dtype[u.field_name - if hasattr(u, 'field_name') - else u.field_names[0]]), + else ((fields_dtype.get(u.field_name + if hasattr(u, 'field_name') + else u.field_names[0])) + or (fields_dtype.get('diff' + u.field_name + if hasattr(u, 'field_name') + else 'diff' + u.field_names[0])))), field_name=(u.field_name if hasattr(u, "field_name") else ""), dim=("" if type(u) == FieldPointerSymbol else u.coordinate), dim_letter=("" if type(u) == FieldPointerSymbol else 'xyz'[u.coordinate])