diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 66da00c114860aa4be7084311251de7f5c3f04bb..d85c5341073e548301ca74131ad091ef5db35595 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -646,9 +646,8 @@ class Typifier:
                 self.visit_expr(op2, args_tc)
 
                 if args_tc.target_type is None:
-                    raise TypificationError(
-                        f"Unable to determine type of arguments to relation: {expr}"
-                    )
+                    args_tc.apply_hint(ToDefault(self._ctx.default_dtype))
+                
                 if not isinstance(args_tc.target_type, PsNumericType):
                     raise TypificationError(
                         f"Invalid type in arguments to relation\n"
@@ -738,9 +737,7 @@ class Typifier:
                 self.visit_expr(arg, arg_tc)
 
                 if arg_tc.target_type is None:
-                    raise TypificationError(
-                        f"Unable to determine type of argument to Cast: {arg}"
-                    )
+                    arg_tc.apply_hint(ToDefault(self._ctx.default_dtype))
 
                 tc.apply_dtype(dtype, expr)
 
diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py
index 7424977be4b90ea0794d2dd3f2c82aef372acbf7..0c707e33125d05cf5c1261ec68e899090f182bbb 100644
--- a/tests/nbackend/kernelcreation/test_typification.py
+++ b/tests/nbackend/kernelcreation/test_typification.py
@@ -6,6 +6,7 @@ from typing import cast
 
 from pystencils import Assignment, TypedSymbol, Field, FieldType, AddAugmentedAssignment
 
+from pystencils.backend.ast import dfs_preorder
 from pystencils.backend.ast.structural import (
     PsDeclaration,
     PsAssignment,
@@ -14,7 +15,6 @@ from pystencils.backend.ast.structural import (
     PsBlock,
 )
 from pystencils.backend.ast.expressions import (
-    PsAddressOf,
     PsArrayInitList,
     PsCast,
     PsConstantExpr,
@@ -32,7 +32,6 @@ from pystencils.backend.ast.expressions import (
     PsLt,
     PsCall,
     PsTernary,
-    PsArrayInitList,
 )
 from pystencils.backend.constants import PsConstant
 from pystencils.backend.functions import CFunction
@@ -218,6 +217,37 @@ def test_default_typing():
     check(expr)
 
 
+def test_nested_contexts_defaults():
+    ctx = KernelCreationContext(default_dtype=Fp(16))
+    freeze = FreezeExpressions(ctx)
+    typify = Typifier(ctx)
+
+    x, y, z = sp.symbols("x, y, z")
+
+    fortytwo = PsExpression.make(PsConstant(42))
+
+    expr = typify(PsEq(fortytwo, fortytwo))
+    assert expr.dtype == Bool(const=True)
+    assert fortytwo.dtype == Fp(16, const=True)
+    assert fortytwo.constant.dtype == Fp(16, const=True)
+
+    decl = freeze(Assignment(x, sp.Ge(3, 4, evaluate=False)))
+    decl = typify(decl)
+
+    assert ctx.get_symbol("x").dtype == Bool()
+    for c in dfs_preorder(decl.rhs, lambda n: isinstance(n, PsConstantExpr)):
+        assert c.dtype == Fp(16, const=True)
+        assert c.constant.dtype == Fp(16, const=True)
+
+    thirtyone = PsExpression.make(PsConstant(31))
+    decl = PsDeclaration(freeze(y), PsCast(Fp(32), thirtyone))
+    decl = typify(decl)
+
+    assert ctx.get_symbol("y").dtype == Fp(32)
+    assert thirtyone.dtype == Fp(16, const=True)
+    assert thirtyone.constant.dtype == Fp(16, const=True)
+
+
 def test_constant_decls():
     ctx = KernelCreationContext(default_dtype=Fp(16))
     freeze = FreezeExpressions(ctx)
@@ -571,19 +601,3 @@ def test_cfunction():
 
     with pytest.raises(TypificationError):
         _ = typify(PsCall(threeway, (x, p)))
-
-
-def test_inference_fails():
-    ctx = KernelCreationContext()
-    typify = Typifier(ctx)
-
-    x = PsExpression.make(PsConstant(42))
-
-    with pytest.raises(TypificationError):
-        typify(PsEq(x, x))
-
-    with pytest.raises(TypificationError):
-        typify(PsArrayInitList([x]))
-
-    with pytest.raises(TypificationError):
-        typify(PsCast(ctx.default_dtype, x))