From 61cd1818118f2c9b26257cce74b9d05511eab531 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Wed, 30 Mar 2022 23:35:51 +0200
Subject: [PATCH] Add first draft code to generate CUDA kernels

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 src/pairs/analysis/__init__.py        |  5 ++
 src/pairs/analysis/devices.py         |  1 +
 src/pairs/code_gen/cgen.py            | 66 ++++++++++++++++++++++++++-
 src/pairs/coupling/parse_cpp.py       |  2 +-
 src/pairs/ir/kernel.py                | 16 +++++++
 src/pairs/mapping/funcs.py            |  2 +-
 src/pairs/sim/simulation.py           | 20 +++++---
 src/pairs/transformations/__init__.py |  9 +++-
 src/pairs/transformations/devices.py  |  2 +
 9 files changed, 113 insertions(+), 10 deletions(-)

diff --git a/src/pairs/analysis/__init__.py b/src/pairs/analysis/__init__.py
index 5cd89c2..8752a40 100644
--- a/src/pairs/analysis/__init__.py
+++ b/src/pairs/analysis/__init__.py
@@ -1,5 +1,6 @@
 from pairs.analysis.bin_ops import SetBinOpTerminals, SetUsedBinOps
 from pairs.analysis.blocks import SetBlockVariants, SetParentBlock
+from pairs.analysis.devices import FetchKernelReferences
 from pairs.analysis.modules import FetchModulesReferences
 
 
@@ -10,6 +11,7 @@ class Analysis:
         self._set_bin_op_terminals = SetBinOpTerminals(ast)
         self._set_block_variants = SetBlockVariants(ast)
         self._set_parent_block = SetParentBlock(ast)
+        self._fetch_kernel_references = FetchKernelReferences(ast)
         self._fetch_modules_references = FetchModulesReferences(ast)
 
     def set_used_bin_ops(self):
@@ -24,5 +26,8 @@ class Analysis:
     def set_parent_block(self):
         self._set_parent_block.visit()
 
+    def fetch_kernel_references(self):
+        self._fetch_kernel_references.visit()
+
     def fetch_modules_references(self):
         self._fetch_modules_references.visit()
diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py
index 94d13ba..4558953 100644
--- a/src/pairs/analysis/devices.py
+++ b/src/pairs/analysis/devices.py
@@ -1,3 +1,4 @@
+from pairs.ir.bin_op import BinOp
 from pairs.ir.visitor import Visitor
 
 
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 99d37fc..a7f3996 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -6,6 +6,7 @@ from pairs.ir.cast import Cast
 from pairs.ir.bin_op import BinOp, Decl, VectorAccess
 from pairs.ir.device import CopyToDevice, CopyToHost
 from pairs.ir.functions import Call
+from pairs.ir.kernel import Kernel, KernelLaunch
 from pairs.ir.layouts import Layouts
 from pairs.ir.lit import Lit
 from pairs.ir.loops import For, Iter, ParticleFor, While
@@ -60,8 +61,13 @@ class CGen:
                     self.print(f"__constant__ {tkw} d_{array.name()}[{size}];")
 
         self.print("")
+
+        for kernel in self.sim.kernels():
+            self.generate_kernel(kernel)
+
         for module in self.sim.modules():
             self.generate_module(module)
+
         self.print.end()
 
     def generate_module(self, module):
@@ -110,6 +116,37 @@ class CGen:
 
             self.print("}")
 
+    def generate_kernel(self, kernel):
+        kernel_params = ""
+        for var in kernel.read_only_variables():
+            type_kw = Types.c_keyword(var.type())
+            decl = f"{type_kw} {var.name()}"
+            kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+
+        for var in kernel.write_variables():
+            type_kw = Types.c_keyword(var.type())
+            decl = f"{type_kw} *{var.name()}"
+            kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+
+        for array in kernel.arrays():
+            type_kw = Types.c_keyword(array.type())
+            decl = f"{type_kw} *{array.name()}"
+            kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+
+        for prop in kernel.properties():
+            type_kw = Types.c_keyword(prop.type())
+            decl = f"{type_kw} *{prop.name()}"
+            kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+
+        for bin_op in kernel.bin_ops():
+            type_kw = Types.c_keyword(bin_op.type())
+            decl = f"{type_kw} e{bin_op.id()}"
+            kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+
+        self.print(f"__global__ void {kernel.name}({kernel_params}) {{")
+        self.generate_statement(kernel.block)
+        self.print("}")
+
     def generate_statement(self, ast_node, bypass_checking=False):
         if isinstance(ast_node, ArrayDecl):
             t = ast_node.array.type()
@@ -230,6 +267,34 @@ class CGen:
                 if self.target.is_gpu() and ast_node.array.device_flag:
                     self.print(f"d_{array_name} = ({tkw} *) pairs::device_alloc({size});")
 
+        if isinstance(ast_node, KernelLaunch):
+            kernel = ast_node.kernel
+            kernel_params = ""
+            for var in module.read_only_variables():
+                decl = var.name()
+                kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+
+            for var in module.write_variables():
+                decl = f"&{var.name()}"
+                kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+
+            for array in module.arrays():
+                decl = f"d_{array.name()}"
+                kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+
+            for prop in module.properties():
+                decl = f"d_{prop.name()}"
+                kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+
+            for bin_op in kernel.bin_ops():
+                decl = self.generate_expression(bin_op)
+                kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+
+            elems = ast_node.kernel.max - ast_node.kernel.min
+            threads_per_block = self.generate_expression(ast_node.kernel.threads_per_block)
+            blocks = self.generate_expression((elems + threads_per_block - 1) // threads_per_block)
+            self.print(f"{kernel.name}<<<{blocks}, {threads_per_block}>>>({kernel_params});")
+
         if isinstance(ast_node, ModuleCall):
             module = ast_node.module
             module_params = ""
@@ -289,7 +354,6 @@ class CGen:
 
         if isinstance(ast_node, UpdateProperty):
             p = ast_node.property()
-
             if p.type() != Types.Vector or p.layout() == Layouts.Invalid:
                 self.print(f"ps->updateProperty({p.id()}, {p.name()});")
             else:
diff --git a/src/pairs/coupling/parse_cpp.py b/src/pairs/coupling/parse_cpp.py
index a1e9493..385a855 100644
--- a/src/pairs/coupling/parse_cpp.py
+++ b/src/pairs/coupling/parse_cpp.py
@@ -176,7 +176,7 @@ def map_kernel_to_simulation(sim, node):
             }
         })
 
-    self.build_kernel_block_with_statements()
+    self.build_module_with_statements()
 
 
 def map_method_tree(sim, node, assignments={}, mappings={}):
diff --git a/src/pairs/ir/kernel.py b/src/pairs/ir/kernel.py
index 8f7770c..f1040ec 100644
--- a/src/pairs/ir/kernel.py
+++ b/src/pairs/ir/kernel.py
@@ -49,6 +49,9 @@ class Kernel(ASTNode):
     def properties_to_synchronize(self):
         return {p for p in self._properties if self._properties[p][0] == 'r'}
 
+    def bin_ops(self):
+        return self._bin_ops
+
     def write_properties(self):
         return {p for p in self._properties if 'w' in self._properties[p]}
 
@@ -91,7 +94,20 @@ class KernelLaunch(ASTNode):
         self._iterator = iterator
         self._range_min = range_min
         self._range_max = range_max
+        self._threads_per_block = 32
 
     @property
     def kernel(self):
         return self._kernel
+
+    @property
+    def min(self):
+        return self._range_min
+
+    @property
+    def max(self):
+        return self._range_max
+
+    @property
+    def threads_per_block(self):
+        return self._threads_per_block
diff --git a/src/pairs/mapping/funcs.py b/src/pairs/mapping/funcs.py
index 164a3de..5ba5fa2 100644
--- a/src/pairs/mapping/funcs.py
+++ b/src/pairs/mapping/funcs.py
@@ -148,4 +148,4 @@ def compute(sim, func, cutoff_radius=None, symbols={}):
             ir.add_symbols({params[0]: i, params[1]: j, 'delta': pairs.delta(), 'rsq': pairs.squared_distance()})
             ir.visit(tree)
 
