diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 66b93f134c7f5c3c84062229a76ab1b55167a8cc..7a95c4183fce019db75c3332f14570a4a143fdc6 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -15,10 +15,13 @@ TODO: Figure out the best way to describe function signatures and overloads for """ from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import Any, Sequence, TYPE_CHECKING from abc import ABC from enum import Enum +from ..types import PsType +from .exceptions import PsInternalCompilerError + if TYPE_CHECKING: from .ast.expressions import PsExpression @@ -71,8 +74,38 @@ class PsFunction(ABC): class CFunction(PsFunction): """A concrete C function.""" - def __init__(self, name: str, arg_count: int): - super().__init__(name, arg_count) + __match_args__ = ("name", "argument_types", "return_type") + + @staticmethod + def parse(obj) -> CFunction: + import inspect + from pystencils.types import create_type + + if not inspect.isfunction(obj): + raise PsInternalCompilerError(f"Cannot parse object {obj} as a function") + + func_sig = inspect.signature(obj) + func_name = obj.__name__ + arg_types = [ + create_type(param.annotation) for param in func_sig.parameters.values() + ] + ret_type = create_type(func_sig.return_annotation) + + return CFunction(func_name, arg_types, ret_type) + + def __init__(self, name: str, arg_types: Sequence[PsType], return_type: PsType): + super().__init__(name, len(arg_types)) + + self._arg_types = tuple(arg_types) + self._return_type = return_type + + @property + def argument_types(self) -> tuple[PsType, ...]: + return self._arg_types + + @property + def return_type(self) -> PsType: + return self._return_type class PsMathFunction(PsFunction): diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index dea542b197c0a9c31622f1d5b1e8f1dc8990265f..d933004fc502831a0946fd129b26758fe5fbf545 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -43,7 +43,7 @@ from ..ast.expressions import ( PsNeg, PsNot, ) -from ..functions import PsMathFunction +from ..functions import PsMathFunction, CFunction __all__ = ["Typifier"] @@ -467,6 +467,14 @@ class Typifier: for arg in args: self.visit_expr(arg, tc) tc.infer_dtype(expr) + + case CFunction(_, arg_types, ret_type): + tc.apply_dtype(ret_type, expr) + + for arg, arg_type in zip(args, arg_types, strict=True): + arg_tc = TypeContext(arg_type) + self.visit_expr(arg, arg_tc) + case _: raise TypificationError( f"Don't know how to typify calls to {function}" diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 6899ac9474d303623f79e2bdc7c3765c64380a6c..17589bf27d3109a3ce891acfdc248e0120da2f77 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -55,6 +55,7 @@ class GenericCpu(Platform): self, math_function: PsMathFunction, dtype: PsType ) -> CFunction: func = math_function.func + arg_types = (dtype,) * func.num_args if isinstance(dtype, PsIeeeFloatType) and dtype.width in (32, 64): match func: case ( @@ -64,9 +65,9 @@ class GenericCpu(Platform): | MathFunctions.Tan | MathFunctions.Pow ): - return CFunction(func.function_name, func.num_args) + return CFunction(func.function_name, arg_types, dtype) case MathFunctions.Abs | MathFunctions.Min | MathFunctions.Max: - return CFunction("f" + func.function_name, func.num_args) + return CFunction("f" + func.function_name, arg_types, dtype) raise MaterializationError( f"No implementation available for function {math_function} on data type {dtype}" diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index fa5af4655810943c47f503329a8c41ce3baa36c5..ccaf9fbe99f46ce4b0ecbb81c775c9f274678026 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -10,7 +10,7 @@ from ..ast.expressions import ( PsSubscript, ) from ..transformations.select_intrinsics import IntrinsicOps -from ...types import PsCustomType, PsVectorType +from ...types import PsCustomType, PsVectorType, PsPointerType from ..constants import PsConstant from ..exceptions import MaterializationError @@ -124,10 +124,13 @@ class X86VectorCpu(GenericVectorCpu): def constant_vector(self, c: PsConstant) -> PsExpression: vtype = c.dtype assert isinstance(vtype, PsVectorType) + stype = vtype.scalar_type prefix = self._vector_arch.intrin_prefix(vtype) suffix = self._vector_arch.intrin_suffix(vtype) - set_func = CFunction(f"{prefix}_set_{suffix}", vtype.vector_entries) + set_func = CFunction( + f"{prefix}_set_{suffix}", (stype,) * vtype.vector_entries, vtype + ) values = c.value return set_func(*values) @@ -164,7 +167,10 @@ def _x86_packed_load( ) -> CFunction: prefix = varch.intrin_prefix(vtype) suffix = varch.intrin_suffix(vtype) - return CFunction(f"{prefix}_load{'' if aligned else 'u'}_{suffix}", 1) + ptr_type = PsPointerType(vtype.scalar_type, const=True) + return CFunction( + f"{prefix}_load{'' if aligned else 'u'}_{suffix}", (ptr_type,), vtype + ) @cache @@ -173,7 +179,12 @@ def _x86_packed_store( ) -> CFunction: prefix = varch.intrin_prefix(vtype) suffix = varch.intrin_suffix(vtype) - return CFunction(f"{prefix}_store{'' if aligned else 'u'}_{suffix}", 2) + ptr_type = PsPointerType(vtype.scalar_type, const=True) + return CFunction( + f"{prefix}_store{'' if aligned else 'u'}_{suffix}", + (ptr_type, vtype), + PsCustomType("void"), + ) @cache @@ -197,4 +208,5 @@ def _x86_op_intrin( case _: assert False - return CFunction(f"{prefix}_{opstr}_{suffix}", 3 if op == IntrinsicOps.FMA else 2) + num_args = 3 if op == IntrinsicOps.FMA else 2 + return CFunction(f"{prefix}_{opstr}_{suffix}", (vtype,) * num_args, vtype) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 0afa5b9e8da6dc18bb47fb2fdf49933b3f7d28d9..0ca35f6ee1c7a3fee28b091c4c04c9fca7745f6f 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -26,10 +26,12 @@ from pystencils.backend.ast.expressions import ( PsLe, PsGt, PsLt, + PsCall, ) from pystencils.backend.constants import PsConstant +from pystencils.backend.functions import CFunction from pystencils.types import constify -from pystencils.types.quick import Fp, Bool, create_type, create_numeric_type +from pystencils.types.quick import Fp, Int, Bool, create_type, create_numeric_type from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError @@ -354,7 +356,7 @@ def test_invalid_conditions(): x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"] - + cond = PsConditional(x + y, PsBlock([])) with pytest.raises(TypificationError): typify(cond) @@ -362,3 +364,24 @@ def test_invalid_conditions(): cond = PsConditional(PsAnd(p, PsOr(x, q)), PsBlock([])) with pytest.raises(TypificationError): typify(cond) + + +def test_cfunction(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] + p, q = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "pq"] + + def _threeway(x: np.float32, y: np.float32) -> np.int32: + assert False + + threeway = CFunction.parse(_threeway) + + result = typify(PsCall(threeway, [x, y])) + + assert result.get_dtype() == Int(32, const=True) + assert result.args[0].get_dtype() == Fp(32, const=True) + assert result.args[1].get_dtype() == Fp(32, const=True) + + with pytest.raises(TypificationError): + _ = typify(PsCall(threeway, (x, p)))