From cdfbfb11096e3784e598b4d20802adcc81189d4d Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Mon, 14 Feb 2022 23:40:20 +0100
Subject: [PATCH] First fixes to the introduction of resize logic

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 src/pairs/sim/cell_lists.py          | 19 ++++----
 src/pairs/transformations/modules.py | 65 ++++++++++------------------
 2 files changed, 34 insertions(+), 50 deletions(-)

diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py
index 82a9f7e..0ba7481 100644
--- a/src/pairs/sim/cell_lists.py
+++ b/src/pairs/sim/cell_lists.py
@@ -19,16 +19,17 @@ class CellLists:
         self.spacing = spacing if isinstance(spacing, list) else [spacing for d in range(sim.ndims())]
         self.cutoff_radius = cutoff_radius
         self.nneighbor_cells = [math.ceil(cutoff_radius / self.spacing[d]) for d in range(sim.ndims())]
-        self.nstencil = self.sim.add_var('nstencil', Type_Int)
         self.nstencil_max = reduce((lambda x, y: x * y), [self.nneighbor_cells[d] * 2 + 1 for d in range(sim.ndims())])
-        self.ncells = self.sim.add_var('ncells', Type_Int, 1)
-        self.ncells_capacity = self.sim.add_var('ncells_capacity', Type_Int, 100)
-        self.dim_ncells = self.sim.add_static_array('dim_cells', self.sim.ndims(), Type_Int)
-        self.cell_capacity = self.sim.add_var('cell_capacity', Type_Int, 20)
-        self.cell_particles = self.sim.add_array('cell_particles', [self.ncells_capacity, self.cell_capacity], Type_Int)
-        self.cell_sizes = self.sim.add_array('cell_sizes', self.ncells_capacity, Type_Int)
-        self.stencil = self.sim.add_array('stencil', self.nstencil_max, Type_Int)
-        self.particle_cell = self.sim.add_array('particle_cell', self.sim.particle_capacity, Type_Int)
+        # Data introduced in the simulation
+        self.nstencil           =   self.sim.add_var('nstencil', Type_Int)
+        self.ncells             =   self.sim.add_var('ncells', Type_Int, 1)
+        self.ncells_capacity    =   self.sim.add_var('ncells_capacity', Type_Int, 100)
+        self.cell_capacity      =   self.sim.add_var('cell_capacity', Type_Int, 20)
+        self.dim_ncells         =   self.sim.add_static_array('dim_cells', self.sim.ndims(), Type_Int)
+        self.cell_particles     =   self.sim.add_array('cell_particles', [self.ncells_capacity, self.cell_capacity], Type_Int)
+        self.cell_sizes         =   self.sim.add_array('cell_sizes', self.ncells_capacity, Type_Int)
+        self.stencil            =   self.sim.add_array('stencil', self.nstencil_max, Type_Int)
+        self.particle_cell      =   self.sim.add_array('particle_cell', self.sim.particle_capacity, Type_Int)
 
 
 class CellListsStencilBuild(Lowerable):
diff --git a/src/pairs/transformations/modules.py b/src/pairs/transformations/modules.py
index b462519..f1fe0cb 100644
--- a/src/pairs/transformations/modules.py
+++ b/src/pairs/transformations/modules.py
@@ -1,14 +1,15 @@
+from pairs.ir.arrays import Array, ArrayAccess
 from pairs.ir.assign import Assign
 from pairs.ir.bin_op import BinOp
 from pairs.ir.block import Block
-from pairs.ir.branches import Filter
+from pairs.ir.branches import Branch, Filter
 from pairs.ir.data_types import Type_Vector
 from pairs.ir.loops import While
 from pairs.ir.memory import Realloc
 from pairs.ir.module import Module, Module_Call
 from pairs.ir.mutator import Mutator
 from pairs.ir.properties import UpdateProperty
-from pairs.ir.variables import Deref
+from pairs.ir.variables import Var, Deref
 from pairs.ir.visitor import Visitor
 from functools import reduce
 import operator
@@ -74,8 +75,6 @@ class AddResizeLogic(Mutator):
         self.module_resizes = {}
         self.resizes_to_check = {}
         self.check_properties_resize = False
-        self.match_capacity = None
-        self.update = {}
         self.nresize_buffers = 0
 
     def get_capacity_for_size(self, size):
@@ -85,42 +84,33 @@ class AddResizeLogic(Mutator):
 
         return None
 
-    def look_for_match_capacity(self, size):
-        capacity = self.get_capacity_for_size(size)
-        if capacity is not None:
-            self.match_capacity = capacity
+    def lookup_capacity(self, nodes):
+        capacity = None
+        for node in nodes:
+            if isinstance(node, (Array, Var)):
+                capacity = self.get_capacity_for_size(node)
+            else:
+                capacity = self.lookup_capacity(node.children())
 
-    def mutate_Array(self, ast_node):
-        self.look_for_match_capacity(ast_node)
-        return ast_node
+            if capacity is not None:
+                return capacity
+
+        return None
 
-    def mutate_Assignment(self, ast_node):
-        for dest, src in ast_node.assignments.items():
+    def mutate_Assign(self, ast_node):
+        for dest, src in ast_node.assignments:
             if isinstance(dest, ArrayAccess):
-                self.match_capacity = None
-                ast_node.indexes = [self.mutate(i) for i in ast_node.indexes]
-                if ast_node.index is not None:
-                    ast_node.index = self.mutate(ast_node.index)
+                match_capacity = self.lookup_capacity(ast_node.children())
 
                 # Resize var is used in index, this statement should be checked for safety
-                if self.match_capacity is not None:
+                if match_capacity is not None:
                     module = self.module_stack[-1]
-                    size = self.resizes_to_check[match_capacity]
-                    check_value = self.update[size] if size in self.update else size
-                    resize_id = self.module_resizes[module].keys()[self.module_resizes[module].values().index(match_capacity)]
-                    return Branch(ast_node.sim, check_value < match_capacity,
-                                  Block(ast_node.sim, ast_node),
-                                  Block(ast_node.sim, sim.resizes[resize_id].set(check_value)))
-
-
-                # Size is changed here, assigned value must be used for further checkings
-                # When size is of type array (i.e. neighbor list size), just use last assignment to it
-                # without checking accessed index (maybe this has to be changed at some point)
-                self.update[dest.array] = src
-
-            if isinstance(dest, Var):
-                # Size is changed here, assigned value must be used for further checkings
-                self.update[dest] = src
+                    resizes = list(self.module_resizes[module].keys())
+                    capacities = list(self.module_resizes[module].values())
+                    resize_id = resizes[capacities.index(match_capacity)]
+                    return Branch(ast_node.sim, dest < match_capacity,
+                                  blk_if=Block(ast_node.sim, ast_node),
+                                  blk_else=Block(ast_node.sim, ast_node.sim.resizes[resize_id].set(src)))
 
         return ast_node
 
@@ -134,7 +124,6 @@ class AddResizeLogic(Mutator):
         # Save current state
         saved_resizes_to_check = self.resizes_to_check
         saved_check_properties_resize = self.check_properties_resize
-        saved_update = self.update
         saved_nresize_buffers = self.nresize_buffers
 
         # Update state and keep traversing tree
@@ -146,21 +135,15 @@ class AddResizeLogic(Mutator):
 
         self.resizes_to_check = ast_node._resizes_to_check
         self.check_properties_resize = ast_node._check_properties_resize
-        self.update = {}
         ast_node._block = self.mutate(ast_node._block)
         self.module_stack.pop()
 
         # Restore saved state
         self.resizes_to_check = saved_resizes_to_check
         self.check_properties_resize = saved_check_properties_resize
-        self.update = saved_update
         self.nresize_buffers = saved_nresize_buffers
         return ast_node
 
-    def mutate_Var(self, ast_node):
-        self.look_for_match_capacity(ast_node)
-        return ast_node
-
 
 class ReplaceModulesByCalls(Mutator):
     def __init__(self, ast, module_resizes, grow_fn=None):
-- 
GitLab