diff --git a/ast/assign.py b/ast/assign.py
index a6cc919c8e528b16f6160549c2461ef9c05fa08d..08f6a7200bc61d523a4e05375718342cd33a9130 100644
--- a/ast/assign.py
+++ b/ast/assign.py
@@ -7,7 +7,6 @@ from functools import reduce
 class Assign(ASTNode):
     def __init__(self, sim, dest, src):
         super().__init__(sim)
-        self.parent_block = None
         self.type = dest.type()
         src = as_lit_ast(sim, src)
 
@@ -16,18 +15,19 @@ class Assign(ASTNode):
 
             for i in range(0, sim.dimensions):
                 from ast.bin_op import BinOp
-                dsrc = (src if (not isinstance(src, BinOp) or
-                                src.type() != Type_Vector)
-                        else src[i])
-
-                self.assignments.append((dest[i], dsrc))
+                dim_src = src if not isinstance(src, BinOp) or src.type() != Type_Vector else src[i]
+                self.assignments.append((dest[i], dim_src))
         else:
             self.assignments = [(dest, src)]
 
     def __str__(self):
         return f"Assign<{self.assignments}>"
 
+    def destinations(self):
+        return [a[0] for a in self.assignments]
+
+    def sources(self):
+        return [a[1] for a in self.assignments]
+
     def children(self):
-        return reduce((lambda x, y: x + y), [
-                      [self.assignments[i][0], self.assignments[i][1]]
-                      for i in range(0, len(self.assignments))])
+        return reduce((lambda x, y: x + y), [[a[0], a[1]] for a in self.assignments])
diff --git a/ast/ast_node.py b/ast/ast_node.py
index e58afc26693907e38fd322ff5969d7a37271e675..8d37b6961bd96c0268fc92581d90aec561c3832b 100644
--- a/ast/ast_node.py
+++ b/ast/ast_node.py
@@ -4,6 +4,7 @@ from ast.data_types import Type_Invalid
 class ASTNode:
     def __init__(self, sim):
         self.sim = sim
+        self._parent_block = None # Set during SetParentBlock transformation
 
     def __str__(self):
         return "ASTNode<>"
@@ -14,5 +15,9 @@ class ASTNode:
     def scope(self):
         return self.sim.global_scope
 
+    @property
+    def parent_block(self):
+        return self._parent_block
+
     def children(self):
         return []
diff --git a/ast/block.py b/ast/block.py
index e85facec636e02ee94235d762cef1f85f355db07..85ddd1fe13f644a47caf0e92c803232e4e714a9a 100644
--- a/ast/block.py
+++ b/ast/block.py
@@ -5,7 +5,7 @@ class Block(ASTNode):
     def __init__(self, sim, stmts):
         super().__init__(sim)
         self.level = 0
-        self.expressions = []
+        self.variants = []
 
         if isinstance(stmts, Block):
             self.stmts = stmts.statements()
@@ -29,24 +29,19 @@ class Block(ASTNode):
 
     def add_statement(self, stmt):
         if isinstance(stmt, list):
-            for s in stmt:
-                s.parent_block = self
-
             self.stmts = self.stmts + stmt
-
         else:
-            stmt.parent_block = self
             self.stmts.append(stmt)
 
+    def add_variant(self, variant):
+        if isinstance(variant, list):
+            self.variants = self.variants + variant
+        else:
+            self.variants.append(variant)
+
     def statements(self):
         return self.stmts
 
-    def add_expression(self, expr):
-        if isinstance(expr, list):
-            self.expressions = self.expressions + expr
-        else:
-            self.expressions.append(expr)
-
     def children(self):
         return self.stmts
 
diff --git a/ast/branches.py b/ast/branches.py
index a6302609d42b7ed6b25513a2ce66bdbc5717e027..209dd12b67fcb43c15015d628b682a40d6c8a9fc 100644
--- a/ast/branches.py
+++ b/ast/branches.py
@@ -6,7 +6,6 @@ from ast.lit import as_lit_ast
 class Branch(ASTNode):
     def __init__(self, sim, cond, one_way=False, blk_if=None, blk_else=None):
         self.sim = sim
-        self.parent_block = None
         self.cond = as_lit_ast(sim, cond)
         self.switch = True
         self.block_if = Block(sim, []) if blk_if is None else blk_if
diff --git a/ast/loops.py b/ast/loops.py
index e0c645cfe256a281b509a14f1d4df8487da911b1..2dc06ee35db21c2ce2f683c1f974197ea10dc8d8 100644
--- a/ast/loops.py
+++ b/ast/loops.py
@@ -46,7 +46,6 @@ class For(ASTNode):
         self.iterator = Iter(sim, self)
         self.min = as_lit_ast(sim, range_min)
         self.max = as_lit_ast(sim, range_max)
