From 48675a8fa69ff5e49f51e3bf23a6d6a49b9678db Mon Sep 17 00:00:00 2001 From: Rafael Ravedutti <rafaelravedutti@gmail.com> Date: Thu, 14 Oct 2021 02:23:56 +0200 Subject: [PATCH] Fix add_device_copies transformation Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com> --- src/pairs/code_gen/cgen.py | 9 ++++++--- src/pairs/ir/block.py | 12 +++++++++--- src/pairs/ir/mutator.py | 12 ++++++++---- src/pairs/sim/kernel_wrapper.py | 4 ++-- src/pairs/sim/particle_simulation.py | 13 ++++++------- src/pairs/transformations/add_device_copies.py | 7 ++++--- 6 files changed, 35 insertions(+), 22 deletions(-) diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 850d74c..d3501b0 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -1,6 +1,6 @@ from pairs.ir.assign import Assign from pairs.ir.arrays import Array, ArrayAccess, ArrayDecl -from pairs.ir.block import Block +from pairs.ir.block import Block, KernelBlock from pairs.ir.branches import Branch from pairs.ir.cast import Cast from pairs.ir.bin_op import BinOp, Decl, VectorAccess @@ -71,10 +71,8 @@ class CGen: if isinstance(ast_node, Block): self.print.add_ind(4) - for stmt in ast_node.statements(): self.generate_statement(stmt) - self.print.add_ind(-4) # TODO: Why there are Decls for other types? @@ -134,6 +132,11 @@ class CGen: if isinstance(ast_node, DeviceCopy): self.print(f"pairs::copy_to_device({ast_node.prop.name()})") + if isinstance(ast_node, KernelBlock): + self.print.add_ind(-4) + self.generate_statement(ast_node.block) + self.print.add_ind(4) # Workaround for fixing indentation of kernels + if isinstance(ast_node, For): iterator = self.generate_expression(ast_node.iterator) lower_range = None diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py index cc3c50c..d46aa91 100644 --- a/src/pairs/ir/block.py +++ b/src/pairs/ir/block.py @@ -30,6 +30,8 @@ class Block(ASTNode): def merge_blocks(block1, block2): assert isinstance(block1, Block), "First block type is not Block!" assert isinstance(block2, Block), "Second block type is not Block!" + assert not isinstance(block1, KernelBlock), "Kernel blocks cannot be merged!" + assert not isinstance(block2, KernelBlock), "Kernel blocks cannot be merged!" return Block(block1.sim, block1.statements() + block2.statements()) def from_list(sim, block_list): @@ -43,9 +45,10 @@ class Block(ASTNode): return result_block -class KernelBlock(Block): - def __init__(self, sim, stmts, run_on_host=False): - super().__init__(sim, stmts) +class KernelBlock(ASTNode): + def __init__(self, sim, block, run_on_host=False): + super().__init__(sim) + self.block = block if isinstance(block, Block) else Block(sim, block) self.run_on_host = run_on_host self.props_accessed = {} @@ -57,6 +60,9 @@ class KernelBlock(Block): elif oper not in self.props_accessed[prop_key]: self.props_accessed[prop_key] += oper + def children(self): + return [self.block] + def properties_to_synchronize(self): return {p for p in self.props_accessed if self.props_accessed[p][0] == 'r'} diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py index 9421736..293c5ec 100644 --- a/src/pairs/ir/mutator.py +++ b/src/pairs/ir/mutator.py @@ -46,6 +46,10 @@ class Mutator: ast_node.block_else = None if ast_node.block_else is None else self.mutate(ast_node.block_else) return ast_node + def mutate_Cast(self, ast_node): + ast_node.expr = self.mutate(ast_node.expr) + return ast_node + def mutate_Decl(self, ast_node): ast_node.elem = self.mutate(ast_node.elem) return ast_node @@ -53,15 +57,15 @@ class Mutator: def mutate_Filter(self, ast_node): return self.mutate_Branch(ast_node) - def mutate_Cast(self, ast_node): - ast_node.expr = self.mutate(ast_node.expr) - return ast_node - def mutate_For(self, ast_node): ast_node.iterator = self.mutate(ast_node.iterator) ast_node.block = self.mutate(ast_node.block) return ast_node + def mutate_KernelBlock(self, ast_node): + ast_node.block = self.mutate(ast_node.block) + return ast_node + def mutate_ParticleFor(self, ast_node): return self.mutate_For(ast_node) diff --git a/src/pairs/sim/kernel_wrapper.py b/src/pairs/sim/kernel_wrapper.py index 53cd870..2e66b55 100644 --- a/src/pairs/sim/kernel_wrapper.py +++ b/src/pairs/sim/kernel_wrapper.py @@ -4,10 +4,10 @@ from pairs.ir.block import Block, KernelBlock class KernelWrapper(): def __init__(self, sim): self.sim = sim - self.kernels = Block(sim, []) + self.kernels = [] def add_kernel_block(self, block): - self.kernels = Block.merge_blocks(self.kernels, KernelBlock(self.sim, block)) + self.kernels.append(KernelBlock(self.sim, block)) def lower(self): return self.kernels diff --git a/src/pairs/sim/particle_simulation.py b/src/pairs/sim/particle_simulation.py index 1456f41..cfcc3a1 100644 --- a/src/pairs/sim/particle_simulation.py +++ b/src/pairs/sim/particle_simulation.py @@ -1,5 +1,5 @@ from pairs.ir.arrays import Arrays -from pairs.ir.block import Block +from pairs.ir.block import Block, KernelBlock from pairs.ir.branches import Filter from pairs.ir.data_types import Type_Int, Type_Float, Type_Vector from pairs.ir.layouts import Layout_AoS @@ -11,7 +11,6 @@ from pairs.mapping.funcs import compute from pairs.sim.arrays import ArraysDecl from pairs.sim.cell_lists import CellLists, CellListsBuild, CellListsStencilBuild from pairs.sim.grid import Grid2D, Grid3D -from pairs.sim.kernel_wrapper import KernelWrapper from pairs.sim.lattice import ParticleLattice from pairs.sim.neighbor_lists import NeighborLists, NeighborListsBuild from pairs.sim.pbc import PBC, UpdatePBC, EnforcePBC, SetupPBC @@ -49,7 +48,7 @@ class ParticleSimulation: self.check_decl_usage = True self.block = Block(self, []) self.setups = SetupWrapper(self) - self.kernels = KernelWrapper(self) + self.kernels = Block(self, []) self.dims = dims self.ntimesteps = timesteps self.expr_id = 0 @@ -146,14 +145,14 @@ class ParticleSimulation: else: yield i, j - self.kernels.add_kernel_block(self.block) + self.kernels.add_statement(KernelBlock(self, self.block)) def particles(self): self.clear_block() for i in ParticleFor(self): yield i - self.kernels.add_kernel_block(self.block) + self.kernels.add_statement(KernelBlock(self, self.block)) def clear_block(self): self.block = Block(self, []) @@ -193,7 +192,7 @@ class ParticleSimulation: (CellListsBuild(self.cell_lists).lower(), 20), (NeighborListsBuild(self.neighbor_lists).lower(), 20), PropertiesResetVolatile(self).lower(), - self.kernels.lower() + self.kernels ]) timestep.add(VTKWrite(self, self.vtk_file, timestep.timestep() + 1).lower()) @@ -224,5 +223,5 @@ class ParticleSimulation: # For this part on, all bin ops are generated without usage verification self.check_decl_usage = False - ASTGraph(self.kernels.lower(), "kernels").render() + ASTGraph(self.kernels, "kernels").render() self.code_gen.generate_program(program) diff --git a/src/pairs/transformations/add_device_copies.py b/src/pairs/transformations/add_device_copies.py index b8d431f..004e504 100644 --- a/src/pairs/transformations/add_device_copies.py +++ b/src/pairs/transformations/add_device_copies.py @@ -42,8 +42,7 @@ class AddDeviceCopies(Mutator): if s is not None: s_id = id(s) if isinstance(s, KernelBlock) and s_id in self.props_to_copy: - for p in self.props.to_copy[s_id]: - new_stmts = new_stmts + DeviceCopy(ast_node.sim, p) + new_stmts = new_stmts + [DeviceCopy(ast_node.sim, ast_node.sim.property(p)) for p in self.props_to_copy[s_id]] new_stmts.append(s) @@ -51,10 +50,12 @@ class AddDeviceCopies(Mutator): return ast_node def mutate_KernelBlock(self, ast_node): - copying_properties = {p for p in ast_node.properties_to_synchronize() if p not in synchronized_props} + ast_node.block = self.mutate(ast_node.block) + copying_properties = {p for p in ast_node.properties_to_synchronize() if p not in self.synchronized_props} self.props_to_copy[id(ast_node)] = copying_properties self.synchronized_props.update(copying_properties) self.synchronized_props -= ast_node.writing_properties() + return ast_node def add_device_copies(ast): -- GitLab