diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 18c2277cf76102f9265114853f97b8e2eb50cc67..201321693395ca3716e70170fbbe511af3a4fca8 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -30,6 +30,7 @@ from typing import Any, Sequence, TYPE_CHECKING from abc import ABC from enum import Enum +from ..sympyextensions import ReductionOp from ..types import PsType from .exceptions import PsInternalCompilerError @@ -134,6 +135,48 @@ class PsMathFunction(PsFunction): return hash(self._func) +class ReductionFunctions(Enum): + """Function representing different steps in kernels with reductions supported by the backend. + + Each platform has to materialize these functions to a concrete implementation. + """ + + InitLocalCopy = ("InitLocalCopy", 2) + WriteBackToPtr = ("WriteBackToPtr", 2) + + def __init__(self, func_name, num_args): + self.function_name = func_name + self.num_args = num_args + + +class PsReductionFunction(PsFunction): + + def __init__(self, func: ReductionFunctions, op: ReductionOp) -> None: + super().__init__(func.function_name, func.num_args) + self._func = func + self._op = op + + @property + def func(self) -> ReductionFunctions: + return self._func + + @property + def op(self) -> ReductionOp: + return self._op + + def __str__(self) -> str: + return f"{self._func.function_name}" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsReductionFunction): + return False + + return self._func == other._func + + def __hash__(self) -> int: + return hash(self._func) + + class CFunction(PsFunction): """A concrete C function. diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 62feca26504b753a9898153b4e6fb85e3fc5b2e7..059817bfda92d4714896a86a110cb257ca4cb823 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -50,7 +50,7 @@ from ..ast.expressions import ( PsNot, ) from ..ast.vector import PsVecBroadcast, PsVecMemAcc -from ..functions import PsMathFunction, CFunction +from ..functions import PsMathFunction, CFunction, PsReductionFunction from ..ast.util import determine_memory_object from ..exceptions import TypificationError @@ -590,7 +590,7 @@ class Typifier: case PsCall(function, args): match function: - case PsMathFunction(): + case PsMathFunction() | PsReductionFunction(): for arg in args: self.visit_expr(arg, tc) tc.infer_dtype(expr) diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index b145b6f76389fb75c5f94eafd3462cb084247b35..33cb28711e59b173ce3a120571653b69db40f122 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -1,11 +1,14 @@ from abc import ABC, abstractmethod from typing import Sequence -from pystencils.backend.ast.expressions import PsCall +from ..ast.expressions import PsCall, PsMemAcc, PsConstantExpr -from ..functions import CFunction, PsMathFunction, MathFunctions, NumericLimitsFunctions +from ..ast import PsAstNode +from ..functions import CFunction, PsMathFunction, MathFunctions, NumericLimitsFunctions, ReductionFunctions, \ + PsReductionFunction from ..literals import PsLiteral -from ...types import PsIntegerType, PsIeeeFloatType, PsScalarType +from ...compound_op_mapping import compound_op_to_expr +from ...types import PsIntegerType, PsIeeeFloatType, PsScalarType, PsPointerType from .platform import Platform from ..exceptions import MaterializationError @@ -18,7 +21,7 @@ from ..kernelcreation.iteration_space import ( ) from ..constants import PsConstant -from ..ast.structural import PsDeclaration, PsLoop, PsBlock +from ..ast.structural import PsDeclaration, PsLoop, PsBlock, PsAssignment from ..ast.expressions import ( PsSymbolExpr, PsExpression, @@ -56,6 +59,36 @@ class GenericCpu(Platform): else: raise MaterializationError(f"Unknown type of iteration space: {ispace}") + def unfold_function( + self, call: PsCall + ) -> PsAstNode: + assert isinstance(call.function, PsReductionFunction) + + func = call.function.func + + match func: + case ReductionFunctions.InitLocalCopy: + symbol_expr, init_val = call.args + assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(init_val, PsExpression) + + return PsDeclaration(symbol_expr, init_val) + case ReductionFunctions.WriteBackToPtr: + ptr_expr, symbol_expr = call.args + op = call.function.op + + assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType) + assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType) + + ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype))) + + # TODO: can this be avoided somehow? + potential_call = compound_op_to_expr(op, ptr_access, symbol_expr) + if isinstance(potential_call, PsCall): + potential_call.dtype = symbol_expr.dtype + potential_call = self.select_function(potential_call) + + return PsAssignment(ptr_access, potential_call) + def select_function(self, call: PsCall) -> PsExpression: assert isinstance(call.function, PsMathFunction) diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py index 2c7ee1c5f4750eac0375bc31a3f44b9eea50642b..732f37bbcd75e6f839728a16add34949dfc95c05 100644 --- a/src/pystencils/backend/platforms/platform.py +++ b/src/pystencils/backend/platforms/platform.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Any +from ..ast import PsAstNode from ..ast.structural import PsBlock from ..ast.expressions import PsCall, PsExpression @@ -40,3 +41,13 @@ class Platform(ABC): If no viable implementation exists, raise a `MaterializationError`. """ pass + + @abstractmethod + def unfold_function( + self, call: PsCall + ) -> PsAstNode: + """Unfolds an implementation for the given function on the given data type. + + If no viable implementation exists, raise a `MaterializationError`. + """ + pass diff --git a/src/pystencils/backend/transformations/select_functions.py b/src/pystencils/backend/transformations/select_functions.py index e41c345ae4ed71101d07fcaa5b9df88b1e0f54e2..0045de87b7b5a4203430fd74a3641ca831826419 100644 --- a/src/pystencils/backend/transformations/select_functions.py +++ b/src/pystencils/backend/transformations/select_functions.py @@ -1,7 +1,7 @@ from ..platforms import Platform from ..ast import PsAstNode from ..ast.expressions import PsCall -from ..functions import PsMathFunction +from ..functions import PsMathFunction, PsReductionFunction class SelectFunctions: @@ -19,5 +19,7 @@ class SelectFunctions: if isinstance(node, PsCall) and isinstance(node.function, PsMathFunction): return self._platform.select_function(node) + elif isinstance(node, PsCall) and isinstance(node.function, PsReductionFunction): + return self._platform.unfold_function(node) else: return node diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index ba7df317acab94e30c26025d246bd06d2bea5490..9a80439e7227e4f35b47556087a4281f2922de00 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -7,14 +7,14 @@ from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO from .kernel import Kernel, GpuKernel, GpuThreadsRange from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr from .parameters import Parameter -from ..compound_op_mapping import compound_op_to_expr -from ..backend.ast.expressions import PsSymbolExpr, PsMemAcc, PsConstantExpr +from ..backend.functions import PsReductionFunction, ReductionFunctions +from ..backend.ast.expressions import PsSymbolExpr, PsCall from ..types import create_numeric_type, PsIntegerType, PsScalarType from ..backend.memory import PsSymbol from ..backend.ast import PsAstNode -from ..backend.ast.structural import PsBlock, PsLoop, PsAssignment, PsDeclaration +from ..backend.ast.structural import PsBlock, PsLoop from ..backend.ast.analysis import collect_undefined_symbols, collect_required_headers from ..backend.kernelcreation import ( KernelCreationContext, @@ -156,19 +156,20 @@ class DefaultKernelCreationDriver: # Extensions for reductions for symbol, reduction_info in self._ctx.symbols_reduction_info.items(): - # Init local reduction variable copy - kernel_ast.statements = [PsDeclaration(PsSymbolExpr(symbol), - reduction_info.init_val)] + kernel_ast.statements + typify = Typifier(self._ctx) + symbol_expr = typify(PsSymbolExpr(symbol)) + ptr_symbol_expr = typify(PsSymbolExpr(reduction_info.ptr_symbol)) + init_val = typify(reduction_info.init_val) - # Write back result to reduction target variable - ptr_access = PsMemAcc(PsSymbolExpr(reduction_info.ptr_symbol), - PsConstantExpr(PsConstant(0))) - kernel_ast.statements += [PsAssignment( - ptr_access, compound_op_to_expr(reduction_info.op, ptr_access, PsSymbolExpr(symbol)))] + init_local_copy = PsCall(PsReductionFunction(ReductionFunctions.InitLocalCopy, reduction_info.op), + [symbol_expr, init_val]) + write_back_ptr = PsCall(PsReductionFunction(ReductionFunctions.WriteBackToPtr, reduction_info.op), + [ptr_symbol_expr, symbol_expr]) - # TODO: only newly introduced nodes - typify = Typifier(self._ctx) - kernel_ast = typify(kernel_ast) + # Init local reduction variable copy + kernel_ast.statements = [init_local_copy] + kernel_ast.statements + # Write back result to reduction target variable + kernel_ast.statements += [write_back_ptr] # Target-Specific optimizations if self._cfg.target.is_cpu(): diff --git a/tests/kernelcreation/test_reduction.py b/tests/kernelcreation/test_reduction.py index 69b75e711c1e8c306486f74bc0a43b8daea924b1..b24058571bfd87f2d50b02e1726f23077753b922 100644 --- a/tests/kernelcreation/test_reduction.py +++ b/tests/kernelcreation/test_reduction.py @@ -32,11 +32,11 @@ def test_reduction(dtype, op): config = ps.CreateKernelConfig(target=ps.Target.GPU) if gpu_avail else ps.CreateKernelConfig(cpu_openmp=True) ast_reduction = ps.create_kernel([red_assign], config, default_dtype=dtype) + ps.show_code(ast_reduction) + # code_reduction = ps.get_code_str(ast_reduction) kernel_reduction = ast_reduction.compile() - ps.show_code(ast_reduction) - array = np.full((SIZE,), INIT_ARR, dtype=dtype) reduction_array = np.full((1,), INIT_W, dtype=dtype)