diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py index 27c7d5716177fa9c019a167885a15748affcfbfe..58037b313355b4067a4385fbe09aa87ea483fb28 100644 --- a/src/pairs/ir/block.py +++ b/src/pairs/ir/block.py @@ -66,6 +66,9 @@ class Block(ASTNode): for v in variant if isinstance(variant, list) else [variant]: self.variants.add(v) + def clear(self): + self.stmts = [] + def statements(self): return self.stmts diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index f03b9f653aff09471d253377d8f459e0bf70218c..dad95eddf3e8149776b841198d4a0e13ea20590f 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -46,6 +46,7 @@ class Simulation: self.nested_count = 0 self.nest = False self.check_decl_usage = True + self._capture_statements = True self._block = Block(self, []) self.setups = Block(self, []) self.functions = Block(self, []) @@ -196,11 +197,15 @@ class Simulation: run_on_device=run_on_device, temps=self._module_temps)) + def capture_statements(self, capture=True): + self._capture_statements = capture + def add_statement(self, stmt): - if not self.scope: - self._block.add_statement(stmt) - else: - self.scope[-1].add_statement(stmt) + if self._capture_statements: + if not self.scope: + self._block.add_statement(stmt) + else: + self.scope[-1].add_statement(stmt) return stmt @@ -237,6 +242,15 @@ class Simulation: dom_part = DimensionRanges(self) comm = Comm(self, dom_part) + self.capture_statements(False) + grid_array = [[self.grid.min(d), self.grid.max(d)] for d in range(self.ndims())] + self.setups.add_statement([ + Call_Void(self, "pairs::initDomain", [param for delim in grid_array for param in delim]), + Call_Void(self, "pairs::fillCommunicationArrays", [dom_part.neighbor_ranks, dom_part.pbc, dom_part.subdom]) + ]) + + self.capture_statements() # TODO: check if this is actually required + timestep = Timestep(self, self.ntimesteps, [ (comm.exchange(), 20), (EnforcePBC(self), 20), @@ -251,11 +265,8 @@ class Simulation: timestep.add(VTKWrite(self, self.vtk_file, timestep.timestep() + 1)) self.leave() - grid_array = [[self.grid.min(d), self.grid.max(d)] for d in range(self.ndims())] body = Block.from_list(self, [ self.setups, - Call_Void(self, "pairs::initDomain", [param for delim in grid_array for param in delim]), - Call_Void(self, "pairs::fillCommunicationArrays", [dom_part.neighbor_ranks, dom_part.pbc, dom_part.subdom]), CellListsStencilBuild(self, self.cell_lists), VTKWrite(self, self.vtk_file, 0), timestep.as_block()