diff --git a/src/pystencils_autodiff/graph_datahandling.py b/src/pystencils_autodiff/graph_datahandling.py index 0e68177b70a83775270a84ca2743b67c5b1c89ff..2a5b7cd8f3c684e0dd7d70e9699d328ca91b2ac8 100644 --- a/src/pystencils_autodiff/graph_datahandling.py +++ b/src/pystencils_autodiff/graph_datahandling.py @@ -168,20 +168,21 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling): alignment=False, field_type=FieldType.GENERIC): - super().add_array(name, - values_per_cell, - dtype, - latex_name, - ghost_layers, - layout, - cpu, - gpu, - alignment, - field_type) + rtn = super().add_array(name, + values_per_cell, + dtype, + latex_name, + ghost_layers, + layout, + cpu, + gpu, + alignment, + field_type) if cpu: self.call_queue.append(DataTransfer(self._fields[name], DataTransferKind.HOST_ALLOC)) if gpu: self.call_queue.append(DataTransfer(self._fields[name], DataTransferKind.DEVICE_ALLOC)) + return rtn def add_custom_data(self, name, cpu_creation_function, gpu_creation_function=None, cpu_to_gpu_transfer_func=None, gpu_to_cpu_transfer_func=None): @@ -200,10 +201,10 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling): self.call_queue.append(Swap(self._fields[name1], self._fields[name2], gpu)) super().swap(name1, name2, gpu) - def run_kernel(self, kernel_function, **kwargs): + def run_kernel(self, kernel_function, simulate_only=False, **kwargs): self.call_queue.append(KernelCall(kernel_function, kwargs)) - super().run_kernel(kernel_function, **kwargs) - # skip calling super + if not simulate_only: + super().run_kernel(kernel_function, **kwargs) def to_cpu(self, name): super().to_cpu(name)