diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 536c73c7f81ede3ea36629e361faa691f0631db6..358b5ff6cdeb6cca1cfc232c5129ae239bdb3be9 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -49,6 +49,8 @@ FieldArrayPair = namedtuple("FieldArrayPair", ("field", "array")) @dataclass(frozen=True) class ReductionInfo: + """Information about a reduction operation, its neutral element in form of an initial value + and the pointer used by the kernel as write-back argument.""" op: ReductionOp init_val: PsExpression diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 63e9ea5b15f4fd15f87ff254862d1e6c17d1b626..9dc3928b3549e7d0a9eae4ae2e78244e4e9e4b4c 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -13,7 +13,7 @@ from ...sympyextensions import ( integer_functions, ConditionalFieldAccess, ) -from ...compound_op_mapping import compound_op_to_expr +from ...reduction_op_mapping import reduction_op_to_expr from ...sympyextensions.typed_sympy import TypedSymbol, TypeCast, DynamicType from ...sympyextensions.pointers import AddressOf, mem_acc from ...sympyextensions.reduction import ReductionAssignment, ReductionOp @@ -174,15 +174,17 @@ class FreezeExpressions: assert isinstance(lhs, PsExpression) assert isinstance(rhs, PsExpression) - _str_to_compound_op: dict[str, ReductionOp] = { + # transform augmented assignment to reduction op + str_to_reduction_op: dict[str, ReductionOp] = { "+=": ReductionOp.Add, "-=": ReductionOp.Sub, "*=": ReductionOp.Mul, "/=": ReductionOp.Div, } + # reuse existing handling for transforming reduction ops to expressions return PsAssignment( - lhs, compound_op_to_expr(_str_to_compound_op[expr.op], lhs.clone(), rhs) + lhs, reduction_op_to_expr(str_to_reduction_op[expr.op], lhs.clone(), rhs) ) def map_ReductionAssignment(self, expr: ReductionAssignment): @@ -198,7 +200,8 @@ class FreezeExpressions: orig_lhs_symb = lhs.symbol dtype = lhs.dtype - assert isinstance(dtype, PsNumericType) + assert isinstance(dtype, PsNumericType), \ + "Reduction assignments require type information of the lhs symbol." # replace original symbol with pointer-based type used for export orig_lhs_symb_as_ptr = PsSymbol(orig_lhs_symb.name, PsPointerType(dtype)) @@ -208,7 +211,7 @@ class FreezeExpressions: new_lhs = PsSymbolExpr(new_lhs_symb) # get new rhs from augmented assignment - new_rhs: PsExpression = compound_op_to_expr(op, new_lhs.clone(), rhs) + new_rhs: PsExpression = reduction_op_to_expr(op, new_lhs.clone(), rhs) # match for reduction operation and set neutral init_val init_val: PsExpression @@ -224,7 +227,7 @@ class FreezeExpressions: case ReductionOp.Max: init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), []) case _: - raise FreezeError(f"Unsupported reduced assignment: {op}.") + raise FreezeError(f"Unsupported kind of reduction assignment: {op}.") reduction_info = ReductionInfo(op, init_val, orig_lhs_symb_as_ptr) diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index da8375c5ea853f785bb07410d3c0f8ee1db6c18f..8c3cd45faf2a210bf8de16e3f25d9ba7897d4e0f 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -23,7 +23,7 @@ from ..constants import PsConstant from ..exceptions import MaterializationError from ..functions import NumericLimitsFunctions, CFunction from ..literals import PsLiteral -from ...compound_op_mapping import compound_op_to_expr +from ...reduction_op_mapping import reduction_op_to_expr from ...sympyextensions import ReductionOp from ...types import PsType, PsIeeeFloatType, PsCustomType, PsPointerType, PsScalarType from ...types.quick import SInt, UInt @@ -87,7 +87,7 @@ class CudaPlatform(GenericGpu): shuffles = tuple( PsAssignment( symbol_expr, - compound_op_to_expr( + reduction_op_to_expr( actual_reduction_op, symbol_expr, gen_shuffle_instr(pow(2, i - 1)), diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index ccef6181793af2c8fa3271edcf5f6b62716863a1..3de7cf696ff47d65950e376e4aa38c63e1528dcf 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -13,7 +13,7 @@ from ..functions import ( PsReductionFunction, ) from ..literals import PsLiteral -from ...compound_op_mapping import compound_op_to_expr +from ...reduction_op_mapping import reduction_op_to_expr from ...sympyextensions import ReductionOp from ...types import PsIntegerType, PsIeeeFloatType, PsScalarType, PsPointerType @@ -97,7 +97,7 @@ class GenericCpu(Platform): actual_op = ReductionOp.Add if op is ReductionOp.Sub else op # create binop and potentially select corresponding function for e.g. min or max - potential_call = compound_op_to_expr(actual_op, ptr_access, symbol_expr) + potential_call = reduction_op_to_expr(actual_op, ptr_access, symbol_expr) if isinstance(potential_call, PsCall): potential_call.dtype = symbol_expr.dtype return self.select_function(potential_call) diff --git a/src/pystencils/compound_op_mapping.py b/src/pystencils/reduction_op_mapping.py similarity index 82% rename from src/pystencils/compound_op_mapping.py rename to src/pystencils/reduction_op_mapping.py index 193b308d01de6c12c40e1104f7704b97d02ecba3..06fb8aa3e981d661f8e4226a34307f03cd6d9a70 100644 --- a/src/pystencils/compound_op_mapping.py +++ b/src/pystencils/reduction_op_mapping.py @@ -11,7 +11,7 @@ _available_operator_interface: set[ReductionOp] = { } -def compound_op_to_expr(op: ReductionOp, op1, op2) -> PsExpression: +def reduction_op_to_expr(op: ReductionOp, op1, op2) -> PsExpression: if op in _available_operator_interface: match op: case ReductionOp.Add: @@ -24,7 +24,7 @@ def compound_op_to_expr(op: ReductionOp, op1, op2) -> PsExpression: operator = PsDiv case _: raise FreezeError( - f"Found unsupported operation type for compound assignments: {op}." + f"Found unsupported operation type for reduction assignments: {op}." ) return operator(op1, op2) else: @@ -35,5 +35,5 @@ def compound_op_to_expr(op: ReductionOp, op1, op2) -> PsExpression: return PsCall(PsMathFunction(MathFunctions.Max), [op1, op2]) case _: raise FreezeError( - f"Found unsupported operation type for compound assignments: {op}." + f"Found unsupported operation type for reduction assignments: {op}." )