-        self.parent_block = None
         self.block = Block(sim, []) if block is None else block
 
     def __str__(self):
@@ -80,7 +79,6 @@ class ParticleFor(For):
 class While(ASTNode):
     def __init__(self, sim, cond, block=None):
         super().__init__(sim)
-        self.parent_block = None
         self.cond = BinOp.inline(cond)
         self.block = Block(sim, []) if block is None else block
 
@@ -103,7 +101,6 @@ class While(ASTNode):
 class NeighborFor():
     def __init__(self, sim, particle, cell_lists):
         self.sim = sim
-        self.parent_block = None
         self.particle = particle
         self.cell_lists = cell_lists
 
diff --git a/ast/memory.py b/ast/memory.py
index 3458bf9fa464a1a4cde36b555b85324fff40b4b1..e853a0dd77b6df1827b6eeaf2dd0034409d0263d 100644
--- a/ast/memory.py
+++ b/ast/memory.py
@@ -8,7 +8,6 @@ import operator
 class Malloc(ASTNode):
     def __init__(self, sim, array, sizes, decl=False):
         super().__init__(sim)
-        self.parent_block = None
         self.array = array
         self.decl = decl
         self.prim_size = Sizeof(sim, array.type())
@@ -22,7 +21,6 @@ class Malloc(ASTNode):
 class Realloc(ASTNode):
     def __init__(self, sim, array, size):
         super().__init__(sim)
-        self.parent_block = None
         self.array = array
         self.prim_size = Sizeof(sim, array.type())
         self.size = BinOp.inline(self.prim_size * size)
diff --git a/ast/mutator.py b/ast/mutator.py
index f450773eb5081d27f4fa8f5f1df8546024802323..6ec8727c7713586dfc4653c3b9db9890b47b29e1 100644
--- a/ast/mutator.py
+++ b/ast/mutator.py
@@ -1,66 +1,19 @@
-from ast.arrays import ArrayAccess
-from ast.assign import Assign
-from ast.bin_op import BinOp, BinOpDef
-from ast.block import Block
-from ast.branches import Branch
-from ast.cast import Cast
-from ast.loops import For, While
-from ast.math import Sqrt
-from ast.memory import Malloc, Realloc
-from ast.select import Select
-from sim.timestep import Timestep
-
-
 class Mutator:
     def __init__(self, ast, max_depth=0):
         self.ast = ast
         self.max_depth = 0
 
+    def get_method(self, method_name):
+        method = getattr(self, method_name, None)
+        return method if callable(method) else None
+
     def mutate(self, ast_node=None):
         if ast_node is None:
             ast_node = self.ast
 
-        if isinstance(ast_node, ArrayAccess):
-            return self.mutate_ArrayAccess(ast_node)
-
-        elif isinstance(ast_node, Assign):
-            return self.mutate_Assign(ast_node)
-
-        elif isinstance(ast_node, BinOp):
-            return self.mutate_BinOp(ast_node)
-
-        elif isinstance(ast_node, BinOpDef):
-            return self.mutate_BinOpDef(ast_node)
-
-        elif isinstance(ast_node, Block):
-            return self.mutate_Block(ast_node)
-
-        elif isinstance(ast_node, Branch):
-            return self.mutate_Branch(ast_node)
-
-        elif isinstance(ast_node, Cast):
-            return self.mutate_Cast(ast_node)
-
-        elif isinstance(ast_node, For):
-            return self.mutate_For(ast_node)
-
-        elif isinstance(ast_node, Malloc):
-            return self.mutate_Malloc(ast_node)
-
-        elif isinstance(ast_node, Realloc):
-            return self.mutate_Realloc(ast_node)
-
-        elif isinstance(ast_node, Select):
-            return self.mutate_Select(ast_node)
-
-        elif isinstance(ast_node, Sqrt):
-            return self.mutate_Sqrt(ast_node)
-
-        elif isinstance(ast_node, Timestep):
-            return self.mutate_Timestep(ast_node)
-
-        elif isinstance(ast_node, While):
-            return self.mutate_While(ast_node)
+        method = self.get_method(f"mutate_{type(ast_node).__name__}")
+        if method is not None:
+            return method(ast_node)
 
         return ast_node
 
@@ -74,10 +27,7 @@ class Mutator:
         return ast_node 
 
     def mutate_Assign(self, ast_node):
-        ast_node.assignments = [
-            (self.mutate(ast_node.assignments[i][0]), self.mutate(ast_node.assignments[i][1]))
-            for i in range(0, len(ast_node.assignments))
-        ]
+        ast_node.assignments = [(self.mutate(a[0]), self.mutate(a[1])) for a in ast_node.assignments]
         return ast_node
 
     def mutate_BinOp(self, ast_node):
@@ -100,6 +50,9 @@ class Mutator:
         ast_node.block_else = None if ast_node.block_else is None else self.mutate(ast_node.block_else)
         return ast_node
 
+    def mutate_Filter(self, ast_node):
+        return self.mutate_Branch(ast_node)
+
     def mutate_Cast(self, ast_node):
         ast_node.expr = self.mutate(ast_node.expr)
         return ast_node
@@ -109,6 +62,9 @@ class Mutator:
         ast_node.block = self.mutate(ast_node.block)
         return ast_node
 
+    def mutate_ParticleFor(self, ast_node):
+        return self.mutate_For(ast_node)
+
     def mutate_Malloc(self, ast_node):
         ast_node.array = self.mutate(ast_node.array)
         ast_node.size = self.mutate(ast_node.size)
diff --git a/ast/visitor.py b/ast/visitor.py
index 419b5e7df9e69eb7992eb49300bd7d925e38fd7b..da2ff9c04b4827a115162144df995956fb4916a1 100644
--- a/ast/visitor.py
+++ b/ast/visitor.py
@@ -1,49 +1,53 @@
-from ast.bin_op import BinOp
-from ast.loops import For, ParticleFor, While
+from collections import deque
 
 
 class Visitor:
-    def __init__(self, ast, max_depth=0):
+    def __init__(self, ast, max_depth=0, breadth_first=False):
         self.ast = ast
-        self.max_depth = 0
+        self.max_depth = max_depth
+        self.breadth_first = breadth_first
 
-    def visit(ast_node):
-        for c in ast_node.children():
-            if isinstance(c, Array):
-                self.visit_Array(c)
-            elif isinstance(c, BinOp):
-                self.visit_BinOp(c)
-            elif isinstance(c, (For, ParticleFor, While)):
-                self.visit_Loop(c)
-            else:
-                self.visit(c)
+    def get_method(self, method_name):
+        method = getattr(self, method_name, None)
+        return method if callable(method) else None
 
-    def visit_Array(self, ast_node):
-        return self.visit(ast_node)
+    def visit(self, ast_node=None):
+        if ast_node is None:
+            ast_node = self.ast
 
-    def visit_BinOp(self, ast_node):
-        return self.visit(ast_node)
+        method = self.get_method(f"visit_{type(ast_node).__name__}")
+        if method is not None:
+            method(ast_node)
+        else:
+            self.keep_visiting(ast_node)
 
-    def visit_Loop(self, ast_node):
-        return self.visit(ast_node)
+    def keep_visiting(self, ast_node):
+        for c in ast_node.children():
+            self.visit(c)
 
-    def list_ast(self):
-        self.list_elements(self.ast)
+    def yield_elements_breadth_first(self, ast_node=None):
+        nodes_to_visit = deque()
 
-    def list_elements(self, ast):
-        ast_list = [ast]
+        if ast_node is None:
+            ast_node = self.ast
 
-        for c in self.ast.children():
-            ast_list += self.list_elements(c)
+        nodes_to_visit.append(ast_node)
 
-        return ast_list
+        while nodes_to_visit:
+            next_node = nodes_to_visit.popleft() # nodes_to_visit.pop() for depth-first traversal
+            yield next_node
+            for c in next_node.children():
+                nodes_to_visit.append(c)
 
-    def yield_elements(ast, depth, max_depth):
-        yield ast
-        if depth < max_depth or max_depth == 0:
+    def yield_elements(self, ast, depth):
+        if depth < self.max_depth or self.max_depth == 0:
             for child in ast.children():
-                for child_node in Visitor.yield_elements(child, depth + 1, max_depth):
-                    yield child_node
+                yield child
+                yield from self.yield_elements(child, depth + 1)
 
     def __iter__(self):
-        yield from Visitor.yield_elements(self.ast, 0, self.max_depth)
+        if self.breadth_first:
+            yield from self.yield_elements_breadth_first(self.ast)
+        else:
+            yield self.ast
+            yield from self.yield_elements(self.ast, 1)
diff --git a/code_gen/cgen.py b/code_gen/cgen.py
index ea87d2a6cffda99eb69d73386ad1eadc7ee89a83..c28e353d69cf64fe502dd8000c14feac669f0dd0 100644
--- a/code_gen/cgen.py
+++ b/code_gen/cgen.py
@@ -183,7 +183,7 @@ class CGen:
                 return f"({lhs} {ast_node.operator()} {rhs})"
 
             # Some expressions can be defined on-the-fly during transformations, hence they do not have
