diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 03ab76ea9e711d8945618975e85061aa9736ca96..ad2fd4b7522394c74ae66f7895fdd32769f528f1 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -495,8 +495,8 @@ class CustomSympyPrinter(CCodePrinter):
                 known = self.known_functions[arg.__class__.__name__.lower()]
                 code = self._print(arg)
                 return code.replace(known, f"{known}f")
-            elif isinstance(arg, sp.Pow) and data_type == BasicType('float32'):
-                known = ['sqrt', 'cbrt', 'pow']
+            elif isinstance(arg, (sp.Pow, sp.exp)) and data_type == BasicType('float32'):
+                known = ['sqrt', 'cbrt', 'pow', 'exp']
                 code = self._print(arg)
                 for k in known:
                     if k in code:
diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py
index b4648835a662027b124a6c5b0192f67b76da5980..0d133038688f722d7d428c01442db2d8fb2458a9 100644
--- a/pystencils/typing/leaf_typing.py
+++ b/pystencils/typing/leaf_typing.py
@@ -216,7 +216,8 @@ class TypeAdder:
                 else:
                     new_args.append(a)
             return expr.func(*new_args) if new_args else expr, collated_type
-        elif isinstance(expr, (sp.Pow, InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)):
+        elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction,
+                               HyperbolicFunction)):
             args_types = [self.figure_out_type(arg) for arg in expr.args]
             collated_type = collate_types([t for _, t in args_types])
             new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
diff --git a/pystencils_tests/test_simplifications.py b/pystencils_tests/test_simplifications.py
index fbce59aca95db083a35a29d4f2e0335767c4f43f..21a5d5a9b54114288ea33468951d88dbee9dfc1c 100644
--- a/pystencils_tests/test_simplifications.py
+++ b/pystencils_tests/test_simplifications.py
@@ -4,6 +4,7 @@ import pytest
 import pystencils.config
 import sympy as sp
 import pystencils as ps
+import numpy as np
 
 from pystencils.simp import subexpression_substitution_in_main_assignments
 from pystencils.simp import add_subexpressions_for_divisions
@@ -143,29 +144,27 @@ def test_add_subexpressions_for_field_reads():
 
 
 @pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
-@pytest.mark.parametrize('simplification', (True, False))
+@pytest.mark.parametrize('dtype', ('float32', 'float64'))
 @pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason")
-def test_sympy_optimizations(target, simplification):
+def test_sympy_optimizations(target, dtype):
     if target == ps.Target.GPU:
         pytest.importorskip("pycuda")
-    src, dst = ps.fields('src, dst:  float32[2d]')
+    src, dst = ps.fields(f'src, dst:  {dtype}[2d]')
 
-    # Triggers Sympy's expm1 optimization
-    # Sympy's expm1 optimization is tedious to use and the behaviour is highly depended on the sympy version. In
-    # some cases the exp expression has to be encapsulated in brackets or multiplied with 1 or 1.0
-    # for sympy to work properly ...
     assignments = ps.AssignmentCollection({
         src[0, 0]: 1.0 * (sp.exp(dst[0, 0]) - 1)
     })
 
-    config = pystencils.config.CreateKernelConfig(target=target, default_assignment_simplifications=simplification)
+    config = pystencils.config.CreateKernelConfig(target=target, default_number_float=dtype)
     ast = ps.create_kernel(assignments, config=config)
 
+    ps.show_code(ast)
+
     code = ps.get_code_str(ast)
-    if simplification:
-        assert 'expm1(' in code
-    else:
-        assert 'expm1(' not in code
+    if dtype == 'float32':
+        assert 'expf(' in code
+    elif dtype == 'float64':
+        assert 'exp(' in code
 
 
 @pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))