diff --git a/ast/ast_node.py b/ast/ast_node.py
index 8d37b6961bd96c0268fc92581d90aec561c3832b..aa1c2202daa09a30ee576d8e401cd128875036b6 100644
--- a/ast/ast_node.py
+++ b/ast/ast_node.py
@@ -4,7 +4,7 @@ from ast.data_types import Type_Invalid
 class ASTNode:
     def __init__(self, sim):
         self.sim = sim
-        self._parent_block = None # Set during SetParentBlock transformation
+        self.parent_block = None # Set during SetParentBlock transformation
 
     def __str__(self):
         return "ASTNode<>"
@@ -15,9 +15,5 @@ class ASTNode:
     def scope(self):
         return self.sim.global_scope
 
-    @property
-    def parent_block(self):
-        return self._parent_block
-
     def children(self):
         return []
diff --git a/ast/bin_op.py b/ast/bin_op.py
index ba826203772cb6bae73e7396b097c58d0cce12f4..6d48c6d72f2e5d03a4c82640c968ec523679c011 100644
--- a/ast/bin_op.py
+++ b/ast/bin_op.py
@@ -46,8 +46,9 @@ class BinOp(ASTNode):
         self.generated = False
         self.bin_op_type = BinOp.infer_type(self.lhs, self.rhs, self.op)
         self.bin_op_scope = None
-        self.bin_op_vector_indexes = set()
-        self.bin_op_vector_index_mapping = {}
+        self.terminals = set()
+        self._vector_indexes = set()
+        self.vector_index_mapping = {}
         self.bin_op_def = BinOpDef(self)
 
     def __str__(self):
@@ -68,17 +69,18 @@ class BinOp(ASTNode):
         return self.__getitem__(2)
 
     def map_vector_index(self, index, expr):
-        self.bin_op_vector_index_mapping[index] = expr
+        self.vector_index_mapping[index] = expr
 
     def mapped_vector_index(self, index):
-        mapping = self.bin_op_vector_index_mapping
+        mapping = self.vector_index_mapping
         return mapping[index] if index in mapping else as_lit_ast(self.sim, index)
 
+    @property
     def vector_indexes(self):
-        return self.bin_op_vector_indexes
+        return self._vector_indexes
 
     def propagate_vector_access(self, index):
-        self.bin_op_vector_indexes.add(index)
+        self.vector_indexes.add(index)
 
         if isinstance(self.lhs, BinOp) and self.lhs.kind() == BinOp.Kind_Vector:
             self.lhs.propagate_vector_access(index)
@@ -162,6 +164,9 @@ class BinOp(ASTNode):
     def kind(self):
         return BinOp.Kind_Vector if self.type() == Type_Vector else BinOp.Kind_Scalar
 
+    def add_terminal(self, terminal):
+        self.terminals.add(terminal)
+
     def scope(self):
         if self.bin_op_scope is None:
             lhs_scp = self.lhs.scope()
diff --git a/ast/block.py b/ast/block.py
index 85ddd1fe13f644a47caf0e92c803232e4e714a9a..58a90926edfa09168d54b535387a47a0db13bb03 100644
--- a/ast/block.py
+++ b/ast/block.py
@@ -5,7 +5,7 @@ class Block(ASTNode):
     def __init__(self, sim, stmts):
         super().__init__(sim)
         self.level = 0
-        self.variants = []
+        self.variants = set()
 
         if isinstance(stmts, Block):
             self.stmts = stmts.statements()
@@ -34,10 +34,8 @@ class Block(ASTNode):
             self.stmts.append(stmt)
 
     def add_variant(self, variant):
-        if isinstance(variant, list):
-            self.variants = self.variants + variant
-        else:
-            self.variants.append(variant)
+        for v in variant if isinstance(variant, list) else [variant]:
+            self.variants.add(v)
 
     def statements(self):
         return self.stmts
diff --git a/ast/mutator.py b/ast/mutator.py
index 6ec8727c7713586dfc4653c3b9db9890b47b29e1..b856fbf103b558993f5a7e23169f0e35c467ba09 100644
--- a/ast/mutator.py
+++ b/ast/mutator.py
@@ -33,7 +33,7 @@ class Mutator:
     def mutate_BinOp(self, ast_node):
         ast_node.lhs = self.mutate(ast_node.lhs)
         ast_node.rhs = self.mutate(ast_node.rhs)
-        ast_node.bin_op_vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.bin_op_vector_index_mapping.items()}
+        ast_node.vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.vector_index_mapping.items()}
         return ast_node
 
     def mutate_BinOpDef(self, ast_node):
diff --git a/code_gen/cgen.py b/code_gen/cgen.py
index c28e353d69cf64fe502dd8000c14feac669f0dd0..8fccf2454abfe43318de3a86849fa6e9fe13ad6f 100644
--- a/code_gen/cgen.py
+++ b/code_gen/cgen.py
@@ -16,7 +16,7 @@ from ast.utils import Print
 from ast.variables import Var, VarDecl
 from sim.timestep import Timestep
 from sim.vtk import VTKWrite
-from code_gen.printer import printer
+from code_gen.printer import Printer
 
 
 class CGen:
@@ -27,34 +27,43 @@ class CGen:
             else 'bool'
         )
 
-    def generate_program(sim, ast_node):
-        printer.print("#include <stdio.h>")
-        printer.print("#include <stdlib.h>")
-        printer.print("#include <stdbool.h>")
-        printer.print("")
-        printer.print("int main() {")
-        CGen.generate_statement(sim, ast_node)
-        printer.print("}")
-
-    def generate_statement(sim, ast_node):
+    def __init__(self, output):
+        self.sim = None
+        self.print = Printer(output)
+
+    def assign_simulation(self, sim):
+        self.sim = sim
+
+    def generate_program(self, ast_node):
+        self.print.start()
+        self.print("#include <stdio.h>")
+        self.print("#include <stdlib.h>")
+        self.print("#include <stdbool.h>")
+        self.print("")
+        self.print("int main() {")
+        self.generate_statement(ast_node)
+        self.print("}")
+        self.print.end()
+
+    def generate_statement(self, ast_node):
         if isinstance(ast_node, ArrayDecl):
             tkw = CGen.type2keyword(ast_node.array.type())
-            size = CGen.generate_expression(sim, BinOp.inline(ast_node.array.alloc_size()))
-            printer.print(f"{tkw} {ast_node.array.name()}[{size}];")
+            size = self.generate_expression(BinOp.inline(ast_node.array.alloc_size()))
+            self.print(f"{tkw} {ast_node.array.name()}[{size}];")
 
         if isinstance(ast_node, Assign):
             for assign_dest, assign_src in ast_node.assignments:
-                dest = CGen.generate_expression(sim, assign_dest, mem=True)
-                src = CGen.generate_expression(sim, assign_src)
-                printer.print(f"{dest} = {src};")
+                dest = self.generate_expression(assign_dest, mem=True)
+                src = self.generate_expression(assign_src)
+                self.print(f"{dest} = {src};")
 
         if isinstance(ast_node, Block):
-            printer.add_ind(4)
+            self.print.add_ind(4)
 
             for stmt in ast_node.statements():
-                CGen.generate_statement(sim, stmt)
+                self.generate_statement(stmt)
 
-            printer.add_ind(-4)
+            self.print.add_ind(-4)
 
         if isinstance(ast_node, BinOpDef):
             bin_op = ast_node.bin_op
@@ -64,16 +73,16 @@ class CGen:
 
             if bin_op.inlined is False and bin_op.operator() != '[]' and bin_op.generated is False:
                 if bin_op.kind() == BinOp.Kind_Scalar:
-                    lhs = CGen.generate_expression(sim, bin_op.lhs, bin_op.mem)
-                    rhs = CGen.generate_expression(sim, bin_op.rhs)
+                    lhs = self.generate_expression(bin_op.lhs, bin_op.mem)
+                    rhs = self.generate_expression(bin_op.rhs)
                     tkw = CGen.type2keyword(bin_op.type())
-                    printer.print(f"const {tkw} e{bin_op.id()} = {lhs} {bin_op.operator()} {rhs};")
+                    self.print(f"const {tkw} e{bin_op.id()} = {lhs} {bin_op.operator()} {rhs};")
 
                 elif bin_op.kind() == BinOp.Kind_Vector:
-                    for i in bin_op.vector_indexes():
-                        lhs = CGen.generate_expression(sim, bin_op.lhs, bin_op.mem, index=i)
-                        rhs = CGen.generate_expression(sim, bin_op.rhs, index=i)
-                        printer.print(f"const double e{bin_op.id()}_{i} = {lhs} {bin_op.operator()} {rhs};")
+                    for i in bin_op.vector_indexes:
+                        lhs = self.generate_expression(bin_op.lhs, bin_op.mem, index=i)
+                        rhs = self.generate_expression(bin_op.rhs, index=i)
+                        self.print(f"const double e{bin_op.id()}_{i} = {lhs} {bin_op.operator()} {rhs};")
 
                 else:
                     raise Exception("Invalid BinOp kind!")
@@ -81,79 +90,79 @@ class CGen:
                 bin_op.generated = True
 
         if isinstance(ast_node, Branch):
-            cond = CGen.generate_expression(sim, ast_node.cond)
-            printer.print(f"if({cond}) {{")
-            CGen.generate_statement(sim, ast_node.block_if)
+            cond = self.generate_expression(ast_node.cond)
+            self.print(f"if({cond}) {{")
+            self.generate_statement(ast_node.block_if)
 
             if ast_node.block_else is not None:
-                printer.print("} else {")
-                CGen.generate_statement(sim, ast_node.block_else)
+                self.print("} else {")
+                self.generate_statement(ast_node.block_else)
 
-            printer.print("}") 
+            self.print("}") 
 
         if isinstance(ast_node, For):
-            iterator = CGen.generate_expression(sim, ast_node.iterator)
+            iterator = self.generate_expression(ast_node.iterator)
             lower_range = None
             upper_range = None
 
             if isinstance(ast_node, ParticleFor):
-                n = sim.nlocal if ast_node.local_only else sim.nlocal + sim.pbc.npbc
+                n = self.sim.nlocal if ast_node.local_only else self.sim.nlocal + self.sim.pbc.npbc
                 lower_range = 0
-                upper_range = CGen.generate_expression(sim, n)
+                upper_range = self.generate_expression(n)
 
             else:
-                lower_range = CGen.generate_expression(sim, ast_node.min)
-                upper_range = CGen.generate_expression(sim, ast_node.max)
+                lower_range = self.generate_expression(ast_node.min)
+                upper_range = self.generate_expression(ast_node.max)
 
-            printer.print(f"for(int {iterator} = {lower_range}; {iterator} < {upper_range}; {iterator}++) {{")
-            CGen.generate_statement(sim, ast_node.block)
-            printer.print("}")
+            self.print(f"for(int {iterator} = {lower_range}; {iterator} < {upper_range}; {iterator}++) {{")
+            self.generate_statement(ast_node.block)
+            self.print("}")
 
 
         if isinstance(ast_node, Malloc):
             tkw = CGen.type2keyword(ast_node.array.type())
-            size = CGen.generate_expression(sim, ast_node.size)
+            size = self.generate_expression(ast_node.size)
             array_name = ast_node.array.name()
 
             if ast_node.decl:
-                printer.print(f"{tkw} *{array_name} = ({tkw} *) malloc({size});")
+                self.print(f"{tkw} *{array_name} = ({tkw} *) malloc({size});")
             else:
-                printer.print(f"{array_name} = ({tkw} *) malloc({size});")
+                self.print(f"{array_name} = ({tkw} *) malloc({size});")
 
         if isinstance(ast_node, Print):
-            printer.print(f"fprintf(stdout, \"{ast_node.string}\\n\");")
-            printer.print(f"fflush(stdout);")
+            self.print(f"fprintf(stdout, \"{ast_node.string}\\n\");")
+            self.print(f"fflush(stdout);")
 
         if isinstance(ast_node, Realloc):
             tkw = CGen.type2keyword(ast_node.array.type())
-            size = CGen.generate_expression(sim, ast_node.size)
+            size = self.generate_expression(ast_node.size)
             array_name = ast_node.array.name()
-            printer.print(f"{array_name} = ({tkw} *) realloc({array_name}, {size});")
+            self.print(f"{array_name} = ({tkw} *) realloc({array_name}, {size});")
 
         if isinstance(ast_node, Timestep):
-            CGen.generate_statement(sim, ast_node.block)
+            self.generate_statement(ast_node.block)
 
         if isinstance(ast_node, VarDecl):
             tkw = CGen.type2keyword(ast_node.var.type())
-            printer.print(f"{tkw} {ast_node.var.name()} = {ast_node.var.init_value()};")
+            self.print(f"{tkw} {ast_node.var.name()} = {ast_node.var.init_value()};")
 
         if isinstance(ast_node, VTKWrite):
-            nlocal = CGen.generate_expression(sim, sim.nlocal)
-            npbc = CGen.generate_expression(sim, sim.pbc.npbc)
-            nall = CGen.generate_expression(sim, sim.nlocal + sim.pbc.npbc)
-            timestep = CGen.generate_expression(sim, ast_node.timestep)
-            CGen.generate_vtk_writing(ast_node.vtk_id * 2, f"{ast_node.filename}_local", 0, nlocal, nlocal, timestep)
-            CGen.generate_vtk_writing(ast_node.vtk_id * 2 + 1, f"{ast_node.filename}_pbc", nlocal, nall, npbc, timestep)
+            nlocal = self.generate_expression(self.sim.nlocal)
+            npbc = self.generate_expression(self.sim.pbc.npbc)
+            nall = self.generate_expression(self.sim.nlocal + self.sim.pbc.npbc)
+            timestep = self.generate_expression(ast_node.timestep)
+            self.generate_vtk_writing(ast_node.vtk_id * 2, f"{ast_node.filename}_local", 0, nlocal, nlocal, timestep)
+            self.generate_vtk_writing(ast_node.vtk_id * 2 + 1, f"{ast_node.filename}_pbc", nlocal, nall, npbc, timestep)
 
         if isinstance(ast_node, While):
-            cond = CGen.generate_expression(sim, ast_node.cond)
-            printer.print(f"while({cond}) {{")
-            CGen.generate_statement(sim, ast_node.block)
-            printer.print("}")
+            cond = self.generate_expression(ast_node.cond)
+            self.print(f"while({cond}) {{")
+            self.generate_statement(ast_node.block)
+            self.print("}")
 
-    def generate_expression(sim, ast_node, mem=False, index=None):
+    def generate_expression(self, ast_node, mem=False, index=None):
         if isinstance(ast_node, ArrayAccess):
-            index = CGen.generate_expression(sim, ast_node.index)
+            index = self.generate_expression(ast_node.index)
             array_name = ast_node.array.name()
 
             if mem:
@@ -162,20 +171,20 @@ class CGen:
             acc_ref = f"a{ast_node.id()}"
             if ast_node.generated is False:
                 tkw = CGen.type2keyword(ast_node.type())
-                printer.print(f"const {tkw} {acc_ref} = {array_name}[{index}];")
+                self.print(f"const {tkw} {acc_ref} = {array_name}[{index}];")
                 ast_node.generated = True
 
             return acc_ref
 
         if isinstance(ast_node, BinOp):
             if isinstance(ast_node.lhs, BinOp) and ast_node.lhs.kind() == BinOp.Kind_Vector and ast_node.operator() == '[]':
-                return CGen.generate_expression(sim, ast_node.lhs, ast_node.mem, CGen.generate_expression(sim, ast_node.rhs))
+                return self.generate_expression(ast_node.lhs, ast_node.mem, self.generate_expression(ast_node.rhs))
 
-            lhs = CGen.generate_expression(sim, ast_node.lhs, mem, index)
-            rhs = CGen.generate_expression(sim, ast_node.rhs, index=index)
+            lhs = self.generate_expression(ast_node.lhs, mem, index)
+            rhs = self.generate_expression(ast_node.rhs, index=index)
 
             if ast_node.operator() == '[]':
-                idx = CGen.generate_expression(sim, ast_node.mapped_vector_index(index)) if ast_node.is_vector_property_access() else rhs
+                idx = self.generate_expression(ast_node.mapped_vector_index(index)) if ast_node.is_vector_property_access() else rhs
                 return f"{lhs}[{idx}]" if ast_node.mem else f"{lhs}_{idx}"
 
             if ast_node.inlined is True:
@@ -185,7 +194,7 @@ class CGen:
             # Some expressions can be defined on-the-fly during transformations, hence they do not have
             # a definition statement in the tree, so we generate them right before use
             if not ast_node.generated:
-                CGen.generate_statement(sim, ast_node.definition())
+                self.generate_statement(ast_node.definition())
 
             if ast_node.kind() == BinOp.Kind_Vector:
                 assert index is not None, "Index must be set for vector reference!"
@@ -195,7 +204,7 @@ class CGen:
 
         if isinstance(ast_node, Cast):
             tkw = CGen.type2keyword(ast_node.cast_type)
-            expr = CGen.generate_expression(sim, ast_node.expr)
+            expr = self.generate_expression(ast_node.expr)
             return f"({tkw})({expr})"
 
         if isinstance(ast_node, Iter):
@@ -216,20 +225,20 @@ class CGen:
 
         if isinstance(ast_node, Sqrt):
             assert mem is False, "Square root call is not lvalue!"
-            expr = CGen.generate_expression(sim, ast_node.expr)
+            expr = self.generate_expression(ast_node.expr)
             return f"sqrt({expr})"
 
         if isinstance(ast_node, Select):
             assert mem is False, "Select expression is not lvalue!"
-            cond = CGen.generate_expression(sim, ast_node.cond)
-            expr_if = CGen.generate_expression(sim, ast_node.expr_if)
-            expr_else = CGen.generate_expression(sim, ast_node.expr_else)
+            cond = self.generate_expression(ast_node.cond)
+            expr_if = self.generate_expression(ast_node.expr_if)
+            expr_else = self.generate_expression(ast_node.expr_else)
             return f"({cond}) ? ({expr_if}) : ({expr_else})"
 
         if isinstance(ast_node, Var):
             return ast_node.name()
 
-    def generate_vtk_writing(id, filename, start, end, n, timestep):
+    def generate_vtk_writing(self, id, filename, start, end, n, timestep):
         # TODO: Do this in a more elegant way, without hard coded stuff
         header = "# vtk DataFile Version 2.0\n" \
                  "Particle data\n" \
@@ -238,50 +247,49 @@ class CGen:
 
         filename_var = f"filename{id}"
         filehandle_var = f"vtk{id}"
-        printer.print(f"char {filename_var}[128];")
-        printer.print(f"snprintf({filename_var}, sizeof {filename_var}, \"{filename}_%d.vtk\", {timestep});")
-        printer.print(f"FILE *{filehandle_var} = fopen({filename_var}, \"w\");")
+        self.print(f"char {filename_var}[128];")
+        self.print(f"snprintf({filename_var}, sizeof {filename_var}, \"{filename}_%d.vtk\", {timestep});")
+        self.print(f"FILE *{filehandle_var} = fopen({filename_var}, \"w\");")
         for line in header.split('\n'):
             if len(line) > 0:
-                printer.print(f"fwrite(\"{line}\\n\", 1, {len(line) + 1}, {filehandle_var});")
+                self.print(f"fwrite(\"{line}\\n\", 1, {len(line) + 1}, {filehandle_var});")
 
         # Write positions
-        printer.print(f"fprintf({filehandle_var}, \"POINTS %d double\\n\", {n});")
-        printer.print(f"for(int i = {start}; i < {end}; i++) {{")
-        printer.add_ind(4)
-        printer.print(f"fprintf({filehandle_var}, \"%.4f %.4f %.4f\\n\", position[i * 3], position[i * 3 + 1], position[i * 3 + 2]);")
-        printer.add_ind(-4)
-        printer.print("}")
-        printer.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});")
+        self.print(f"fprintf({filehandle_var}, \"POINTS %d double\\n\", {n});")
+        self.print(f"for(int i = {start}; i < {end}; i++) {{")
+        self.print.add_ind(4)
+        self.print(f"fprintf({filehandle_var}, \"%.4f %.4f %.4f\\n\", position[i * 3], position[i * 3 + 1], position[i * 3 + 2]);")
+        self.print.add_ind(-4)
+        self.print("}")
+        self.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});")
 
         # Write cells
-        printer.print(f"fprintf({filehandle_var}, \"CELLS %d %d\\n\", {n}, {n} * 2);")
-        printer.print(f"for(int i = {start}; i < {end}; i++) {{")
-        printer.add_ind(4)
-        printer.print(f"fprintf({filehandle_var}, \"1 %d\\n\", i - {start});")
-        printer.add_ind(-4)
-        printer.print("}")
-        printer.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});")
+        self.print(f"fprintf({filehandle_var}, \"CELLS %d %d\\n\", {n}, {n} * 2);")
+        self.print(f"for(int i = {start}; i < {end}; i++) {{")
+        self.print.add_ind(4)
+        self.print(f"fprintf({filehandle_var}, \"1 %d\\n\", i - {start});")
+        self.print.add_ind(-4)
+        self.print("}")
+        self.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});")
 
         # Write cell types
-        printer.print(f"fprintf({filehandle_var}, \"CELL_TYPES %d\\n\", {n});")
-        printer.print(f"for(int i = {start}; i < {end}; i++) {{")
-        printer.add_ind(4)
-        printer.print(f"fwrite(\"1\\n\", 1, 2, {filehandle_var});")
-        printer.add_ind(-4)
-        printer.print("}")
-        printer.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});")
+        self.print(f"fprintf({filehandle_var}, \"CELL_TYPES %d\\n\", {n});")
+        self.print(f"for(int i = {start}; i < {end}; i++) {{")
+        self.print.add_ind(4)
+        self.print(f"fwrite(\"1\\n\", 1, 2, {filehandle_var});")
+        self.print.add_ind(-4)
+        self.print("}")
+        self.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});")
 
         # Write masses
-        printer.print(f"fprintf({filehandle_var}, \"POINT_DATA %d\\n\", {n});")
-        printer.print(f"fprintf({filehandle_var}, \"SCALARS mass double\\n\");")
-        printer.print(f"fprintf({filehandle_var}, \"LOOKUP_TABLE default\\n\");")
-        printer.print(f"for(int i = {start}; i < {end}; i++) {{")
-        printer.add_ind(4)
-        #printer.print(f"fprintf({filehandle_var}, \"%4.f\\n\", mass[i]);")
-        printer.print(f"fprintf({filehandle_var}, \"1.0\\n\");")
-        printer.add_ind(-4)
-        printer.print("}")
-        printer.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});")
-
-        printer.print(f"fclose({filehandle_var});")
+        self.print(f"fprintf({filehandle_var}, \"POINT_DATA %d\\n\", {n});")
+        self.print(f"fprintf({filehandle_var}, \"SCALARS mass double\\n\");")
+        self.print(f"fprintf({filehandle_var}, \"LOOKUP_TABLE default\\n\");")
+        self.print(f"for(int i = {start}; i < {end}; i++) {{")
+        self.print.add_ind(4)
+        #self.print(f"fprintf({filehandle_var}, \"%4.f\\n\", mass[i]);")
+        self.print(f"fprintf({filehandle_var}, \"1.0\\n\");")
+        self.print.add_ind(-4)
+        self.print("}")
+        self.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});")
+        self.print(f"fclose({filehandle_var});")
diff --git a/code_gen/printer.py b/code_gen/printer.py
index 458713b51d35bdb8c4a4f6b7fd576cf92244bdcf..95299469b67fe70a6cb89c308a8a91990ee5cb0a 100644
--- a/code_gen/printer.py
+++ b/code_gen/printer.py
@@ -1,12 +1,19 @@
 class Printer:
-    def __init__(self):
+    def __init__(self, output):
+        self.output = output
+        self.stream = None
         self.indent = 0
 
     def add_ind(self, offset):
         self.indent += offset
 
-    def print(self, text):
-        print(self.indent * ' ' + text)
+    def start(self):
+        self.stream = open(self.output, 'w')
 
+    def end(self):
+        self.stream.close()
+        self.stream = None
 
-printer = Printer()
+    def __call__(self, text):
+        assert self.stream is not None, "Invalid stream!"
+        self.stream.write(self.indent * ' ' + text + '\n')
diff --git a/part_prot.py b/part_prot.py
index 7c1bcc54cee4b330451d2e007dab59c3947e1fec..dab286c5177b26f3d468ebc9da6254341246e2b8 100644
--- a/part_prot.py
+++ b/part_prot.py
@@ -2,5 +2,5 @@ from code_gen.cgen import CGen
 from sim.particle_simulation import ParticleSimulation
 
 
-def simulation(dims=3, timesteps=100):
-    return ParticleSimulation(CGen, dims, timesteps)
+def simulation(ref, dims=3, timesteps=100):
+    return ParticleSimulation(CGen(f"{ref}.c"), dims, timesteps)
diff --git a/particle.py b/particle.py
index d9b95bf4a97ba01b139f976444d85b500c859864..f46612644d8d87433b43e5dc49aba3cc98754671 100644
--- a/particle.py
+++ b/particle.py
@@ -8,7 +8,7 @@ sigma = 1.0
 epsilon = 1.0
 sigma6 = sigma ** 6
 
-psim = pt.simulation()
+psim = pt.simulation("lj")
 mass = psim.add_real_property('mass', 1.0)
 position = psim.add_vector_property('position')
 velocity = psim.add_vector_property('velocity')
diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py
index ae2f814c0ad8e90dfa6c031962474af1e53e0253..b21b5e37a0a8ce693eec2b847aeab38da64e530e 100644
--- a/sim/particle_simulation.py
+++ b/sim/particle_simulation.py
@@ -21,11 +21,13 @@ from sim.variables import VariablesDecl
 from sim.vtk import VTKWrite
 from transformations.flatten import flatten_property_accesses
 from transformations.simplify import simplify_expressions
+from transformations.LICM import move_loop_invariant_code
 
 
 class ParticleSimulation:
     def __init__(self, code_gen, dims=3, timesteps=100):
         self.code_gen = code_gen
+        self.code_gen.assign_simulation(self)
         self.global_scope = None
         self.properties = Properties(self)
         self.vars = Variables(self)
@@ -198,6 +200,7 @@ class ParticleSimulation:
         # Transformations
         flatten_property_accesses(program)
         simplify_expressions(program)
+        move_loop_invariant_code(program)
 
         ASTGraph(self.kernels.lower(), "kernels").render()
-        self.code_gen.generate_program(self, program)
+        self.code_gen.generate_program(program)
diff --git a/transformations/LICM.py b/transformations/LICM.py
index 07517b50cb4ddbab03925e119738244efc4b9512..4742f19adf852e065cb7ba65733b5ce837c2dbf3 100644
--- a/transformations/LICM.py
+++ b/transformations/LICM.py
@@ -1,3 +1,4 @@
+from ast.loops import For, While
 from ast.mutator import Mutator
 from ast.visitor import Visitor
 
@@ -5,42 +6,42 @@ from ast.visitor import Visitor
 class SetBlockVariants(Mutator):
     def __init__(self, ast):
         super().__init__(ast)
-        self.current_block = None
         self.in_assignment = None
 
-    def mutate_Block(self, ast_node):
-        self.current_block = ast_node
-        ast_node.stmts = [self.mutate(s) for s in ast_node.stmts]
-        return ast_node
-
     def mutate_Assign(self, ast_node):
-        self.in_assignment = ast_node
+        self.in_assignment = ast_node if ast_node.parent_block is not None else None
         for dest in ast_node.destinations():
             self.mutate(dest)
         self.in_assignment = None
         return ast_node
 
+    def mutate_For(self, ast_node):
+        ast_node.block.add_variant(id(ast_node.iterator))
+        ast_node.iterator = self.mutate(ast_node.iterator)
+        ast_node.block = self.mutate(ast_node.block)
+        return ast_node
+
     def mutate_Array(self, ast_node):
         if self.in_assignment is not None:
-            self.in_assignment.parent_block.add_variant(ast_node)
+            self.in_assignment.parent_block.add_variant(id(ast_node))
 
         return ast_node
 
-    def mutate_For(self, ast_node):
-        ast_node.block.add_variant(ast_node.iterator)
-        ast_node.iterator = self.mutate(ast_node.iterator)
-        ast_node.block = self.mutate(ast_node.block)
+    def mutate_Iter(self, ast_node):
+        if self.in_assignment is not None:
+            self.in_assignment.parent_block.add_variant(id(ast_node))
+
         return ast_node
 
     def mutate_Property(self, ast_node):
         if self.in_assignment is not None:
-            self.in_assignment.parent_block.add_variant(ast_node)
+            self.in_assignment.parent_block.add_variant(id(ast_node))
 
         return ast_node
 
     def mutate_Variable(self, ast_node):
         if self.in_assignment is not None:
-            self.in_assignment.parent_block.add_variant(ast_node)
+            self.in_assignment.parent_block.add_variant(id(ast_node))
 
         return ast_node
 
@@ -51,48 +52,48 @@ class SetParentBlock(Visitor):
         self.blocks = []
 
     def current_block(self):
-        return self.blocks[-1]
+        return self.blocks[-1] if self.blocks else None
 
     def visit_Assign(self, ast_node):
-        ast_node.parent_block = self.current_block
+        ast_node.parent_block = self.current_block()
         self.visit_children(ast_node)
 
     def visit_Block(self, ast_node):
-        ast_node.parent_block = self.current_block
+        ast_node.parent_block = self.current_block()
         self.blocks.append(ast_node)
         self.visit_children(ast_node)
         self.blocks.pop()
 
     def visit_BinOpDef(self, ast_node):
-        ast_node.parent_block = self.current_block
+        ast_node.parent_block = self.current_block()
         self.visit_children(ast_node)
 
     def visit_Branch(self, ast_node):
-        ast_node.parent_block = self.current_block
+        ast_node.parent_block = self.current_block()
         self.visit_children(ast_node)
 
     def visit_Filter(self, ast_node):
-        ast_node.parent_block = self.current_block
+        ast_node.parent_block = self.current_block()
         self.visit_children(ast_node)
 
     def visit_For(self, ast_node):
-        ast_node.parent_block = self.current_block
+        ast_node.parent_block = self.current_block()
         self.visit_children(ast_node)
 
     def visit_ParticleFor(self, ast_node):
-        ast_node.parent_block = self.current_block
+        ast_node.parent_block = self.current_block()
         self.visit_children(ast_node)
 
     def visit_Malloc(self, ast_node):
-        ast_node.parent_block = self.current_block
+        ast_node.parent_block = self.current_block()
         self.visit_children(ast_node)
 
     def visit_Realloc(self, ast_node):
-        ast_node.parent_block = self.current_block
+        ast_node.parent_block = self.current_block()
         self.visit_children(ast_node)
 
     def visit_While(self, ast_node):
