Skip to content
Snippets Groups Projects
Commit 58cdb792 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Fix bug with doubly inverted sign for subtraction reductions

parent 7306f4dd
No related branches found
No related tags found
1 merge request!438Reduction Support
......@@ -168,10 +168,11 @@ class CudaPlatform(GenericGpu):
match op:
case ReductionOp.Sub:
# workaround for unsupported atomicSub: use atomic add and invert sign
# workaround for unsupported atomicSub: use atomic add
# similar to OpenMP reductions: local copies (negative sign) are added at the end
call.function = CFunction(f"atomicAdd", [ptr_expr.dtype, symbol_expr.dtype],
PsCustomType("void"))
call.args = (ptr_expr, -symbol_expr)
call.args = (ptr_expr, symbol_expr)
case _:
call.function = CFunction(f"atomic{op.name}", [ptr_expr.dtype, symbol_expr.dtype],
PsCustomType("void"))
......
......@@ -8,6 +8,7 @@ from ..functions import CFunction, PsMathFunction, MathFunctions, NumericLimitsF
PsReductionFunction
from ..literals import PsLiteral
from ...compound_op_mapping import compound_op_to_expr
from ...sympyextensions import ReductionOp
from ...types import PsIntegerType, PsIeeeFloatType, PsScalarType, PsPointerType
from .platform import Platform
......@@ -81,8 +82,11 @@ class GenericCpu(Platform):
ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
# inspired by OpenMP: local reduction variable (negative sign) is added at the end
actual_op = ReductionOp.Add if op is ReductionOp.Sub else op
# TODO: can this be avoided somehow?
potential_call = compound_op_to_expr(op, ptr_access, symbol_expr)
potential_call = compound_op_to_expr(actual_op, ptr_access, symbol_expr)
if isinstance(potential_call, PsCall):
potential_call.dtype = symbol_expr.dtype
potential_call = self.select_function(potential_call)
......
......@@ -18,6 +18,7 @@ from ..ast.expressions import (
PsCall,
)
from ..ast.vector import PsVecMemAcc, PsVecBroadcast, PsVecHorizontal
from ...sympyextensions import ReductionOp
from ...types import PsCustomType, PsVectorType, PsPointerType, PsType
from ..constants import PsConstant
......@@ -357,7 +358,9 @@ def _x86_op_intrin(
suffix += "x"
atype = vtype.scalar_type
case PsVecHorizontal():
opstr = f"horizontal_{op.reduction_op.name.lower()}"
# horizontal add instead of sub avoids double inversion of sign
actual_op = ReductionOp.Add if op.reduction_op == ReductionOp.Sub else op.reduction_op
opstr = f"horizontal_{actual_op.name.lower()}"
rtype = vtype.scalar_type
atypes = (vtype.scalar_type, vtype)
case PsAdd():
......
......@@ -10,7 +10,7 @@ INIT_ARR = 2
SIZE = 15
SOLUTION = {
"+": INIT_W + INIT_ARR * SIZE,
"-": INIT_W - INIT_ARR * -SIZE,
"-": INIT_W - INIT_ARR * SIZE,
"*": INIT_W * INIT_ARR ** SIZE,
"min": min(INIT_W, INIT_ARR),
"max": max(INIT_W, INIT_ARR),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment