diff --git a/src/pystencils/backend/kernelcreation/analysis.py b/src/pystencils/backend/kernelcreation/analysis.py index 1365e1ef327afd756e17eac282efb7c0f8e65052..e5f8b921e78d703d52f507e5d577a9385adb2393 100644 --- a/src/pystencils/backend/kernelcreation/analysis.py +++ b/src/pystencils/backend/kernelcreation/analysis.py @@ -13,6 +13,8 @@ from ...simp import AssignmentCollection from sympy.codegen.ast import AssignmentBase from ..exceptions import PsInternalCompilerError, KernelConstraintsError +from ...sympyextensions.reduction import ReductionAssignment +from ...sympyextensions.typed_sympy import TypedSymbol class KernelAnalysis: @@ -54,6 +56,8 @@ class KernelAnalysis: self._check_access_independence = check_access_independence self._check_double_writes = check_double_writes + self._reduction_symbols: set[TypedSymbol] = set() + # Map pairs of fields and indices to offsets self._field_writes: dict[KernelAnalysis.FieldAndIndex, set[Any]] = defaultdict( set @@ -88,6 +92,14 @@ class KernelAnalysis: for asm in asms: self._visit(asm) + case ReductionAssignment(): + assert isinstance(obj.lhs, TypedSymbol) + + self._reduction_symbols.add(obj.lhs) + + self._handle_rhs(obj.rhs) + self._handle_lhs(obj.lhs) + case AssignmentBase(): self._handle_rhs(obj.rhs) self._handle_lhs(obj.lhs) @@ -152,6 +164,11 @@ class KernelAnalysis: f"{field} is read at {offsets} and written at {write_offset}" ) case sp.Symbol(): + if expr in self._reduction_symbols: + raise KernelConstraintsError( + f"Illegal access to reduction symbol {expr.name} outside of ReductionAssignment. " + ) + self._scopes.access_symbol(expr) for arg in expr.args: diff --git a/tests/nbackend/kernelcreation/test_analysis.py b/tests/nbackend/kernelcreation/test_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..d68c0a5f3c88fbc53c157ea8051403c9bb3c56b3 --- /dev/null +++ b/tests/nbackend/kernelcreation/test_analysis.py @@ -0,0 +1,38 @@ +import pytest + +from pystencils import fields, TypedSymbol, AddReductionAssignment, Assignment, KernelConstraintsError +from pystencils.backend.kernelcreation import KernelCreationContext, KernelAnalysis +from pystencils.sympyextensions import mem_acc +from pystencils.types.quick import Ptr, Fp + + +def test_invalid_reduction_symbol_reassign(): + dtype = Fp(64) + ctx = KernelCreationContext(default_dtype=dtype) + analysis = KernelAnalysis(ctx) + + x = fields(f"x: [1d]") + w = TypedSymbol("w", dtype) + + # illegal reassign to already locally defined symbol (here: reduction symbol) + with pytest.raises(KernelConstraintsError): + analysis([ + AddReductionAssignment(w, 3 * x.center()), + Assignment(w, 0) + ]) + +def test_invalid_reduction_symbol_reference(): + dtype = Fp(64) + ctx = KernelCreationContext(default_dtype=dtype) + analysis = KernelAnalysis(ctx) + + x = fields(f"x: [1d]") + v = TypedSymbol("v", dtype) + w = TypedSymbol("w", dtype) + + # do not allow reduction symbol to be referenced on rhs of other assignments + with pytest.raises(KernelConstraintsError): + analysis([ + AddReductionAssignment(w, 3 * x.center()), + Assignment(v, w) + ]) \ No newline at end of file