Skip to content
Snippets Groups Projects
Commit a3e7154a authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Inline particle interactions and fix modules references

parent 216461eb
Branches
Tags
No related merge requests found
......@@ -7,20 +7,37 @@ class FetchModulesReferences(Visitor):
self.module_stack = []
self.writing = False
def visit_ArrayAccess(self, ast_node):
# Visit array and save current writing state
self.visit(ast_node.array)
writing_state = self.writing
# Index elements are read-only
self.writing = False
self.visit([roc for roc in ast_node.children() if roc != ast_node.array])
self.writing = writing_state
def visit_Assign(self, ast_node):
self.writing = True
for c in ast_node.destinations():
self.visit(c)
self.visit(ast_node.destinations())
self.writing = False
for c in ast_node.sources():
self.visit(c)
self.visit(ast_node.sources())
def visit_Module(self, ast_node):
self.module_stack.append(ast_node)
self.visit_children(ast_node)
self.module_stack.pop()
def visit_PropertyAccess(self, ast_node):
# Visit property and save current writing state
self.visit(ast_node.prop)
writing_state = self.writing
# Index elements are read-only
self.writing = False
self.visit([roc for roc in ast_node.children() if roc != ast_node.prop])
self.writing = writing_state
def visit_Array(self, ast_node):
for m in self.module_stack:
m.add_array(ast_node, self.writing)
......
......@@ -11,26 +11,29 @@ class Visitor:
method = getattr(self, method_name, None)
return method if callable(method) else None
def visit(self, ast_node=None):
if ast_node is None:
ast_node = self.ast
method = self.get_method(f"visit_{type(ast_node).__name__}")
if method is not None:
method(ast_node)
else:
for b in type(ast_node).__bases__:
method = self.get_method(f"visit_{b.__name__}")
if method is not None:
method(ast_node)
break
if method is None:
self.visit_children(ast_node)
def visit(self, ast_nodes=None):
if ast_nodes is None:
ast_nodes = [self.ast]
if not isinstance(ast_nodes, list):
ast_nodes = [ast_nodes]
for node in ast_nodes:
method = self.get_method(f"visit_{type(node).__name__}")
if method is not None:
method(node)
else:
for b in type(node).__bases__:
method = self.get_method(f"visit_{b.__name__}")
if method is not None:
method(node)
break
if method is None:
self.visit(node.children())
def visit_children(self, ast_node):
for c in ast_node.children():
self.visit(c)
self.visit(ast_node.children())
def yield_elements_breadth_first(self, ast_node=None):
nodes_to_visit = deque()
......@@ -39,7 +42,6 @@ class Visitor:
ast_node = self.ast
nodes_to_visit.append(ast_node)
while nodes_to_visit:
next_node = nodes_to_visit.popleft() # nodes_to_visit.pop() for depth-first traversal
yield next_node
......
from pairs.ir.bin_op import BinOp
from pairs.ir.block import Block, pairs_device_block
from pairs.ir.block import Block, pairs_inline
from pairs.ir.branches import Branch, Filter
from pairs.ir.loops import For, ParticleFor
from pairs.ir.types import Types
......@@ -60,7 +60,7 @@ class ParticleInteraction(Lowerable):
yield self.i, self.j
self.sim.leave()
@pairs_device_block
@pairs_inline
def lower(self):
if self.nbody == 2:
position = self.sim.position()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment