diff --git a/ast/ast_node.py b/ast/ast_node.py index 8d37b6961bd96c0268fc92581d90aec561c3832b..aa1c2202daa09a30ee576d8e401cd128875036b6 100644 --- a/ast/ast_node.py +++ b/ast/ast_node.py @@ -4,7 +4,7 @@ from ast.data_types import Type_Invalid class ASTNode: def __init__(self, sim): self.sim = sim - self._parent_block = None # Set during SetParentBlock transformation + self.parent_block = None # Set during SetParentBlock transformation def __str__(self): return "ASTNode<>" @@ -15,9 +15,5 @@ class ASTNode: def scope(self): return self.sim.global_scope - @property - def parent_block(self): - return self._parent_block - def children(self): return [] diff --git a/ast/bin_op.py b/ast/bin_op.py index ba826203772cb6bae73e7396b097c58d0cce12f4..6d48c6d72f2e5d03a4c82640c968ec523679c011 100644 --- a/ast/bin_op.py +++ b/ast/bin_op.py @@ -46,8 +46,9 @@ class BinOp(ASTNode): self.generated = False self.bin_op_type = BinOp.infer_type(self.lhs, self.rhs, self.op) self.bin_op_scope = None - self.bin_op_vector_indexes = set() - self.bin_op_vector_index_mapping = {} + self.terminals = set() + self._vector_indexes = set() + self.vector_index_mapping = {} self.bin_op_def = BinOpDef(self) def __str__(self): @@ -68,17 +69,18 @@ class BinOp(ASTNode): return self.__getitem__(2) def map_vector_index(self, index, expr): - self.bin_op_vector_index_mapping[index] = expr + self.vector_index_mapping[index] = expr def mapped_vector_index(self, index): - mapping = self.bin_op_vector_index_mapping + mapping = self.vector_index_mapping return mapping[index] if index in mapping else as_lit_ast(self.sim, index) + @property def vector_indexes(self): - return self.bin_op_vector_indexes + return self._vector_indexes def propagate_vector_access(self, index): - self.bin_op_vector_indexes.add(index) + self.vector_indexes.add(index) if isinstance(self.lhs, BinOp) and self.lhs.kind() == BinOp.Kind_Vector: self.lhs.propagate_vector_access(index) @@ -162,6 +164,9 @@ class BinOp(ASTNode): def kind(self): return BinOp.Kind_Vector if self.type() == Type_Vector else BinOp.Kind_Scalar + def add_terminal(self, terminal): + self.terminals.add(terminal) + def scope(self): if self.bin_op_scope is None: lhs_scp = self.lhs.scope() diff --git a/ast/block.py b/ast/block.py index 85ddd1fe13f644a47caf0e92c803232e4e714a9a..58a90926edfa09168d54b535387a47a0db13bb03 100644 --- a/ast/block.py +++ b/ast/block.py @@ -5,7 +5,7 @@ class Block(ASTNode): def __init__(self, sim, stmts): super().__init__(sim) self.level = 0 - self.variants = [] + self.variants = set() if isinstance(stmts, Block): self.stmts = stmts.statements() @@ -34,10 +34,8 @@ class Block(ASTNode): self.stmts.append(stmt) def add_variant(self, variant): - if isinstance(variant, list): - self.variants = self.variants + variant - else: - self.variants.append(variant) + for v in variant if isinstance(variant, list) else [variant]: + self.variants.add(v) def statements(self): return self.stmts diff --git a/ast/mutator.py b/ast/mutator.py index 6ec8727c7713586dfc4653c3b9db9890b47b29e1..b856fbf103b558993f5a7e23169f0e35c467ba09 100644 --- a/ast/mutator.py +++ b/ast/mutator.py @@ -33,7 +33,7 @@ class Mutator: def mutate_BinOp(self, ast_node): ast_node.lhs = self.mutate(ast_node.lhs) ast_node.rhs = self.mutate(ast_node.rhs) - ast_node.bin_op_vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.bin_op_vector_index_mapping.items()} + ast_node.vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.vector_index_mapping.items()} return ast_node def mutate_BinOpDef(self, ast_node): diff --git a/code_gen/cgen.py b/code_gen/cgen.py index c28e353d69cf64fe502dd8000c14feac669f0dd0..8fccf2454abfe43318de3a86849fa6e9fe13ad6f 100644 --- a/code_gen/cgen.py +++ b/code_gen/cgen.py @@ -16,7 +16,7 @@ from ast.utils import Print from ast.variables import Var, VarDecl from sim.timestep import Timestep from sim.vtk import VTKWrite -from code_gen.printer import printer +from code_gen.printer import Printer class CGen: @@ -27,34 +27,43 @@ class CGen: else 'bool' ) - def generate_program(sim, ast_node): - printer.print("#include <stdio.h>") - printer.print("#include <stdlib.h>") - printer.print("#include <stdbool.h>") - printer.print("") - printer.print("int main() {") - CGen.generate_statement(sim, ast_node) - printer.print("}") - - def generate_statement(sim, ast_node): + def __init__(self, output): + self.sim = None + self.print = Printer(output) + + def assign_simulation(self, sim): + self.sim = sim + + def generate_program(self, ast_node): + self.print.start() + self.print("#include <stdio.h>") + self.print("#include <stdlib.h>") + self.print("#include <stdbool.h>") + self.print("") + self.print("int main() {") + self.generate_statement(ast_node) + self.print("}") + self.print.end() + + def generate_statement(self, ast_node): if isinstance(ast_node, ArrayDecl): tkw = CGen.type2keyword(ast_node.array.type()) - size = CGen.generate_expression(sim, BinOp.inline(ast_node.array.alloc_size())) - printer.print(f"{tkw} {ast_node.array.name()}[{size}];") + size = self.generate_expression(BinOp.inline(ast_node.array.alloc_size())) + self.print(f"{tkw} {ast_node.array.name()}[{size}];") if isinstance(ast_node, Assign): for assign_dest, assign_src in ast_node.assignments: - dest = CGen.generate_expression(sim, assign_dest, mem=True) - src = CGen.generate_expression(sim, assign_src) - printer.print(f"{dest} = {src};") + dest = self.generate_expression(assign_dest, mem=True) + src = self.generate_expression(assign_src) + self.print(f"{dest} = {src};") if isinstance(ast_node, Block): - printer.add_ind(4) + self.print.add_ind(4) for stmt in ast_node.statements(): - CGen.generate_statement(sim, stmt) + self.generate_statement(stmt) - printer.add_ind(-4) + self.print.add_ind(-4) if isinstance(ast_node, BinOpDef): bin_op = ast_node.bin_op @@ -64,16 +73,16 @@ class CGen: if bin_op.inlined is False and bin_op.operator() != '[]' and bin_op.generated is False: if bin_op.kind() == BinOp.Kind_Scalar: - lhs = CGen.generate_expression(sim, bin_op.lhs, bin_op.mem) - rhs = CGen.generate_expression(sim, bin_op.rhs) + lhs = self.generate_expression(bin_op.lhs, bin_op.mem) + rhs = self.generate_expression(bin_op.rhs) tkw = CGen.type2keyword(bin_op.type()) - printer.print(f"const {tkw} e{bin_op.id()} = {lhs} {bin_op.operator()} {rhs};") + self.print(f"const {tkw} e{bin_op.id()} = {lhs} {bin_op.operator()} {rhs};") elif bin_op.kind() == BinOp.Kind_Vector: - for i in bin_op.vector_indexes(): - lhs = CGen.generate_expression(sim, bin_op.lhs, bin_op.mem, index=i) - rhs = CGen.generate_expression(sim, bin_op.rhs, index=i) - printer.print(f"const double e{bin_op.id()}_{i} = {lhs} {bin_op.operator()} {rhs};") + for i in bin_op.vector_indexes: + lhs = self.generate_expression(bin_op.lhs, bin_op.mem, index=i) + rhs = self.generate_expression(bin_op.rhs, index=i) + self.print(f"const double e{bin_op.id()}_{i} = {lhs} {bin_op.operator()} {rhs};") else: raise Exception("Invalid BinOp kind!") @@ -81,79 +90,79 @@ class CGen: bin_op.generated = True if isinstance(ast_node, Branch): - cond = CGen.generate_expression(sim, ast_node.cond) - printer.print(f"if({cond}) {{") - CGen.generate_statement(sim, ast_node.block_if) + cond = self.generate_expression(ast_node.cond) + self.print(f"if({cond}) {{") + self.generate_statement(ast_node.block_if) if ast_node.block_else is not None: - printer.print("} else {") - CGen.generate_statement(sim, ast_node.block_else) + self.print("} else {") + self.generate_statement(ast_node.block_else) - printer.print("}") + self.print("}") if isinstance(ast_node, For): - iterator = CGen.generate_expression(sim, ast_node.iterator) + iterator = self.generate_expression(ast_node.iterator) lower_range = None upper_range = None if isinstance(ast_node, ParticleFor): - n = sim.nlocal if ast_node.local_only else sim.nlocal + sim.pbc.npbc + n = self.sim.nlocal if ast_node.local_only else self.sim.nlocal + self.sim.pbc.npbc lower_range = 0 - upper_range = CGen.generate_expression(sim, n) + upper_range = self.generate_expression(n) else: - lower_range = CGen.generate_expression(sim, ast_node.min) - upper_range = CGen.generate_expression(sim, ast_node.max) + lower_range = self.generate_expression(ast_node.min) + upper_range = self.generate_expression(ast_node.max) - printer.print(f"for(int {iterator} = {lower_range}; {iterator} < {upper_range}; {iterator}++) {{") - CGen.generate_statement(sim, ast_node.block) - printer.print("}") + self.print(f"for(int {iterator} = {lower_range}; {iterator} < {upper_range}; {iterator}++) {{") + self.generate_statement(ast_node.block) + self.print("}") if isinstance(ast_node, Malloc): tkw = CGen.type2keyword(ast_node.array.type()) - size = CGen.generate_expression(sim, ast_node.size) + size = self.generate_expression(ast_node.size) array_name = ast_node.array.name() if ast_node.decl: - printer.print(f"{tkw} *{array_name} = ({tkw} *) malloc({size});") + self.print(f"{tkw} *{array_name} = ({tkw} *) malloc({size});") else: - printer.print(f"{array_name} = ({tkw} *) malloc({size});") + self.print(f"{array_name} = ({tkw} *) malloc({size});") if isinstance(ast_node, Print): - printer.print(f"fprintf(stdout, \"{ast_node.string}\\n\");") - printer.print(f"fflush(stdout);") + self.print(f"fprintf(stdout, \"{ast_node.string}\\n\");") + self.print(f"fflush(stdout);") if isinstance(ast_node, Realloc): tkw = CGen.type2keyword(ast_node.array.type()) - size = CGen.generate_expression(sim, ast_node.size) + size = self.generate_expression(ast_node.size) array_name = ast_node.array.name() - printer.print(f"{array_name} = ({tkw} *) realloc({array_name}, {size});") + self.print(f"{array_name} = ({tkw} *) realloc({array_name}, {size});") if isinstance(ast_node, Timestep): - CGen.generate_statement(sim, ast_node.block) + self.generate_statement(ast_node.block) if isinstance(ast_node, VarDecl): tkw = CGen.type2keyword(ast_node.var.type()) - printer.print(f"{tkw} {ast_node.var.name()} = {ast_node.var.init_value()};") + self.print(f"{tkw} {ast_node.var.name()} = {ast_node.var.init_value()};") if isinstance(ast_node, VTKWrite): - nlocal = CGen.generate_expression(sim, sim.nlocal) - npbc = CGen.generate_expression(sim, sim.pbc.npbc) - nall = CGen.generate_expression(sim, sim.nlocal + sim.pbc.npbc) - timestep = CGen.generate_expression(sim, ast_node.timestep) - CGen.generate_vtk_writing(ast_node.vtk_id * 2, f"{ast_node.filename}_local", 0, nlocal, nlocal, timestep) - CGen.generate_vtk_writing(ast_node.vtk_id * 2 + 1, f"{ast_node.filename}_pbc", nlocal, nall, npbc, timestep) + nlocal = self.generate_expression(self.sim.nlocal) + npbc = self.generate_expression(self.sim.pbc.npbc) + nall = self.generate_expression(self.sim.nlocal + self.sim.pbc.npbc) + timestep = self.generate_expression(ast_node.timestep) + self.generate_vtk_writing(ast_node.vtk_id * 2, f"{ast_node.filename}_local", 0, nlocal, nlocal, timestep) + self.generate_vtk_writing(ast_node.vtk_id * 2 + 1, f"{ast_node.filename}_pbc", nlocal, nall, npbc, timestep) if isinstance(ast_node, While): - cond = CGen.generate_expression(sim, ast_node.cond) - printer.print(f"while({cond}) {{") - CGen.generate_statement(sim, ast_node.block) - printer.print("}") + cond = self.generate_expression(ast_node.cond) + self.print(f"while({cond}) {{") + self.generate_statement(ast_node.block) + self.print("}") - def generate_expression(sim, ast_node, mem=False, index=None): + def generate_expression(self, ast_node, mem=False, index=None): if isinstance(ast_node, ArrayAccess): - index = CGen.generate_expression(sim, ast_node.index) + index = self.generate_expression(ast_node.index) array_name = ast_node.array.name() if mem: @@ -162,20 +171,20 @@ class CGen: acc_ref = f"a{ast_node.id()}" if ast_node.generated is False: tkw = CGen.type2keyword(ast_node.type()) - printer.print(f"const {tkw} {acc_ref} = {array_name}[{index}];") + self.print(f"const {tkw} {acc_ref} = {array_name}[{index}];") ast_node.generated = True return acc_ref if isinstance(ast_node, BinOp): if isinstance(ast_node.lhs, BinOp) and ast_node.lhs.kind() == BinOp.Kind_Vector and ast_node.operator() == '[]': - return CGen.generate_expression(sim, ast_node.lhs, ast_node.mem, CGen.generate_expression(sim, ast_node.rhs)) + return self.generate_expression(ast_node.lhs, ast_node.mem, self.generate_expression(ast_node.rhs)) - lhs = CGen.generate_expression(sim, ast_node.lhs, mem, index) - rhs = CGen.generate_expression(sim, ast_node.rhs, index=index) + lhs = self.generate_expression(ast_node.lhs, mem, index) + rhs = self.generate_expression(ast_node.rhs, index=index) if ast_node.operator() == '[]': - idx = CGen.generate_expression(sim, ast_node.mapped_vector_index(index)) if ast_node.is_vector_property_access() else rhs + idx = self.generate_expression(ast_node.mapped_vector_index(index)) if ast_node.is_vector_property_access() else rhs return f"{lhs}[{idx}]" if ast_node.mem else f"{lhs}_{idx}" if ast_node.inlined is True: @@ -185,7 +194,7 @@ class CGen: # Some expressions can be defined on-the-fly during transformations, hence they do not have # a definition statement in the tree, so we generate them right before use if not ast_node.generated: - CGen.generate_statement(sim, ast_node.definition()) + self.generate_statement(ast_node.definition()) if ast_node.kind() == BinOp.Kind_Vector: assert index is not None, "Index must be set for vector reference!" @@ -195,7 +204,7 @@ class CGen: if isinstance(ast_node, Cast): tkw = CGen.type2keyword(ast_node.cast_type) - expr = CGen.generate_expression(sim, ast_node.expr) + expr = self.generate_expression(ast_node.expr) return f"({tkw})({expr})" if isinstance(ast_node, Iter): @@ -216,20 +225,20 @@ class CGen: if isinstance(ast_node, Sqrt): assert mem is False, "Square root call is not lvalue!" - expr = CGen.generate_expression(sim, ast_node.expr) + expr = self.generate_expression(ast_node.expr) return f"sqrt({expr})" if isinstance(ast_node, Select): assert mem is False, "Select expression is not lvalue!" - cond = CGen.generate_expression(sim, ast_node.cond) - expr_if = CGen.generate_expression(sim, ast_node.expr_if) - expr_else = CGen.generate_expression(sim, ast_node.expr_else) + cond = self.generate_expression(ast_node.cond) + expr_if = self.generate_expression(ast_node.expr_if) + expr_else = self.generate_expression(ast_node.expr_else) return f"({cond}) ? ({expr_if}) : ({expr_else})" if isinstance(ast_node, Var): return ast_node.name() - def generate_vtk_writing(id, filename, start, end, n, timestep): + def generate_vtk_writing(self, id, filename, start, end, n, timestep): # TODO: Do this in a more elegant way, without hard coded stuff header = "# vtk DataFile Version 2.0\n" \ "Particle data\n" \ @@ -238,50 +247,49 @@ class CGen: filename_var = f"filename{id}" filehandle_var = f"vtk{id}" - printer.print(f"char {filename_var}[128];") - printer.print(f"snprintf({filename_var}, sizeof {filename_var}, \"{filename}_%d.vtk\", {timestep});") - printer.print(f"FILE *{filehandle_var} = fopen({filename_var}, \"w\");") + self.print(f"char {filename_var}[128];") + self.print(f"snprintf({filename_var}, sizeof {filename_var}, \"{filename}_%d.vtk\", {timestep});") + self.print(f"FILE *{filehandle_var} = fopen({filename_var}, \"w\");") for line in header.split('\n'): if len(line) > 0: - printer.print(f"fwrite(\"{line}\\n\", 1, {len(line) + 1}, {filehandle_var});") + self.print(f"fwrite(\"{line}\\n\", 1, {len(line) + 1}, {filehandle_var});") # Write positions - printer.print(f"fprintf({filehandle_var}, \"POINTS %d double\\n\", {n});") - printer.print(f"for(int i = {start}; i < {end}; i++) {{") - printer.add_ind(4) - printer.print(f"fprintf({filehandle_var}, \"%.4f %.4f %.4f\\n\", position[i * 3], position[i * 3 + 1], position[i * 3 + 2]);") - printer.add_ind(-4) - printer.print("}") - printer.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});") + self.print(f"fprintf({filehandle_var}, \"POINTS %d double\\n\", {n});") + self.print(f"for(int i = {start}; i < {end}; i++) {{") + self.print.add_ind(4) + self.print(f"fprintf({filehandle_var}, \"%.4f %.4f %.4f\\n\", position[i * 3], position[i * 3 + 1], position[i * 3 + 2]);") + self.print.add_ind(-4) + self.print("}") + self.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});") # Write cells - printer.print(f"fprintf({filehandle_var}, \"CELLS %d %d\\n\", {n}, {n} * 2);") - printer.print(f"for(int i = {start}; i < {end}; i++) {{") - printer.add_ind(4) - printer.print(f"fprintf({filehandle_var}, \"1 %d\\n\", i - {start});") - printer.add_ind(-4) - printer.print("}") - printer.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});") + self.print(f"fprintf({filehandle_var}, \"CELLS %d %d\\n\", {n}, {n} * 2);") + self.print(f"for(int i = {start}; i < {end}; i++) {{") + self.print.add_ind(4) + self.print(f"fprintf({filehandle_var}, \"1 %d\\n\", i - {start});") + self.print.add_ind(-4) + self.print("}") + self.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});") # Write cell types - printer.print(f"fprintf({filehandle_var}, \"CELL_TYPES %d\\n\", {n});") - printer.print(f"for(int i = {start}; i < {end}; i++) {{") - printer.add_ind(4) - printer.print(f"fwrite(\"1\\n\", 1, 2, {filehandle_var});") - printer.add_ind(-4) - printer.print("}") - printer.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});") + self.print(f"fprintf({filehandle_var}, \"CELL_TYPES %d\\n\", {n});") + self.print(f"for(int i = {start}; i < {end}; i++) {{") + self.print.add_ind(4) + self.print(f"fwrite(\"1\\n\", 1, 2, {filehandle_var});") + self.print.add_ind(-4) + self.print("}") + self.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});") # Write masses - printer.print(f"fprintf({filehandle_var}, \"POINT_DATA %d\\n\", {n});") - printer.print(f"fprintf({filehandle_var}, \"SCALARS mass double\\n\");") - printer.print(f"fprintf({filehandle_var}, \"LOOKUP_TABLE default\\n\");") - printer.print(f"for(int i = {start}; i < {end}; i++) {{") - printer.add_ind(4) - #printer.print(f"fprintf({filehandle_var}, \"%4.f\\n\", mass[i]);") - printer.print(f"fprintf({filehandle_var}, \"1.0\\n\");") - printer.add_ind(-4) - printer.print("}") - printer.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});") - - printer.print(f"fclose({filehandle_var});") + self.print(f"fprintf({filehandle_var}, \"POINT_DATA %d\\n\", {n});") + self.print(f"fprintf({filehandle_var}, \"SCALARS mass double\\n\");") + self.print(f"fprintf({filehandle_var}, \"LOOKUP_TABLE default\\n\");") + self.print(f"for(int i = {start}; i < {end}; i++) {{") + self.print.add_ind(4) + #self.print(f"fprintf({filehandle_var}, \"%4.f\\n\", mass[i]);") + self.print(f"fprintf({filehandle_var}, \"1.0\\n\");") + self.print.add_ind(-4) + self.print("}") + self.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});") + self.print(f"fclose({filehandle_var});") diff --git a/code_gen/printer.py b/code_gen/printer.py index 458713b51d35bdb8c4a4f6b7fd576cf92244bdcf..95299469b67fe70a6cb89c308a8a91990ee5cb0a 100644 --- a/code_gen/printer.py +++ b/code_gen/printer.py @@ -1,12 +1,19 @@ class Printer: - def __init__(self): + def __init__(self, output): + self.output = output + self.stream = None self.indent = 0 def add_ind(self, offset): self.indent += offset - def print(self, text): - print(self.indent * ' ' + text) + def start(self): + self.stream = open(self.output, 'w') + def end(self): + self.stream.close() + self.stream = None -printer = Printer() + def __call__(self, text): + assert self.stream is not None, "Invalid stream!" + self.stream.write(self.indent * ' ' + text + '\n') diff --git a/part_prot.py b/part_prot.py index 7c1bcc54cee4b330451d2e007dab59c3947e1fec..dab286c5177b26f3d468ebc9da6254341246e2b8 100644 --- a/part_prot.py +++ b/part_prot.py @@ -2,5 +2,5 @@ from code_gen.cgen import CGen from sim.particle_simulation import ParticleSimulation -def simulation(dims=3, timesteps=100): - return ParticleSimulation(CGen, dims, timesteps) +def simulation(ref, dims=3, timesteps=100): + return ParticleSimulation(CGen(f"{ref}.c"), dims, timesteps) diff --git a/particle.py b/particle.py index d9b95bf4a97ba01b139f976444d85b500c859864..f46612644d8d87433b43e5dc49aba3cc98754671 100644 --- a/particle.py +++ b/particle.py @@ -8,7 +8,7 @@ sigma = 1.0 epsilon = 1.0 sigma6 = sigma ** 6 -psim = pt.simulation() +psim = pt.simulation("lj") mass = psim.add_real_property('mass', 1.0) position = psim.add_vector_property('position') velocity = psim.add_vector_property('velocity') diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py index ae2f814c0ad8e90dfa6c031962474af1e53e0253..b21b5e37a0a8ce693eec2b847aeab38da64e530e 100644 --- a/sim/particle_simulation.py +++ b/sim/particle_simulation.py @@ -21,11 +21,13 @@ from sim.variables import VariablesDecl from sim.vtk import VTKWrite from transformations.flatten import flatten_property_accesses from transformations.simplify import simplify_expressions +from transformations.LICM import move_loop_invariant_code class ParticleSimulation: def __init__(self, code_gen, dims=3, timesteps=100): self.code_gen = code_gen + self.code_gen.assign_simulation(self) self.global_scope = None self.properties = Properties(self) self.vars = Variables(self) @@ -198,6 +200,7 @@ class ParticleSimulation: # Transformations flatten_property_accesses(program) simplify_expressions(program) + move_loop_invariant_code(program) ASTGraph(self.kernels.lower(), "kernels").render() - self.code_gen.generate_program(self, program) + self.code_gen.generate_program(program) diff --git a/transformations/LICM.py b/transformations/LICM.py index 07517b50cb4ddbab03925e119738244efc4b9512..4742f19adf852e065cb7ba65733b5ce837c2dbf3 100644 --- a/transformations/LICM.py +++ b/transformations/LICM.py @@ -1,3 +1,4 @@ +from ast.loops import For, While from ast.mutator import Mutator from ast.visitor import Visitor @@ -5,42 +6,42 @@ from ast.visitor import Visitor class SetBlockVariants(Mutator): def __init__(self, ast): super().__init__(ast) - self.current_block = None self.in_assignment = None - def mutate_Block(self, ast_node): - self.current_block = ast_node - ast_node.stmts = [self.mutate(s) for s in ast_node.stmts] - return ast_node - def mutate_Assign(self, ast_node): - self.in_assignment = ast_node + self.in_assignment = ast_node if ast_node.parent_block is not None else None for dest in ast_node.destinations(): self.mutate(dest) self.in_assignment = None return ast_node + def mutate_For(self, ast_node): + ast_node.block.add_variant(id(ast_node.iterator)) + ast_node.iterator = self.mutate(ast_node.iterator) + ast_node.block = self.mutate(ast_node.block) + return ast_node + def mutate_Array(self, ast_node): if self.in_assignment is not None: - self.in_assignment.parent_block.add_variant(ast_node) + self.in_assignment.parent_block.add_variant(id(ast_node)) return ast_node - def mutate_For(self, ast_node): - ast_node.block.add_variant(ast_node.iterator) - ast_node.iterator = self.mutate(ast_node.iterator) - ast_node.block = self.mutate(ast_node.block) + def mutate_Iter(self, ast_node): + if self.in_assignment is not None: + self.in_assignment.parent_block.add_variant(id(ast_node)) + return ast_node def mutate_Property(self, ast_node): if self.in_assignment is not None: - self.in_assignment.parent_block.add_variant(ast_node) + self.in_assignment.parent_block.add_variant(id(ast_node)) return ast_node def mutate_Variable(self, ast_node): if self.in_assignment is not None: - self.in_assignment.parent_block.add_variant(ast_node) + self.in_assignment.parent_block.add_variant(id(ast_node)) return ast_node @@ -51,48 +52,48 @@ class SetParentBlock(Visitor): self.blocks = [] def current_block(self): - return self.blocks[-1] + return self.blocks[-1] if self.blocks else None def visit_Assign(self, ast_node): - ast_node.parent_block = self.current_block + ast_node.parent_block = self.current_block() self.visit_children(ast_node) def visit_Block(self, ast_node): - ast_node.parent_block = self.current_block + ast_node.parent_block = self.current_block() self.blocks.append(ast_node) self.visit_children(ast_node) self.blocks.pop() def visit_BinOpDef(self, ast_node): - ast_node.parent_block = self.current_block + ast_node.parent_block = self.current_block() self.visit_children(ast_node) def visit_Branch(self, ast_node): - ast_node.parent_block = self.current_block + ast_node.parent_block = self.current_block() self.visit_children(ast_node) def visit_Filter(self, ast_node): - ast_node.parent_block = self.current_block + ast_node.parent_block = self.current_block() self.visit_children(ast_node) def visit_For(self, ast_node): - ast_node.parent_block = self.current_block + ast_node.parent_block = self.current_block() self.visit_children(ast_node) def visit_ParticleFor(self, ast_node): - ast_node.parent_block = self.current_block + ast_node.parent_block = self.current_block() self.visit_children(ast_node) def visit_Malloc(self, ast_node): - ast_node.parent_block = self.current_block + ast_node.parent_block = self.current_block() self.visit_children(ast_node) def visit_Realloc(self, ast_node): - ast_node.parent_block = self.current_block + ast_node.parent_block = self.current_block() self.visit_children(ast_node) def visit_While(self, ast_node): - ast_node.parent_block = self.current_block + ast_node.parent_block = self.current_block() self.visit_children(ast_node) def get_loop_parent_block(self, ast_node): @@ -101,18 +102,71 @@ class SetParentBlock(Visitor): return self.parents[loop_id] if loop_id in self.parents else None +class SetBinOpTerminals(Visitor): + def __init__(self, ast): + super().__init__(ast) + self.bin_ops = [] + + def visit_BinOp(self, ast_node): + self.bin_ops.append(ast_node) + self.visit_children(ast_node) + self.bin_ops.pop() + + def visit_Array(self, ast_node): + for bin_op in self.bin_ops: + bin_op.add_terminal(id(ast_node)) + + def visit_Iter(self, ast_node): + for bin_op in self.bin_ops: + bin_op.add_terminal(id(ast_node)) + + def visit_Property(self, ast_node): + for bin_op in self.bin_ops: + bin_op.add_terminal(id(ast_node)) + + def visit_Variable(self, ast_node): + for bin_op in self.bin_ops: + bin_op.add_terminal(id(ast_node)) + + class LICM(Mutator): - def __init__(self, ast, loop_parents): + def __init__(self, ast): super().__init__(ast) - self.loop_parents = loop_parents + self.lifts = {} + self.loops = [] def mutate_For(self, ast_node): + self.lifts[id(ast_node)] = [] + self.loops.append(ast_node) ast_node.iterator = self.mutate(ast_node.iterator) ast_node.block = self.mutate(ast_node.block) + self.loops.pop() + return ast_node + + def mutate_BinOpDef(self, ast_node): + if self.loops: + last_loop = self.loops[-1] + print(f"Checking lifting for {ast_node.id()}") + if not last_loop.block.variants.intersect(ast_node.bin_op.terminals): + self.lifts[id(last_loop)].append(ast_node) + print(f"Lifting {ast_node.id()}") + return None + return ast_node def mutate_Block(self, ast_node): - ast_node.stmts = [self.mutate(s) for s in ast_node.stmts] + new_stmts = [] + stmts = self.mutate(ast_node.stmts) + + for s in stmts: + if s is not None: + s_id = id(s) + if isinstance(s, (For, While)) and s_id in self.lifts: + new_stmts = new_stmts + self.lifts[s_id] + + new_stmts.append(s) + + ast_node.stmts = new_stmts return ast_node @@ -121,5 +175,7 @@ def move_loop_invariant_code(ast): set_parent_block.visit() set_block_variants = SetBlockVariants(ast) set_block_variants.mutate() + set_bin_op_terminals = SetBinOpTerminals(ast) + set_bin_op_terminals.visit() licm = LICM(ast) licm.mutate() diff --git a/transformations/flatten.py b/transformations/flatten.py index b06ebcb07d6a54c5b4367abdb2bcc4d3093f0e7e..b414079e8cc961fbfbaa9e282e9d40b3a50f53b7 100644 --- a/transformations/flatten.py +++ b/transformations/flatten.py @@ -13,7 +13,7 @@ class FlattenPropertyAccesses(Mutator): if ast_node.is_vector_property_access(): layout = ast_node.lhs.layout() - for i in ast_node.vector_indexes(): + for i in ast_node.vector_indexes: flat_index = None if layout == Layout_AoS: diff --git a/transformations/simplify.py b/transformations/simplify.py index ff673159bcb09af7a372f959c97bbe852e00c835..d59dd59123c2129d15886173c9f3c4d78826bfc6 100644 --- a/transformations/simplify.py +++ b/transformations/simplify.py @@ -11,7 +11,7 @@ class SimplifyExpressions(Mutator): sim = ast_node.lhs.sim ast_node.lhs = self.mutate(ast_node.lhs) ast_node.rhs = self.mutate(ast_node.rhs) - ast_node.bin_op_vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.bin_op_vector_index_mapping.items()} + ast_node.vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.vector_index_mapping.items()} if ast_node.op in ['+', '-'] and ast_node.rhs == 0: return ast_node.lhs