From 973f7a5ce6172df70388e046892738cf3e3e2ff3 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Mon, 26 Sep 2022 16:19:10 +0200
Subject: [PATCH] Avoid duplicates of domain initialization calls

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 src/pairs/ir/block.py       |  3 +++
 src/pairs/sim/simulation.py | 25 ++++++++++++++++++-------
 2 files changed, 21 insertions(+), 7 deletions(-)

diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py
index 27c7d57..58037b3 100644
--- a/src/pairs/ir/block.py
+++ b/src/pairs/ir/block.py
@@ -66,6 +66,9 @@ class Block(ASTNode):
         for v in variant if isinstance(variant, list) else [variant]:
             self.variants.add(v)
 
+    def clear(self):
+        self.stmts = []
+
     def statements(self):
         return self.stmts
 
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index f03b9f6..dad95ed 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -46,6 +46,7 @@ class Simulation:
         self.nested_count = 0
         self.nest = False
         self.check_decl_usage = True
+        self._capture_statements = True
         self._block = Block(self, [])
         self.setups = Block(self, [])
         self.functions = Block(self, [])
@@ -196,11 +197,15 @@ class Simulation:
                 run_on_device=run_on_device,
                 temps=self._module_temps))
 
+    def capture_statements(self, capture=True):
+        self._capture_statements = capture
+
     def add_statement(self, stmt):
-        if not self.scope:
-            self._block.add_statement(stmt)
-        else:
-            self.scope[-1].add_statement(stmt)
+        if self._capture_statements:
+            if not self.scope:
+                self._block.add_statement(stmt)
+            else:
+                self.scope[-1].add_statement(stmt)
 
         return stmt
 
@@ -237,6 +242,15 @@ class Simulation:
         dom_part = DimensionRanges(self)
         comm = Comm(self, dom_part)
 
+        self.capture_statements(False)
+        grid_array = [[self.grid.min(d), self.grid.max(d)] for d in range(self.ndims())]
+        self.setups.add_statement([
+            Call_Void(self, "pairs::initDomain", [param for delim in grid_array for param in delim]),
+            Call_Void(self, "pairs::fillCommunicationArrays", [dom_part.neighbor_ranks, dom_part.pbc, dom_part.subdom])
+        ])
+
+        self.capture_statements() # TODO: check if this is actually required
+
         timestep = Timestep(self, self.ntimesteps, [
             (comm.exchange(), 20),
             (EnforcePBC(self), 20),
@@ -251,11 +265,8 @@ class Simulation:
         timestep.add(VTKWrite(self, self.vtk_file, timestep.timestep() + 1))
         self.leave()
 
-        grid_array = [[self.grid.min(d), self.grid.max(d)] for d in range(self.ndims())]
         body = Block.from_list(self, [
             self.setups,
-            Call_Void(self, "pairs::initDomain", [param for delim in grid_array for param in delim]),
-            Call_Void(self, "pairs::fillCommunicationArrays", [dom_part.neighbor_ranks, dom_part.pbc, dom_part.subdom]),
             CellListsStencilBuild(self, self.cell_lists),
             VTKWrite(self, self.vtk_file, 0),
             timestep.as_block()
-- 
GitLab