From 2b8884d99cfe2d83219368688c2b04862f143bd9 Mon Sep 17 00:00:00 2001
From: Daniel Bauer <daniel.j.bauer@fau.de>
Date: Fri, 26 Jul 2024 16:15:24 +0200
Subject: [PATCH] test freeze of cast of bare constant and failing type
 inference

---
 tests/nbackend/kernelcreation/test_freeze.py  |  4 ++++
 .../kernelcreation/test_typification.py       | 19 +++++++++++++++++++
 2 files changed, 23 insertions(+)

diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py
index f16a468e7..21c79979d 100644
--- a/tests/nbackend/kernelcreation/test_freeze.py
+++ b/tests/nbackend/kernelcreation/test_freeze.py
@@ -29,6 +29,7 @@ from pystencils.backend.ast.expressions import (
     PsGe,
     PsCall,
     PsCast,
+    PsConstantExpr,
 )
 from pystencils.backend.constants import PsConstant
 from pystencils.backend.functions import PsMathFunction, MathFunctions
@@ -286,3 +287,6 @@ def test_cast_func():
 
     expr = freeze(CastFunc.as_index(z))
     assert expr.structurally_equal(PsCast(ctx.index_dtype, z2))
+
+    expr = freeze(CastFunc(42, create_type("int16")))
+    assert expr.structurally_equal(PsConstantExpr(PsConstant(42, create_type("int16"))))
diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py
index 2ebfa2ec8..4c6a4d602 100644
--- a/tests/nbackend/kernelcreation/test_typification.py
+++ b/tests/nbackend/kernelcreation/test_typification.py
@@ -14,6 +14,9 @@ from pystencils.backend.ast.structural import (
     PsBlock,
 )
 from pystencils.backend.ast.expressions import (
+    PsAddressOf,
+    PsArrayInitList,
+    PsCast,
     PsConstantExpr,
     PsSymbolExpr,
     PsSubscript,
@@ -478,3 +481,19 @@ def test_cfunction():
 
     with pytest.raises(TypificationError):
         _ = typify(PsCall(threeway, (x, p)))
+
+
+def test_inference_fails():
+    ctx = KernelCreationContext()
+    typify = Typifier(ctx)
+
+    x = PsExpression.make(PsConstant(42))
+
+    with pytest.raises(TypificationError):
+        typify(PsEq(x, x))
+
+    with pytest.raises(TypificationError):
+        typify(PsArrayInitList([x]))
+
+    with pytest.raises(TypificationError):
+        typify(PsCast(ctx.default_dtype, x))
-- 
GitLab