diff --git a/ast/arrays.py b/ast/arrays.py index 308bd12bf7bfedbec1f42d6a3ef7ae1cdb96ee00..1ea90143c4a2cd15fe0cbc4ac5b8d3c345bb0da7 100644 --- a/ast/arrays.py +++ b/ast/arrays.py @@ -175,15 +175,6 @@ class ArrayAccess(ASTTerm): def children(self): return [self.array] + self.indexes - def transform(self, fn): - self.array = self.array.transform(fn) - self.indexes = [i.transform(fn) for i in self.indexes] - - if self.index is not None: - self.index = self.index.transform(fn) - - return fn(self) - class ArrayDecl(ASTNode): def __init__(self, sim, array): diff --git a/ast/assign.py b/ast/assign.py index 8ce1d7b8ac8e3027dc3f2b2e4a78cefb210e04b6..a6cc919c8e528b16f6160549c2461ef9c05fa08d 100644 --- a/ast/assign.py +++ b/ast/assign.py @@ -31,11 +31,3 @@ class Assign(ASTNode): return reduce((lambda x, y: x + y), [ [self.assignments[i][0], self.assignments[i][1]] for i in range(0, len(self.assignments))]) - - def transform(self, fn): - self.assignments = [( - self.assignments[i][0].transform(fn), - self.assignments[i][1].transform(fn)) - for i in range(0, len(self.assignments))] - - return fn(self) diff --git a/ast/ast_node.py b/ast/ast_node.py index 7e426762ba235fba865137a30c45c939eeb88638..e58afc26693907e38fd322ff5969d7a37271e675 100644 --- a/ast/ast_node.py +++ b/ast/ast_node.py @@ -16,6 +16,3 @@ class ASTNode: def children(self): return [] - - def transform(self, fn): - return fn(self) diff --git a/ast/bin_op.py b/ast/bin_op.py index 840ca43fa81fd8bcf28acc90f7d43fc1dfd01c7d..ba826203772cb6bae73e7396b097c58d0cce12f4 100644 --- a/ast/bin_op.py +++ b/ast/bin_op.py @@ -17,10 +17,6 @@ class BinOpDef(ASTNode): def children(self): return [self.bin_op] - def transform(self, fn): - self.bin_op = self.bin_op.transform(fn) - return fn(self) - class BinOp(ASTNode): # BinOp kinds @@ -177,12 +173,6 @@ class BinOp(ASTNode): def children(self): return [self.lhs, self.rhs] - def transform(self, fn): - self.lhs = self.lhs.transform(fn) - self.rhs = self.rhs.transform(fn) - self.bin_op_vector_index_mapping = {i: e.transform(fn) for i, e in self.bin_op_vector_index_mapping.items()} - return fn(self) - def __add__(self, other): return BinOp(self.sim, self, other, '+') diff --git a/ast/block.py b/ast/block.py index b7f287e1b055dc1d49160858c49e631c550856ce..e85facec636e02ee94235d762cef1f85f355db07 100644 --- a/ast/block.py +++ b/ast/block.py @@ -1,5 +1,4 @@ from ast.ast_node import ASTNode -from ast.visitor import Visitor class Block(ASTNode): @@ -51,12 +50,6 @@ class Block(ASTNode): def children(self): return self.stmts - def transform(self, fn): - for i in range(0, len(self.stmts)): - self.stmts[i] = self.stmts[i].transform(fn) - - return fn(self) - def merge_blocks(block1, block2): assert isinstance(block1, Block), "First block type is not Block!" assert isinstance(block2, Block), "Second block type is not Block!" diff --git a/ast/branches.py b/ast/branches.py index ab31b0b17dbe0ab049c79d63d26026d0fa6f32ff..a6302609d42b7ed6b25513a2ce66bdbc5717e027 100644 --- a/ast/branches.py +++ b/ast/branches.py @@ -37,14 +37,6 @@ class Branch(ASTNode): return [self.cond, self.block_if] + \ ([] if self.block_else is None else [self.block_else]) - def transform(self, fn): - self.cond = self.cond.transform(fn) - self.block_if = self.block_if.transform(fn) - self.block_else = \ - None if self.block_else is None \ - else self.block_else.transform(fn) - return fn(self) - class Filter(Branch): def __init__(self, sim, cond): diff --git a/ast/cast.py b/ast/cast.py index 0a30700b05311ed3bdc6e3d116bb0c89a9764391..39a8538ae44163d61c110852dc4647bb7dd81ede 100644 --- a/ast/cast.py +++ b/ast/cast.py @@ -25,7 +25,3 @@ class Cast(ASTNode): def children(self): return [self.expr] - - def transform(self, fn): - self.expr = self.expr.transform(fn) - return fn(self) diff --git a/ast/loops.py b/ast/loops.py index 902cec345ad65dc2968c8b9d7861e90f46a5c1cb..e0c645cfe256a281b509a14f1d4df8487da911b1 100644 --- a/ast/loops.py +++ b/ast/loops.py @@ -67,11 +67,6 @@ class For(ASTNode): def children(self): return [self.iterator, self.block] - def transform(self, fn): - self.iterator = self.iterator.transform(fn) - self.block = self.block.transform(fn) - return fn(self) - class ParticleFor(For): def __init__(self, sim, block=None, local_only=True): @@ -104,11 +99,6 @@ class While(ASTNode): def children(self): return [self.cond, self.block] - def transform(self, fn): - self.cond = self.cond.transform(fn) - self.block = self.block.transform(fn) - return fn(self) - class NeighborFor(): def __init__(self, sim, particle, cell_lists): diff --git a/ast/math.py b/ast/math.py index 90977b8af5d53e8716245ad867b8787bab8dec66..4e610cb378277962039096f342acc553ece932d0 100644 --- a/ast/math.py +++ b/ast/math.py @@ -18,7 +18,3 @@ class Sqrt(ASTNode): def children(self): return [self.expr] - - def transform(self, fn): - self.expr = self.expr.transform(fn) - return fn(self) diff --git a/ast/memory.py b/ast/memory.py index ef8abfed3c2f7a736861d04c4e188e457e56af16..3458bf9fa464a1a4cde36b555b85324fff40b4b1 100644 --- a/ast/memory.py +++ b/ast/memory.py @@ -18,11 +18,6 @@ class Malloc(ASTNode): def children(self): return [self.array, self.size] - def transform(self, fn): - self.array = self.array.transform(fn) - self.size = self.size.transform(fn) - return fn(self) - class Realloc(ASTNode): def __init__(self, sim, array, size): @@ -35,8 +30,3 @@ class Realloc(ASTNode): def children(self): return [self.array, self.size] - - def transform(self, fn): - self.array = self.array.transform(fn) - self.size = self.size.transform(fn) - return fn(self) diff --git a/ast/mutator.py b/ast/mutator.py new file mode 100644 index 0000000000000000000000000000000000000000..f450773eb5081d27f4fa8f5f1df8546024802323 --- /dev/null +++ b/ast/mutator.py @@ -0,0 +1,139 @@ +from ast.arrays import ArrayAccess +from ast.assign import Assign +from ast.bin_op import BinOp, BinOpDef +from ast.block import Block +from ast.branches import Branch +from ast.cast import Cast +from ast.loops import For, While +from ast.math import Sqrt +from ast.memory import Malloc, Realloc +from ast.select import Select +from sim.timestep import Timestep + + +class Mutator: + def __init__(self, ast, max_depth=0): + self.ast = ast + self.max_depth = 0 + + def mutate(self, ast_node=None): + if ast_node is None: + ast_node = self.ast + + if isinstance(ast_node, ArrayAccess): + return self.mutate_ArrayAccess(ast_node) + + elif isinstance(ast_node, Assign): + return self.mutate_Assign(ast_node) + + elif isinstance(ast_node, BinOp): + return self.mutate_BinOp(ast_node) + + elif isinstance(ast_node, BinOpDef): + return self.mutate_BinOpDef(ast_node) + + elif isinstance(ast_node, Block): + return self.mutate_Block(ast_node) + + elif isinstance(ast_node, Branch): + return self.mutate_Branch(ast_node) + + elif isinstance(ast_node, Cast): + return self.mutate_Cast(ast_node) + + elif isinstance(ast_node, For): + return self.mutate_For(ast_node) + + elif isinstance(ast_node, Malloc): + return self.mutate_Malloc(ast_node) + + elif isinstance(ast_node, Realloc): + return self.mutate_Realloc(ast_node) + + elif isinstance(ast_node, Select): + return self.mutate_Select(ast_node) + + elif isinstance(ast_node, Sqrt): + return self.mutate_Sqrt(ast_node) + + elif isinstance(ast_node, Timestep): + return self.mutate_Timestep(ast_node) + + elif isinstance(ast_node, While): + return self.mutate_While(ast_node) + + return ast_node + + def mutate_ArrayAccess(self, ast_node): + ast_node.array = self.mutate(ast_node.array) + ast_node.indexes = [self.mutate(i) for i in ast_node.indexes] + + if ast_node.index is not None: + ast_node.index = self.mutate(ast_node.index) + + return ast_node + + def mutate_Assign(self, ast_node): + ast_node.assignments = [ + (self.mutate(ast_node.assignments[i][0]), self.mutate(ast_node.assignments[i][1])) + for i in range(0, len(ast_node.assignments)) + ] + return ast_node + + def mutate_BinOp(self, ast_node): + ast_node.lhs = self.mutate(ast_node.lhs) + ast_node.rhs = self.mutate(ast_node.rhs) + ast_node.bin_op_vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.bin_op_vector_index_mapping.items()} + return ast_node + + def mutate_BinOpDef(self, ast_node): + ast_node.bin_op = self.mutate(ast_node.bin_op) + return ast_node + + def mutate_Block(self, ast_node): + ast_node.stmts = [self.mutate(s) for s in ast_node.stmts] + return ast_node + + def mutate_Branch(self, ast_node): + ast_node.cond = self.mutate(ast_node.cond) + ast_node.block_if = self.mutate(ast_node.block_if) + ast_node.block_else = None if ast_node.block_else is None else self.mutate(ast_node.block_else) + return ast_node + + def mutate_Cast(self, ast_node): + ast_node.expr = self.mutate(ast_node.expr) + return ast_node + + def mutate_For(self, ast_node): + ast_node.iterator = self.mutate(ast_node.iterator) + ast_node.block = self.mutate(ast_node.block) + return ast_node + + def mutate_Malloc(self, ast_node): + ast_node.array = self.mutate(ast_node.array) + ast_node.size = self.mutate(ast_node.size) + return ast_node + + def mutate_Realloc(self, ast_node): + ast_node.array = self.mutate(ast_node.array) + ast_node.size = self.mutate(ast_node.size) + return ast_node + + def mutate_Select(self, ast_node): + ast_node.cond = self.mutate(ast_node.cond) + ast_node.expr_if = self.mutate(ast_node.expr_if) + ast_node.expr_else = self.mutate(ast_node.expr_else) + return ast_node + + def mutate_Sqrt(self, ast_node): + ast_node.expr = self.mutate(ast_node.expr) + return ast_node + + def mutate_Timestep(self, ast_node): + ast_node.block = self.mutate(ast_node.block) + return ast_node + + def mutate_While(self, ast_node): + ast_node.cond = self.mutate(ast_node.cond) + ast_node.block = self.mutate(ast_node.block) + return ast_node diff --git a/ast/select.py b/ast/select.py index 702bb8fb4a16e54fed93e717728e3439f35bdd21..1fe9d088b1bbb504279966769a7302897e701e6c 100644 --- a/ast/select.py +++ b/ast/select.py @@ -12,10 +12,3 @@ class Select(ASTNode): def children(self): return [self.cond, self.expr_if, self.expr_else] - - def transform(self, fn): - self.cond = self.cond.transform(fn) - self.expr_if = self.expr_if.transform(fn) - self.expr_else = self.expr_else.transform(fn) - return fn(self) - diff --git a/ast/visitor.py b/ast/visitor.py index 8cf72e97d3a6ce6952a83f9e50767f7fa272f276..419b5e7df9e69eb7992eb49300bd7d925e38fd7b 100644 --- a/ast/visitor.py +++ b/ast/visitor.py @@ -1,22 +1,31 @@ -class Visitor: - def __init__(self, ast, enter_fn=None, leave_fn=None, max_depth=0): - self.ast = ast - self.enter_fn = enter_fn - self.leave_fn = leave_fn - self.max_depth = max_depth - - def visit(self): - self.visit_rec(self.ast) +from ast.bin_op import BinOp +from ast.loops import For, ParticleFor, While - def visit_rec(self, ast): - if self.enter_fn is not None: - self.enter_fn(ast) - for c in ast.children(): - self.visit_rec(c) - - if self.leave_fn is not None: - self.leave_fn(ast) +class Visitor: + def __init__(self, ast, max_depth=0): + self.ast = ast + self.max_depth = 0 + + def visit(ast_node): + for c in ast_node.children(): + if isinstance(c, Array): + self.visit_Array(c) + elif isinstance(c, BinOp): + self.visit_BinOp(c) + elif isinstance(c, (For, ParticleFor, While)): + self.visit_Loop(c) + else: + self.visit(c) + + def visit_Array(self, ast_node): + return self.visit(ast_node) + + def visit_BinOp(self, ast_node): + return self.visit(ast_node) + + def visit_Loop(self, ast_node): + return self.visit(ast_node) def list_ast(self): self.list_elements(self.ast) diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py index 12104a3acdecbe431db74f844f56e7be7e9c1dd6..53f8ef1716087fa78291eaabbb7702b4ac629ecb 100644 --- a/sim/particle_simulation.py +++ b/sim/particle_simulation.py @@ -5,10 +5,9 @@ from ast.data_types import Type_Int, Type_Float, Type_Vector from ast.layouts import Layout_AoS from ast.loops import ParticleFor, NeighborFor from ast.properties import Properties -from ast.transform import Transform from ast.variables import Variables -from sim.arrays import ArraysDecl from graph.graphviz import ASTGraph +from sim.arrays import ArraysDecl from sim.cell_lists import CellLists, CellListsBuild, CellListsStencilBuild from sim.grid import Grid2D, Grid3D from sim.kernel_wrapper import KernelWrapper @@ -20,6 +19,8 @@ from sim.setup_wrapper import SetupWrapper from sim.timestep import Timestep from sim.variables import VariablesDecl from sim.vtk import VTKWrite +from transformations.flatten import flatten_property_accesses +from transformations.simplify import simplify_expressions class ParticleSimulation: @@ -193,12 +194,10 @@ class ParticleSimulation: program = Block.merge_blocks(decls, body) self.global_scope = program - Transform.apply(program, Transform.flatten) - Transform.apply(program, Transform.simplify) - #Transform.apply(program, Transform.reuse_index_expressions) - #Transform.apply(program, Transform.reuse_expr_expressions) - #Transform.apply(program, Transform.reuse_array_access_expressions) - #Transform.apply(program, Transform.move_loop_invariant_expressions) + + # Transformations + flatten_property_accesses(program) + simplify_expressions(program) ASTGraph(self.kernels.lower(), "kernels").generate_and_view() self.code_gen.generate_program(self, program) diff --git a/sim/timestep.py b/sim/timestep.py index 0d541330bce00d894ae82442f6c19f4c3649ae77..74aa2419a8bd7d85338399807f6644ee3d593129 100644 --- a/sim/timestep.py +++ b/sim/timestep.py @@ -41,6 +41,3 @@ class Timestep: def as_block(self): return Block(self.sim, [self.timestep_loop]) - - def transform(self, fn): - self.block = self.block.transform(fn) diff --git a/sim/vtk.py b/sim/vtk.py index c1620f85e0bc79238e28ca8b849741ea73540399..284533abf61ef096865ab3161061a77809cd08bf 100644 --- a/sim/vtk.py +++ b/sim/vtk.py @@ -1,17 +1,13 @@ from ast.lit import as_lit_ast +from ast.ast_node import ASTNode -class VTKWrite: + +class VTKWrite(ASTNode): vtk_id = 0 def __init__(self, sim, filename, timestep): - self.sim = sim + super().__init__(sim) self.vtk_id = VTKWrite.vtk_id self.filename = filename self.timestep = as_lit_ast(sim, timestep) VTKWrite.vtk_id += 1 - - def children(self): - return [] - - def transform(self, fn): - return fn(self) diff --git a/transformations/flatten.py b/transformations/flatten.py new file mode 100644 index 0000000000000000000000000000000000000000..b06ebcb07d6a54c5b4367abdb2bcc4d3093f0e7e --- /dev/null +++ b/transformations/flatten.py @@ -0,0 +1,35 @@ +from ast.layouts import Layout_AoS, Layout_SoA +from ast.mutator import Mutator + + +class FlattenPropertyAccesses(Mutator): + def __init__(self, ast): + super().__init__(ast) + + def mutate_BinOp(self, ast_node): + ast_node.lhs = self.mutate(ast_node.lhs) + ast_node.rhs = self.mutate(ast_node.rhs) + + if ast_node.is_vector_property_access(): + layout = ast_node.lhs.layout() + + for i in ast_node.vector_indexes(): + flat_index = None + + if layout == Layout_AoS: + flat_index = ast_node.rhs * ast_node.sim.dimensions + i + + elif layout == Layout_SoA: + flat_index = i * ast_node.sim.particle_capacity + ast_node.rhs + + else: + raise Exception("Invalid property layout!") + + ast_node.map_vector_index(i, flat_index) + + return ast_node + + +def flatten_property_accesses(ast_node): + flatten = FlattenPropertyAccesses(ast_node) + flatten.mutate() diff --git a/transformations/simplify.py b/transformations/simplify.py new file mode 100644 index 0000000000000000000000000000000000000000..ff673159bcb09af7a372f959c97bbe852e00c835 --- /dev/null +++ b/transformations/simplify.py @@ -0,0 +1,36 @@ +from ast.data_types import Type_Int +from ast.lit import Lit +from ast.mutator import Mutator + + +class SimplifyExpressions(Mutator): + def __init__(self, ast): + super().__init__(ast) + + def mutate_BinOp(self, ast_node): + sim = ast_node.lhs.sim + ast_node.lhs = self.mutate(ast_node.lhs) + ast_node.rhs = self.mutate(ast_node.rhs) + ast_node.bin_op_vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.bin_op_vector_index_mapping.items()} + + if ast_node.op in ['+', '-'] and ast_node.rhs == 0: + return ast_node.lhs + + if ast_node.op in ['+'] and ast_node.lhs == 0: + return ast_node.rhs + + if ast_node.op in ['*', '/'] and ast_node.rhs == 1: + return ast_node.lhs + + if ast_node.op == '*' and ast_node.lhs == 1: + return ast_node.rhs + + if ast_node.op == '*' and ast_node.lhs == 0: + return Lit(sim, 0 if ast_node.type() == Type_Int else 0.0) + + return ast_node + + +def simplify_expressions(ast_node): + simplify = SimplifyExpressions(ast_node) + simplify.mutate()