diff --git a/src/pairs/transformations/modules.py b/src/pairs/transformations/modules.py index f1fe0cb9ae5b0bfe03d7e980eda5cc384ecddf88..c463beadab29a0b0b6b32d08b535546fc005f49e 100644 --- a/src/pairs/transformations/modules.py +++ b/src/pairs/transformations/modules.py @@ -4,6 +4,7 @@ from pairs.ir.bin_op import BinOp from pairs.ir.block import Block from pairs.ir.branches import Branch, Filter from pairs.ir.data_types import Type_Vector +from pairs.ir.lit import Lit from pairs.ir.loops import While from pairs.ir.memory import Realloc from pairs.ir.module import Module, Module_Call @@ -89,6 +90,9 @@ class AddResizeLogic(Mutator): for node in nodes: if isinstance(node, (Array, Var)): capacity = self.get_capacity_for_size(node) + elif isinstance(node, ArrayAccess): + # We just want to look into mutable elements, not indexes + capacity = self.lookup_capacity([node.array]) else: capacity = self.lookup_capacity(node.children()) @@ -99,8 +103,11 @@ class AddResizeLogic(Mutator): def mutate_Assign(self, ast_node): for dest, src in ast_node.assignments: - if isinstance(dest, ArrayAccess): - match_capacity = self.lookup_capacity(ast_node.children()) + if not isinstance(src, Lit): + match_capacity = None + + if isinstance(dest, (ArrayAccess, Var)): + match_capacity = self.lookup_capacity([dest]) # Resize var is used in index, this statement should be checked for safety if match_capacity is not None: @@ -108,7 +115,7 @@ class AddResizeLogic(Mutator): resizes = list(self.module_resizes[module].keys()) capacities = list(self.module_resizes[module].values()) resize_id = resizes[capacities.index(match_capacity)] - return Branch(ast_node.sim, dest < match_capacity, + return Branch(ast_node.sim, src < match_capacity, blk_if=Block(ast_node.sim, ast_node), blk_else=Block(ast_node.sim, ast_node.sim.resizes[resize_id].set(src)))