diff --git a/src/pystencils/backend/ast/vector.py b/src/pystencils/backend/ast/vector.py index 5121987a8129d682cb07e4b9cae5d6c9d6741817..4f52241330c1570efb4824c018251006ab9de02f 100644 --- a/src/pystencils/backend/ast/vector.py +++ b/src/pystencils/backend/ast/vector.py @@ -46,7 +46,7 @@ class PsVecBroadcast(PsUnOp, PsVectorOp): class PsVecHorizontal(PsBinOp, PsVectorOp): """Extracts scalar value from N vector lanes.""" - __match_args__ = ("lanes", "scalar_operand", "vector_operand", "operation") + __match_args__ = ("lanes", "scalar_operand", "vector_operand", "reduction_op") def __init__(self, lanes: int, scalar_operand: PsExpression, vector_operand: PsExpression, reduction_op: ReductionOp): diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index e8c8f6a3a2a9466405376d8b898ec8cece4dedea..12a18b41b2da5c2c1a0b89aee50d0e9a4860d054 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -87,11 +87,18 @@ class CudaPlatform(GenericGpu): dtype = call.get_dtype() arg_types = (dtype,) * func.num_args - if isinstance(dtype, PsScalarType) and func in (NumericLimitsFunctions.Min, NumericLimitsFunctions.Max): + if isinstance(dtype, PsScalarType) and func in NumericLimitsFunctions: assert isinstance(dtype, PsIeeeFloatType) - defines = {NumericLimitsFunctions.Min: "NEG_INFINITY", NumericLimitsFunctions.Max: "POS_INFINITY"} - return PsLiteralExpr(PsLiteral(defines[func], dtype)) + match func: + case NumericLimitsFunctions.Min: + define = "NEG_INFINITY" + case NumericLimitsFunctions.Max: + define = "POS_INFINITY" + case _: + raise MaterializationError(f"Cannot materialize call to function {func}") + + return PsLiteralExpr(PsLiteral(define, dtype)) if isinstance(dtype, PsIeeeFloatType): match func: diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py index bdf99b7adf7c39eb99daff37cf80840728e8a646..03260f6496ed69abb685b69e708e747724c30558 100644 --- a/src/pystencils/jit/cpu_extension_module.py +++ b/src/pystencils/jit/cpu_extension_module.py @@ -286,7 +286,11 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ def extract_ptr(self, param: Parameter) -> str: if param not in self._pointer_extractions: ptr = param.symbol - self._buffer_types[ptr] = ptr.dtype.base_type + ptr_dtype = ptr.dtype + + assert isinstance(ptr_dtype, PsPointerType) + + self._buffer_types[ptr] = ptr_dtype.base_type self.extract_buffer(ptr, param.name) buffer = self.get_buffer(param.name) code = f"{param.dtype.c_string()} {param.name} = ({param.dtype}) {buffer}.buf;" diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py index 0792b6c01a526ab9fc24f5ac471a8456ee9f2745..6b0ccf02f3f77ef33bf13cb487f2b992b7088e1f 100644 --- a/src/pystencils/jit/gpu_cupy.py +++ b/src/pystencils/jit/gpu_cupy.py @@ -197,7 +197,7 @@ class CupyKernelWrapper(KernelWrapper): args.append(val) else: # scalar parameter - val: Any = kwargs[kparam.name] + val = kwargs[kparam.name] add_arg(kparam.name, val, kparam.dtype) # Determine launch grid