Skip to content
Snippets Groups Projects
Commit 567bcea0 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add computationgraph.py

parent 865d2274
Branches
Tags
No related merge requests found
Pipeline #21223 failed
#
# Distributed under terms of the GPLv3 license.
"""
"""
from itertools import chain
import pystencils
from pystencils import Field
from pystencils.astnodes import KernelFunction
from pystencils.kernel_wrapper import KernelWrapper
from pystencils_autodiff.graph_datahandling import KernelCall, Swap, TimeloopRun
class ComputationGraph:
class FieldWriteCounter:
def __init__(self, field, counter=0):
self.field = field
self.counter = counter
def __hash__(self):
return hash((self.field, self.counter))
def next(self):
return self.__class__(self.field, self.counter + 1)
@property
def name(self):
return self.field.name
def __str__(self):
return self.__repr__()
def __repr__(self):
return f'{self.field} #{self.counter}'
def __eq__(self, other):
return hash(self) == hash(other)
def __init__(self, call_list, write_counter={}):
self.call_list = call_list
self.write_counter = write_counter
self.reads = {}
self.writes = {}
self.computation_nodes = set()
self.input_nodes = []
self.output_nodes = []
for c in call_list:
if isinstance(c, KernelCall): # TODO get rid of this one
c = c.kernel
if isinstance(c, KernelWrapper):
c = c.ast
if isinstance(c, KernelFunction):
output_fields = c.fields_written
input_fields = c.fields_read
computation_node = self.ComputationNode(c)
self.read(input_fields, computation_node)
self.write(output_fields, computation_node)
self.computation_nodes.add(computation_node)
elif isinstance(c, Swap):
computation_node = self.ComputationNode(c)
self.read([c.field, c.destination], computation_node)
self.write([c.field, c.destination], computation_node)
self.computation_nodes.add(computation_node)
elif isinstance(c, TimeloopRun):
computation_node = ComputationGraph(c.timeloop._single_step_asts, self.write_counter)
self.computation_nodes.add(computation_node)
else:
print(c)
for c in self.computation_nodes:
if isinstance(c, ComputationGraph):
reads = set(c.reads.keys())
writes = set(c.writes.keys())
known = set(chain(self.writes.keys(), self.reads.keys()))
c.input_nodes = [self.ArrayNode(a) for a in (known & reads)]
c.output_nodes = [self.ArrayNode(a) for a in (known & writes)]
def read(self, fields, kernel):
fields = [self.FieldWriteCounter(f, self.write_counter.get(f.name, 0)) for f in fields]
for f in fields:
read_node = {**self.writes, **self.reads}.get(f, self.ArrayNode(f))
read_node.destination_nodes.append(kernel)
self.reads[f] = read_node
kernel.input_nodes.append(read_node)
def write(self, fields, kernel):
for f in fields:
field_snapshot = self.FieldWriteCounter(f, self.write_counter.get(f.name, 0) + 1)
write_node = self.ArrayNode(field_snapshot)
write_node.source_node = kernel
self.writes[field_snapshot] = write_node
kernel.output_nodes.append(write_node)
self.write_counter[f.name] = self.write_counter.get(f.name, 0) + 1
def to_dot(self, graph_style=None, with_code=False):
import graphviz
graph_style = {} if graph_style is None else graph_style
fields = {**self.reads, **self.writes}
dot = graphviz.Digraph(str(id(self)))
for field, node in fields.items():
label = f'{field.name} #{field.counter}'
dot.node(label, style='filled', fillcolor='#a056db', label=label)
for node in self.computation_nodes:
if isinstance(node, ComputationGraph):
subgraph = node.to_dot(with_code=with_code)
dot.subgraph(subgraph)
continue
elif isinstance(node.kernel, Swap):
name = f'Swap {id(node)}'
dot.node(str(id(node)), style='filled', fillcolor='#ff5600', label=name)
elif isinstance(node.kernel, KernelFunction):
if with_code:
name = str(pystencils.show_code(node.kernel))
else:
name = node.kernel.function_name
dot.node(str(id(node)), style='filled', fillcolor='#0056db', label=name)
else:
raise 'foo'
for input in node.input_nodes:
field = input.field
label = f'{field.name} #{field.counter}'
dot.edge(label, str(id(node)))
for output in node.output_nodes:
field = output.field
label = f'{field.name} #{field.counter}'
dot.edge(str(id(node)), label)
return dot
def to_dot_file(self, path, graph_style=None, with_code=False):
with open(path, 'w') as f:
f.write(str(self.to_dot(graph_style, with_code)))
class ComputationNode:
def __init__(self, kernel):
self.kernel = kernel
self.input_nodes = []
self.output_nodes = []
def __hash__(self):
return id(self)
class ArrayNode:
def __init__(self, field: Field):
self.field = field
self.source_node = None
self.destination_nodes = []
def __hash__(self):
return id(self)
...@@ -45,6 +45,9 @@ class DataTransfer: ...@@ -45,6 +45,9 @@ class DataTransfer:
def __str__(self): def __str__(self):
return f'DataTransferKind: {self.kind} with {self.field}' return f'DataTransferKind: {self.kind} with {self.field}'
def __repr__(self):
return f'DataTransferKind: {self.kind} with {self.field}'
class Swap(DataTransfer): class Swap(DataTransfer):
def __init__(self, source, destination, gpu): def __init__(self, source, destination, gpu):
...@@ -52,7 +55,7 @@ class Swap(DataTransfer): ...@@ -52,7 +55,7 @@ class Swap(DataTransfer):
self.field = source self.field = source
self.destination = destination self.destination = destination
def __str__(self): def __repr__(self):
return f'Swap: {self.field} with {self.destination}' return f'Swap: {self.field} with {self.destination}'
...@@ -122,6 +125,13 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling): ...@@ -122,6 +125,13 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling):
self.parent.call_queue.append(TimeloopRun(self, time_steps)) self.parent.call_queue.append(TimeloopRun(self, time_steps))
super().run(time_steps) super().run(time_steps)
def swap(self, src, dst, is_gpu):
if isinstance(src, str):
src = self.parent.fields[src]
if isinstance(dst, str):
dst = self.parent.fields[dst]
self._single_step_asts.append(Swap(src, dst, is_gpu))
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.call_queue = [] self.call_queue = []
......
...@@ -15,6 +15,8 @@ try: ...@@ -15,6 +15,8 @@ try:
from lbmpy.lbstep import LatticeBoltzmannStep from lbmpy.lbstep import LatticeBoltzmannStep
from pystencils_autodiff.graph_datahandling import GraphDataHandling from pystencils_autodiff.graph_datahandling import GraphDataHandling
from pystencils.slicing import slice_from_direction from pystencils.slicing import slice_from_direction
from pystencils_autodiff.computationgraph import ComputationGraph
except ImportError: except ImportError:
pass pass
...@@ -65,10 +67,22 @@ def ldc_setup(**kwargs): ...@@ -65,10 +67,22 @@ def ldc_setup(**kwargs):
def test_graph_datahandling(): def test_graph_datahandling():
print("--- LDC 2D test ---")
opt_params = {'target': 'gpu', 'gpu_indexing_params': {'block_size': (8, 4, 2)}} opt_params = {'target': 'gpu', 'gpu_indexing_params': {'block_size': (8, 4, 2)}}
lbm_step: LatticeBoltzmannStep = ldc_setup(domain_size=(10, 15), optimization=opt_params) lbm_step: LatticeBoltzmannStep = ldc_setup(domain_size=(10, 15), optimization=opt_params)
print(lbm_step._data_handling) print(lbm_step._data_handling)
print(lbm_step._data_handling.call_queue) print(lbm_step._data_handling.call_queue)
def test_graph_generation():
opt_params = {'target': 'gpu', 'gpu_indexing_params': {'block_size': (8, 4, 2)}}
lbm_step: LatticeBoltzmannStep = ldc_setup(domain_size=(10, 15), optimization=opt_params)
graph = ComputationGraph(lbm_step._data_handling.call_queue)
print("graph.writes: " + str(graph.writes))
print("graph.reads: " + str(graph.reads))
print(graph.to_dot())
graph.to_dot_file('/tmp/foo.dot', with_code=False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment