From d5f02c1f271818ed8e0425f5a7f8de4a65b68b1c Mon Sep 17 00:00:00 2001
From: Behzad Safaei <iwia103h@alex1.nhr.fau.de>
Date: Sun, 12 Jan 2025 15:55:22 +0100
Subject: [PATCH] Enable outer Iter nodes to be referenced by inner kernels

---
 runtime/domain/block_forest.hpp |  2 +-
 src/pairs/analysis/devices.py   |  6 ++++++
 src/pairs/code_gen/cgen.py      |  8 ++++++++
 src/pairs/ir/kernel.py          | 16 ++++++++++++++++
 src/pairs/ir/loops.py           | 15 ++++++++++++++-
 5 files changed, 45 insertions(+), 2 deletions(-)

diff --git a/runtime/domain/block_forest.hpp b/runtime/domain/block_forest.hpp
index 1c7bf53..d5d9d54 100644
--- a/runtime/domain/block_forest.hpp
+++ b/runtime/domain/block_forest.hpp
@@ -42,7 +42,7 @@ private:
     real_t *subdom;
     const bool globalPBC[3];
     int world_size, rank, nranks, total_aabbs;
-    bool balance_workload;
+    bool balance_workload = false;
 
 public:
     BlockForest(
diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py
index 9fdddc2..23a3bed 100644
--- a/src/pairs/analysis/devices.py
+++ b/src/pairs/analysis/devices.py
@@ -18,6 +18,7 @@ class MarkCandidateLoops(Visitor):
         if self.device_module:
             if ast_node.not_kernel:
                 self.visit(ast_node.block)
+                ast_node.mark_iter_as_ref_candidate()
             else:
                 if not isinstance(ast_node.min, Lit) or not isinstance(ast_node.max, Lit):
                     ast_node.mark_as_kernel_candidate()
@@ -195,3 +196,8 @@ class FetchKernelReferences(Visitor):
             # Variables only have a device version when changed within kernels
             if self.writing:
                 ast_node.device_flag = True
+    
+    def visit_Iter(self, ast_node):
+        for k in self.kernel_stack:
+            if ast_node.is_ref_candidate():
+                k.add_iter(ast_node, self.writing)
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index a7166be..c9e2c2e 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -626,6 +626,11 @@ class CGen:
             decl = f"{type_kw} *{var.name()}"
             kernel_params += f", {decl}"
 
+        for it in kernel.iters():
+            type_kw = Types.c_keyword(self.sim, it.type())
+            decl = f"{type_kw} {it.name()}"
+            kernel_params += f", {decl}"
+
         for array in kernel.arrays():
             if array.is_static():
                 continue
@@ -982,6 +987,9 @@ class CGen:
             for var in kernel.write_variables():
                 kernel_params += f", {var.name()}"
 
+            for it in kernel.iters():
+                kernel_params += f", {it.name()}"
+
             for array in kernel.arrays():
                 if array.is_static():
                     continue
diff --git a/src/pairs/ir/kernel.py b/src/pairs/ir/kernel.py
index 04def29..5faaee4 100644
--- a/src/pairs/ir/kernel.py
+++ b/src/pairs/ir/kernel.py
@@ -9,6 +9,7 @@ from pairs.ir.properties import Property, ContactProperty
 from pairs.ir.quaternions import QuaternionOp
 from pairs.ir.variables import Var
 from pairs.ir.vectors import VectorOp
+from pairs.ir.loops import Iter
 
 
 class Kernel(ASTNode):
@@ -19,6 +20,7 @@ class Kernel(ASTNode):
         self._id = Kernel.last_kernel
         self._name = name if name is not None else "kernel" + str(Kernel.last_kernel)
         self._variables = {}
+        self._iters = {}
         self._arrays = {}
         self._properties = {}
         self._contact_properties = {}
@@ -50,6 +52,9 @@ class Kernel(ASTNode):
     def variables(self):
         return self._variables
 
+    def iters(self):
+        return self._iters
+    
     def read_only_variables(self):
         return [var for var in self._variables if self._variables[var] == Actions.ReadOnly]
 
@@ -99,6 +104,17 @@ class Kernel(ASTNode):
 
                 action = Actions.NoAction if var not in self._variables else self._variables[var]
                 self._variables[var] = Actions.update_rule(action, new_op)
+    
+    def add_iter(self, iter, write=False):
+        iter_list = iter if isinstance(iter, list) else [iter]
+        new_op = 'w' if write else 'r'
+
+        for it in iter_list:
+            assert isinstance(it, Iter), \
+                "Kernel.add_iter(): Element is not of type Iter."
+
+            action = Actions.NoAction if it not in self._iters else self._iters[it]
+            self._iters[it] = Actions.update_rule(action, new_op)
 
     def add_property(self, prop, write=False):
         prop_list = prop if isinstance(prop, list) else [prop]
diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py
index 228bd40..8842818 100644
--- a/src/pairs/ir/loops.py
+++ b/src/pairs/ir/loops.py
@@ -18,6 +18,7 @@ class Iter(ASTTerm):
         super().__init__(sim, ScalarOp)
         self.loop = loop
         self.iter_id = Iter.new_id()
+        self._ref_candidate = False
 
     def id(self):
         return self.iter_id
@@ -27,7 +28,16 @@ class Iter(ASTTerm):
 
     def type(self):
         return Types.Int32
-
+    
+    def mark_as_ref_candidate(self):
+        self._ref_candidate = True
+
+    def is_ref_candidate(self):
+        return self._ref_candidate
+    
+    def __hash__(self):
+        return hash(self.iter_id)
+    
     def __eq__(self, other):
         return isinstance(other, Iter) and self.iter_id == other.iter_id
 
@@ -64,6 +74,9 @@ class For(ASTNode):
     def mark_as_kernel_candidate(self):
         self._kernel_candidate = True
 
+    def mark_iter_as_ref_candidate(self):
+        self.iterator.mark_as_ref_candidate()
+
     def is_kernel_candidate(self):
         return self._kernel_candidate
 
-- 
GitLab