diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py index 4943ac4095b9e0f1e5763bb73a70e6756b2950a2..0c3233af41cb867621252503b36e2a10e8c216ff 100644 --- a/src/pystencils/backend/ast/analysis.py +++ b/src/pystencils/backend/ast/analysis.py @@ -16,7 +16,7 @@ from .structural import ( ) from .expressions import ( PsAdd, - PsArrayAccess, + PsBufferAcc, PsCall, PsConstantExpr, PsDiv, @@ -282,7 +282,7 @@ class OperationCounter: case PsSymbolExpr(_) | PsConstantExpr(_) | PsLiteralExpr(_): return OperationCounts() - case PsArrayAccess(_, index): + case PsBufferAcc(_, index): return self.visit_expr(index) case PsCall(_, args): diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 286958ee2f226eb2b6a92b2d7f5959e399418ddf..e658108e08148a78dd1fd8e834d78c3edcce7c9b 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -7,16 +7,13 @@ import operator import numpy as np from numpy.typing import NDArray -from ..memory import PsSymbol +from ..memory import PsSymbol, PsBuffer, BufferBasePtr from ..constants import PsConstant from ..literals import PsLiteral -from ..memory import PsBuffer from ..functions import PsFunction from ...types import ( PsType, - PsScalarType, PsVectorType, - PsTypeError, ) from .util import failing_cast from ..exceptions import PsInternalCompilerError @@ -193,6 +190,63 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression): return f"PsLiteralExpr({repr(self._literal)})" +class PsBufferAcc(PsLvalue, PsExpression): + """Access into a `PsBuffer`.""" + + __match_args__ = ("base_pointer", "index") + + def __init__(self, base_ptr: PsSymbol, index: Sequence[PsExpression]): + super().__init__() + bptr_prop = cast(BufferBasePtr, base_ptr.get_properties(BufferBasePtr).pop()) + + if len(index) != bptr_prop.buffer.dim: + raise ValueError("Number of index expressions must equal buffer shape.") + + self._base_ptr = PsExpression.make(base_ptr) + self._index = list(index) + self._dtype = bptr_prop.buffer.element_type + + @property + def base_pointer(self) -> PsSymbolExpr: + return self._base_ptr + + @base_pointer.setter + def base_pointer(self, expr: PsSymbolExpr): + bptr_prop = cast(BufferBasePtr, expr.symbol.get_properties(BufferBasePtr).pop()) + if bptr_prop.buffer != self.buffer: + raise ValueError( + "Cannot replace a buffer access's base pointer with one belonging to a different buffer." + ) + + self._base_ptr = expr + + @property + def buffer(self) -> PsBuffer: + return cast( + BufferBasePtr, self._base_ptr.symbol.get_properties(BufferBasePtr).pop() + ).buffer + + @property + def index(self) -> list[PsExpression]: + return self._index + + def get_children(self) -> tuple[PsAstNode, ...]: + return (self._base_ptr,) + tuple(self._index) + + def set_child(self, idx: int, c: PsAstNode): + idx = range(len(self._index) + 1)[idx] + if idx == 0: + self.base_pointer = failing_cast(PsSymbolExpr, c) + else: + self._index[idx - 1] = failing_cast(PsExpression, c) + + def clone(self) -> PsBufferAcc: + return PsBufferAcc(self._base_ptr.symbol, [i.clone() for i in self._index]) + + def __repr__(self) -> str: + return f"PsArrayAccess({repr(self._base_ptr)}, {repr(self._index)})" + + class PsSubscript(PsLvalue, PsExpression): """N-dimensional subscript into an array.""" @@ -239,7 +293,7 @@ class PsSubscript(PsLvalue, PsExpression): def __repr__(self) -> str: idx = ", ".join(repr(i) for i in self._index) - return f"PsSubscript({self._arr}, ({idx}))" + return f"PsSubscript({repr(self._arr)}, {repr(idx)})" class PsMemAcc(PsLvalue, PsExpression): @@ -286,83 +340,28 @@ class PsMemAcc(PsLvalue, PsExpression): return f"PsMemAcc({repr(self._ptr)}, {repr(self._offset)})" -class PsArrayAccess(PsMemAcc): - __match_args__ = ("base_ptr", "index") - - def __init__(self, base_ptr: PsArrayBasePointer, index: PsExpression): - super().__init__(PsExpression.make(base_ptr), index) - self._base_ptr = base_ptr - self._dtype = base_ptr.array.element_type - - @property - def base_ptr(self) -> PsArrayBasePointer: - return self._base_ptr +class PsVectorMemAcc(PsMemAcc): + """Pointer-based vectorized memory access.""" - @property - def pointer(self) -> PsExpression: - return self._ptr - - @pointer.setter - def pointer(self, expr: PsExpression): - if not isinstance(expr, PsSymbolExpr) or not isinstance( - expr.symbol, PsArrayBasePointer - ): - raise ValueError( - "Base expression of PsArrayAccess must be an array base pointer" - ) - - self._base_ptr = expr.symbol - self._ptr = expr - - @property - def array(self) -> PsBuffer: - return self._base_ptr.array - - @property - def index(self) -> PsExpression: - return self._offset - - @index.setter - def index(self, expr: PsExpression): - self._offset = expr - - def clone(self) -> PsArrayAccess: - return PsArrayAccess(self._base_ptr, self._offset.clone()) - - def __repr__(self) -> str: - return f"PsArrayAccess({repr(self._base_ptr)}, {repr(self._offset)})" - - -class PsVectorArrayAccess(PsArrayAccess): __match_args__ = ("base_ptr", "base_index") def __init__( self, - base_ptr: PsArrayBasePointer, + base_ptr: PsExpression, base_index: PsExpression, vector_entries: int, stride: int = 1, alignment: int = 0, ): super().__init__(base_ptr, base_index) - element_type = base_ptr.array.element_type - - if not isinstance(element_type, PsScalarType): - raise PsTypeError( - "Cannot generate vector accesses to arrays with non-scalar elements" - ) - self._vector_type = PsVectorType( - element_type, vector_entries, const=element_type.const - ) + self._vector_entries = vector_entries self._stride = stride self._alignment = alignment - self._dtype = self._vector_type - @property def vector_entries(self) -> int: - return self._vector_type.vector_entries + return self._vector_entries @property def stride(self) -> int: @@ -375,9 +374,9 @@ class PsVectorArrayAccess(PsArrayAccess): def get_vector_type(self) -> PsVectorType: return cast(PsVectorType, self._dtype) - def clone(self) -> PsVectorArrayAccess: - return PsVectorArrayAccess( - self._base_ptr, + def clone(self) -> PsVectorMemAcc: + return PsVectorMemAcc( + self._ptr.clone(), self._offset.clone(), self.vector_entries, self._stride, @@ -385,12 +384,12 @@ class PsVectorArrayAccess(PsArrayAccess): ) def structurally_equal(self, other: PsAstNode) -> bool: - if not isinstance(other, PsVectorArrayAccess): + if not isinstance(other, PsVectorMemAcc): return False return ( super().structurally_equal(other) - and self._vector_type == other._vector_type + and self._vector_entries == other._vector_entries and self._stride == other._stride and self._alignment == other._alignment ) @@ -591,14 +590,15 @@ class PsNeg(PsUnOp, PsNumericOpTrait): class PsAddressOf(PsUnOp): """Take the address of a memory location. - + .. DANGER:: Taking the address of a memory location owned by a symbol or field array introduces an alias to that memory location. As pystencils assumes its symbols and fields to never be aliased, this can - subtly change the semantics of a kernel. + subtly change the semantics of a kernel. Use the address-of operator with utmost care. """ + pass @@ -813,18 +813,21 @@ class PsArrayInitList(PsExpression): __match_args__ = ("items",) - def __init__(self, items: Sequence[PsExpression | Sequence[PsExpression | Sequence[PsExpression]]]): + def __init__( + self, + items: Sequence[PsExpression | Sequence[PsExpression | Sequence[PsExpression]]], + ): super().__init__() self._items = np.array(items, dtype=np.object_) @property def items_grid(self) -> NDArray[np.object_]: return self._items - + @property def shape(self) -> tuple[int, ...]: return self._items.shape - + @property def items(self) -> tuple[PsExpression, ...]: return tuple(self._items.flat) # type: ignore @@ -836,7 +839,7 @@ class PsArrayInitList(PsExpression): self._items.flat[idx] = failing_cast(PsExpression, c) def clone(self) -> PsExpression: - return PsArrayInitList( + return PsArrayInitList( np.array([expr.clone() for expr in self.children]).reshape( # type: ignore self._items.shape ) diff --git a/src/pystencils/backend/ast/util.py b/src/pystencils/backend/ast/util.py index b7bde603f030b0814a118899dd7c45246be22418..288097a901e3f11f4a6f12c47799b25ec672151e 100644 --- a/src/pystencils/backend/ast/util.py +++ b/src/pystencils/backend/ast/util.py @@ -59,7 +59,7 @@ def determine_memory_object( PsLookup, PsSymbolExpr, PsMemAcc, - PsArrayAccess, + PsBufferAcc, ) while isinstance(expr, (PsSubscript, PsLookup)): @@ -74,9 +74,9 @@ def determine_memory_object( return symb, symb.get_dtype().const case PsMemAcc(ptr, _): return None, cast(PsDereferencableType, ptr.get_dtype()).base_type.const - case PsArrayAccess(ptr, _): + case PsBufferAcc(ptr, _): return ( - expr.array, + expr.buffer, cast(PsDereferencableType, ptr.get_dtype()).base_type.const, ) case _: diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index e0a1f4242f2dadde44b23501ecd6af6c1076b834..e976159cfb3c11db19765a26d068190340d5ce9c 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -39,7 +39,7 @@ from .ast.expressions import ( PsSub, PsSymbolExpr, PsLiteralExpr, - PsVectorArrayAccess, + PsVectorMemAcc, PsTernary, PsAnd, PsOr, @@ -268,7 +268,7 @@ class CAstPrinter: case PsLiteralExpr(lit): return lit.text - case PsVectorArrayAccess(): + case PsVectorMemAcc(): raise EmissionError("Cannot print vectorized array accesses") case PsMemAcc(base, offset): diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 11ac929bfadb9cb9e62ddd8bab29d10c8d956d8b..1cadf1fa4835d6b2e3b89174028df5b6ae69d36f 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -28,7 +28,7 @@ from ..ast.structural import ( PsSymbolExpr, ) from ..ast.expressions import ( - PsArrayAccess, + PsBufferAcc, PsArrayInitList, PsBitwiseAnd, PsBitwiseOr, @@ -43,7 +43,7 @@ from ..ast.expressions import ( PsLookup, PsRightShift, PsSubscript, - PsVectorArrayAccess, + PsVectorMemAcc, PsTernary, PsRel, PsEq, @@ -158,7 +158,7 @@ class FreezeExpressions: if isinstance(lhs, PsSymbolExpr): return PsDeclaration(lhs, rhs) - elif isinstance(lhs, (PsArrayAccess, PsLookup, PsVectorArrayAccess)): # todo + elif isinstance(lhs, (PsBufferAcc, PsLookup, PsVectorMemAcc)): # todo return PsAssignment(lhs, rhs) else: raise FreezeError( @@ -372,9 +372,9 @@ class FreezeExpressions: if struct_member_name is not None: # Produce a Lookup here, don't check yet if the member name is valid. That's the typifier's job. - return PsLookup(PsArrayAccess(ptr, index), struct_member_name) + return PsLookup(PsBufferAcc(ptr, index), struct_member_name) else: - return PsArrayAccess(ptr, index) + return PsBufferAcc(ptr, index) def map_ConditionalFieldAccess(self, acc: ConditionalFieldAccess): facc = self.visit_expr(acc.access) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 819d4a12b5446dd7a330d69cee6248ab204d4a64..222b2cac3ed1966d46847ffd00f48e1f7ad2875d 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -26,7 +26,7 @@ from ..ast.structural import ( PsEmptyLeafMixIn, ) from ..ast.expressions import ( - PsArrayAccess, + PsBufferAcc, PsArrayInitList, PsBinOp, PsIntOpTrait, @@ -413,7 +413,7 @@ class Typifier: case PsLiteralExpr(lit): tc.apply_dtype(lit.dtype, expr) - case PsArrayAccess(bptr, idx): + case PsBufferAcc(bptr, idx): tc.apply_dtype(bptr.array.element_type, expr) self._handle_idx(idx) diff --git a/src/pystencils/backend/memory.py b/src/pystencils/backend/memory.py index 6bc7039f52f63bb43ecd4cf2101e4f19c1950791..f45635150848c9ba97ee5c0b42df74d7e497af08 100644 --- a/src/pystencils/backend/memory.py +++ b/src/pystencils/backend/memory.py @@ -194,6 +194,10 @@ class PsBuffer: @property def strides(self) -> tuple[PsSymbol | PsConstant, ...]: return self._strides + + @property + def dim(self) -> int: + return len(self._shape) @property def index_type(self) -> PsIntegerType: diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index c89d2278881bc48cd81df44ab49adefda4fc99f3..75c9b7a8ff405a37b855c328226562f3a3d979c8 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -17,7 +17,7 @@ from ..ast.expressions import ( PsCast, PsCall, PsLookup, - PsArrayAccess, + PsBufferAcc, ) from ..ast.expressions import PsLt, PsAnd from ...types import PsSignedIntegerType, PsIeeeFloatType @@ -171,7 +171,7 @@ class CudaPlatform(GenericGpu): PsDeclaration( PsExpression.make(ctr), PsLookup( - PsArrayAccess( + PsBufferAcc( ispace.index_list.base_pointer, sparse_ctr, ), diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index a1505e672862264a50a8b035a83dc8dcdfb0769d..4ea1d6d4c673214a337ab37892d7870839f79747 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -21,8 +21,8 @@ from ..ast.structural import PsDeclaration, PsLoop, PsBlock from ..ast.expressions import ( PsSymbolExpr, PsExpression, - PsArrayAccess, - PsVectorArrayAccess, + PsBufferAcc, + PsVectorMemAcc, PsLookup, PsGe, PsLe, @@ -128,7 +128,7 @@ class GenericCpu(Platform): PsDeclaration( PsSymbolExpr(ctr), PsLookup( - PsArrayAccess( + PsBufferAcc( ispace.index_list.base_pointer, PsExpression.make(ispace.sparse_counter), ), @@ -173,11 +173,11 @@ class GenericVectorCpu(GenericCpu, ABC): or raise an `MaterializationError` if not supported.""" @abstractmethod - def vector_load(self, acc: PsVectorArrayAccess) -> PsExpression: + def vector_load(self, acc: PsVectorMemAcc) -> PsExpression: """Return an expression intrinsically performing a vector load, or raise an `MaterializationError` if not supported.""" @abstractmethod - def vector_store(self, acc: PsVectorArrayAccess, arg: PsExpression) -> PsExpression: + def vector_store(self, acc: PsVectorMemAcc, arg: PsExpression) -> PsExpression: """Return an expression intrinsically performing a vector store, or raise an `MaterializationError` if not supported.""" diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py index fa42ed021b315e205f24e0735ea1bb113ffdbc26..7c34689322e132ed90d080e887767dd3b08afc21 100644 --- a/src/pystencils/backend/platforms/sycl.py +++ b/src/pystencils/backend/platforms/sycl.py @@ -16,7 +16,7 @@ from ..ast.expressions import ( PsLe, PsTernary, PsLookup, - PsArrayAccess + PsBufferAcc ) from ..extensions.cpp import CppMethodCall @@ -163,7 +163,7 @@ class SyclPlatform(GenericGpu): PsDeclaration( PsExpression.make(ctr), PsLookup( - PsArrayAccess( + PsBufferAcc( ispace.index_list.base_pointer, sparse_ctr, ), diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index 0c6f6883d3581725987684a402fcba41c03ed3d9..5f5ad4a05bad8b02d18e5032e0d52e6daad2a48c 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -5,7 +5,7 @@ from typing import Sequence from ..ast.expressions import ( PsExpression, - PsVectorArrayAccess, + PsVectorMemAcc, PsAddressOf, PsMemAcc, ) @@ -141,7 +141,7 @@ class X86VectorCpu(GenericVectorCpu): func = _x86_op_intrin(self._vector_arch, op, vtype) return func(*args) - def vector_load(self, acc: PsVectorArrayAccess) -> PsExpression: + def vector_load(self, acc: PsVectorMemAcc) -> PsExpression: if acc.stride == 1: load_func = _x86_packed_load(self._vector_arch, acc.dtype, False) return load_func( @@ -150,7 +150,7 @@ class X86VectorCpu(GenericVectorCpu): else: raise NotImplementedError("Gather loads not implemented yet.") - def vector_store(self, acc: PsVectorArrayAccess, arg: PsExpression) -> PsExpression: + def vector_store(self, acc: PsVectorMemAcc, arg: PsExpression) -> PsExpression: if acc.stride == 1: store_func = _x86_packed_store(self._vector_arch, acc.dtype, False) return store_func( diff --git a/src/pystencils/backend/transformations/erase_anonymous_structs.py b/src/pystencils/backend/transformations/erase_anonymous_structs.py index 08fd6bfa59bec56baa8a8206ee8bd3cd1d3882af..60787759231641ad3aefba8481b4506ada26dc58 100644 --- a/src/pystencils/backend/transformations/erase_anonymous_structs.py +++ b/src/pystencils/backend/transformations/erase_anonymous_structs.py @@ -5,7 +5,7 @@ from ..kernelcreation.context import KernelCreationContext from ..constants import PsConstant from ..ast.structural import PsAstNode from ..ast.expressions import ( - PsArrayAccess, + PsBufferAcc, PsLookup, PsExpression, PsMemAcc, @@ -58,10 +58,10 @@ class EraseAnonymousStructTypes: def handle_lookup(self, lookup: PsLookup) -> PsExpression: aggr = lookup.aggregate - if not isinstance(aggr, PsArrayAccess): + if not isinstance(aggr, PsBufferAcc): return lookup - arr = aggr.array + arr = aggr.buffer if ( not isinstance(arr.element_type, PsStructType) or not arr.element_type.anonymous @@ -96,7 +96,7 @@ class EraseAnonymousStructTypes: byte_index = base_index + PsExpression.make( PsConstant(member_offset, self._ctx.index_dtype) ) - type_erased_access = PsArrayAccess(type_erased_bp, byte_index) + type_erased_access = PsBufferAcc(type_erased_bp, byte_index) deref = PsMemAcc( PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access)), diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index c10a696ab7ae5435b30f3ecf5c707df112d8956a..f0e4cc9f19f1a046125bb3e8aab5302a9df2790c 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -9,7 +9,7 @@ from ..ast.expressions import ( PsConstantExpr, PsLiteralExpr, PsCall, - PsArrayAccess, + PsBufferAcc, PsSubscript, PsLookup, PsUnOp, @@ -53,7 +53,7 @@ class HoistContext: case PsSubscript() | PsLookup(): return determine_memory_object(expr)[1] and args_invariant(expr) - case PsArrayAccess(ptr, _): + case PsBufferAcc(ptr, _): # Regular pointer derefs are never invariant, since we cannot reason about aliasing ptr_type = cast(PsDereferencableType, ptr.get_dtype()) return ptr_type.base_type.const and args_invariant(expr) diff --git a/src/pystencils/backend/transformations/select_intrinsics.py b/src/pystencils/backend/transformations/select_intrinsics.py index 7972de0699f52b1230a5e9c9e00b43d5d122f61f..3fb484c154fbb4ab873deea3e9b1d83c2f4354e6 100644 --- a/src/pystencils/backend/transformations/select_intrinsics.py +++ b/src/pystencils/backend/transformations/select_intrinsics.py @@ -6,7 +6,7 @@ from ..ast.structural import PsAstNode, PsAssignment, PsStatement from ..ast.expressions import PsExpression from ...types import PsVectorType, deconstify from ..ast.expressions import ( - PsVectorArrayAccess, + PsVectorMemAcc, PsSymbolExpr, PsConstantExpr, PsBinOp, @@ -66,7 +66,7 @@ class MaterializeVectorIntrinsics: def visit(self, node: PsAstNode) -> PsAstNode: match node: - case PsAssignment(lhs, rhs) if isinstance(lhs, PsVectorArrayAccess): + case PsAssignment(lhs, rhs) if isinstance(lhs, PsVectorMemAcc): vc = VecTypeCtx() vc.set(lhs.get_vector_type()) store_arg = self.visit_expr(rhs, vc) @@ -94,7 +94,7 @@ class MaterializeVectorIntrinsics: else: return expr - case PsVectorArrayAccess(): + case PsVectorMemAcc(): vc.set(expr.get_vector_type()) return self._platform.vector_load(expr) diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 65bf57e787c8cecdc6f1d279a5ca09fc374a07e4..67e5d2319557a679403093f057e5eea5cc1c7122 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -17,7 +17,7 @@ from pystencils.backend.ast.structural import ( PsDeclaration, ) from pystencils.backend.ast.expressions import ( - PsArrayAccess, + PsBufferAcc, PsBitwiseAnd, PsBitwiseOr, PsBitwiseXor, @@ -113,12 +113,12 @@ def test_freeze_fields(): zero = PsExpression.make(PsConstant(0)) - lhs = PsArrayAccess( + lhs = PsBufferAcc( f_arr.base_pointer, (PsExpression.make(counter) + zero) * PsExpression.make(f_arr.strides[0]) + zero * one, ) - rhs = PsArrayAccess( + rhs = PsBufferAcc( g_arr.base_pointer, (PsExpression.make(counter) + zero) * PsExpression.make(g_arr.strides[0]) + zero * one, diff --git a/tests/nbackend/test_ast.py b/tests/nbackend/test_ast.py index 88fdd3c8d1e85bc8365bed7268882632aa904891..cf7fd3f31b13f0fbbac3b350f769e6993ab44d9d 100644 --- a/tests/nbackend/test_ast.py +++ b/tests/nbackend/test_ast.py @@ -1,4 +1,7 @@ -from pystencils.backend.memory import PsSymbol +import pytest + +from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory +from pystencils.backend.memory import PsSymbol, BufferBasePtr from pystencils.backend.constants import PsConstant from pystencils.backend.ast.expressions import ( PsExpression, @@ -6,6 +9,8 @@ from pystencils.backend.ast.expressions import ( PsMemAcc, PsArrayInitList, PsSubscript, + PsBufferAcc, + PsSymbolExpr, ) from pystencils.backend.ast.structural import ( PsStatement, @@ -72,3 +77,45 @@ def test_cloning(): ]: ast_clone = ast.clone() check(ast, ast_clone) + + +def test_buffer_acc(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + from pystencils import fields + + f, g = fields("f, g(3): [2D]") + a, b = [ctx.get_symbol(n, ctx.index_dtype) for n in "ab"] + + f_buf = ctx.get_buffer(f) + + f_acc = PsBufferAcc(f_buf.base_pointer, [PsExpression.make(i) for i in (a, b)] + [factory.parse_index(0)]) + assert f_acc.buffer == f_buf + assert f_acc.base_pointer.structurally_equal(PsSymbolExpr(f_buf.base_pointer)) + + f_acc_clone = f_acc.clone() + assert f_acc_clone is not f_acc + + assert f_acc_clone.buffer == f_buf + assert f_acc_clone.base_pointer.structurally_equal(PsSymbolExpr(f_buf.base_pointer)) + assert len(f_acc_clone.index) == 3 + assert f_acc_clone.index[0].structurally_equal(PsSymbolExpr(ctx.get_symbol("a"))) + assert f_acc_clone.index[1].structurally_equal(PsSymbolExpr(ctx.get_symbol("b"))) + + g_buf = ctx.get_buffer(g) + + g_acc = PsBufferAcc(g_buf.base_pointer, [PsExpression.make(i) for i in (a, b)] + [factory.parse_index(2)]) + assert g_acc.buffer == g_buf + assert g_acc.base_pointer.structurally_equal(PsSymbolExpr(g_buf.base_pointer)) + + second_bptr = PsExpression.make(ctx.get_symbol("data_g_interior", g_buf.base_pointer.dtype)) + second_bptr.symbol.add_property(BufferBasePtr(g_buf)) + g_acc.base_pointer = second_bptr + + assert g_acc.base_pointer == second_bptr + assert g_acc.buffer == g_buf + + # cannot change base pointer to different buffer + with pytest.raises(ValueError): + g_acc.base_pointer = PsExpression.make(f_buf.base_pointer) diff --git a/tests/nbackend/test_cpujit.py b/tests/nbackend/test_cpujit.py index dc321848645e65e62c605b6110951534e678492b..648112ef95bf5d6c3181f5c3c2527dd870220f0e 100644 --- a/tests/nbackend/test_cpujit.py +++ b/tests/nbackend/test_cpujit.py @@ -6,7 +6,7 @@ from pystencils import Target from pystencils.backend.memory import PsSymbol, PsBuffer from pystencils.backend.constants import PsConstant -from pystencils.backend.ast.expressions import PsArrayAccess, PsExpression +from pystencils.backend.ast.expressions import PsBufferAcc, PsExpression from pystencils.backend.ast.structural import PsAssignment, PsBlock, PsLoop from pystencils.backend.kernelfunction import KernelFunction @@ -33,8 +33,8 @@ def test_pairwise_addition(): two = PsExpression.make(PsConstant(2, idx_type)) update = PsAssignment( - PsArrayAccess(v_data, loop_ctr), - PsArrayAccess(u_data, two * loop_ctr) + PsArrayAccess(u_data, two * loop_ctr + one) + PsBufferAcc(v_data, loop_ctr), + PsBufferAcc(u_data, two * loop_ctr) + PsBufferAcc(u_data, two * loop_ctr + one) ) loop = PsLoop(