diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py index 0c3233af41cb867621252503b36e2a10e8c216ff..3c6d2ef557e44a882edf4e104df4bd4e2a8830fd 100644 --- a/src/pystencils/backend/ast/analysis.py +++ b/src/pystencils/backend/ast/analysis.py @@ -28,6 +28,8 @@ from .expressions import ( PsSub, PsSymbolExpr, PsTernary, + PsSubscript, + PsMemAcc ) from ..memory import PsSymbol @@ -282,8 +284,14 @@ class OperationCounter: case PsSymbolExpr(_) | PsConstantExpr(_) | PsLiteralExpr(_): return OperationCounts() - case PsBufferAcc(_, index): - return self.visit_expr(index) + case PsBufferAcc(_, indices) | PsSubscript(_, indices): + return reduce( + operator.add, + (self.visit_expr(idx) for idx in indices) + ) + + case PsMemAcc(_, offset): + return self.visit_expr(offset) case PsCall(_, args): return OperationCounts(calls=1) + reduce( diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index d6084dbc7bca97aa8ff6bf3f1f9766e4e70c0561..2462e5e66ea1a55cd638df07f645b213dd37d68f 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -170,6 +170,8 @@ class AstFactory: raise ValueError( "Cannot parse a slice with `stop == None` if no normalization limit is given" ) + + assert stop is not None # for mypy return start, stop, step diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index 75c9b7a8ff405a37b855c328226562f3a3d979c8..6100a371b34fa1986b377f01854cd63f4a52888d 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -173,7 +173,7 @@ class CudaPlatform(GenericGpu): PsLookup( PsBufferAcc( ispace.index_list.base_pointer, - sparse_ctr, + (sparse_ctr,), ), coord.name, ), diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 4ea1d6d4c673214a337ab37892d7870839f79747..95aaf50c4b06472dcf4ceb0edfbeea623e0b2e04 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -130,7 +130,7 @@ class GenericCpu(Platform): PsLookup( PsBufferAcc( ispace.index_list.base_pointer, - PsExpression.make(ispace.sparse_counter), + (PsExpression.make(ispace.sparse_counter),), ), coord.name, ), diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py index 7c34689322e132ed90d080e887767dd3b08afc21..b8684ce22b43cf24337e5dbeb00f38b2763ed77d 100644 --- a/src/pystencils/backend/platforms/sycl.py +++ b/src/pystencils/backend/platforms/sycl.py @@ -165,7 +165,7 @@ class SyclPlatform(GenericGpu): PsLookup( PsBufferAcc( ispace.index_list.base_pointer, - sparse_ctr, + (sparse_ctr,), ), coord.name, ), diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index 5f5ad4a05bad8b02d18e5032e0d52e6daad2a48c..33838df08bcdb13094a96387fd3db565e4ba5932 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -145,7 +145,7 @@ class X86VectorCpu(GenericVectorCpu): if acc.stride == 1: load_func = _x86_packed_load(self._vector_arch, acc.dtype, False) return load_func( - PsAddressOf(PsMemAcc(PsExpression.make(acc.base_ptr), acc.index)) + PsAddressOf(PsMemAcc(acc.pointer, acc.offset)) ) else: raise NotImplementedError("Gather loads not implemented yet.") @@ -154,7 +154,7 @@ class X86VectorCpu(GenericVectorCpu): if acc.stride == 1: store_func = _x86_packed_store(self._vector_arch, acc.dtype, False) return store_func( - PsAddressOf(PsMemAcc(PsExpression.make(acc.base_ptr), acc.index)), + PsAddressOf(PsMemAcc(acc.pointer, acc.offset)), arg, ) else: