From ae595f96046f17685b915b57e94522c713c258bb Mon Sep 17 00:00:00 2001
From: Behzad Safaei <iwia103h@alex2.nhr.fau.de>
Date: Wed, 29 Jan 2025 01:14:54 +0100
Subject: [PATCH] Fix For to Kernel transformation, fix whole-program gen

---
 examples/dem_sd.py                   |   1 +
 src/pairs/analysis/devices.py        |  21 +++---
 src/pairs/code_gen/cgen.py           | 102 +++++++++++++++------------
 src/pairs/mapping/funcs.py           |   4 ++
 src/pairs/transformations/devices.py |   6 +-
 5 files changed, 79 insertions(+), 55 deletions(-)

diff --git a/examples/dem_sd.py b/examples/dem_sd.py
index 15ee699..94bc93f 100644
--- a/examples/dem_sd.py
+++ b/examples/dem_sd.py
@@ -181,4 +181,5 @@ psim.compute(linear_spring_dashpot,
                       'collisionTime_SI': collisionTime_SI})
 
 psim.compute(euler, parameters={'dt' : pairs.real()})
+# psim.compute(euler, symbols={'dt' : dt_SI})
 psim.generate()
diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py
index 1c009af..29e554e 100644
--- a/src/pairs/analysis/devices.py
+++ b/src/pairs/analysis/devices.py
@@ -15,18 +15,21 @@ class MarkCandidateLoops(Visitor):
         self.device_module = False
 
     def visit_For(self, ast_node):
-        if self.device_module:
-            if ast_node.not_kernel:
-                self.visit(ast_node.block)
-                ast_node.mark_iter_as_ref_candidate()
-            else:
-                if not isinstance(ast_node.min, Lit) or not isinstance(ast_node.max, Lit):
-                    ast_node.mark_as_kernel_candidate()
+        if self.device_module and not ast_node.not_kernel and (not isinstance(ast_node.min, Lit) or not isinstance(ast_node.max, Lit)):
+            ast_node.mark_as_kernel_candidate()
+        else:
+            ast_node.mark_iter_as_ref_candidate()
+            self.visit(ast_node.block)
+
 
     def visit_Module(self, ast_node):
-        self.device_module = ast_node.run_on_device
-        self.visit_children(ast_node)
+        parent_runs_on_device = self.device_module
+        if ast_node.run_on_device:
+            self.device_module = True
 
+        self.visit_children(ast_node)
+        self.device_module = parent_runs_on_device
+        
 
 class FetchKernelReferences(Visitor):
     def __init__(self, ast=None):
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 7649f4f..6795e2d 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -372,11 +372,22 @@ class CGen:
         self.generate_pairs_object_structure()
         self.generate_module_headers()
 
+        self.print("namespace pairs::internal {")
+        self.print.add_indent(4)
+
         for kernel in self.sim.kernels():
             self.generate_kernel(kernel)
 
         for module in self.sim.modules():
-            self.generate_module(module)
+            if module.name!='main':
+                self.generate_module(module)
+
+        self.print.add_indent(-4)
+        self.print("}")
+
+        for module in self.sim.modules():
+            if module.name=='main':
+                self.generate_main(module)
 
         self.print.end()
 
@@ -589,57 +600,58 @@ class CGen:
             if feature_prop in module.host_references():
                 self.print(f"{type_kw} *h_{feature_prop.name()} = pobj->{feature_prop.name()};")
 
-    def generate_module(self, module):
-        if module.name == 'main':
-            ndims = module.sim.ndims()
-            nprops = module.sim.properties.nprops()
-            ncontactprops = module.sim.contact_properties.nprops()
-            narrays = module.sim.arrays.narrays()
-            part = DomainPartitioners.c_keyword(module.sim.partitioner())
-
-            self.generate_full_object_names = True
-            self.print("int main(int argc, char **argv) {")
-            self.print(f"    PairsRuntime *pairs_runtime = new PairsRuntime({nprops}, {ncontactprops}, {narrays}, {part});")
-            self.print(f"    struct PairsObjects *pobj = new PairsObjects();")
-
-            if module.sim._enable_profiler:
-                self.print("    LIKWID_MARKER_INIT;")
-
-            self.generate_statement(module.block)
-
-            if module.sim._enable_profiler:
-                self.print("    LIKWID_MARKER_CLOSE;")
-
-            self.print("    pairs::print_timers(pairs_runtime);")
-            self.print("    pairs::print_stats(pairs_runtime, pobj->nlocal, pobj->nghost);")
-            self.print("    delete pobj;")
-            self.print("    delete pairs_runtime;")
-            self.print("    return 0;")
-            self.print("}")
-            self.generate_full_object_names = False
+    def generate_main(self, module):
+        assert module.name=='main'
+
+        ndims = module.sim.ndims()
+        nprops = module.sim.properties.nprops()
+        ncontactprops = module.sim.contact_properties.nprops()
+        narrays = module.sim.arrays.narrays()
+        part = DomainPartitioners.c_keyword(module.sim.partitioner())
+
+        self.generate_full_object_names = True
+        self.print("int main(int argc, char **argv) {")
+        self.print(f"    PairsRuntime *pairs_runtime = new PairsRuntime({nprops}, {ncontactprops}, {narrays}, {part});")
+        self.print(f"    struct PairsObjects *pobj = new PairsObjects();")
+
+        if module.sim._enable_profiler:
+            self.print("    LIKWID_MARKER_INIT;")
+
+        self.generate_statement(module.block)
 
+        if module.sim._enable_profiler:
+            self.print("    LIKWID_MARKER_CLOSE;")
+
+        self.print("    pairs::print_timers(pairs_runtime);")
+        self.print("    pairs::print_stats(pairs_runtime, pobj->nlocal, pobj->nghost);")
+        self.print("    delete pobj;")
+        self.print("    delete pairs_runtime;")
+        self.print("    return 0;")
+        self.print("}")
+        self.generate_full_object_names = False
+
+    def generate_module(self, module):
+        module_params = ", ".join(f"{Types.c_keyword(self.sim, param.type())} {param.name()}"
+                                            for param in module.parameters())
+        if not module.user_defined:
+            module_params = ", " + module_params if module_params else ""
+            self.print(f"void {module.name}(PairsRuntime *pairs_runtime, struct PairsObjects *pobj{module_params}) {{")
         else:
-            module_params = ", ".join(f"{Types.c_keyword(self.sim, param.type())} {param.name()}"
-                                                for param in module.parameters())
-            if not module.user_defined:
-                module_params = ", " + module_params if module_params else ""
-                self.print(f"void {module.name}(PairsRuntime *pairs_runtime, struct PairsObjects *pobj{module_params}) {{")
-            else:
-                
-                self.print(f"void {module.name}({module_params}) {{")
+            
+            self.print(f"void {module.name}({module_params}) {{")
 
 
-            self.print.add_indent(4)
+        self.print.add_indent(4)
 
-            if self.debug:
-                self.print(f"PAIRS_DEBUG(\"\\n{module.name}\\n\");")
+        if self.debug:
+            self.print(f"PAIRS_DEBUG(\"\\n{module.name}\\n\");")
 
-            if not module.user_defined:
-                self.generate_module_declerations(module)
+        if not module.user_defined:
+            self.generate_module_declerations(module)
 
-            self.print.add_indent(-4)
-            self.generate_statement(module.block)
-            self.print("}")
+        self.print.add_indent(-4)
+        self.generate_statement(module.block)
+        self.print("}")
 
     def generate_kernel(self, kernel):
         kernel_params = "int range_start"
diff --git a/src/pairs/mapping/funcs.py b/src/pairs/mapping/funcs.py
index efcb825..6258aa1 100644
--- a/src/pairs/mapping/funcs.py
+++ b/src/pairs/mapping/funcs.py
@@ -286,6 +286,10 @@ class BuildParticleIR(ast.NodeVisitor):
 
 
 def compute(sim, func, cutoff_radius=None, symbols={}, parameters={}, pre_step=False, skip_first=False):
+    if sim._generate_whole_program:
+        assert not parameters, "Compute functions can't take custom parameters when generating whole program."
+    
+
     src = inspect.getsource(func)
     tree = ast.parse(src, mode='exec')
     #print(ast.dump(ast.parse(src, mode='exec')))
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index dceb1e3..21898c1 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -88,6 +88,7 @@ class AddDeviceKernels(Mutator):
         super().__init__(ast)
         self._module_name = None
         self._kernel_id = 0
+        self._device_module = False
 
     def create_kernel(self, sim, iterator, rmax, block):
         kernel_name = f"{self._module_name}_kernel{self._kernel_id}"
@@ -101,7 +102,7 @@ class AddDeviceKernels(Mutator):
         return kernel
     
     def mutate_For(self, ast_node):
-        if ast_node.is_kernel_candidate():
+        if ast_node.is_kernel_candidate() and self._device_module:
             kernel = self.create_kernel(ast_node.sim, ast_node.iterator, ast_node.max, ast_node.block)
             ast_node = KernelLaunch(ast_node.sim, kernel, ast_node.iterator, ast_node.min, ast_node.max)
 
@@ -111,11 +112,14 @@ class AddDeviceKernels(Mutator):
         return ast_node
 
     def mutate_Module(self, ast_node):
+        parent_runs_on_device = self._device_module
         if ast_node.run_on_device:
+            self._device_module = True
             self._module_name = ast_node.name
             self._kernel_id = 0
 
         ast_node._block = self.mutate(ast_node._block)
+        self._device_module = parent_runs_on_device
         return ast_node
 
 class AddHostReferencesToModules(Mutator):
-- 
GitLab