diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py
index 5ed8c4ea9ee62fca33933eead71be6fff5571704..e1f42dc60467c2cd8a7fa874afd79323e77111a7 100644
--- a/pystencils/simp/simplifications.py
+++ b/pystencils/simp/simplifications.py
@@ -8,6 +8,7 @@ from pystencils.assignment import Assignment
 from pystencils.astnodes import Node
 from pystencils.field import Field
 from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
+from pystencils.typing import TypedSymbol
 
 
 def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
@@ -168,12 +169,14 @@ def add_subexpressions_for_sums(ac):
     return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
 
 
-def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True):
+def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None):
     r"""Substitutes field accesses on rhs of assignments with subexpressions
 
     Can change semantics of the update rule (which is the goal of this transformation)
     This is useful if a field should be update in place - all values are loaded before into subexpression variables,
     then the new values are computed and written to the same field in-place.
+    Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have
+    this data type. This is useful for mixed precision kernels
     """
     field_reads = set()
     to_iterate = []
@@ -185,7 +188,17 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments
     for assignment in to_iterate:
         if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):
             field_reads.update(assignment.rhs.atoms(Field.Access))
-    substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}
+
+    if not field_reads:
+        return
+
+    substitutions = dict()
+    for fa in field_reads:
+        lhs = next(ac.subexpression_symbol_generator)
+        if data_type is not None:
+            substitutions.update({fa: TypedSymbol(lhs.name, data_type)})
+        else:
+            substitutions.update({fa: lhs})
     return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
                                      substitute_on_lhs=False, sort_topologically=False)
 
diff --git a/pystencils_tests/test_simplifications.py b/pystencils_tests/test_simplifications.py
index 61d009d03bd341032896076b65fabcb3630859d5..2814ac14f5fb5734e0d23a92d8e0aaf975691b36 100644
--- a/pystencils_tests/test_simplifications.py
+++ b/pystencils_tests/test_simplifications.py
@@ -4,14 +4,14 @@ import pytest
 import pystencils.config
 import sympy as sp
 import pystencils as ps
-import numpy as np
 
+from pystencils import Assignment, AssignmentCollection, fields
 from pystencils.simp import subexpression_substitution_in_main_assignments
 from pystencils.simp import add_subexpressions_for_divisions
 from pystencils.simp import add_subexpressions_for_sums
 from pystencils.simp import add_subexpressions_for_field_reads
 from pystencils.simp.simplifications import add_subexpressions_for_constants
-from pystencils import Assignment, AssignmentCollection, fields
+from pystencils.typing import BasicType, TypedSymbol
 
 a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
 s0, s1, s2, s3 = sp.symbols("s_:4")
@@ -133,14 +133,18 @@ def test_add_subexpressions_for_sums():
 def test_add_subexpressions_for_field_reads():
     s, v = fields("s(5), v(5): double[2D]")
     subexpressions = []
-    main = [
-        Assignment(s[0, 0](0), 3 * v[0, 0](0)),
-        Assignment(s[0, 0](1), 10 * v[0, 0](1))
-    ]
+
+    main = [Assignment(s[0, 0](0), 3 * v[0, 0](0)),
+            Assignment(s[0, 0](1), 10 * v[0, 0](1))]
+
     ac = AssignmentCollection(main, subexpressions)
     assert len(ac.subexpressions) == 0
-    ac = add_subexpressions_for_field_reads(ac)
-    assert len(ac.subexpressions) == 2
+    ac2 = add_subexpressions_for_field_reads(ac)
+    assert len(ac2.subexpressions) == 2
+    ac3 = add_subexpressions_for_field_reads(ac, data_type="float32")
+    assert len(ac3.subexpressions) == 2
+    assert isinstance(ac3.subexpressions[0].lhs, TypedSymbol)
+    assert ac3.subexpressions[0].lhs.dtype == BasicType("float32")
 
 
 @pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))