From 02327683ca0bc51926b0974bf92ea3698d6f48e9 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Thu, 26 May 2022 00:21:24 +0200
Subject: [PATCH] Add new block transformations and fix small issues

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 examples/lj_func.py                   |  4 +-
 runtime/devices/cuda.hpp              |  4 +-
 runtime/read_from_file.hpp            |  2 +-
 src/pairs/analysis/__init__.py        | 13 +++++-
 src/pairs/analysis/bin_ops.py         | 19 +++++++++
 src/pairs/analysis/blocks.py          | 60 +++++++++++++++++++++++++++
 src/pairs/analysis/devices.py         |  9 ++--
 src/pairs/ir/bin_op.py                |  1 +
 src/pairs/ir/module.py                |  5 +++
 src/pairs/sim/pbc.py                  | 16 ++++---
 src/pairs/transformations/__init__.py |  9 +++-
 src/pairs/transformations/blocks.py   | 18 ++++++++
 12 files changed, 140 insertions(+), 20 deletions(-)

diff --git a/examples/lj_func.py b/examples/lj_func.py
index c7e5618..61c8119 100644
--- a/examples/lj_func.py
+++ b/examples/lj_func.py
@@ -31,6 +31,6 @@ psim.periodic(2.8)
 psim.vtk_output("output/test")
 psim.compute(lj, cutoff_radius, {'sigma6': sigma6, 'epsilon': epsilon})
 psim.compute(euler, symbols={'dt': dt})
-#psim.target(pairs.target_cpu())
-psim.target(pairs.target_gpu())
+psim.target(pairs.target_cpu())
+#psim.target(pairs.target_gpu())
 psim.generate()
diff --git a/runtime/devices/cuda.hpp b/runtime/devices/cuda.hpp
index 873f040..2e8236a 100644
--- a/runtime/devices/cuda.hpp
+++ b/runtime/devices/cuda.hpp
@@ -16,13 +16,13 @@ inline void cuda_assert(cudaError_t err, const char *file, int line) {
     }
 }
 
