From b57eabce6ff0f5ea61b49369ccd6c6a106e67bf5 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Wed, 16 Feb 2022 00:57:23 +0100
Subject: [PATCH] Remove KernelBlock and store varying properties in modules

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 src/pairs/code_gen/cgen.py                    | 11 ++---
 src/pairs/ir/block.py                         | 43 ++-----------------
 src/pairs/ir/module.py                        | 37 ++++++++++------
 src/pairs/ir/mutator.py                       |  4 --
 src/pairs/sim/simulation.py                   |  7 +--
 .../transformations/add_device_copies.py      | 39 +++--------------
 src/pairs/transformations/modules.py          |  8 ++--
 7 files changed, 44 insertions(+), 105 deletions(-)

diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index af7b1a6..2c188e7 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -1,6 +1,6 @@
 from pairs.ir.assign import Assign
 from pairs.ir.arrays import Array, ArrayAccess, ArrayDecl
-from pairs.ir.block import Block, KernelBlock
+from pairs.ir.block import Block
 from pairs.ir.branches import Branch
 from pairs.ir.cast import Cast
 from pairs.ir.bin_op import BinOp, Decl, VectorAccess
@@ -12,7 +12,7 @@ from pairs.ir.lit import Lit
 from pairs.ir.loops import For, Iter, ParticleFor, While
 from pairs.ir.math import Ceil, Sqrt
 from pairs.ir.memory import Malloc, Realloc
-from pairs.ir.module import Module_Call
+from pairs.ir.module import ModuleCall
 from pairs.ir.properties import Property, PropertyAccess, PropertyList, RegisterProperty, UpdateProperty
 from pairs.ir.select import Select
 from pairs.ir.sizeof import Sizeof
@@ -176,11 +176,6 @@ class CGen:
         if isinstance(ast_node, DeviceCopy):
             self.print(f"pairs::copy_to_device({ast_node.prop.name()})")
 
-        if isinstance(ast_node, KernelBlock):
-            self.print.add_indent(-4)
-            self.generate_statement(ast_node.block)
-            self.print.add_indent(4) # Workaround for fixing indentation of kernels
-
         if isinstance(ast_node, For):
             iterator = self.generate_expression(ast_node.iterator)
             lower_range = None
@@ -210,7 +205,7 @@ class CGen:
             else:
                 self.print(f"{array_name} = ({tkw} *) malloc({size});")
 
-        if isinstance(ast_node, Module_Call):
+        if isinstance(ast_node, ModuleCall):
             module = ast_node.module
             module_params = ""
             for var in module.read_only_variables():
diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py
index 42ae050..2d65656 100644
--- a/src/pairs/ir/block.py
+++ b/src/pairs/ir/block.py
@@ -19,20 +19,10 @@ def pairs_device_block(func):
         func(*args, **kwargs)
         return Module(sim,
             name=sim._module_name,
-            block=KernelBlock(sim, sim._block),
+            block=Block(sim, sim._block),
             resizes_to_check=sim._resizes_to_check,
-            check_properties_resize=sim._check_properties_resize)
-
-    return inner
-
-
-# TODO: Is this really useful? Or just pairs_block is enough?
-def pairs_host_block(func):
-    def inner(*args, **kwargs):
-        sim = args[0].sim # self.sim
-        sim.init_block()
-        func(*args, **kwargs)
-        return KernelBlock(sim, sim._block, run_on_host=True)
+            check_properties_resize=sim._check_properties_resize,
+            run_on_device=True)
 
     return inner
 
@@ -68,8 +58,6 @@ class Block(ASTNode):
     def merge_blocks(block1, block2):
         assert isinstance(block1, Block), "First block type is not Block!"
         assert isinstance(block2, Block), "Second block type is not Block!"
-        assert not isinstance(block1, KernelBlock), "Kernel blocks cannot be merged!"
-        assert not isinstance(block2, KernelBlock), "Kernel blocks cannot be merged!"
         return Block(block1.sim, block1.statements() + block2.statements())
 
     def from_list(sim, block_list):
@@ -83,28 +71,3 @@ class Block(ASTNode):
                 result_block.add_statement(block)
 
         return result_block
-
-
-class KernelBlock(ASTNode):
-    def __init__(self, sim, block, run_on_host=False):
-        super().__init__(sim)
-        self.block = block if isinstance(block, Block) else Block(sim, block)
-        self.run_on_host = run_on_host
-        self.props_accessed = {}
-
-    def add_property_access(self, prop, oper):
-        prop_key = prop.name()
-        if prop_key not in self.props_accessed:
-            self.props_accessed[prop_key] = oper
-
-        elif oper not in self.props_accessed[prop_key]:
-            self.props_accessed[prop_key] += oper
-
-    def children(self):
-        return [self.block]
-
-    def properties_to_synchronize(self):
-        return {p for p in self.props_accessed if self.props_accessed[p][0] == 'r'}
-
-    def writing_properties(self):
-        return {p for p in self.props_accessed if 'w' in self.props_accessed[p][0]}
diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py
index 70aac81..d8ffcaf 100644
--- a/src/pairs/ir/module.py
+++ b/src/pairs/ir/module.py
@@ -7,15 +7,16 @@ from pairs.ir.variables import Var
 class Module(ASTNode):
     last_module = 0
 
