From 1593bd166359a78269bf1decf173e31c167bcd7d Mon Sep 17 00:00:00 2001 From: Rafael Ravedutti Lucio Machado <rafael.r.ravedutti@fau.de> Date: Tue, 12 Jan 2021 03:39:24 +0100 Subject: [PATCH] Provide better way to write analyses and transformations Signed-off-by: Rafael Ravedutti Lucio Machado <rafael.r.ravedutti@fau.de> --- ast/arrays.py | 9 --- ast/assign.py | 8 --- ast/ast_node.py | 3 - ast/bin_op.py | 10 --- ast/block.py | 7 -- ast/branches.py | 8 --- ast/cast.py | 4 -- ast/loops.py | 10 --- ast/math.py | 4 -- ast/memory.py | 10 --- ast/mutator.py | 139 ++++++++++++++++++++++++++++++++++++ ast/select.py | 7 -- ast/visitor.py | 43 ++++++----- sim/particle_simulation.py | 15 ++-- sim/timestep.py | 3 - sim/vtk.py | 12 ++-- transformations/flatten.py | 35 +++++++++ transformations/simplify.py | 36 ++++++++++ 18 files changed, 247 insertions(+), 116 deletions(-) create mode 100644 ast/mutator.py create mode 100644 transformations/flatten.py create mode 100644 transformations/simplify.py diff --git a/ast/arrays.py b/ast/arrays.py index 308bd12..1ea9014 100644 --- a/ast/arrays.py +++ b/ast/arrays.py @@ -175,15 +175,6 @@ class ArrayAccess(ASTTerm): def children(self): return [self.array] + self.indexes - def transform(self, fn): - self.array = self.array.transform(fn) - self.indexes = [i.transform(fn) for i in self.indexes] - - if self.index is not None: - self.index = self.index.transform(fn) - - return fn(self) - class ArrayDecl(ASTNode): def __init__(self, sim, array): diff --git a/ast/assign.py b/ast/assign.py index 8ce1d7b..a6cc919 100644 --- a/ast/assign.py +++ b/ast/assign.py @@ -31,11 +31,3 @@ class Assign(ASTNode): return reduce((lambda x, y: x + y), [ [self.assignments[i][0], self.assignments[i][1]] for i in range(0, len(self.assignments))]) - - def transform(self, fn): - self.assignments = [( - self.assignments[i][0].transform(fn), - self.assignments[i][1].transform(fn)) - for i in range(0, len(self.assignments))] - - return fn(self) diff --git a/ast/ast_node.py b/ast/ast_node.py index 7e42676..e58afc2 100644 --- a/ast/ast_node.py +++ b/ast/ast_node.py @@ -16,6 +16,3 @@ class ASTNode: def children(self): return [] - - def transform(self, fn): - return fn(self) diff --git a/ast/bin_op.py b/ast/bin_op.py index 840ca43..ba82620 100644 --- a/ast/bin_op.py +++ b/ast/bin_op.py @@ -17,10 +17,6 @@ class BinOpDef(ASTNode): def children(self): return [self.bin_op] - def transform(self, fn): - self.bin_op = self.bin_op.transform(fn) - return fn(self) - class BinOp(ASTNode): # BinOp kinds @@ -177,12 +173,6 @@ class BinOp(ASTNode): def children(self): return [self.lhs, self.rhs] - def transform(self, fn): - self.lhs = self.lhs.transform(fn) - self.rhs = self.rhs.transform(fn) - self.bin_op_vector_index_mapping = {i: e.transform(fn) for i, e in self.bin_op_vector_index_mapping.items()} - return fn(self) - def __add__(self, other): return BinOp(self.sim, self, other, '+') diff --git a/ast/block.py b/ast/block.py index b7f287e..e85face 100644 --- a/ast/block.py +++ b/ast/block.py @@ -1,5 +1,4 @@ from ast.ast_node import ASTNode -from ast.visitor import Visitor class Block(ASTNode): @@ -51,12 +50,6 @@ class Block(ASTNode): def children(self): return self.stmts - def transform(self, fn): - for i in range(0, len(self.stmts)): - self.stmts[i] = self.stmts[i].transform(fn) - - return fn(self) - def merge_blocks(block1, block2): assert isinstance(block1, Block), "First block type is not Block!" assert isinstance(block2, Block), "Second block type is not Block!" diff --git a/ast/branches.py b/ast/branches.py index ab31b0b..a630260 100644 --- a/ast/branches.py +++ b/ast/branches.py @@ -37,14 +37,6 @@ class Branch(ASTNode): return [self.cond, self.block_if] + \ ([] if self.block_else is None else [self.block_else]) - def transform(self, fn): - self.cond = self.cond.transform(fn) - self.block_if = self.block_if.transform(fn) - self.block_else = \ - None if self.block_else is None \ - else self.block_else.transform(fn) - return fn(self) - class Filter(Branch): def __init__(self, sim, cond): diff --git a/ast/cast.py b/ast/cast.py index 0a30700..39a8538 100644 --- a/ast/cast.py +++ b/ast/cast.py @@ -25,7 +25,3 @@ class Cast(ASTNode): def children(self): return [self.expr] - - def transform(self, fn): - self.expr = self.expr.transform(fn) - return fn(self) diff --git a/ast/loops.py b/ast/loops.py index 902cec3..e0c645c 100644 --- a/ast/loops.py +++ b/ast/loops.py @@ -67,11 +67,6 @@ class For(ASTNode): def children(self): return [self.iterator, self.block] - def transform(self, fn): - self.iterator = self.iterator.transform(fn) - self.block = self.block.transform(fn) - return fn(self) - class ParticleFor(For): def __init__(self, sim, block=None, local_only=True): @@ -104,11 +99,6 @@ class While(ASTNode): def children(self): return [self.cond, self.block] - def transform(self, fn): - self.cond = self.cond.transform(fn) - self.block = self.block.transform(fn) - return fn(self) - class NeighborFor(): def __init__(self, sim, particle, cell_lists): diff --git a/ast/math.py b/ast/math.py index 90977b8..4e610cb 100644 --- a/ast/math.py +++ b/ast/math.py @@ -18,7 +18,3 @@ class Sqrt(ASTNode): def children(self): return [self.expr] - - def transform(self, fn): - self.expr = self.expr.transform(fn) - return fn(self) diff --git a/ast/memory.py b/ast/memory.py index ef8abfe..3458bf9 100644 --- a/ast/memory.py +++ b/ast/memory.py @@ -18,11 +18,6 @@ class Malloc(ASTNode): def children(self): return [self.array, self.size] - def transform(self, fn): - self.array = self.array.transform(fn) - self.size = self.size.transform(fn) - return fn(self) - class Realloc(ASTNode): def __init__(self, sim, array, size): @@ -35,8 +30,3 @@ class Realloc(ASTNode): def children(self): return [self.array, self.size] - - def transform(self, fn): - self.array = self.array.transform(fn) - self.size = self.size.transform(fn) - return fn(self) diff --git a/ast/mutator.py b/ast/mutator.py new file mode 100644 index 0000000..f450773 --- /dev/null +++ b/ast/mutator.py @@ -0,0 +1,139 @@ +from ast.arrays import ArrayAccess +from ast.assign import Assign +from ast.bin_op import BinOp, BinOpDef +from ast.block import Block +from ast.branches import Branch +from ast.cast import Cast +from ast.loops import For, While +from ast.math import Sqrt +from ast.memory import Malloc, Realloc +from ast.select import Select +from sim.timestep import Timestep + + +class Mutator: + def __init__(self, ast, max_depth=0): + self.ast = ast + self.max_depth = 0 + + def mutate(self, ast_node=None): + if ast_node is None: + ast_node = self.ast + + if isinstance(ast_node, ArrayAccess): + return self.mutate_ArrayAccess(ast_node) + + elif isinstance(ast_node, Assign): + return self.mutate_Assign(ast_node) + + elif isinstance(ast_node, BinOp): + return self.mutate_BinOp(ast_node) + + elif isinstance(ast_node, BinOpDef): + return self.mutate_BinOpDef(ast_node) + + elif isinstance(ast_node, Block): + return self.mutate_Block(ast_node) + + elif isinstance(ast_node, Branch): + return self.mutate_Branch(ast_node) + + elif isinstance(ast_node, Cast): + return self.mutate_Cast(ast_node) + + elif isinstance(ast_node, For): + return self.mutate_For(ast_node) + + elif isinstance(ast_node, Malloc): + return self.mutate_Malloc(ast_node) + + elif isinstance(ast_node, Realloc): + return self.mutate_Realloc(ast_node) + + elif isinstance(ast_node, Select): + return self.mutate_Select(ast_node) + + elif isinstance(ast_node, Sqrt): + return self.mutate_Sqrt(ast_node) + + elif isinstance(ast_node, Timestep): + return self.mutate_Timestep(ast_node) + + elif isinstance(ast_node, While): + return self.mutate_While(ast_node) + + return ast_node + + def mutate_ArrayAccess(self, ast_node): + ast_node.array = self.mutate(ast_node.array) + ast_node.indexes = [self.mutate(i) for i in ast_node.indexes] + + if ast_node.index is not None: + ast_node.index = self.mutate(ast_node.index) + + return ast_node + + def mutate_Assign(self, ast_node): + ast_node.assignments = [ + (self.mutate(ast_node.assignments[i][0]), self.mutate(ast_node.assignments[i][1])) + for i in range(0, len(ast_node.assignments)) + ] + return ast_node + + def mutate_BinOp(self, ast_node): + ast_node.lhs = self.mutate(ast_node.lhs) + ast_node.rhs = self.mutate(ast_node.rhs) + ast_node.bin_op_vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.bin_op_vector_index_mapping.items()} + return ast_node + + def mutate_BinOpDef(self, ast_node): + ast_node.bin_op = self.mutate(ast_node.bin_op) + return ast_node + + def mutate_Block(self, ast_node): + ast_node.stmts = [self.mutate(s) for s in ast_node.stmts] + return ast_node + + def mutate_Branch(self, ast_node): + ast_node.cond = self.mutate(ast_node.cond) + ast_node.block_if = self.mutate(ast_node.block_if) + ast_node.block_else = None if ast_node.block_else is None else self.mutate(ast_node.block_else) + return ast_node + + def mutate_Cast(self, ast_node): + ast_node.expr = self.mutate(ast_node.expr) + return ast_node + + def mutate_For(self, ast_node): + ast_node.iterator = self.mutate(ast_node.iterator) + ast_node.block = self.mutate(ast_node.block) + return ast_node + + def mutate_Malloc(self, ast_node): + ast_node.array = self.mutate(ast_node.array) + ast_node.size = self.mutate(ast_node.size) + return ast_node + + def mutate_Realloc(self, ast_node): + ast_node.array = self.mutate(ast_node.array) + ast_node.size = self.mutate(ast_node.size) + return ast_node + + def mutate_Select(self, ast_node): + ast_node.cond = self.mutate(ast_node.cond) + ast_node.expr_if = self.mutate(ast_node.expr_if) + ast_node.expr_else = self.mutate(ast_node.expr_else) + return ast_node + + def mutate_Sqrt(self, ast_node): + ast_node.expr = self.mutate(ast_node.expr) + return ast_node + + def mutate_Timestep(self, ast_node): + ast_node.block = self.mutate(ast_node.block) + return ast_node + + def mutate_While(self, ast_node): + ast_node.cond = self.mutate(ast_node.cond) + ast_node.block = self.mutate(ast_node.block) + return ast_node diff --git a/ast/select.py b/ast/select.py index 702bb8f..1fe9d08 100644 --- a/ast/select.py +++ b/ast/select.py @@ -12,10 +12,3 @@ class Select(ASTNode): def children(self): return [self.cond, self.expr_if, self.expr_else] - - def transform(self, fn): - self.cond = self.cond.transform(fn) - self.expr_if = self.expr_if.transform(fn) - self.expr_else = self.expr_else.transform(fn) - return fn(self) - diff --git a/ast/visitor.py b/ast/visitor.py index 8cf72e9..419b5e7 100644 --- a/ast/visitor.py +++ b/ast/visitor.py @@ -1,22 +1,31 @@ -class Visitor: - def __init__(self, ast, enter_fn=None, leave_fn=None, max_depth=0): - self.ast = ast - self.enter_fn = enter_fn - self.leave_fn = leave_fn - self.max_depth = max_depth - - def visit(self): - self.visit_rec(self.ast) +from ast.bin_op import BinOp +from ast.loops import For, ParticleFor, While - def visit_rec(self, ast): - if self.enter_fn is not None: - self.enter_fn(ast) - for c in ast.children(): - self.visit_rec(c) - - if self.leave_fn is not None: - self.leave_fn(ast) +class Visitor: + def __init__(self, ast, max_depth=0): + self.ast = ast + self.max_depth = 0 + + def visit(ast_node): + for c in ast_node.children(): + if isinstance(c, Array): + self.visit_Array(c) + elif isinstance(c, BinOp): + self.visit_BinOp(c) + elif isinstance(c, (For, ParticleFor, While)): + self.visit_Loop(c) + else: + self.visit(c) + + def visit_Array(self, ast_node): + return self.visit(ast_node) + + def visit_BinOp(self, ast_node): + return self.visit(ast_node) + + def visit_Loop(self, ast_node): + return self.visit(ast_node) def list_ast(self): self.list_elements(self.ast) diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py index 12104a3..53f8ef1 100644 --- a/sim/particle_simulation.py +++ b/sim/particle_simulation.py @@ -5,10 +5,9 @@ from ast.data_types import Type_Int, Type_Float, Type_Vector from ast.layouts import Layout_AoS from ast.loops import ParticleFor, NeighborFor from ast.properties import Properties -from ast.transform import Transform from ast.variables import Variables -from sim.arrays import ArraysDecl from graph.graphviz import ASTGraph +from sim.arrays import ArraysDecl from sim.cell_lists import CellLists, CellListsBuild, CellListsStencilBuild from sim.grid import Grid2D, Grid3D from sim.kernel_wrapper import KernelWrapper @@ -20,6 +19,8 @@ from sim.setup_wrapper import SetupWrapper from sim.timestep import Timestep from sim.variables import VariablesDecl from sim.vtk import VTKWrite +from transformations.flatten import flatten_property_accesses +from transformations.simplify import simplify_expressions class ParticleSimulation: @@ -193,12 +194,10 @@ class ParticleSimulation: program = Block.merge_blocks(decls, body) self.global_scope = program - Transform.apply(program, Transform.flatten) - Transform.apply(program, Transform.simplify) - #Transform.apply(program, Transform.reuse_index_expressions) - #Transform.apply(program, Transform.reuse_expr_expressions) - #Transform.apply(program, Transform.reuse_array_access_expressions) - #Transform.apply(program, Transform.move_loop_invariant_expressions) + + # Transformations + flatten_property_accesses(program) + simplify_expressions(program) ASTGraph(self.kernels.lower(), "kernels").generate_and_view() self.code_gen.generate_program(self, program) diff --git a/sim/timestep.py b/sim/timestep.py index 0d54133..74aa241 100644 --- a/sim/timestep.py +++ b/sim/timestep.py @@ -41,6 +41,3 @@ class Timestep: def as_block(self): return Block(self.sim, [self.timestep_loop]) - - def transform(self, fn): - self.block = self.block.transform(fn) diff --git a/sim/vtk.py b/sim/vtk.py index c1620f8..284533a 100644 --- a/sim/vtk.py +++ b/sim/vtk.py @@ -1,17 +1,13 @@ from ast.lit import as_lit_ast +from ast.ast_node import ASTNode -class VTKWrite: + +class VTKWrite(ASTNode): vtk_id = 0 def __init__(self, sim, filename, timestep): - self.sim = sim + super().__init__(sim) self.vtk_id = VTKWrite.vtk_id self.filename = filename self.timestep = as_lit_ast(sim, timestep) VTKWrite.vtk_id += 1 - - def children(self): - return [] - - def transform(self, fn): - return fn(self) diff --git a/transformations/flatten.py b/transformations/flatten.py new file mode 100644 index 0000000..b06ebcb --- /dev/null +++ b/transformations/flatten.py @@ -0,0 +1,35 @@ +from ast.layouts import Layout_AoS, Layout_SoA +from ast.mutator import Mutator + + +class FlattenPropertyAccesses(Mutator): + def __init__(self, ast): + super().__init__(ast) + + def mutate_BinOp(self, ast_node): + ast_node.lhs = self.mutate(ast_node.lhs) + ast_node.rhs = self.mutate(ast_node.rhs) + + if ast_node.is_vector_property_access(): + layout = ast_node.lhs.layout() + + for i in ast_node.vector_indexes(): + flat_index = None + + if layout == Layout_AoS: + flat_index = ast_node.rhs * ast_node.sim.dimensions + i + + elif layout == Layout_SoA: + flat_index = i * ast_node.sim.particle_capacity + ast_node.rhs + + else: + raise Exception("Invalid property layout!") + + ast_node.map_vector_index(i, flat_index) + + return ast_node + + +def flatten_property_accesses(ast_node): + flatten = FlattenPropertyAccesses(ast_node) + flatten.mutate() diff --git a/transformations/simplify.py b/transformations/simplify.py new file mode 100644 index 0000000..ff67315 --- /dev/null +++ b/transformations/simplify.py @@ -0,0 +1,36 @@ +from ast.data_types import Type_Int +from ast.lit import Lit +from ast.mutator import Mutator + + +class SimplifyExpressions(Mutator): + def __init__(self, ast): + super().__init__(ast) + + def mutate_BinOp(self, ast_node): + sim = ast_node.lhs.sim + ast_node.lhs = self.mutate(ast_node.lhs) + ast_node.rhs = self.mutate(ast_node.rhs) + ast_node.bin_op_vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.bin_op_vector_index_mapping.items()} + + if ast_node.op in ['+', '-'] and ast_node.rhs == 0: + return ast_node.lhs + + if ast_node.op in ['+'] and ast_node.lhs == 0: + return ast_node.rhs + + if ast_node.op in ['*', '/'] and ast_node.rhs == 1: + return ast_node.lhs + + if ast_node.op == '*' and ast_node.lhs == 1: + return ast_node.rhs + + if ast_node.op == '*' and ast_node.lhs == 0: + return Lit(sim, 0 if ast_node.type() == Type_Int else 0.0) + + return ast_node + + +def simplify_expressions(ast_node): + simplify = SimplifyExpressions(ast_node) + simplify.mutate() -- GitLab