diff --git a/src/pystencils_autodiff/graph_datahandling.py b/src/pystencils_autodiff/graph_datahandling.py index 8768edc5c7a696bc0c63f032c35aaf56261c0669..b6d03403cfe61c4d656513baf94f03ee767aa909 100644 --- a/src/pystencils_autodiff/graph_datahandling.py +++ b/src/pystencils_autodiff/graph_datahandling.py @@ -93,9 +93,18 @@ class Communication(DataTransfer): class KernelCall: def __init__(self, kernel: pystencils.kernel_wrapper.KernelWrapper, kwargs, tmp_field_swaps=[]): + tmp = None + src = None + for f in kernel.ast.fields_accessed: + if 'pdfTmp' in f.name: + tmp = f + if 'pdfSrc' in f.name: + src = f self.kernel = kernel self.kwargs = kwargs self.tmp_field_swaps = tmp_field_swaps + if tmp and src: + self.tmp_field_swaps.append((src, tmp)) def __str__(self): return "Call " + str(self.kernel.ast.function_name)