From 8c1a76cd1324ee1248cce3b6cb589ec9818fe8e0 Mon Sep 17 00:00:00 2001 From: Rafael Ravedutti <rafaelravedutti@gmail.com> Date: Fri, 15 Oct 2021 00:58:29 +0200 Subject: [PATCH] Define lowering and adjacent blocks merging as transformations Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com> --- src/pairs/ir/block.py | 4 +-- src/pairs/ir/mutator.py | 4 +++ src/pairs/sim/arrays.py | 5 +-- src/pairs/sim/cell_lists.py | 9 ++--- src/pairs/sim/lattice.py | 5 +-- src/pairs/sim/lowerable.py | 6 ++++ src/pairs/sim/neighbor_lists.py | 5 +-- src/pairs/sim/pbc.py | 13 +++---- src/pairs/sim/properties.py | 5 +-- src/pairs/sim/read_from_file.py | 5 +-- src/pairs/sim/simulation.py | 34 +++++++++++-------- src/pairs/sim/variables.py | 5 +-- src/pairs/sim/vtk.py | 10 +++--- src/pairs/transformations/lower.py | 23 +++++++++++++ .../transformations/merge_adjacent_blocks.py | 25 ++++++++++++++ .../transformations/prioritize_scalar_ops.py | 12 +++---- 16 files changed, 118 insertions(+), 52 deletions(-) create mode 100644 src/pairs/sim/lowerable.py create mode 100644 src/pairs/transformations/lower.py create mode 100644 src/pairs/transformations/merge_adjacent_blocks.py diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py index ff75903..b9c04b8 100644 --- a/src/pairs/ir/block.py +++ b/src/pairs/ir/block.py @@ -74,10 +74,8 @@ class Block(ASTNode): for block in block_list: if isinstance(block, Block): result_block = Block.merge_blocks(result_block, block) - elif isinstance(block, KernelBlock): - result_block.add_statement(block) else: - raise Exception("Element in list is not Block!") + result_block.add_statement(block) return result_block diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py index 293c5ec..cb394cf 100644 --- a/src/pairs/ir/mutator.py +++ b/src/pairs/ir/mutator.py @@ -15,6 +15,10 @@ class Mutator: if method is not None: return method(ast_node) + method_unknown = self.get_method("mutate_Unknown") + if method_unknown is not None: + return method_unknown(ast_node) + return ast_node def mutate_ArrayAccess(self, ast_node): diff --git a/src/pairs/sim/arrays.py b/src/pairs/sim/arrays.py index 7167baf..5adcaec 100644 --- a/src/pairs/sim/arrays.py +++ b/src/pairs/sim/arrays.py @@ -1,11 +1,12 @@ from pairs.ir.block import pairs_block from pairs.ir.memory import Malloc from pairs.ir.arrays import ArrayDecl +from pairs.sim.lowerable import Lowerable -class ArraysDecl: +class ArraysDecl(Lowerable): def __init__(self, sim): - self.sim = sim + super().__init__(sim) @pairs_block def lower(self): diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py index 6657cd6..1a36a04 100644 --- a/src/pairs/sim/cell_lists.py +++ b/src/pairs/sim/cell_lists.py @@ -8,6 +8,7 @@ from pairs.ir.data_types import Type_Int from pairs.ir.math import Ceil from pairs.ir.loops import For, ParticleFor from pairs.ir.utils import Print +from pairs.sim.lowerable import Lowerable from pairs.sim.resize import Resize @@ -30,9 +31,9 @@ class CellLists: self.particle_cell = self.sim.add_array('particle_cell', self.sim.particle_capacity, Type_Int) -class CellListsStencilBuild: +class CellListsStencilBuild(Lowerable): def __init__(self, sim, cell_lists): - self.sim = sim + super().__init__(sim) self.cell_lists = cell_lists @pairs_device_block @@ -63,9 +64,9 @@ class CellListsStencilBuild: cl.nstencil.set(cl.nstencil + 1) -class CellListsBuild: +class CellListsBuild(Lowerable): def __init__(self, sim, cell_lists): - self.sim = sim + super().__init__(sim) self.cell_lists = cell_lists @pairs_device_block diff --git a/src/pairs/sim/lattice.py b/src/pairs/sim/lattice.py index 8977f87..e1faa01 100644 --- a/src/pairs/sim/lattice.py +++ b/src/pairs/sim/lattice.py @@ -1,11 +1,12 @@ from pairs.ir.block import pairs_block from pairs.ir.data_types import Type_Vector from pairs.ir.loops import For +from pairs.sim.lowerable import Lowerable -class ParticleLattice(): +class ParticleLattice(Lowerable): def __init__(self, sim, grid, spacing, props, positions): - self.sim = sim + super().__init__(sim) self.grid = grid self.spacing = spacing self.props = props diff --git a/src/pairs/sim/lowerable.py b/src/pairs/sim/lowerable.py new file mode 100644 index 0000000..9a2698c --- /dev/null +++ b/src/pairs/sim/lowerable.py @@ -0,0 +1,6 @@ +class Lowerable: + def __init__(self, sim): + self.sim = sim + + def lower(self): + raise Exception("Error: lower() method must be implemented for Lowerable inherited classes!") diff --git a/src/pairs/sim/neighbor_lists.py b/src/pairs/sim/neighbor_lists.py index 13b5682..798a243 100644 --- a/src/pairs/sim/neighbor_lists.py +++ b/src/pairs/sim/neighbor_lists.py @@ -3,6 +3,7 @@ from pairs.ir.branches import Branch, Filter from pairs.ir.data_types import Type_Int from pairs.ir.loops import For, ParticleFor, NeighborFor from pairs.ir.utils import Print +from pairs.sim.lowerable import Lowerable from pairs.sim.resize import Resize @@ -15,9 +16,9 @@ class NeighborLists: self.numneighs = self.sim.add_array('numneighs', self.sim.particle_capacity, Type_Int) -class NeighborListsBuild: +class NeighborListsBuild(Lowerable): def __init__(self, sim, neighbor_lists): - self.sim = sim + super().__init__(sim) self.neighbor_lists = neighbor_lists @pairs_device_block diff --git a/src/pairs/sim/pbc.py b/src/pairs/sim/pbc.py index 76886ad..1a5d480 100644 --- a/src/pairs/sim/pbc.py +++ b/src/pairs/sim/pbc.py @@ -4,6 +4,7 @@ from pairs.ir.data_types import Type_Int from pairs.ir.loops import For, ParticleFor from pairs.ir.utils import Print from pairs.ir.select import Select +from pairs.sim.lowerable import Lowerable from pairs.sim.resize import Resize @@ -19,9 +20,9 @@ class PBC: self.pbc_mult = sim.add_array('pbc_mult', [self.pbc_capacity, sim.ndims()], Type_Int) -class UpdatePBC: +class UpdatePBC(Lowerable): def __init__(self, sim, pbc): - self.sim = sim + super().__init__(sim) self.pbc = pbc @pairs_device_block @@ -42,9 +43,9 @@ class UpdatePBC: positions[nlocal + i][d].set(positions[pbc_map[i]][d] + pbc_mult[i][d] * grid.length(d)) -class EnforcePBC: +class EnforcePBC(Lowerable): def __init__(self, sim, pbc): - self.sim = sim + super().__init__(sim) self.pbc = pbc @pairs_device_block @@ -64,9 +65,9 @@ class EnforcePBC: positions[i][d].sub(grid.length(d)) -class SetupPBC: +class SetupPBC(Lowerable): def __init__(self, sim, pbc): - self.sim = sim + super().__init__(sim) self.pbc = pbc @pairs_device_block diff --git a/src/pairs/sim/properties.py b/src/pairs/sim/properties.py index 3892bbf..35a24f3 100644 --- a/src/pairs/sim/properties.py +++ b/src/pairs/sim/properties.py @@ -4,11 +4,12 @@ from pairs.ir.loops import ParticleFor from pairs.ir.memory import Malloc, Realloc from pairs.ir.properties import RegisterProperty, UpdateProperty from pairs.ir.utils import Print +from pairs.sim.lowerable import Lowerable from functools import reduce import operator -class PropertiesAlloc: +class PropertiesAlloc(Lowerable): def __init__(self, sim, realloc=False): self.sim = sim self.realloc = realloc @@ -33,7 +34,7 @@ class PropertiesAlloc: RegisterProperty(self.sim, p, sizes) -class PropertiesResetVolatile: +class PropertiesResetVolatile(Lowerable): def __init__(self, sim): self.sim = sim diff --git a/src/pairs/sim/read_from_file.py b/src/pairs/sim/read_from_file.py index 6f16cb0..606f79a 100644 --- a/src/pairs/sim/read_from_file.py +++ b/src/pairs/sim/read_from_file.py @@ -3,11 +3,12 @@ from pairs.ir.data_types import Type_Float from pairs.ir.functions import Call_Int from pairs.ir.properties import PropertyList from pairs.sim.grid import MutableGrid +from pairs.sim.lowerable import Lowerable -class ReadFromFile(): +class ReadFromFile(Lowerable): def __init__(self, sim, filename, props): - self.sim = sim + super().__init__(sim) self.filename = filename self.props = PropertyList(sim, props) self.grid = MutableGrid(sim, sim.ndims()) diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index 238e392..f8c3461 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -20,10 +20,12 @@ from pairs.sim.timestep import Timestep from pairs.sim.variables import VariablesDecl from pairs.sim.vtk import VTKWrite from pairs.transformations.add_device_copies import add_device_copies -from pairs.transformations.prioritize_scalar_ops import prioritaze_scalar_ops +from pairs.transformations.prioritize_scalar_ops import prioritize_scalar_ops from pairs.transformations.set_used_bin_ops import set_used_bin_ops from pairs.transformations.simplify import simplify_expressions from pairs.transformations.LICM import move_loop_invariant_code +from pairs.transformations.lower import lower_everything +from pairs.transformations.merge_adjacent_blocks import merge_adjacent_blocks class Simulation: @@ -107,12 +109,12 @@ class Simulation: def create_particle_lattice(self, grid, spacing, props={}): positions = self.property('position') lattice = ParticleLattice(self, grid, spacing, props, positions) - self.setups.add_statement(lattice.lower()) + self.setups.add_statement(lattice) def from_file(self, filename, prop_names): props = [self.property(prop_name) for prop_name in prop_names] read_object = ReadFromFile(self, filename, props) - self.setups.add_statement(read_object.lower()) + self.setups.add_statement(read_object) self.grid = read_object.grid def create_cell_lists(self, spacing, cutoff_radius): @@ -186,34 +188,36 @@ class Simulation: def generate(self): timestep = Timestep(self, self.ntimesteps, [ - (EnforcePBC(self, self.pbc).lower(), 20), - (SetupPBC(self, self.pbc).lower(), UpdatePBC(self, self.pbc).lower(), 20), - (CellListsBuild(self, self.cell_lists).lower(), 20), - (NeighborListsBuild(self, self.neighbor_lists).lower(), 20), - PropertiesResetVolatile(self).lower(), + (EnforcePBC(self, self.pbc), 20), + (SetupPBC(self, self.pbc), UpdatePBC(self, self.pbc), 20), + (CellListsBuild(self, self.cell_lists), 20), + (NeighborListsBuild(self, self.neighbor_lists), 20), + PropertiesResetVolatile(self), self.kernels ]) - timestep.add(VTKWrite(self, self.vtk_file, timestep.timestep() + 1).lower()) + timestep.add(VTKWrite(self, self.vtk_file, timestep.timestep() + 1)) body = Block.from_list(self, [ self.setups, - CellListsStencilBuild(self, self.cell_lists).lower(), - VTKWrite(self, self.vtk_file, 0).lower(), + CellListsStencilBuild(self, self.cell_lists), + VTKWrite(self, self.vtk_file, 0), timestep.as_block() ]) decls = Block.from_list(self, [ - VariablesDecl(self).lower(), - ArraysDecl(self).lower(), - PropertiesAlloc(self).lower(), + VariablesDecl(self), + ArraysDecl(self), + PropertiesAlloc(self), ]) program = Block.merge_blocks(decls, body) self.global_scope = program # Transformations - prioritaze_scalar_ops(program) + lower_everything(program) + merge_adjacent_blocks(program) + prioritize_scalar_ops(program) simplify_expressions(program) move_loop_invariant_code(program) set_used_bin_ops(program) diff --git a/src/pairs/sim/variables.py b/src/pairs/sim/variables.py index 1d9168f..f7b512f 100644 --- a/src/pairs/sim/variables.py +++ b/src/pairs/sim/variables.py @@ -1,10 +1,11 @@ from pairs.ir.block import pairs_block from pairs.ir.variables import VarDecl +from pairs.sim.lowerable import Lowerable -class VariablesDecl: +class VariablesDecl(Lowerable): def __init__(self, sim): - self.sim = sim + super().__init__(sim) @pairs_block def lower(self): diff --git a/src/pairs/sim/vtk.py b/src/pairs/sim/vtk.py index 86de551..3fbf888 100644 --- a/src/pairs/sim/vtk.py +++ b/src/pairs/sim/vtk.py @@ -1,22 +1,20 @@ from pairs.ir.ast_node import ASTNode +from pairs.ir.block import pairs_block from pairs.ir.functions import Call_Void from pairs.ir.lit import as_lit_ast +from pairs.sim.lowerable import Lowerable -class VTKWrite(ASTNode): +class VTKWrite(Lowerable): def __init__(self, sim, filename, timestep): super().__init__(sim) self.filename = filename self.timestep = as_lit_ast(sim, timestep) + @pairs_block def lower(self): nlocal = self.sim.nlocal npbc = self.sim.pbc.npbc - self.sim.clear_block() nall = nlocal + npbc Call_Void(self.sim, "pairs::vtk_write_data", [self.filename + "_local", 0, nlocal, self.timestep]) Call_Void(self.sim, "pairs::vtk_write_data", [self.filename + "_pbc", nlocal, nall, self.timestep]) - return self.sim.block - - def children(self): - return [self.timestep] diff --git a/src/pairs/transformations/lower.py b/src/pairs/transformations/lower.py new file mode 100644 index 0000000..c2fb14f --- /dev/null +++ b/src/pairs/transformations/lower.py @@ -0,0 +1,23 @@ +from pairs.ir.mutator import Mutator +from pairs.sim.lowerable import Lowerable + + +class Lower(Mutator): + def __init__(self, ast): + super().__init__(ast) + self.lowered_nodes = 0 + + def mutate_Unknown(self, ast_node): + if isinstance(ast_node, Lowerable): + self.lowered_nodes += 1 + return ast_node.lower() + + return ast_node + + +def lower_everything(ast): + nlowered = 1 + while nlowered > 0: + lower = Lower(ast) + lower.mutate() + nlowered = lower.lowered_nodes diff --git a/src/pairs/transformations/merge_adjacent_blocks.py b/src/pairs/transformations/merge_adjacent_blocks.py new file mode 100644 index 0000000..1095144 --- /dev/null +++ b/src/pairs/transformations/merge_adjacent_blocks.py @@ -0,0 +1,25 @@ +from pairs.ir.block import Block +from pairs.ir.mutator import Mutator + + +class MergeAdjacentBlocks(Mutator): + def __init__(self, ast): + super().__init__(ast) + + def mutate_Block(self, ast_node): + new_stmts = [] + stmts = [self.mutate(s) for s in ast_node.stmts] + + for s in stmts: + if isinstance(s, Block): + new_stmts = new_stmts + s.statements() + else: + new_stmts.append(s) + + ast_node.stmts = new_stmts + return ast_node + + +def merge_adjacent_blocks(ast): + merge = MergeAdjacentBlocks(ast) + merge.mutate() diff --git a/src/pairs/transformations/prioritize_scalar_ops.py b/src/pairs/transformations/prioritize_scalar_ops.py index 47957f6..caeb87d 100644 --- a/src/pairs/transformations/prioritize_scalar_ops.py +++ b/src/pairs/transformations/prioritize_scalar_ops.py @@ -3,7 +3,7 @@ from pairs.ir.data_types import Type_Float, Type_Vector from pairs.ir.mutator import Mutator -class PrioritazeScalarOps(Mutator): +class PrioritizeScalarOps(Mutator): def __init__(self, ast): super().__init__(ast) @@ -21,7 +21,7 @@ class PrioritazeScalarOps(Mutator): op = ast_node.op if( isinstance(lhs, BinOp) and lhs.type() == Type_Vector and rhs.type() == Type_Float and \ - PrioritazeScalarOps.can_rearrange(op, lhs.op) ): + PrioritizeScalarOps.can_rearrange(op, lhs.op) ): if lhs.lhs.type() == Type_Vector and lhs.rhs.type() == Type_Float: ast_node.reassign(lhs.lhs, BinOp(sim, lhs.rhs, rhs, op), op) @@ -32,7 +32,7 @@ class PrioritazeScalarOps(Mutator): #return BinOp(sim, lhs.rhs, BinOp(sim, lhs.lhs, rhs, op), op) if( isinstance(rhs, BinOp) and rhs.type() == Type_Vector and lhs.type() == Type_Float and \ - PrioritazeScalarOps.can_rearrange(op, rhs.op) ): + PrioritizeScalarOps.can_rearrange(op, rhs.op) ): if rhs.lhs.type() == Type_Vector and rhs.rhs.type() == Type_Float: ast_node.reassign(rhs.lhs, BinOp(sim, rhs.rhs, lhs, op), op) @@ -45,6 +45,6 @@ class PrioritazeScalarOps(Mutator): return ast_node -def prioritaze_scalar_ops(ast_node): - prioritaze = PrioritazeScalarOps(ast_node) - prioritaze.mutate() +def prioritize_scalar_ops(ast_node): + prioritize = PrioritizeScalarOps(ast_node) + prioritize.mutate() -- GitLab