From 2f324c5cfe413dc6a744ead47c436744c89022ae Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Thu, 25 Nov 2021 22:41:38 +0100
Subject: [PATCH] Adjust PBC kernels to not perform redundant accesses and
 improve resize transformations

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 src/pairs/sim/pbc.py                 | 31 ++++++-----
 src/pairs/sim/simulation.py          |  5 +-
 src/pairs/transformations/modules.py | 80 ++++++++++++++++++----------
 3 files changed, 72 insertions(+), 44 deletions(-)

diff --git a/src/pairs/sim/pbc.py b/src/pairs/sim/pbc.py
index 90e76d7..08a5d96 100644
--- a/src/pairs/sim/pbc.py
+++ b/src/pairs/sim/pbc.py
@@ -59,12 +59,13 @@ class EnforcePBC(Lowerable):
 
         for i in ParticleFor(sim):
             # TODO: VecFilter?
+            pos = positions[i]
             for d in range(0, ndims):
-                for _ in Filter(sim, positions[i][d] < grid.min(d)):
-                    positions[i][d].add(grid.length(d))
+                for _ in Filter(sim, pos[d] < grid.min(d)):
+                    pos[d].add(grid.length(d))
 
-                for _ in Filter(sim, positions[i][d] > grid.max(d)):
-                    positions[i][d].sub(grid.length(d))
+                for _ in Filter(sim, pos[d] > grid.max(d)):
+                    pos[d].sub(grid.length(d))
 
 
 class SetupPBC(Lowerable):
@@ -88,28 +89,32 @@ class SetupPBC(Lowerable):
         sim.check_resize(pbc_capacity, npbc)
 
         npbc.set(0)
-        for d in range(0, ndims):
-            for i in For(sim, 0, nlocal + npbc):
-                last_id = nlocal + npbc
+        for i in For(sim, 0, nlocal + npbc):
+            pos = positions[i]
+            last_id = nlocal + npbc
+            last_pos = positions[last_id]
+
+            for d in range(0, ndims):
+                grid_length = grid.length(d)
                 # TODO: VecFilter?
-                for _ in Filter(sim, positions[i][d] < grid.min(d) + cutneigh):
+                for _ in Filter(sim, pos[d] < grid.min(d) + cutneigh):
                     pbc_map[npbc].set(i)
                     pbc_mult[npbc][d].set(1)
-                    positions[last_id][d].set(positions[i][d] + grid.length(d))
+                    last_pos[d].set(pos[d] + grid_length)
 
                     for d_ in [x for x in range(0, ndims) if x != d]:
                         pbc_mult[npbc][d_].set(0)
-                        positions[last_id][d_].set(positions[i][d_])
+                        last_pos[d_].set(pos[d_])
 
                     npbc.add(1)
 
-                for _ in Filter(sim, positions[i][d] > grid.max(d) - cutneigh):
+                for _ in Filter(sim, pos[d] > grid.max(d) - cutneigh):
                     pbc_map[npbc].set(i)
                     pbc_mult[npbc][d].set(-1)
-                    positions[last_id][d].set(positions[i][d] - grid.length(d))
+                    last_pos[d].set(pos[d] - grid_length)
 
                     for d_ in [x for x in range(0, ndims) if x != d]:
                         pbc_mult[npbc][d_].set(0)
-                        positions[last_id][d_].set(positions[i][d_])
+                        last_pos[d_].set(pos[d_])
 
                     npbc.add(1)
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 8d6717b..6acf3d6 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -164,11 +164,10 @@ class Simulation:
         self._check_properties_resize = True
 
     def check_resize(self, capacity, size):
-        size_array = [size] if not isinstance(size, list) else size
         if capacity not in self._resizes_to_check:
-            self._resizes_to_check[capacity] = size_array
+            self._resizes_to_check[capacity] = size
         else:
-            self._resizes_to_check[capacity] += size_array
+            raise Exception("Two sizes assigned to same capacity!")
 
     def build_kernel_block_with_statements(self):
         self.kernels.add_statement(
diff --git a/src/pairs/transformations/modules.py b/src/pairs/transformations/modules.py
index 8a34cec..c369694 100644
--- a/src/pairs/transformations/modules.py
+++ b/src/pairs/transformations/modules.py
@@ -1,9 +1,15 @@
 from pairs.ir.bin_op import BinOp
-from pairs.ir.branches import Branch
+from pairs.ir.branches import Filter
+from pairs.ir.data_types import Type_Vector
+from pairs.ir.loops import While
+from pairs.ir.memory import Realloc
 from pairs.ir.module import Module, Module_Call
 from pairs.ir.mutator import Mutator
+from pairs.ir.properties import UpdateProperty
 from pairs.ir.variables import Deref
 from pairs.ir.visitor import Visitor
+from functools import reduce
+import operator
 
 
 class FetchModulesReferences(Visitor):
@@ -68,14 +74,22 @@ class AddResizeLogic(Mutator):
         self.check_properties_resize = False
         self.match_capacity = None
         self.update = {}
-        self.resize_buffers = {}
         self.nresize_buffers = 0
 
-    def mutate_Array(self, ast_node):
-        for capacity, size in self.resizes_to_check.items():
-            if size == ast_node.name():
-                self.match_capacity = capacity
+    def get_capacity_for_size(self, size):
+        for _capacity, _size in self.resizes_to_check.items():
+            if _size.name() == size.name():
+                return _capacity
+
+        return None
 
+    def look_for_match_capacity(self, size):
+        capacity = self.get_capacity_for_size(size)
+        if capacity is not None:
+            self.match_capacity = capacity
+
+    def mutate_Array(self, ast_node):
+        self.look_for_match_capacity(ast_node)
         return ast_node
 
     def mutate_Assignment(self, ast_node):
@@ -88,23 +102,23 @@ class AddResizeLogic(Mutator):
 
                 # Resize var is used in index, this statement should be checked for safety
                 if self.match_capacity is not None:
+                    module = self.module_stack[-1]
                     size = self.resizes_to_check[match_capacity]
                     check_value = self.update[size] if size in self.update else size
-                    resize_id = self.resize_buffers[match_capacity]
+                    resize_id = self.module_resizes[module].keys()[self.module_resizes[module].values().index(match_capacity)]
                     return Branch(ast_node.sim, check_value < match_capacity,
                                   Block(ast_node.sim, ast_node),
                                   Block(ast_node.sim, ast_node.resizes[resize_id].set(check_value)))
 
+
                 # Size is changed here, assigned value must be used for further checkings
-                for capacity, size in self.resizes_to_check.items():
-                    if size == dest.array.name():
-                        self.update[size] = src
+                # When size is array (i.e. neighbor list size), just use last assignment to it
+                # without checking accessed index (maybe this has to be changed at some point)
+                self.update[dest.array] = src
 
             if isinstance(dest, Var):
                 # Size is changed here, assigned value must be used for further checkings
-                for capacity, size in self.resizes_to_check.items():
-                    if size == dest.name():
-                        self.update[size] = src
+                self.update[dest] = src
 
         return ast_node
 
@@ -119,15 +133,13 @@ class AddResizeLogic(Mutator):
         saved_resizes_to_check = self.resizes_to_check
         saved_check_properties_resize = self.check_properties_resize
         saved_update = self.update
-        saved_resize_buffers = self.resize_buffers
         saved_nresize_buffers = self.nresize_buffers
 
         # Update state and keep traversing tree
-        self.module_resizes[ast_node] = []
+        self.module_resizes[ast_node] = {}
         self.module_stack.append(ast_node)
         for capacity in ast_node._resizes_to_check.keys():
-            self.module_resizes[ast_node].append(self.nresize_buffers)
-            self.resize_buffers[capacity] = self.nresize_buffers
+            self.module_resizes[ast_node][self.nresize_buffers] = capacity
             self.nresize_buffers += 1
 
         self.resizes_to_check = ast_node._resizes_to_check
@@ -140,41 +152,53 @@ class AddResizeLogic(Mutator):
         self.resizes_to_check = saved_resizes_to_check
         self.check_properties_resize = saved_check_properties_resize
         self.update = saved_update
-        self.resize_buffers = saved_resize_buffers
         self.nresize_buffers = saved_nresize_buffers
         return ast_node
 
     def mutate_Var(self, ast_node):
-        for capacity, size in self.resizes_to_check.items():
-            if size == ast_node.name():
-                self.match_capacity = capacity
-
+        self.look_for_match_capacity(ast_node)
         return ast_node
 
 
 class ReplaceModulesByCalls(Mutator):
-    def __init__(self, ast, module_resizes):
+    def __init__(self, ast, module_resizes, grow_fn=None):
         super().__init__(ast)
         self.module_resizes = module_resizes
+        self.grow_fn = grow_fn if grow_fn is not None else (lambda x: x * 2)
 
     def mutate_Module(self, ast_node):
         ast_node._block = self.mutate(ast_node._block)
         if ast_node.name == 'main':
             return ast_node
 
-        call = Module_Call(ast_node.sim, ast_node)
+        sim = ast_node.sim
+        call = Module_Call(sim, ast_node)
         if self.module_resizes[ast_node]:
+            properties = sim.properties
             init_stmts = []
             reset_stmts = []
+            resize_stmts = []
             branch_cond = None
 
-            for r in self.module_resizes[ast_node]:
-                init_stmts.append(Assign(ast_node.resizes[r], 1))
-                reset_stmts.append(Assign(ast_node.resizes[r], 0))
+            for r, c in self.module_resizes[ast_node].items():
+                init_stmts.append(Assign(sim, ast_node.resizes[r], 1))
+                reset_stmts.append(Assign(sim, ast_node.resizes[r], 0))
                 cond = ast_node.resizes[r] > 0
                 branch_cond = cond if branch_cond is None else BinOp.or_op(cond, branch_cond)
+                props_realloc = []
+
+                if properties.is_capacity(c):
+                    for p in properties.all():
+                        sizes = [capacity, sim.ndims()] if p.type() == Type_Vector else [capacity]
+                        props_realloc.append([Realloc(sim, p, reduce(operator.mul, sizes)), UpdateProperty(sim, p, sizes)])
+
+                resize_stmts.append(
+                    Filter(sim, ast_node.resizes[r] > 0,
+                        [Assign(sim, c, self.grow_fn(ast_node.resizes[r]))] +
+                        [a.realloc() for a in c.bonded_arrays()] +
+                        props_realloc))
 
-            return Block(ast_node.sim, init_stmts + Filter(ast_node.sim, branch_cond, reset_stmts + [call]))
+            return Block(sim, init_stmts + While(sim, branch_cond, reset_stmts + [call, resize_stmts]))
 
         return call
 
-- 
GitLab