Skip to content
Snippets Groups Projects
Commit 58cdb792 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Fix bug with doubly inverted sign for subtraction reductions

parent 7306f4dd
No related branches found
No related tags found
1 merge request!438Reduction Support
...@@ -168,10 +168,11 @@ class CudaPlatform(GenericGpu): ...@@ -168,10 +168,11 @@ class CudaPlatform(GenericGpu):
match op: match op:
case ReductionOp.Sub: case ReductionOp.Sub:
# workaround for unsupported atomicSub: use atomic add and invert sign # workaround for unsupported atomicSub: use atomic add
# similar to OpenMP reductions: local copies (negative sign) are added at the end
call.function = CFunction(f"atomicAdd", [ptr_expr.dtype, symbol_expr.dtype], call.function = CFunction(f"atomicAdd", [ptr_expr.dtype, symbol_expr.dtype],
PsCustomType("void")) PsCustomType("void"))
call.args = (ptr_expr, -symbol_expr) call.args = (ptr_expr, symbol_expr)
case _: case _:
call.function = CFunction(f"atomic{op.name}", [ptr_expr.dtype, symbol_expr.dtype], call.function = CFunction(f"atomic{op.name}", [ptr_expr.dtype, symbol_expr.dtype],
PsCustomType("void")) PsCustomType("void"))
......
...@@ -8,6 +8,7 @@ from ..functions import CFunction, PsMathFunction, MathFunctions, NumericLimitsF ...@@ -8,6 +8,7 @@ from ..functions import CFunction, PsMathFunction, MathFunctions, NumericLimitsF
PsReductionFunction PsReductionFunction
from ..literals import PsLiteral from ..literals import PsLiteral
from ...compound_op_mapping import compound_op_to_expr from ...compound_op_mapping import compound_op_to_expr
from ...sympyextensions import ReductionOp
from ...types import PsIntegerType, PsIeeeFloatType, PsScalarType, PsPointerType from ...types import PsIntegerType, PsIeeeFloatType, PsScalarType, PsPointerType
from .platform import Platform from .platform import Platform
...@@ -81,8 +82,11 @@ class GenericCpu(Platform): ...@@ -81,8 +82,11 @@ class GenericCpu(Platform):
ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype))) ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
# inspired by OpenMP: local reduction variable (negative sign) is added at the end
actual_op = ReductionOp.Add if op is ReductionOp.Sub else op
# TODO: can this be avoided somehow? # TODO: can this be avoided somehow?
potential_call = compound_op_to_expr(op, ptr_access, symbol_expr) potential_call = compound_op_to_expr(actual_op, ptr_access, symbol_expr)
if isinstance(potential_call, PsCall): if isinstance(potential_call, PsCall):
potential_call.dtype = symbol_expr.dtype potential_call.dtype = symbol_expr.dtype
potential_call = self.select_function(potential_call) potential_call = self.select_function(potential_call)
......
...@@ -18,6 +18,7 @@ from ..ast.expressions import ( ...@@ -18,6 +18,7 @@ from ..ast.expressions import (
PsCall, PsCall,
) )
from ..ast.vector import PsVecMemAcc, PsVecBroadcast, PsVecHorizontal from ..ast.vector import PsVecMemAcc, PsVecBroadcast, PsVecHorizontal
from ...sympyextensions import ReductionOp
from ...types import PsCustomType, PsVectorType, PsPointerType, PsType from ...types import PsCustomType, PsVectorType, PsPointerType, PsType
from ..constants import PsConstant from ..constants import PsConstant
...@@ -357,7 +358,9 @@ def _x86_op_intrin( ...@@ -357,7 +358,9 @@ def _x86_op_intrin(
suffix += "x" suffix += "x"
atype = vtype.scalar_type atype = vtype.scalar_type
case PsVecHorizontal(): case PsVecHorizontal():
opstr = f"horizontal_{op.reduction_op.name.lower()}" # horizontal add instead of sub avoids double inversion of sign
actual_op = ReductionOp.Add if op.reduction_op == ReductionOp.Sub else op.reduction_op
opstr = f"horizontal_{actual_op.name.lower()}"
rtype = vtype.scalar_type rtype = vtype.scalar_type
atypes = (vtype.scalar_type, vtype) atypes = (vtype.scalar_type, vtype)
case PsAdd(): case PsAdd():
......
...@@ -10,7 +10,7 @@ INIT_ARR = 2 ...@@ -10,7 +10,7 @@ INIT_ARR = 2
SIZE = 15 SIZE = 15
SOLUTION = { SOLUTION = {
"+": INIT_W + INIT_ARR * SIZE, "+": INIT_W + INIT_ARR * SIZE,
"-": INIT_W - INIT_ARR * -SIZE, "-": INIT_W - INIT_ARR * SIZE,
"*": INIT_W * INIT_ARR ** SIZE, "*": INIT_W * INIT_ARR ** SIZE,
"min": min(INIT_W, INIT_ARR), "min": min(INIT_W, INIT_ARR),
"max": max(INIT_W, INIT_ARR), "max": max(INIT_W, INIT_ARR),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment