diff --git a/src/pystencils_autodiff/computationgraph.py b/src/pystencils_autodiff/computationgraph.py
new file mode 100644
index 0000000000000000000000000000000000000000..25c2f8590927a87c09f4f0b7fcb0e3e52ae1c3af
--- /dev/null
+++ b/src/pystencils_autodiff/computationgraph.py
@@ -0,0 +1,163 @@
+#
+# Distributed under terms of the GPLv3 license.
+
+"""
+
+"""
+
+from itertools import chain
+
+import pystencils
+from pystencils import Field
+from pystencils.astnodes import KernelFunction
+from pystencils.kernel_wrapper import KernelWrapper
+from pystencils_autodiff.graph_datahandling import KernelCall, Swap, TimeloopRun
+
+
+class ComputationGraph:
+    class FieldWriteCounter:
+        def __init__(self, field, counter=0):
+            self.field = field
+            self.counter = counter
+
+        def __hash__(self):
+            return hash((self.field, self.counter))
+
+        def next(self):
+            return self.__class__(self.field, self.counter + 1)
+
+        @property
+        def name(self):
+            return self.field.name
+
+        def __str__(self):
+            return self.__repr__()
+
+        def __repr__(self):
+            return f'{self.field} #{self.counter}'
+
+        def __eq__(self, other):
+            return hash(self) == hash(other)
+
+    def __init__(self, call_list, write_counter={}):
+        self.call_list = call_list
+        self.write_counter = write_counter
+        self.reads = {}
+        self.writes = {}
+        self.computation_nodes = set()
+        self.input_nodes = []
+        self.output_nodes = []
+
+        for c in call_list:
+            if isinstance(c, KernelCall):  # TODO get rid of this one
+                c = c.kernel
+
+            if isinstance(c, KernelWrapper):
+                c = c.ast
+
+            if isinstance(c, KernelFunction):
+                output_fields = c.fields_written
+                input_fields = c.fields_read
+
+                computation_node = self.ComputationNode(c)
+                self.read(input_fields, computation_node)
+                self.write(output_fields, computation_node)
+                self.computation_nodes.add(computation_node)
+            elif isinstance(c, Swap):
+                computation_node = self.ComputationNode(c)
+                self.read([c.field, c.destination], computation_node)
+                self.write([c.field, c.destination], computation_node)
+                self.computation_nodes.add(computation_node)
+            elif isinstance(c, TimeloopRun):
+                computation_node = ComputationGraph(c.timeloop._single_step_asts, self.write_counter)
+
+                self.computation_nodes.add(computation_node)
+            else:
+                print(c)
+
+        for c in self.computation_nodes:
+            if isinstance(c, ComputationGraph):
+                reads = set(c.reads.keys())
+                writes = set(c.writes.keys())
+                known = set(chain(self.writes.keys(), self.reads.keys()))
+                c.input_nodes = [self.ArrayNode(a) for a in (known & reads)]
+                c.output_nodes = [self.ArrayNode(a) for a in (known & writes)]
+
+    def read(self, fields, kernel):
+        fields = [self.FieldWriteCounter(f, self.write_counter.get(f.name, 0)) for f in fields]
+        for f in fields:
+            read_node = {**self.writes, **self.reads}.get(f, self.ArrayNode(f))
+            read_node.destination_nodes.append(kernel)
+            self.reads[f] = read_node
+            kernel.input_nodes.append(read_node)
+
+    def write(self, fields, kernel):
+        for f in fields:
+            field_snapshot = self.FieldWriteCounter(f, self.write_counter.get(f.name, 0) + 1)
+            write_node = self.ArrayNode(field_snapshot)
+            write_node.source_node = kernel
+            self.writes[field_snapshot] = write_node
+            kernel.output_nodes.append(write_node)
+            self.write_counter[f.name] = self.write_counter.get(f.name, 0) + 1
+
+    def to_dot(self, graph_style=None, with_code=False):
+        import graphviz
+        graph_style = {} if graph_style is None else graph_style
+
+        fields = {**self.reads, **self.writes}
+        dot = graphviz.Digraph(str(id(self)))
+
+        for field, node in fields.items():
+            label = f'{field.name} #{field.counter}'
+            dot.node(label, style='filled', fillcolor='#a056db', label=label)
+
+        for node in self.computation_nodes:
+            if isinstance(node, ComputationGraph):
+                subgraph = node.to_dot(with_code=with_code)
+                dot.subgraph(subgraph)
+                continue
+            elif isinstance(node.kernel, Swap):
+                name = f'Swap {id(node)}'
+                dot.node(str(id(node)), style='filled', fillcolor='#ff5600', label=name)
+            elif isinstance(node.kernel, KernelFunction):
+                if with_code:
+                    name = str(pystencils.show_code(node.kernel))
+                else:
+                    name = node.kernel.function_name
+
+                dot.node(str(id(node)), style='filled', fillcolor='#0056db', label=name)
+            else:
+                raise 'foo'
+
+            for input in node.input_nodes:
+                field = input.field
+                label = f'{field.name} #{field.counter}'
+                dot.edge(label, str(id(node)))
+            for output in node.output_nodes:
+                field = output.field
+                label = f'{field.name} #{field.counter}'
+                dot.edge(str(id(node)), label)
+
+        return dot
+
+    def to_dot_file(self, path, graph_style=None, with_code=False):
+        with open(path, 'w') as f:
+            f.write(str(self.to_dot(graph_style, with_code)))
+
+    class ComputationNode:
+        def __init__(self, kernel):
+            self.kernel = kernel
+            self.input_nodes = []
+            self.output_nodes = []
+
+        def __hash__(self):
+            return id(self)
+
+    class ArrayNode:
+        def __init__(self, field: Field):
+            self.field = field
+            self.source_node = None
+            self.destination_nodes = []
+
+        def __hash__(self):
+            return id(self)
diff --git a/src/pystencils_autodiff/graph_datahandling.py b/src/pystencils_autodiff/graph_datahandling.py
index 2125ac229ccda3b3989714638b9e9bdfbbc9c9da..59c31c912afcf9ebd000e8042bda7c5f08535273 100644
--- a/src/pystencils_autodiff/graph_datahandling.py
+++ b/src/pystencils_autodiff/graph_datahandling.py
@@ -45,6 +45,9 @@ class DataTransfer:
     def __str__(self):
         return f'DataTransferKind: {self.kind} with {self.field}'
 
+    def __repr__(self):
+        return f'DataTransferKind: {self.kind} with {self.field}'
+
 
 class Swap(DataTransfer):
     def __init__(self, source, destination, gpu):
@@ -52,7 +55,7 @@ class Swap(DataTransfer):
         self.field = source
         self.destination = destination
 
-    def __str__(self):
+    def __repr__(self):
         return f'Swap: {self.field} with {self.destination}'
 
 
@@ -122,6 +125,13 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling):
             self.parent.call_queue.append(TimeloopRun(self, time_steps))
             super().run(time_steps)
 
+        def swap(self, src, dst, is_gpu):
+            if isinstance(src, str):
+                src = self.parent.fields[src]
+            if isinstance(dst, str):
+                dst = self.parent.fields[dst]
+            self._single_step_asts.append(Swap(src, dst, is_gpu))
+
     def __init__(self, *args, **kwargs):
 
         self.call_queue = []
diff --git a/tests/test_graph_datahandling.py b/tests/test_graph_datahandling.py
index 5b22070b8231b9ecca3e31b106ae93de690b3e48..17d3229c321f1fe7ed6e27cc2cde8068c6aef18e 100644
--- a/tests/test_graph_datahandling.py
+++ b/tests/test_graph_datahandling.py
@@ -15,6 +15,8 @@ try:
     from lbmpy.lbstep import LatticeBoltzmannStep
     from pystencils_autodiff.graph_datahandling import GraphDataHandling
     from pystencils.slicing import slice_from_direction
+    from pystencils_autodiff.computationgraph import ComputationGraph
+
 except ImportError:
     pass
 
@@ -65,10 +67,22 @@ def ldc_setup(**kwargs):
 
 def test_graph_datahandling():
 
-    print("--- LDC 2D test ---")
-
     opt_params = {'target': 'gpu', 'gpu_indexing_params': {'block_size': (8, 4, 2)}}
     lbm_step: LatticeBoltzmannStep = ldc_setup(domain_size=(10, 15), optimization=opt_params)
     print(lbm_step._data_handling)
 
     print(lbm_step._data_handling.call_queue)
+
+
+def test_graph_generation():
+
+    opt_params = {'target': 'gpu', 'gpu_indexing_params': {'block_size': (8, 4, 2)}}
+    lbm_step: LatticeBoltzmannStep = ldc_setup(domain_size=(10, 15), optimization=opt_params)
+
+    graph = ComputationGraph(lbm_step._data_handling.call_queue)
+    print("graph.writes: " + str(graph.writes))
+    print("graph.reads: " + str(graph.reads))
+
+    print(graph.to_dot())
+
+    graph.to_dot_file('/tmp/foo.dot', with_code=False)