diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py index ff759039c9e6ea30696fa06c713e1fdc4e056686..b9c04b829c14dc76ee5e56a869075333ebab2840 100644 --- a/src/pairs/ir/block.py +++ b/src/pairs/ir/block.py @@ -74,10 +74,8 @@ class Block(ASTNode): for block in block_list: if isinstance(block, Block): result_block = Block.merge_blocks(result_block, block) - elif isinstance(block, KernelBlock): - result_block.add_statement(block) else: - raise Exception("Element in list is not Block!") + result_block.add_statement(block) return result_block diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py index 293c5ec87550667d67123f472a3cb5350e23b528..cb394cf368f5f905da7199a22ef0d23c575985e0 100644 --- a/src/pairs/ir/mutator.py +++ b/src/pairs/ir/mutator.py @@ -15,6 +15,10 @@ class Mutator: if method is not None: return method(ast_node) + method_unknown = self.get_method("mutate_Unknown") + if method_unknown is not None: + return method_unknown(ast_node) + return ast_node def mutate_ArrayAccess(self, ast_node): diff --git a/src/pairs/sim/arrays.py b/src/pairs/sim/arrays.py index 7167baf724ba560d3a01d7d354c21c6d3e0ba1c0..5adcaecccb9d2eb69ff8d24cf20a3e1ed2a7c4d3 100644 --- a/src/pairs/sim/arrays.py +++ b/src/pairs/sim/arrays.py @@ -1,11 +1,12 @@ from pairs.ir.block import pairs_block from pairs.ir.memory import Malloc from pairs.ir.arrays import ArrayDecl +from pairs.sim.lowerable import Lowerable -class ArraysDecl: +class ArraysDecl(Lowerable): def __init__(self, sim): - self.sim = sim + super().__init__(sim) @pairs_block def lower(self): diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py index 6657cd60da15e50164094150ac87e84c163c56d8..1a36a04bde66ad4728151195fe288db878cfd4a6 100644 --- a/src/pairs/sim/cell_lists.py +++ b/src/pairs/sim/cell_lists.py @@ -8,6 +8,7 @@ from pairs.ir.data_types import Type_Int from pairs.ir.math import Ceil from pairs.ir.loops import For, ParticleFor from pairs.ir.utils import Print +from pairs.sim.lowerable import Lowerable from pairs.sim.resize import Resize @@ -30,9 +31,9 @@ class CellLists: self.particle_cell = self.sim.add_array('particle_cell', self.sim.particle_capacity, Type_Int) -class CellListsStencilBuild: +class CellListsStencilBuild(Lowerable): def __init__(self, sim, cell_lists): - self.sim = sim + super().__init__(sim) self.cell_lists = cell_lists @pairs_device_block @@ -63,9 +64,9 @@ class CellListsStencilBuild: cl.nstencil.set(cl.nstencil + 1) -class CellListsBuild: +class CellListsBuild(Lowerable): def __init__(self, sim, cell_lists): - self.sim = sim + super().__init__(sim) self.cell_lists = cell_lists @pairs_device_block diff --git a/src/pairs/sim/lattice.py b/src/pairs/sim/lattice.py index 8977f8727f3a77fb18ba8f83901004580f359b02..e1faa010edbfcf0ff8f50ffd8e6eb6e5517a5db5 100644 --- a/src/pairs/sim/lattice.py +++ b/src/pairs/sim/lattice.py @@ -1,11 +1,12 @@ from pairs.ir.block import pairs_block from pairs.ir.data_types import Type_Vector from pairs.ir.loops import For +from pairs.sim.lowerable import Lowerable -class ParticleLattice(): +class ParticleLattice(Lowerable): def __init__(self, sim, grid, spacing, props, positions): - self.sim = sim + super().__init__(sim) self.grid = grid self.spacing = spacing self.props = props diff --git a/src/pairs/sim/lowerable.py b/src/pairs/sim/lowerable.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2698c30809723989fd7ad1cfef97e82379155e --- /dev/null +++ b/src/pairs/sim/lowerable.py @@ -0,0 +1,6 @@ +class Lowerable: + def __init__(self, sim): + self.sim = sim + + def lower(self): + raise Exception("Error: lower() method must be implemented for Lowerable inherited classes!") diff --git a/src/pairs/sim/neighbor_lists.py b/src/pairs/sim/neighbor_lists.py index 13b568272bcea18de50bfaf51c8023912edf828e..798a243a22f33c7cc4c7f1208077299466ca24ce 100644 --- a/src/pairs/sim/neighbor_lists.py +++ b/src/pairs/sim/neighbor_lists.py @@ -3,6 +3,7 @@ from pairs.ir.branches import Branch, Filter from pairs.ir.data_types import Type_Int from pairs.ir.loops import For, ParticleFor, NeighborFor from pairs.ir.utils import Print +from pairs.sim.lowerable import Lowerable from pairs.sim.resize import Resize @@ -15,9 +16,9 @@ class NeighborLists: self.numneighs = self.sim.add_array('numneighs', self.sim.particle_capacity, Type_Int) -class NeighborListsBuild: +class NeighborListsBuild(Lowerable): def __init__(self, sim, neighbor_lists): - self.sim = sim + super().__init__(sim) self.neighbor_lists = neighbor_lists @pairs_device_block diff --git a/src/pairs/sim/pbc.py b/src/pairs/sim/pbc.py index 76886ade244874f692260de3981c95c5fbca1c16..1a5d4809423800fb670e635924d23de8318549ae 100644 --- a/src/pairs/sim/pbc.py +++ b/src/pairs/sim/pbc.py @@ -4,6 +4,7 @@ from pairs.ir.data_types import Type_Int from pairs.ir.loops import For, ParticleFor from pairs.ir.utils import Print from pairs.ir.select import Select +from pairs.sim.lowerable import Lowerable from pairs.sim.resize import Resize @@ -19,9 +20,9 @@ class PBC: self.pbc_mult = sim.add_array('pbc_mult', [self.pbc_capacity, sim.ndims()], Type_Int) -class UpdatePBC: +class UpdatePBC(Lowerable): def __init__(self, sim, pbc): - self.sim = sim + super().__init__(sim) self.pbc = pbc @pairs_device_block @@ -42,9 +43,9 @@ class UpdatePBC: positions[nlocal + i][d].set(positions[pbc_map[i]][d] + pbc_mult[i][d] * grid.length(d)) -class EnforcePBC: +class EnforcePBC(Lowerable): def __init__(self, sim, pbc): - self.sim = sim + super().__init__(sim) self.pbc = pbc @pairs_device_block @@ -64,9 +65,9 @@ class EnforcePBC: positions[i][d].sub(grid.length(d)) -class SetupPBC: +class SetupPBC(Lowerable): def __init__(self, sim, pbc): - self.sim = sim + super().__init__(sim) self.pbc = pbc @pairs_device_block diff --git a/src/pairs/sim/properties.py b/src/pairs/sim/properties.py index 3892bbf372a876757af641d92c478f62f686fc4a..35a24f3cbbd4e616f10ffb58f1e24d46eb4a3620 100644 --- a/src/pairs/sim/properties.py +++ b/src/pairs/sim/properties.py @@ -4,11 +4,12 @@ from pairs.ir.loops import ParticleFor from pairs.ir.memory import Malloc, Realloc from pairs.ir.properties import RegisterProperty, UpdateProperty from pairs.ir.utils import Print +from pairs.sim.lowerable import Lowerable from functools import reduce import operator -class PropertiesAlloc: +class PropertiesAlloc(Lowerable): def __init__(self, sim, realloc=False): self.sim = sim self.realloc = realloc @@ -33,7 +34,7 @@ class PropertiesAlloc: RegisterProperty(self.sim, p, sizes) -class PropertiesResetVolatile: +class PropertiesResetVolatile(Lowerable): def __init__(self, sim): self.sim = sim diff --git a/src/pairs/sim/read_from_file.py b/src/pairs/sim/read_from_file.py index 6f16cb0556a235f7786952947ecbc164f820e7d0..606f79af41b09000a26fdc3c4f33407e8c4ff7e8 100644 --- a/src/pairs/sim/read_from_file.py +++ b/src/pairs/sim/read_from_file.py @@ -3,11 +3,12 @@ from pairs.ir.data_types import Type_Float from pairs.ir.functions import Call_Int from pairs.ir.properties import PropertyList from pairs.sim.grid import MutableGrid +from pairs.sim.lowerable import Lowerable -class ReadFromFile(): +class ReadFromFile(Lowerable): def __init__(self, sim, filename, props): - self.sim = sim + super().__init__(sim) self.filename = filename self.props = PropertyList(sim, props) self.grid = MutableGrid(sim, sim.ndims()) diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index 238e39242be802593b658b9d4a06eb1f83062fe6..f8c34613a9b28f4f4114b27748e06e16e245cb42 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -20,10 +20,12 @@ from pairs.sim.timestep import Timestep from pairs.sim.variables import VariablesDecl from pairs.sim.vtk import VTKWrite from pairs.transformations.add_device_copies import add_device_copies -from pairs.transformations.prioritize_scalar_ops import prioritaze_scalar_ops +from pairs.transformations.prioritize_scalar_ops import prioritize_scalar_ops from pairs.transformations.set_used_bin_ops import set_used_bin_ops from pairs.transformations.simplify import simplify_expressions from pairs.transformations.LICM import move_loop_invariant_code +from pairs.transformations.lower import lower_everything +from pairs.transformations.merge_adjacent_blocks import merge_adjacent_blocks class Simulation: @@ -107,12 +109,12 @@ class Simulation: def create_particle_lattice(self, grid, spacing, props={}): positions = self.property('position') lattice = ParticleLattice(self, grid, spacing, props, positions) - self.setups.add_statement(lattice.lower()) + self.setups.add_statement(lattice) def from_file(self, filename, prop_names): props = [self.property(prop_name) for prop_name in prop_names] read_object = ReadFromFile(self, filename, props) - self.setups.add_statement(read_object.lower()) + self.setups.add_statement(read_object) self.grid = read_object.grid def create_cell_lists(self, spacing, cutoff_radius): @@ -186,34 +188,36 @@ class Simulation: def generate(self): timestep = Timestep(self, self.ntimesteps, [ - (EnforcePBC(self, self.pbc).lower(), 20), - (SetupPBC(self, self.pbc).lower(), UpdatePBC(self, self.pbc).lower(), 20), - (CellListsBuild(self, self.cell_lists).lower(), 20), - (NeighborListsBuild(self, self.neighbor_lists).lower(), 20), - PropertiesResetVolatile(self).lower(), + (EnforcePBC(self, self.pbc), 20), + (SetupPBC(self, self.pbc), UpdatePBC(self, self.pbc), 20), + (CellListsBuild(self, self.cell_lists), 20), + (NeighborListsBuild(self, self.neighbor_lists), 20), + PropertiesResetVolatile(self), self.kernels ]) - timestep.add(VTKWrite(self, self.vtk_file, timestep.timestep() + 1).lower()) + timestep.add(VTKWrite(self, self.vtk_file, timestep.timestep() + 1)) body = Block.from_list(self, [ self.setups, - CellListsStencilBuild(self, self.cell_lists).lower(), - VTKWrite(self, self.vtk_file, 0).lower(), + CellListsStencilBuild(self, self.cell_lists), + VTKWrite(self, self.vtk_file, 0), timestep.as_block() ]) decls = Block.from_list(self, [ - VariablesDecl(self).lower(), - ArraysDecl(self).lower(), - PropertiesAlloc(self).lower(), + VariablesDecl(self), + ArraysDecl(self), + PropertiesAlloc(self), ]) program = Block.merge_blocks(decls, body) self.global_scope = program # Transformations - prioritaze_scalar_ops(program) + lower_everything(program) + merge_adjacent_blocks(program) + prioritize_scalar_ops(program) simplify_expressions(program) move_loop_invariant_code(program) set_used_bin_ops(program) diff --git a/src/pairs/sim/variables.py b/src/pairs/sim/variables.py index 1d9168ffdedb5ae67b84b17ea9f6a81f65c3ff39..f7b512fe96ae50d9854be229c3ba1d202051db9c 100644 --- a/src/pairs/sim/variables.py +++ b/src/pairs/sim/variables.py @@ -1,10 +1,11 @@ from pairs.ir.block import pairs_block from pairs.ir.variables import VarDecl +from pairs.sim.lowerable import Lowerable -class VariablesDecl: +class VariablesDecl(Lowerable): def __init__(self, sim): - self.sim = sim + super().__init__(sim) @pairs_block def lower(self): diff --git a/src/pairs/sim/vtk.py b/src/pairs/sim/vtk.py index 86de55122918cf7d4df1c264861bcc4f0f1ff34f..3fbf8883c67b7a280752df7972851c025c6d28c7 100644 --- a/src/pairs/sim/vtk.py +++ b/src/pairs/sim/vtk.py @@ -1,22 +1,20 @@ from pairs.ir.ast_node import ASTNode +from pairs.ir.block import pairs_block from pairs.ir.functions import Call_Void from pairs.ir.lit import as_lit_ast +from pairs.sim.lowerable import Lowerable -class VTKWrite(ASTNode): +class VTKWrite(Lowerable): def __init__(self, sim, filename, timestep): super().__init__(sim) self.filename = filename self.timestep = as_lit_ast(sim, timestep) + @pairs_block def lower(self): nlocal = self.sim.nlocal npbc = self.sim.pbc.npbc - self.sim.clear_block() nall = nlocal + npbc Call_Void(self.sim, "pairs::vtk_write_data", [self.filename + "_local", 0, nlocal, self.timestep]) Call_Void(self.sim, "pairs::vtk_write_data", [self.filename + "_pbc", nlocal, nall, self.timestep]) - return self.sim.block - - def children(self): - return [self.timestep] diff --git a/src/pairs/transformations/lower.py b/src/pairs/transformations/lower.py new file mode 100644 index 0000000000000000000000000000000000000000..c2fb14f14d87671439bbf74197b7986938aba749 --- /dev/null +++ b/src/pairs/transformations/lower.py @@ -0,0 +1,23 @@ +from pairs.ir.mutator import Mutator +from pairs.sim.lowerable import Lowerable + + +class Lower(Mutator): + def __init__(self, ast): + super().__init__(ast) + self.lowered_nodes = 0 + + def mutate_Unknown(self, ast_node): + if isinstance(ast_node, Lowerable): + self.lowered_nodes += 1 + return ast_node.lower() + + return ast_node + + +def lower_everything(ast): + nlowered = 1 + while nlowered > 0: + lower = Lower(ast) + lower.mutate() + nlowered = lower.lowered_nodes diff --git a/src/pairs/transformations/merge_adjacent_blocks.py b/src/pairs/transformations/merge_adjacent_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..10951443a57202ee34bae5e02fade8635843a764 --- /dev/null +++ b/src/pairs/transformations/merge_adjacent_blocks.py @@ -0,0 +1,25 @@ +from pairs.ir.block import Block +from pairs.ir.mutator import Mutator + + +class MergeAdjacentBlocks(Mutator): + def __init__(self, ast): + super().__init__(ast) + + def mutate_Block(self, ast_node): + new_stmts = [] + stmts = [self.mutate(s) for s in ast_node.stmts] + + for s in stmts: + if isinstance(s, Block): + new_stmts = new_stmts + s.statements() + else: + new_stmts.append(s) + + ast_node.stmts = new_stmts + return ast_node + + +def merge_adjacent_blocks(ast): + merge = MergeAdjacentBlocks(ast) + merge.mutate() diff --git a/src/pairs/transformations/prioritize_scalar_ops.py b/src/pairs/transformations/prioritize_scalar_ops.py index 47957f64c78da0edf214af51fadc5b4f9e7da94e..caeb87d1386b5a47d6ac499bac6c0c24a0b28af8 100644 --- a/src/pairs/transformations/prioritize_scalar_ops.py +++ b/src/pairs/transformations/prioritize_scalar_ops.py @@ -3,7 +3,7 @@ from pairs.ir.data_types import Type_Float, Type_Vector from pairs.ir.mutator import Mutator -class PrioritazeScalarOps(Mutator): +class PrioritizeScalarOps(Mutator): def __init__(self, ast): super().__init__(ast) @@ -21,7 +21,7 @@ class PrioritazeScalarOps(Mutator): op = ast_node.op if( isinstance(lhs, BinOp) and lhs.type() == Type_Vector and rhs.type() == Type_Float and \ - PrioritazeScalarOps.can_rearrange(op, lhs.op) ): + PrioritizeScalarOps.can_rearrange(op, lhs.op) ): if lhs.lhs.type() == Type_Vector and lhs.rhs.type() == Type_Float: ast_node.reassign(lhs.lhs, BinOp(sim, lhs.rhs, rhs, op), op) @@ -32,7 +32,7 @@ class PrioritazeScalarOps(Mutator): #return BinOp(sim, lhs.rhs, BinOp(sim, lhs.lhs, rhs, op), op) if( isinstance(rhs, BinOp) and rhs.type() == Type_Vector and lhs.type() == Type_Float and \ - PrioritazeScalarOps.can_rearrange(op, rhs.op) ): + PrioritizeScalarOps.can_rearrange(op, rhs.op) ): if rhs.lhs.type() == Type_Vector and rhs.rhs.type() == Type_Float: ast_node.reassign(rhs.lhs, BinOp(sim, rhs.rhs, lhs, op), op) @@ -45,6 +45,6 @@ class PrioritazeScalarOps(Mutator): return ast_node -def prioritaze_scalar_ops(ast_node): - prioritaze = PrioritazeScalarOps(ast_node) - prioritaze.mutate() +def prioritize_scalar_ops(ast_node): + prioritize = PrioritizeScalarOps(ast_node) + prioritize.mutate()