diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py index 9065940a988b1bcb75738f948fee13ab25b146b6..1e2253fe24c5e52b7b92bb49d94a46b976351ffe 100644 --- a/src/pairs/ir/loops.py +++ b/src/pairs/ir/loops.py @@ -76,7 +76,7 @@ class ParticleFor(For): return f"ParticleFor<>" def children(self): - return [self.sim.nlocal] + ([] if self.local_only else [self.sim.pbc.npbc]) + return [self.block, self.sim.nlocal] + ([] if self.local_only else [self.sim.pbc.npbc]) class While(ASTNode): diff --git a/src/pairs/ir/visitor.py b/src/pairs/ir/visitor.py index b40a8d52d41632df430738b116048d5666b3d1e0..e75c5bb9b55979d4f0826e12b5dfceedac2437f6 100644 --- a/src/pairs/ir/visitor.py +++ b/src/pairs/ir/visitor.py @@ -19,7 +19,14 @@ class Visitor: if method is not None: method(ast_node) else: - self.visit_children(ast_node) + for b in type(ast_node).__bases__: + method = self.get_method(f"visit_{b.__name__}") + if method is not None: + method(ast_node) + break + + if method is None: + self.visit_children(ast_node) def visit_children(self, ast_node): for c in ast_node.children(): diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index 4187b4d42a6c6229232b0486adb8b94808bc6c0f..8e30895eb1a9d06c3f938502baec03fb08acbd9b 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -216,7 +216,9 @@ class Simulation: self.kernels ]) + self.enter(timestep.block) timestep.add(VTKWrite(self, self.vtk_file, timestep.timestep() + 1)) + self.leave() body = Block.from_list(self, [ self.setups, diff --git a/src/pairs/sim/timestep.py b/src/pairs/sim/timestep.py index 3e312f281a1657d17bd4244e7860adb1f7a03bba..606fc0e22221a5e860da2b7864934a6096d2b611 100644 --- a/src/pairs/sim/timestep.py +++ b/src/pairs/sim/timestep.py @@ -28,16 +28,19 @@ class Timestep: stmts = item if not isinstance(item, Block) else item.statements() stmts_else = None ts = self.timestep_loop.iter() + self.sim.enter(self.block) if item_else is not None: stmts_else = item_else if not isinstance(item_else, Block) else item_else.statements() if exec_every > 0: self.block.add_statement( - Branch(self.sim, BinOp.cmp(ts % exec_every, 0), True if stmts_else is None else False, + Branch(self.sim, BinOp.inline(BinOp.cmp(ts % exec_every, 0)), True if stmts_else is None else False, Block(self.sim, stmts), None if stmts_else is None else Block(self.sim, stmts_else))) else: self.block.add_statement(stmts) + self.sim.leave() + def as_block(self): return Block(self.sim, [self.timestep_loop])