From 1593bd166359a78269bf1decf173e31c167bcd7d Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti Lucio Machado <rafael.r.ravedutti@fau.de>
Date: Tue, 12 Jan 2021 03:39:24 +0100
Subject: [PATCH] Provide better way to write analyses and transformations

Signed-off-by: Rafael Ravedutti Lucio Machado <rafael.r.ravedutti@fau.de>
---
 ast/arrays.py               |   9 ---
 ast/assign.py               |   8 ---
 ast/ast_node.py             |   3 -
 ast/bin_op.py               |  10 ---
 ast/block.py                |   7 --
 ast/branches.py             |   8 ---
 ast/cast.py                 |   4 --
 ast/loops.py                |  10 ---
 ast/math.py                 |   4 --
 ast/memory.py               |  10 ---
 ast/mutator.py              | 139 ++++++++++++++++++++++++++++++++++++
 ast/select.py               |   7 --
 ast/visitor.py              |  43 ++++++-----
 sim/particle_simulation.py  |  15 ++--
 sim/timestep.py             |   3 -
 sim/vtk.py                  |  12 ++--
 transformations/flatten.py  |  35 +++++++++
 transformations/simplify.py |  36 ++++++++++
 18 files changed, 247 insertions(+), 116 deletions(-)
 create mode 100644 ast/mutator.py
 create mode 100644 transformations/flatten.py
 create mode 100644 transformations/simplify.py

diff --git a/ast/arrays.py b/ast/arrays.py
index 308bd12..1ea9014 100644
--- a/ast/arrays.py
+++ b/ast/arrays.py
@@ -175,15 +175,6 @@ class ArrayAccess(ASTTerm):
     def children(self):
         return [self.array] + self.indexes
 
-    def transform(self, fn):
-        self.array = self.array.transform(fn)
-        self.indexes = [i.transform(fn) for i in self.indexes]
-
-        if self.index is not None:
-            self.index = self.index.transform(fn)
-
-        return fn(self)
-
 
 class ArrayDecl(ASTNode):
     def __init__(self, sim, array):
diff --git a/ast/assign.py b/ast/assign.py
index 8ce1d7b..a6cc919 100644
--- a/ast/assign.py
+++ b/ast/assign.py
@@ -31,11 +31,3 @@ class Assign(ASTNode):
         return reduce((lambda x, y: x + y), [
                       [self.assignments[i][0], self.assignments[i][1]]
                       for i in range(0, len(self.assignments))])
-
-    def transform(self, fn):
-        self.assignments = [(
-            self.assignments[i][0].transform(fn),
-            self.assignments[i][1].transform(fn))
-            for i in range(0, len(self.assignments))]
-
-        return fn(self)
diff --git a/ast/ast_node.py b/ast/ast_node.py
index 7e42676..e58afc2 100644
--- a/ast/ast_node.py
+++ b/ast/ast_node.py
@@ -16,6 +16,3 @@ class ASTNode:
 
     def children(self):
         return []
-
-    def transform(self, fn):
-        return fn(self)
diff --git a/ast/bin_op.py b/ast/bin_op.py
index 840ca43..ba82620 100644
--- a/ast/bin_op.py
+++ b/ast/bin_op.py
@@ -17,10 +17,6 @@ class BinOpDef(ASTNode):
     def children(self):
         return [self.bin_op]
 
-    def transform(self, fn):
-        self.bin_op = self.bin_op.transform(fn)
-        return fn(self)
-
 
 class BinOp(ASTNode):
     # BinOp kinds
@@ -177,12 +173,6 @@ class BinOp(ASTNode):
     def children(self):
         return [self.lhs, self.rhs]
 
-    def transform(self, fn):
-        self.lhs = self.lhs.transform(fn)
-        self.rhs = self.rhs.transform(fn)
-        self.bin_op_vector_index_mapping = {i: e.transform(fn) for i, e in self.bin_op_vector_index_mapping.items()}
-        return fn(self)
-
     def __add__(self, other):
         return BinOp(self.sim, self, other, '+')
 
diff --git a/ast/block.py b/ast/block.py
index b7f287e..e85face 100644
--- a/ast/block.py
+++ b/ast/block.py
@@ -1,5 +1,4 @@
 from ast.ast_node import ASTNode
-from ast.visitor import Visitor
 
 
 class Block(ASTNode):
@@ -51,12 +50,6 @@ class Block(ASTNode):
     def children(self):
         return self.stmts
 
-    def transform(self, fn):
-        for i in range(0, len(self.stmts)):
-            self.stmts[i] = self.stmts[i].transform(fn)
-
-        return fn(self)
-
     def merge_blocks(block1, block2):
         assert isinstance(block1, Block), "First block type is not Block!"
         assert isinstance(block2, Block), "Second block type is not Block!"
diff --git a/ast/branches.py b/ast/branches.py
index ab31b0b..a630260 100644
--- a/ast/branches.py
+++ b/ast/branches.py
@@ -37,14 +37,6 @@ class Branch(ASTNode):
         return [self.cond, self.block_if] + \
                ([] if self.block_else is None else [self.block_else])
 
-    def transform(self, fn):
-        self.cond = self.cond.transform(fn)
-        self.block_if = self.block_if.transform(fn)
-        self.block_else = \
-            None if self.block_else is None \
-            else self.block_else.transform(fn)
-        return fn(self)
-
 
 class Filter(Branch):
     def __init__(self, sim, cond):
diff --git a/ast/cast.py b/ast/cast.py
index 0a30700..39a8538 100644
--- a/ast/cast.py
+++ b/ast/cast.py
@@ -25,7 +25,3 @@ class Cast(ASTNode):
 
     def children(self):
         return [self.expr]
-
-    def transform(self, fn):
-        self.expr = self.expr.transform(fn)
-        return fn(self)
diff --git a/ast/loops.py b/ast/loops.py
index 902cec3..e0c645c 100644
--- a/ast/loops.py
+++ b/ast/loops.py
@@ -67,11 +67,6 @@ class For(ASTNode):
     def children(self):
         return [self.iterator, self.block]
 
-    def transform(self, fn):
-        self.iterator = self.iterator.transform(fn)
-        self.block = self.block.transform(fn)
-        return fn(self)
-
 
 class ParticleFor(For):
     def __init__(self, sim, block=None, local_only=True):
@@ -104,11 +99,6 @@ class While(ASTNode):
     def children(self):
         return [self.cond, self.block]
 
-    def transform(self, fn):
-        self.cond = self.cond.transform(fn)
-        self.block = self.block.transform(fn)
-        return fn(self)
-
 
 class NeighborFor():
     def __init__(self, sim, particle, cell_lists):
diff --git a/ast/math.py b/ast/math.py
index 90977b8..4e610cb 100644
--- a/ast/math.py
+++ b/ast/math.py
@@ -18,7 +18,3 @@ class Sqrt(ASTNode):
 
     def children(self):
         return [self.expr]
-
-    def transform(self, fn):
-        self.expr = self.expr.transform(fn)
-        return fn(self)
diff --git a/ast/memory.py b/ast/memory.py
index ef8abfe..3458bf9 100644
--- a/ast/memory.py
+++ b/ast/memory.py
@@ -18,11 +18,6 @@ class Malloc(ASTNode):
     def children(self):
         return [self.array, self.size]
 
-    def transform(self, fn):
-        self.array = self.array.transform(fn)
-        self.size = self.size.transform(fn)
-        return fn(self)
-
 
 class Realloc(ASTNode):
     def __init__(self, sim, array, size):
@@ -35,8 +30,3 @@ class Realloc(ASTNode):
 
     def children(self):
         return [self.array, self.size]
-
-    def transform(self, fn):
-        self.array = self.array.transform(fn)
-        self.size = self.size.transform(fn)
-        return fn(self)
diff --git a/ast/mutator.py b/ast/mutator.py
new file mode 100644
index 0000000..f450773
--- /dev/null
+++ b/ast/mutator.py
@@ -0,0 +1,139 @@
+from ast.arrays import ArrayAccess
+from ast.assign import Assign
+from ast.bin_op import BinOp, BinOpDef
+from ast.block import Block
+from ast.branches import Branch
+from ast.cast import Cast
+from ast.loops import For, While
+from ast.math import Sqrt
+from ast.memory import Malloc, Realloc
+from ast.select import Select
+from sim.timestep import Timestep
+
+
+class Mutator:
+    def __init__(self, ast, max_depth=0):
+        self.ast = ast
+        self.max_depth = 0
+
+    def mutate(self, ast_node=None):
+        if ast_node is None:
+            ast_node = self.ast
+
+        if isinstance(ast_node, ArrayAccess):
+            return self.mutate_ArrayAccess(ast_node)
+
+        elif isinstance(ast_node, Assign):
+            return self.mutate_Assign(ast_node)
+
+        elif isinstance(ast_node, BinOp):
+            return self.mutate_BinOp(ast_node)
+
+        elif isinstance(ast_node, BinOpDef):
+            return self.mutate_BinOpDef(ast_node)
+
+        elif isinstance(ast_node, Block):
+            return self.mutate_Block(ast_node)
+
+        elif isinstance(ast_node, Branch):
+            return self.mutate_Branch(ast_node)
+
+        elif isinstance(ast_node, Cast):
+            return self.mutate_Cast(ast_node)
+
+        elif isinstance(ast_node, For):
+            return self.mutate_For(ast_node)
+
+        elif isinstance(ast_node, Malloc):
+            return self.mutate_Malloc(ast_node)
+
+        elif isinstance(ast_node, Realloc):
+            return self.mutate_Realloc(ast_node)
+
+        elif isinstance(ast_node, Select):
+            return self.mutate_Select(ast_node)
+
+        elif isinstance(ast_node, Sqrt):
+            return self.mutate_Sqrt(ast_node)
+
+        elif isinstance(ast_node, Timestep):
+            return self.mutate_Timestep(ast_node)
+
+        elif isinstance(ast_node, While):
+            return self.mutate_While(ast_node)
+
+        return ast_node
+
+    def mutate_ArrayAccess(self, ast_node):
+        ast_node.array = self.mutate(ast_node.array)
+        ast_node.indexes = [self.mutate(i) for i in ast_node.indexes]
+
+        if ast_node.index is not None:
+            ast_node.index = self.mutate(ast_node.index)
+
+        return ast_node 
+
+    def mutate_Assign(self, ast_node):
+        ast_node.assignments = [
+            (self.mutate(ast_node.assignments[i][0]), self.mutate(ast_node.assignments[i][1]))
+            for i in range(0, len(ast_node.assignments))
+        ]
+        return ast_node
+
+    def mutate_BinOp(self, ast_node):
+        ast_node.lhs = self.mutate(ast_node.lhs)
+        ast_node.rhs = self.mutate(ast_node.rhs)
+        ast_node.bin_op_vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.bin_op_vector_index_mapping.items()}
+        return ast_node
+
+    def mutate_BinOpDef(self, ast_node):
+        ast_node.bin_op = self.mutate(ast_node.bin_op)
+        return ast_node
+
+    def mutate_Block(self, ast_node):
+        ast_node.stmts = [self.mutate(s) for s in ast_node.stmts]
+        return ast_node
+
+    def mutate_Branch(self, ast_node):
+        ast_node.cond = self.mutate(ast_node.cond)
+        ast_node.block_if = self.mutate(ast_node.block_if)
+        ast_node.block_else = None if ast_node.block_else is None else self.mutate(ast_node.block_else)
+        return ast_node
+
+    def mutate_Cast(self, ast_node):
+        ast_node.expr = self.mutate(ast_node.expr)
+        return ast_node
+
+    def mutate_For(self, ast_node):
+        ast_node.iterator = self.mutate(ast_node.iterator)
+        ast_node.block = self.mutate(ast_node.block)
+        return ast_node
+
+    def mutate_Malloc(self, ast_node):
+        ast_node.array = self.mutate(ast_node.array)
+        ast_node.size = self.mutate(ast_node.size)
+        return ast_node
+
+    def mutate_Realloc(self, ast_node):
+        ast_node.array = self.mutate(ast_node.array)
+        ast_node.size = self.mutate(ast_node.size)
+        return ast_node
+
+    def mutate_Select(self, ast_node):
+        ast_node.cond = self.mutate(ast_node.cond)
+        ast_node.expr_if = self.mutate(ast_node.expr_if)
+        ast_node.expr_else = self.mutate(ast_node.expr_else)
+        return ast_node
+
+    def mutate_Sqrt(self, ast_node):
+        ast_node.expr = self.mutate(ast_node.expr)
+        return ast_node
+
+    def mutate_Timestep(self, ast_node):
+        ast_node.block = self.mutate(ast_node.block)
+        return ast_node
+
+    def mutate_While(self, ast_node):
+        ast_node.cond = self.mutate(ast_node.cond)
+        ast_node.block = self.mutate(ast_node.block)
+        return ast_node
diff --git a/ast/select.py b/ast/select.py
index 702bb8f..1fe9d08 100644
--- a/ast/select.py
+++ b/ast/select.py
@@ -12,10 +12,3 @@ class Select(ASTNode):
 
     def children(self):
         return [self.cond, self.expr_if, self.expr_else]
