From 4e5c89b9cd23610f27d61612b14a12e968724e3f Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Mon, 24 Mar 2025 16:54:50 +0100
Subject: [PATCH] Fix lint, typecheck

---
 src/pystencils/backend/platforms/cuda.py        | 5 +++--
 src/pystencils/backend/platforms/generic_gpu.py | 4 ++--
 src/pystencils/backend/platforms/hip.py         | 4 ++--
 3 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py
index b5b3478e4..7a5074677 100644
--- a/src/pystencils/backend/platforms/cuda.py
+++ b/src/pystencils/backend/platforms/cuda.py
@@ -2,7 +2,8 @@ from __future__ import annotations
 
 from .generic_gpu import GenericGpu
 from ..ast.expressions import PsExpression, PsLiteralExpr
-from ..functions import PsFunction, NumericLimitsFunctions
+from ..exceptions import MaterializationError
+from ..functions import NumericLimitsFunctions
 from ..literals import PsLiteral
 from ...types import PsType, PsIeeeFloatType
 
@@ -16,7 +17,7 @@ class CudaPlatform(GenericGpu):
             '"npp.h"',
         }
 
-    def resolve_numeric_limits(self, func: PsMathFunction, dtype: PsType) -> PsExpression:
+    def resolve_numeric_limits(self, func: NumericLimitsFunctions, dtype: PsType) -> PsExpression:
         assert isinstance(dtype, PsIeeeFloatType)
 
         match func:
diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py
index 787b390fe..8b7eead8d 100644
--- a/src/pystencils/backend/platforms/generic_gpu.py
+++ b/src/pystencils/backend/platforms/generic_gpu.py
@@ -206,7 +206,7 @@ class GenericGpu(Platform):
         }
 
     @abstractmethod
-    def resolve_numeric_limits(self, func: PsMathFunction, dtype: PsType) -> PsExpression:
+    def resolve_numeric_limits(self, func: NumericLimitsFunctions, dtype: PsType) -> PsExpression:
         pass
 
     def __init__(
@@ -371,7 +371,7 @@ class GenericGpu(Platform):
         dtype = call.get_dtype()
         arg_types = (dtype,) * func.num_args
 
-        if isinstance(dtype, PsScalarType) and func in NumericLimitsFunctions:
+        if isinstance(dtype, PsScalarType) and isinstance(func, NumericLimitsFunctions):
             return self.resolve_numeric_limits(func, dtype)
 
         if isinstance(dtype, PsIeeeFloatType) and func in MathFunctions:
diff --git a/src/pystencils/backend/platforms/hip.py b/src/pystencils/backend/platforms/hip.py
index 60e249aeb..45d60452b 100644
--- a/src/pystencils/backend/platforms/hip.py
+++ b/src/pystencils/backend/platforms/hip.py
@@ -2,7 +2,7 @@ from __future__ import annotations
 
 from .generic_gpu import GenericGpu
 from ..ast.expressions import PsExpression, PsLiteralExpr
-from ..functions import PsMathFunction
+from ..functions import NumericLimitsFunctions
 from ..literals import PsLiteral
 from ...types import PsType, PsIeeeFloatType
 
@@ -17,7 +17,7 @@ class HipPlatform(GenericGpu):
             "<limits>"
         }
 
-    def resolve_numeric_limits(self, func: PsMathFunction, dtype: PsType) -> PsExpression:
+    def resolve_numeric_limits(self, func: NumericLimitsFunctions, dtype: PsType) -> PsExpression:
         assert isinstance(dtype, PsIeeeFloatType)
 
         return PsLiteralExpr(
-- 
GitLab