diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 358b5ff6cdeb6cca1cfc232c5129ae239bdb3be9..58a4bd7d11c8150a5df120e8b0360fc25aec3f9c 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -54,7 +54,8 @@ class ReductionInfo: op: ReductionOp init_val: PsExpression - ptr_symbol: PsSymbol + local_symbol: PsSymbol + writeback_ptr_symbol: PsSymbol class KernelCreationContext: @@ -88,7 +89,7 @@ class KernelCreationContext: self._symbol_ctr_pattern = re.compile(r"__[0-9]+$") self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0) - self._symbols_reduction_info: dict[PsSymbol, ReductionInfo] = dict() + self._reduction_data: dict[str, ReductionInfo] = dict() self._fields_and_arrays: dict[str, FieldArrayPair] = dict() self._fields_collection = FieldsInKernel() @@ -193,19 +194,39 @@ class KernelCreationContext: self._symbols[old.name] = new - def add_symbol_reduction_info( - self, local_symb: PsSymbol, reduction_info: ReductionInfo + def add_reduction_info( + self, + lhs_name: str, + lhs_dtype: PsType, + reduction_op: ReductionOp, + init_value: PsExpression, ): - """Adds entry for a symbol and its reduction info to its corresponding lookup table. + """Create ReductionInfo instance and add to its corresponding lookup table for a given symbol name.""" - The symbol ``symbol`` shall not exist in the symbol table already. - """ - if local_symb in self._symbols_reduction_info: - raise PsInternalCompilerError( - f"add_symbol_reduction_info: {local_symb.name} already exist in the symbol table" - ) + # replace datatype of lhs symbol with pointer datatype for write-back mechanism + symb = self.get_symbol(lhs_name, lhs_dtype) + pointer_symb = PsSymbol(lhs_name, PsPointerType(lhs_dtype)) + self.replace_symbol(symb, pointer_symb) + + # create kernel-local copy of lhs symbol + local_symb = PsSymbol(f"{lhs_name}_local", lhs_dtype) + self.add_symbol(local_symb) - self._symbols_reduction_info[local_symb] = reduction_info + # create reduction info and add to set + reduction_info = ReductionInfo( + reduction_op, init_value, local_symb, pointer_symb + ) + self._reduction_data[lhs_name] = reduction_info + + return reduction_info + + def find_reduction_info(self, name: str) -> ReductionInfo | None: + """Find a ReductionInfo with the given name in the lookup table, if it exists. + + Returns: + The ReductionInfo with the given name, or `None` if it does not exist. + """ + return self._reduction_data.get(name, None) def duplicate_symbol( self, symb: PsSymbol, new_dtype: PsType | None = None @@ -243,9 +264,9 @@ class KernelCreationContext: return self._symbols.values() @property - def symbols_reduction_info(self) -> dict[PsSymbol, ReductionInfo]: - """Return a dictionary holding kernel-local reduction symbols and their reduction properties.""" - return self._symbols_reduction_info + def reduction_data(self) -> dict[str, ReductionInfo]: + """Return a dictionary holding kernel-local reduction information for given symbol names.""" + return self._reduction_data # Fields and Arrays diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index c5ff43fb96d74bcc21f3c5a1b212d8635b5fa2ce..2f00df4e86575fe5e5d916f7c29cfb3cdc600a94 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -6,7 +6,6 @@ import sympy as sp import sympy.logic.boolalg from sympy.codegen.ast import AssignmentBase, AugmentedAssignment -from ..memory import PsSymbol from ...assignment import Assignment from ...simp import AssignmentCollection from ...sympyextensions import ( @@ -19,7 +18,7 @@ from ...sympyextensions.pointers import AddressOf, mem_acc from ...sympyextensions.reduction import ReductionAssignment, ReductionOp from ...field import Field, FieldType -from .context import KernelCreationContext, ReductionInfo +from .context import KernelCreationContext from ..ast.structural import ( PsAstNode, @@ -62,7 +61,7 @@ from ..ast.expressions import ( from ..ast.vector import PsVecMemAcc from ..constants import PsConstant -from ...types import PsNumericType, PsStructType, PsType, PsPointerType +from ...types import PsNumericType, PsStructType, PsType from ..exceptions import PsInputError from ..functions import PsMathFunction, MathFunctions, NumericLimitsFunctions from ..exceptions import FreezeError @@ -190,32 +189,41 @@ class FreezeExpressions: def map_ReductionAssignment(self, expr: ReductionAssignment): assert isinstance(expr.lhs, TypedSymbol) + # make sure that either: + # 1) lhs symbol never occurred + # 2) that it is at least known as lhs of an existing reduction operation + if self._ctx.find_symbol(expr.lhs.name): + # make sure that reduction operations are not mixed within a kernel + if info := self._ctx.find_reduction_info(expr.lhs.name): + if info.op is not expr.reduction_op: + raise FreezeError( + f"Different reduction operation {info.op} already exists " + f"for {expr.lhs} with target reduction op {expr.reduction_op}." + ) + else: + raise FreezeError( + f"Left-hand side {expr.lhs} of ReductionAssignment already exists in symbol table." + f"Make sure that it is exclusively used within the kernel to conduct ReductionAssignment's." + ) + lhs = self.visit(expr.lhs) rhs = self.visit(expr.rhs) assert isinstance(rhs, PsExpression) assert isinstance(lhs, PsSymbolExpr) - op = expr.reduction_op - orig_lhs_symb = lhs.symbol - dtype = lhs.dtype - - assert isinstance(dtype, PsNumericType), \ - "Reduction assignments require type information of the lhs symbol." - - # replace original symbol with pointer-based type used for export - orig_lhs_symb_as_ptr = PsSymbol(orig_lhs_symb.name, PsPointerType(dtype)) - - # create kernel-local copy of lhs symbol to work with - new_lhs_symb = PsSymbol(f"{orig_lhs_symb.name}_local", dtype) - new_lhs = PsSymbolExpr(new_lhs_symb) + reduction_op = expr.reduction_op + lhs_symbol = lhs.symbol + lhs_dtype = lhs_symbol.dtype + lhs_name = lhs_symbol.name - # get new rhs from augmented assignment - new_rhs: PsExpression = reduction_op_to_expr(op, new_lhs.clone(), rhs) + assert isinstance( + lhs_dtype, PsNumericType + ), "Reduction assignments require type information of the lhs symbol." # match for reduction operation and set neutral init_val init_val: PsExpression - match op: + match reduction_op: case ReductionOp.Add: init_val = PsConstantExpr(PsConstant(0)) case ReductionOp.Sub: @@ -227,14 +235,20 @@ class FreezeExpressions: case ReductionOp.Max: init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), []) case _: - raise FreezeError(f"Unsupported kind of reduction assignment: {op}.") + raise FreezeError( + f"Unsupported kind of reduction assignment: {reduction_op}." + ) - reduction_info = ReductionInfo(op, init_val, orig_lhs_symb_as_ptr) + # get reduction info from context + reduction_info = self._ctx.add_reduction_info( + lhs_name, lhs_dtype, reduction_op, init_val + ) + + # create new lhs from newly created local lhs symbol + new_lhs = PsSymbolExpr(reduction_info.local_symbol) - # add new symbol for local copy, replace original copy with pointer counterpart and add reduction info - self._ctx.add_symbol(new_lhs_symb) - self._ctx.add_symbol_reduction_info(new_lhs_symb, reduction_info) - self._ctx.replace_symbol(orig_lhs_symb, orig_lhs_symb_as_ptr) + # get new rhs from augmented assignment + new_rhs: PsExpression = reduction_op_to_expr(reduction_op, new_lhs, rhs) return PsAssignment(new_lhs, new_rhs) diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index fa466e495c826fca6c8f0ae1f338dfd868436ebd..1d1cb6a8de914b08f7bdb0657a126b54fe857a6c 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -123,11 +123,11 @@ class AddOpenMP: if num_threads is not None: pragma_text += f" num_threads({str(num_threads)})" - if bool(ctx.symbols_reduction_info): - for symbol, reduction_info in ctx.symbols_reduction_info.items(): - if isinstance(symbol.dtype, PsScalarType): + if bool(ctx.reduction_data): + for _, reduction_info in ctx.reduction_data.items(): + if isinstance(reduction_info.local_symbol.dtype, PsScalarType): pragma_text += ( - f" reduction({reduction_info.op.value}: {symbol.name})" + f" reduction({reduction_info.op.value}: {reduction_info.local_symbol.name})" ) else: NotImplementedError( diff --git a/src/pystencils/backend/transformations/loop_vectorizer.py b/src/pystencils/backend/transformations/loop_vectorizer.py index 09b0aa5dd2e28b62d22e73e74ab8216bdcaf502e..8061240b75fb413fc49bd3d93c7ab8127e36303e 100644 --- a/src/pystencils/backend/transformations/loop_vectorizer.py +++ b/src/pystencils/backend/transformations/loop_vectorizer.py @@ -143,16 +143,17 @@ class LoopVectorizer: # Prepare reductions simd_init_local_reduction_vars: list[PsStructuralNode] = [] simd_writeback_local_reduction_vars: list[PsStructuralNode] = [] - for symb, reduction_info in self._ctx.symbols_reduction_info.items(): + for _, reduction_info in self._ctx.reduction_data.items(): # Vectorize symbol for local copy - vector_symb = vc.vectorize_symbol(symb) + local_symbol = reduction_info.local_symbol + vector_symb = vc.vectorize_symbol(local_symbol) # Declare and init vector simd_init_local_reduction_vars += [ self._type_fold( PsDeclaration( PsSymbolExpr(vector_symb), - PsVecBroadcast(self._lanes, PsSymbolExpr(symb)), + PsVecBroadcast(self._lanes, PsSymbolExpr(local_symbol)), ) ) ] @@ -160,9 +161,9 @@ class LoopVectorizer: # Write back vectorization result simd_writeback_local_reduction_vars += [ PsAssignment( - PsSymbolExpr(symb), + PsSymbolExpr(local_symbol), PsVecHorizontal( - PsSymbolExpr(symb), + PsSymbolExpr(local_symbol), PsSymbolExpr(vector_symb), reduction_info.op, ), diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index 3d107eda386abd517929a6dbba51164936d4e111..74a07b902a6494d0c528602326286a2aaab01cc5 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -192,8 +192,8 @@ class DefaultKernelCreationDriver: self._intermediates.constants_eliminated = kernel_ast.clone() # Extensions for reductions - for symbol, reduction_info in self._ctx.symbols_reduction_info.items(): - self._modify_kernel_ast_for_reductions(symbol, reduction_info, kernel_ast) + for _, reduction_info in self._ctx.reduction_data.items(): + self._modify_kernel_ast_for_reductions(reduction_info, kernel_ast) # Target-Specific optimizations if self._target.is_cpu(): @@ -296,13 +296,12 @@ class DefaultKernelCreationDriver: return kernel_body def _modify_kernel_ast_for_reductions(self, - symbol: PsSymbol, reduction_info: ReductionInfo, kernel_ast: PsBlock): # typify local symbol and write-back pointer expressions and initial value typify = Typifier(self._ctx) - symbol_expr = typify(PsSymbolExpr(symbol)) - ptr_symbol_expr = typify(PsSymbolExpr(reduction_info.ptr_symbol)) + symbol_expr = typify(PsSymbolExpr(reduction_info.local_symbol)) + ptr_symbol_expr = typify(PsSymbolExpr(reduction_info.writeback_ptr_symbol)) init_val = typify(reduction_info.init_val) ptr_access = PsMemAcc(