-
-    def transform(self, fn):
-        self.cond = self.cond.transform(fn)
-        self.expr_if = self.expr_if.transform(fn)
-        self.expr_else = self.expr_else.transform(fn)
-        return fn(self)
-
diff --git a/ast/visitor.py b/ast/visitor.py
index 8cf72e9..419b5e7 100644
--- a/ast/visitor.py
+++ b/ast/visitor.py
@@ -1,22 +1,31 @@
-class Visitor:
-    def __init__(self, ast, enter_fn=None, leave_fn=None, max_depth=0):
-        self.ast = ast
-        self.enter_fn = enter_fn
-        self.leave_fn = leave_fn
-        self.max_depth = max_depth
-
-    def visit(self):
-        self.visit_rec(self.ast)
+from ast.bin_op import BinOp
+from ast.loops import For, ParticleFor, While
 
-    def visit_rec(self, ast):
-        if self.enter_fn is not None:
-            self.enter_fn(ast)
 
-        for c in ast.children():
-            self.visit_rec(c)
-
-        if self.leave_fn is not None:
-            self.leave_fn(ast)
+class Visitor:
+    def __init__(self, ast, max_depth=0):
+        self.ast = ast
+        self.max_depth = 0
+
+    def visit(ast_node):
+        for c in ast_node.children():
+            if isinstance(c, Array):
+                self.visit_Array(c)
+            elif isinstance(c, BinOp):
+                self.visit_BinOp(c)
+            elif isinstance(c, (For, ParticleFor, While)):
+                self.visit_Loop(c)
+            else:
+                self.visit(c)
+
+    def visit_Array(self, ast_node):
+        return self.visit(ast_node)
+
+    def visit_BinOp(self, ast_node):
+        return self.visit(ast_node)
+
+    def visit_Loop(self, ast_node):
+        return self.visit(ast_node)
 
     def list_ast(self):
         self.list_elements(self.ast)
diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py
index 12104a3..53f8ef1 100644
--- a/sim/particle_simulation.py
+++ b/sim/particle_simulation.py
@@ -5,10 +5,9 @@ from ast.data_types import Type_Int, Type_Float, Type_Vector
 from ast.layouts import Layout_AoS
 from ast.loops import ParticleFor, NeighborFor
 from ast.properties import Properties
-from ast.transform import Transform
 from ast.variables import Variables
-from sim.arrays import ArraysDecl
 from graph.graphviz import ASTGraph
+from sim.arrays import ArraysDecl
 from sim.cell_lists import CellLists, CellListsBuild, CellListsStencilBuild
 from sim.grid import Grid2D, Grid3D
 from sim.kernel_wrapper import KernelWrapper
@@ -20,6 +19,8 @@ from sim.setup_wrapper import SetupWrapper
 from sim.timestep import Timestep
 from sim.variables import VariablesDecl
 from sim.vtk import VTKWrite
+from transformations.flatten import flatten_property_accesses
+from transformations.simplify import simplify_expressions
 
 
 class ParticleSimulation:
@@ -193,12 +194,10 @@ class ParticleSimulation:
 
         program = Block.merge_blocks(decls, body)
         self.global_scope = program
-        Transform.apply(program, Transform.flatten)
-        Transform.apply(program, Transform.simplify)
-        #Transform.apply(program, Transform.reuse_index_expressions)
-        #Transform.apply(program, Transform.reuse_expr_expressions)
-        #Transform.apply(program, Transform.reuse_array_access_expressions)
-        #Transform.apply(program, Transform.move_loop_invariant_expressions)
+
+        # Transformations
+        flatten_property_accesses(program)
+        simplify_expressions(program)
 
         ASTGraph(self.kernels.lower(), "kernels").generate_and_view()
         self.code_gen.generate_program(self, program)
diff --git a/sim/timestep.py b/sim/timestep.py
index 0d54133..74aa241 100644
--- a/sim/timestep.py
+++ b/sim/timestep.py
@@ -41,6 +41,3 @@ class Timestep:
 
     def as_block(self):
         return Block(self.sim, [self.timestep_loop])
-
-    def transform(self, fn):
-        self.block = self.block.transform(fn)
diff --git a/sim/vtk.py b/sim/vtk.py
index c1620f8..284533a 100644
--- a/sim/vtk.py
+++ b/sim/vtk.py
@@ -1,17 +1,13 @@
 from ast.lit import as_lit_ast
+from ast.ast_node import ASTNode
 
-class VTKWrite:
+
+class VTKWrite(ASTNode):
     vtk_id = 0
 
     def __init__(self, sim, filename, timestep):
-        self.sim = sim
+        super().__init__(sim)
         self.vtk_id = VTKWrite.vtk_id
         self.filename = filename
         self.timestep = as_lit_ast(sim, timestep)
         VTKWrite.vtk_id += 1
-
-    def children(self):
-        return []
-
-    def transform(self, fn):
-        return fn(self)
diff --git a/transformations/flatten.py b/transformations/flatten.py
new file mode 100644
index 0000000..b06ebcb
--- /dev/null
+++ b/transformations/flatten.py
@@ -0,0 +1,35 @@
+from ast.layouts import Layout_AoS, Layout_SoA
+from ast.mutator import Mutator
+
+
+class FlattenPropertyAccesses(Mutator):
+    def __init__(self, ast):
+        super().__init__(ast)
+
+    def mutate_BinOp(self, ast_node):
+        ast_node.lhs = self.mutate(ast_node.lhs)
+        ast_node.rhs = self.mutate(ast_node.rhs)
+
+        if ast_node.is_vector_property_access():
+            layout = ast_node.lhs.layout()
+
+            for i in ast_node.vector_indexes():
+                flat_index = None
+
+                if layout == Layout_AoS:
+                    flat_index = ast_node.rhs * ast_node.sim.dimensions + i
+
+                elif layout == Layout_SoA:
+                    flat_index = i * ast_node.sim.particle_capacity + ast_node.rhs
+
+                else:
+                    raise Exception("Invalid property layout!")
+
+                ast_node.map_vector_index(i, flat_index)
+
+        return ast_node
+
+
+def flatten_property_accesses(ast_node):
+    flatten = FlattenPropertyAccesses(ast_node)
+    flatten.mutate()
diff --git a/transformations/simplify.py b/transformations/simplify.py
new file mode 100644
index 0000000..ff67315
--- /dev/null
+++ b/transformations/simplify.py
@@ -0,0 +1,36 @@
+from ast.data_types import Type_Int
+from ast.lit import Lit
+from ast.mutator import Mutator
+
+
+class SimplifyExpressions(Mutator):
+    def __init__(self, ast):
+        super().__init__(ast)
+
+    def mutate_BinOp(self, ast_node):
+        sim = ast_node.lhs.sim
+        ast_node.lhs = self.mutate(ast_node.lhs)
+        ast_node.rhs = self.mutate(ast_node.rhs)
+        ast_node.bin_op_vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.bin_op_vector_index_mapping.items()}
+
+        if ast_node.op in ['+', '-'] and ast_node.rhs == 0:
+            return ast_node.lhs
+
+        if ast_node.op in ['+'] and ast_node.lhs == 0:
+            return ast_node.rhs
+
+        if ast_node.op in ['*', '/'] and ast_node.rhs == 1:
+            return ast_node.lhs
+
+        if ast_node.op == '*' and ast_node.lhs == 1:
+            return ast_node.rhs
+
+        if ast_node.op == '*' and ast_node.lhs == 0:
+            return Lit(sim, 0 if ast_node.type() == Type_Int else 0.0)
+
+        return ast_node
+
+
+def simplify_expressions(ast_node):
+    simplify = SimplifyExpressions(ast_node)
+    simplify.mutate()
-- 
GitLab