diff --git a/src/pairs/analysis/modules.py b/src/pairs/analysis/modules.py index 6cf31a5d594bf35d2a8db006bb0007f4a6f70e47..e9192804898c2b6579732f297b8af0840b2b3410 100644 --- a/src/pairs/analysis/modules.py +++ b/src/pairs/analysis/modules.py @@ -7,20 +7,37 @@ class FetchModulesReferences(Visitor): self.module_stack = [] self.writing = False + def visit_ArrayAccess(self, ast_node): + # Visit array and save current writing state + self.visit(ast_node.array) + writing_state = self.writing + + # Index elements are read-only + self.writing = False + self.visit([roc for roc in ast_node.children() if roc != ast_node.array]) + self.writing = writing_state + def visit_Assign(self, ast_node): self.writing = True - for c in ast_node.destinations(): - self.visit(c) - + self.visit(ast_node.destinations()) self.writing = False - for c in ast_node.sources(): - self.visit(c) + self.visit(ast_node.sources()) def visit_Module(self, ast_node): self.module_stack.append(ast_node) self.visit_children(ast_node) self.module_stack.pop() + def visit_PropertyAccess(self, ast_node): + # Visit property and save current writing state + self.visit(ast_node.prop) + writing_state = self.writing + + # Index elements are read-only + self.writing = False + self.visit([roc for roc in ast_node.children() if roc != ast_node.prop]) + self.writing = writing_state + def visit_Array(self, ast_node): for m in self.module_stack: m.add_array(ast_node, self.writing) diff --git a/src/pairs/ir/visitor.py b/src/pairs/ir/visitor.py index e75c5bb9b55979d4f0826e12b5dfceedac2437f6..190e05eeab3160335a500e03b6f96149edcf488e 100644 --- a/src/pairs/ir/visitor.py +++ b/src/pairs/ir/visitor.py @@ -11,26 +11,29 @@ class Visitor: method = getattr(self, method_name, None) return method if callable(method) else None - def visit(self, ast_node=None): - if ast_node is None: - ast_node = self.ast - - method = self.get_method(f"visit_{type(ast_node).__name__}") - if method is not None: - method(ast_node) - else: - for b in type(ast_node).__bases__: - method = self.get_method(f"visit_{b.__name__}") - if method is not None: - method(ast_node) - break - - if method is None: - self.visit_children(ast_node) + def visit(self, ast_nodes=None): + if ast_nodes is None: + ast_nodes = [self.ast] + + if not isinstance(ast_nodes, list): + ast_nodes = [ast_nodes] + + for node in ast_nodes: + method = self.get_method(f"visit_{type(node).__name__}") + if method is not None: + method(node) + else: + for b in type(node).__bases__: + method = self.get_method(f"visit_{b.__name__}") + if method is not None: + method(node) + break + + if method is None: + self.visit(node.children()) def visit_children(self, ast_node): - for c in ast_node.children(): - self.visit(c) + self.visit(ast_node.children()) def yield_elements_breadth_first(self, ast_node=None): nodes_to_visit = deque() @@ -39,7 +42,6 @@ class Visitor: ast_node = self.ast nodes_to_visit.append(ast_node) - while nodes_to_visit: next_node = nodes_to_visit.popleft() # nodes_to_visit.pop() for depth-first traversal yield next_node diff --git a/src/pairs/sim/interaction.py b/src/pairs/sim/interaction.py index f597b942b1811f3f42ae6b31515f0aac8282a027..f84777b5d033755bc7e662b357728be8870a8f1b 100644 --- a/src/pairs/sim/interaction.py +++ b/src/pairs/sim/interaction.py @@ -1,5 +1,5 @@ from pairs.ir.bin_op import BinOp -from pairs.ir.block import Block, pairs_device_block +from pairs.ir.block import Block, pairs_inline from pairs.ir.branches import Branch, Filter from pairs.ir.loops import For, ParticleFor from pairs.ir.types import Types @@ -60,7 +60,7 @@ class ParticleInteraction(Lowerable): yield self.i, self.j self.sim.leave() - @pairs_device_block + @pairs_inline def lower(self): if self.nbody == 2: position = self.sim.position()