From ce816539159408b26b09b4bf6e1df0cbff437829 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Wed, 19 Feb 2025 16:33:25 +0100
Subject: [PATCH] Encapsulate fetching of kernel conditions for iteration
 spaces in separate function

---
 src/pystencils/backend/platforms/cuda.py | 55 ++++++++++++++++++------
 1 file changed, 42 insertions(+), 13 deletions(-)

diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py
index fb613347a..eff88df7e 100644
--- a/src/pystencils/backend/platforms/cuda.py
+++ b/src/pystencils/backend/platforms/cuda.py
@@ -315,13 +315,47 @@ class CudaPlatform(GenericGpu):
 
     #   Internals
 
+    # TODO: SYCL platform has very similar code for fetching conditionals -> move to GenericGPU?
+
+    def _get_condition_for_translation(
+            self, ispace: IterationSpace):
+
+        if not self._omit_range_check:
+            return None
+
+        match ispace:
+            case FullIterationSpace():
+
+                dimensions = ispace.dimensions_in_loop_order()
+
+                conds = []
+                for dim in dimensions:
+                    ctr_expr = PsExpression.make(dim.counter)
+                    conds.append(PsLt(ctr_expr, dim.stop))
+
+                    if conds:
+                        condition: PsExpression = conds[0]
+                        for cond in conds[1:]:
+                            condition = PsAnd(condition, cond)
+                        return condition
+                    else:
+                        return None
+
+            case SparseIterationSpace():
+                sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
+                stop = PsExpression.make(ispace.index_list.shape[0])
+
+                return PsLt(sparse_ctr_expr.clone(), stop)
+            case _:
+                assert False, "Unknown iteration space"
+
     def _prepend_dense_translation(
         self, body: PsBlock, ispace: FullIterationSpace
     ) -> PsBlock:
         ctr_mapping = self._thread_mapping(ispace)
 
         indexing_decls = []
-        conds = []
+        cond = self._get_condition_for_translation(ispace)
 
         dimensions = ispace.dimensions_in_loop_order()
 
@@ -335,14 +369,9 @@ class CudaPlatform(GenericGpu):
             indexing_decls.append(
                 self._typify(PsDeclaration(ctr_expr, ctr_mapping[dim.counter]))
             )
-            if not self._omit_range_check:
-                conds.append(PsLt(ctr_expr, dim.stop))
-
-        if conds:
-            condition: PsExpression = conds[0]
-            for cond in conds[1:]:
-                condition = PsAnd(condition, cond)
-            ast = PsBlock(indexing_decls + [PsConditional(condition, body)])
+
+        if cond:
+            ast = PsBlock(indexing_decls + [PsConditional(cond, body)])
         else:
             body.statements = indexing_decls + body.statements
             ast = body
@@ -355,6 +384,8 @@ class CudaPlatform(GenericGpu):
         factory = AstFactory(self._ctx)
         ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
 
+        cond = self._get_condition_for_translation(ispace)
+
         sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
         ctr_mapping = self._thread_mapping(ispace)
 
@@ -377,10 +408,8 @@ class CudaPlatform(GenericGpu):
         ]
         body.statements = mappings + body.statements
 
-        if not self._omit_range_check:
-            stop = PsExpression.make(ispace.index_list.shape[0])
-            condition = PsLt(sparse_ctr_expr.clone(), stop)
-            ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)])
+        if cond:
+            ast = PsBlock([sparse_idx_decl, PsConditional(cond, body)])
         else:
             body.statements = [sparse_idx_decl] + body.statements
             ast = body
-- 
GitLab