diff --git a/pystencils/datahandling/graph_datahandling.py b/pystencils/datahandling/graph_datahandling.py index 650bf51f6acf9f88c6660ef3ee388bd43a0daa47..15027fa4cadb0fc8c9b3eac538484bf9e61a2804 100644 --- a/pystencils/datahandling/graph_datahandling.py +++ b/pystencils/datahandling/graph_datahandling.py @@ -7,13 +7,13 @@ """ """ - from enum import Enum import numpy as np import pystencils.datahandling import pystencils.kernel_wrapper +import pystencils.timeloop from pystencils.field import FieldType @@ -181,7 +181,7 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling): def synchronization_function(self, names, stencil=None, target=None, **_): for name in names: - gpu = target == 'cpu' + gpu = target == 'gpu' self.call_queue.append(Communication(self._fields[name], stencil, gpu)) super().synchronization_function(names, stencil=None, target=None, **_) diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py index 613050544ebf8fec100bb76bfebc9208cbb239e6..298d01255cd638941998a76f9ec33073c02d3d0d 100644 --- a/pystencils/kernelcreation.py +++ b/pystencils/kernelcreation.py @@ -76,9 +76,9 @@ def create_kernel(assignments, [0., 4., 4., 4., 0.], [0., 0., 0., 0., 0.]]) """ - #save the original assignments + # save the original assignments assign = assignments - + # ---- Normalizing parameters split_groups = () if isinstance(assignments, AssignmentCollection): @@ -187,7 +187,7 @@ def create_indexed_kernel(assignments, """ assign = assignments indF = index_fields - + if isinstance(assignments, Assignment): assignments = [assignments] elif isinstance(assignments, AssignmentCollection):