diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 6a9d3e4f45815eb8995e55469e7d13dbff0ed924..a3da9a1de1aa0639c5b04acfa3020bc551497b51 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -95,45 +95,31 @@ class PsMathFunction(PsFunction): return hash(self._func) -class ReductionFunctions(Enum): - """Function representing different steps in kernels with reductions supported by the backend. +class PsReductionWriteBack(PsFunction): + """Function representing a reduction kernel's write-back step supported by the backend. - Each platform has to materialize these functions to a concrete implementation. + Each platform has to materialize this function to a concrete implementation. """ - WriteBackToPtr = ("WriteBackToPtr", 2) - - def __init__(self, func_name, num_args): - self.function_name = func_name - self.num_args = num_args - - -class PsReductionFunction(PsFunction): - - def __init__(self, func: ReductionFunctions, reduction_op: ReductionOp) -> None: - super().__init__(func.function_name, func.num_args) - self._func = func + def __init__(self, reduction_op: ReductionOp) -> None: + super().__init__("WriteBackToPtr", 2) self._reduction_op = reduction_op - @property - def func(self) -> ReductionFunctions: - return self._func - @property def reduction_op(self) -> ReductionOp: return self._reduction_op def __str__(self) -> str: - return f"{self._func.function_name}" + return f"{super().name}" def __eq__(self, other: object) -> bool: - if not isinstance(other, PsReductionFunction): + if not isinstance(other, PsReductionWriteBack): return False - return self._func == other._func and self._reduction_op == other._reduction_op + return self._reduction_op == other._reduction_op def __hash__(self) -> int: - return hash(self._func) + return hash(self._reduction_op) class ConstantFunctions(Enum): diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 1dc2914b9f587a11180e0df426eb2bb38a7377b8..cbde419d4eff508915b607df8a699632c1ba4416 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -8,11 +8,10 @@ from ..ast import PsAstNode from ..functions import ( CFunction, MathFunctions, - ReductionFunctions, PsMathFunction, - PsReductionFunction, PsConstantFunction, ConstantFunctions, + PsReductionWriteBack, ) from ..reduction_op_mapping import reduction_op_to_expr from ...sympyextensions import ReductionOp @@ -69,14 +68,9 @@ class GenericCpu(Platform): self, call: PsCall ) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]: call_func = call.function - assert isinstance(call_func, (PsReductionFunction | PsMathFunction | PsConstantFunction)) + assert isinstance(call_func, (PsReductionWriteBack | PsMathFunction | PsConstantFunction)) - func = call_func.func - - if ( - isinstance(call_func, PsReductionFunction) - and func is ReductionFunctions.WriteBackToPtr - ): + if isinstance(call_func, PsReductionWriteBack): ptr_expr, symbol_expr = call.args op = call_func.reduction_op @@ -103,6 +97,7 @@ class GenericCpu(Platform): return potential_call dtype = call.get_dtype() + func = call_func.func arg_types = (dtype,) * call.function.arg_count expr: PsExpression | None = None diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 06b230454c289bf4e1bce294e725cb54cae624ad..876da4c912c68fed993b7d864e69748c47345277 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -52,8 +52,7 @@ from ..literals import PsLiteral from ..functions import ( MathFunctions, CFunction, - ReductionFunctions, - PsReductionFunction, + PsReductionWriteBack, PsMathFunction, PsConstantFunction, ConstantFunctions, @@ -296,20 +295,16 @@ class GenericGpu(Platform): self, call: PsCall ) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]: call_func = call.function - assert isinstance(call_func, (PsReductionFunction | PsMathFunction | PsConstantFunction)) + assert isinstance(call_func, (PsReductionWriteBack | PsMathFunction | PsConstantFunction)) - func = call_func.func - - if ( - isinstance(call_func, PsReductionFunction) - and func is ReductionFunctions.WriteBackToPtr - ): + if isinstance(call_func, PsReductionWriteBack): ptr_expr, symbol_expr = call.args op = call_func.reduction_op return self.resolve_reduction(ptr_expr, symbol_expr, op) dtype = call.get_dtype() + func = call_func.func arg_types = (dtype,) * call.function.arg_count expr: PsExpression | None = None diff --git a/src/pystencils/backend/transformations/select_functions.py b/src/pystencils/backend/transformations/select_functions.py index d005acb4bcf3042473826383384e16d9fc7dd4fc..5953bd47db53e716fffef9b2cd1be1d2897fcf52 100644 --- a/src/pystencils/backend/transformations/select_functions.py +++ b/src/pystencils/backend/transformations/select_functions.py @@ -3,7 +3,7 @@ from ..exceptions import MaterializationError from ..platforms import Platform from ..ast import PsAstNode from ..ast.expressions import PsCall, PsExpression -from ..functions import PsMathFunction, PsConstantFunction, PsReductionFunction +from ..functions import PsMathFunction, PsConstantFunction, PsReductionWriteBack class SelectFunctions: @@ -22,7 +22,7 @@ class SelectFunctions: if isinstance(node, PsAssignment): rhs = node.rhs if isinstance(rhs, PsCall) and isinstance( - rhs.function, PsReductionFunction + rhs.function, PsReductionWriteBack ): resolved_func = self._platform.select_function(rhs) diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index 74a07b902a6494d0c528602326286a2aaab01cc5..68ab59bd9265d5cdadac253f45e61f8259cfa2be 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -26,7 +26,7 @@ from ..types import PsIntegerType, PsScalarType from ..backend.memory import PsSymbol from ..backend.ast import PsAstNode -from ..backend.functions import PsReductionFunction, ReductionFunctions +from ..backend.functions import PsReductionWriteBack from ..backend.ast.expressions import ( PsExpression, PsSymbolExpr, @@ -308,9 +308,7 @@ class DefaultKernelCreationDriver: ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)) ) write_back_ptr = PsCall( - PsReductionFunction( - ReductionFunctions.WriteBackToPtr, reduction_info.op - ), + PsReductionWriteBack(reduction_info.op), [ptr_symbol_expr, symbol_expr], )