diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 51101c3e5eac361d1d232d874006eb00297472b2..18a909282791b3d533d7bd16f0c5a5bdfb460b49 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -35,6 +35,7 @@ class DestructuringBindingsForFieldClass(Node): FieldStrideSymbol: "stride[{dim}]" } CLASS_NAME_TEMPLATE = "PyStencilsField<{dtype}, {ndim}>" + ARGS_AS_REFERENCE = True @property def fields_accessed(self) -> Set['ResolvedFieldAccess']: @@ -70,7 +71,11 @@ class DestructuringBindingsForFieldClass(Node): undefined_field_symbols = self.symbols_defined corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')} corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')} - return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&') + return {TypedSymbol(f, + self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype, + ndim=field_map[f].ndim) + ('&' + if self.ARGS_AS_REFERENCE + else '')) for f in corresponding_field_names} | (self.body.undefined_symbols - undefined_field_symbols) def subs(self, subs_dict) -> None: