From c76b897f87b2a4841ab0894249cdaf771d7941b5 Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Wed, 30 Apr 2025 16:09:09 +0200 Subject: [PATCH] Drop ReductionFunctions and introduce PsReductionWriteBack --- src/pystencils/backend/functions.py | 32 ++++++------------- .../backend/platforms/generic_cpu.py | 13 +++----- .../backend/platforms/generic_gpu.py | 13 +++----- .../transformations/select_functions.py | 4 +-- src/pystencils/codegen/driver.py | 6 ++-- 5 files changed, 21 insertions(+), 47 deletions(-) diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 6a9d3e4f4..a3da9a1de 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -95,45 +95,31 @@ class PsMathFunction(PsFunction): return hash(self._func) -class ReductionFunctions(Enum): - """Function representing different steps in kernels with reductions supported by the backend. +class PsReductionWriteBack(PsFunction): + """Function representing a reduction kernel's write-back step supported by the backend. - Each platform has to materialize these functions to a concrete implementation. + Each platform has to materialize this function to a concrete implementation. """ - WriteBackToPtr = ("WriteBackToPtr", 2) - - def __init__(self, func_name, num_args): - self.function_name = func_name - self.num_args = num_args - - -class PsReductionFunction(PsFunction): - - def __init__(self, func: ReductionFunctions, reduction_op: ReductionOp) -> None: - super().__init__(func.function_name, func.num_args) - self._func = func + def __init__(self, reduction_op: ReductionOp) -> None: + super().__init__("WriteBackToPtr", 2) self._reduction_op = reduction_op - @property - def func(self) -> ReductionFunctions: - return self._func - @property def reduction_op(self) -> ReductionOp: return self._reduction_op def __str__(self) -> str: - return f"{self._func.function_name}" + return f"{super().name}" def __eq__(self, other: object) -> bool: - if not isinstance(other, PsReductionFunction): + if not isinstance(other, PsReductionWriteBack): return False - return self._func == other._func and self._reduction_op == other._reduction_op + return self._reduction_op == other._reduction_op def __hash__(self) -> int: - return hash(self._func) + return hash(self._reduction_op) class ConstantFunctions(Enum): diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 1dc2914b9..cbde419d4 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -8,11 +8,10 @@ from ..ast import PsAstNode from ..functions import ( CFunction, MathFunctions, - ReductionFunctions, PsMathFunction, - PsReductionFunction, PsConstantFunction, ConstantFunctions, + PsReductionWriteBack, ) from ..reduction_op_mapping import reduction_op_to_expr from ...sympyextensions import ReductionOp @@ -69,14 +68,9 @@ class GenericCpu(Platform): self, call: PsCall ) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]: call_func = call.function - assert isinstance(call_func, (PsReductionFunction | PsMathFunction | PsConstantFunction)) + assert isinstance(call_func, (PsReductionWriteBack | PsMathFunction | PsConstantFunction)) - func = call_func.func - - if ( - isinstance(call_func, PsReductionFunction) - and func is ReductionFunctions.WriteBackToPtr - ): + if isinstance(call_func, PsReductionWriteBack): ptr_expr, symbol_expr = call.args op = call_func.reduction_op @@ -103,6 +97,7 @@ class GenericCpu(Platform): return potential_call dtype = call.get_dtype() + func = call_func.func arg_types = (dtype,) * call.function.arg_count expr: PsExpression | None = None diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 06b230454..876da4c91 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -52,8 +52,7 @@ from ..literals import PsLiteral from ..functions import ( MathFunctions, CFunction, - ReductionFunctions, - PsReductionFunction, + PsReductionWriteBack, PsMathFunction, PsConstantFunction, ConstantFunctions, @@ -296,20 +295,16 @@ class GenericGpu(Platform): self, call: PsCall ) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]: call_func = call.function - assert isinstance(call_func, (PsReductionFunction | PsMathFunction | PsConstantFunction)) + assert isinstance(call_func, (PsReductionWriteBack | PsMathFunction | PsConstantFunction)) - func = call_func.func - - if ( - isinstance(call_func, PsReductionFunction) - and func is ReductionFunctions.WriteBackToPtr - ): + if isinstance(call_func, PsReductionWriteBack): ptr_expr, symbol_expr = call.args op = call_func.reduction_op return self.resolve_reduction(ptr_expr, symbol_expr, op) dtype = call.get_dtype() + func = call_func.func arg_types = (dtype,) * call.function.arg_count expr: PsExpression | None = None diff --git a/src/pystencils/backend/transformations/select_functions.py b/src/pystencils/backend/transformations/select_functions.py index d005acb4b..5953bd47d 100644 --- a/src/pystencils/backend/transformations/select_functions.py +++ b/src/pystencils/backend/transformations/select_functions.py @@ -3,7 +3,7 @@ from ..exceptions import MaterializationError from ..platforms import Platform from ..ast import PsAstNode from ..ast.expressions import PsCall, PsExpression -from ..functions import PsMathFunction, PsConstantFunction, PsReductionFunction +from ..functions import PsMathFunction, PsConstantFunction, PsReductionWriteBack class SelectFunctions: @@ -22,7 +22,7 @@ class SelectFunctions: if isinstance(node, PsAssignment): rhs = node.rhs if isinstance(rhs, PsCall) and isinstance( - rhs.function, PsReductionFunction + rhs.function, PsReductionWriteBack ): resolved_func = self._platform.select_function(rhs) diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index 74a07b902..68ab59bd9 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -26,7 +26,7 @@ from ..types import PsIntegerType, PsScalarType from ..backend.memory import PsSymbol from ..backend.ast import PsAstNode -from ..backend.functions import PsReductionFunction, ReductionFunctions +from ..backend.functions import PsReductionWriteBack from ..backend.ast.expressions import ( PsExpression, PsSymbolExpr, @@ -308,9 +308,7 @@ class DefaultKernelCreationDriver: ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)) ) write_back_ptr = PsCall( - PsReductionFunction( - ReductionFunctions.WriteBackToPtr, reduction_info.op - ), + PsReductionWriteBack(reduction_info.op), [ptr_symbol_expr, symbol_expr], ) -- GitLab