Skip to content
Snippets Groups Projects
Commit 9a5eb9e1 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fix usages of PsBufferAcc api

parent 51e38626
No related branches found
No related tags found
1 merge request!421Refactor Field Modelling
......@@ -28,6 +28,8 @@ from .expressions import (
PsSub,
PsSymbolExpr,
PsTernary,
PsSubscript,
PsMemAcc
)
from ..memory import PsSymbol
......@@ -282,8 +284,14 @@ class OperationCounter:
case PsSymbolExpr(_) | PsConstantExpr(_) | PsLiteralExpr(_):
return OperationCounts()
case PsBufferAcc(_, index):
return self.visit_expr(index)
case PsBufferAcc(_, indices) | PsSubscript(_, indices):
return reduce(
operator.add,
(self.visit_expr(idx) for idx in indices)
)
case PsMemAcc(_, offset):
return self.visit_expr(offset)
case PsCall(_, args):
return OperationCounts(calls=1) + reduce(
......
......@@ -170,6 +170,8 @@ class AstFactory:
raise ValueError(
"Cannot parse a slice with `stop == None` if no normalization limit is given"
)
assert stop is not None # for mypy
return start, stop, step
......
......@@ -173,7 +173,7 @@ class CudaPlatform(GenericGpu):
PsLookup(
PsBufferAcc(
ispace.index_list.base_pointer,
sparse_ctr,
(sparse_ctr,),
),
coord.name,
),
......
......@@ -130,7 +130,7 @@ class GenericCpu(Platform):
PsLookup(
PsBufferAcc(
ispace.index_list.base_pointer,
PsExpression.make(ispace.sparse_counter),
(PsExpression.make(ispace.sparse_counter),),
),
coord.name,
),
......
......@@ -165,7 +165,7 @@ class SyclPlatform(GenericGpu):
PsLookup(
PsBufferAcc(
ispace.index_list.base_pointer,
sparse_ctr,
(sparse_ctr,),
),
coord.name,
),
......
......@@ -145,7 +145,7 @@ class X86VectorCpu(GenericVectorCpu):
if acc.stride == 1:
load_func = _x86_packed_load(self._vector_arch, acc.dtype, False)
return load_func(
PsAddressOf(PsMemAcc(PsExpression.make(acc.base_ptr), acc.index))
PsAddressOf(PsMemAcc(acc.pointer, acc.offset))
)
else:
raise NotImplementedError("Gather loads not implemented yet.")
......@@ -154,7 +154,7 @@ class X86VectorCpu(GenericVectorCpu):
if acc.stride == 1:
store_func = _x86_packed_store(self._vector_arch, acc.dtype, False)
return store_func(
PsAddressOf(PsMemAcc(PsExpression.make(acc.base_ptr), acc.index)),
PsAddressOf(PsMemAcc(acc.pointer, acc.offset)),
arg,
)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment