diff --git a/ast/arrays.py b/ast/arrays.py
index 308bd12bf7bfedbec1f42d6a3ef7ae1cdb96ee00..1ea90143c4a2cd15fe0cbc4ac5b8d3c345bb0da7 100644
--- a/ast/arrays.py
+++ b/ast/arrays.py
@@ -175,15 +175,6 @@ class ArrayAccess(ASTTerm):
     def children(self):
         return [self.array] + self.indexes
 
-    def transform(self, fn):
-        self.array = self.array.transform(fn)
-        self.indexes = [i.transform(fn) for i in self.indexes]
-
-        if self.index is not None:
-            self.index = self.index.transform(fn)
-
-        return fn(self)
-
 
 class ArrayDecl(ASTNode):
     def __init__(self, sim, array):
diff --git a/ast/assign.py b/ast/assign.py
index 8ce1d7b8ac8e3027dc3f2b2e4a78cefb210e04b6..a6cc919c8e528b16f6160549c2461ef9c05fa08d 100644
--- a/ast/assign.py
+++ b/ast/assign.py
@@ -31,11 +31,3 @@ class Assign(ASTNode):
         return reduce((lambda x, y: x + y), [
                       [self.assignments[i][0], self.assignments[i][1]]
                       for i in range(0, len(self.assignments))])
-
-    def transform(self, fn):
-        self.assignments = [(
-            self.assignments[i][0].transform(fn),
-            self.assignments[i][1].transform(fn))
-            for i in range(0, len(self.assignments))]
-
-        return fn(self)
diff --git a/ast/ast_node.py b/ast/ast_node.py
index 7e426762ba235fba865137a30c45c939eeb88638..e58afc26693907e38fd322ff5969d7a37271e675 100644
--- a/ast/ast_node.py
+++ b/ast/ast_node.py
@@ -16,6 +16,3 @@ class ASTNode:
 
     def children(self):
         return []
-
-    def transform(self, fn):
-        return fn(self)
diff --git a/ast/bin_op.py b/ast/bin_op.py
index 840ca43fa81fd8bcf28acc90f7d43fc1dfd01c7d..ba826203772cb6bae73e7396b097c58d0cce12f4 100644
--- a/ast/bin_op.py
+++ b/ast/bin_op.py
@@ -17,10 +17,6 @@ class BinOpDef(ASTNode):
     def children(self):
         return [self.bin_op]
 
-    def transform(self, fn):
-        self.bin_op = self.bin_op.transform(fn)
-        return fn(self)
-
 
 class BinOp(ASTNode):
     # BinOp kinds
@@ -177,12 +173,6 @@ class BinOp(ASTNode):
     def children(self):
         return [self.lhs, self.rhs]
 
-    def transform(self, fn):
-        self.lhs = self.lhs.transform(fn)
-        self.rhs = self.rhs.transform(fn)
-        self.bin_op_vector_index_mapping = {i: e.transform(fn) for i, e in self.bin_op_vector_index_mapping.items()}
-        return fn(self)
-
     def __add__(self, other):
         return BinOp(self.sim, self, other, '+')
 
diff --git a/ast/block.py b/ast/block.py
index b7f287e1b055dc1d49160858c49e631c550856ce..e85facec636e02ee94235d762cef1f85f355db07 100644
--- a/ast/block.py
+++ b/ast/block.py
@@ -1,5 +1,4 @@
 from ast.ast_node import ASTNode
-from ast.visitor import Visitor
 
 
 class Block(ASTNode):
@@ -51,12 +50,6 @@ class Block(ASTNode):
     def children(self):
         return self.stmts
 
-    def transform(self, fn):
-        for i in range(0, len(self.stmts)):
-            self.stmts[i] = self.stmts[i].transform(fn)
-
-        return fn(self)
-
     def merge_blocks(block1, block2):
         assert isinstance(block1, Block), "First block type is not Block!"
         assert isinstance(block2, Block), "Second block type is not Block!"
diff --git a/ast/branches.py b/ast/branches.py
index ab31b0b17dbe0ab049c79d63d26026d0fa6f32ff..a6302609d42b7ed6b25513a2ce66bdbc5717e027 100644
--- a/ast/branches.py
+++ b/ast/branches.py
@@ -37,14 +37,6 @@ class Branch(ASTNode):
         return [self.cond, self.block_if] + \
                ([] if self.block_else is None else [self.block_else])
 
