Skip to content
Snippets Groups Projects
Commit a7e21796 authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Use Visitor for discovering block variants

parent a2ec33c2
No related branches found
No related tags found
1 merge request!1Implement DEM and many other features
......@@ -21,7 +21,7 @@ class Analysis:
self.apply(DetermineExpressionsTerminals())
def discover_block_variants(self):
DiscoverBlockVariants(self._ast).mutate()
self.apply(DiscoverBlockVariants())
def determine_parent_blocks(self):
self.apply(DetermineParentBlocks())
......
from pairs.ir.loops import For, While
from pairs.ir.mutator import Mutator
from pairs.ir.visitor import Visitor
class DiscoverBlockVariants(Mutator):
class DiscoverBlockVariants(Visitor):
def __init__(self, ast=None):
super().__init__(ast)
self.in_assignment = None
......@@ -14,98 +13,73 @@ class DiscoverBlockVariants(Mutator):
for block in self.blocks:
block.add_variant(ast_node.name())
return ast_node
def mutate_Block(self, ast_node):
def visit_Block(self, ast_node):
self.blocks.append(ast_node)
ast_node.stmts = [self.mutate(s) for s in ast_node.stmts]
self.visit_children(ast_node)
self.clear_visited_nodes()
self.blocks.pop()
return ast_node
def mutate_Assign(self, ast_node):
def visit_Assign(self, ast_node):
self.in_assignment = ast_node if len(self.blocks) > 0 else None
ast_node._dest = self.mutate(ast_node._dest)
self.visit(ast_node._dest)
self.in_assignment = None
return ast_node
def mutate_AtomicAdd(self, ast_node):
def visit_AtomicAdd(self, ast_node):
self.in_assignment = ast_node
ast_node.elem = self.mutate(ast_node.elem)
self.visit(ast_node.elem)
self.in_assignment = None
ast_node.value = self.mutate(ast_node.value)
self.visit(ast_node.value)
if ast_node.check_for_resize():
ast_node.resize = self.mutate(ast_node.resize)
ast_node.capacity = self.mutate(ast_node.capacity)
self.visit(ast_node.resize)
self.visit(ast_node.capacity)
return ast_node
def mutate_For(self, ast_node):
def visit_For(self, ast_node):
self.push_variant(ast_node.iterator)
ast_node.block.add_variant(ast_node.iterator.name())
ast_node.iterator = self.mutate(ast_node.iterator)
ast_node.block = self.mutate(ast_node.block)
return ast_node
def mutate_ScalarOp(self, ast_node):
ast_node.lhs = self.mutate(ast_node.lhs)
if not ast_node.operator().is_unary():
ast_node.rhs = self.mutate(ast_node.rhs)
return ast_node
def mutate_VectorOp(self, ast_node):
ast_node.lhs = self.mutate(ast_node.lhs)
if not ast_node.operator().is_unary():
ast_node.rhs = self.mutate(ast_node.rhs)
return ast_node
self.visit_children(ast_node)
def mutate_ArrayAccess(self, ast_node):
def visit_ArrayAccess(self, ast_node):
# For array accesses, we only want to include the array name, and not
# the index that is also present in the access node
ast_node.array = self.mutate(ast_node.array)
return ast_node
self.visit(ast_node.array)
def mutate_Array(self, ast_node):
return self.push_variant(ast_node)
def visit_Array(self, ast_node):
self.push_variant(ast_node)
# TODO: Array should be enough
def mutate_ArrayND(self, ast_node):
return self.push_variant(ast_node)
def visit_ArrayND(self, ast_node):
self.push_variant(ast_node)
def mutate_Iter(self, ast_node):
return self.push_variant(ast_node)
def visit_Iter(self, ast_node):
self.push_variant(ast_node)
def mutate_Property(self, ast_node):
return self.push_variant(ast_node)
def visit_Property(self, ast_node):
self.push_variant(ast_node)
def mutate_ContactProperty(self, ast_node):
return self.push_variant(ast_node)
def visit_ContactProperty(self, ast_node):
self.push_variant(ast_node)
def mutate_FeatureProperty(self, ast_node):
return self.push_variant(ast_node)
def visit_FeatureProperty(self, ast_node):
self.push_variant(ast_node)
def mutate_PropertyAccess(self, ast_node):
def visit_PropertyAccess(self, ast_node):
# For property accesses, we only want to include the property name, and not
# the index that is also present in the access node
ast_node.prop = self.mutate(ast_node.prop)
return ast_node
self.visit(ast_node.prop)
def mutate_ContactPropertyAccess(self, ast_node):
def visit_ContactPropertyAccess(self, ast_node):
# For property accesses, we only want to include the property name, and not
# the index that is also present in the access node
ast_node.contact_prop = self.mutate(ast_node.contact_prop)
return ast_node
self.visit(ast_node.contact_prop)
def mutate_FeaturePropertyAccess(self, ast_node):
def visit_FeaturePropertyAccess(self, ast_node):
# For property accesses, we only want to include the property name, and not
# the index that is also present in the access node
ast_node.feature_prop = self.mutate(ast_node.feature_prop)
return ast_node
self.visit(ast_node.feature_prop)
def mutate_Var(self, ast_node):
return self.push_variant(ast_node)
def visit_Var(self, ast_node):
self.push_variant(ast_node)
class DetermineParentBlocks(Visitor):
......
......@@ -69,7 +69,6 @@ class Simulation:
self.vtk_file = None
self._target = None
self._dom_part = DimensionRanges(self)
self.nparticles = self.nlocal + self.nghost
def add_module(self, module):
assert isinstance(module, Module), "add_module(): Given parameter is not of type Module!"
......
......@@ -46,11 +46,11 @@ class LICM(Mutator):
if self.loops and isinstance(ast_node.elem, elems_to_check):
last_loop = self.loops[-1]
loop_lifts = self.lifts[id(last_loop)]
#print(f"variants = {last_loop.block.variants}, terminals = {ast_node.elem.terminals}")
#print(f"id = {ast_node.elem.id()}, variants = {last_loop.block.variants}, terminals = {ast_node.elem.terminals}")
if not last_loop.block.variants.intersection(ast_node.elem.terminals):
found = False
for d in loop_lifts:
if ast_node.elem == d.elem:
for lifted_decls in loop_lifts:
if ast_node.elem == lifted_decls.elem:
found = True
if not found:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment