From 9a5eb9e1e90a3deeeea95c892b48fdce65ca742c Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 22 Oct 2024 12:24:33 +0200
Subject: [PATCH] fix usages of PsBufferAcc api

---
 src/pystencils/backend/ast/analysis.py               | 12 ++++++++++--
 src/pystencils/backend/kernelcreation/ast_factory.py |  2 ++
 src/pystencils/backend/platforms/cuda.py             |  2 +-
 src/pystencils/backend/platforms/generic_cpu.py      |  2 +-
 src/pystencils/backend/platforms/sycl.py             |  2 +-
 src/pystencils/backend/platforms/x86.py              |  4 ++--
 6 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py
index 0c3233af4..3c6d2ef55 100644
--- a/src/pystencils/backend/ast/analysis.py
+++ b/src/pystencils/backend/ast/analysis.py
@@ -28,6 +28,8 @@ from .expressions import (
     PsSub,
     PsSymbolExpr,
     PsTernary,
+    PsSubscript,
+    PsMemAcc
 )
 
 from ..memory import PsSymbol
@@ -282,8 +284,14 @@ class OperationCounter:
             case PsSymbolExpr(_) | PsConstantExpr(_) | PsLiteralExpr(_):
                 return OperationCounts()
 
-            case PsBufferAcc(_, index):
-                return self.visit_expr(index)
+            case PsBufferAcc(_, indices) | PsSubscript(_, indices):
+                return reduce(
+                    operator.add,
+                    (self.visit_expr(idx) for idx in indices)
+                )
+            
+            case PsMemAcc(_, offset):
+                return self.visit_expr(offset)
 
             case PsCall(_, args):
                 return OperationCounts(calls=1) + reduce(
diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py
index d6084dbc7..2462e5e66 100644
--- a/src/pystencils/backend/kernelcreation/ast_factory.py
+++ b/src/pystencils/backend/kernelcreation/ast_factory.py
@@ -170,6 +170,8 @@ class AstFactory:
             raise ValueError(
                 "Cannot parse a slice with `stop == None` if no normalization limit is given"
             )
+        
+        assert stop is not None  # for mypy
 
         return start, stop, step
 
diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py
index 75c9b7a8f..6100a371b 100644
--- a/src/pystencils/backend/platforms/cuda.py
+++ b/src/pystencils/backend/platforms/cuda.py
@@ -173,7 +173,7 @@ class CudaPlatform(GenericGpu):
                 PsLookup(
                     PsBufferAcc(
                         ispace.index_list.base_pointer,
-                        sparse_ctr,
+                        (sparse_ctr,),
                     ),
                     coord.name,
                 ),
diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py
index 4ea1d6d4c..95aaf50c4 100644
--- a/src/pystencils/backend/platforms/generic_cpu.py
+++ b/src/pystencils/backend/platforms/generic_cpu.py
@@ -130,7 +130,7 @@ class GenericCpu(Platform):
                 PsLookup(
                     PsBufferAcc(
                         ispace.index_list.base_pointer,
-                        PsExpression.make(ispace.sparse_counter),
+                        (PsExpression.make(ispace.sparse_counter),),
                     ),
                     coord.name,
                 ),
diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py
index 7c3468932..b8684ce22 100644
--- a/src/pystencils/backend/platforms/sycl.py
+++ b/src/pystencils/backend/platforms/sycl.py
@@ -165,7 +165,7 @@ class SyclPlatform(GenericGpu):
                 PsLookup(
                     PsBufferAcc(
                         ispace.index_list.base_pointer,
-                        sparse_ctr,
+                        (sparse_ctr,),
                     ),
                     coord.name,
                 ),
diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py
index 5f5ad4a05..33838df08 100644
--- a/src/pystencils/backend/platforms/x86.py
+++ b/src/pystencils/backend/platforms/x86.py
@@ -145,7 +145,7 @@ class X86VectorCpu(GenericVectorCpu):
         if acc.stride == 1:
             load_func = _x86_packed_load(self._vector_arch, acc.dtype, False)
             return load_func(
-                PsAddressOf(PsMemAcc(PsExpression.make(acc.base_ptr), acc.index))
+                PsAddressOf(PsMemAcc(acc.pointer, acc.offset))
             )
         else:
             raise NotImplementedError("Gather loads not implemented yet.")
@@ -154,7 +154,7 @@ class X86VectorCpu(GenericVectorCpu):
         if acc.stride == 1:
             store_func = _x86_packed_store(self._vector_arch, acc.dtype, False)
             return store_func(
-                PsAddressOf(PsMemAcc(PsExpression.make(acc.base_ptr), acc.index)),
+                PsAddressOf(PsMemAcc(acc.pointer, acc.offset)),
                 arg,
             )
         else:
-- 
GitLab