diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 045aca1d16021f6f164943873aafd9d144a75c42..e0ed0f1f7a11ed7d63d900a19f5e264980578cfb 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -189,22 +189,12 @@ class FreezeExpressions: def map_ReductionAssignment(self, expr: ReductionAssignment): assert isinstance(expr.lhs, TypedSymbol) - # make sure that either: - # 1) lhs symbol never occurred - # 2) that it is at least known as lhs of an existing reduction operation + # make sure that lhs symbol never occurred before ReductionAssignment if self._ctx.find_symbol(expr.lhs.name): - # make sure that reduction operations are not mixed within a kernel - if info := self._ctx.find_reduction_info(expr.lhs.name): - if info.op is not expr.reduction_op: - raise FreezeError( - f"Different reduction operation {info.op} already exists " - f"for {expr.lhs} with target reduction op {expr.reduction_op}." - ) - else: - raise FreezeError( - f"Left-hand side {expr.lhs} of ReductionAssignment already exists in symbol table." - f"Make sure that it is exclusively used within the kernel to conduct ReductionAssignment's." - ) + raise FreezeError( + f"Left-hand side {expr.lhs} of ReductionAssignment already exists in symbol table. " + f"Make sure that it is only used once in a kernel's ReductionAssignment." + ) lhs = self.visit(expr.lhs) rhs = self.visit(expr.rhs) @@ -340,6 +330,16 @@ class FreezeExpressions: def map_TypedSymbol(self, expr: TypedSymbol): dtype = self._ctx.resolve_dynamic_type(expr.dtype) + + # check if symbol is referenced after freezing a ReductionAssignment + if self._ctx.find_reduction_info(expr.name): + # check if types do not align since a ReductionAssignment modifies + # the symbol's type to PsPointerType in the context's symbol table + if (symbol := self._ctx.find_symbol(expr.name)) and symbol.dtype != dtype: + raise FreezeError( + f"Illegal access to reduction symbol {symbol.name} after freezing a kernel's ReductionAssignment. " + ) + symb = self._ctx.get_symbol(expr.name, dtype) return PsSymbolExpr(symb)