-    def __init__(self, sim, name=None, block=None, resizes_to_check={}, check_properties_resize=False):
+    def __init__(self, sim, name=None, block=None, resizes_to_check={}, check_properties_resize=False, run_on_device=False):
         super().__init__(sim)
         self._name = name if name is not None else "module_" + str(Module.last_module)
         self._variables = {}
-        self._arrays = set()
-        self._properties = set()
+        self._arrays = {}
+        self._properties = {}
         self._block = block
         self._resizes_to_check = resizes_to_check
         self._check_properties_resize = check_properties_resize
+        self._run_on_device = run_on_device
         sim.add_module(self)
         Module.last_module += 1
 
@@ -27,14 +28,18 @@ class Module(ASTNode):
     def block(self):
         return self._block
 
+    @property
+    def run_on_device(self):
+        return self._run_on_device
+
     def variables(self):
         return self._variables
 
     def read_only_variables(self):
-        return [v for v in self._variables if not self._variables[v]]
+        return [v for v in self._variables if 'w' not in self._variables[v]]
 
     def write_variables(self):
-        return [v for v in self._variables if self._variables[v]]
+        return [v for v in self._variables if 'w' in self._variables[v]]
 
     def arrays(self):
         return self._arrays
@@ -42,34 +47,40 @@ class Module(ASTNode):
     def properties(self):
         return self._properties
 
+    def properties_to_synchronize(self):
+        return {p for p in self._properties if self._properties[p][0] == 'r'}
+
+    def write_properties(self):
+        return {p for p in self._properties if 'w' in self._properties[p]}
+
     def add_array(self, array, write=False):
         array_list = array if isinstance(array, list) else [array]
+        character = 'w' if write else 'r'
         for a in array_list:
             assert isinstance(a, Array), "Module.add_array(): given element is not of type Array!"
-            self._arrays.add(a)
+            self._arrays[a] = character if a not in self._arrays else self._arrays[a] + character
 
     def add_variable(self, variable, write=False):
         variable_list = variable if isinstance(variable, list) else [variable]
+        character = 'w' if write else 'r'
         for v in variable_list:
             assert isinstance(v, Var), "Module.add_variable(): given element is not of type Var!"
-            if v not in self._variables:
-                self._variables[v] = write
-            else:
-                self._variables[v] = self._variables[v] or write
+            self._variables[v] = character if v not in self._variables else self._variables[v] + character
 
     def add_property(self, prop, write=False):
         prop_list = prop if isinstance(prop, list) else [prop]
+        character = 'w' if write else 'r'
         for p in prop_list:
             assert isinstance(p, Property), "Module.add_property(): given element is not of type Property!"
-            self._properties.add(p)
+            self._properties[p] = character if p not in self._properties else self._properties[p] + character
 
     def children(self):
         return [self._block]
 
 
-class Module_Call(ASTNode):
+class ModuleCall(ASTNode):
     def __init__(self, sim, module):
-        assert isinstance(module, Module), "Module_Call(): given parameter is not of type Module!"
+        assert isinstance(module, Module), "ModuleCall(): given parameter is not of type Module!"
         super().__init__(sim)
         self._module = module
 
diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py
index 09c9f73..7118b07 100644
--- a/src/pairs/ir/mutator.py
+++ b/src/pairs/ir/mutator.py
@@ -66,10 +66,6 @@ class Mutator:
         ast_node.block = self.mutate(ast_node.block)
         return ast_node
 
-    def mutate_KernelBlock(self, ast_node):
-        ast_node.block = self.mutate(ast_node.block)
-        return ast_node
-
     def mutate_ParticleFor(self, ast_node):
         return self.mutate_For(ast_node)
 
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 8ef7b4f..19692ed 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -1,5 +1,5 @@
 from pairs.ir.arrays import Arrays
-from pairs.ir.block import Block, KernelBlock
+from pairs.ir.block import Block
 from pairs.ir.branches import Filter
 from pairs.ir.data_types import Type_Int, Type_Float, Type_Vector
 from pairs.ir.layouts import Layout_AoS
@@ -181,9 +181,10 @@ class Simulation:
         self.kernels.add_statement(
             Module(self,
                 name=self._module_name,
-                block=KernelBlock(self, self._block),
+                block=Block(self, self._block),
                 resizes_to_check=self._resizes_to_check,
-                check_properties_resize=self._check_properties_resize))
+                check_properties_resize=self._check_properties_resize,
+                run_on_device=True))
 
     def add_statement(self, stmt):
         if not self.scope:
diff --git a/src/pairs/transformations/add_device_copies.py b/src/pairs/transformations/add_device_copies.py
index 004e504..33db9c8 100644
--- a/src/pairs/transformations/add_device_copies.py
+++ b/src/pairs/transformations/add_device_copies.py
@@ -1,33 +1,9 @@
-from pairs.ir.block import KernelBlock
 from pairs.ir.device import DeviceCopy
+from pairs.ir.module import ModuleCall
 from pairs.ir.mutator import Mutator
 from pairs.ir.visitor import Visitor
 
 
-class AddAccessedProperties(Visitor):
-    def __init__(self, ast):
-        super().__init__(ast)
-        self.current_kernel_block = None
-        self.writing = False
-
-    def visit_Assign(self, ast_node):
-        for s in ast_node.sources():
-            self.visit(s)
-        self.writing = True
-
-        for d in ast_node.destinations():
-            self.visit(d)
-        self.writing = False
-
-    def visit_KernelBlock(self, ast_node):
-        self.current_kernel_block = ast_node
-        self.visit_children(ast_node)
-
-    def visit_PropertyAccess(self, ast_node):
-        if self.current_kernel_block is not None:
-            self.current_kernel_block.add_property_access(ast_node.prop, 'w' if self.writing else 'r')
-
-
 class AddDeviceCopies(Mutator):
     def __init__(self, ast):
         super().__init__(ast)
@@ -41,25 +17,22 @@ class AddDeviceCopies(Mutator):
         for s in stmts:
             if s is not None:
                 s_id = id(s)
-                if isinstance(s, KernelBlock) and s_id in self.props_to_copy:
-                    new_stmts = new_stmts + [DeviceCopy(ast_node.sim, ast_node.sim.property(p)) for p in self.props_to_copy[s_id]]
+                if isinstance(s, ModuleCall) and s_id in self.props_to_copy:
+                    new_stmts = new_stmts + [DeviceCopy(ast_node.sim, p) for p in self.props_to_copy[s_id]]
 
                 new_stmts.append(s)
 
         ast_node.stmts = new_stmts
         return ast_node
 
-    def mutate_KernelBlock(self, ast_node):
-        ast_node.block = self.mutate(ast_node.block)
-        copying_properties = {p for p in ast_node.properties_to_synchronize() if p not in self.synchronized_props}
+    def mutate_ModuleCall(self, ast_node):
+        copying_properties = {p for p in ast_node.module.properties_to_synchronize() if p not in self.synchronized_props}
         self.props_to_copy[id(ast_node)] = copying_properties
         self.synchronized_props.update(copying_properties)
-        self.synchronized_props -= ast_node.writing_properties()
+        self.synchronized_props -= ast_node.module.write_properties()
         return ast_node
 
 
 def add_device_copies(ast):
-    add_accessed_props = AddAccessedProperties(ast)
-    add_accessed_props.visit()
     add_copies = AddDeviceCopies(ast)
     add_copies.mutate()
diff --git a/src/pairs/transformations/modules.py b/src/pairs/transformations/modules.py
index 695bce0..aaaef3f 100644
--- a/src/pairs/transformations/modules.py
+++ b/src/pairs/transformations/modules.py
@@ -7,7 +7,7 @@ from pairs.ir.data_types import Type_Vector
 from pairs.ir.lit import Lit
 from pairs.ir.loops import While
 from pairs.ir.memory import Realloc
-from pairs.ir.module import Module, Module_Call
+from pairs.ir.module import Module, ModuleCall
 from pairs.ir.mutator import Mutator
 from pairs.ir.properties import UpdateProperty
 from pairs.ir.variables import Var, Deref
@@ -38,11 +38,11 @@ class FetchModulesReferences(Visitor):
 
     def visit_Array(self, ast_node):
         for m in self.module_stack:
-            m.add_array(ast_node)
+            m.add_array(ast_node, self.writing)
 
     def visit_Property(self, ast_node):
         for m in self.module_stack:
-            m.add_property(ast_node)
+            m.add_property(ast_node, self.writing)
 
     def visit_Var(self, ast_node):
         for m in self.module_stack:
@@ -164,7 +164,7 @@ class ReplaceModulesByCalls(Mutator):
             return ast_node
 
         sim = ast_node.sim
-        call = Module_Call(sim, ast_node)
+        call = ModuleCall(sim, ast_node)
         if self.module_resizes[ast_node]:
             properties = sim.properties
             init_stmts = []
-- 
GitLab