-    def transform(self, fn):
-        self.cond = self.cond.transform(fn)
-        self.block_if = self.block_if.transform(fn)
-        self.block_else = \
-            None if self.block_else is None \
-            else self.block_else.transform(fn)
-        return fn(self)
-
 
 class Filter(Branch):
     def __init__(self, sim, cond):
diff --git a/ast/cast.py b/ast/cast.py
index 0a30700b05311ed3bdc6e3d116bb0c89a9764391..39a8538ae44163d61c110852dc4647bb7dd81ede 100644
--- a/ast/cast.py
+++ b/ast/cast.py
@@ -25,7 +25,3 @@ class Cast(ASTNode):
 
     def children(self):
         return [self.expr]
-
-    def transform(self, fn):
-        self.expr = self.expr.transform(fn)
-        return fn(self)
diff --git a/ast/loops.py b/ast/loops.py
index 902cec345ad65dc2968c8b9d7861e90f46a5c1cb..e0c645cfe256a281b509a14f1d4df8487da911b1 100644
--- a/ast/loops.py
+++ b/ast/loops.py
@@ -67,11 +67,6 @@ class For(ASTNode):
     def children(self):
         return [self.iterator, self.block]
 
-    def transform(self, fn):
-        self.iterator = self.iterator.transform(fn)
-        self.block = self.block.transform(fn)
-        return fn(self)
-
 
 class ParticleFor(For):
     def __init__(self, sim, block=None, local_only=True):
@@ -104,11 +99,6 @@ class While(ASTNode):
     def children(self):
         return [self.cond, self.block]
 
-    def transform(self, fn):
-        self.cond = self.cond.transform(fn)
-        self.block = self.block.transform(fn)
-        return fn(self)
-
 
 class NeighborFor():
     def __init__(self, sim, particle, cell_lists):
diff --git a/ast/math.py b/ast/math.py
index 90977b8af5d53e8716245ad867b8787bab8dec66..4e610cb378277962039096f342acc553ece932d0 100644
--- a/ast/math.py
+++ b/ast/math.py
@@ -18,7 +18,3 @@ class Sqrt(ASTNode):
 
     def children(self):
         return [self.expr]
-
-    def transform(self, fn):
-        self.expr = self.expr.transform(fn)
-        return fn(self)
diff --git a/ast/memory.py b/ast/memory.py
index ef8abfed3c2f7a736861d04c4e188e457e56af16..3458bf9fa464a1a4cde36b555b85324fff40b4b1 100644
--- a/ast/memory.py
+++ b/ast/memory.py
@@ -18,11 +18,6 @@ class Malloc(ASTNode):
     def children(self):
         return [self.array, self.size]
 
-    def transform(self, fn):
-        self.array = self.array.transform(fn)
-        self.size = self.size.transform(fn)
-        return fn(self)
-
 
 class Realloc(ASTNode):
     def __init__(self, sim, array, size):
@@ -35,8 +30,3 @@ class Realloc(ASTNode):
 
     def children(self):
         return [self.array, self.size]
-
-    def transform(self, fn):
-        self.array = self.array.transform(fn)
-        self.size = self.size.transform(fn)
-        return fn(self)
diff --git a/ast/mutator.py b/ast/mutator.py
new file mode 100644
index 0000000000000000000000000000000000000000..f450773eb5081d27f4fa8f5f1df8546024802323
--- /dev/null
+++ b/ast/mutator.py
@@ -0,0 +1,139 @@
+from ast.arrays import ArrayAccess
+from ast.assign import Assign
+from ast.bin_op import BinOp, BinOpDef
+from ast.block import Block
+from ast.branches import Branch
+from ast.cast import Cast
+from ast.loops import For, While
+from ast.math import Sqrt
+from ast.memory import Malloc, Realloc
+from ast.select import Select
+from sim.timestep import Timestep
+
+
+class Mutator:
+    def __init__(self, ast, max_depth=0):
+        self.ast = ast
+        self.max_depth = 0
+
+    def mutate(self, ast_node=None):
+        if ast_node is None:
+            ast_node = self.ast
+
+        if isinstance(ast_node, ArrayAccess):
+            return self.mutate_ArrayAccess(ast_node)
+
+        elif isinstance(ast_node, Assign):
+            return self.mutate_Assign(ast_node)
+
+        elif isinstance(ast_node, BinOp):
+            return self.mutate_BinOp(ast_node)
+
+        elif isinstance(ast_node, BinOpDef):
+            return self.mutate_BinOpDef(ast_node)
+
+        elif isinstance(ast_node, Block):
+            return self.mutate_Block(ast_node)
+
+        elif isinstance(ast_node, Branch):
+            return self.mutate_Branch(ast_node)
+
+        elif isinstance(ast_node, Cast):
+            return self.mutate_Cast(ast_node)
+
+        elif isinstance(ast_node, For):
+            return self.mutate_For(ast_node)
+
+        elif isinstance(ast_node, Malloc):
+            return self.mutate_Malloc(ast_node)
+
+        elif isinstance(ast_node, Realloc):
+            return self.mutate_Realloc(ast_node)
+
+        elif isinstance(ast_node, Select):
+            return self.mutate_Select(ast_node)
+
+        elif isinstance(ast_node, Sqrt):
+            return self.mutate_Sqrt(ast_node)
+
+        elif isinstance(ast_node, Timestep):
+            return self.mutate_Timestep(ast_node)
+
+        elif isinstance(ast_node, While):
+            return self.mutate_While(ast_node)
+
+        return ast_node
+
+    def mutate_ArrayAccess(self, ast_node):
+        ast_node.array = self.mutate(ast_node.array)
+        ast_node.indexes = [self.mutate(i) for i in ast_node.indexes]
+
+        if ast_node.index is not None:
+            ast_node.index = self.mutate(ast_node.index)
+
+        return ast_node 
+
+    def mutate_Assign(self, ast_node):
+        ast_node.assignments = [
+            (self.mutate(ast_node.assignments[i][0]), self.mutate(ast_node.assignments[i][1]))
+            for i in range(0, len(ast_node.assignments))
+        ]
+        return ast_node
+
+    def mutate_BinOp(self, ast_node):
+        ast_node.lhs = self.mutate(ast_node.lhs)
+        ast_node.rhs = self.mutate(ast_node.rhs)
+        ast_node.bin_op_vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.bin_op_vector_index_mapping.items()}
+        return ast_node
+
+    def mutate_BinOpDef(self, ast_node):
+        ast_node.bin_op = self.mutate(ast_node.bin_op)
+        return ast_node
+
+    def mutate_Block(self, ast_node):
+        ast_node.stmts = [self.mutate(s) for s in ast_node.stmts]
+        return ast_node
+
+    def mutate_Branch(self, ast_node):
+        ast_node.cond = self.mutate(ast_node.cond)
+        ast_node.block_if = self.mutate(ast_node.block_if)
+        ast_node.block_else = None if ast_node.block_else is None else self.mutate(ast_node.block_else)
+        return ast_node
+
+    def mutate_Cast(self, ast_node):
+        ast_node.expr = self.mutate(ast_node.expr)
+        return ast_node
+
+    def mutate_For(self, ast_node):
+        ast_node.iterator = self.mutate(ast_node.iterator)
+        ast_node.block = self.mutate(ast_node.block)
+        return ast_node
+
+    def mutate_Malloc(self, ast_node):
+        ast_node.array = self.mutate(ast_node.array)
+        ast_node.size = self.mutate(ast_node.size)
+        return ast_node
+
+    def mutate_Realloc(self, ast_node):
+        ast_node.array = self.mutate(ast_node.array)
+        ast_node.size = self.mutate(ast_node.size)
+        return ast_node
+
+    def mutate_Select(self, ast_node):
+        ast_node.cond = self.mutate(ast_node.cond)
+        ast_node.expr_if = self.mutate(ast_node.expr_if)
+        ast_node.expr_else = self.mutate(ast_node.expr_else)
+        return ast_node
+
+    def mutate_Sqrt(self, ast_node):
+        ast_node.expr = self.mutate(ast_node.expr)
+        return ast_node
+
+    def mutate_Timestep(self, ast_node):
+        ast_node.block = self.mutate(ast_node.block)
+        return ast_node
+
+    def mutate_While(self, ast_node):
+        ast_node.cond = self.mutate(ast_node.cond)
+        ast_node.block = self.mutate(ast_node.block)
+        return ast_node
diff --git a/ast/select.py b/ast/select.py
index 702bb8fb4a16e54fed93e717728e3439f35bdd21..1fe9d088b1bbb504279966769a7302897e701e6c 100644
--- a/ast/select.py
+++ b/ast/select.py
@@ -12,10 +12,3 @@ class Select(ASTNode):
 
     def children(self):
         return [self.cond, self.expr_if, self.expr_else]
-
-    def transform(self, fn):
-        self.cond = self.cond.transform(fn)
-        self.expr_if = self.expr_if.transform(fn)
-        self.expr_else = self.expr_else.transform(fn)
-        return fn(self)
-
diff --git a/ast/visitor.py b/ast/visitor.py
index 8cf72e97d3a6ce6952a83f9e50767f7fa272f276..419b5e7df9e69eb7992eb49300bd7d925e38fd7b 100644
--- a/ast/visitor.py
+++ b/ast/visitor.py
@@ -1,22 +1,31 @@
-class Visitor:
-    def __init__(self, ast, enter_fn=None, leave_fn=None, max_depth=0):
-        self.ast = ast
-        self.enter_fn = enter_fn
-        self.leave_fn = leave_fn
-        self.max_depth = max_depth
-
-    def visit(self):
-        self.visit_rec(self.ast)
+from ast.bin_op import BinOp
+from ast.loops import For, ParticleFor, While
 
-    def visit_rec(self, ast):
-        if self.enter_fn is not None:
-            self.enter_fn(ast)
 
-        for c in ast.children():
-            self.visit_rec(c)
-
-        if self.leave_fn is not None:
-            self.leave_fn(ast)
+class Visitor:
+    def __init__(self, ast, max_depth=0):
+        self.ast = ast
+        self.max_depth = 0
+
+    def visit(ast_node):
+        for c in ast_node.children():
+            if isinstance(c, Array):
+                self.visit_Array(c)
+            elif isinstance(c, BinOp):
+                self.visit_BinOp(c)
+            elif isinstance(c, (For, ParticleFor, While)):
+                self.visit_Loop(c)
+            else:
+                self.visit(c)
+
+    def visit_Array(self, ast_node):
+        return self.visit(ast_node)
+
+    def visit_BinOp(self, ast_node):
+        return self.visit(ast_node)
+
+    def visit_Loop(self, ast_node):
+        return self.visit(ast_node)
 
     def list_ast(self):
         self.list_elements(self.ast)
diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py
index 12104a3acdecbe431db74f844f56e7be7e9c1dd6..53f8ef1716087fa78291eaabbb7702b4ac629ecb 100644
--- a/sim/particle_simulation.py
+++ b/sim/particle_simulation.py
@@ -5,10 +5,9 @@ from ast.data_types import Type_Int, Type_Float, Type_Vector
 from ast.layouts import Layout_AoS
 from ast.loops import ParticleFor, NeighborFor
 from ast.properties import Properties
-from ast.transform import Transform
 from ast.variables import Variables
-from sim.arrays import ArraysDecl
 from graph.graphviz import ASTGraph
+from sim.arrays import ArraysDecl
 from sim.cell_lists import CellLists, CellListsBuild, CellListsStencilBuild
 from sim.grid import Grid2D, Grid3D
 from sim.kernel_wrapper import KernelWrapper
@@ -20,6 +19,8 @@ from sim.setup_wrapper import SetupWrapper
 from sim.timestep import Timestep
 from sim.variables import VariablesDecl
 from sim.vtk import VTKWrite
+from transformations.flatten import flatten_property_accesses
+from transformations.simplify import simplify_expressions
 
 
 class ParticleSimulation:
@@ -193,12 +194,10 @@ class ParticleSimulation:
 
         program = Block.merge_blocks(decls, body)
         self.global_scope = program
-        Transform.apply(program, Transform.flatten)
-        Transform.apply(program, Transform.simplify)
-        #Transform.apply(program, Transform.reuse_index_expressions)
-        #Transform.apply(program, Transform.reuse_expr_expressions)
-        #Transform.apply(program, Transform.reuse_array_access_expressions)
-        #Transform.apply(program, Transform.move_loop_invariant_expressions)
+
+        # Transformations
+        flatten_property_accesses(program)
+        simplify_expressions(program)
 
         ASTGraph(self.kernels.lower(), "kernels").generate_and_view()
         self.code_gen.generate_program(self, program)
diff --git a/sim/timestep.py b/sim/timestep.py
index 0d541330bce00d894ae82442f6c19f4c3649ae77..74aa2419a8bd7d85338399807f6644ee3d593129 100644
--- a/sim/timestep.py
+++ b/sim/timestep.py
@@ -41,6 +41,3 @@ class Timestep:
 
     def as_block(self):
         return Block(self.sim, [self.timestep_loop])
-
-    def transform(self, fn):
-        self.block = self.block.transform(fn)
diff --git a/sim/vtk.py b/sim/vtk.py
index c1620f85e0bc79238e28ca8b849741ea73540399..284533abf61ef096865ab3161061a77809cd08bf 100644
--- a/sim/vtk.py
+++ b/sim/vtk.py
@@ -1,17 +1,13 @@
 from ast.lit import as_lit_ast
+from ast.ast_node import ASTNode
 
-class VTKWrite:
+
+class VTKWrite(ASTNode):
     vtk_id = 0
 
     def __init__(self, sim, filename, timestep):
-        self.sim = sim
+        super().__init__(sim)
         self.vtk_id = VTKWrite.vtk_id
         self.filename = filename
         self.timestep = as_lit_ast(sim, timestep)
         VTKWrite.vtk_id += 1
-
-    def children(self):
-        return []
-
-    def transform(self, fn):
-        return fn(self)
diff --git a/transformations/flatten.py b/transformations/flatten.py
new file mode 100644
index 0000000000000000000000000000000000000000..b06ebcb07d6a54c5b4367abdb2bcc4d3093f0e7e
--- /dev/null
+++ b/transformations/flatten.py
@@ -0,0 +1,35 @@
+from ast.layouts import Layout_AoS, Layout_SoA
+from ast.mutator import Mutator
+
+
+class FlattenPropertyAccesses(Mutator):
+    def __init__(self, ast):
+        super().__init__(ast)
+
+    def mutate_BinOp(self, ast_node):
+        ast_node.lhs = self.mutate(ast_node.lhs)
+        ast_node.rhs = self.mutate(ast_node.rhs)
+
+        if ast_node.is_vector_property_access():
+            layout = ast_node.lhs.layout()
+
+            for i in ast_node.vector_indexes():
+                flat_index = None
+
+                if layout == Layout_AoS:
+                    flat_index = ast_node.rhs * ast_node.sim.dimensions + i
+
+                elif layout == Layout_SoA:
+                    flat_index = i * ast_node.sim.particle_capacity + ast_node.rhs
+
+                else:
+                    raise Exception("Invalid property layout!")
+
+                ast_node.map_vector_index(i, flat_index)
+
+        return ast_node
+
+
+def flatten_property_accesses(ast_node):
+    flatten = FlattenPropertyAccesses(ast_node)
+    flatten.mutate()
diff --git a/transformations/simplify.py b/transformations/simplify.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff673159bcb09af7a372f959c97bbe852e00c835
--- /dev/null
+++ b/transformations/simplify.py
@@ -0,0 +1,36 @@
+from ast.data_types import Type_Int
+from ast.lit import Lit
+from ast.mutator import Mutator
+
+
+class SimplifyExpressions(Mutator):
+    def __init__(self, ast):
+        super().__init__(ast)
+
+    def mutate_BinOp(self, ast_node):
+        sim = ast_node.lhs.sim
+        ast_node.lhs = self.mutate(ast_node.lhs)
+        ast_node.rhs = self.mutate(ast_node.rhs)
+        ast_node.bin_op_vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.bin_op_vector_index_mapping.items()}
+
+        if ast_node.op in ['+', '-'] and ast_node.rhs == 0:
+            return ast_node.lhs
+
+        if ast_node.op in ['+'] and ast_node.lhs == 0:
+            return ast_node.rhs
+
+        if ast_node.op in ['*', '/'] and ast_node.rhs == 1:
+            return ast_node.lhs
+
+        if ast_node.op == '*' and ast_node.lhs == 1:
+            return ast_node.rhs
+
+        if ast_node.op == '*' and ast_node.lhs == 0:
+            return Lit(sim, 0 if ast_node.type() == Type_Int else 0.0)
+
+        return ast_node
+
+
+def simplify_expressions(ast_node):
+    simplify = SimplifyExpressions(ast_node)
+    simplify.mutate()