From 02be4d5e994e458e0e42d72045ec53355e7820d1 Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Fri, 21 Feb 2025 19:35:48 +0100 Subject: [PATCH] Fix typecheck --- src/pystencils/backend/platforms/cuda.py | 121 +++++++++--------- .../backend/platforms/generic_cpu.py | 40 +++--- src/pystencils/backend/platforms/platform.py | 2 +- .../transformations/select_functions.py | 20 +-- src/pystencils/codegen/driver.py | 8 +- 5 files changed, 99 insertions(+), 92 deletions(-) diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index 6df502c1f..291858810 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -36,8 +36,8 @@ from ..ast.expressions import ( from ..ast.expressions import PsLt, PsAnd from ...types import PsSignedIntegerType, PsIeeeFloatType from ..literals import PsLiteral -from ..functions import MathFunctions, CFunction, ReductionFunctions, NumericLimitsFunctions - +from ..functions import MathFunctions, CFunction, ReductionFunctions, NumericLimitsFunctions, PsReductionFunction, \ + PsMathFunction int32 = PsSignedIntegerType(width=32, const=False) @@ -209,65 +209,66 @@ class CudaPlatform(GenericGpu): else: raise MaterializationError(f"Unknown type of iteration space: {ispace}") - def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]: - func = call.function.func + def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]: + call_func = call.function + assert isinstance(call_func, PsReductionFunction | PsMathFunction) - if func in ReductionFunctions: - match func: - case ReductionFunctions.WriteBackToPtr: - ptr_expr, symbol_expr = call.args - op = call.function.reduction_op - stype = symbol_expr.dtype - ptrtype = ptr_expr.dtype - - warp_size = 32 # TODO: get from platform/user config - - assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType) - assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType) - - if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64): - NotImplementedError("atomic operations are only available for float32/64 datatypes") - - # workaround for subtractions -> use additions for reducing intermediate results - # similar to OpenMP reductions: local copies (negative sign) are added at the end - match op: - case ReductionOp.Sub: - actual_op = ReductionOp.Add - case _: - actual_op = op - - # perform local warp reductions - def gen_shuffle_instr(offset: int): - full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))) - return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype), - [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))]) - - num_shuffles = math.frexp(warp_size)[1] - shuffles = [PsAssignment(symbol_expr, - compound_op_to_expr(actual_op, - symbol_expr, gen_shuffle_instr(pow(2, i - 1)))) - for i in reversed(range(1, num_shuffles))] - - # find first thread in warp - ispace = self._ctx.get_iteration_space() - is_valid_thread = self._get_condition_for_translation(ispace) - thread_indices_per_dim = [ - idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32))) - for i, idx in enumerate(THREAD_IDX[:ispace.rank]) - ] - tid: PsExpression = thread_indices_per_dim[0] - for t in thread_indices_per_dim[1:]: - tid = PsAdd(tid, t) - first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))), - PsConstantExpr(PsConstant(0, SInt(32)))) - cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp - - # use atomic operation on first thread of warp to sync - call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void")) - call.args = (ptr_expr, symbol_expr) - - # assemble warp reduction - return (shuffles, PsConditional(cond, PsBlock([PsStatement(call)]))) + func = call_func.func + + if isinstance(call_func, PsReductionFunction) and func is ReductionFunctions.WriteBackToPtr: + ptr_expr, symbol_expr = call.args + op = call_func.reduction_op + stype = symbol_expr.dtype + ptrtype = ptr_expr.dtype + + warp_size = 32 # TODO: get from platform/user config + + assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType) + assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType) + + if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64): + NotImplementedError("atomic operations are only available for float32/64 datatypes") + + # workaround for subtractions -> use additions for reducing intermediate results + # similar to OpenMP reductions: local copies (negative sign) are added at the end + match op: + case ReductionOp.Sub: + actual_op = ReductionOp.Add + case _: + actual_op = op + + # perform local warp reductions + def gen_shuffle_instr(offset: int): + full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))) + return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype), + [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))]) + + num_shuffles = math.frexp(warp_size)[1] + shuffles = tuple(PsAssignment(symbol_expr, + compound_op_to_expr(actual_op, + symbol_expr, gen_shuffle_instr(pow(2, i - 1)))) + for i in reversed(range(1, num_shuffles))) + + # find first thread in warp + ispace = self._ctx.get_iteration_space() + is_valid_thread = self._get_condition_for_translation(ispace) + thread_indices_per_dim = [ + idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32))) + for i, idx in enumerate(THREAD_IDX[:ispace.rank]) + ] + tid: PsExpression = thread_indices_per_dim[0] + for t in thread_indices_per_dim[1:]: + tid = PsAdd(tid, t) + first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))), + PsConstantExpr(PsConstant(0, SInt(32)))) + cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp + + # use atomic operation on first thread of warp to sync + call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void")) + call.args = (ptr_expr, symbol_expr) + + # assemble warp reduction + return shuffles, PsConditional(cond, PsBlock([PsStatement(call)])) dtype = call.get_dtype() arg_types = (dtype,) * func.num_args diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 3ffdfa22f..2f873ff29 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -4,7 +4,8 @@ from typing import Sequence from ..ast.expressions import PsCall, PsMemAcc, PsConstantExpr from ..ast import PsAstNode -from ..functions import CFunction, MathFunctions, NumericLimitsFunctions, ReductionFunctions +from ..functions import CFunction, MathFunctions, NumericLimitsFunctions, ReductionFunctions, PsMathFunction, \ + PsReductionFunction from ..literals import PsLiteral from ...compound_op_mapping import compound_op_to_expr from ...sympyextensions import ReductionOp @@ -59,30 +60,31 @@ class GenericCpu(Platform): else: raise MaterializationError(f"Unknown type of iteration space: {ispace}") - def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]: - func = call.function.func + def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]: + call_func = call.function + assert isinstance(call_func, PsReductionFunction | PsMathFunction) - if func in ReductionFunctions: - match func: - case ReductionFunctions.WriteBackToPtr: - ptr_expr, symbol_expr = call.args - op = call.function.reduction_op + func = call_func.func + + if isinstance(call_func, PsReductionFunction) and func is ReductionFunctions.WriteBackToPtr: + ptr_expr, symbol_expr = call.args + op = call_func.reduction_op - assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType) - assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType) + assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType) + assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType) - ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype))) + ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype))) - # inspired by OpenMP: local reduction variable (negative sign) is added at the end - actual_op = ReductionOp.Add if op is ReductionOp.Sub else op + # inspired by OpenMP: local reduction variable (negative sign) is added at the end + actual_op = ReductionOp.Add if op is ReductionOp.Sub else op - # TODO: can this be avoided somehow? - potential_call = compound_op_to_expr(actual_op, ptr_access, symbol_expr) - if isinstance(potential_call, PsCall): - potential_call.dtype = symbol_expr.dtype - potential_call = self.select_function(potential_call) + # create binop and potentially select corresponding function for e.g. min or max + potential_call = compound_op_to_expr(actual_op, ptr_access, symbol_expr) + if isinstance(potential_call, PsCall): + potential_call.dtype = symbol_expr.dtype + return self.select_function(potential_call) - return potential_call + return potential_call dtype = call.get_dtype() arg_types = (dtype,) * func.num_args diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py index 90fd69084..437962172 100644 --- a/src/pystencils/backend/platforms/platform.py +++ b/src/pystencils/backend/platforms/platform.py @@ -38,7 +38,7 @@ class Platform(ABC): @abstractmethod def select_function( self, call: PsCall - ) -> PsExpression | tuple[tuple[PsAstNode, ...], PsExpression]: + ) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]: """Select an implementation for the given function on the given data type. If no viable implementation exists, raise a `MaterializationError`. diff --git a/src/pystencils/backend/transformations/select_functions.py b/src/pystencils/backend/transformations/select_functions.py index 288650698..d5f731653 100644 --- a/src/pystencils/backend/transformations/select_functions.py +++ b/src/pystencils/backend/transformations/select_functions.py @@ -1,6 +1,6 @@ -from ..ast.structural import PsStatement, PsAssignment, PsBlock +from ..ast.structural import PsAssignment, PsBlock from ..exceptions import MaterializationError -from ..platforms import Platform, CudaPlatform +from ..platforms import Platform from ..ast import PsAstNode from ..ast.expressions import PsCall, PsExpression from ..functions import PsMathFunction, PsReductionFunction @@ -25,13 +25,17 @@ class SelectFunctions: resolved_func = self._platform.select_function(rhs) match resolved_func: - case ((prepend), expr): - match self._platform: - case CudaPlatform(): - # special case: produces statement with atomic operation writing value back to ptr - return PsBlock(prepend + [PsStatement(expr)]) + case (prepend, new_rhs): + assert isinstance(prepend, tuple) + + match new_rhs: + case PsExpression(): + return PsBlock(prepend + (PsAssignment(node.lhs, new_rhs),)) + case PsAstNode(): + # special case: produces structural with atomic operation writing value back to ptr + return PsBlock(prepend + (new_rhs,)) case _: - return PsBlock(prepend + [PsAssignment(node.lhs, expr)]) + assert False, "Unexpected output from SelectFunctions." case PsExpression(): return PsAssignment(node.lhs, resolved_func) case _: diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index 9f04d074a..cc3411249 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -187,16 +187,16 @@ class DefaultKernelCreationDriver: symbol_expr = typify(PsSymbolExpr(symbol)) ptr_symbol_expr = typify(PsSymbolExpr(reduction_info.ptr_symbol)) init_val = typify(reduction_info.init_val) - ptr_access = PsMemAcc(ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype))) - decl_local_copy = PsDeclaration(symbol_expr, init_val) + ptr_access = PsMemAcc(ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype))) write_back_ptr = PsCall(PsReductionFunction(ReductionFunctions.WriteBackToPtr, reduction_info.op), [ptr_symbol_expr, symbol_expr]) - prepend_ast = [decl_local_copy] # declare and init local copy with neutral element + prepend_ast = [PsDeclaration(symbol_expr, init_val)] # declare and init local copy with neutral element append_ast = [PsAssignment(ptr_access, write_back_ptr)] # write back result to reduction target variable - kernel_ast.statements = prepend_ast + kernel_ast.statements + append_ast + kernel_ast.statements = prepend_ast + kernel_ast.statements + kernel_ast.statements += append_ast # Target-Specific optimizations if self._target.is_cpu(): -- GitLab