-    sim.build_kernel_block_with_statements()
+    sim.build_module_with_statements()
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 7cca2fc..e8bd7b2 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -45,8 +45,9 @@ class Simulation:
         self.check_decl_usage = True
         self._block = Block(self, [])
         self.setups = Block(self, [])
-        self.kernels = Block(self, [])
+        self.functions = Block(self, [])
         self.module_list = []
+        self.kernel_list = []
         self._check_properties_resize = False
         self._resizes_to_check = {}
         self._module_name = None
@@ -74,6 +75,13 @@ class Simulation:
 
         return sorted_mods + [main_mod]
 
+    def add_kernel(self, kernel):
+        assert isinstance(kernel, Kernel), "add_kernel(): Given parameter is not of type Kernel!"
+        self.kernel_list.append(kernel)
+
+    def kernels(self):
+        return self.kernel_list
+
     def ndims(self):
         return self.dims
 
@@ -170,14 +178,14 @@ class Simulation:
         else:
             raise Exception("Two sizes assigned to same capacity!")
 
-    def build_kernel_block_with_statements(self):
-        self.kernels.add_statement(
+    def build_module_with_statements(self, run_on_device=True):
+        self.functions.add_statement(
             Module(self,
                 name=self._module_name,
                 block=Block(self, self._block),
                 resizes_to_check=self._resizes_to_check,
                 check_properties_resize=self._check_properties_resize,
-                run_on_device=True))
+                run_on_device=run_on_device))
 
     def add_statement(self, stmt):
         if not self.scope:
@@ -220,7 +228,7 @@ class Simulation:
             (CellListsBuild(self, self.cell_lists), 20),
             (NeighborListsBuild(self, self.neighbor_lists), 20),
             PropertiesResetVolatile(self),
-            self.kernels
+            self.functions
         ])
 
         self.enter(timestep.block)
@@ -249,5 +257,5 @@ class Simulation:
         # For this part on, all bin ops are generated without usage verification
         self.check_decl_usage = False
 
-        ASTGraph(self.kernels, "kernels").render()
+        ASTGraph(self.functions, "functions").render()
         self.code_gen.generate_program(program)
diff --git a/src/pairs/transformations/__init__.py b/src/pairs/transformations/__init__.py
index edd3b33..08cf0e0 100644
--- a/src/pairs/transformations/__init__.py
+++ b/src/pairs/transformations/__init__.py
@@ -1,6 +1,6 @@
 from pairs.analysis import Analysis
 from pairs.transformations.blocks import MergeAdjacentBlocks
-from pairs.transformations.devices import AddDeviceCopies
+from pairs.transformations.devices import AddDeviceCopies, AddDeviceKernels
 from pairs.transformations.expressions import ReplaceSymbols, SimplifyExpressions, PrioritizeScalarOps
 from pairs.transformations.loops import LICM
 from pairs.transformations.lower import Lower
@@ -23,6 +23,7 @@ class Transformations:
 
         if target.is_gpu():
             self._add_device_copies = AddDeviceCopies(ast)
+            self._add_device_kernels = AddDeviceKernels(ast)
 
     def lower_everything(self):
         nlowered = 1
@@ -58,9 +59,15 @@ class Transformations:
         if self._target.is_gpu():
             self._add_device_copies.mutate()
 
+    def add_device_kernels(self):
+        if self._target.is_gpu():
+            self._analysis.fetch_kernel_references()
+            self._add_device_kernels.mutate()
+
     def apply_all(self):
         self.lower_everything()
         self.optimize_expressions()
         self.licm()
         self.modularize()
         self.add_device_copies()
+        self.add_device_kernels()
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index 1a9a037..e493d94 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -83,4 +83,6 @@ class AddDeviceKernels(Mutator):
                     else:
                         new_stmts.append(s)
 
+            ast_node._block_stmts = new_stmts
+
         return ast_node
-- 
GitLab