From a3e7154ad15bcb3445c3ef00865c77410fab14af Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Fri, 18 Feb 2022 16:12:25 +0100
Subject: [PATCH] Inline particle interactions and fix modules references

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 src/pairs/analysis/modules.py | 27 ++++++++++++++++++-----
 src/pairs/ir/visitor.py       | 40 ++++++++++++++++++-----------------
 src/pairs/sim/interaction.py  |  4 ++--
 3 files changed, 45 insertions(+), 26 deletions(-)

diff --git a/src/pairs/analysis/modules.py b/src/pairs/analysis/modules.py
index 6cf31a5..e919280 100644
--- a/src/pairs/analysis/modules.py
+++ b/src/pairs/analysis/modules.py
@@ -7,20 +7,37 @@ class FetchModulesReferences(Visitor):
         self.module_stack = []
         self.writing = False
 
+    def visit_ArrayAccess(self, ast_node):
+        # Visit array and save current writing state
+        self.visit(ast_node.array)
+        writing_state = self.writing
+
+        # Index elements are read-only
+        self.writing = False
+        self.visit([roc for roc in ast_node.children() if roc != ast_node.array])
+        self.writing = writing_state
+
     def visit_Assign(self, ast_node):
         self.writing = True
-        for c in ast_node.destinations():
-            self.visit(c)
-
+        self.visit(ast_node.destinations())
         self.writing = False
-        for c in ast_node.sources():
-            self.visit(c)
+        self.visit(ast_node.sources())
 
     def visit_Module(self, ast_node):
         self.module_stack.append(ast_node)
         self.visit_children(ast_node)
         self.module_stack.pop()
 
+    def visit_PropertyAccess(self, ast_node):
+        # Visit property and save current writing state
+        self.visit(ast_node.prop)
+        writing_state = self.writing
+
+        # Index elements are read-only
+        self.writing = False
+        self.visit([roc for roc in ast_node.children() if roc != ast_node.prop])
+        self.writing = writing_state
+
     def visit_Array(self, ast_node):
         for m in self.module_stack:
             m.add_array(ast_node, self.writing)
diff --git a/src/pairs/ir/visitor.py b/src/pairs/ir/visitor.py
index e75c5bb..190e05e 100644
--- a/src/pairs/ir/visitor.py
+++ b/src/pairs/ir/visitor.py
@@ -11,26 +11,29 @@ class Visitor:
         method = getattr(self, method_name, None)
         return method if callable(method) else None
 
-    def visit(self, ast_node=None):
-        if ast_node is None:
-            ast_node = self.ast
-
-        method = self.get_method(f"visit_{type(ast_node).__name__}")
-        if method is not None:
-            method(ast_node)
-        else:
-            for b in type(ast_node).__bases__:
-                method = self.get_method(f"visit_{b.__name__}")
-                if method is not None:
-                    method(ast_node)
-                    break
-
-            if method is None:
-                self.visit_children(ast_node)
+    def visit(self, ast_nodes=None):
+        if ast_nodes is None:
+            ast_nodes = [self.ast]
+
+        if not isinstance(ast_nodes, list):
+            ast_nodes = [ast_nodes]
+
+        for node in ast_nodes:
+            method = self.get_method(f"visit_{type(node).__name__}")
+            if method is not None:
+                method(node)
+            else:
+                for b in type(node).__bases__:
+                    method = self.get_method(f"visit_{b.__name__}")
+                    if method is not None:
+                        method(node)
+                        break
+
+                if method is None:
+                    self.visit(node.children())
 
     def visit_children(self, ast_node):
-        for c in ast_node.children():
-            self.visit(c)
+        self.visit(ast_node.children())
 
     def yield_elements_breadth_first(self, ast_node=None):
         nodes_to_visit = deque()
@@ -39,7 +42,6 @@ class Visitor:
             ast_node = self.ast
 
         nodes_to_visit.append(ast_node)
-
         while nodes_to_visit:
             next_node = nodes_to_visit.popleft() # nodes_to_visit.pop() for depth-first traversal
             yield next_node
diff --git a/src/pairs/sim/interaction.py b/src/pairs/sim/interaction.py
index f597b94..f84777b 100644
--- a/src/pairs/sim/interaction.py
+++ b/src/pairs/sim/interaction.py
@@ -1,5 +1,5 @@
 from pairs.ir.bin_op import BinOp
-from pairs.ir.block import Block, pairs_device_block
+from pairs.ir.block import Block, pairs_inline
 from pairs.ir.branches import Branch, Filter
 from pairs.ir.loops import For, ParticleFor
 from pairs.ir.types import Types
@@ -60,7 +60,7 @@ class ParticleInteraction(Lowerable):
         yield self.i, self.j
         self.sim.leave()
 
-    @pairs_device_block
+    @pairs_inline
     def lower(self):
         if self.nbody == 2:
             position = self.sim.position()
-- 
GitLab