From 1f94548004b1b9f14f714d320d866e26a14693e8 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Fri, 18 Jun 2021 00:42:49 +0200
Subject: [PATCH] First version of prioritize_scalar_ops transformation

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 ir/bin_op.py                             |  8 ++++
 new_syntax.py                            |  2 +
 sim/particle_simulation.py               |  5 ++-
 transformations/prioritize_scalar_ops.py | 50 ++++++++++++++++++++++++
 transformations/set_used_bin_ops.py      | 16 ++++++++
 5 files changed, 80 insertions(+), 1 deletion(-)
 create mode 100644 transformations/prioritize_scalar_ops.py
 create mode 100644 transformations/set_used_bin_ops.py

diff --git a/ir/bin_op.py b/ir/bin_op.py
index 237a8cc..38e2d85 100644
--- a/ir/bin_op.py
+++ b/ir/bin_op.py
@@ -10,6 +10,7 @@ class BinOpDef(ASTNode):
         super().__init__(bin_op.sim)
         self.bin_op = bin_op
         self.bin_op.sim.add_statement(self)
+        self.used = False
 
     def __str__(self):
         return f"BinOpDef<bin_op: self.bin_op>"
@@ -51,6 +52,13 @@ class BinOp(ASTNode):
         self.vector_index_mapping = {}
         self.bin_op_def = BinOpDef(self)
 
+    def reassign(self, lhs, rhs, op):
+        assert self.generated is False, "Error on reassign: BinOp {} already generated!".format(self.bin_op_id)
+        self.lhs = as_lit_ast(self.sim, lhs)
+        self.rhs = as_lit_ast(self.sim, rhs)
+        self.op = op
+        self.bin_op_type = BinOp.infer_type(self.lhs, self.rhs, self.op)
+
     def __str__(self):
         a = self.lhs.id() if isinstance(self.lhs, BinOp) else self.lhs
         b = self.rhs.id() if isinstance(self.rhs, BinOp) else self.rhs
diff --git a/new_syntax.py b/new_syntax.py
index 829a0ea..b5e98b8 100644
--- a/new_syntax.py
+++ b/new_syntax.py
@@ -17,6 +17,8 @@ def rsq(i, j):
 def lj(i, j):
     sr2 = 1.0 / rsq
     sr6 = sr2 * sr2 * sr2 * sigma6
+    #f = 48.0 * sr6 * (sr6 - 0.5) * sr2 * epsilon
+    #force[i] += delta * f
     force[i] += delta * 48.0 * sr6 * (sr6 - 0.5) * sr2 * epsilon
 
 
diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py
index 3dfa57c..a641fdd 100644
--- a/sim/particle_simulation.py
+++ b/sim/particle_simulation.py
@@ -20,10 +20,11 @@ from sim.timestep import Timestep
 from sim.variables import VariablesDecl
 from sim.vtk import VTKWrite
 from transformations.flatten import flatten_property_accesses
+from transformations.prioritize_scalar_ops import prioritaze_scalar_ops
+from transformations.set_used_bin_ops import set_used_bin_ops
 from transformations.simplify import simplify_expressions
 from transformations.LICM import move_loop_invariant_code
 
-
 class ParticleSimulation:
     def __init__(self, code_gen, dims=3, timesteps=100):
         self.code_gen = code_gen
@@ -198,9 +199,11 @@ class ParticleSimulation:
         self.global_scope = program
 
         # Transformations
+        #prioritaze_scalar_ops(program)
         flatten_property_accesses(program)
         simplify_expressions(program)
         move_loop_invariant_code(program)
+        #set_used_bin_ops(program)
 
         ASTGraph(self.kernels.lower(), "kernels").render()
         self.code_gen.generate_program(program)
diff --git a/transformations/prioritize_scalar_ops.py b/transformations/prioritize_scalar_ops.py
new file mode 100644
index 0000000..cf099a8
--- /dev/null
+++ b/transformations/prioritize_scalar_ops.py
@@ -0,0 +1,50 @@
+from ir.bin_op import BinOp
+from ir.data_types import Type_Float, Type_Vector
+from ir.mutator import Mutator
+
+
+class PrioritazeScalarOps(Mutator):
+    def __init__(self, ast):
+        super().__init__(ast)
+
+    def can_rearrange(op1, op2):
+        return op1 == op2 and op1 in ['+', '*']
+
+    def mutate_BinOp(self, ast_node):
+        sim = ast_node.sim
+        ast_node.lhs = self.mutate(ast_node.lhs)
+        ast_node.rhs = self.mutate(ast_node.rhs)
+
+        if ast_node.type() == Type_Vector:
+            lhs = ast_node.lhs
+            rhs = ast_node.rhs
+            op = ast_node.op
+
+            if( isinstance(lhs, BinOp) and lhs.type() == Type_Vector and rhs.type() == Type_Float and \
+                PrioritazeScalarOps.can_rearrange(op, lhs.op) ):
+
+                if lhs.lhs.type() == Type_Vector and lhs.rhs.type() == Type_Float:
+                    ast_node.reassign(lhs.lhs, BinOp(sim, lhs.rhs, rhs, op), op)
+                    #return BinOp(sim, lhs.lhs, BinOp(sim, lhs.rhs, rhs, op), op)
+
+                if lhs.rhs.type() == Type_Vector and lhs.lhs.type() == Type_Float:
+                    ast_node.reassign(lhs.rhs, BinOp(sim, lhs.lhs, rhs, op), op)
+                    #return BinOp(sim, lhs.rhs, BinOp(sim, lhs.lhs, rhs, op), op)
+
+            if( isinstance(rhs, BinOp) and rhs.type() == Type_Vector and lhs.type() == Type_Float and \
+                PrioritazeScalarOps.can_rearrange(op, rhs.op) ):
+
+                if rhs.lhs.type() == Type_Vector and rhs.rhs.type() == Type_Float:
+                    ast_node.reassign(rhs.lhs, BinOp(sim, rhs.rhs, lhs, op), op)
+                    #return BinOp(sim, rhs.lhs, BinOp(sim, rhs.rhs, lhs, op), op)
+
+                if rhs.rhs.type() == Type_Vector and rhs.lhs.type() == Type_Float:
+                    ast_node.reassign(rhs.rhs, BinOp(sim, rhs.lhs, lhs, op), op)
+                    #return BinOp(sim, rhs.rhs, BinOp(sim, rhs.lhs, lhs, op), op)
+
+        return ast_node
+
+
+def prioritaze_scalar_ops(ast_node):
+    prioritaze = PrioritazeScalarOps(ast_node)
+    prioritaze.mutate()
diff --git a/transformations/set_used_bin_ops.py b/transformations/set_used_bin_ops.py
new file mode 100644
index 0000000..64f3d04
--- /dev/null
+++ b/transformations/set_used_bin_ops.py
@@ -0,0 +1,16 @@
+from ir.bin_op import BinOp
+from ir.visitor import Visitor
+
+
+class SetUsedBinOps(Visitor):
+    def __init__(self, ast):
+        super().__init__(ast)
+        self.bin_ops = []
+
+    def visit_BinOp(self, ast_node):
+        ast_node.bin_op_def.used = True
+
+
+def set_used_bin_ops(ast):
+    set_used_binops = SetUsedBinOps(ast)
+    set_used_binops.visit()
-- 
GitLab