From d14898373c0503bd9bbde0d3c0ee35888519f11f Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Mon, 17 Feb 2025 17:54:40 +0100
Subject: [PATCH] Fix typecheck

---
 src/pystencils/backend/ast/vector.py       |  2 +-
 src/pystencils/backend/platforms/cuda.py   | 13 ++++++++++---
 src/pystencils/jit/cpu_extension_module.py |  6 +++++-
 src/pystencils/jit/gpu_cupy.py             |  2 +-
 4 files changed, 17 insertions(+), 6 deletions(-)

diff --git a/src/pystencils/backend/ast/vector.py b/src/pystencils/backend/ast/vector.py
index 5121987a8..4f5224133 100644
--- a/src/pystencils/backend/ast/vector.py
+++ b/src/pystencils/backend/ast/vector.py
@@ -46,7 +46,7 @@ class PsVecBroadcast(PsUnOp, PsVectorOp):
 class PsVecHorizontal(PsBinOp, PsVectorOp):
     """Extracts scalar value from N vector lanes."""
 
-    __match_args__ = ("lanes", "scalar_operand", "vector_operand", "operation")
+    __match_args__ = ("lanes", "scalar_operand", "vector_operand", "reduction_op")
 
     def __init__(self, lanes: int, scalar_operand: PsExpression, vector_operand: PsExpression,
                  reduction_op: ReductionOp):
diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py
index e8c8f6a3a..12a18b41b 100644
--- a/src/pystencils/backend/platforms/cuda.py
+++ b/src/pystencils/backend/platforms/cuda.py
@@ -87,11 +87,18 @@ class CudaPlatform(GenericGpu):
         dtype = call.get_dtype()
         arg_types = (dtype,) * func.num_args
 
-        if isinstance(dtype, PsScalarType) and func in (NumericLimitsFunctions.Min, NumericLimitsFunctions.Max):
+        if isinstance(dtype, PsScalarType) and func in NumericLimitsFunctions:
             assert isinstance(dtype, PsIeeeFloatType)
-            defines = {NumericLimitsFunctions.Min: "NEG_INFINITY", NumericLimitsFunctions.Max: "POS_INFINITY"}
 
-            return PsLiteralExpr(PsLiteral(defines[func], dtype))
+            match func:
+                case NumericLimitsFunctions.Min:
+                    define = "NEG_INFINITY"
+                case NumericLimitsFunctions.Max:
+                    define = "POS_INFINITY"
+                case _:
+                    raise MaterializationError(f"Cannot materialize call to function {func}")
+
+            return PsLiteralExpr(PsLiteral(define, dtype))
 
         if isinstance(dtype, PsIeeeFloatType):
             match func:
diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py
index bdf99b7ad..03260f649 100644
--- a/src/pystencils/jit/cpu_extension_module.py
+++ b/src/pystencils/jit/cpu_extension_module.py
@@ -286,7 +286,11 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
     def extract_ptr(self, param: Parameter) -> str:
         if param not in self._pointer_extractions:
             ptr = param.symbol
-            self._buffer_types[ptr] = ptr.dtype.base_type
+            ptr_dtype = ptr.dtype
+
+            assert isinstance(ptr_dtype, PsPointerType)
+
+            self._buffer_types[ptr] = ptr_dtype.base_type
             self.extract_buffer(ptr, param.name)
             buffer = self.get_buffer(param.name)
             code = f"{param.dtype.c_string()} {param.name} = ({param.dtype}) {buffer}.buf;"
diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py
index 0792b6c01..6b0ccf02f 100644
--- a/src/pystencils/jit/gpu_cupy.py
+++ b/src/pystencils/jit/gpu_cupy.py
@@ -197,7 +197,7 @@ class CupyKernelWrapper(KernelWrapper):
                 args.append(val)
             else:
                 #   scalar parameter
-                val: Any = kwargs[kparam.name]
+                val = kwargs[kparam.name]
                 add_arg(kparam.name, val, kparam.dtype)
 
         #   Determine launch grid
-- 
GitLab