diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index fc8b3c08c6574fc22ccffb22f2b52adeb4199c66..ad50e3e7235c6184ebda49ffc68c776f755ee4f5 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -225,7 +225,7 @@ class CGen: array_access = ast_node.elem array_name = self.generate_expression(array_access.array) tkw = Types.c_keyword(array_access.type()) - acc_index = self.generate_expression(array_access.index) + acc_index = self.generate_expression(array_access.flat_index) acc_ref = f"a{array_access.id()}" self.print(f"const {tkw} {acc_ref} = {array_name}[{acc_index}];") @@ -486,7 +486,7 @@ class CGen: if isinstance(ast_node, ArrayAccess): array_name = self.generate_expression(ast_node.array) - acc_index = self.generate_expression(ast_node.index) + acc_index = self.generate_expression(ast_node.flat_index) if mem or ast_node.inlined is True: return f"{array_name}[{acc_index}]" diff --git a/src/pairs/ir/arrays.py b/src/pairs/ir/arrays.py index 7e144bdc7abc86bdb6d432a6939022c00b384e22..593651567b4e62d89349245528bc96dc7237dc7c 100644 --- a/src/pairs/ir/arrays.py +++ b/src/pairs/ir/arrays.py @@ -1,7 +1,7 @@ from functools import reduce from pairs.ir.assign import Assign from pairs.ir.ast_node import ASTNode -from pairs.ir.bin_op import BinOp, ASTTerm +from pairs.ir.bin_op import ASTTerm, BinOp from pairs.ir.layouts import Layouts from pairs.ir.lit import Lit from pairs.ir.memory import Realloc @@ -130,44 +130,47 @@ class ArrayAccess(ASTTerm): super().__init__(sim) self.acc_id = ArrayAccess.new_id() self.array = array - self.indexes = [Lit.cvt(sim, index)] - self.index = None + self.partial_indexes = [Lit.cvt(sim, index)] + self.flat_index = None self.inlined = False self.terminals = set() - self.check_and_set_index() + self.check_and_set_flat_index() def __str__(self): - return f"ArrayAccess<{self.array}, {self.indexes}>" + return f"ArrayAccess<{self.array}, {self.partial_indexes}>" def __getitem__(self, index): - assert self.index is None, "Number of indexes higher than array dimension!" - self.indexes.append(Lit.cvt(self.sim, index)) - self.check_and_set_index() + assert self.flat_index is None, "Number of partial indexes higher than array dimension!" + self.partial_indexes.append(Lit.cvt(self.sim, index)) + self.check_and_set_flat_index() return self def inline_rec(self): self.inlined = True return self - def check_and_set_index(self): - if len(self.indexes) == self.array.ndims(): + def check_and_set_flat_index(self): + if len(self.partial_indexes) == self.array.ndims(): sizes = self.array.sizes() layout = self.array.layout() if layout == Layouts.AoS: for s in range(0, len(sizes)): - self.index = (self.indexes[s] if self.index is None - else self.index * sizes[s] + self.indexes[s]) + self.flat_index = (self.partial_indexes[s] if self.flat_index is None + else self.flat_index * sizes[s] + self.partial_indexes[s]) elif layout == Layouts.SoA: for s in reversed(range(0, len(sizes))): - self.index = (self.indexes[s] if self.index is None - else self.index * sizes[s] + self.indexes[s]) + self.flat_index = (self.partial_indexes[s] if self.flat_index is None + else self.flat_index * sizes[s] + self.partial_indexes[s]) else: raise Exception("Invalid data layout!") - self.index = Lit.cvt(self.sim, self.index) + self.flat_index = Lit.cvt(self.sim, self.flat_index) + return True + + return False def set(self, other): return self.sim.add_statement(Assign(self.sim, self, other)) @@ -180,16 +183,15 @@ class ArrayAccess(ASTTerm): def type(self): return self.array.type() - # return self.array.type() if self.index is None else Types.Array def add_terminal(self, terminal): self.terminals.add(terminal) def children(self): - if self.index is not None: - return [self.array, self.index] + if self.flat_index is not None: + return [self.array, self.flat_index] - return [self.array] + self.indexes + return [self.array] + self.partial_indexes class ArrayDecl(ASTNode): diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py index 887d29a3caceffd406695b56eda44bee0cbef8b4..6e86148840149218224037e4e94a249a48555bc8 100644 --- a/src/pairs/ir/mutator.py +++ b/src/pairs/ir/mutator.py @@ -26,10 +26,10 @@ class Mutator: def mutate_ArrayAccess(self, ast_node): ast_node.array = self.mutate(ast_node.array) - ast_node.indexes = [self.mutate(i) for i in ast_node.indexes] + ast_node.partial_indexes = [self.mutate(i) for i in ast_node.partial_indexes] - if ast_node.index is not None: - ast_node.index = self.mutate(ast_node.index) + if ast_node.flat_index is not None: + ast_node.flat_index = self.mutate(ast_node.flat_index) return ast_node diff --git a/src/pairs/transformations/expressions.py b/src/pairs/transformations/expressions.py index 2921ebd54dba2380679900bd440035a8dcc16ac5..37367637b96e06cee8787ef45a4a768312052fef 100644 --- a/src/pairs/transformations/expressions.py +++ b/src/pairs/transformations/expressions.py @@ -103,12 +103,13 @@ class AddExpressionDeclarations(Mutator): def mutate_ArrayAccess(self, ast_node): writing = self.writing ast_node.array = self.mutate(ast_node.array) + self.writing = False - ast_node.indexes = [self.mutate(i) for i in ast_node.indexes] - if ast_node.index is not None: - ast_node.index = self.mutate(ast_node.index) - self.writing = writing + ast_node.partial_indexes = [self.mutate(i) for i in ast_node.partial_indexes] + if ast_node.flat_index is not None: + ast_node.flat_index = self.mutate(ast_node.flat_index) + self.writing = writing if self.writing is False and ast_node.inlined is False: array_access_id = id(ast_node) if array_access_id not in self.declared_exprs and array_access_id not in self.params: