diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index fc8b3c08c6574fc22ccffb22f2b52adeb4199c66..ad50e3e7235c6184ebda49ffc68c776f755ee4f5 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -225,7 +225,7 @@ class CGen:
                 array_access = ast_node.elem
                 array_name = self.generate_expression(array_access.array)
                 tkw = Types.c_keyword(array_access.type())
-                acc_index = self.generate_expression(array_access.index)
+                acc_index = self.generate_expression(array_access.flat_index)
                 acc_ref = f"a{array_access.id()}"
                 self.print(f"const {tkw} {acc_ref} = {array_name}[{acc_index}];")
 
@@ -486,7 +486,7 @@ class CGen:
 
         if isinstance(ast_node, ArrayAccess):
             array_name = self.generate_expression(ast_node.array)
-            acc_index = self.generate_expression(ast_node.index)
+            acc_index = self.generate_expression(ast_node.flat_index)
             if mem or ast_node.inlined is True:
                 return f"{array_name}[{acc_index}]"
 
diff --git a/src/pairs/ir/arrays.py b/src/pairs/ir/arrays.py
index 7e144bdc7abc86bdb6d432a6939022c00b384e22..593651567b4e62d89349245528bc96dc7237dc7c 100644
--- a/src/pairs/ir/arrays.py
+++ b/src/pairs/ir/arrays.py
@@ -1,7 +1,7 @@
 from functools import reduce
 from pairs.ir.assign import Assign
 from pairs.ir.ast_node import ASTNode
-from pairs.ir.bin_op import BinOp, ASTTerm
+from pairs.ir.bin_op import ASTTerm, BinOp
 from pairs.ir.layouts import Layouts
 from pairs.ir.lit import Lit
 from pairs.ir.memory import Realloc
@@ -130,44 +130,47 @@ class ArrayAccess(ASTTerm):
         super().__init__(sim)
         self.acc_id = ArrayAccess.new_id()
         self.array = array
-        self.indexes = [Lit.cvt(sim, index)]
-        self.index = None
+        self.partial_indexes = [Lit.cvt(sim, index)]
+        self.flat_index = None
         self.inlined = False
         self.terminals = set()
-        self.check_and_set_index()
+        self.check_and_set_flat_index()
 
     def __str__(self):
-        return f"ArrayAccess<{self.array}, {self.indexes}>"
+        return f"ArrayAccess<{self.array}, {self.partial_indexes}>"
 
     def __getitem__(self, index):
-        assert self.index is None, "Number of indexes higher than array dimension!"
-        self.indexes.append(Lit.cvt(self.sim, index))
-        self.check_and_set_index()
+        assert self.flat_index is None, "Number of partial indexes higher than array dimension!"
+        self.partial_indexes.append(Lit.cvt(self.sim, index))
+        self.check_and_set_flat_index()
         return self
 
     def inline_rec(self):
         self.inlined = True
         return self
 
-    def check_and_set_index(self):
-        if len(self.indexes) == self.array.ndims():
+    def check_and_set_flat_index(self):
+        if len(self.partial_indexes) == self.array.ndims():
             sizes = self.array.sizes()
             layout = self.array.layout()
 
             if layout == Layouts.AoS:
                 for s in range(0, len(sizes)):
-                    self.index = (self.indexes[s] if self.index is None
-                                  else self.index * sizes[s] + self.indexes[s])
+                    self.flat_index = (self.partial_indexes[s] if self.flat_index is None
+                                       else self.flat_index * sizes[s] + self.partial_indexes[s])
 
             elif layout == Layouts.SoA:
                 for s in reversed(range(0, len(sizes))):
-                    self.index = (self.indexes[s] if self.index is None
-                                  else self.index * sizes[s] + self.indexes[s])
+                    self.flat_index = (self.partial_indexes[s] if self.flat_index is None
+                                       else self.flat_index * sizes[s] + self.partial_indexes[s])
 
             else:
                 raise Exception("Invalid data layout!")
 
-            self.index = Lit.cvt(self.sim, self.index)
+            self.flat_index = Lit.cvt(self.sim, self.flat_index)
+            return True
+
+        return False
 
     def set(self, other):
         return self.sim.add_statement(Assign(self.sim, self, other))
@@ -180,16 +183,15 @@ class ArrayAccess(ASTTerm):
 
     def type(self):
         return self.array.type()
-        # return self.array.type() if self.index is None else Types.Array
 
     def add_terminal(self, terminal):
         self.terminals.add(terminal)
 
     def children(self):
-        if self.index is not None:
-            return [self.array, self.index]
+        if self.flat_index is not None:
+            return [self.array, self.flat_index]
 
-        return [self.array] + self.indexes
+        return [self.array] + self.partial_indexes
 
 
 class ArrayDecl(ASTNode):
diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py
index 887d29a3caceffd406695b56eda44bee0cbef8b4..6e86148840149218224037e4e94a249a48555bc8 100644
--- a/src/pairs/ir/mutator.py
+++ b/src/pairs/ir/mutator.py
@@ -26,10 +26,10 @@ class Mutator:
 
     def mutate_ArrayAccess(self, ast_node):
         ast_node.array = self.mutate(ast_node.array)
-        ast_node.indexes = [self.mutate(i) for i in ast_node.indexes]
+        ast_node.partial_indexes = [self.mutate(i) for i in ast_node.partial_indexes]
 
-        if ast_node.index is not None:
-            ast_node.index = self.mutate(ast_node.index)
+        if ast_node.flat_index is not None:
+            ast_node.flat_index = self.mutate(ast_node.flat_index)
 
         return ast_node 
 
diff --git a/src/pairs/transformations/expressions.py b/src/pairs/transformations/expressions.py
index 2921ebd54dba2380679900bd440035a8dcc16ac5..37367637b96e06cee8787ef45a4a768312052fef 100644
--- a/src/pairs/transformations/expressions.py
+++ b/src/pairs/transformations/expressions.py
@@ -103,12 +103,13 @@ class AddExpressionDeclarations(Mutator):
     def mutate_ArrayAccess(self, ast_node):
         writing = self.writing
         ast_node.array = self.mutate(ast_node.array)
+
         self.writing = False
-        ast_node.indexes = [self.mutate(i) for i in ast_node.indexes]
-        if ast_node.index is not None:
-            ast_node.index = self.mutate(ast_node.index)
-        self.writing = writing
+        ast_node.partial_indexes = [self.mutate(i) for i in ast_node.partial_indexes]
+        if ast_node.flat_index is not None:
+            ast_node.flat_index = self.mutate(ast_node.flat_index)
 
+        self.writing = writing
         if self.writing is False and ast_node.inlined is False:
             array_access_id = id(ast_node)
             if array_access_id not in self.declared_exprs and array_access_id not in self.params: