From 8c1a76cd1324ee1248cce3b6cb589ec9818fe8e0 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Fri, 15 Oct 2021 00:58:29 +0200
Subject: [PATCH] Define lowering and adjacent blocks merging as
 transformations

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 src/pairs/ir/block.py                         |  4 +--
 src/pairs/ir/mutator.py                       |  4 +++
 src/pairs/sim/arrays.py                       |  5 +--
 src/pairs/sim/cell_lists.py                   |  9 ++---
 src/pairs/sim/lattice.py                      |  5 +--
 src/pairs/sim/lowerable.py                    |  6 ++++
 src/pairs/sim/neighbor_lists.py               |  5 +--
 src/pairs/sim/pbc.py                          | 13 +++----
 src/pairs/sim/properties.py                   |  5 +--
 src/pairs/sim/read_from_file.py               |  5 +--
 src/pairs/sim/simulation.py                   | 34 +++++++++++--------
 src/pairs/sim/variables.py                    |  5 +--
 src/pairs/sim/vtk.py                          | 10 +++---
 src/pairs/transformations/lower.py            | 23 +++++++++++++
 .../transformations/merge_adjacent_blocks.py  | 25 ++++++++++++++
 .../transformations/prioritize_scalar_ops.py  | 12 +++----
 16 files changed, 118 insertions(+), 52 deletions(-)
 create mode 100644 src/pairs/sim/lowerable.py
 create mode 100644 src/pairs/transformations/lower.py
 create mode 100644 src/pairs/transformations/merge_adjacent_blocks.py

diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py
index ff75903..b9c04b8 100644
--- a/src/pairs/ir/block.py
+++ b/src/pairs/ir/block.py
@@ -74,10 +74,8 @@ class Block(ASTNode):
         for block in block_list:
             if isinstance(block, Block):
                 result_block = Block.merge_blocks(result_block, block)
-            elif isinstance(block, KernelBlock):
-                result_block.add_statement(block)
             else:
-                raise Exception("Element in list is not Block!")
+                result_block.add_statement(block)
 
         return result_block
 
diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py
index 293c5ec..cb394cf 100644
--- a/src/pairs/ir/mutator.py
+++ b/src/pairs/ir/mutator.py
@@ -15,6 +15,10 @@ class Mutator:
         if method is not None:
             return method(ast_node)
 
+        method_unknown = self.get_method("mutate_Unknown")
+        if method_unknown is not None:
+            return method_unknown(ast_node)
+
         return ast_node
 
     def mutate_ArrayAccess(self, ast_node):
diff --git a/src/pairs/sim/arrays.py b/src/pairs/sim/arrays.py
index 7167baf..5adcaec 100644
--- a/src/pairs/sim/arrays.py
+++ b/src/pairs/sim/arrays.py
@@ -1,11 +1,12 @@
 from pairs.ir.block import pairs_block
 from pairs.ir.memory import Malloc
 from pairs.ir.arrays import ArrayDecl
+from pairs.sim.lowerable import Lowerable
 
 
-class ArraysDecl:
+class ArraysDecl(Lowerable):
     def __init__(self, sim):
-        self.sim = sim
+        super().__init__(sim)
 
     @pairs_block
     def lower(self):
diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py
index 6657cd6..1a36a04 100644
--- a/src/pairs/sim/cell_lists.py
+++ b/src/pairs/sim/cell_lists.py
@@ -8,6 +8,7 @@ from pairs.ir.data_types import Type_Int
 from pairs.ir.math import Ceil
 from pairs.ir.loops import For, ParticleFor
 from pairs.ir.utils import Print
+from pairs.sim.lowerable import Lowerable
 from pairs.sim.resize import Resize
 
 
@@ -30,9 +31,9 @@ class CellLists:
         self.particle_cell = self.sim.add_array('particle_cell', self.sim.particle_capacity, Type_Int)
 
 
-class CellListsStencilBuild:
+class CellListsStencilBuild(Lowerable):
     def __init__(self, sim, cell_lists):
-        self.sim = sim
+        super().__init__(sim)
         self.cell_lists = cell_lists
 
     @pairs_device_block
@@ -63,9 +64,9 @@ class CellListsStencilBuild:
                         cl.nstencil.set(cl.nstencil + 1)
 
 
-class CellListsBuild:
+class CellListsBuild(Lowerable):
     def __init__(self, sim, cell_lists):
-        self.sim = sim
+        super().__init__(sim)
         self.cell_lists = cell_lists
 
     @pairs_device_block
diff --git a/src/pairs/sim/lattice.py b/src/pairs/sim/lattice.py
index 8977f87..e1faa01 100644
--- a/src/pairs/sim/lattice.py
+++ b/src/pairs/sim/lattice.py
@@ -1,11 +1,12 @@
 from pairs.ir.block import pairs_block
 from pairs.ir.data_types import Type_Vector
 from pairs.ir.loops import For
+from pairs.sim.lowerable import Lowerable
 
 
-class ParticleLattice():
+class ParticleLattice(Lowerable):
     def __init__(self, sim, grid, spacing, props, positions):
-        self.sim = sim
+        super().__init__(sim)
         self.grid = grid
         self.spacing = spacing
         self.props = props
diff --git a/src/pairs/sim/lowerable.py b/src/pairs/sim/lowerable.py
new file mode 100644
index 0000000..9a2698c
--- /dev/null
+++ b/src/pairs/sim/lowerable.py
@@ -0,0 +1,6 @@
+class Lowerable:
+    def __init__(self, sim):
+        self.sim = sim
+
+    def lower(self):
+        raise Exception("Error: lower() method must be implemented for Lowerable inherited classes!")
diff --git a/src/pairs/sim/neighbor_lists.py b/src/pairs/sim/neighbor_lists.py
index 13b5682..798a243 100644
--- a/src/pairs/sim/neighbor_lists.py
+++ b/src/pairs/sim/neighbor_lists.py
@@ -3,6 +3,7 @@ from pairs.ir.branches import Branch, Filter
 from pairs.ir.data_types import Type_Int
 from pairs.ir.loops import For, ParticleFor, NeighborFor
 from pairs.ir.utils import Print
+from pairs.sim.lowerable import Lowerable
 from pairs.sim.resize import Resize
 
 
@@ -15,9 +16,9 @@ class NeighborLists:
         self.numneighs = self.sim.add_array('numneighs', self.sim.particle_capacity, Type_Int)
 
 
-class NeighborListsBuild:
+class NeighborListsBuild(Lowerable):
     def __init__(self, sim, neighbor_lists):
-        self.sim = sim
+        super().__init__(sim)
         self.neighbor_lists = neighbor_lists
 
     @pairs_device_block
diff --git a/src/pairs/sim/pbc.py b/src/pairs/sim/pbc.py
index 76886ad..1a5d480 100644
--- a/src/pairs/sim/pbc.py
+++ b/src/pairs/sim/pbc.py
@@ -4,6 +4,7 @@ from pairs.ir.data_types import Type_Int
 from pairs.ir.loops import For, ParticleFor
 from pairs.ir.utils import Print
 from pairs.ir.select import Select
+from pairs.sim.lowerable import Lowerable
 from pairs.sim.resize import Resize
 
 
@@ -19,9 +20,9 @@ class PBC:
         self.pbc_mult = sim.add_array('pbc_mult', [self.pbc_capacity, sim.ndims()], Type_Int)
 
 
-class UpdatePBC:
+class UpdatePBC(Lowerable):
     def __init__(self, sim, pbc):
-        self.sim = sim
+        super().__init__(sim)
         self.pbc = pbc
 
     @pairs_device_block
@@ -42,9 +43,9 @@ class UpdatePBC:
                 positions[nlocal + i][d].set(positions[pbc_map[i]][d] + pbc_mult[i][d] * grid.length(d))
 
 
-class EnforcePBC:
+class EnforcePBC(Lowerable):
     def __init__(self, sim, pbc):
-        self.sim = sim
+        super().__init__(sim)
         self.pbc = pbc
 
     @pairs_device_block
@@ -64,9 +65,9 @@ class EnforcePBC:
                     positions[i][d].sub(grid.length(d))
 
 
-class SetupPBC:
+class SetupPBC(Lowerable):
     def __init__(self, sim, pbc):
-        self.sim = sim
+        super().__init__(sim)
         self.pbc = pbc
 
     @pairs_device_block
diff --git a/src/pairs/sim/properties.py b/src/pairs/sim/properties.py
index 3892bbf..35a24f3 100644
--- a/src/pairs/sim/properties.py
+++ b/src/pairs/sim/properties.py
@@ -4,11 +4,12 @@ from pairs.ir.loops import ParticleFor
 from pairs.ir.memory import Malloc, Realloc
 from pairs.ir.properties import RegisterProperty, UpdateProperty
 from pairs.ir.utils import Print
+from pairs.sim.lowerable import Lowerable
 from functools import reduce
 import operator
 
 
-class PropertiesAlloc:
+class PropertiesAlloc(Lowerable):
     def __init__(self, sim, realloc=False):
         self.sim = sim
         self.realloc = realloc
@@ -33,7 +34,7 @@ class PropertiesAlloc:
                 RegisterProperty(self.sim, p, sizes)
 
 
-class PropertiesResetVolatile:
+class PropertiesResetVolatile(Lowerable):
     def __init__(self, sim):
         self.sim = sim
 
diff --git a/src/pairs/sim/read_from_file.py b/src/pairs/sim/read_from_file.py
index 6f16cb0..606f79a 100644
--- a/src/pairs/sim/read_from_file.py
+++ b/src/pairs/sim/read_from_file.py
@@ -3,11 +3,12 @@ from pairs.ir.data_types import Type_Float
 from pairs.ir.functions import Call_Int
 from pairs.ir.properties import PropertyList
 from pairs.sim.grid import MutableGrid
+from pairs.sim.lowerable import Lowerable
 
 
-class ReadFromFile():
+class ReadFromFile(Lowerable):
     def __init__(self, sim, filename, props):
-        self.sim = sim
+        super().__init__(sim)
         self.filename = filename
         self.props = PropertyList(sim, props)
         self.grid = MutableGrid(sim, sim.ndims())
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 238e392..f8c3461 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -20,10 +20,12 @@ from pairs.sim.timestep import Timestep
 from pairs.sim.variables import VariablesDecl
 from pairs.sim.vtk import VTKWrite
 from pairs.transformations.add_device_copies import add_device_copies
-from pairs.transformations.prioritize_scalar_ops import prioritaze_scalar_ops
+from pairs.transformations.prioritize_scalar_ops import prioritize_scalar_ops
 from pairs.transformations.set_used_bin_ops import set_used_bin_ops
 from pairs.transformations.simplify import simplify_expressions
 from pairs.transformations.LICM import move_loop_invariant_code
+from pairs.transformations.lower import lower_everything
+from pairs.transformations.merge_adjacent_blocks import merge_adjacent_blocks
 
 
 class Simulation:
@@ -107,12 +109,12 @@ class Simulation:
     def create_particle_lattice(self, grid, spacing, props={}):
         positions = self.property('position')
         lattice = ParticleLattice(self, grid, spacing, props, positions)
-        self.setups.add_statement(lattice.lower())
+        self.setups.add_statement(lattice)
 
     def from_file(self, filename, prop_names):
         props = [self.property(prop_name) for prop_name in prop_names]
         read_object = ReadFromFile(self, filename, props)
-        self.setups.add_statement(read_object.lower())
+        self.setups.add_statement(read_object)
         self.grid = read_object.grid
 
     def create_cell_lists(self, spacing, cutoff_radius):
@@ -186,34 +188,36 @@ class Simulation:
 
     def generate(self):
         timestep = Timestep(self, self.ntimesteps, [
-            (EnforcePBC(self, self.pbc).lower(), 20),
-            (SetupPBC(self, self.pbc).lower(), UpdatePBC(self, self.pbc).lower(), 20),
-            (CellListsBuild(self, self.cell_lists).lower(), 20),
-            (NeighborListsBuild(self, self.neighbor_lists).lower(), 20),
-            PropertiesResetVolatile(self).lower(),
+            (EnforcePBC(self, self.pbc), 20),
+            (SetupPBC(self, self.pbc), UpdatePBC(self, self.pbc), 20),
+            (CellListsBuild(self, self.cell_lists), 20),
+            (NeighborListsBuild(self, self.neighbor_lists), 20),
+            PropertiesResetVolatile(self),
             self.kernels
         ])
 
-        timestep.add(VTKWrite(self, self.vtk_file, timestep.timestep() + 1).lower())
+        timestep.add(VTKWrite(self, self.vtk_file, timestep.timestep() + 1))
 
         body = Block.from_list(self, [
             self.setups,
-            CellListsStencilBuild(self, self.cell_lists).lower(),
-            VTKWrite(self, self.vtk_file, 0).lower(),
+            CellListsStencilBuild(self, self.cell_lists),
+            VTKWrite(self, self.vtk_file, 0),
             timestep.as_block()
         ])
 
         decls = Block.from_list(self, [
-            VariablesDecl(self).lower(),
-            ArraysDecl(self).lower(),
-            PropertiesAlloc(self).lower(),
+            VariablesDecl(self),
+            ArraysDecl(self),
+            PropertiesAlloc(self),
         ])
 
         program = Block.merge_blocks(decls, body)
         self.global_scope = program
 
         # Transformations
-        prioritaze_scalar_ops(program)
+        lower_everything(program)
+        merge_adjacent_blocks(program)
+        prioritize_scalar_ops(program)
         simplify_expressions(program)
         move_loop_invariant_code(program)
         set_used_bin_ops(program)
diff --git a/src/pairs/sim/variables.py b/src/pairs/sim/variables.py
index 1d9168f..f7b512f 100644
--- a/src/pairs/sim/variables.py
+++ b/src/pairs/sim/variables.py
@@ -1,10 +1,11 @@
 from pairs.ir.block import pairs_block
 from pairs.ir.variables import VarDecl
+from pairs.sim.lowerable import Lowerable
 
 
-class VariablesDecl:
+class VariablesDecl(Lowerable):
     def __init__(self, sim):
-        self.sim = sim
+        super().__init__(sim)
 
     @pairs_block
     def lower(self):
diff --git a/src/pairs/sim/vtk.py b/src/pairs/sim/vtk.py
index 86de551..3fbf888 100644
--- a/src/pairs/sim/vtk.py
+++ b/src/pairs/sim/vtk.py
@@ -1,22 +1,20 @@
 from pairs.ir.ast_node import ASTNode
+from pairs.ir.block import pairs_block
 from pairs.ir.functions import Call_Void
 from pairs.ir.lit import as_lit_ast
+from pairs.sim.lowerable import Lowerable
 
 
-class VTKWrite(ASTNode):
+class VTKWrite(Lowerable):
     def __init__(self, sim, filename, timestep):
         super().__init__(sim)
         self.filename = filename
         self.timestep = as_lit_ast(sim, timestep)
 
+    @pairs_block
     def lower(self):
         nlocal = self.sim.nlocal
         npbc = self.sim.pbc.npbc
-        self.sim.clear_block()
         nall = nlocal + npbc
         Call_Void(self.sim, "pairs::vtk_write_data", [self.filename + "_local", 0, nlocal, self.timestep])
         Call_Void(self.sim, "pairs::vtk_write_data", [self.filename + "_pbc", nlocal, nall, self.timestep])
