From 7f0d066ff341c426098a6e2fc6d7e8514c9a7f55 Mon Sep 17 00:00:00 2001 From: Rafael Ravedutti <rafaelravedutti@gmail.com> Date: Thu, 18 Nov 2021 03:00:11 +0100 Subject: [PATCH] Fix first bugs with the introduction of modules Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com> --- src/pairs/code_gen/cgen.py | 1 + src/pairs/ir/block.py | 4 +--- src/pairs/ir/loops.py | 3 +++ src/pairs/ir/module.py | 6 +++++- src/pairs/sim/simulation.py | 6 ------ src/pairs/transformations/fetch_modules_references.py | 3 ++- src/pairs/transformations/replace_modules_by_calls.py | 3 ++- 7 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index bc3d1c2..b59579e 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -202,6 +202,7 @@ class CGen: self.print(f"{array_name} = ({tkw} *) malloc({size});") if isinstance(ast_node, Module_Call): + module = ast_node.module module_params = "" for var in module.read_only_variables(): decl = var.name() diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py index 4789b38..38092da 100644 --- a/src/pairs/ir/block.py +++ b/src/pairs/ir/block.py @@ -17,9 +17,7 @@ def pairs_device_block(func): sim = args[0].sim # self.sim sim.clear_block() func(*args, **kwargs) - module = Module(sim, block=KernelBlock(sim, sim.block)) - sim.add_module(module) - return module + return Module(sim, block=KernelBlock(sim, sim.block)) return inner diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py index 099c5c4..9065940 100644 --- a/src/pairs/ir/loops.py +++ b/src/pairs/ir/loops.py @@ -75,6 +75,9 @@ class ParticleFor(For): def __str__(self): return f"ParticleFor<>" + def children(self): + return [self.sim.nlocal] + ([] if self.local_only else [self.sim.pbc.npbc]) + class While(ASTNode): def __init__(self, sim, cond, block=None): diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py index 900e565..752152a 100644 --- a/src/pairs/ir/module.py +++ b/src/pairs/ir/module.py @@ -69,4 +69,8 @@ class Module_Call(ASTNode): def __init__(self, sim, module): assert isinstance(module, Module), "Module_Call(): given parameter is not of type Module!" super().__init__(sim) - self.module = module + self._module = module + + @property + def module(self): + return self._module diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index 70f11b0..734a7b1 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -119,12 +119,6 @@ class Simulation: def add_symbol(self, sym_type): return Symbol(self, sym_type) - def add_temporary_vector(self): - return self.vars.add(f"tmp{self.temp_id}", Type_Vector) - - def add_temporary_real(self): - return self.vars.add(f"tmp{self.temp_id}", Type_Float) - def var(self, var_name): return self.vars.find(var_name) diff --git a/src/pairs/transformations/fetch_modules_references.py b/src/pairs/transformations/fetch_modules_references.py index 534ddfb..d9fa43d 100644 --- a/src/pairs/transformations/fetch_modules_references.py +++ b/src/pairs/transformations/fetch_modules_references.py @@ -49,7 +49,8 @@ class AddDereferencesToWriteVariables(Mutator): return ast_node def mutate_Var(self, ast_node): - if ast_node in self.module_stack[-1].write_variables(): + parent_module = self.module_stack[-1] + if parent_module.name != 'main' and ast_node in parent_module.write_variables(): return Deref(ast_node.sim, ast_node) return ast_node diff --git a/src/pairs/transformations/replace_modules_by_calls.py b/src/pairs/transformations/replace_modules_by_calls.py index 4b2e6c7..1e81694 100644 --- a/src/pairs/transformations/replace_modules_by_calls.py +++ b/src/pairs/transformations/replace_modules_by_calls.py @@ -7,7 +7,8 @@ class ReplaceModulesByCalls(Mutator): super().__init__(ast) def mutate_Module(self, ast_node): - return Module_Call(ast_node.sim, ast_node) + ast_node._block = self.mutate(ast_node._block) + return Module_Call(ast_node.sim, ast_node) if ast_node.name != 'main' else ast_node def replace_modules_by_calls(ast): -- GitLab