diff --git a/src/pairs/sim/instrumentation.py b/src/pairs/sim/instrumentation.py index dedc7c18e940796e035aa97ccd3ff527e370438f..7281f13fc9f848038c4df34248d65db7e8e13c8c 100644 --- a/src/pairs/sim/instrumentation.py +++ b/src/pairs/sim/instrumentation.py @@ -13,7 +13,7 @@ class RegisterTimers(FinalLowerable): Call_Void(self.sim, "pairs::register_timer", [t, Timers.name(t)]) for m in self.sim.module_list: - if m.name != 'main': + if m.name != 'main' and m.name != 'initialize': Call_Void(self.sim, "pairs::register_timer", [m.module_id + Timers.Offset, m.name]) @@ -25,5 +25,5 @@ class RegisterMarkers(FinalLowerable): def lower(self): if self.sim._enable_profiler: for m in self.sim.module_list: - if m.name != 'main' and m.must_profile(): + if m.name != 'main' and m.name != 'initialize' and m.must_profile(): Call_Void(self.sim, "LIKWID_MARKER_REGISTER", [m.name]) diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index 66d5f6dbf1715779b19490e3b789b453f0c16e45..89cfb5925224de67d0a0580238f6bdf2a9b785a2 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -591,7 +591,7 @@ class Simulation: ]) setup_sim_module = Module(self, name='setup_sim', block=setup_sim) - communicate_module = Module(self, name='communicate', block=Timestep(self, 0, comm_routine).as_block()) + communicate_module = Module(self, name='communicate', block=Timestep(self, 0, comm_routine).block) reset_volatiles_module = Module(self, name='reset_volatiles', block=Block(self, ResetVolatileProperties(self))) modules_list = [ @@ -610,9 +610,12 @@ class Simulation: # user defined modules are transformed seperately as indvidual modules # i.e. they are transformed once again if already transformed in setup_sim or do_timestep - user_defined_modules = self.setup_functions + self.pre_step_functions + self.functions - user_defined_modules = [m[0] if isinstance(m, tuple) else m for m in user_defined_modules] - user_defined_modules = [Module(self, name=m.name, block=Block(self, m), user_defined=True) for m in user_defined_modules] + udf_internal = self.setup_functions + self.pre_step_functions + self.functions + udf_internal = [m[0] if isinstance(m, tuple) else m for m in udf_internal] + user_defined_modules = [Module(self, name=m.name, block=Block(self, m), user_defined=True) for m in udf_internal] + for i, m in enumerate(user_defined_modules): + m._id = udf_internal[i]._id + Transformations(user_defined_modules, self._target).apply_all() # Generate library diff --git a/src/pairs/transformations/__init__.py b/src/pairs/transformations/__init__.py index 36df8512d6f047ec6ef8de764e14854aad618311..3a902ef2348c041c3c1734bec3bcbb83e21d5fda 100644 --- a/src/pairs/transformations/__init__.py +++ b/src/pairs/transformations/__init__.py @@ -104,8 +104,9 @@ class Transformations: self.modularize() self.add_device_kernels() self.add_device_copies() - self.add_instrumentation() self.lower(True) self.add_expression_declarations() self.add_host_references_to_modules() self.add_device_references_to_modules() + self.add_instrumentation() + diff --git a/src/pairs/transformations/instrumentation.py b/src/pairs/transformations/instrumentation.py index 1e70bdb7ebe2753d3031cf3a1ed87dd047b69650..88b73c0d8406b97392d267ace3b1e1bbbb3ca068 100644 --- a/src/pairs/transformations/instrumentation.py +++ b/src/pairs/transformations/instrumentation.py @@ -12,16 +12,17 @@ class AddModulesInstrumentation(Mutator): def mutate_ModuleCall(self, ast_node): ast_node._module = self.mutate(ast_node._module) module = ast_node._module - if module.name == 'main': + if module.name == 'main' or module.name == 'initialize': return ast_node - timer_id = module.module_id + Timers.Offset - start_timer = Call_Void(ast_node.sim, "pairs::start_timer", [timer_id]) - stop_timer = Call_Void(ast_node.sim, "pairs::stop_timer", [timer_id]) - if module.must_profile(): start_marker = Call_Void(ast_node.sim, "LIKWID_MARKER_START", [module.name]) stop_marker = Call_Void(ast_node.sim, "LIKWID_MARKER_STOP", [module.name]) - return Block(ast_node.sim, [start_timer, start_marker, ast_node, stop_marker, stop_timer]) + module._block = Block.from_list(ast_node.sim, [start_marker, module._block, stop_marker]) + + timer_id = module.module_id + Timers.Offset + start_timer = Call_Void(ast_node.sim, "pairs::start_timer", [timer_id]) + stop_timer = Call_Void(ast_node.sim, "pairs::stop_timer", [timer_id]) + module._block = Block.from_list(ast_node.sim, [start_timer, module._block, stop_timer]) - return Block(ast_node.sim, [start_timer, ast_node, stop_timer]) + return ast_node