From 7ae76a3637bb64235f9d24940e105b92dfcc1a7b Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Fri, 10 Mar 2023 16:26:28 +0100
Subject: [PATCH] Fix communication for GPU

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 src/pairs/code_gen/cgen.py | 38 +++++++++++++++++---------------------
 1 file changed, 17 insertions(+), 21 deletions(-)

diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index ec40182..285f147 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -142,39 +142,39 @@ class CGen:
             self.print("}")
 
     def generate_kernel(self, kernel):
-        kernel_params = ""
+        kernel_params = "int range_start"
         for var in kernel.read_only_variables():
             type_kw = Types.c_keyword(var.type())
             decl = f"{type_kw} {var.name()}"
-            kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+            kernel_params += f", {decl}"
 
         for var in kernel.write_variables():
             type_kw = Types.c_keyword(var.type())
             decl = f"{type_kw} *{var.name()}"
-            kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+            kernel_params += f", {decl}"
 
         for array in kernel.arrays():
             type_kw = Types.c_keyword(array.type())
             decl = f"{type_kw} *{array.name()}"
-            kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+            kernel_params += f", {decl}"
 
         for prop in kernel.properties():
             type_kw = Types.c_keyword(prop.type())
             decl = f"{type_kw} *{prop.name()}"
-            kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+            kernel_params += f", {decl}"
 
         for array_access in kernel.array_accesses():
             type_kw = Types.c_keyword(array_access.type())
             decl = f"{type_kw} a{array_access.id()}"
-            kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+            kernel_params += f", {decl}"
 
         for bin_op in kernel.bin_ops():
             type_kw = Types.c_keyword(bin_op.type())
             decl = f"{type_kw} e{bin_op.id()}"
-            kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+            kernel_params += f", {decl}"
 
         self.print(f"__global__ void {kernel.name}({kernel_params}) {{")
-        self.print(f"    const int {kernel.iterator.name()} = blockIdx.x * blockDim.x + threadIdx.x;")
+        self.print(f"    const int {kernel.iterator.name()} = blockIdx.x * blockDim.x + threadIdx.x + range_start;")
         self.print.add_indent(4)
         self.generate_statement(kernel.block)
         self.print.add_indent(-4)
@@ -365,31 +365,27 @@ class CGen:
                     self.print(f"d_{array_name} = ({tkw} *) pairs::device_alloc({size});")
 
         if isinstance(ast_node, KernelLaunch):
+            range_start = self.generate_expression(BinOp.inline(ast_node.min))
             kernel = ast_node.kernel
-            kernel_params = ""
+            kernel_params = f"{range_start}"
+
             for var in kernel.read_only_variables():
-                decl = var.name()
-                kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+                kernel_params += f", {var.name()}"
 
             for var in kernel.write_variables():
-                decl = var.name()
-                kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+                kernel_params += f", {var.name()}"
 
             for array in kernel.arrays():
-                decl = array.name()
-                kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+                kernel_params += f", {array.name()}"
 
             for prop in kernel.properties():
-                decl = prop.name()
-                kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+                kernel_params += f", {prop.name()}"
 
             for array_access in kernel.array_accesses():
-                decl = self.generate_expression(array_access)
-                kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+                kernel_params += f", {self.generate_expression(array_access)}"
 
             for bin_op in kernel.bin_ops():
-                decl = self.generate_expression(bin_op)
-                kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
+                kernel_params += f", {self.generate_expression(bin_op)}"
 
             threads_per_block = self.generate_expression(ast_node.threads_per_block)
             nblocks = self.generate_expression(ast_node.nblocks)
-- 
GitLab