From 3aa444a39d5f09573c5acbef6d486a4986a949ea Mon Sep 17 00:00:00 2001
From: Behzad Safaei <iwia103h@alex1.nhr.fau.de>
Date: Mon, 10 Feb 2025 15:35:52 +0100
Subject: [PATCH] Move module instrumentations inside the module block

---
 src/pairs/sim/instrumentation.py             |  4 ++--
 src/pairs/sim/simulation.py                  | 11 +++++++----
 src/pairs/transformations/__init__.py        |  3 ++-
 src/pairs/transformations/instrumentation.py | 15 ++++++++-------
 4 files changed, 19 insertions(+), 14 deletions(-)

diff --git a/src/pairs/sim/instrumentation.py b/src/pairs/sim/instrumentation.py
index dedc7c1..7281f13 100644
--- a/src/pairs/sim/instrumentation.py
+++ b/src/pairs/sim/instrumentation.py
@@ -13,7 +13,7 @@ class RegisterTimers(FinalLowerable):
             Call_Void(self.sim, "pairs::register_timer", [t, Timers.name(t)])
 
         for m in self.sim.module_list:
-            if m.name != 'main':
+            if m.name != 'main' and m.name != 'initialize':
                 Call_Void(self.sim, "pairs::register_timer", [m.module_id + Timers.Offset, m.name])
 
 
@@ -25,5 +25,5 @@ class RegisterMarkers(FinalLowerable):
     def lower(self):
         if self.sim._enable_profiler:
             for m in self.sim.module_list:
-                if m.name != 'main' and m.must_profile():
+                if m.name != 'main' and m.name != 'initialize' and m.must_profile():
                     Call_Void(self.sim, "LIKWID_MARKER_REGISTER", [m.name])
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 66d5f6d..89cfb59 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -591,7 +591,7 @@ class Simulation:
                 ])
 
             setup_sim_module = Module(self, name='setup_sim', block=setup_sim)
-            communicate_module = Module(self, name='communicate', block=Timestep(self, 0, comm_routine).as_block())
+            communicate_module = Module(self, name='communicate', block=Timestep(self, 0, comm_routine).block)
             reset_volatiles_module = Module(self, name='reset_volatiles', block=Block(self, ResetVolatileProperties(self)))
             
             modules_list = [
@@ -610,9 +610,12 @@ class Simulation:
 
             # user defined modules are transformed seperately as indvidual modules 
             # i.e. they are transformed once again if already transformed in setup_sim or do_timestep
-            user_defined_modules = self.setup_functions + self.pre_step_functions + self.functions
-            user_defined_modules = [m[0] if isinstance(m, tuple) else m for m in user_defined_modules]
-            user_defined_modules = [Module(self, name=m.name, block=Block(self, m), user_defined=True) for m in user_defined_modules]
+            udf_internal = self.setup_functions + self.pre_step_functions + self.functions
+            udf_internal = [m[0] if isinstance(m, tuple) else m for m in udf_internal]
+            user_defined_modules = [Module(self, name=m.name, block=Block(self, m), user_defined=True) for m in udf_internal]
+            for i, m in enumerate(user_defined_modules):
+                m._id = udf_internal[i]._id
+
             Transformations(user_defined_modules, self._target).apply_all()
 
             # Generate library
diff --git a/src/pairs/transformations/__init__.py b/src/pairs/transformations/__init__.py
index 36df851..3a902ef 100644
--- a/src/pairs/transformations/__init__.py
+++ b/src/pairs/transformations/__init__.py
@@ -104,8 +104,9 @@ class Transformations:
         self.modularize()
         self.add_device_kernels()
         self.add_device_copies()
-        self.add_instrumentation()
         self.lower(True)
         self.add_expression_declarations()
         self.add_host_references_to_modules()
         self.add_device_references_to_modules()
+        self.add_instrumentation()
+
diff --git a/src/pairs/transformations/instrumentation.py b/src/pairs/transformations/instrumentation.py
index 1e70bdb..88b73c0 100644
--- a/src/pairs/transformations/instrumentation.py
+++ b/src/pairs/transformations/instrumentation.py
@@ -12,16 +12,17 @@ class AddModulesInstrumentation(Mutator):
     def mutate_ModuleCall(self, ast_node):
         ast_node._module = self.mutate(ast_node._module)
         module = ast_node._module
-        if module.name == 'main':
+        if module.name == 'main' or module.name == 'initialize':
             return ast_node
 
-        timer_id = module.module_id + Timers.Offset
-        start_timer = Call_Void(ast_node.sim, "pairs::start_timer", [timer_id])
-        stop_timer = Call_Void(ast_node.sim, "pairs::stop_timer", [timer_id])
-
         if module.must_profile():
             start_marker = Call_Void(ast_node.sim, "LIKWID_MARKER_START", [module.name])
             stop_marker = Call_Void(ast_node.sim, "LIKWID_MARKER_STOP", [module.name])
-            return Block(ast_node.sim, [start_timer, start_marker, ast_node, stop_marker, stop_timer])
+            module._block =  Block.from_list(ast_node.sim, [start_marker, module._block, stop_marker])
+        
+        timer_id = module.module_id + Timers.Offset
+        start_timer = Call_Void(ast_node.sim, "pairs::start_timer", [timer_id])
+        stop_timer = Call_Void(ast_node.sim, "pairs::stop_timer", [timer_id])
+        module._block = Block.from_list(ast_node.sim, [start_timer, module._block, stop_timer])
 
-        return Block(ast_node.sim, [start_timer, ast_node, stop_timer])
+        return ast_node
-- 
GitLab