diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 3e79bf24a43e202304b16cd74bb4b16b7c4e7cd5..48e2f4a3a7f349c407472acfef60ec94ecea8f4c 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -207,7 +207,7 @@ class KernelCreationContext: def add_reduction_info( self, lhs_name: str, - lhs_dtype: PsType, + lhs_dtype: PsNumericType, reduction_op: ReductionOp, ): """Create ReductionInfo instance and add to its corresponding lookup table for a given symbol name.""" @@ -215,16 +215,15 @@ class KernelCreationContext: # make sure that lhs symbol never occurred before ReductionAssignment if self.find_symbol(lhs_name): raise KernelConstraintsError( - f"Left-hand side {lhs_name} of ReductionAssignment already exists in symbol table. " - f"Make sure that it is only used once in a kernel's ReductionAssignment." + f"Cannot create reduction with symbol {lhs_name}: " + "Another symbol with the same name already exist." ) # add symbol for lhs with pointer datatype for write-back mechanism pointer_symb = self.get_symbol(lhs_name, PsPointerType(lhs_dtype)) # create kernel-local copy of lhs symbol - local_symb = PsSymbol(f"{lhs_name}_local", lhs_dtype) - self.add_symbol(local_symb) + local_symb = self.get_new_symbol(f"{lhs_name}_local", lhs_dtype) # match for reduction operation and set neutral init_val init_val: PsExpression diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index b1bb4cd4a9084310f1eb066ab636eea2a331d91d..5987165678327eb0e49ca10f4ee66bb7f7fffd39 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -198,9 +198,8 @@ class FreezeExpressions: lhs_dtype = lhs_symbol.dtype lhs_name = lhs_symbol.name - assert isinstance( - lhs_dtype, PsNumericType - ), "Reduction assignments require type information of the lhs symbol." + if not isinstance(lhs_dtype, PsNumericType): + raise FreezeError("Reduction symbol must have a numeric data type.") # get reduction info from context reduction_info = self._ctx.add_reduction_info(