-__host__ __device__ void *device_alloc(size_t size) {
+__host__ void *device_alloc(size_t size) {
     void *ptr;
     CUDA_ASSERT(cudaMalloc(&ptr, size));
     return ptr;
 }
 
-__host__ __device__ void *device_realloc(void *ptr, size_t size) {
+__host__ void *device_realloc(void *ptr, size_t size) {
     void *new_ptr;
     CUDA_ASSERT(cudaFree(ptr));
     CUDA_ASSERT(cudaMalloc(&new_ptr, size));
diff --git a/runtime/read_from_file.hpp b/runtime/read_from_file.hpp
index 3e47403..749f984 100644
--- a/runtime/read_from_file.hpp
+++ b/runtime/read_from_file.hpp
@@ -47,7 +47,7 @@ size_t read_particle_data(PairsSim *ps, const char *filename, double *grid_buffe
                         float_ptr(n) = std::stod(in0);
                     } else {
                         std::cerr << "read_particle_data(): Invalid property type!" << std::endl;
-                        return -1;
+                        return 0;
                     }
                 }
 
diff --git a/src/pairs/analysis/__init__.py b/src/pairs/analysis/__init__.py
index 8752a40..169676c 100644
--- a/src/pairs/analysis/__init__.py
+++ b/src/pairs/analysis/__init__.py
@@ -1,5 +1,5 @@
-from pairs.analysis.bin_ops import SetBinOpTerminals, SetUsedBinOps
-from pairs.analysis.blocks import SetBlockVariants, SetParentBlock
+from pairs.analysis.bin_ops import ResetInPlaceBinOps, SetBinOpTerminals, SetInPlaceBinOps, SetUsedBinOps
+from pairs.analysis.blocks import SetBlockVariants, SetExprOwnerBlock, SetParentBlock
 from pairs.analysis.devices import FetchKernelReferences
 from pairs.analysis.modules import FetchModulesReferences
 
@@ -11,6 +11,9 @@ class Analysis:
         self._set_bin_op_terminals = SetBinOpTerminals(ast)
         self._set_block_variants = SetBlockVariants(ast)
         self._set_parent_block = SetParentBlock(ast)
+        self._set_expressions_owner_block = SetExprOwnerBlock(ast)
+        self._reset_in_place_bin_ops = ResetInPlaceBinOps(ast)
+        self._set_in_place_bin_ops = SetInPlaceBinOps(ast)
         self._fetch_kernel_references = FetchKernelReferences(ast)
         self._fetch_modules_references = FetchModulesReferences(ast)
 
@@ -26,7 +29,13 @@ class Analysis:
     def set_parent_block(self):
         self._set_parent_block.visit()
 
+    def set_expressions_owner_block(self):
+        self._set_expressions_owner_block.visit()
+        return (self._set_expressions_owner_block.ownership, self._set_expressions_owner_block.expressions_to_lift)
+
     def fetch_kernel_references(self):
+        self._reset_in_place_bin_ops.visit()
+        self._set_in_place_bin_ops.visit()
         self._fetch_kernel_references.visit()
 
     def fetch_modules_references(self):
diff --git a/src/pairs/analysis/bin_ops.py b/src/pairs/analysis/bin_ops.py
index e19a3a2..f012b7b 100644
--- a/src/pairs/analysis/bin_ops.py
+++ b/src/pairs/analysis/bin_ops.py
@@ -1,3 +1,4 @@
+from pairs.ir.bin_op import BinOp
 from pairs.ir.visitor import Visitor
 
 
@@ -60,3 +61,21 @@ class SetUsedBinOps(Visitor):
         ast_node.decl.used = not self.writing
         self.writing = False
         self.visit_children(ast_node)
+
+
+class ResetInPlaceBinOps(Visitor):
+    def __init__(self, ast):
+        super().__init__(ast)
+
+    def visit_BinOp(self, ast_node):
+        ast_node.in_place = True
+        self.visit_children(ast_node)
+
+
+class SetInPlaceBinOps(Visitor):
+    def __init__(self, ast):
+        super().__init__(ast)
+
+    def visit_Decl(self, ast_node):
+        if isinstance(ast_node.elem, BinOp):
+            ast_node.elem.in_place = False
diff --git a/src/pairs/analysis/blocks.py b/src/pairs/analysis/blocks.py
index ec4b7b6..8c8f5f3 100644
--- a/src/pairs/analysis/blocks.py
+++ b/src/pairs/analysis/blocks.py
@@ -119,3 +119,63 @@ class SetParentBlock(Visitor):
         assert isinstance(ast_node, (For, While)), "Node must be a loop!"
         loop_id = id(ast_node)
         return self.parents[loop_id] if loop_id in self.parents else None
+
+
+class SetExprOwnerBlock(Visitor):
+    def __init__(self, ast):
+        super().__init__(ast)
+        self.ownership = {}
+        self.expressions_to_lift = []
+        self.block_level = {}
+        self.block_parent = {}
+        self.block_stack = []
+
+    def common_parent_block(self, block1, block2):
+        if block1 is None:
+            return (block2, False)
+
+        if block2 is None:
+            return (block1, False)
+
+        parent_block1 = block1
+        parent_block2 = block2
+        while parent_block1 != parent_block2:
+            l1 = self.block_level[parent_block1]
+            l2 = self.block_level[parent_block2]
+
+            if l1 >= l2:
+                if l1 == 0:
+                    return (parent_block1, False)
+
+                parent_block1 = self.block_parent[parent_block1]
+
+            if l2 >= l1:
+                if l2 == 0:
+                    return (parent_block2, False)
+
+                parent_block2 = self.block_parent[parent_block2]
+
+        return (parent_block1, parent_block1 != block1 and parent_block1 != block2)
+
+    def set_ownership(self, ast_node):
+        if ast_node not in self.ownership:
+            self.ownership[ast_node] = None
+
+        self.ownership[ast_node], must_lift = self.common_parent_block(self.ownership[ast_node], self.block_stack[-1])
+        if must_lift and ast_node not in self.expressions_to_lift:
+            self.expressions_to_lift.append(ast_node)
+
+    def visit_Block(self, ast_node):
+        self.block_level[ast_node] = len(self.block_stack)
+        self.block_parent[ast_node] = self.block_stack[-1] if len(self.block_stack) > 0 else None
+        self.block_stack.append(ast_node)
+        self.visit_children(ast_node)
+        self.block_stack.pop()
+
+    def visit_BinOp(self, ast_node):
+        self.set_ownership(ast_node)
+        self.visit_children(ast_node)
+
+    def visit_PropertyAccess(self, ast_node):
+        self.set_ownership(ast_node)
+        self.visit_children(ast_node)
diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py
index 0e61269..51b2f0c 100644
--- a/src/pairs/analysis/devices.py
+++ b/src/pairs/analysis/devices.py
@@ -33,7 +33,7 @@ class FetchKernelReferences(Visitor):
         self.kernel_stack.append(ast_node)
         self.visit_children(ast_node)
         self.kernel_stack.pop()
-        ast_node.add_bin_op([b for b in self.kernel_used_bin_ops[kernel_id] if b not in self.kernel_decls[kernel_id]])
+        ast_node.add_bin_op([b for b in self.kernel_used_bin_ops[kernel_id] if b not in self.kernel_decls[kernel_id] and not b.in_place])
 
     def visit_PropertyAccess(self, ast_node):
         # Visit property and save current writing state
@@ -51,8 +51,11 @@ class FetchKernelReferences(Visitor):
                 self.kernel_decls[k.kernel_id].append(ast_node.elem)
 
     def visit_BinOp(self, ast_node):
-        for k in self.kernel_stack:
-            self.kernel_used_bin_ops[k.kernel_id].append(ast_node)
+        if ast_node.inlined is False:
+            for k in self.kernel_stack:
+                self.kernel_used_bin_ops[k.kernel_id].append(ast_node)
+
+        self.visit_children(ast_node)
 
     def visit_Array(self, ast_node):
         for k in self.kernel_stack:
diff --git a/src/pairs/ir/bin_op.py b/src/pairs/ir/bin_op.py
index 5aa6d77..91a02d3 100644
--- a/src/pairs/ir/bin_op.py
+++ b/src/pairs/ir/bin_op.py
@@ -41,6 +41,7 @@ class BinOp(VectorExpression):
         self.mem = mem
         self.inlined = False
         self.generated = False
+        self.in_place = False
         self.bin_op_type = BinOp.infer_type(self.lhs, self.rhs, self.op)
         self.terminals = set()
         self.decl = Decl(sim, self)
diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py
index da2f225..8db5573 100644
--- a/src/pairs/ir/module.py
+++ b/src/pairs/ir/module.py
@@ -9,6 +9,7 @@ class Module(ASTNode):
 
     def __init__(self, sim, name=None, block=None, resizes_to_check={}, check_properties_resize=False, run_on_device=False):
         super().__init__(sim)
+        self._id = Module.last_module
         self._name = name if name is not None else "module" + str(Module.last_module)
         self._variables = {}
         self._arrays = {}
@@ -20,6 +21,10 @@ class Module(ASTNode):
         sim.add_module(self)
         Module.last_module += 1
 
+    @property
+    def module_id(self):
+        return self._id
+
     @property
     def name(self):
         return self._name
diff --git a/src/pairs/sim/pbc.py b/src/pairs/sim/pbc.py
index f4e50b5..3c812b3 100644
--- a/src/pairs/sim/pbc.py
+++ b/src/pairs/sim/pbc.py
@@ -58,13 +58,12 @@ class EnforcePBC(Lowerable):
 
         for i in ParticleFor(sim):
             # TODO: VecFilter?
-            pos = positions[i]
             for d in range(0, ndims):
-                for _ in Filter(sim, pos[d] < grid.min(d)):
-                    pos[d].add(grid.length(d))
+                for _ in Filter(sim, positions[i][d] < grid.min(d)):
+                    positions[i][d].add(grid.length(d))
 
-                for _ in Filter(sim, pos[d] > grid.max(d)):
-                    pos[d].sub(grid.length(d))
+                for _ in Filter(sim, positions[i][d] > grid.max(d)):
+                    positions[i][d].sub(grid.length(d))
 
 
 class SetupPBC(Lowerable):
@@ -91,12 +90,11 @@ class SetupPBC(Lowerable):
         for d in range(0, ndims):
             for i in For(sim, 0, nlocal + npbc):
                 pos = positions[i]
-                last_id = nlocal + npbc
-
                 grid_length = grid.length(d)
+
                 # TODO: VecFilter?
                 for _ in Filter(sim, pos[d] < grid.min(d) + cutneigh):
-                    last_pos = positions[last_id]
+                    last_pos = positions[nlocal + npbc]
                     pbc_map[npbc].set(i)
                     pbc_mult[npbc][d].set(1)
                     last_pos[d].set(pos[d] + grid_length)
@@ -108,7 +106,7 @@ class SetupPBC(Lowerable):
                     npbc.add(1)
 
                 for _ in Filter(sim, pos[d] > grid.max(d) - cutneigh):
-                    last_pos = positions[last_id]
+                    last_pos = positions[nlocal + npbc]
                     pbc_map[npbc].set(i)
                     pbc_mult[npbc][d].set(-1)
                     last_pos[d].set(pos[d] - grid_length)
diff --git a/src/pairs/transformations/__init__.py b/src/pairs/transformations/__init__.py
index 8615eb2..cf279fd 100644
--- a/src/pairs/transformations/__init__.py
+++ b/src/pairs/transformations/__init__.py
@@ -1,5 +1,5 @@
 from pairs.analysis import Analysis
-from pairs.transformations.blocks import MergeAdjacentBlocks
+from pairs.transformations.blocks import LiftExprOwnerBlocks, MergeAdjacentBlocks
 from pairs.transformations.devices import AddDeviceCopies, AddDeviceKernels
 from pairs.transformations.expressions import ReplaceSymbols, SimplifyExpressions, PrioritizeScalarOps
 from pairs.transformations.loops import LICM
@@ -16,6 +16,7 @@ class Transformations:
         self._replace_symbols = ReplaceSymbols(ast)
         self._simplify_expressions = SimplifyExpressions(ast)
         self._prioritize_scalar_ops = PrioritizeScalarOps(ast)
+        self._lift_expressions_to_owner_blocks = LiftExprOwnerBlocks(ast)
         self._licm = LICM(ast)
         self._dereference_write_variables = DereferenceWriteVariables(ast)
         self._add_resize_logic = AddResizeLogic(ast)
@@ -41,6 +42,11 @@ class Transformations:
         self._simplify_expressions.mutate()
         self._analysis.set_used_bin_ops()
 
+    def lift_expressions_to_owner_blocks(self):
+        ownership, expressions_to_lift = self._analysis.set_expressions_owner_block()
+        self._lift_expressions_to_owner_blocks.set_data(ownership, expressions_to_lift)
+        self._lift_expressions_to_owner_blocks.mutate()
+
     def licm(self):
         self._analysis.set_parent_block()
         self._analysis.set_block_variants()
@@ -67,6 +73,7 @@ class Transformations:
     def apply_all(self):
         self.lower_everything()
         self.optimize_expressions()
+        self.lift_expressions_to_owner_blocks()
         self.licm()
         self.modularize()
         self.add_device_copies()
diff --git a/src/pairs/transformations/blocks.py b/src/pairs/transformations/blocks.py
index e2f7964..21070a3 100644
--- a/src/pairs/transformations/blocks.py
+++ b/src/pairs/transformations/blocks.py
@@ -1,4 +1,5 @@
 from pairs.ir.block import Block
+from pairs.ir.bin_op import Decl
 from pairs.ir.mutator import Mutator
 
 
@@ -18,3 +19,20 @@ class MergeAdjacentBlocks(Mutator):
 
         ast_node.stmts = new_stmts 
         return ast_node
+
+
+class LiftExprOwnerBlocks(Mutator):
+    def __init__(self, ast):
+        super().__init__(ast)
+        self.ownership = None
+        self.expressions_to_lift = None
+
+    def set_data(self, ownership, expressions_to_lift):
+        self.ownership = ownership
+        self.expressions_to_lift = expressions_to_lift
+
+    def mutate_Block(self, ast_node):
+        ast_node.stmts = \
+            [Decl(ast_node.sim, e) for e in self.ownership if self.ownership[e] == ast_node and e in self.expressions_to_lift] + \
+            [self.mutate(s) for s in ast_node.stmts]
+        return ast_node
-- 
GitLab