Skip to content
Snippets Groups Projects
Commit a3e7154a authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Inline particle interactions and fix modules references


Signed-off-by: default avatarRafael Ravedutti <rafaelravedutti@gmail.com>
parent 216461eb
No related merge requests found
...@@ -7,20 +7,37 @@ class FetchModulesReferences(Visitor): ...@@ -7,20 +7,37 @@ class FetchModulesReferences(Visitor):
self.module_stack = [] self.module_stack = []
self.writing = False self.writing = False
def visit_ArrayAccess(self, ast_node):
# Visit array and save current writing state
self.visit(ast_node.array)
writing_state = self.writing
# Index elements are read-only
self.writing = False
self.visit([roc for roc in ast_node.children() if roc != ast_node.array])
self.writing = writing_state
def visit_Assign(self, ast_node): def visit_Assign(self, ast_node):
self.writing = True self.writing = True
for c in ast_node.destinations(): self.visit(ast_node.destinations())
self.visit(c)
self.writing = False self.writing = False
for c in ast_node.sources(): self.visit(ast_node.sources())
self.visit(c)
def visit_Module(self, ast_node): def visit_Module(self, ast_node):
self.module_stack.append(ast_node) self.module_stack.append(ast_node)
self.visit_children(ast_node) self.visit_children(ast_node)
self.module_stack.pop() self.module_stack.pop()
def visit_PropertyAccess(self, ast_node):
# Visit property and save current writing state
self.visit(ast_node.prop)
writing_state = self.writing
# Index elements are read-only
self.writing = False
self.visit([roc for roc in ast_node.children() if roc != ast_node.prop])
self.writing = writing_state
def visit_Array(self, ast_node): def visit_Array(self, ast_node):
for m in self.module_stack: for m in self.module_stack:
m.add_array(ast_node, self.writing) m.add_array(ast_node, self.writing)
......
...@@ -11,26 +11,29 @@ class Visitor: ...@@ -11,26 +11,29 @@ class Visitor:
method = getattr(self, method_name, None) method = getattr(self, method_name, None)
return method if callable(method) else None return method if callable(method) else None
def visit(self, ast_node=None): def visit(self, ast_nodes=None):
if ast_node is None: if ast_nodes is None:
ast_node = self.ast ast_nodes = [self.ast]
method = self.get_method(f"visit_{type(ast_node).__name__}") if not isinstance(ast_nodes, list):
if method is not None: ast_nodes = [ast_nodes]
method(ast_node)
else: for node in ast_nodes:
for b in type(ast_node).__bases__: method = self.get_method(f"visit_{type(node).__name__}")
method = self.get_method(f"visit_{b.__name__}") if method is not None:
if method is not None: method(node)
method(ast_node) else:
break for b in type(node).__bases__:
method = self.get_method(f"visit_{b.__name__}")
if method is None: if method is not None:
self.visit_children(ast_node) method(node)
break
if method is None:
self.visit(node.children())
def visit_children(self, ast_node): def visit_children(self, ast_node):
for c in ast_node.children(): self.visit(ast_node.children())
self.visit(c)
def yield_elements_breadth_first(self, ast_node=None): def yield_elements_breadth_first(self, ast_node=None):
nodes_to_visit = deque() nodes_to_visit = deque()
...@@ -39,7 +42,6 @@ class Visitor: ...@@ -39,7 +42,6 @@ class Visitor:
ast_node = self.ast ast_node = self.ast
nodes_to_visit.append(ast_node) nodes_to_visit.append(ast_node)
while nodes_to_visit: while nodes_to_visit:
next_node = nodes_to_visit.popleft() # nodes_to_visit.pop() for depth-first traversal next_node = nodes_to_visit.popleft() # nodes_to_visit.pop() for depth-first traversal
yield next_node yield next_node
......
from pairs.ir.bin_op import BinOp from pairs.ir.bin_op import BinOp
from pairs.ir.block import Block, pairs_device_block from pairs.ir.block import Block, pairs_inline
from pairs.ir.branches import Branch, Filter from pairs.ir.branches import Branch, Filter
from pairs.ir.loops import For, ParticleFor from pairs.ir.loops import For, ParticleFor
from pairs.ir.types import Types from pairs.ir.types import Types
...@@ -60,7 +60,7 @@ class ParticleInteraction(Lowerable): ...@@ -60,7 +60,7 @@ class ParticleInteraction(Lowerable):
yield self.i, self.j yield self.i, self.j
self.sim.leave() self.sim.leave()
@pairs_device_block @pairs_inline
def lower(self): def lower(self):
if self.nbody == 2: if self.nbody == 2:
position = self.sim.position() position = self.sim.position()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment