From dcccfc94a350d92a4f7a342f6dafd12a7e408bca Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Thu, 24 Jun 2021 02:24:15 +0200
Subject: [PATCH] Add neighbor lists and consider ArrayND in LICM

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 ir/loops.py                | 28 ++++++++++++++---------
 new_syntax.py              |  1 +
 particle.py                |  1 +
 sim/cell_lists.py          |  1 +
 sim/neighbor_lists.py      | 46 ++++++++++++++++++++++++++++++++++++++
 sim/particle_simulation.py |  9 +++++++-
 transformations/LICM.py    | 14 +++++++++++-
 7 files changed, 87 insertions(+), 13 deletions(-)
 create mode 100644 sim/neighbor_lists.py

diff --git a/ir/loops.py b/ir/loops.py
index b4ce423..f035afa 100644
--- a/ir/loops.py
+++ b/ir/loops.py
@@ -102,22 +102,28 @@ class While(ASTNode):
 
 
 class NeighborFor():
-    def __init__(self, sim, particle, cell_lists):
+    def __init__(self, sim, particle, cell_lists, neighbor_lists=None):
         self.sim = sim
         self.particle = particle
         self.cell_lists = cell_lists
+        self.neighbor_lists = neighbor_lists
 
     def __str__(self):
         return f"NeighborFor<particle: {self.particle}>"
 
     def __iter__(self):
-        cl = self.cell_lists
-        for s in For(self.sim, 0, cl.nstencil):
-            neigh_cell = cl.particle_cell[self.particle] + cl.stencil[s]
-            for _ in Filter(self.sim,
-                            BinOp.and_op(neigh_cell >= 0,
-                                         neigh_cell <= cl.ncells_all)):
-                for nc in For(self.sim, 0, cl.cell_sizes[neigh_cell]):
-                    it = cl.cell_particles[neigh_cell][nc]
-                    for _ in Filter(self.sim, BinOp.neq(it, self.particle)):
-                            yield it
+        if self.neighbor_lists is None:
+            cl = self.cell_lists
+            for s in For(self.sim, 0, cl.nstencil):
+                neigh_cell = cl.particle_cell[self.particle] + cl.stencil[s]
+                for _ in Filter(self.sim,
+                                BinOp.and_op(neigh_cell >= 0,
+                                             neigh_cell <= cl.ncells_all)):
+                    for nc in For(self.sim, 0, cl.cell_sizes[neigh_cell]):
+                        it = cl.cell_particles[neigh_cell][nc]
+                        for _ in Filter(self.sim, BinOp.neq(it, self.particle)):
+                                yield it
+        else:
+            neighbor_lists = self.neighbor_lists
+            for k in For(self.sim, 0, neighbor_lists.numneighs[self.particle]):
+                yield neighbor_lists.neighborlists[self.particle][k]
diff --git a/new_syntax.py b/new_syntax.py
index 28a1e06..bdbeb0d 100644
--- a/new_syntax.py
+++ b/new_syntax.py
@@ -184,6 +184,7 @@ psim.add_vector_property('velocity')
 psim.add_vector_property('force', vol=True)
 psim.from_file("data/minimd_setup_4x4x4.input", ['mass', 'position', 'velocity'])
 psim.create_cell_lists(2.8, 2.8)
+psim.create_neighbor_lists()
 psim.periodic(2.8)
 psim.vtk_output("output/test")
 
diff --git a/particle.py b/particle.py
index 312a132..8680d2b 100644
--- a/particle.py
+++ b/particle.py
@@ -30,6 +30,7 @@ force = psim.add_vector_property('force', vol=True)
 #psim.create_particle_lattice(grid, spacing=[0.82323, 0.82323, 0.82323])
 psim.from_file("data/minimd_setup_4x4x4.input", ['mass', 'position', 'velocity'])
 psim.create_cell_lists(2.8, 2.8)
+psim.create_neighbor_lists()
 psim.periodic(2.8)
 psim.vtk_output("output/test")
 
diff --git a/sim/cell_lists.py b/sim/cell_lists.py
index 923c4de..4c1400a 100644
--- a/sim/cell_lists.py
+++ b/sim/cell_lists.py
@@ -14,6 +14,7 @@ class CellLists:
         self.sim = sim
         self.grid = grid
         self.spacing = spacing
+        self.cutoff_radius = cutoff_radius
 
         self.nneighbor_cells = [
             math.ceil(cutoff_radius / (
diff --git a/sim/neighbor_lists.py b/sim/neighbor_lists.py
new file mode 100644
index 0000000..f2fa90b
--- /dev/null
+++ b/sim/neighbor_lists.py
@@ -0,0 +1,46 @@
+from ir.branches import Branch, Filter
+from ir.data_types import Type_Int
+from ir.loops import For, ParticleFor, NeighborFor
+from ir.utils import Print
+from sim.resize import Resize
+
+
+class NeighborLists:
+    def __init__(self, cell_lists):
+        self.sim = cell_lists.sim
+        self.cell_lists = cell_lists
+        self.capacity = self.sim.add_var('neighborlist_capacity', Type_Int, 32)
+        self.neighborlists = self.sim.add_array('neighborlists', [self.sim.particle_capacity, self.capacity], Type_Int)
+        self.numneighs = self.sim.add_array('numneighs', self.sim.particle_capacity, Type_Int)
+
+
+class NeighborListsBuild:
+    def __init__(self, neighbor_lists):
+        self.neighbor_lists = neighbor_lists
+
+    def lower(self):
+        neighbor_lists = self.neighbor_lists
+        sim = neighbor_lists.sim
+        cell_lists = neighbor_lists.cell_lists
+        cutoff_radius = cell_lists.cutoff_radius
+        position = sim.property('position')
+
+        sim.clear_block()
+        sim.add_statement(Print(sim, "NeighborListsBuild"))
+        for resize in Resize(sim, neighbor_lists.capacity):
+            for i in ParticleFor(sim):
+                neighbor_lists.numneighs[i].set(0)
+                for j in NeighborFor(sim, i, cell_lists):
+                    # TODO: find a way to not repeat this (already present in particle_pairs)
+                    dp = position[i] - position[j]
+                    rsq = dp.x() * dp.x() + dp.y() * dp.y() + dp.z() * dp.z()
+                    for _ in Filter(sim, rsq < cutoff_radius):
+                        numneighs = neighbor_lists.numneighs[i]
+                        for cond in Branch(sim, numneighs >= neighbor_lists.capacity):
+                            if cond:
+                                resize.set(numneighs)
+                            else:
+                                neighbor_lists.neighborlists[i][numneighs].set(j)
+                                neighbor_lists.numneighs[i].set(numneighs + 1)
+
+        return sim.block
diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py
index 20efef8..5b82cd4 100644
--- a/sim/particle_simulation.py
+++ b/sim/particle_simulation.py
@@ -12,6 +12,7 @@ from sim.cell_lists import CellLists, CellListsBuild, CellListsStencilBuild
 from sim.grid import Grid2D, Grid3D
 from sim.kernel_wrapper import KernelWrapper
 from sim.lattice import ParticleLattice
+from sim.neighbor_lists import NeighborLists, NeighborListsBuild
 from sim.pbc import PBC, UpdatePBC, EnforcePBC, SetupPBC
 from sim.properties import PropertiesAlloc, PropertiesResetVolatile
 from sim.read_from_file import ReadFromFile
@@ -38,6 +39,7 @@ class ParticleSimulation:
         self.nghost = self.add_var('nghost', Type_Int)
         self.grid = None
         self.cell_lists = None
+        self.neighbor_lists = None
         self.pbc = None
         self.scope = []
         self.nested_count = 0
@@ -114,6 +116,10 @@ class ParticleSimulation:
         self.cell_lists = CellLists(self, self.grid, spacing, cutoff_radius)
         return self.cell_lists
 
+    def create_neighbor_lists(self):
+        self.neighbor_lists = NeighborLists(self.cell_lists)
+        return self.neighbor_lists
+
     def periodic(self, cutneigh, flags=[1, 1, 1]):
         self.pbc = PBC(self, self.grid, cutneigh, flags)
         self.properties.add_capacity(self.pbc.pbc_capacity)
@@ -122,7 +128,7 @@ class ParticleSimulation:
     def particle_pairs(self, cutoff_radius=None, position=None):
         self.clear_block()
         for i in ParticleFor(self):
-            for j in NeighborFor(self, i, self.cell_lists):
+            for j in NeighborFor(self, i, self.cell_lists, self.neighbor_lists):
                 if cutoff_radius is not None and position is not None:
                     dp = position[i] - position[j]
                     rsq = dp.x() * dp.x() + dp.y() * dp.y() + dp.z() * dp.z()
@@ -177,6 +183,7 @@ class ParticleSimulation:
             (EnforcePBC(self.pbc).lower(), 20),
             (SetupPBC(self.pbc).lower(), UpdatePBC(self.pbc).lower(), 20),
             (CellListsBuild(self.cell_lists).lower(), 20),
+            (NeighborListsBuild(self.neighbor_lists).lower(), 20),
             PropertiesResetVolatile(self).lower(),
             self.kernels.lower()
         ])
diff --git a/transformations/LICM.py b/transformations/LICM.py
index 74f9969..5b42d8c 100644
--- a/transformations/LICM.py
+++ b/transformations/LICM.py
@@ -38,9 +38,13 @@ class SetBlockVariants(Mutator):
         return ast_node
 
     def mutate_BinOp(self, ast_node):
+        ast_node.lhs = self.mutate(ast_node.lhs)
+
         # For property accesses, we only want to include the property name, and not
         # the index that is also present in the expression
-        ast_node.lhs = self.mutate(ast_node.lhs)
+        if not ast_node.is_property_access():
+            ast_node.rhs = self.mutate(ast_node.rhs)
+
         return ast_node
 
     def mutate_ArrayAccess(self, ast_node):
@@ -52,6 +56,10 @@ class SetBlockVariants(Mutator):
     def mutate_Array(self, ast_node):
         return self.push_variant(ast_node)
 
+    # TODO: Array should be enough
+    def mutate_ArrayND(self, ast_node):
+        return self.push_variant(ast_node)
+
     def mutate_Iter(self, ast_node):
         return self.push_variant(ast_node)
 
@@ -130,6 +138,10 @@ class SetBinOpTerminals(Visitor):
     def visit_Array(self, ast_node):
         self.push_terminal(ast_node)
 
+    # TODO: Array should be enough
+    def visit_ArrayND(self, ast_node):
+        self.push_terminal(ast_node)
+
     def visit_Iter(self, ast_node):
         self.push_terminal(ast_node)
 
-- 
GitLab