-        ast_node.parent_block = self.current_block
+        ast_node.parent_block = self.current_block()
         self.visit_children(ast_node)
 
     def get_loop_parent_block(self, ast_node):
@@ -101,18 +102,71 @@ class SetParentBlock(Visitor):
         return self.parents[loop_id] if loop_id in self.parents else None
 
 
+class SetBinOpTerminals(Visitor):
+    def __init__(self, ast):
+        super().__init__(ast)
+        self.bin_ops = []
+
+    def visit_BinOp(self, ast_node):
+        self.bin_ops.append(ast_node)
+        self.visit_children(ast_node)
+        self.bin_ops.pop()
+
+    def visit_Array(self, ast_node):
+        for bin_op in self.bin_ops:
+            bin_op.add_terminal(id(ast_node))
+
+    def visit_Iter(self, ast_node):
+        for bin_op in self.bin_ops:
+            bin_op.add_terminal(id(ast_node))
+
+    def visit_Property(self, ast_node):
+        for bin_op in self.bin_ops:
+            bin_op.add_terminal(id(ast_node))
+
+    def visit_Variable(self, ast_node):
+        for bin_op in self.bin_ops:
+            bin_op.add_terminal(id(ast_node))
+
+
 class LICM(Mutator):
-    def __init__(self, ast, loop_parents):
+    def __init__(self, ast):
         super().__init__(ast)
-        self.loop_parents = loop_parents
+        self.lifts = {}
+        self.loops = []
 
     def mutate_For(self, ast_node):
+        self.lifts[id(ast_node)] = []
+        self.loops.append(ast_node)
         ast_node.iterator = self.mutate(ast_node.iterator)
         ast_node.block = self.mutate(ast_node.block)
+        self.loops.pop()
+        return ast_node
+
+    def mutate_BinOpDef(self, ast_node):
+        if self.loops:
+            last_loop = self.loops[-1]
+            print(f"Checking lifting for {ast_node.id()}")
+            if not last_loop.block.variants.intersect(ast_node.bin_op.terminals):
+                self.lifts[id(last_loop)].append(ast_node)
+                print(f"Lifting {ast_node.id()}")
+                return None
+
         return ast_node
 
     def mutate_Block(self, ast_node):
-        ast_node.stmts = [self.mutate(s) for s in ast_node.stmts]
+        new_stmts = []
+        stmts = self.mutate(ast_node.stmts)
+
+        for s in stmts:
+            if s is not None:
+                s_id = id(s)
+                if isinstance(s, (For, While)) and s_id in self.lifts:
+                    new_stmts = new_stmts + self.lifts[s_id]
+
+                new_stmts.append(s)
+
+        ast_node.stmts = new_stmts
         return ast_node
 
 
@@ -121,5 +175,7 @@ def move_loop_invariant_code(ast):
     set_parent_block.visit()
     set_block_variants = SetBlockVariants(ast)
     set_block_variants.mutate()
+    set_bin_op_terminals = SetBinOpTerminals(ast)
+    set_bin_op_terminals.visit()
     licm = LICM(ast)
     licm.mutate()
diff --git a/transformations/flatten.py b/transformations/flatten.py
index b06ebcb07d6a54c5b4367abdb2bcc4d3093f0e7e..b414079e8cc961fbfbaa9e282e9d40b3a50f53b7 100644
--- a/transformations/flatten.py
+++ b/transformations/flatten.py
@@ -13,7 +13,7 @@ class FlattenPropertyAccesses(Mutator):
         if ast_node.is_vector_property_access():
             layout = ast_node.lhs.layout()
 
-            for i in ast_node.vector_indexes():
+            for i in ast_node.vector_indexes:
                 flat_index = None
 
                 if layout == Layout_AoS:
diff --git a/transformations/simplify.py b/transformations/simplify.py
index ff673159bcb09af7a372f959c97bbe852e00c835..d59dd59123c2129d15886173c9f3c4d78826bfc6 100644
--- a/transformations/simplify.py
+++ b/transformations/simplify.py
@@ -11,7 +11,7 @@ class SimplifyExpressions(Mutator):
         sim = ast_node.lhs.sim
         ast_node.lhs = self.mutate(ast_node.lhs)
         ast_node.rhs = self.mutate(ast_node.rhs)
-        ast_node.bin_op_vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.bin_op_vector_index_mapping.items()}
+        ast_node.vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.vector_index_mapping.items()}
 
         if ast_node.op in ['+', '-'] and ast_node.rhs == 0:
             return ast_node.lhs