-        return self.sim.block
-
-    def children(self):
-        return [self.timestep]
diff --git a/src/pairs/transformations/lower.py b/src/pairs/transformations/lower.py
new file mode 100644
index 0000000..c2fb14f
--- /dev/null
+++ b/src/pairs/transformations/lower.py
@@ -0,0 +1,23 @@
+from pairs.ir.mutator import Mutator
+from pairs.sim.lowerable import Lowerable
+
+
+class Lower(Mutator):
+    def __init__(self, ast):
+        super().__init__(ast)
+        self.lowered_nodes = 0
+
+    def mutate_Unknown(self, ast_node):
+        if isinstance(ast_node, Lowerable):
+            self.lowered_nodes += 1
+            return ast_node.lower()
+
+        return ast_node
+
+
+def lower_everything(ast):
+    nlowered = 1
+    while nlowered > 0:
+        lower = Lower(ast)
+        lower.mutate()
+        nlowered = lower.lowered_nodes
diff --git a/src/pairs/transformations/merge_adjacent_blocks.py b/src/pairs/transformations/merge_adjacent_blocks.py
new file mode 100644
index 0000000..1095144
--- /dev/null
+++ b/src/pairs/transformations/merge_adjacent_blocks.py
@@ -0,0 +1,25 @@
+from pairs.ir.block import Block
+from pairs.ir.mutator import Mutator
+
+
+class MergeAdjacentBlocks(Mutator):
+    def __init__(self, ast):
+        super().__init__(ast)
+
+    def mutate_Block(self, ast_node):
+        new_stmts = []
+        stmts = [self.mutate(s) for s in ast_node.stmts]
+
+        for s in stmts:
+            if isinstance(s, Block):
+                new_stmts = new_stmts + s.statements()
+            else:
+                new_stmts.append(s)
+
+        ast_node.stmts = new_stmts 
+        return ast_node
+
+
+def merge_adjacent_blocks(ast):
+    merge = MergeAdjacentBlocks(ast)
+    merge.mutate()
diff --git a/src/pairs/transformations/prioritize_scalar_ops.py b/src/pairs/transformations/prioritize_scalar_ops.py
index 47957f6..caeb87d 100644
--- a/src/pairs/transformations/prioritize_scalar_ops.py
+++ b/src/pairs/transformations/prioritize_scalar_ops.py
@@ -3,7 +3,7 @@ from pairs.ir.data_types import Type_Float, Type_Vector
 from pairs.ir.mutator import Mutator
 
 
-class PrioritazeScalarOps(Mutator):
+class PrioritizeScalarOps(Mutator):
     def __init__(self, ast):
         super().__init__(ast)
 
@@ -21,7 +21,7 @@ class PrioritazeScalarOps(Mutator):
             op = ast_node.op
 
             if( isinstance(lhs, BinOp) and lhs.type() == Type_Vector and rhs.type() == Type_Float and \
-                PrioritazeScalarOps.can_rearrange(op, lhs.op) ):
+                PrioritizeScalarOps.can_rearrange(op, lhs.op) ):
 
                 if lhs.lhs.type() == Type_Vector and lhs.rhs.type() == Type_Float:
                     ast_node.reassign(lhs.lhs, BinOp(sim, lhs.rhs, rhs, op), op)
@@ -32,7 +32,7 @@ class PrioritazeScalarOps(Mutator):
                     #return BinOp(sim, lhs.rhs, BinOp(sim, lhs.lhs, rhs, op), op)
 
             if( isinstance(rhs, BinOp) and rhs.type() == Type_Vector and lhs.type() == Type_Float and \
-                PrioritazeScalarOps.can_rearrange(op, rhs.op) ):
+                PrioritizeScalarOps.can_rearrange(op, rhs.op) ):
 
                 if rhs.lhs.type() == Type_Vector and rhs.rhs.type() == Type_Float:
                     ast_node.reassign(rhs.lhs, BinOp(sim, rhs.rhs, lhs, op), op)
@@ -45,6 +45,6 @@ class PrioritazeScalarOps(Mutator):
         return ast_node
 
 
-def prioritaze_scalar_ops(ast_node):
-    prioritaze = PrioritazeScalarOps(ast_node)
-    prioritaze.mutate()
+def prioritize_scalar_ops(ast_node):
+    prioritize = PrioritizeScalarOps(ast_node)
+    prioritize.mutate()
-- 
GitLab