Skip to content
Snippets Groups Projects

WIP: Graph datahandling

Closed Stephan Seitz requested to merge seitz/pystencils:graph-datahandling into master
Files
2
@@ -7,13 +7,13 @@
"""
"""
from enum import Enum
import numpy as np
import pystencils.datahandling
import pystencils.kernel_wrapper
import pystencils.timeloop
from pystencils.field import FieldType
@@ -42,6 +42,9 @@ class DataTransfer:
self.field = field
self.kind = kind
def __str__(self):
return f'DataTransferKind: {self.kind} with {self.field}'
class Swap(DataTransfer):
def __init__(self, source, destination, gpu):
@@ -49,6 +52,9 @@ class Swap(DataTransfer):
self.field = source
self.destination = destination
def __str__(self):
return f'Swap: {self.field} with {self.destination}'
class Communication(DataTransfer):
def __init__(self, field, stencil, gpu):
@@ -80,6 +86,10 @@ class TimeloopRun:
+ '\nPost:\n'
+ '\n '.join(str(f) for f in self.timeloop._post_run_functions))
@property
def asts(self):
return self.timeloop._single_step_asts
class GraphDataHandling(pystencils.datahandling.SerialDataHandling):
@@ -171,7 +181,7 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling):
def synchronization_function(self, names, stencil=None, target=None, **_):
for name in names:
gpu = target == 'cpu'
gpu = target == 'gpu'
self.call_queue.append(Communication(self._fields[name], stencil, gpu))
super().synchronization_function(names, stencil=None, target=None, **_)