diff --git a/src/pystencils_autodiff/computationgraph.py b/src/pystencils_autodiff/computationgraph.py new file mode 100644 index 0000000000000000000000000000000000000000..25c2f8590927a87c09f4f0b7fcb0e3e52ae1c3af --- /dev/null +++ b/src/pystencils_autodiff/computationgraph.py @@ -0,0 +1,163 @@ +# +# Distributed under terms of the GPLv3 license. + +""" + +""" + +from itertools import chain + +import pystencils +from pystencils import Field +from pystencils.astnodes import KernelFunction +from pystencils.kernel_wrapper import KernelWrapper +from pystencils_autodiff.graph_datahandling import KernelCall, Swap, TimeloopRun + + +class ComputationGraph: + class FieldWriteCounter: + def __init__(self, field, counter=0): + self.field = field + self.counter = counter + + def __hash__(self): + return hash((self.field, self.counter)) + + def next(self): + return self.__class__(self.field, self.counter + 1) + + @property + def name(self): + return self.field.name + + def __str__(self): + return self.__repr__() + + def __repr__(self): + return f'{self.field} #{self.counter}' + + def __eq__(self, other): + return hash(self) == hash(other) + + def __init__(self, call_list, write_counter={}): + self.call_list = call_list + self.write_counter = write_counter + self.reads = {} + self.writes = {} + self.computation_nodes = set() + self.input_nodes = [] + self.output_nodes = [] + + for c in call_list: + if isinstance(c, KernelCall): # TODO get rid of this one + c = c.kernel + + if isinstance(c, KernelWrapper): + c = c.ast + + if isinstance(c, KernelFunction): + output_fields = c.fields_written + input_fields = c.fields_read + + computation_node = self.ComputationNode(c) + self.read(input_fields, computation_node) + self.write(output_fields, computation_node) + self.computation_nodes.add(computation_node) + elif isinstance(c, Swap): + computation_node = self.ComputationNode(c) + self.read([c.field, c.destination], computation_node) + self.write([c.field, c.destination], computation_node) + self.computation_nodes.add(computation_node) + elif isinstance(c, TimeloopRun): + computation_node = ComputationGraph(c.timeloop._single_step_asts, self.write_counter) + + self.computation_nodes.add(computation_node) + else: + print(c) + + for c in self.computation_nodes: + if isinstance(c, ComputationGraph): + reads = set(c.reads.keys()) + writes = set(c.writes.keys()) + known = set(chain(self.writes.keys(), self.reads.keys())) + c.input_nodes = [self.ArrayNode(a) for a in (known & reads)] + c.output_nodes = [self.ArrayNode(a) for a in (known & writes)] + + def read(self, fields, kernel): + fields = [self.FieldWriteCounter(f, self.write_counter.get(f.name, 0)) for f in fields] + for f in fields: + read_node = {**self.writes, **self.reads}.get(f, self.ArrayNode(f)) + read_node.destination_nodes.append(kernel) + self.reads[f] = read_node + kernel.input_nodes.append(read_node) + + def write(self, fields, kernel): + for f in fields: + field_snapshot = self.FieldWriteCounter(f, self.write_counter.get(f.name, 0) + 1) + write_node = self.ArrayNode(field_snapshot) + write_node.source_node = kernel + self.writes[field_snapshot] = write_node + kernel.output_nodes.append(write_node) + self.write_counter[f.name] = self.write_counter.get(f.name, 0) + 1 + + def to_dot(self, graph_style=None, with_code=False): + import graphviz + graph_style = {} if graph_style is None else graph_style + + fields = {**self.reads, **self.writes} + dot = graphviz.Digraph(str(id(self))) + + for field, node in fields.items(): + label = f'{field.name} #{field.counter}' + dot.node(label, style='filled', fillcolor='#a056db', label=label) + + for node in self.computation_nodes: + if isinstance(node, ComputationGraph): + subgraph = node.to_dot(with_code=with_code) + dot.subgraph(subgraph) + continue + elif isinstance(node.kernel, Swap): + name = f'Swap {id(node)}' + dot.node(str(id(node)), style='filled', fillcolor='#ff5600', label=name) + elif isinstance(node.kernel, KernelFunction): + if with_code: + name = str(pystencils.show_code(node.kernel)) + else: + name = node.kernel.function_name + + dot.node(str(id(node)), style='filled', fillcolor='#0056db', label=name) + else: + raise 'foo' + + for input in node.input_nodes: + field = input.field + label = f'{field.name} #{field.counter}' + dot.edge(label, str(id(node))) + for output in node.output_nodes: + field = output.field + label = f'{field.name} #{field.counter}' + dot.edge(str(id(node)), label) + + return dot + + def to_dot_file(self, path, graph_style=None, with_code=False): + with open(path, 'w') as f: + f.write(str(self.to_dot(graph_style, with_code))) + + class ComputationNode: + def __init__(self, kernel): + self.kernel = kernel + self.input_nodes = [] + self.output_nodes = [] + + def __hash__(self): + return id(self) + + class ArrayNode: + def __init__(self, field: Field): + self.field = field + self.source_node = None + self.destination_nodes = [] + + def __hash__(self): + return id(self) diff --git a/src/pystencils_autodiff/graph_datahandling.py b/src/pystencils_autodiff/graph_datahandling.py index 2125ac229ccda3b3989714638b9e9bdfbbc9c9da..59c31c912afcf9ebd000e8042bda7c5f08535273 100644 --- a/src/pystencils_autodiff/graph_datahandling.py +++ b/src/pystencils_autodiff/graph_datahandling.py @@ -45,6 +45,9 @@ class DataTransfer: def __str__(self): return f'DataTransferKind: {self.kind} with {self.field}' + def __repr__(self): + return f'DataTransferKind: {self.kind} with {self.field}' + class Swap(DataTransfer): def __init__(self, source, destination, gpu): @@ -52,7 +55,7 @@ class Swap(DataTransfer): self.field = source self.destination = destination - def __str__(self): + def __repr__(self): return f'Swap: {self.field} with {self.destination}' @@ -122,6 +125,13 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling): self.parent.call_queue.append(TimeloopRun(self, time_steps)) super().run(time_steps) + def swap(self, src, dst, is_gpu): + if isinstance(src, str): + src = self.parent.fields[src] + if isinstance(dst, str): + dst = self.parent.fields[dst] + self._single_step_asts.append(Swap(src, dst, is_gpu)) + def __init__(self, *args, **kwargs): self.call_queue = [] diff --git a/tests/test_graph_datahandling.py b/tests/test_graph_datahandling.py index 5b22070b8231b9ecca3e31b106ae93de690b3e48..17d3229c321f1fe7ed6e27cc2cde8068c6aef18e 100644 --- a/tests/test_graph_datahandling.py +++ b/tests/test_graph_datahandling.py @@ -15,6 +15,8 @@ try: from lbmpy.lbstep import LatticeBoltzmannStep from pystencils_autodiff.graph_datahandling import GraphDataHandling from pystencils.slicing import slice_from_direction + from pystencils_autodiff.computationgraph import ComputationGraph + except ImportError: pass @@ -65,10 +67,22 @@ def ldc_setup(**kwargs): def test_graph_datahandling(): - print("--- LDC 2D test ---") - opt_params = {'target': 'gpu', 'gpu_indexing_params': {'block_size': (8, 4, 2)}} lbm_step: LatticeBoltzmannStep = ldc_setup(domain_size=(10, 15), optimization=opt_params) print(lbm_step._data_handling) print(lbm_step._data_handling.call_queue) + + +def test_graph_generation(): + + opt_params = {'target': 'gpu', 'gpu_indexing_params': {'block_size': (8, 4, 2)}} + lbm_step: LatticeBoltzmannStep = ldc_setup(domain_size=(10, 15), optimization=opt_params) + + graph = ComputationGraph(lbm_step._data_handling.call_queue) + print("graph.writes: " + str(graph.writes)) + print("graph.reads: " + str(graph.reads)) + + print(graph.to_dot()) + + graph.to_dot_file('/tmp/foo.dot', with_code=False)