From 4e5c89b9cd23610f27d61612b14a12e968724e3f Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Mon, 24 Mar 2025 16:54:50 +0100 Subject: [PATCH] Fix lint, typecheck --- src/pystencils/backend/platforms/cuda.py | 5 +++-- src/pystencils/backend/platforms/generic_gpu.py | 4 ++-- src/pystencils/backend/platforms/hip.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index b5b3478e4..7a5074677 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -2,7 +2,8 @@ from __future__ import annotations from .generic_gpu import GenericGpu from ..ast.expressions import PsExpression, PsLiteralExpr -from ..functions import PsFunction, NumericLimitsFunctions +from ..exceptions import MaterializationError +from ..functions import NumericLimitsFunctions from ..literals import PsLiteral from ...types import PsType, PsIeeeFloatType @@ -16,7 +17,7 @@ class CudaPlatform(GenericGpu): '"npp.h"', } - def resolve_numeric_limits(self, func: PsMathFunction, dtype: PsType) -> PsExpression: + def resolve_numeric_limits(self, func: NumericLimitsFunctions, dtype: PsType) -> PsExpression: assert isinstance(dtype, PsIeeeFloatType) match func: diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 787b390fe..8b7eead8d 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -206,7 +206,7 @@ class GenericGpu(Platform): } @abstractmethod - def resolve_numeric_limits(self, func: PsMathFunction, dtype: PsType) -> PsExpression: + def resolve_numeric_limits(self, func: NumericLimitsFunctions, dtype: PsType) -> PsExpression: pass def __init__( @@ -371,7 +371,7 @@ class GenericGpu(Platform): dtype = call.get_dtype() arg_types = (dtype,) * func.num_args - if isinstance(dtype, PsScalarType) and func in NumericLimitsFunctions: + if isinstance(dtype, PsScalarType) and isinstance(func, NumericLimitsFunctions): return self.resolve_numeric_limits(func, dtype) if isinstance(dtype, PsIeeeFloatType) and func in MathFunctions: diff --git a/src/pystencils/backend/platforms/hip.py b/src/pystencils/backend/platforms/hip.py index 60e249aeb..45d60452b 100644 --- a/src/pystencils/backend/platforms/hip.py +++ b/src/pystencils/backend/platforms/hip.py @@ -2,7 +2,7 @@ from __future__ import annotations from .generic_gpu import GenericGpu from ..ast.expressions import PsExpression, PsLiteralExpr -from ..functions import PsMathFunction +from ..functions import NumericLimitsFunctions from ..literals import PsLiteral from ...types import PsType, PsIeeeFloatType @@ -17,7 +17,7 @@ class HipPlatform(GenericGpu): "<limits>" } - def resolve_numeric_limits(self, func: PsMathFunction, dtype: PsType) -> PsExpression: + def resolve_numeric_limits(self, func: NumericLimitsFunctions, dtype: PsType) -> PsExpression: assert isinstance(dtype, PsIeeeFloatType) return PsLiteralExpr( -- GitLab