From 9a5eb9e1e90a3deeeea95c892b48fdce65ca742c Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 22 Oct 2024 12:24:33 +0200 Subject: [PATCH] fix usages of PsBufferAcc api --- src/pystencils/backend/ast/analysis.py | 12 ++++++++++-- src/pystencils/backend/kernelcreation/ast_factory.py | 2 ++ src/pystencils/backend/platforms/cuda.py | 2 +- src/pystencils/backend/platforms/generic_cpu.py | 2 +- src/pystencils/backend/platforms/sycl.py | 2 +- src/pystencils/backend/platforms/x86.py | 4 ++-- 6 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py index 0c3233af4..3c6d2ef55 100644 --- a/src/pystencils/backend/ast/analysis.py +++ b/src/pystencils/backend/ast/analysis.py @@ -28,6 +28,8 @@ from .expressions import ( PsSub, PsSymbolExpr, PsTernary, + PsSubscript, + PsMemAcc ) from ..memory import PsSymbol @@ -282,8 +284,14 @@ class OperationCounter: case PsSymbolExpr(_) | PsConstantExpr(_) | PsLiteralExpr(_): return OperationCounts() - case PsBufferAcc(_, index): - return self.visit_expr(index) + case PsBufferAcc(_, indices) | PsSubscript(_, indices): + return reduce( + operator.add, + (self.visit_expr(idx) for idx in indices) + ) + + case PsMemAcc(_, offset): + return self.visit_expr(offset) case PsCall(_, args): return OperationCounts(calls=1) + reduce( diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index d6084dbc7..2462e5e66 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -170,6 +170,8 @@ class AstFactory: raise ValueError( "Cannot parse a slice with `stop == None` if no normalization limit is given" ) + + assert stop is not None # for mypy return start, stop, step diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index 75c9b7a8f..6100a371b 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -173,7 +173,7 @@ class CudaPlatform(GenericGpu): PsLookup( PsBufferAcc( ispace.index_list.base_pointer, - sparse_ctr, + (sparse_ctr,), ), coord.name, ), diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 4ea1d6d4c..95aaf50c4 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -130,7 +130,7 @@ class GenericCpu(Platform): PsLookup( PsBufferAcc( ispace.index_list.base_pointer, - PsExpression.make(ispace.sparse_counter), + (PsExpression.make(ispace.sparse_counter),), ), coord.name, ), diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py index 7c3468932..b8684ce22 100644 --- a/src/pystencils/backend/platforms/sycl.py +++ b/src/pystencils/backend/platforms/sycl.py @@ -165,7 +165,7 @@ class SyclPlatform(GenericGpu): PsLookup( PsBufferAcc( ispace.index_list.base_pointer, - sparse_ctr, + (sparse_ctr,), ), coord.name, ), diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index 5f5ad4a05..33838df08 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -145,7 +145,7 @@ class X86VectorCpu(GenericVectorCpu): if acc.stride == 1: load_func = _x86_packed_load(self._vector_arch, acc.dtype, False) return load_func( - PsAddressOf(PsMemAcc(PsExpression.make(acc.base_ptr), acc.index)) + PsAddressOf(PsMemAcc(acc.pointer, acc.offset)) ) else: raise NotImplementedError("Gather loads not implemented yet.") @@ -154,7 +154,7 @@ class X86VectorCpu(GenericVectorCpu): if acc.stride == 1: store_func = _x86_packed_store(self._vector_arch, acc.dtype, False) return store_func( - PsAddressOf(PsMemAcc(PsExpression.make(acc.base_ptr), acc.index)), + PsAddressOf(PsMemAcc(acc.pointer, acc.offset)), arg, ) else: -- GitLab