diff --git a/src/pystencils_autodiff/graph_datahandling.py b/src/pystencils_autodiff/graph_datahandling.py index 59c31c912afcf9ebd000e8042bda7c5f08535273..f33b89f5797c525e4a6539bd684a453aac2d9c90 100644 --- a/src/pystencils_autodiff/graph_datahandling.py +++ b/src/pystencils_autodiff/graph_datahandling.py @@ -179,14 +179,14 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling): def to_cpu(self, name): super().to_cpu(name) - self.call_queue.append(DataTransfer(self._fields[name], DataTransferKind.HOST_TO_DEVICE)) + self.call_queue.append(DataTransfer(self._fields[name], DataTransferKind.DEVICE_TO_HOST)) def to_gpu(self, name): super().to_gpu(name) if name in self._custom_data_transfer_functions: self.call_queue.append('Custom Tranfer Function') else: - self.call_queue.append(DataTransfer(self._fields[name], DataTransferKind.DEVICE_TO_HOST)) + self.call_queue.append(DataTransfer(self._fields[name], DataTransferKind.HOST_TO_DEVICE)) def synchronization_function(self, names, stencil=None, target=None, **_): for name in names: