From b9f5f8bb18d3385050d121913526e25e256ba92b Mon Sep 17 00:00:00 2001
From: Behzad Safaei <iwia103h@a0228.nhr.fau.de>
Date: Mon, 17 Mar 2025 00:35:35 +0100
Subject: [PATCH] Simplify halo cell generation

---
 src/pairs/code_gen/interface.py |  1 +
 src/pairs/sim/cell_lists.py     | 51 +++++++++------------------------
 2 files changed, 14 insertions(+), 38 deletions(-)

diff --git a/src/pairs/code_gen/interface.py b/src/pairs/code_gen/interface.py
index 8c0b727..b263157 100644
--- a/src/pairs/code_gen/interface.py
+++ b/src/pairs/code_gen/interface.py
@@ -88,6 +88,7 @@ class InterfaceModules:
         # This update assumes all particles have been created exactly in the rank that contains them 
         self.sim.add_statement(UpdateDomain(self.sim))  
         self.sim.add_statement(BuildCellListsStencil(self.sim, self.sim.cell_lists))
+        self.sim.add_statement(self.sim.update_cells_procedures)
         
     @pairs_interface_block
     def update_domain(self):
diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py
index 09eefb0..a45ed95 100644
--- a/src/pairs/sim/cell_lists.py
+++ b/src/pairs/sim/cell_lists.py
@@ -125,44 +125,19 @@ class BuildCellListsStencil(Lowerable):
             Assign(self.sim, layers_1, Ceil(self.sim, (cutoff_radius / spacing[1])) + 2)
             layers_2 = self.sim.add_temp_var(0)
             Assign(self.sim, layers_2, Ceil(self.sim, (cutoff_radius / spacing[2])) + 2)
-
-            # TODO: Merge these loops.
-            # X faces
-            for y in For(self.sim, 0, dim_ncells[1]):
-                for z in For(self.sim, 0, dim_ncells[2]):
-                    for x in For(self.sim, 0, layers_0):
-                        index = x*dim_ncells[1]*dim_ncells[2] + y*dim_ncells[2] + z
-                        Assign(self.sim, halo_cells[n], index + 1)
-                        Assign(self.sim, n, n+1)
-                    for x in For(self.sim, dim_ncells[0]-layers_0, dim_ncells[0]):
-                        index = x*dim_ncells[1]*dim_ncells[2] + y*dim_ncells[2] + z
-                        Assign(self.sim, halo_cells[n], index + 1)
-                        Assign(self.sim, n, n+1)
-            
-            # Y faces (excluding X edges)
-            for x in For(self.sim, layers_0, dim_ncells[0]-layers_0):
-                for z in For(self.sim, 0, dim_ncells[2]):
-                    for y in For(self.sim, 0, layers_1):
-                        index = x*dim_ncells[1]*dim_ncells[2] + y*dim_ncells[2] + z
-                        Assign(self.sim, halo_cells[n], index + 1)
-                        Assign(self.sim, n, n+1)
-                    for y in For(self.sim, dim_ncells[1]-layers_1, dim_ncells[1]):
-                        index = x*dim_ncells[1]*dim_ncells[2] + y*dim_ncells[2] + z
-                        Assign(self.sim, halo_cells[n], index + 1)
-                        Assign(self.sim, n, n+1)
-            
-            # Z faces (exluding X and Y edges)
-            for x in For(self.sim, layers_0, dim_ncells[0]-layers_0):
-                for y in For(self.sim, layers_1, dim_ncells[1]-layers_1):
-                    for z in For(self.sim, 0, layers_2):
-                        index = x*dim_ncells[1]*dim_ncells[2] + y*dim_ncells[2] + z
-                        Assign(self.sim, halo_cells[n], index + 1)
-                        Assign(self.sim, n, n+1)
-                    for z in For(self.sim, dim_ncells[2]-layers_2, dim_ncells[2]):
-                        index = x*dim_ncells[1]*dim_ncells[2] + y*dim_ncells[2] + z
-                        Assign(self.sim, halo_cells[n], index + 1)
-                        Assign(self.sim, n, n+1)
         
+            for x in For(self.sim, 0, dim_ncells[0]):
+                for y in For(self.sim, 0, dim_ncells[1]):
+                    for z in For(self.sim, 0, dim_ncells[2]):
+                        cond0 = ScalarOp.or_op(x<layers_0, x>=(dim_ncells[0]-layers_0)) 
+                        cond1 = ScalarOp.or_op(y<layers_1, y>=(dim_ncells[1]-layers_1)) 
+                        cond2 = ScalarOp.or_op(z<layers_2, z>=(dim_ncells[2]-layers_2)) 
+                        fullcond = ScalarOp.or_op(ScalarOp.or_op(cond0, cond1), cond2) 
+                        for _ in Filter(self.sim, fullcond):
+                            index = x*dim_ncells[1]*dim_ncells[2] + y*dim_ncells[2] + z
+                            Assign(self.sim, halo_cells[n], index + 1)
+                            Assign(self.sim, n, n+1)
+
 
 class BuildCellLists(Lowerable):
     def __init__(self, sim, cell_lists):
@@ -183,7 +158,7 @@ class BuildCellLists(Lowerable):
         positions = self.sim.position()
 
         self.sim.module_name("build_cell_lists")
-        self.sim.check_resize(cell_capacity, cell_sizes)
+        # self.sim.check_resize(cell_capacity, cell_sizes)  # TODO: Check resize for 2D arrays
 
         for c in For(self.sim, 0, ncells):
             Assign(self.sim, cell_sizes[c], 0)
-- 
GitLab