From 973f7a5ce6172df70388e046892738cf3e3e2ff3 Mon Sep 17 00:00:00 2001 From: Rafael Ravedutti <rafaelravedutti@gmail.com> Date: Mon, 26 Sep 2022 16:19:10 +0200 Subject: [PATCH] Avoid duplicates of domain initialization calls Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com> --- src/pairs/ir/block.py | 3 +++ src/pairs/sim/simulation.py | 25 ++++++++++++++++++------- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py index 27c7d57..58037b3 100644 --- a/src/pairs/ir/block.py +++ b/src/pairs/ir/block.py @@ -66,6 +66,9 @@ class Block(ASTNode): for v in variant if isinstance(variant, list) else [variant]: self.variants.add(v) + def clear(self): + self.stmts = [] + def statements(self): return self.stmts diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index f03b9f6..dad95ed 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -46,6 +46,7 @@ class Simulation: self.nested_count = 0 self.nest = False self.check_decl_usage = True + self._capture_statements = True self._block = Block(self, []) self.setups = Block(self, []) self.functions = Block(self, []) @@ -196,11 +197,15 @@ class Simulation: run_on_device=run_on_device, temps=self._module_temps)) + def capture_statements(self, capture=True): + self._capture_statements = capture + def add_statement(self, stmt): - if not self.scope: - self._block.add_statement(stmt) - else: - self.scope[-1].add_statement(stmt) + if self._capture_statements: + if not self.scope: + self._block.add_statement(stmt) + else: + self.scope[-1].add_statement(stmt) return stmt @@ -237,6 +242,15 @@ class Simulation: dom_part = DimensionRanges(self) comm = Comm(self, dom_part) + self.capture_statements(False) + grid_array = [[self.grid.min(d), self.grid.max(d)] for d in range(self.ndims())] + self.setups.add_statement([ + Call_Void(self, "pairs::initDomain", [param for delim in grid_array for param in delim]), + Call_Void(self, "pairs::fillCommunicationArrays", [dom_part.neighbor_ranks, dom_part.pbc, dom_part.subdom]) + ]) + + self.capture_statements() # TODO: check if this is actually required + timestep = Timestep(self, self.ntimesteps, [ (comm.exchange(), 20), (EnforcePBC(self), 20), @@ -251,11 +265,8 @@ class Simulation: timestep.add(VTKWrite(self, self.vtk_file, timestep.timestep() + 1)) self.leave() - grid_array = [[self.grid.min(d), self.grid.max(d)] for d in range(self.ndims())] body = Block.from_list(self, [ self.setups, - Call_Void(self, "pairs::initDomain", [param for delim in grid_array for param in delim]), - Call_Void(self, "pairs::fillCommunicationArrays", [dom_part.neighbor_ranks, dom_part.pbc, dom_part.subdom]), CellListsStencilBuild(self, self.cell_lists), VTKWrite(self, self.vtk_file, 0), timestep.as_block() -- GitLab