-            # a definition statement in the tree, so we generate them right before usage
+            # a definition statement in the tree, so we generate them right before use
             if not ast_node.generated:
                 CGen.generate_statement(sim, ast_node.definition())
 
diff --git a/transformations/LICM.py b/transformations/LICM.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0edaf702142b0584b33f9800b32ff77a92340c3
--- /dev/null
+++ b/transformations/LICM.py
@@ -0,0 +1,117 @@
+from ast.mutator import Mutator
+from ast.visitor import Visitor
+
+
+class SetBlockVariants(Mutator):
+    def __init__(self, ast):
+        super().__init__(ast)
+        self.current_block = None
+        self.in_assignment = None
+
+    def mutate_Block(self, ast_node):
+        self.current_block = ast_node
+        ast_node.stmts = [self.mutate(s) for s in ast_node.stmts]
+        return ast_node
+
+    def mutate_Assign(self, ast_node):
+        self.in_assignment = ast_node
+        for dest in ast_node.destinations():
+            self.mutate(dest)
+        self.in_assignment = None
+        return ast_node
+
+    def mutate_Array(self, ast_node):
+        if self.in_assignment is not None:
+            self.in_assignment.parent_block.add_variant(ast_node)
+
+        return ast_node
+
+    def mutate_For(self, ast_node):
+        ast_node.block.add_variant(ast_node.iterator)
+        ast_node.iterator = self.mutate(ast_node.iterator)
+        ast_node.block = self.mutate(ast_node.block)
+        return ast_node
+
+    def mutate_Property(self, ast_node):
+        if self.in_assignment is not None:
+            self.in_assignment.parent_block.add_variant(ast_node)
+
+        return ast_node
+
+    def mutate_Variable(self, ast_node):
+        if self.in_assignment is not None:
+            self.in_assignment.parent_block.add_variant(ast_node)
+
+        return ast_node
+
+
+class SetParentBlock(Visitor):
+    def __init__(self, ast):
+        super().__init__(ast)
+        self.blocks = []
+
+    def current_block(self):
+        return self.blocks[-1]
+
+    def visit_Assign(self, ast_node):
+        ast_node.parent_block = self.current_block
+        self.keep_visiting(ast_node)
+
+    def visit_Block(self, ast_node):
+        ast_node.parent_block = self.current_block
+        self.blocks.append(ast_node)
+        self.keep_visiting(ast_node)
+        self.blocks.pop()
+
+    def visit_BinOpDef(self, ast_node):
+        ast_node.parent_block = self.current_block
+        self.keep_visiting(ast_node)
+
+    def visit_Branch(self, ast_node):
+        ast_node.parent_block = self.current_block
+        self.keep_visiting(ast_node)
+
+    def visit_For(self, ast_node):
+        ast_node.parent_block = self.current_block
+        self.keep_visiting(ast_node)
+
+    def visit_Malloc(self, ast_node):
+        ast_node.parent_block = self.current_block
+        self.keep_visiting(ast_node)
+
+    def visit_Realloc(self, ast_node):
+        ast_node.parent_block = self.current_block
+        self.keep_visiting(ast_node)
+
+    def visit_While(self, ast_node):
+        ast_node.parent_block = self.current_block
+        self.keep_visiting(ast_node)
+
+    def get_loop_parent_block(self, ast_node):
+        assert isinstance(ast_node, (For, While)), "Node must be a loop!"
+        loop_id = id(ast_node)
+        return self.parents[loop_id] if loop_id in self.parents else None
+
+
+class LICM(Mutator):
+    def __init__(self, ast, loop_parents):
+        super().__init__(ast)
+        self.loop_parents = loop_parents
+
+    def mutate_For(self, ast_node):
+        ast_node.iterator = self.mutate(ast_node.iterator)
+        ast_node.block = self.mutate(ast_node.block)
+        return ast_node
+
+    def mutate_Block(self, ast_node):
+        ast_node.stmts = [self.mutate(s) for s in ast_node.stmts]
+        return ast_node
+
+
+def move_loop_invariant_code(ast):
+    set_parent_block = SetParentBlock(ast)
+    set_parent_block.visit()
+    set_block_variants = SetBlockVariants(ast)
+    set_block_variants.mutate()
+    licm = LICM(ast, set_loop_parents)
+    licm.mutate()