diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index bc3d1c2c0248e72e6f61f533529717d8688f4b25..b59579e77a253ff284ff791e55bca2c5d22499b7 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -202,6 +202,7 @@ class CGen: self.print(f"{array_name} = ({tkw} *) malloc({size});") if isinstance(ast_node, Module_Call): + module = ast_node.module module_params = "" for var in module.read_only_variables(): decl = var.name() diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py index 4789b3852603d41b5b002967e84da4d9edb63dd3..38092da8671f581584218f4b7d506ef3b3f4acf8 100644 --- a/src/pairs/ir/block.py +++ b/src/pairs/ir/block.py @@ -17,9 +17,7 @@ def pairs_device_block(func): sim = args[0].sim # self.sim sim.clear_block() func(*args, **kwargs) - module = Module(sim, block=KernelBlock(sim, sim.block)) - sim.add_module(module) - return module + return Module(sim, block=KernelBlock(sim, sim.block)) return inner diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py index 099c5c432f13d4389ece066b953dafc377389981..9065940a988b1bcb75738f948fee13ab25b146b6 100644 --- a/src/pairs/ir/loops.py +++ b/src/pairs/ir/loops.py @@ -75,6 +75,9 @@ class ParticleFor(For): def __str__(self): return f"ParticleFor<>" + def children(self): + return [self.sim.nlocal] + ([] if self.local_only else [self.sim.pbc.npbc]) + class While(ASTNode): def __init__(self, sim, cond, block=None): diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py index 900e565869b82f35ec1ae6fb1b81fa44fdd12fd0..752152a858ca609fdd3d5bca9a183c0422028d6e 100644 --- a/src/pairs/ir/module.py +++ b/src/pairs/ir/module.py @@ -69,4 +69,8 @@ class Module_Call(ASTNode): def __init__(self, sim, module): assert isinstance(module, Module), "Module_Call(): given parameter is not of type Module!" super().__init__(sim) - self.module = module + self._module = module + + @property + def module(self): + return self._module diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index 70f11b088c04b2db747eed7499fccff47797dc10..734a7b181b3181203a1941f73ad64f254233bcf4 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -119,12 +119,6 @@ class Simulation: def add_symbol(self, sym_type): return Symbol(self, sym_type) - def add_temporary_vector(self): - return self.vars.add(f"tmp{self.temp_id}", Type_Vector) - - def add_temporary_real(self): - return self.vars.add(f"tmp{self.temp_id}", Type_Float) - def var(self, var_name): return self.vars.find(var_name) diff --git a/src/pairs/transformations/fetch_modules_references.py b/src/pairs/transformations/fetch_modules_references.py index 534ddfbd3eedef8b204154d48f109aeebca77121..d9fa43d92fdaaf7a714ce2c47b323ed9dffbd973 100644 --- a/src/pairs/transformations/fetch_modules_references.py +++ b/src/pairs/transformations/fetch_modules_references.py @@ -49,7 +49,8 @@ class AddDereferencesToWriteVariables(Mutator): return ast_node def mutate_Var(self, ast_node): - if ast_node in self.module_stack[-1].write_variables(): + parent_module = self.module_stack[-1] + if parent_module.name != 'main' and ast_node in parent_module.write_variables(): return Deref(ast_node.sim, ast_node) return ast_node diff --git a/src/pairs/transformations/replace_modules_by_calls.py b/src/pairs/transformations/replace_modules_by_calls.py index 4b2e6c7a405eb8bc091f3aca1c113c46c15d4c68..1e81694b936fb9783ca86903cf890b9f420a87ae 100644 --- a/src/pairs/transformations/replace_modules_by_calls.py +++ b/src/pairs/transformations/replace_modules_by_calls.py @@ -7,7 +7,8 @@ class ReplaceModulesByCalls(Mutator): super().__init__(ast) def mutate_Module(self, ast_node): - return Module_Call(ast_node.sim, ast_node) + ast_node._block = self.mutate(ast_node._block) + return Module_Call(ast_node.sim, ast_node) if ast_node.name != 'main' else ast_node def replace_modules_by_calls(ast):