From 48675a8fa69ff5e49f51e3bf23a6d6a49b9678db Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Thu, 14 Oct 2021 02:23:56 +0200
Subject: [PATCH] Fix add_device_copies transformation

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 src/pairs/code_gen/cgen.py                     |  9 ++++++---
 src/pairs/ir/block.py                          | 12 +++++++++---
 src/pairs/ir/mutator.py                        | 12 ++++++++----
 src/pairs/sim/kernel_wrapper.py                |  4 ++--
 src/pairs/sim/particle_simulation.py           | 13 ++++++-------
 src/pairs/transformations/add_device_copies.py |  7 ++++---
 6 files changed, 35 insertions(+), 22 deletions(-)

diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 850d74c..d3501b0 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -1,6 +1,6 @@
 from pairs.ir.assign import Assign
 from pairs.ir.arrays import Array, ArrayAccess, ArrayDecl
-from pairs.ir.block import Block
+from pairs.ir.block import Block, KernelBlock
 from pairs.ir.branches import Branch
 from pairs.ir.cast import Cast
 from pairs.ir.bin_op import BinOp, Decl, VectorAccess
@@ -71,10 +71,8 @@ class CGen:
 
         if isinstance(ast_node, Block):
             self.print.add_ind(4)
-
             for stmt in ast_node.statements():
                 self.generate_statement(stmt)
-
             self.print.add_ind(-4)
 
         # TODO: Why there are Decls for other types?
@@ -134,6 +132,11 @@ class CGen:
         if isinstance(ast_node, DeviceCopy):
             self.print(f"pairs::copy_to_device({ast_node.prop.name()})")
 
+        if isinstance(ast_node, KernelBlock):
+            self.print.add_ind(-4)
+            self.generate_statement(ast_node.block)
+            self.print.add_ind(4) # Workaround for fixing indentation of kernels
+
         if isinstance(ast_node, For):
             iterator = self.generate_expression(ast_node.iterator)
             lower_range = None
diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py
index cc3c50c..d46aa91 100644
--- a/src/pairs/ir/block.py
+++ b/src/pairs/ir/block.py
@@ -30,6 +30,8 @@ class Block(ASTNode):
     def merge_blocks(block1, block2):
         assert isinstance(block1, Block), "First block type is not Block!"
         assert isinstance(block2, Block), "Second block type is not Block!"
+        assert not isinstance(block1, KernelBlock), "Kernel blocks cannot be merged!"
+        assert not isinstance(block2, KernelBlock), "Kernel blocks cannot be merged!"
         return Block(block1.sim, block1.statements() + block2.statements())
 
     def from_list(sim, block_list):
@@ -43,9 +45,10 @@ class Block(ASTNode):
         return result_block
 
 
-class KernelBlock(Block):
-    def __init__(self, sim, stmts, run_on_host=False):
-        super().__init__(sim, stmts)
+class KernelBlock(ASTNode):
+    def __init__(self, sim, block, run_on_host=False):
+        super().__init__(sim)
+        self.block = block if isinstance(block, Block) else Block(sim, block)
         self.run_on_host = run_on_host
         self.props_accessed = {}
 
@@ -57,6 +60,9 @@ class KernelBlock(Block):
         elif oper not in self.props_accessed[prop_key]:
             self.props_accessed[prop_key] += oper
 
+    def children(self):
+        return [self.block]
+
     def properties_to_synchronize(self):
         return {p for p in self.props_accessed if self.props_accessed[p][0] == 'r'}
 
diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py
index 9421736..293c5ec 100644
--- a/src/pairs/ir/mutator.py
+++ b/src/pairs/ir/mutator.py
@@ -46,6 +46,10 @@ class Mutator:
         ast_node.block_else = None if ast_node.block_else is None else self.mutate(ast_node.block_else)
         return ast_node
 
+    def mutate_Cast(self, ast_node):
+        ast_node.expr = self.mutate(ast_node.expr)
+        return ast_node
+
     def mutate_Decl(self, ast_node):
         ast_node.elem = self.mutate(ast_node.elem)
         return ast_node
@@ -53,15 +57,15 @@ class Mutator:
     def mutate_Filter(self, ast_node):
         return self.mutate_Branch(ast_node)
 
-    def mutate_Cast(self, ast_node):
-        ast_node.expr = self.mutate(ast_node.expr)
-        return ast_node
-
     def mutate_For(self, ast_node):
         ast_node.iterator = self.mutate(ast_node.iterator)
         ast_node.block = self.mutate(ast_node.block)
         return ast_node
 
+    def mutate_KernelBlock(self, ast_node):
+        ast_node.block = self.mutate(ast_node.block)
+        return ast_node
+
     def mutate_ParticleFor(self, ast_node):
         return self.mutate_For(ast_node)
 
diff --git a/src/pairs/sim/kernel_wrapper.py b/src/pairs/sim/kernel_wrapper.py
index 53cd870..2e66b55 100644
--- a/src/pairs/sim/kernel_wrapper.py
+++ b/src/pairs/sim/kernel_wrapper.py
@@ -4,10 +4,10 @@ from pairs.ir.block import Block, KernelBlock
 class KernelWrapper():
     def __init__(self, sim):
         self.sim = sim
-        self.kernels = Block(sim, [])
+        self.kernels = []
 
     def add_kernel_block(self, block):
-        self.kernels = Block.merge_blocks(self.kernels, KernelBlock(self.sim, block))
+        self.kernels.append(KernelBlock(self.sim, block))
 
     def lower(self):
         return self.kernels
diff --git a/src/pairs/sim/particle_simulation.py b/src/pairs/sim/particle_simulation.py
index 1456f41..cfcc3a1 100644
--- a/src/pairs/sim/particle_simulation.py
+++ b/src/pairs/sim/particle_simulation.py
@@ -1,5 +1,5 @@
 from pairs.ir.arrays import Arrays
-from pairs.ir.block import Block
+from pairs.ir.block import Block, KernelBlock
 from pairs.ir.branches import Filter
 from pairs.ir.data_types import Type_Int, Type_Float, Type_Vector
 from pairs.ir.layouts import Layout_AoS
@@ -11,7 +11,6 @@ from pairs.mapping.funcs import compute
 from pairs.sim.arrays import ArraysDecl
 from pairs.sim.cell_lists import CellLists, CellListsBuild, CellListsStencilBuild
 from pairs.sim.grid import Grid2D, Grid3D
-from pairs.sim.kernel_wrapper import KernelWrapper
 from pairs.sim.lattice import ParticleLattice
 from pairs.sim.neighbor_lists import NeighborLists, NeighborListsBuild
 from pairs.sim.pbc import PBC, UpdatePBC, EnforcePBC, SetupPBC
@@ -49,7 +48,7 @@ class ParticleSimulation:
         self.check_decl_usage = True
         self.block = Block(self, [])
         self.setups = SetupWrapper(self)
-        self.kernels = KernelWrapper(self)
+        self.kernels = Block(self, [])
         self.dims = dims
         self.ntimesteps = timesteps
         self.expr_id = 0
@@ -146,14 +145,14 @@ class ParticleSimulation:
                 else:
                     yield i, j
 
-        self.kernels.add_kernel_block(self.block)
+        self.kernels.add_statement(KernelBlock(self, self.block))
 
     def particles(self):
         self.clear_block()
         for i in ParticleFor(self):
             yield i
 
-        self.kernels.add_kernel_block(self.block)
+        self.kernels.add_statement(KernelBlock(self, self.block))
 
     def clear_block(self):
         self.block = Block(self, [])
@@ -193,7 +192,7 @@ class ParticleSimulation:
             (CellListsBuild(self.cell_lists).lower(), 20),
             (NeighborListsBuild(self.neighbor_lists).lower(), 20),
             PropertiesResetVolatile(self).lower(),
-            self.kernels.lower()
+            self.kernels
         ])
 
         timestep.add(VTKWrite(self, self.vtk_file, timestep.timestep() + 1).lower())
@@ -224,5 +223,5 @@ class ParticleSimulation:
         # For this part on, all bin ops are generated without usage verification
         self.check_decl_usage = False
 
-        ASTGraph(self.kernels.lower(), "kernels").render()
+        ASTGraph(self.kernels, "kernels").render()
         self.code_gen.generate_program(program)
diff --git a/src/pairs/transformations/add_device_copies.py b/src/pairs/transformations/add_device_copies.py
index b8d431f..004e504 100644
--- a/src/pairs/transformations/add_device_copies.py
+++ b/src/pairs/transformations/add_device_copies.py
@@ -42,8 +42,7 @@ class AddDeviceCopies(Mutator):
             if s is not None:
                 s_id = id(s)
                 if isinstance(s, KernelBlock) and s_id in self.props_to_copy:
-                    for p in self.props.to_copy[s_id]:
-                        new_stmts = new_stmts + DeviceCopy(ast_node.sim, p)
+                    new_stmts = new_stmts + [DeviceCopy(ast_node.sim, ast_node.sim.property(p)) for p in self.props_to_copy[s_id]]
 
                 new_stmts.append(s)
 
@@ -51,10 +50,12 @@ class AddDeviceCopies(Mutator):
         return ast_node
 
     def mutate_KernelBlock(self, ast_node):
-        copying_properties = {p for p in ast_node.properties_to_synchronize() if p not in synchronized_props}
+        ast_node.block = self.mutate(ast_node.block)
+        copying_properties = {p for p in ast_node.properties_to_synchronize() if p not in self.synchronized_props}
         self.props_to_copy[id(ast_node)] = copying_properties
         self.synchronized_props.update(copying_properties)
         self.synchronized_props -= ast_node.writing_properties()
+        return ast_node
 
 
 def add_device_copies(ast):
-- 
GitLab