From c126045e24f9280eef5bc951507703a006120da2 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Fri, 30 Jun 2023 00:17:59 +0200
Subject: [PATCH] Fix accesses to constant memory on force parameters

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 examples/lj.py                           |  1 -
 src/pairs/code_gen/cgen.py               |  8 ++-
 src/pairs/ir/device.py                   | 12 +++++
 src/pairs/ir/mutator.py                  |  6 ++-
 src/pairs/transformations/__init__.py    |  7 ++-
 src/pairs/transformations/devices.py     | 65 +++++++++++++++++++++++-
 src/pairs/transformations/expressions.py |  2 +-
 7 files changed, 94 insertions(+), 7 deletions(-)

diff --git a/examples/lj.py b/examples/lj.py
index 5d4a5d5..c7392da 100644
--- a/examples/lj.py
+++ b/examples/lj.py
@@ -18,7 +18,6 @@ target = sys.argv[1] if len(sys.argv[1]) > 1 else "none"
 if target != 'cpu' and target != 'gpu':
     print(f"Invalid target, use {cmd} <cpu/gpu>")
 
-
 dt = 0.005
 cutoff_radius = 2.5
 skin = 0.3
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 4e7c431..7fd9480 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -6,7 +6,7 @@ from pairs.ir.branches import Branch
 from pairs.ir.cast import Cast
 from pairs.ir.contexts import Contexts
 from pairs.ir.bin_op import BinOp, Decl, VectorAccess
-from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, SetArrayFlag, SetPropertyFlag, HostRef
+from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, DeviceStaticRef, SetArrayFlag, SetPropertyFlag, HostRef
 from pairs.ir.features import FeatureProperty, FeaturePropertyAccess, RegisterFeatureProperty
 from pairs.ir.functions import Call
 from pairs.ir.kernel import Kernel, KernelLaunch
@@ -635,6 +635,10 @@ class CGen:
             var = self.generate_expression(ast_node.var)
             return f"(*{var})"
 
+        if isinstance(ast_node, DeviceStaticRef):
+            elem = self.generate_expression(ast_node.elem)
+            return f"d_{elem}"
+
         if isinstance(ast_node, FeatureProperty):
             return ast_node.name()
 
@@ -668,7 +672,7 @@ class CGen:
             return f"p{ast_node.id()}" + (f"_{index}" if ast_node.is_vector_kind() else "")
 
         if isinstance(ast_node, FeaturePropertyAccess):
-            feature_name = self.generate_expression(ast_node.feature_prop.name())
+            feature_name = self.generate_expression(ast_node.feature_prop)
             if mem or ast_node.inlined is True:
                 index = self.generate_expression(ast_node.index)
                 return f"{feature_name}[{index}]"
diff --git a/src/pairs/ir/device.py b/src/pairs/ir/device.py
index 6ba75d8..90a3ad6 100644
--- a/src/pairs/ir/device.py
+++ b/src/pairs/ir/device.py
@@ -17,6 +17,18 @@ class HostRef(ASTNode):
         return [self.elem]
 
 
+class DeviceStaticRef(ASTNode):
+    def __init__(self, sim, elem):
+        super().__init__(sim)
+        self.elem = elem
+
+    def type(self):
+        return self.elem.type()
+
+    def children(self):
+        return [self.elem]
+
+
 class CopyArray(ASTNode):
     def __init__(self, sim, array, ctx):
         super().__init__(sim)
diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py
index 2e5acb9..d87b2f8 100644
--- a/src/pairs/ir/mutator.py
+++ b/src/pairs/ir/mutator.py
@@ -83,6 +83,10 @@ class Mutator:
         ast_node.elem = self.mutate(ast_node.elem)
         return ast_node
 
+    def mutate_DeviceStaticRef(self, ast_node):
+        ast_node.elem = self.mutate(ast_node.elem)
+        return ast_node
+
     def mutate_Filter(self, ast_node):
         return self.mutate_Branch(ast_node)
 
@@ -120,7 +124,7 @@ class Mutator:
         return ast_node
 
     def mutate_FeaturePropertyAccess(self, ast_node):
-        ast_node.prop = self.mutate(ast_node.feature_prop)
+        ast_node.feature_prop = self.mutate(ast_node.feature_prop)
         ast_node.index = self.mutate(ast_node.index)
         ast_node.expressions = {i: self.mutate(e) for i, e in ast_node.expressions.items()}
         return ast_node
diff --git a/src/pairs/transformations/__init__.py b/src/pairs/transformations/__init__.py
index e1cf7d3..538d416 100644
--- a/src/pairs/transformations/__init__.py
+++ b/src/pairs/transformations/__init__.py
@@ -1,6 +1,6 @@
 from pairs.analysis import Analysis
 from pairs.transformations.blocks import LiftExprOwnerBlocks, MergeAdjacentBlocks
-from pairs.transformations.devices import AddDeviceCopies, AddDeviceKernels, AddHostReferencesToModules
+from pairs.transformations.devices import AddDeviceCopies, AddDeviceKernels, AddHostReferencesToModules, AddDeviceReferencesToModules
 from pairs.transformations.expressions import ReplaceSymbols, SimplifyExpressions, PrioritizeScalarOps, AddExpressionDeclarations
 from pairs.transformations.loops import LICM
 from pairs.transformations.lower import Lower
@@ -74,6 +74,10 @@ class Transformations:
         if self._target.is_gpu():
             self.apply(AddHostReferencesToModules())
 
+    def add_device_references_to_modules(self):
+        if self._target.is_gpu():
+            self.apply(AddDeviceReferencesToModules())
+
     def apply_all(self):
         self.lower()
         self.optimize_expressions()
@@ -86,3 +90,4 @@ class Transformations:
         self.lower(True)
         self.add_expression_declarations()
         self.add_host_references_to_modules()
+        self.add_device_references_to_modules()
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index fcb118b..836f98f 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -5,7 +5,7 @@ from pairs.ir.block import Block
 from pairs.ir.branches import Filter
 from pairs.ir.cast import Cast
 from pairs.ir.contexts import Contexts
-from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, SetArrayFlag, SetPropertyFlag, HostRef
+from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, DeviceStaticRef, SetArrayFlag, SetPropertyFlag, HostRef
 from pairs.ir.kernel import Kernel, KernelLaunch
 from pairs.ir.lit import Lit
 from pairs.ir.loops import For
@@ -161,3 +161,66 @@ class AddHostReferencesToModules(Mutator):
             return HostRef(ast_node.sim, ast_node)
 
         return ast_node
+
+
+class AddDeviceReferencesToModules(Mutator):
+    def __init__(self, ast=None):
+        super().__init__(ast)
+        self.kernel_context = False
+        self.within_decl = False
+        self.add_reference = False
+        self.declared_objects = []
+
+    def must_add_reference(self, ast_node):
+        return id(ast_node) not in self.declared_objects and self.kernel_context and \
+               (ast_node.inlined is True or self.within_decl)
+
+    def mutate_ArrayAccess(self, ast_node):
+        if isinstance(ast_node.array, (DeviceStaticRef, HostRef)):
+            return ast_node
+
+        _add_reference = self.add_reference
+        self.add_reference = ast_node.array.is_static() and self.must_add_reference(ast_node)
+        ast_node.array = self.mutate(ast_node.array)
+        self.add_reference = _add_reference
+        return ast_node
+
+    def mutate_ArrayStatic(self, ast_node):
+        if self.add_reference:
+            return DeviceStaticRef(ast_node.sim, ast_node)
+
+        return ast_node
+
+    def mutate_FeaturePropertyAccess(self, ast_node):
+        _add_reference = self.add_reference
+        self.add_reference = self.must_add_reference(ast_node)
+        ast_node.feature_prop = self.mutate(ast_node.feature_prop)
+        self.add_reference = _add_reference
+        return ast_node
+
+    def mutate_FeatureProperty(self, ast_node):
+        if self.add_reference:
+            return DeviceStaticRef(ast_node.sim, ast_node)
+
+        return ast_node
+
+    def mutate_DeviceStaticRef(self, ast_node):
+        return ast_node
+
+    def mutate_Decl(self, ast_node):
+        _within_decl = self.within_decl
+        self.within_decl = True
+        ast_node.elem = self.mutate(ast_node.elem)
+        self.declared_objects.append(id(ast_node.elem))
+        self.within_decl = _within_decl
+        return ast_node
+
+    def mutate_HostRef(self, ast_node):
+        return ast_node
+
+    def mutate_Kernel(self, ast_node):
+        _kernel_context = self.kernel_context
+        self.kernel_context = True
+        ast_node._block = self.mutate(ast_node._block)
+        self.kernel_context = _kernel_context
+        return ast_node
diff --git a/src/pairs/transformations/expressions.py b/src/pairs/transformations/expressions.py
index 7a366dd..15dc688 100644
--- a/src/pairs/transformations/expressions.py
+++ b/src/pairs/transformations/expressions.py
@@ -193,7 +193,7 @@ class AddExpressionDeclarations(Mutator):
 
     def mutate_FeaturePropertyAccess(self, ast_node):
         assert self.writing is False, "Cannot change feature property!"
-        ast_node.prop = self.mutate(ast_node.prop)
+        ast_node.feature_prop = self.mutate(ast_node.feature_prop)
         ast_node.index = self.mutate(ast_node.index)
         ast_node.expressions = {i: self.mutate(e) for i, e in ast_node.expressions.items()}
 
-- 
GitLab