From 2b8884d99cfe2d83219368688c2b04862f143bd9 Mon Sep 17 00:00:00 2001 From: Daniel Bauer <daniel.j.bauer@fau.de> Date: Fri, 26 Jul 2024 16:15:24 +0200 Subject: [PATCH] test freeze of cast of bare constant and failing type inference --- tests/nbackend/kernelcreation/test_freeze.py | 4 ++++ .../kernelcreation/test_typification.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index f16a468e7..21c79979d 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -29,6 +29,7 @@ from pystencils.backend.ast.expressions import ( PsGe, PsCall, PsCast, + PsConstantExpr, ) from pystencils.backend.constants import PsConstant from pystencils.backend.functions import PsMathFunction, MathFunctions @@ -286,3 +287,6 @@ def test_cast_func(): expr = freeze(CastFunc.as_index(z)) assert expr.structurally_equal(PsCast(ctx.index_dtype, z2)) + + expr = freeze(CastFunc(42, create_type("int16"))) + assert expr.structurally_equal(PsConstantExpr(PsConstant(42, create_type("int16")))) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 2ebfa2ec8..4c6a4d602 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -14,6 +14,9 @@ from pystencils.backend.ast.structural import ( PsBlock, ) from pystencils.backend.ast.expressions import ( + PsAddressOf, + PsArrayInitList, + PsCast, PsConstantExpr, PsSymbolExpr, PsSubscript, @@ -478,3 +481,19 @@ def test_cfunction(): with pytest.raises(TypificationError): _ = typify(PsCall(threeway, (x, p))) + + +def test_inference_fails(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + + x = PsExpression.make(PsConstant(42)) + + with pytest.raises(TypificationError): + typify(PsEq(x, x)) + + with pytest.raises(TypificationError): + typify(PsArrayInitList([x])) + + with pytest.raises(TypificationError): + typify(PsCast(ctx.default_dtype, x)) -- GitLab