Skip to content
Snippets Groups Projects
Commit f7790954 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Use NPP library for numeric limits for CUDA, use std limits for HIP

parent c31d4074
No related branches found
No related tags found
1 merge request!438Reduction Support
Pipeline #77216 failed
from __future__ import annotations from __future__ import annotations
from .generic_gpu import GenericGpu from .generic_gpu import GenericGpu
from ..ast.expressions import PsExpression, PsLiteralExpr
from ..functions import PsFunction, NumericLimitsFunctions
from ..literals import PsLiteral
from ...types import PsType, PsIeeeFloatType
class CudaPlatform(GenericGpu): class CudaPlatform(GenericGpu):
"""Platform for the CUDA GPU taret.""" """Platform for the CUDA GPU target."""
@property @property
def required_headers(self) -> set[str]: def required_headers(self) -> set[str]:
return super().required_headers return super().required_headers | {
'"npp.h"',
}
def resolve_numeric_limits(self, func: PsMathFunction, dtype: PsType) -> PsExpression:
assert isinstance(dtype, PsIeeeFloatType)
match func:
case NumericLimitsFunctions.Min:
define = f"NPP_MINABS_{dtype.width}F"
case NumericLimitsFunctions.Max:
define = f"NPP_MAXABS_{dtype.width}F"
case _:
raise MaterializationError(
f"Cannot materialize call to function {func}"
)
return PsLiteralExpr(PsLiteral(define, dtype))
...@@ -9,7 +9,7 @@ from ..ast import PsAstNode ...@@ -9,7 +9,7 @@ from ..ast import PsAstNode
from ..constants import PsConstant from ..constants import PsConstant
from ...compound_op_mapping import compound_op_to_expr from ...compound_op_mapping import compound_op_to_expr
from ...sympyextensions.reduction import ReductionOp from ...sympyextensions.reduction import ReductionOp
from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType, PsType
from ...types.quick import UInt, SInt from ...types.quick import UInt, SInt
from ..exceptions import MaterializationError from ..exceptions import MaterializationError
from .platform import Platform from .platform import Platform
...@@ -203,9 +203,12 @@ class GenericGpu(Platform): ...@@ -203,9 +203,12 @@ class GenericGpu(Platform):
def required_headers(self) -> set[str]: def required_headers(self) -> set[str]:
return { return {
'"gpu_atomics.h"', '"gpu_atomics.h"',
"<cmath>",
} }
@abstractmethod
def resolve_numeric_limits(self, func: PsMathFunction, dtype: PsType) -> PsExpression:
pass
def __init__( def __init__(
self, self,
ctx: KernelCreationContext, ctx: KernelCreationContext,
...@@ -369,19 +372,7 @@ class GenericGpu(Platform): ...@@ -369,19 +372,7 @@ class GenericGpu(Platform):
arg_types = (dtype,) * func.num_args arg_types = (dtype,) * func.num_args
if isinstance(dtype, PsScalarType) and func in NumericLimitsFunctions: if isinstance(dtype, PsScalarType) and func in NumericLimitsFunctions:
assert isinstance(dtype, PsIeeeFloatType) return self.resolve_numeric_limits(func, dtype)
match func:
case NumericLimitsFunctions.Min:
define = "-INFINITY"
case NumericLimitsFunctions.Max:
define = "INFINITY"
case _:
raise MaterializationError(
f"Cannot materialize call to function {func}"
)
return PsLiteralExpr(PsLiteral(define, dtype))
if isinstance(dtype, PsIeeeFloatType) and func in MathFunctions: if isinstance(dtype, PsIeeeFloatType) and func in MathFunctions:
match func: match func:
......
from __future__ import annotations from __future__ import annotations
from .generic_gpu import GenericGpu from .generic_gpu import GenericGpu
from ..ast.expressions import PsExpression, PsLiteralExpr
from ..functions import PsMathFunction
from ..literals import PsLiteral
from ...types import PsType, PsIeeeFloatType
class HipPlatform(GenericGpu): class HipPlatform(GenericGpu):
"""Platform for the HIP GPU taret.""" """Platform for the HIP GPU target."""
@property @property
def required_headers(self) -> set[str]: def required_headers(self) -> set[str]:
return super().required_headers | { return super().required_headers | {
'"pystencils_runtime/hip.h"', '"pystencils_runtime/hip.h"',
"<limits>"
} }
def resolve_numeric_limits(self, func: PsMathFunction, dtype: PsType) -> PsExpression:
assert isinstance(dtype, PsIeeeFloatType)
return PsLiteralExpr(
PsLiteral(
f"std::numeric_limits<{dtype.c_string()}>::{func.function_name}()",
dtype,
)
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment