From c1d8f645991e9a8dfa48cd12225d4f11f50f6a1e Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Wed, 3 May 2023 14:37:05 +0200
Subject: [PATCH] Include feature properties into module and kernel parameters

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 runtime/feature_property.hpp         |  4 ++--
 src/pairs/analysis/devices.py        |  4 ++++
 src/pairs/analysis/modules.py        |  6 ++++++
 src/pairs/code_gen/cgen.py           | 28 ++++++++++++++++++++++++++--
 src/pairs/ir/kernel.py               | 11 +++++++++++
 src/pairs/ir/module.py               | 11 +++++++++++
 src/pairs/transformations/devices.py |  7 +++++++
 7 files changed, 67 insertions(+), 4 deletions(-)

diff --git a/runtime/feature_property.hpp b/runtime/feature_property.hpp
index 7ee5ce6..025a514 100644
--- a/runtime/feature_property.hpp
+++ b/runtime/feature_property.hpp
@@ -27,8 +27,8 @@ public:
     void *getHostPointer() { return h_ptr; }
     void *getDevicePointer() { return d_ptr; }
     PropertyType getType() { return type; }
-    const int getNumberOfKinds() { return nkinds; }
-    const int getArraySize() { return array_size; }
+    size_t getNumberOfKinds() { return nkinds; }
+    size_t getArraySize() { return array_size; }
 };
 
 }
diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py
index c75b3a6..ab47b00 100644
--- a/src/pairs/analysis/devices.py
+++ b/src/pairs/analysis/devices.py
@@ -83,6 +83,10 @@ class FetchKernelReferences(Visitor):
         for k in self.kernel_stack:
             k.add_property(ast_node, self.writing)
 
+    def visit_FeatureProperty(self, ast_node):
+        for k in self.kernel_stack:
+            k.add_feature_property(ast_node)
+
     def visit_Var(self, ast_node):
         for k in self.kernel_stack:
             k.add_variable(ast_node, self.writing)
diff --git a/src/pairs/analysis/modules.py b/src/pairs/analysis/modules.py
index c75ded4..23303a4 100644
--- a/src/pairs/analysis/modules.py
+++ b/src/pairs/analysis/modules.py
@@ -64,6 +64,12 @@ class FetchModulesReferences(Visitor):
             if m.run_on_device:
                 ast_node.device_flag = True
 
+    def visit_FeatureProperty(self, ast_node):
+        for m in self.module_stack:
+            m.add_feature_property(ast_node)
+            if m.run_on_device:
+                ast_node.device_flag = True
+
     def visit_Var(self, ast_node):
         for m in self.module_stack:
             if not ast_node.temporary():
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 3f2b947..ffbe414 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -77,7 +77,7 @@ class CGen:
                     size = self.generate_expression(BinOp.inline(array.alloc_size()))
                     self.print(f"__constant__ {tkw} d_{array.name()}[{size}];")
 
-            for feature_prop in self.sim.feature_properties():
+            for feature_prop in self.sim.feature_properties:
                 if feature_prop.device_flag:
                     t = feature_prop.type()
                     tkw = Types.c_keyword(t)
@@ -136,6 +136,15 @@ class CGen:
                     decl = f"{type_kw} *h_{prop.name()}"
                     module_params += f", {decl}"
 
+            for feature_prop in module.feature_properties():
+                type_kw = Types.c_keyword(feature_prop.type())
+                decl = f"{type_kw} *{feature_prop.name()}"
+                module_params += f", {decl}"
+
+                if feature_prop in module.host_references():
+                    decl = f"{type_kw} *h_{feature_prop.name()}"
+                    module_params += f", {decl}"
+
             self.print(f"void {module.name}({module_params}) {{")
 
             if self.debug:
@@ -173,6 +182,11 @@ class CGen:
             decl = f"{type_kw} *{prop.name()}"
             kernel_params += f", {decl}"
 
+        for feature_prop in kernel.feature_properties():
+            type_kw = Types.c_keyword(feature_prop.type())
+            decl = f"{type_kw} *{feature_prop.name()}"
+            kernel_params += f", {decl}"
+
         for array_access in kernel.array_accesses():
             type_kw = Types.c_keyword(array_access.type())
             decl = f"{type_kw} a{array_access.id()}"
@@ -400,6 +414,9 @@ class CGen:
             for prop in kernel.properties():
                 kernel_params += f", {prop.name()}"
 
+            for prop in kernel.feature_properties():
+                kernel_params += f", {prop.name()}"
+
             for array_access in kernel.array_accesses():
                 kernel_params += f", {self.generate_expression(array_access)}"
 
@@ -442,6 +459,13 @@ class CGen:
                     decl = prop.name()
                     module_params += f", {decl}"
 
+            for feature_prop in module.feature_properties():
+                decl = f"d_{feature_prop.name()}" if device_cond else feature_prop.name()
+                module_params += f", {decl}"
+                if feature_prop in module.host_references():
+                    decl = feature_prop.name()
+                    module_params += f", {decl}"
+
             self.print(f"{module.name}({module_params});")
 
         if isinstance(ast_node, Print):
@@ -503,7 +527,7 @@ class CGen:
         if isinstance(ast_node, RegisterFeatureProperty):
             fp = ast_node.feature_property()
             ptr = fp.name()
-            d_ptr = f"d_{ptr}" if self.target.is_gpu() and p.device_flag else "nullptr"
+            d_ptr = f"d_{ptr}" if self.target.is_gpu() and fp.device_flag else "nullptr"
             array_size = fp.array_size()
             nkinds = fp.feature().nkinds()
             tkw = Types.c_keyword(fp.type())
diff --git a/src/pairs/ir/kernel.py b/src/pairs/ir/kernel.py
index 63bd8d2..acfd096 100644
--- a/src/pairs/ir/kernel.py
+++ b/src/pairs/ir/kernel.py
@@ -1,6 +1,7 @@
 from pairs.ir.arrays import Array, ArrayAccess
 from pairs.ir.ast_node import ASTNode
 from pairs.ir.bin_op import BinOp
+from pairs.ir.features import FeatureProperty
 from pairs.ir.lit import Lit
 from pairs.ir.properties import Property
 from pairs.ir.variables import Var
@@ -16,6 +17,7 @@ class Kernel(ASTNode):
         self._variables = {}
         self._arrays = {}
         self._properties = {}
+        self._feature_properties = {}
         self._array_accesses = set()
         self._bin_ops = []
         self._block = block
@@ -54,6 +56,9 @@ class Kernel(ASTNode):
     def properties(self):
         return self._properties
 
+    def feature_properties(self):
+        return self._feature_properties
+
     def properties_to_synchronize(self):
         return {p for p in self._properties if self._properties[p][0] == 'r'}
 
@@ -87,6 +92,12 @@ class Kernel(ASTNode):
             assert isinstance(p, Property), "Kernel.add_property(): given element is not of type Property!"
             self._properties[p] = character if p not in self._properties else self._properties[p] + character
 
+    def add_feature_property(self, feature_prop):
+        feature_prop_list = feature_prop if isinstance(feature_prop, list) else [feature_prop]
+        for fp in feature_prop_list:
+            assert isinstance(fp, FeatureProperty), "Kernel.add_feature_property(): given element is not of type FeatureProperty!"
+            self._feature_properties[fp] = 'r'
+
     def add_array_access(self, array_access):
         array_access_list = array_access if isinstance(array_access, list) else [array_access]
         for a in array_access_list:
diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py
index 51f026f..6111d03 100644
--- a/src/pairs/ir/module.py
+++ b/src/pairs/ir/module.py
@@ -1,5 +1,6 @@
 from pairs.ir.arrays import Array
 from pairs.ir.ast_node import ASTNode
+from pairs.ir.features import FeatureProperty
 from pairs.ir.properties import Property
 from pairs.ir.variables import Var
 
@@ -15,6 +16,7 @@ class Module(ASTNode):
         self._variables = {}
         self._arrays = {}
         self._properties = {}
+        self._feature_properties = {}
         self._host_references = set()
         self._block = block
         self._resizes_to_check = resizes_to_check
@@ -61,6 +63,9 @@ class Module(ASTNode):
     def properties(self):
         return self._properties
 
+    def feature_properties(self):
+        return self._feature_properties
+
     def host_references(self):
         return self._host_references
 
@@ -99,6 +104,12 @@ class Module(ASTNode):
             assert isinstance(p, Property), "Module.add_property(): given element is not of type Property!"
             self._properties[p] = character if p not in self._properties else self._properties[p] + character
 
+    def add_feature_property(self, feature_prop):
+        feature_prop_list = feature_prop if isinstance(feature_prop, list) else [feature_prop]
+        for fp in feature_prop_list:
+            assert isinstance(fp, FeatureProperty), "Module.add_feature_property(): given element is not of type FeatureProperty!"
+            self._feature_properties[fp] = 'r'
+
     def add_host_reference(self, elem):
         self._host_references.add(elem)
 
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index 8094b69..fcb118b 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -128,6 +128,13 @@ class AddHostReferencesToModules(Mutator):
     def mutate_Decl(self, ast_node):
         return ast_node
 
+    def mutate_FeatureProperty(self, ast_node):
+        if self.device_context:
+            self.module_stack[-1].add_host_reference(ast_node)
+            return HostRef(ast_node.sim, ast_node)
+
+        return ast_node
+
     def mutate_HostRef(self, ast_node):
         return ast_node
 
-- 
GitLab