Skip to content
Snippets Groups Projects
Commit 4be6e795 authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Fix prioritaze_scalar_ops transformation

parent 59f749c9
No related branches found
No related tags found
No related merge requests found
...@@ -68,7 +68,7 @@ class CGen: ...@@ -68,7 +68,7 @@ class CGen:
if isinstance(ast_node, BinOpDef): if isinstance(ast_node, BinOpDef):
bin_op = ast_node.bin_op bin_op = ast_node.bin_op
if not isinstance(bin_op, BinOp): if not isinstance(bin_op, BinOp) or not ast_node.used:
return None return None
if bin_op.inlined is False and bin_op.operator() != '[]' and bin_op.generated is False: if bin_op.inlined is False and bin_op.operator() != '[]' and bin_op.generated is False:
......
...@@ -173,6 +173,9 @@ class ArrayAccess(ASTTerm): ...@@ -173,6 +173,9 @@ class ArrayAccess(ASTTerm):
return self.index.scope() return self.index.scope()
def children(self): def children(self):
if self.index is not None:
return [self.array, self.index]
return [self.array] + self.indexes return [self.array] + self.indexes
......
...@@ -10,7 +10,7 @@ class BinOpDef(ASTNode): ...@@ -10,7 +10,7 @@ class BinOpDef(ASTNode):
super().__init__(bin_op.sim) super().__init__(bin_op.sim)
self.bin_op = bin_op self.bin_op = bin_op
self.bin_op.sim.add_statement(self) self.bin_op.sim.add_statement(self)
self.used = False self.used = not bin_op.sim.check_bin_ops_usage
def __str__(self): def __str__(self):
return f"BinOpDef<bin_op: self.bin_op>" return f"BinOpDef<bin_op: self.bin_op>"
...@@ -85,6 +85,9 @@ class BinOp(ASTNode): ...@@ -85,6 +85,9 @@ class BinOp(ASTNode):
mapping = self.vector_index_mapping mapping = self.vector_index_mapping
return mapping[index] if index in mapping else as_lit_ast(self.sim, index) return mapping[index] if index in mapping else as_lit_ast(self.sim, index)
def mapped_expressions(self):
return self.vector_index_mapping.values()
@property @property
def vector_indexes(self): def vector_indexes(self):
return self._vector_indexes return self._vector_indexes
......
...@@ -67,7 +67,7 @@ class For(ASTNode): ...@@ -67,7 +67,7 @@ class For(ASTNode):
self.block.add_statement(stmt) self.block.add_statement(stmt)
def children(self): def children(self):
return [self.iterator, self.block] return [self.iterator, self.block, self.min, self.max]
class ParticleFor(For): class ParticleFor(For):
......
...@@ -42,6 +42,7 @@ class ParticleSimulation: ...@@ -42,6 +42,7 @@ class ParticleSimulation:
self.scope = [] self.scope = []
self.nested_count = 0 self.nested_count = 0
self.nest = False self.nest = False
self.check_bin_ops_usage = True
self.block = Block(self, []) self.block = Block(self, [])
self.setups = SetupWrapper() self.setups = SetupWrapper()
self.kernels = KernelWrapper() self.kernels = KernelWrapper()
...@@ -199,11 +200,14 @@ class ParticleSimulation: ...@@ -199,11 +200,14 @@ class ParticleSimulation:
self.global_scope = program self.global_scope = program
# Transformations # Transformations
#prioritaze_scalar_ops(program) prioritaze_scalar_ops(program)
flatten_property_accesses(program) flatten_property_accesses(program)
simplify_expressions(program) simplify_expressions(program)
move_loop_invariant_code(program) move_loop_invariant_code(program)
#set_used_bin_ops(program) set_used_bin_ops(program)
# For this part on, all bin ops are generated without usage verification
self.check_bin_ops_usage = False
ASTGraph(self.kernels.lower(), "kernels").render() ASTGraph(self.kernels.lower(), "kernels").render()
self.code_gen.generate_program(program) self.code_gen.generate_program(program)
...@@ -11,3 +11,6 @@ class VTKWrite(ASTNode): ...@@ -11,3 +11,6 @@ class VTKWrite(ASTNode):
self.filename = filename self.filename = filename
self.timestep = as_lit_ast(sim, timestep) self.timestep = as_lit_ast(sim, timestep)
VTKWrite.vtk_id += 1 VTKWrite.vtk_id += 1
def children(self):
return [self.timestep]
...@@ -7,8 +7,15 @@ class SetUsedBinOps(Visitor): ...@@ -7,8 +7,15 @@ class SetUsedBinOps(Visitor):
super().__init__(ast) super().__init__(ast)
self.bin_ops = [] self.bin_ops = []
def visit_BinOpDef(self, ast_node):
pass
def visit_BinOp(self, ast_node): def visit_BinOp(self, ast_node):
ast_node.bin_op_def.used = True ast_node.bin_op_def.used = True
self.visit_children(ast_node)
# TODO: These expressions could be automatically included in visitor traversal
for vidxs in ast_node.mapped_expressions():
self.visit(vidxs)
def set_used_bin_ops(ast): def set_used_bin_ops(ast):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment