diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index 81f373e079477bace01d77f0231e51ad914a74aa..09bf9a57b80846ecab8181b9589513374e1e9e05 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -431,11 +431,15 @@ def collate_types(types,
 def get_type_of_expression(expr,
                            default_float_type='double',
                            default_int_type='int',
-                           default_complex_type='complex128',
                            symbol_type_dict=None):
 
     from pystencils.astnodes import ResolvedFieldAccess
     from pystencils.cpu.vectorization import vec_all, vec_any
+    # TODO: determine more general
+    if default_float_type == 'double' or default_float_type == 'float64':
+        default_complex_type = 'complex128'
+    else:
+        default_complex_type = 'complex64'
 
     if not symbol_type_dict:
         symbol_type_dict = defaultdict(lambda: create_type('double'))
@@ -443,7 +447,6 @@ def get_type_of_expression(expr,
     get_type = partial(get_type_of_expression,
                        default_float_type=default_float_type,
                        default_int_type=default_int_type,
-                       default_complex_type=default_complex_type,
                        symbol_type_dict=symbol_type_dict)
 
     expr = sp.sympify(expr)
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index 2aa5f1603bfe3310666c2ec55fb8b09a720243cc..8469fc79a91268387aab90a8111609eead62d533 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -12,8 +12,8 @@ from sympy.logic.boolalg import Boolean
 import pystencils.astnodes as ast
 from pystencils.assignment import Assignment
 from pystencils.data_types import (
-    PointerType, StructType, TypedSymbol, cast_func, collate_types, create_type, get_base_type,
-    get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
+    PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type,
+    get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
 from pystencils.field import AbstractField, Field, FieldType
 from pystencils.kernelparameters import FieldPointerSymbol
 from pystencils.simp.assignment_collection import AssignmentCollection
@@ -898,6 +898,11 @@ class KernelConstraintsCheck:
             return rhs
         elif isinstance(rhs, TypedSymbol):
             return rhs
+        elif isinstance(rhs, sp.numbers.ImaginaryUnit):
+            return TypedImaginaryUnit(self._type_for_symbol['_ImaginaryUnit'])
+        elif isinstance(rhs, sp.Symbol):
+            return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
+            return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
         elif isinstance(rhs, sp.Symbol):
             return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
         elif type_constants and isinstance(rhs, np.generic):
@@ -1167,6 +1172,11 @@ def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='i
         dictionary, mapping symbol name to type
     """
     result = defaultdict(lambda: default_type)
+    if default_type == 'double' or default_type == 'float64':  # todo: fix
+        result['_ImaginaryUnit'] = create_type('complex128')
+    else:
+        result['_ImaginaryUnit'] = create_type('complex64')
+
     for eq in eqs:
         if isinstance(eq, ast.Conditional):
             result.update(typing_from_sympy_inspection(eq.true_block.args))
diff --git a/pystencils_tests/test_complex_numbers.py b/pystencils_tests/test_complex_numbers.py
index d1ac2bf7907ecc1afa6c9e5b3d66fcd865e6e4e7..5a161914d51423e37476f3ffab0e50a6df7435d1 100644
--- a/pystencils_tests/test_complex_numbers.py
+++ b/pystencils_tests/test_complex_numbers.py
@@ -20,7 +20,8 @@ from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, create_type
 X, Y = pystencils.fields('x, y: complex64[2d]')
 A, B = pystencils.fields('a, b: float32[2d]')
 S1, S2 = sympy.symbols('S1, S2')
-T64 = TypedSymbol('t', create_type('complex64'))
+# T64 = TypedSymbol('t', create_type('complex64'))
+T64 = sympy.Symbol('t')
 
 TEST_ASSIGNMENTS = [
     AssignmentCollection({X[0, 0]: 1j}),
@@ -48,11 +49,9 @@ SCALAR_DTYPES = ['float32', 'float64']
 @pytest.mark.parametrize("assignment, scalar_dtypes",
                          itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES))
 def test_complex_numbers(assignment, scalar_dtypes):
-    ast = pystencils.create_kernel(assignment.subs(
-        {sympy.sympify(1j).args[1]:
-         TypedImaginaryUnit(create_type('complex64'))}),
-        target='cpu',
-        data_type=scalar_dtypes)
+    ast = pystencils.create_kernel(assignment,
+                                   target='cpu',
+                                   data_type='float32')
     code = str(pystencils.show_code(ast))
 
     print(code)
@@ -94,11 +93,9 @@ SCALAR_DTYPES = ['float32', 'float64']
 @pytest.mark.parametrize("assignment, scalar_dtypes",
                          itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES))
 def test_complex_numbers_64(assignment, scalar_dtypes):
-    ast = pystencils.create_kernel(assignment.subs(
-        {sympy.sympify(1j).args[1]:
-         TypedImaginaryUnit(create_type('complex128'))}),
-        target='cpu',
-        data_type=scalar_dtypes)
+    ast = pystencils.create_kernel(assignment,
+                                   target='cpu',
+                                   data_type='double')
     code = str(pystencils.show_code(ast))
 
     print(code)
@@ -113,5 +110,8 @@ def test_get_data_type():
     from pystencils.data_types import get_type_of_expression
 
     i = TypedImaginaryUnit(create_type('complex128'))
-    # assert get_type_of_expression(i+3).numpy_dtype == np.complex128
+    assert get_type_of_expression(i+3).numpy_dtype == np.complex128
     assert get_type_of_expression(i+3.).numpy_dtype == np.complex128
+    i = TypedImaginaryUnit(create_type('complex64'))
+    assert get_type_of_expression(i+3, default_float_type='float32').numpy_dtype == np.complex64
+    assert get_type_of_expression(i+3., default_float_type='float32').numpy_dtype == np.complex64