diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index d28ef5f44954b90ef8b7c48fd1b00ea35da05b9d..4e38de5e9f11ca1d971ae6659f04e6df7b47f64a 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -142,7 +142,6 @@ class ReductionFunctions(Enum): Each platform has to materialize these functions to a concrete implementation. """ - InitLocalCopy = ("InitLocalCopy", 2) WriteBackToPtr = ("WriteBackToPtr", 2) def __init__(self, func_name, num_args): diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 544746ef641f003b35ac1ad6f031bd2883e2726c..284e80b9d718538f7960e099cc2a2dfbb257d34d 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -617,7 +617,7 @@ class Typifier: case PsCall(function, args): match function: - case PsMathFunction() | PsReductionFunction(): + case PsMathFunction(): for arg in args: self.visit_expr(arg, tc) tc.infer_dtype(expr) diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index 8936bf73f90f5950a8a181d403f312a65e24ee07..1f6506c8f69a61b7822c21b6068c74164c4ff29f 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -212,10 +212,66 @@ class CudaPlatform(GenericGpu): else: raise MaterializationError(f"Unknown type of iteration space: {ispace}") - def select_function(self, call: PsCall) -> PsExpression: - assert isinstance(call.function, PsMathFunction) - + def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]: func = call.function.func + + if func in ReductionFunctions: + match func: + case ReductionFunctions.WriteBackToPtr: + ptr_expr, symbol_expr = call.args + op = call.function.reduction_op + stype = symbol_expr.dtype + ptrtype = ptr_expr.dtype + + warp_size = 32 # TODO: get from platform/user config + + assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType) + assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType) + + if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64): + NotImplementedError("atomic operations are only available for float32/64 datatypes") + + # workaround for subtractions -> use additions for reducing intermediate results + # similar to OpenMP reductions: local copies (negative sign) are added at the end + match op: + case ReductionOp.Sub: + actual_op = ReductionOp.Add + case _: + actual_op = op + + # perform local warp reductions + def gen_shuffle_instr(offset: int): + full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))) + return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype), + [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))]) + + num_shuffles = math.frexp(warp_size)[1] + shuffles = [PsAssignment(symbol_expr, + compound_op_to_expr(actual_op, + symbol_expr, gen_shuffle_instr(pow(2, i - 1)))) + for i in reversed(range(1, num_shuffles))] + + # find first thread in warp + ispace = self._ctx.get_iteration_space() + is_valid_thread = self._get_condition_for_translation(ispace) + thread_indices_per_dim = [ + idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32))) + for i, idx in enumerate(THREAD_IDX[:ispace.rank]) + ] + tid: PsExpression = thread_indices_per_dim[0] + for t in thread_indices_per_dim[1:]: + tid = PsAdd(tid, t) + first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))), + PsConstantExpr(PsConstant(0, SInt(32)))) + cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp + + # use atomic operation on first thread of warp to sync + call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void")) + call.args = (ptr_expr, symbol_expr) + + # assemble warp reduction + return (shuffles, PsConditional(cond, PsBlock([PsStatement(call)]))) + dtype = call.get_dtype() arg_types = (dtype,) * func.num_args @@ -232,7 +288,7 @@ class CudaPlatform(GenericGpu): return PsLiteralExpr(PsLiteral(define, dtype)) - if isinstance(dtype, PsIeeeFloatType): + if isinstance(dtype, PsIeeeFloatType) and func in MathFunctions: match func: case ( MathFunctions.Exp @@ -285,84 +341,6 @@ class CudaPlatform(GenericGpu): f"No implementation available for function {func} on data type {dtype}" ) - def unfold_function( - self, call: PsCall - ) -> PsAstNode: - assert isinstance(call.function, PsReductionFunction) - - func = call.function.func - - match func: - case ReductionFunctions.InitLocalCopy: - symbol_expr, init_val = call.args - assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(init_val, PsExpression) - - return PsDeclaration(symbol_expr, init_val) - case ReductionFunctions.WriteBackToPtr: - ptr_expr, symbol_expr = call.args - op = call.function.reduction_op - stype = symbol_expr.dtype - ptrtype = ptr_expr.dtype - - warp_size = 32 # TODO: get from platform/user config - - assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType) - assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType) - - if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64): - NotImplementedError("atomic operations are only available for float32/64 datatypes") - - # set up mask symbol for active threads in warp - #mask = PsSymbol("__shfl_mask", UInt(32)) - #self._ctx.add_symbol(mask) - full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))) - - # workaround for subtractions -> use additions for reducing intermediate results - # similar to OpenMP reductions: local copies (negative sign) are added at the end - match op: - case ReductionOp.Sub: - actual_op = ReductionOp.Add - case _: - actual_op = op - - # perform local warp reductions - def gen_shuffle_instr(offset: int): - return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype), - [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))]) - - num_shuffles = math.frexp(warp_size)[1] - shuffles = [PsAssignment(symbol_expr, - compound_op_to_expr(actual_op, symbol_expr, gen_shuffle_instr(pow(2, i - 1)))) - for i in reversed(range(1, num_shuffles))] - - # find first thread in warp - ispace = self._ctx.get_iteration_space() # TODO: receive as argument in unfold_function? - is_valid_thread = self._get_condition_for_translation(ispace) - thread_indices_per_dim = [ - idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32))) - for i, idx in enumerate(THREAD_IDX[:ispace.rank]) - ] - tid: PsExpression = thread_indices_per_dim[0] - for t in thread_indices_per_dim[1:]: - tid = PsAdd(tid, t) - first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))), - PsConstantExpr(PsConstant(0, SInt(32)))) - cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp - - #ballot_instr = PsCall(CFunction("__ballot_sync", [UInt(32), SInt(32)], SInt(32)), - # [full_mask, is_valid_thread]) - #decl_mask = PsDeclaration(full_mask) - - # use atomic operation on first thread of warp to sync - call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void")) - call.args = (ptr_expr, symbol_expr) - - # assemble warp reduction - return PsBlock( - #[decl_mask] - shuffles - + [PsConditional(cond, PsBlock([PsStatement(call)]))]) - # Internals # TODO: SYCL platform has very similar code for fetching conditionals -> move to GenericGPU? diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 1e7468e33768f45c989cb89cd617475762570bce..24692b25c9e76dc470a0432110e9a3031e3f6b2f 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -4,8 +4,7 @@ from typing import Sequence from ..ast.expressions import PsCall, PsMemAcc, PsConstantExpr from ..ast import PsAstNode -from ..functions import CFunction, PsMathFunction, MathFunctions, NumericLimitsFunctions, ReductionFunctions, \ - PsReductionFunction +from ..functions import CFunction, MathFunctions, NumericLimitsFunctions, ReductionFunctions from ..literals import PsLiteral from ...compound_op_mapping import compound_op_to_expr from ...sympyextensions import ReductionOp @@ -60,43 +59,31 @@ class GenericCpu(Platform): else: raise MaterializationError(f"Unknown type of iteration space: {ispace}") - def unfold_function( - self, call: PsCall - ) -> PsAstNode: - assert isinstance(call.function, PsReductionFunction) - + def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]: func = call.function.func - match func: - case ReductionFunctions.InitLocalCopy: - symbol_expr, init_val = call.args - assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(init_val, PsExpression) - - return PsDeclaration(symbol_expr, init_val) - case ReductionFunctions.WriteBackToPtr: - ptr_expr, symbol_expr = call.args - op = call.function.reduction_op - - assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType) - assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType) + if func in ReductionFunctions: + match func: + case ReductionFunctions.WriteBackToPtr: + ptr_expr, symbol_expr = call.args + op = call.function.reduction_op - ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype))) + assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType) + assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType) - # inspired by OpenMP: local reduction variable (negative sign) is added at the end - actual_op = ReductionOp.Add if op is ReductionOp.Sub else op + ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype))) - # TODO: can this be avoided somehow? - potential_call = compound_op_to_expr(actual_op, ptr_access, symbol_expr) - if isinstance(potential_call, PsCall): - potential_call.dtype = symbol_expr.dtype - potential_call = self.select_function(potential_call) + # inspired by OpenMP: local reduction variable (negative sign) is added at the end + actual_op = ReductionOp.Add if op is ReductionOp.Sub else op - return PsAssignment(ptr_access, potential_call) + # TODO: can this be avoided somehow? + potential_call = compound_op_to_expr(actual_op, ptr_access, symbol_expr) + if isinstance(potential_call, PsCall): + potential_call.dtype = symbol_expr.dtype + potential_call = self.select_function(potential_call) - def select_function(self, call: PsCall) -> PsExpression: - assert isinstance(call.function, PsMathFunction) + return potential_call - func = call.function.func dtype = call.get_dtype() arg_types = (dtype,) * func.num_args diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py index e195d59bc104d7f57f356e2cb5bf0b219c54008c..90fd690840741ad8dd53dfdf3d4140b1a9c0b544 100644 --- a/src/pystencils/backend/platforms/platform.py +++ b/src/pystencils/backend/platforms/platform.py @@ -38,19 +38,9 @@ class Platform(ABC): @abstractmethod def select_function( self, call: PsCall - ) -> PsExpression: + ) -> PsExpression | tuple[tuple[PsAstNode, ...], PsExpression]: """Select an implementation for the given function on the given data type. If no viable implementation exists, raise a `MaterializationError`. """ pass - - @abstractmethod - def unfold_function( - self, call: PsCall - ) -> PsAstNode: - """Unfolds an implementation for the given function on the given data type. - - If no viable implementation exists, raise a `MaterializationError`. - """ - pass diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py index 9b077fd2b719421962ae82c9af95ea24420f622e..eae2b7598bfa43cf5379fe8782233be11d0dfef2 100644 --- a/src/pystencils/backend/platforms/sycl.py +++ b/src/pystencils/backend/platforms/sycl.py @@ -57,7 +57,7 @@ class SyclPlatform(GenericGpu): else: raise MaterializationError(f"Unknown type of iteration space: {ispace}") - def select_function(self, call: PsCall) -> PsExpression: + def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]: assert isinstance(call.function, PsMathFunction) func = call.function.func @@ -108,13 +108,6 @@ class SyclPlatform(GenericGpu): f"No implementation available for function {func} on data type {dtype}" ) - def unfold_function( - self, call: PsCall - ) -> PsAstNode: - raise MaterializationError( - f"No implementation available for function {call.function.name}" - ) - def _prepend_dense_translation( self, body: PsBlock, ispace: FullIterationSpace ) -> PsBlock: diff --git a/src/pystencils/backend/transformations/select_functions.py b/src/pystencils/backend/transformations/select_functions.py index 0045de87b7b5a4203430fd74a3641ca831826419..288650698f29d2a290fc2f5a5f7ae6c2ab0206bc 100644 --- a/src/pystencils/backend/transformations/select_functions.py +++ b/src/pystencils/backend/transformations/select_functions.py @@ -1,6 +1,8 @@ -from ..platforms import Platform +from ..ast.structural import PsStatement, PsAssignment, PsBlock +from ..exceptions import MaterializationError +from ..platforms import Platform, CudaPlatform from ..ast import PsAstNode -from ..ast.expressions import PsCall +from ..ast.expressions import PsCall, PsExpression from ..functions import PsMathFunction, PsReductionFunction @@ -17,9 +19,31 @@ class SelectFunctions: def visit(self, node: PsAstNode) -> PsAstNode: node.children = [self.visit(c) for c in node.children] - if isinstance(node, PsCall) and isinstance(node.function, PsMathFunction): - return self._platform.select_function(node) - elif isinstance(node, PsCall) and isinstance(node.function, PsReductionFunction): - return self._platform.unfold_function(node) + if isinstance(node, PsAssignment): + rhs = node.rhs + if isinstance(rhs, PsCall) and isinstance(rhs.function, PsReductionFunction): + resolved_func = self._platform.select_function(rhs) + + match resolved_func: + case ((prepend), expr): + match self._platform: + case CudaPlatform(): + # special case: produces statement with atomic operation writing value back to ptr + return PsBlock(prepend + [PsStatement(expr)]) + case _: + return PsBlock(prepend + [PsAssignment(node.lhs, expr)]) + case PsExpression(): + return PsAssignment(node.lhs, resolved_func) + case _: + raise MaterializationError( + f"Wrong return type for resolved function {rhs.function.name} in SelectFunctions." + ) + else: + return node + elif isinstance(node, PsCall) and isinstance(node.function, PsMathFunction): + resolved_func = self._platform.select_function(node) + assert isinstance(resolved_func, PsExpression) + + return resolved_func else: return node diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index 96e9b94edf040f9af36a727e4bb70874c7ce5ad5..9f04d074a60fcc038acd54d4443977945d8a2a7c 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -16,7 +16,7 @@ from .kernel import Kernel, GpuKernel from .properties import PsSymbolProperty, FieldBasePtr from .parameters import Parameter from ..backend.functions import PsReductionFunction, ReductionFunctions -from ..backend.ast.expressions import PsSymbolExpr, PsCall +from ..backend.ast.expressions import PsSymbolExpr, PsCall, PsMemAcc, PsConstantExpr from .gpu_indexing import GpuIndexing, GpuLaunchConfiguration from ..field import Field @@ -24,7 +24,7 @@ from ..types import PsIntegerType, PsScalarType from ..backend.memory import PsSymbol from ..backend.ast import PsAstNode -from ..backend.ast.structural import PsBlock, PsLoop +from ..backend.ast.structural import PsBlock, PsLoop, PsDeclaration, PsAssignment from ..backend.ast.analysis import collect_undefined_symbols, collect_required_headers from ..backend.kernelcreation import ( KernelCreationContext, @@ -187,16 +187,16 @@ class DefaultKernelCreationDriver: symbol_expr = typify(PsSymbolExpr(symbol)) ptr_symbol_expr = typify(PsSymbolExpr(reduction_info.ptr_symbol)) init_val = typify(reduction_info.init_val) + ptr_access = PsMemAcc(ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype))) - init_local_copy = PsCall(PsReductionFunction(ReductionFunctions.InitLocalCopy, reduction_info.op), - [symbol_expr, init_val]) + decl_local_copy = PsDeclaration(symbol_expr, init_val) write_back_ptr = PsCall(PsReductionFunction(ReductionFunctions.WriteBackToPtr, reduction_info.op), [ptr_symbol_expr, symbol_expr]) - # Init local reduction variable copy - kernel_ast.statements = [init_local_copy] + kernel_ast.statements - # Write back result to reduction target variable - kernel_ast.statements += [write_back_ptr] + prepend_ast = [decl_local_copy] # declare and init local copy with neutral element + append_ast = [PsAssignment(ptr_access, write_back_ptr)] # write back result to reduction target variable + + kernel_ast.statements = prepend_ast + kernel_ast.statements + append_ast # Target-Specific optimizations if self._target.is_cpu():