From 4be6e79548d68b6027b06cd06fc1b0af2a2897c8 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Tue, 22 Jun 2021 19:12:19 +0200
Subject: [PATCH] Fix prioritaze_scalar_ops transformation

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 code_gen/cgen.py                    | 2 +-
 ir/arrays.py                        | 3 +++
 ir/bin_op.py                        | 5 ++++-
 ir/loops.py                         | 2 +-
 sim/particle_simulation.py          | 8 ++++++--
 sim/vtk.py                          | 3 +++
 transformations/set_used_bin_ops.py | 7 +++++++
 7 files changed, 25 insertions(+), 5 deletions(-)

diff --git a/code_gen/cgen.py b/code_gen/cgen.py
index 6235723..b39bc2f 100644
--- a/code_gen/cgen.py
+++ b/code_gen/cgen.py
@@ -68,7 +68,7 @@ class CGen:
         if isinstance(ast_node, BinOpDef):
             bin_op = ast_node.bin_op
 
-            if not isinstance(bin_op, BinOp):
+            if not isinstance(bin_op, BinOp) or not ast_node.used:
                 return None
 
             if bin_op.inlined is False and bin_op.operator() != '[]' and bin_op.generated is False:
diff --git a/ir/arrays.py b/ir/arrays.py
index 27ffac1..a49e39a 100644
--- a/ir/arrays.py
+++ b/ir/arrays.py
@@ -173,6 +173,9 @@ class ArrayAccess(ASTTerm):
         return self.index.scope()
 
     def children(self):
+        if self.index is not None:
+            return [self.array, self.index]
+
         return [self.array] + self.indexes
 
 
diff --git a/ir/bin_op.py b/ir/bin_op.py
index 38e2d85..34081bb 100644
--- a/ir/bin_op.py
+++ b/ir/bin_op.py
@@ -10,7 +10,7 @@ class BinOpDef(ASTNode):
         super().__init__(bin_op.sim)
         self.bin_op = bin_op
         self.bin_op.sim.add_statement(self)
-        self.used = False
+        self.used = not bin_op.sim.check_bin_ops_usage
 
     def __str__(self):
         return f"BinOpDef<bin_op: self.bin_op>"
@@ -85,6 +85,9 @@ class BinOp(ASTNode):
         mapping = self.vector_index_mapping
         return mapping[index] if index in mapping else as_lit_ast(self.sim, index)
 
+    def mapped_expressions(self):
+        return self.vector_index_mapping.values()
+
     @property
     def vector_indexes(self):
         return self._vector_indexes
diff --git a/ir/loops.py b/ir/loops.py
index a1af262..b4ce423 100644
--- a/ir/loops.py
+++ b/ir/loops.py
@@ -67,7 +67,7 @@ class For(ASTNode):
         self.block.add_statement(stmt)
 
     def children(self):
-        return [self.iterator, self.block]
+        return [self.iterator, self.block, self.min, self.max]
 
 
 class ParticleFor(For):
diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py
index a641fdd..20efef8 100644
--- a/sim/particle_simulation.py
+++ b/sim/particle_simulation.py
@@ -42,6 +42,7 @@ class ParticleSimulation:
         self.scope = []
         self.nested_count = 0
         self.nest = False
+        self.check_bin_ops_usage = True
         self.block = Block(self, [])
         self.setups = SetupWrapper()
         self.kernels = KernelWrapper()
@@ -199,11 +200,14 @@ class ParticleSimulation:
         self.global_scope = program
 
         # Transformations
-        #prioritaze_scalar_ops(program)
+        prioritaze_scalar_ops(program)
         flatten_property_accesses(program)
         simplify_expressions(program)
         move_loop_invariant_code(program)
-        #set_used_bin_ops(program)
+        set_used_bin_ops(program)
+
+        # For this part on, all bin ops are generated without usage verification
+        self.check_bin_ops_usage = False
 
         ASTGraph(self.kernels.lower(), "kernels").render()
         self.code_gen.generate_program(program)
diff --git a/sim/vtk.py b/sim/vtk.py
index c6125c4..6a35b6a 100644
--- a/sim/vtk.py
+++ b/sim/vtk.py
@@ -11,3 +11,6 @@ class VTKWrite(ASTNode):
         self.filename = filename
         self.timestep = as_lit_ast(sim, timestep)
         VTKWrite.vtk_id += 1
+
+    def children(self):
+        return [self.timestep]
diff --git a/transformations/set_used_bin_ops.py b/transformations/set_used_bin_ops.py
index 64f3d04..d26d1bc 100644
--- a/transformations/set_used_bin_ops.py
+++ b/transformations/set_used_bin_ops.py
@@ -7,8 +7,15 @@ class SetUsedBinOps(Visitor):
         super().__init__(ast)
         self.bin_ops = []
 
+    def visit_BinOpDef(self, ast_node):
+        pass
+
     def visit_BinOp(self, ast_node):
         ast_node.bin_op_def.used = True
+        self.visit_children(ast_node)
+        # TODO: These expressions could be automatically included in visitor traversal
+        for vidxs in ast_node.mapped_expressions():
+            self.visit(vidxs)
 
 
 def set_used_bin_ops(ast):
-- 
GitLab