Skip to content
Snippets Groups Projects
Commit 02be4d5e authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Fix typecheck

parent 89a6f36a
1 merge request!438Reduction Support
Pipeline #74416 failed with stages
in 8 minutes and 6 seconds
...@@ -36,8 +36,8 @@ from ..ast.expressions import ( ...@@ -36,8 +36,8 @@ from ..ast.expressions import (
from ..ast.expressions import PsLt, PsAnd from ..ast.expressions import PsLt, PsAnd
from ...types import PsSignedIntegerType, PsIeeeFloatType from ...types import PsSignedIntegerType, PsIeeeFloatType
from ..literals import PsLiteral from ..literals import PsLiteral
from ..functions import MathFunctions, CFunction, ReductionFunctions, NumericLimitsFunctions from ..functions import MathFunctions, CFunction, ReductionFunctions, NumericLimitsFunctions, PsReductionFunction, \
PsMathFunction
int32 = PsSignedIntegerType(width=32, const=False) int32 = PsSignedIntegerType(width=32, const=False)
...@@ -209,65 +209,66 @@ class CudaPlatform(GenericGpu): ...@@ -209,65 +209,66 @@ class CudaPlatform(GenericGpu):
else: else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}") raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]: def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
func = call.function.func call_func = call.function
assert isinstance(call_func, PsReductionFunction | PsMathFunction)
if func in ReductionFunctions: func = call_func.func
match func:
case ReductionFunctions.WriteBackToPtr: if isinstance(call_func, PsReductionFunction) and func is ReductionFunctions.WriteBackToPtr:
ptr_expr, symbol_expr = call.args ptr_expr, symbol_expr = call.args
op = call.function.reduction_op op = call_func.reduction_op
stype = symbol_expr.dtype stype = symbol_expr.dtype
ptrtype = ptr_expr.dtype ptrtype = ptr_expr.dtype
warp_size = 32 # TODO: get from platform/user config warp_size = 32 # TODO: get from platform/user config
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType) assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType) assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64): if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
NotImplementedError("atomic operations are only available for float32/64 datatypes") NotImplementedError("atomic operations are only available for float32/64 datatypes")
# workaround for subtractions -> use additions for reducing intermediate results # workaround for subtractions -> use additions for reducing intermediate results
# similar to OpenMP reductions: local copies (negative sign) are added at the end # similar to OpenMP reductions: local copies (negative sign) are added at the end
match op: match op:
case ReductionOp.Sub: case ReductionOp.Sub:
actual_op = ReductionOp.Add actual_op = ReductionOp.Add
case _: case _:
actual_op = op actual_op = op
# perform local warp reductions # perform local warp reductions
def gen_shuffle_instr(offset: int): def gen_shuffle_instr(offset: int):
full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))) full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype), return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
[full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))]) [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
num_shuffles = math.frexp(warp_size)[1] num_shuffles = math.frexp(warp_size)[1]
shuffles = [PsAssignment(symbol_expr, shuffles = tuple(PsAssignment(symbol_expr,
compound_op_to_expr(actual_op, compound_op_to_expr(actual_op,
symbol_expr, gen_shuffle_instr(pow(2, i - 1)))) symbol_expr, gen_shuffle_instr(pow(2, i - 1))))
for i in reversed(range(1, num_shuffles))] for i in reversed(range(1, num_shuffles)))
# find first thread in warp # find first thread in warp
ispace = self._ctx.get_iteration_space() ispace = self._ctx.get_iteration_space()
is_valid_thread = self._get_condition_for_translation(ispace) is_valid_thread = self._get_condition_for_translation(ispace)
thread_indices_per_dim = [ thread_indices_per_dim = [
idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32))) idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
for i, idx in enumerate(THREAD_IDX[:ispace.rank]) for i, idx in enumerate(THREAD_IDX[:ispace.rank])
] ]
tid: PsExpression = thread_indices_per_dim[0] tid: PsExpression = thread_indices_per_dim[0]
for t in thread_indices_per_dim[1:]: for t in thread_indices_per_dim[1:]:
tid = PsAdd(tid, t) tid = PsAdd(tid, t)
first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))), first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))),
PsConstantExpr(PsConstant(0, SInt(32)))) PsConstantExpr(PsConstant(0, SInt(32))))
cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
# use atomic operation on first thread of warp to sync # use atomic operation on first thread of warp to sync
call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void")) call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
call.args = (ptr_expr, symbol_expr) call.args = (ptr_expr, symbol_expr)
# assemble warp reduction # assemble warp reduction
return (shuffles, PsConditional(cond, PsBlock([PsStatement(call)]))) return shuffles, PsConditional(cond, PsBlock([PsStatement(call)]))
dtype = call.get_dtype() dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args arg_types = (dtype,) * func.num_args
......
...@@ -4,7 +4,8 @@ from typing import Sequence ...@@ -4,7 +4,8 @@ from typing import Sequence
from ..ast.expressions import PsCall, PsMemAcc, PsConstantExpr from ..ast.expressions import PsCall, PsMemAcc, PsConstantExpr
from ..ast import PsAstNode from ..ast import PsAstNode
from ..functions import CFunction, MathFunctions, NumericLimitsFunctions, ReductionFunctions from ..functions import CFunction, MathFunctions, NumericLimitsFunctions, ReductionFunctions, PsMathFunction, \
PsReductionFunction
from ..literals import PsLiteral from ..literals import PsLiteral
from ...compound_op_mapping import compound_op_to_expr from ...compound_op_mapping import compound_op_to_expr
from ...sympyextensions import ReductionOp from ...sympyextensions import ReductionOp
...@@ -59,30 +60,31 @@ class GenericCpu(Platform): ...@@ -59,30 +60,31 @@ class GenericCpu(Platform):
else: else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}") raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]: def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
func = call.function.func call_func = call.function
assert isinstance(call_func, PsReductionFunction | PsMathFunction)
if func in ReductionFunctions: func = call_func.func
match func:
case ReductionFunctions.WriteBackToPtr: if isinstance(call_func, PsReductionFunction) and func is ReductionFunctions.WriteBackToPtr:
ptr_expr, symbol_expr = call.args ptr_expr, symbol_expr = call.args
op = call.function.reduction_op op = call_func.reduction_op
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType) assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType) assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype))) ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
# inspired by OpenMP: local reduction variable (negative sign) is added at the end # inspired by OpenMP: local reduction variable (negative sign) is added at the end
actual_op = ReductionOp.Add if op is ReductionOp.Sub else op actual_op = ReductionOp.Add if op is ReductionOp.Sub else op
# TODO: can this be avoided somehow? # create binop and potentially select corresponding function for e.g. min or max
potential_call = compound_op_to_expr(actual_op, ptr_access, symbol_expr) potential_call = compound_op_to_expr(actual_op, ptr_access, symbol_expr)
if isinstance(potential_call, PsCall): if isinstance(potential_call, PsCall):
potential_call.dtype = symbol_expr.dtype potential_call.dtype = symbol_expr.dtype
potential_call = self.select_function(potential_call) return self.select_function(potential_call)
return potential_call return potential_call
dtype = call.get_dtype() dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args arg_types = (dtype,) * func.num_args
......
...@@ -38,7 +38,7 @@ class Platform(ABC): ...@@ -38,7 +38,7 @@ class Platform(ABC):
@abstractmethod @abstractmethod
def select_function( def select_function(
self, call: PsCall self, call: PsCall
) -> PsExpression | tuple[tuple[PsAstNode, ...], PsExpression]: ) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
"""Select an implementation for the given function on the given data type. """Select an implementation for the given function on the given data type.
If no viable implementation exists, raise a `MaterializationError`. If no viable implementation exists, raise a `MaterializationError`.
......
from ..ast.structural import PsStatement, PsAssignment, PsBlock from ..ast.structural import PsAssignment, PsBlock
from ..exceptions import MaterializationError from ..exceptions import MaterializationError
from ..platforms import Platform, CudaPlatform from ..platforms import Platform
from ..ast import PsAstNode from ..ast import PsAstNode
from ..ast.expressions import PsCall, PsExpression from ..ast.expressions import PsCall, PsExpression
from ..functions import PsMathFunction, PsReductionFunction from ..functions import PsMathFunction, PsReductionFunction
...@@ -25,13 +25,17 @@ class SelectFunctions: ...@@ -25,13 +25,17 @@ class SelectFunctions:
resolved_func = self._platform.select_function(rhs) resolved_func = self._platform.select_function(rhs)
match resolved_func: match resolved_func:
case ((prepend), expr): case (prepend, new_rhs):
match self._platform: assert isinstance(prepend, tuple)
case CudaPlatform():
# special case: produces statement with atomic operation writing value back to ptr match new_rhs:
return PsBlock(prepend + [PsStatement(expr)]) case PsExpression():
return PsBlock(prepend + (PsAssignment(node.lhs, new_rhs),))
case PsAstNode():
# special case: produces structural with atomic operation writing value back to ptr
return PsBlock(prepend + (new_rhs,))
case _: case _:
return PsBlock(prepend + [PsAssignment(node.lhs, expr)]) assert False, "Unexpected output from SelectFunctions."
case PsExpression(): case PsExpression():
return PsAssignment(node.lhs, resolved_func) return PsAssignment(node.lhs, resolved_func)
case _: case _:
......
...@@ -187,16 +187,16 @@ class DefaultKernelCreationDriver: ...@@ -187,16 +187,16 @@ class DefaultKernelCreationDriver:
symbol_expr = typify(PsSymbolExpr(symbol)) symbol_expr = typify(PsSymbolExpr(symbol))
ptr_symbol_expr = typify(PsSymbolExpr(reduction_info.ptr_symbol)) ptr_symbol_expr = typify(PsSymbolExpr(reduction_info.ptr_symbol))
init_val = typify(reduction_info.init_val) init_val = typify(reduction_info.init_val)
ptr_access = PsMemAcc(ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
decl_local_copy = PsDeclaration(symbol_expr, init_val) ptr_access = PsMemAcc(ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
write_back_ptr = PsCall(PsReductionFunction(ReductionFunctions.WriteBackToPtr, reduction_info.op), write_back_ptr = PsCall(PsReductionFunction(ReductionFunctions.WriteBackToPtr, reduction_info.op),
[ptr_symbol_expr, symbol_expr]) [ptr_symbol_expr, symbol_expr])
prepend_ast = [decl_local_copy] # declare and init local copy with neutral element prepend_ast = [PsDeclaration(symbol_expr, init_val)] # declare and init local copy with neutral element
append_ast = [PsAssignment(ptr_access, write_back_ptr)] # write back result to reduction target variable append_ast = [PsAssignment(ptr_access, write_back_ptr)] # write back result to reduction target variable
kernel_ast.statements = prepend_ast + kernel_ast.statements + append_ast kernel_ast.statements = prepend_ast + kernel_ast.statements
kernel_ast.statements += append_ast
# Target-Specific optimizations # Target-Specific optimizations
if self._target.is_cpu(): if self._target.is_cpu():
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment