Skip to content
Snippets Groups Projects
Commit 02be4d5e authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Fix typecheck

parent 89a6f36a
No related branches found
No related tags found
1 merge request!438Reduction Support
Pipeline #74416 failed
......@@ -36,8 +36,8 @@ from ..ast.expressions import (
from ..ast.expressions import PsLt, PsAnd
from ...types import PsSignedIntegerType, PsIeeeFloatType
from ..literals import PsLiteral
from ..functions import MathFunctions, CFunction, ReductionFunctions, NumericLimitsFunctions
from ..functions import MathFunctions, CFunction, ReductionFunctions, NumericLimitsFunctions, PsReductionFunction, \
PsMathFunction
int32 = PsSignedIntegerType(width=32, const=False)
......@@ -209,65 +209,66 @@ class CudaPlatform(GenericGpu):
else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]:
func = call.function.func
def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
call_func = call.function
assert isinstance(call_func, PsReductionFunction | PsMathFunction)
if func in ReductionFunctions:
match func:
case ReductionFunctions.WriteBackToPtr:
ptr_expr, symbol_expr = call.args
op = call.function.reduction_op
stype = symbol_expr.dtype
ptrtype = ptr_expr.dtype
warp_size = 32 # TODO: get from platform/user config
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
NotImplementedError("atomic operations are only available for float32/64 datatypes")
# workaround for subtractions -> use additions for reducing intermediate results
# similar to OpenMP reductions: local copies (negative sign) are added at the end
match op:
case ReductionOp.Sub:
actual_op = ReductionOp.Add
case _:
actual_op = op
# perform local warp reductions
def gen_shuffle_instr(offset: int):
full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
[full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
num_shuffles = math.frexp(warp_size)[1]
shuffles = [PsAssignment(symbol_expr,
compound_op_to_expr(actual_op,
symbol_expr, gen_shuffle_instr(pow(2, i - 1))))
for i in reversed(range(1, num_shuffles))]
# find first thread in warp
ispace = self._ctx.get_iteration_space()
is_valid_thread = self._get_condition_for_translation(ispace)
thread_indices_per_dim = [
idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
for i, idx in enumerate(THREAD_IDX[:ispace.rank])
]
tid: PsExpression = thread_indices_per_dim[0]
for t in thread_indices_per_dim[1:]:
tid = PsAdd(tid, t)
first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))),
PsConstantExpr(PsConstant(0, SInt(32))))
cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
# use atomic operation on first thread of warp to sync
call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
# assemble warp reduction
return (shuffles, PsConditional(cond, PsBlock([PsStatement(call)])))
func = call_func.func
if isinstance(call_func, PsReductionFunction) and func is ReductionFunctions.WriteBackToPtr:
ptr_expr, symbol_expr = call.args
op = call_func.reduction_op
stype = symbol_expr.dtype
ptrtype = ptr_expr.dtype
warp_size = 32 # TODO: get from platform/user config
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
NotImplementedError("atomic operations are only available for float32/64 datatypes")
# workaround for subtractions -> use additions for reducing intermediate results
# similar to OpenMP reductions: local copies (negative sign) are added at the end
match op:
case ReductionOp.Sub:
actual_op = ReductionOp.Add
case _:
actual_op = op
# perform local warp reductions
def gen_shuffle_instr(offset: int):
full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
[full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
num_shuffles = math.frexp(warp_size)[1]
shuffles = tuple(PsAssignment(symbol_expr,
compound_op_to_expr(actual_op,
symbol_expr, gen_shuffle_instr(pow(2, i - 1))))
for i in reversed(range(1, num_shuffles)))
# find first thread in warp
ispace = self._ctx.get_iteration_space()
is_valid_thread = self._get_condition_for_translation(ispace)
thread_indices_per_dim = [
idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
for i, idx in enumerate(THREAD_IDX[:ispace.rank])
]
tid: PsExpression = thread_indices_per_dim[0]
for t in thread_indices_per_dim[1:]:
tid = PsAdd(tid, t)
first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))),
PsConstantExpr(PsConstant(0, SInt(32))))
cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
# use atomic operation on first thread of warp to sync
call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
# assemble warp reduction
return shuffles, PsConditional(cond, PsBlock([PsStatement(call)]))
dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args
......
......@@ -4,7 +4,8 @@ from typing import Sequence
from ..ast.expressions import PsCall, PsMemAcc, PsConstantExpr
from ..ast import PsAstNode
from ..functions import CFunction, MathFunctions, NumericLimitsFunctions, ReductionFunctions
from ..functions import CFunction, MathFunctions, NumericLimitsFunctions, ReductionFunctions, PsMathFunction, \
PsReductionFunction
from ..literals import PsLiteral
from ...compound_op_mapping import compound_op_to_expr
from ...sympyextensions import ReductionOp
......@@ -59,30 +60,31 @@ class GenericCpu(Platform):
else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]:
func = call.function.func
def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
call_func = call.function
assert isinstance(call_func, PsReductionFunction | PsMathFunction)
if func in ReductionFunctions:
match func:
case ReductionFunctions.WriteBackToPtr:
ptr_expr, symbol_expr = call.args
op = call.function.reduction_op
func = call_func.func
if isinstance(call_func, PsReductionFunction) and func is ReductionFunctions.WriteBackToPtr:
ptr_expr, symbol_expr = call.args
op = call_func.reduction_op
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
# inspired by OpenMP: local reduction variable (negative sign) is added at the end
actual_op = ReductionOp.Add if op is ReductionOp.Sub else op
# inspired by OpenMP: local reduction variable (negative sign) is added at the end
actual_op = ReductionOp.Add if op is ReductionOp.Sub else op
# TODO: can this be avoided somehow?
potential_call = compound_op_to_expr(actual_op, ptr_access, symbol_expr)
if isinstance(potential_call, PsCall):
potential_call.dtype = symbol_expr.dtype
potential_call = self.select_function(potential_call)
# create binop and potentially select corresponding function for e.g. min or max
potential_call = compound_op_to_expr(actual_op, ptr_access, symbol_expr)
if isinstance(potential_call, PsCall):
potential_call.dtype = symbol_expr.dtype
return self.select_function(potential_call)
return potential_call
return potential_call
dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args
......
......@@ -38,7 +38,7 @@ class Platform(ABC):
@abstractmethod
def select_function(
self, call: PsCall
) -> PsExpression | tuple[tuple[PsAstNode, ...], PsExpression]:
) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
"""Select an implementation for the given function on the given data type.
If no viable implementation exists, raise a `MaterializationError`.
......
from ..ast.structural import PsStatement, PsAssignment, PsBlock
from ..ast.structural import PsAssignment, PsBlock
from ..exceptions import MaterializationError
from ..platforms import Platform, CudaPlatform
from ..platforms import Platform
from ..ast import PsAstNode
from ..ast.expressions import PsCall, PsExpression
from ..functions import PsMathFunction, PsReductionFunction
......@@ -25,13 +25,17 @@ class SelectFunctions:
resolved_func = self._platform.select_function(rhs)
match resolved_func:
case ((prepend), expr):
match self._platform:
case CudaPlatform():
# special case: produces statement with atomic operation writing value back to ptr
return PsBlock(prepend + [PsStatement(expr)])
case (prepend, new_rhs):
assert isinstance(prepend, tuple)
match new_rhs:
case PsExpression():
return PsBlock(prepend + (PsAssignment(node.lhs, new_rhs),))
case PsAstNode():
# special case: produces structural with atomic operation writing value back to ptr
return PsBlock(prepend + (new_rhs,))
case _:
return PsBlock(prepend + [PsAssignment(node.lhs, expr)])
assert False, "Unexpected output from SelectFunctions."
case PsExpression():
return PsAssignment(node.lhs, resolved_func)
case _:
......
......@@ -187,16 +187,16 @@ class DefaultKernelCreationDriver:
symbol_expr = typify(PsSymbolExpr(symbol))
ptr_symbol_expr = typify(PsSymbolExpr(reduction_info.ptr_symbol))
init_val = typify(reduction_info.init_val)
ptr_access = PsMemAcc(ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
decl_local_copy = PsDeclaration(symbol_expr, init_val)
ptr_access = PsMemAcc(ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
write_back_ptr = PsCall(PsReductionFunction(ReductionFunctions.WriteBackToPtr, reduction_info.op),
[ptr_symbol_expr, symbol_expr])
prepend_ast = [decl_local_copy] # declare and init local copy with neutral element
prepend_ast = [PsDeclaration(symbol_expr, init_val)] # declare and init local copy with neutral element
append_ast = [PsAssignment(ptr_access, write_back_ptr)] # write back result to reduction target variable
kernel_ast.statements = prepend_ast + kernel_ast.statements + append_ast
kernel_ast.statements = prepend_ast + kernel_ast.statements
kernel_ast.statements += append_ast
# Target-Specific optimizations
if self._target.is_cpu():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment