From ea3310d727ee1ba1c554039f793f841df814a94f Mon Sep 17 00:00:00 2001 From: Rafael Ravedutti <rafaelravedutti@gmail.com> Date: Sat, 27 Nov 2021 00:33:08 +0100 Subject: [PATCH] Fix references in modules Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com> --- src/pairs/ir/loops.py | 2 +- src/pairs/ir/visitor.py | 9 ++++++++- src/pairs/sim/simulation.py | 2 ++ src/pairs/sim/timestep.py | 5 ++++- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py index 9065940..1e2253f 100644 --- a/src/pairs/ir/loops.py +++ b/src/pairs/ir/loops.py @@ -76,7 +76,7 @@ class ParticleFor(For): return f"ParticleFor<>" def children(self): - return [self.sim.nlocal] + ([] if self.local_only else [self.sim.pbc.npbc]) + return [self.block, self.sim.nlocal] + ([] if self.local_only else [self.sim.pbc.npbc]) class While(ASTNode): diff --git a/src/pairs/ir/visitor.py b/src/pairs/ir/visitor.py index b40a8d5..e75c5bb 100644 --- a/src/pairs/ir/visitor.py +++ b/src/pairs/ir/visitor.py @@ -19,7 +19,14 @@ class Visitor: if method is not None: method(ast_node) else: - self.visit_children(ast_node) + for b in type(ast_node).__bases__: + method = self.get_method(f"visit_{b.__name__}") + if method is not None: + method(ast_node) + break + + if method is None: + self.visit_children(ast_node) def visit_children(self, ast_node): for c in ast_node.children(): diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index 4187b4d..8e30895 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -216,7 +216,9 @@ class Simulation: self.kernels ]) + self.enter(timestep.block) timestep.add(VTKWrite(self, self.vtk_file, timestep.timestep() + 1)) + self.leave() body = Block.from_list(self, [ self.setups, diff --git a/src/pairs/sim/timestep.py b/src/pairs/sim/timestep.py index 3e312f2..606fc0e 100644 --- a/src/pairs/sim/timestep.py +++ b/src/pairs/sim/timestep.py @@ -28,16 +28,19 @@ class Timestep: stmts = item if not isinstance(item, Block) else item.statements() stmts_else = None ts = self.timestep_loop.iter() + self.sim.enter(self.block) if item_else is not None: stmts_else = item_else if not isinstance(item_else, Block) else item_else.statements() if exec_every > 0: self.block.add_statement( - Branch(self.sim, BinOp.cmp(ts % exec_every, 0), True if stmts_else is None else False, + Branch(self.sim, BinOp.inline(BinOp.cmp(ts % exec_every, 0)), True if stmts_else is None else False, Block(self.sim, stmts), None if stmts_else is None else Block(self.sim, stmts_else))) else: self.block.add_statement(stmts) + self.sim.leave() + def as_block(self): return Block(self.sim, [self.timestep_loop]) -- GitLab