Skip to content
Snippets Groups Projects
Commit d1489837 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Fix typecheck

parent 13569a61
No related branches found
No related tags found
1 merge request!438Reduction Support
......@@ -46,7 +46,7 @@ class PsVecBroadcast(PsUnOp, PsVectorOp):
class PsVecHorizontal(PsBinOp, PsVectorOp):
"""Extracts scalar value from N vector lanes."""
__match_args__ = ("lanes", "scalar_operand", "vector_operand", "operation")
__match_args__ = ("lanes", "scalar_operand", "vector_operand", "reduction_op")
def __init__(self, lanes: int, scalar_operand: PsExpression, vector_operand: PsExpression,
reduction_op: ReductionOp):
......
......@@ -87,11 +87,18 @@ class CudaPlatform(GenericGpu):
dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args
if isinstance(dtype, PsScalarType) and func in (NumericLimitsFunctions.Min, NumericLimitsFunctions.Max):
if isinstance(dtype, PsScalarType) and func in NumericLimitsFunctions:
assert isinstance(dtype, PsIeeeFloatType)
defines = {NumericLimitsFunctions.Min: "NEG_INFINITY", NumericLimitsFunctions.Max: "POS_INFINITY"}
return PsLiteralExpr(PsLiteral(defines[func], dtype))
match func:
case NumericLimitsFunctions.Min:
define = "NEG_INFINITY"
case NumericLimitsFunctions.Max:
define = "POS_INFINITY"
case _:
raise MaterializationError(f"Cannot materialize call to function {func}")
return PsLiteralExpr(PsLiteral(define, dtype))
if isinstance(dtype, PsIeeeFloatType):
match func:
......
......@@ -286,7 +286,11 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
def extract_ptr(self, param: Parameter) -> str:
if param not in self._pointer_extractions:
ptr = param.symbol
self._buffer_types[ptr] = ptr.dtype.base_type
ptr_dtype = ptr.dtype
assert isinstance(ptr_dtype, PsPointerType)
self._buffer_types[ptr] = ptr_dtype.base_type
self.extract_buffer(ptr, param.name)
buffer = self.get_buffer(param.name)
code = f"{param.dtype.c_string()} {param.name} = ({param.dtype}) {buffer}.buf;"
......
......@@ -197,7 +197,7 @@ class CupyKernelWrapper(KernelWrapper):
args.append(val)
else:
# scalar parameter
val: Any = kwargs[kparam.name]
val = kwargs[kparam.name]
add_arg(kparam.name, val, kparam.dtype)
# Determine launch grid
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment