diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 49bc302867c51431210bb392e501e4269baee428..fc085e2be99f61204cde92438811a3b4e41c8bf7 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -68,18 +68,12 @@ class TypeContext:
      - A set of restrictions on the target type:
        - `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides
        - Additional restrictions may be added in the future.
-    
-    The type context also tracks the tree traversal of the typifier:
-    
-     - ``is_lhs`` is set to True while a left-hand side expression is being processed,
-       and `False` while a right-hand side expression is processed.
     """
 
     def __init__(
         self,
         target_type: PsType | None = None,
         require_nonconst: bool = False,
-        is_lhs: bool = False
     ):
         self._require_nonconst = require_nonconst
         self._deferred_exprs: list[PsExpression] = []
@@ -88,8 +82,6 @@ class TypeContext:
             self._fix_constness(target_type) if target_type is not None else None
         )
 
-        self._is_lhs = is_lhs
-
     @property
     def target_type(self) -> PsType | None:
         return self._target_type
@@ -97,14 +89,6 @@ class TypeContext:
     @property
     def require_nonconst(self) -> bool:
         return self._require_nonconst
-    
-    @property
-    def is_lhs(self) -> bool:
-        return self._is_lhs
-    
-    @is_lhs.setter
-    def is_lhs(self, value: bool):
-        self._is_lhs = value
 
     def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None):
         """Applies the given ``dtype`` to this type context, and optionally to the given expression.
@@ -333,29 +317,41 @@ class Typifier:
                     self.visit(s)
 
             case PsDeclaration(lhs, rhs):
+                #   Only if the LHS is an untyped symbol, infer its type from the RHS
+                infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None
+
                 tc = TypeContext()
 
-                tc.is_lhs = True
-                self.visit_expr(lhs, tc)
-                tc.is_lhs = False
+                if infer_lhs:
+                    tc.infer_dtype(lhs)
+                else:
+                    self.visit_expr(lhs, tc)
+                    assert tc.target_type is not None
 
                 self.visit_expr(rhs, tc)
 
-                if tc.target_type is None:
+                if infer_lhs and tc.target_type is None:
                     #   no type has been inferred -> use the default dtype
                     tc.apply_dtype(self._ctx.default_dtype)
 
             case PsAssignment(lhs, rhs):
-                tc_lhs = TypeContext(require_nonconst=True, is_lhs=True)
-                self.visit_expr(lhs, tc_lhs)
+                infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None
+
+                tc_lhs = TypeContext(require_nonconst=True)
+
+                if infer_lhs:
+                    tc_lhs.infer_dtype(lhs)
+                else:
+                    self.visit_expr(lhs, tc_lhs)
+                    assert tc_lhs.target_type is not None
 
                 tc_rhs = TypeContext(target_type=tc_lhs.target_type)
                 self.visit_expr(rhs, tc_rhs)
 
-                if tc_rhs.target_type is None:
-                    tc_rhs.apply_dtype(self._ctx.default_dtype)
-
-                if tc_lhs.target_type is None:
+                if infer_lhs:
+                    if tc_rhs.target_type is None:
+                        tc_rhs.apply_dtype(self._ctx.default_dtype)
+                    
                     assert tc_rhs.target_type is not None
                     tc_lhs.apply_dtype(deconstify(tc_rhs.target_type))
 
@@ -398,15 +394,9 @@ class Typifier:
         """
         match expr:
             case PsSymbolExpr(symb):
-                if tc.is_lhs:
-                    if symb.dtype is not None:
-                        tc.apply_dtype(symb.dtype, expr)
-                    elif tc.is_lhs:
-                        tc.infer_dtype(expr)
-                else:
-                    if symb.dtype is None:
-                        symb.dtype = self._ctx.default_dtype
-                    tc.apply_dtype(symb.dtype, expr)
+                if symb.dtype is None:
+                    symb.dtype = self._ctx.default_dtype
+                tc.apply_dtype(symb.dtype, expr)
 
             case PsConstantExpr(c):
                 if c.dtype is not None:
diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py
index ca170401fcc78900e5fd9a6e43ba118dccce2c5a..d3da7e8881d631266d58ee6cc0d4d3612a2900a1 100644
--- a/tests/nbackend/kernelcreation/test_typification.py
+++ b/tests/nbackend/kernelcreation/test_typification.py
@@ -16,6 +16,7 @@ from pystencils.backend.ast.structural import (
 from pystencils.backend.ast.expressions import (
     PsConstantExpr,
     PsSymbolExpr,
+    PsSubscript,
     PsBinOp,
     PsAnd,
     PsOr,
@@ -27,12 +28,12 @@ from pystencils.backend.ast.expressions import (
     PsGt,
     PsLt,
     PsCall,
-    PsTernary
+    PsTernary,
 )
 from pystencils.backend.constants import PsConstant
 from pystencils.backend.functions import CFunction
 from pystencils.types import constify, create_type, create_numeric_type
-from pystencils.types.quick import Fp, Int, Bool
+from pystencils.types.quick import Fp, Int, Bool, Arr
 from pystencils.backend.kernelcreation.context import KernelCreationContext
 from pystencils.backend.kernelcreation.freeze import FreezeExpressions
 from pystencils.backend.kernelcreation.typification import Typifier, TypificationError
@@ -242,9 +243,11 @@ def test_lhs_inference():
     assert ctx.get_symbol("z").dtype == Fp(16)
     assert fasm.lhs.dtype == Fp(16)
 
-    fasm = PsDeclaration(PsExpression.make(ctx.get_symbol("r")), PsLe(freeze(q), freeze(2 * q)))
+    fasm = PsDeclaration(
+        PsExpression.make(ctx.get_symbol("r")), PsLe(freeze(q), freeze(2 * q))
+    )
     fasm = typify(fasm)
-    
+
     assert ctx.get_symbol("r").dtype == Bool()
     assert fasm.lhs.dtype == constify(Bool())
     assert fasm.rhs.dtype == constify(Bool())
@@ -282,6 +285,26 @@ def test_erronous_typing():
         typify(fasm)
 
 
+def test_invalid_indices():
+    ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
+    typify = Typifier(ctx)
+
+    arr = PsExpression.make(ctx.get_symbol("arr", Arr(Fp(64))))
+    x, y, z = [PsExpression.make(ctx.get_symbol(x)) for x in "xyz"]
+
+    #   Using default-typed symbols as array indices is illegal when the default type is a float
+
+    fasm = PsAssignment(PsSubscript(arr, x + y), z)
+
+    with pytest.raises(TypificationError):
+        typify(fasm)
+
+    fasm = PsAssignment(z, PsSubscript(arr, x + y))
+
+    with pytest.raises(TypificationError):
+        typify(fasm)
+
+
 def test_typify_integer_binops():
     ctx = KernelCreationContext()
     freeze = FreezeExpressions(ctx)
@@ -410,7 +433,7 @@ def test_invalid_conditions():
     with pytest.raises(TypificationError):
         typify(cond)
 
-    
+
 def test_typify_ternary():
     ctx = KernelCreationContext()
     typify = Typifier(ctx)