From 16b089f2b6191a4b9493eaa0332a830926d39fcb Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Tue, 29 Oct 2019 09:17:45 +0100 Subject: [PATCH] Extend DestructuringBindingsForFieldClass to enable arguments per value --- src/pystencils_autodiff/framework_integration/astnodes.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 51101c3..18a9092 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -35,6 +35,7 @@ class DestructuringBindingsForFieldClass(Node): FieldStrideSymbol: "stride[{dim}]" } CLASS_NAME_TEMPLATE = "PyStencilsField<{dtype}, {ndim}>" + ARGS_AS_REFERENCE = True @property def fields_accessed(self) -> Set['ResolvedFieldAccess']: @@ -70,7 +71,11 @@ class DestructuringBindingsForFieldClass(Node): undefined_field_symbols = self.symbols_defined corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')} corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')} - return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&') + return {TypedSymbol(f, + self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype, + ndim=field_map[f].ndim) + ('&' + if self.ARGS_AS_REFERENCE + else '')) for f in corresponding_field_names} | (self.body.undefined_symbols - undefined_field_symbols) def subs(self, subs_dict) -> None: -- GitLab