From c5736142b4891eee2f2a2ad854ef88e706f6b930 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Wed, 19 Oct 2022 20:41:11 +0200
Subject: [PATCH] Small improvements in code

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 src/pairs/code_gen/cgen.py           |  6 +-----
 src/pairs/ir/loops.py                |  1 +
 src/pairs/transformations/devices.py | 11 +++++++----
 3 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 4efca1a..984cdc9 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -582,11 +582,7 @@ class CGen:
                 index_g = self.generate_expression(index_expr)
                 return f"{prop_name}[{index_g}]"
 
-            acc_ref = f"p{ast_node.id()}"
-            if ast_node.is_vector_kind():
-                acc_ref += f"_{index}"
-
-            return acc_ref
+            return f"p{ast_node.id()}" + (f"_{index}" if ast_node.is_vector_kind() else "")
 
         if isinstance(ast_node, PropertyList):
             tid = CGen.temp_id
diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py
index b2ea4af..553b5ad 100644
--- a/src/pairs/ir/loops.py
+++ b/src/pairs/ir/loops.py
@@ -44,6 +44,7 @@ class For(ASTNode):
         self.min = Lit.cvt(sim, range_min)
         self.max = Lit.cvt(sim, range_max)
         self.block = Block(sim, []) if block is None else block
+        self.kernel = None
 
     def __str__(self):
         return f"For<{self.iterator}, {self.min} ... {self.max}>"
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index 9ced487..a5b18f6 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -72,10 +72,13 @@ class AddDeviceKernels(Mutator):
             for s in ast_node._block.stmts:
                 if s is not None:
                     if isinstance(s, For) and (not isinstance(s.min, Lit) or not isinstance(s.max, Lit)):
-                        kernel_block = Filter(ast_node.sim, BinOp.inline(s.iterator < s.max), s.block)
-                        kernel = Kernel(ast_node.sim, f"{ast_node.name}_kernel{kernel_id}", kernel_block, s.iterator)
-                        new_stmts.append(KernelLaunch(ast_node.sim, kernel, s.iterator, s.min, s.max))
-                        kernel_id += 1
+                        if s.kernel is None:
+                            s.kernel = Kernel(ast_node.sim, f"{ast_node.name}_kernel{kernel_id}",
+                                              Filter(ast_node.sim, BinOp.inline(s.iterator < s.max), s.block), s.iterator)
+                            kernel_id += 1
+
+                        new_stmts.append(KernelLaunch(ast_node.sim, s.kernel, s.iterator, s.min, s.max))
+
                     else:
                         new_stmts.append(s)
 
-- 
GitLab