diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 4e7c5cf2445c07f79dddc2c7eba93c520e2b320c..213df6f559fd18e3ba451451569e240971cf2dbf 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -17,17 +17,24 @@ from pystencils.integer_functions import (
     int_div, int_power_of_2, modulo_ceil)
 from pystencils.kernelparameters import FieldPointerSymbol
 
+from sympy.printing.codeprinter import requires
 try:
     from sympy.printing.ccode import C99CodePrinter as CCodePrinter
 except ImportError:
     from sympy.printing.ccode import CCodePrinter  # for sympy versions < 1.1
 
-__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
+__all__ = [
+    'generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers',
+    'CustomSympyPrinter'
+]
 
 KERNCRAFT_NO_TERNARY_MODE = False
 
 
-def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom_backend=None) -> str:
+def generate_c(ast_node: Node,
+               signature_only: bool = False,
+               dialect='c',
+               custom_backend=None) -> str:
     """Prints an abstract syntax tree node as C or CUDA code.
 
     This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
@@ -76,14 +83,28 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom
 def get_global_declarations(ast):
     global_declarations = []
 
+    # global_declarations.append(
+    # CustomCodeNode(
+    # "using namespace std::complex_literals;", set(),
+    # set()))
     def visit_node(sub_ast):
+        nonlocal global_declarations
         if hasattr(sub_ast, "required_global_declarations"):
-            nonlocal global_declarations
             global_declarations += sub_ast.required_global_declarations
 
         if hasattr(sub_ast, "args"):
             for node in sub_ast.args:
                 visit_node(node)
+        if isinstance(sub_ast, KernelFunction):
+            if any(
+                    np.issubdtype(a.dtype.numpy_dtype, np.complexfloating)
+                    for a in sub_ast.atoms(sp.Expr) if hasattr(a, 'dtype')
+                    and hasattr(a.dtype, 'numpy_dtype')):
+                if sub_ast.backend == 'cpu':
+                    global_declarations.append(
+                        CustomCodeNode(
+                            "using namespace std::complex_literals;", set(),
+                            set()))
 
     visit_node(ast)
 
@@ -94,9 +115,17 @@ def get_headers(ast_node: Node) -> Set[str]:
     """Return a set of header files, necessary to compile the printed C-like code."""
     headers = set()
 
+    headers.update({'"complex_helper.hpp"'})
     if isinstance(ast_node, KernelFunction) and ast_node.instruction_set:
         headers.update(ast_node.instruction_set['headers'])
 
+    if isinstance(ast_node, KernelFunction):
+        if any(
+                np.issubdtype(a.dtype.numpy_dtype, np.complexfloating)
+                for a in ast_node.atoms(sp.Symbol) if hasattr(a,'dtype') and hasattr(a.dtype, 'numpy_dtype')):
+            if ast_node.backend == 'c':
+                headers.update({"<complex>"})
+
     if hasattr(ast_node, 'headers'):
         headers.update(ast_node.headers)
     for a in ast_node.args:
@@ -136,8 +165,11 @@ class CustomCodeNode(Node):
 class PrintNode(CustomCodeNode):
     # noinspection SpellCheckingInspection
     def __init__(self, symbol_to_print):
-        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name)
-        super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
+        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (
+            symbol_to_print.name, symbol_to_print.name)
+        super(PrintNode, self).__init__(code,
+                                        symbols_read=[symbol_to_print],
+                                        symbols_defined=set())
         self.headers.append("<iostream>")
 
 
@@ -146,11 +178,15 @@ class PrintNode(CustomCodeNode):
 
 # noinspection PyPep8Naming
 class CBackend:
-
-    def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
+    def __init__(self,
+                 sympy_printer=None,
+                 signature_only=False,
+                 vector_instruction_set=None,
+                 dialect='c'):
         if sympy_printer is None:
             if vector_instruction_set is not None:
-                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
+                self.sympy_printer = VectorizedCustomSympyPrinter(
+                    vector_instruction_set)
             else:
                 self.sympy_printer = CustomSympyPrinter()
         else:
@@ -175,20 +211,25 @@ class CBackend:
             method_name = "_print_" + cls.__name__
             if hasattr(self, method_name):
                 return getattr(self, method_name)(node)
-        raise NotImplementedError(self.__class__.__name__ + " does not support node of type " + node.__class__.__name__)
+        raise NotImplementedError(self.__class__.__name__ +
+                                  " does not support node of type " +
+                                  node.__class__.__name__)
 
     def _print_Type(self, node):
         return str(node)
 
     def _print_KernelFunction(self, node):
-        function_arguments = ["%s %s" % (self._print(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
+        function_arguments = [
+            "%s %s" % (self._print(s.symbol.dtype), s.symbol.name)
+            for s in node.get_parameters()
+        ]
         launch_bounds = ""
         if self._dialect == 'cuda':
             max_threads = node.indexing.max_threads_per_block()
             if max_threads:
                 launch_bounds = "__launch_bounds__({}) ".format(max_threads)
-        func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (launch_bounds, node.function_name,
-                                                          ", ".join(function_arguments))
+        func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (
+            launch_bounds, node.function_name, ", ".join(function_arguments))
         if self._signatureOnly:
             return func_declaration
 
@@ -197,16 +238,22 @@ class CBackend:
 
     def _print_Block(self, node):
         block_contents = "\n".join([self._print(child) for child in node.args])
-        return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
+        return "{\n%s\n}" % (
+            self._indent + self._indent.join(block_contents.splitlines(True)))
 
     def _print_PragmaBlock(self, node):
         return "%s\n%s" % (node.pragma_line, self._print_Block(node))
 
     def _print_LoopOverCoordinate(self, node):
         counter_symbol = node.loop_counter_name
-        start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start))
-        condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop))
-        update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),)
+        start = "int %s = %s" % (counter_symbol,
+                                 self.sympy_printer.doprint(node.start))
+        condition = "%s < %s" % (counter_symbol,
+                                 self.sympy_printer.doprint(node.stop))
+        update = "%s += %s" % (
+            counter_symbol,
+            self.sympy_printer.doprint(node.step),
+        )
         loop_str = "for (%s; %s; %s)" % (start, condition, update)
 
         prefix = "\n".join(node.prefix_lines)
@@ -221,11 +268,13 @@ class CBackend:
             else:
                 prefix = ''
             data_type = prefix + self._print(node.lhs.dtype) + " "
-            return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
+            return "%s%s = %s;" % (data_type,
+                                   self.sympy_printer.doprint(node.lhs),
                                    self.sympy_printer.doprint(node.rhs))
         else:
             lhs_type = get_type_of_expression(node.lhs)
-            if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
+            if type(lhs_type) is VectorType and isinstance(
+                    node.lhs, cast_func):
                 arg, data_type, aligned, nontemporal = node.lhs.args
                 instr = 'storeU'
                 if aligned:
@@ -237,10 +286,12 @@ class CBackend:
                 else:
                     rhs = node.rhs
 
-                return self._vector_instruction_set[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
-                                                                  self.sympy_printer.doprint(rhs)) + ';'
+                return self._vector_instruction_set[instr].format(
+                    "&" + self.sympy_printer.doprint(node.lhs.args[0]),
+                    self.sympy_printer.doprint(rhs)) + ';'
             else:
-                return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
+                return "%s = %s;" % (self.sympy_printer.doprint(
+                    node.lhs), self.sympy_printer.doprint(node.rhs))
 
     def _print_TemporaryMemoryAllocation(self, node):
         align = 64
@@ -256,7 +307,8 @@ class CBackend:
 
     def _print_TemporaryMemoryFree(self, node):
         align = 64
-        return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
+        return "free(%s - %d);" % (self.sympy_printer.doprint(
+            node.symbol.name), node.offset(align))
 
     def _print_SkipIteration(self, _):
         return "continue;"
@@ -267,7 +319,9 @@ class CBackend:
     def _print_Conditional(self, node):
         cond_type = get_type_of_expression(node.condition_expr)
         if isinstance(cond_type, VectorType):
-            raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all")
+            raise ValueError(
+                "Problem with Conditional inside vectorized loop - use vec_any or vec_all"
+            )
         condition_expr = self.sympy_printer.doprint(node.condition_expr)
         true_block = self._print_Block(node.true_block)
         result = "if (%s)\n%s " % (condition_expr, true_block)
@@ -279,14 +333,13 @@ class CBackend:
     def _print_DestructuringBindingsForFieldClass(self, node):
         # Define all undefined symbols
         undefined_field_symbols = node.symbols_defined
-        destructuring_bindings = ["%s %s = %s.%s;" %
-                                  (u.dtype,
-                                   u.name,
-                                   u.field_name if hasattr(u, 'field_name') else u.field_names[0],
-                                   node.CLASS_TO_MEMBER_DICT[u.__class__] %
-                                   (() if type(u) == FieldPointerSymbol else (u.coordinate,)))
-                                  for u in undefined_field_symbols
-                                  ]
+        destructuring_bindings = [
+            "%s %s = %s.%s;" %
+            (u.dtype, u.name, u.field_name if hasattr(u, 'field_name') else
+             u.field_names[0], node.CLASS_TO_MEMBER_DICT[u.__class__] %
+             (() if type(u) == FieldPointerSymbol else (u.coordinate, )))
+            for u in undefined_field_symbols
+        ]
         destructuring_bindings.sort()  # only for code aesthetics
         return "{\n" + self._indent + \
                ("\n" + self._indent).join(destructuring_bindings) + \
@@ -300,7 +353,6 @@ class CBackend:
 
 # noinspection PyPep8Naming
 class CustomSympyPrinter(CCodePrinter):
-
     def __init__(self):
         super(CustomSympyPrinter, self).__init__()
         self._float_type = create_type("float32")
@@ -312,12 +364,16 @@ class CustomSympyPrinter(CCodePrinter):
     def _print_Pow(self, expr):
         """Don't use std::pow function, for small integer exponents, write as multiplication"""
         if not expr.free_symbols:
-            return self._typed_number(expr.evalf(), get_type_of_expression(expr))
+            return self._typed_number(expr.evalf(),
+                                      get_type_of_expression(expr))
 
         if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
-            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
-        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
-            return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
+            return "(" + self._print(
+                sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
+        elif expr.exp.is_integer and expr.exp.is_number and -8 < expr.exp < 0:
+            return "1 / ({})".format(
+                self._print(sp.Mul(*[expr.base] * (-expr.exp),
+                                   evaluate=False)))
         else:
             return super(CustomSympyPrinter, self)._print_Pow(expr)
 
@@ -328,7 +384,8 @@ class CustomSympyPrinter(CCodePrinter):
 
     def _print_Equality(self, expr):
         """Equality operator is not printable in default printer"""
-        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'
+        return '((' + self._print(expr.lhs) + ") == (" + self._print(
+            expr.rhs) + '))'
 
     def _print_Piecewise(self, expr):
         """Print piecewise in one line (remove newlines)"""
@@ -347,9 +404,11 @@ class CustomSympyPrinter(CCodePrinter):
             return expr.to_c(self._print)
         if isinstance(expr, reinterpret_cast_func):
             arg, data_type = expr.args
-            return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg))
+            return "*((%s)(& %s))" % (PointerType(
+                data_type, restrict=False), self._print(arg))
         elif isinstance(expr, address_of):
-            assert len(expr.args) == 1, "address_of must only have one argument"
+            assert len(
+                expr.args) == 1, "address_of must only have one argument"
             return "&(%s)" % self._print(expr.args[0])
         elif isinstance(expr, cast_func):
             arg, data_type = expr.args
@@ -366,11 +425,14 @@ class CustomSympyPrinter(CCodePrinter):
         elif isinstance(expr, fast_inv_sqrt):
             return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
         elif expr.func in infix_functions:
-            return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
+            return "(%s %s %s)" % (self._print(
+                expr.args[0]), infix_functions[expr.func],
+                                   self._print(expr.args[1]))
         elif expr.func == int_power_of_2:
             return "(1 << (%s))" % (self._print(expr.args[0]))
         elif expr.func == int_div:
-            return "((%s) / (%s))" % (self._print(expr.args[0]), self._print(expr.args[1]))
+            return "((%s) / (%s))" % (self._print(
+                expr.args[0]), self._print(expr.args[1]))
         else:
             return super(CustomSympyPrinter, self)._print_Function(expr)
 
@@ -381,9 +443,9 @@ class CustomSympyPrinter(CCodePrinter):
         elif dtype.numpy_dtype == np.float64:
             return res + '.0' if '.' not in res else res
         elif dtype.numpy_dtype == np.complex64:
-            return f"{_typed_number(number.real, np.float32)} + {_typed_number(number.real, np.float32).replace('f', 'if')}"
+            return f"{self._typed_number(number.real, np.float32)} + {self._typed_number(number.real, np.float32).replace('f', 'if')}"
         elif dtype.numpy_dtype == np.complex128:
-            return f"{_typed_number(number.real, np.float64)} + {_typed_number(number.real, np.float64)}i"
+            return f"{self._typed_number(number.real, np.float64)} + {self._typed_number(number.real, np.float64)}i"
         else:
             return res
 
@@ -406,7 +468,8 @@ class CustomSympyPrinter(CCodePrinter):
             end=self._print(end),
             expr=self._print(expr.function),
             increment=str(1),
-            condition=self._print(var) + ' <= ' + self._print(end)  # if start < end else '>='
+            condition=self._print(var) + ' <= ' +
+            self._print(end)  # if start < end else '>='
         )
         return code
 
@@ -429,12 +492,15 @@ class CustomSympyPrinter(CCodePrinter):
             end=self._print(end),
             expr=self._print(expr.function),
             increment=str(1),
-            condition=self._print(var) + ' <= ' + self._print(end)  # if start < end else '>='
+            condition=self._print(var) + ' <= ' +
+            self._print(end)  # if start < end else '>='
         )
         return code
+
     _print_Max = C89CodePrinter._print_Max
     _print_Min = C89CodePrinter._print_Min
 
+    @requires(headers={'stdbool.h'})
     def _print_re(self, expr):
         return f"std::real({self._print(expr.args[0])})"
 
@@ -442,8 +508,11 @@ class CustomSympyPrinter(CCodePrinter):
         return f"std::imag({self._print(expr.args[0])})"
 
     def _print_ImaginaryUnit(self, expr):
-        return "1if"
-        
+        return "std::complex<double>{0,1}"
+
+    def _print_Complex(self, expr):
+        return self._typed_number(expr, np.complex64)
+
 
 # noinspection PyPep8Naming
 class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -456,7 +525,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
     def _scalarFallback(self, func_name, expr, *args, **kwargs):
         expr_type = get_type_of_expression(expr)
         if type(expr_type) is not VectorType:
-            return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
+            return getattr(super(VectorizedCustomSympyPrinter, self),
+                           func_name)(expr, *args, **kwargs)
         else:
             assert self.instruction_set['width'] == expr_type.width
             return None
@@ -464,7 +534,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
     def _print_Function(self, expr):
         if isinstance(expr, vector_memory_access):
             arg, data_type, aligned, _ = expr.args
-            instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
+            instruction = self.instruction_set[
+                'loadA'] if aligned else self.instruction_set['loadU']
             return instruction.format("& " + self._print(arg))
         elif isinstance(expr, cast_func):
             arg, data_type = expr.args
@@ -473,7 +544,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
         elif expr.func == fast_division:
             result = self._scalarFallback('_print_Function', expr)
             if not result:
-                result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]))
+                result = self.instruction_set['/'].format(
+                    self._print(expr.args[0]), self._print(expr.args[1]))
             return result
         elif expr.func == fast_sqrt:
             return "({})".format(self._print(sp.sqrt(expr.args[0])))
@@ -481,21 +553,25 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
             result = self._scalarFallback('_print_Function', expr)
             if not result:
                 if self.instruction_set['rsqrt']:
-                    return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
+                    return self.instruction_set['rsqrt'].format(
+                        self._print(expr.args[0]))
                 else:
-                    return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
+                    return "({})".format(self._print(1 /
+                                                     sp.sqrt(expr.args[0])))
         elif isinstance(expr, vec_any):
             expr_type = get_type_of_expression(expr.args[0])
             if type(expr_type) is not VectorType:
                 return self._print(expr.args[0])
             else:
-                return self.instruction_set['any'].format(self._print(expr.args[0]))
+                return self.instruction_set['any'].format(
+                    self._print(expr.args[0]))
         elif isinstance(expr, vec_all):
             expr_type = get_type_of_expression(expr.args[0])
             if type(expr_type) is not VectorType:
                 return self._print(expr.args[0])
             else:
-                return self.instruction_set['all'].format(self._print(expr.args[0]))
+                return self.instruction_set['all'].format(
+                    self._print(expr.args[0]))
 
         return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
 
@@ -545,7 +621,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
         assert len(summands) >= 2
         processed = summands[0].term
         for summand in summands[1:]:
-            func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+']
+            func = self.instruction_set[
+                '-'] if summand.sign == -1 else self.instruction_set['+']
             processed = func.format(processed, summand.term)
         return processed
 
@@ -557,18 +634,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
         one = self.instruction_set['makeVec'].format(1.0)
 
         if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
-            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
+            return "(" + self._print(
+                sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
         elif expr.exp == -1:
             one = self.instruction_set['makeVec'].format(1.0)
-            return self.instruction_set['/'].format(one, self._print(expr.base))
+            return self.instruction_set['/'].format(one,
+                                                    self._print(expr.base))
         elif expr.exp == 0.5:
             return self.instruction_set['sqrt'].format(self._print(expr.base))
         elif expr.exp == -0.5:
             root = self.instruction_set['sqrt'].format(self._print(expr.base))
             return self.instruction_set['/'].format(one, root)
-        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
-            return self.instruction_set['/'].format(one,
-                                                    self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
+        elif expr.exp.is_integer and expr.exp.is_number and -8 < expr.exp < 0:
+            return self.instruction_set['/'].format(
+                one,
+                self._print(sp.Mul(*[expr.base] * (-expr.exp),
+                                   evaluate=False)))
         else:
             raise ValueError("Generic exponential not supported: " + str(expr))
 
@@ -612,14 +693,16 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
         if len(b) > 0:
             denominator_str = b_str[0]
             for item in b_str[1:]:
-                denominator_str = self.instruction_set['*'].format(denominator_str, item)
+                denominator_str = self.instruction_set['*'].format(
+                    denominator_str, item)
             result = self.instruction_set['/'].format(result, denominator_str)
 
         if inside_add:
             return sign, result
         else:
             if sign < 0:
-                return self.instruction_set['*'].format(self._print(S.NegativeOne), result)
+                return self.instruction_set['*'].format(
+                    self._print(S.NegativeOne), result)
             else:
                 return result
 
@@ -627,13 +710,15 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
         result = self._scalarFallback('_print_Relational', expr)
         if result:
             return result
-        return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
+        return self.instruction_set[expr.rel_op].format(
+            self._print(expr.lhs), self._print(expr.rhs))
 
     def _print_Equality(self, expr):
         result = self._scalarFallback('_print_Equality', expr)
         if result:
             return result
-        return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs))
+        return self.instruction_set['=='].format(self._print(expr.lhs),
+                                                 self._print(expr.rhs))
 
     def _print_Piecewise(self, expr):
         result = self._scalarFallback('_print_Piecewise', expr)
@@ -651,13 +736,17 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
 
         result = self._print(expr.args[-1][0])
         for true_expr, condition in reversed(expr.args[:-1]):
-            if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"):
+            if isinstance(condition, cast_func) and get_type_of_expression(
+                    condition.args[0]) == create_type("bool"):
                 if not KERNCRAFT_NO_TERNARY_MODE:
-                    result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr),
-                                                           result)
+                    result = "(({}) ? ({}) : ({}))".format(
+                        self._print(condition.args[0]), self._print(true_expr),
+                        result)
                 else:
                     print("Warning - skipping ternary op")
             else:
                 # noinspection SpellCheckingInspection
-                result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
+                result = self.instruction_set['blendv'].format(
+                    result, self._print(true_expr), self._print(condition))
+        return result
         return result
diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index a0c6efc2a8abb4a496aa1f4cecd6e39ea937c3a7..b52a9a1bb4459f522effbe4effe174338232c620 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -123,7 +123,7 @@ class boolean_cast_func(cast_func, Boolean):
 
 # noinspection PyPep8Naming
 class vector_memory_access(cast_func):
-    nargs = (4,)
+    nargs = (4, )
 
 
 # noinspection PyPep8Naming
@@ -172,7 +172,8 @@ class TypedSymbol(sp.Symbol):
     @property
     def is_integer(self):
         if hasattr(self.dtype, 'numpy_dtype'):
-            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
+            return np.issubdtype(self.dtype.numpy_dtype,
+                                 np.integer) or super().is_integer
         else:
             return super().is_integer
 
@@ -292,12 +293,10 @@ to_ctypes.map = {
     np.dtype(np.int16): ctypes.c_int16,
     np.dtype(np.int32): ctypes.c_int32,
     np.dtype(np.int64): ctypes.c_int64,
-
     np.dtype(np.uint8): ctypes.c_uint8,
     np.dtype(np.uint16): ctypes.c_uint16,
     np.dtype(np.uint32): ctypes.c_uint32,
     np.dtype(np.uint64): ctypes.c_uint64,
-
     np.dtype(np.float32): ctypes.c_float,
     np.dtype(np.float64): ctypes.c_double,
 }
@@ -330,7 +329,8 @@ def ctypes_from_llvm(data_type):
     elif isinstance(data_type, ir.VoidType):
         return None  # Void type is not supported by ctypes
     else:
-        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))
+        raise NotImplementedError('Data type %s of %s is not supported yet' %
+                                  (type(data_type), data_type))
 
 
 def to_llvm_type(data_type):
@@ -353,12 +353,10 @@ if ir:
         np.dtype(np.int16): ir.IntType(16),
         np.dtype(np.int32): ir.IntType(32),
         np.dtype(np.int64): ir.IntType(64),
-
         np.dtype(np.uint8): ir.IntType(8),
         np.dtype(np.uint16): ir.IntType(16),
         np.dtype(np.uint32): ir.IntType(32),
         np.dtype(np.uint64): ir.IntType(64),
-
         np.dtype(np.float32): ir.FloatType(),
         np.dtype(np.float64): ir.DoubleType(),
     }
@@ -370,11 +368,21 @@ def peel_off_type(dtype, type_to_peel_off):
     return dtype
 
 
-def collate_types(types):
+def collate_types(types, forbid_collation_to_complex=False, forbid_collation_to_float=False):
     """
     Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
     Uses the collation rules from numpy.
     """
+    if forbid_collation_to_complex:
+        types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.complexfloating)]
+        if not types:
+            types = [ create_type(np.float64)]
+
+    if forbid_collation_to_float:
+        types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)]
+        if not types:
+            types = [ create_type(np.int64) ]
+
 
     # Pointer arithmetic case i.e. pointer + integer is allowed
     if any(type(t) is PointerType for t in types):
@@ -382,7 +390,8 @@ def collate_types(types):
         for t in types:
             if type(t) is PointerType:
                 if pointer_type is not None:
-                    raise ValueError("Cannot collate the combination of two pointer types")
+                    raise ValueError(
+                        "Cannot collate the combination of two pointer types")
                 pointer_type = t
             elif type(t) is BasicType:
                 if not (t.is_int() or t.is_uint()):
@@ -394,7 +403,8 @@ def collate_types(types):
     # peel of vector types, if at least one vector type occurred the result will also be the vector type
     vector_type = [t for t in types if type(t) is VectorType]
     if not all_equal(t.width for t in vector_type):
-        raise ValueError("Collation failed because of vector types with different width")
+        raise ValueError(
+            "Collation failed because of vector types with different width")
     types = [peel_off_type(t, VectorType) for t in types]
 
     # now we should have a list of basic types - struct types are not yet supported
@@ -429,6 +439,8 @@ def get_type_of_expression(expr,
     expr = sp.sympify(expr)
     if isinstance(expr, sp.Integer):
         return create_type(default_int_type)
+    elif  expr.is_real == False:
+        return create_type(default_complex_type)
     elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
         return create_type(default_float_type)
     elif isinstance(expr, ResolvedFieldAccess):
@@ -453,7 +465,8 @@ def get_type_of_expression(expr,
     elif isinstance(expr, sp.Indexed):
         typed_symbol = expr.base.label
         return typed_symbol.dtype.base_type
-    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
+    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(
+            expr, sp.boolalg.BooleanFunction):
         # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
         result = create_type("bool")
         vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
@@ -499,14 +512,15 @@ class BasicType(Type):
             return 'std::complex<double>'
         elif name.startswith('int'):
             width = int(name[len("int"):])
-            return "int%d_t" % (width,)
+            return "int%d_t" % (width, )
         elif name.startswith('uint'):
             width = int(name[len("uint"):])
-            return "uint%d_t" % (width,)
+            return "uint%d_t" % (width, )
         elif name == 'bool':
             return 'bool'
         else:
-            raise NotImplementedError("Can map numpy to C name for %s" % (name,))
+            raise NotImplementedError("Can map numpy to C name for %s" %
+                                      (name, ))
 
     def __init__(self, dtype, const=False):
         self.const = const
@@ -534,7 +548,8 @@ class BasicType(Type):
         return 1
 
     def is_int(self):
-        return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
+        return self.numpy_dtype in np.sctypes[
+            'int'] or self.numpy_dtype in np.sctypes['uint']
 
     def is_float(self):
         return self.numpy_dtype in np.sctypes['float']
@@ -565,7 +580,8 @@ class BasicType(Type):
         if not isinstance(other, BasicType):
             return False
         else:
-            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
+            return (self.numpy_dtype, self.const) == (other.numpy_dtype,
+                                                      other.const)
 
     def __hash__(self):
         return hash(str(self))
@@ -590,7 +606,8 @@ class VectorType(Type):
         if not isinstance(other, VectorType):
             return False
         else:
-            return (self.base_type, self.width) == (other.base_type, other.width)
+            return (self.base_type, self.width) == (other.base_type,
+                                                    other.width)
 
     def __str__(self):
         if self.instruction_set is None:
@@ -639,7 +656,9 @@ class PointerType(Type):
         if not isinstance(other, PointerType):
             return False
         else:
-            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
+            return (self.base_type, self.const,
+                    self.restrict) == (other.base_type, other.const,
+                                       other.restrict)
 
     def __str__(self):
         components = [str(self.base_type), '*']
@@ -690,7 +709,8 @@ class StructType:
         if not isinstance(other, StructType):
             return False
         else:
-            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
+            return (self.numpy_dtype, self.const) == (other.numpy_dtype,
+                                                      other.const)
 
     def __str__(self):
         # structs are handled byte-wise
diff --git a/pystencils/include/complex_helper.hpp b/pystencils/include/complex_helper.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..92b8f7c5454fa08d2b82d2ddf56a431ecb06a65d
--- /dev/null
+++ b/pystencils/include/complex_helper.hpp
@@ -0,0 +1,66 @@
+/*
+ * complex_helper.hpp
+ * Copyright (C) 2019 Stephan Seitz <stephan.seitz@fau.de>
+ *
+ * Distributed under terms of the GPLv3 license.
+ */
+
+#pragma once
+
+#include <complex>
+
+template <class U, class V>
+auto operator*(const std::complex<U> &complexNumber, const V &scalar)
+    -> std::complex<U> {
+  return std::complex<U>{std::real(complexNumber) * scalar,
+                         std::imag(complexNumber) * scalar};
+}
+
+template <class U, class V>
+auto operator*(const V &scalar, const std::complex<U> &complexNumber)
+    -> std::complex<U> {
+  return std::complex<U>{std::real(complexNumber) * scalar,
+                         std::imag(complexNumber) * scalar};
+}
+
+template <class U, class V>
+auto operator+(const std::complex<U> &complexNumber, const V &scalar)
+    -> std::complex<U> {
+  return std::complex<U>{std::real(complexNumber) + scalar,
+                         std::imag(complexNumber)};
+}
+
+template <class U, class V>
+auto operator+(const V &scalar, const std::complex<U> &complexNumber)
+    -> std::complex<U> {
+  return std::complex<U>{std::real(complexNumber) + scalar,
+                         std::imag(complexNumber)};
+}
+
+template <class U, class V>
+auto operator-(const std::complex<U> &complexNumber, const V &scalar)
+    -> std::complex<U> {
+  return std::complex<U>{std::real(complexNumber) - scalar,
+                         std::imag(complexNumber)};
+}
+
+template <class U, class V>
+auto operator-(const V &scalar, const std::complex<U> &complexNumber)
+    -> std::complex<U> {
+  return std::complex<U>{scalar - std::real(complexNumber),
+                         std::imag(complexNumber)};
+}
+
+template <class U, class V>
+auto operator/(const std::complex<U> &complexNumber, const V &scalar)
+    -> std::complex<U> {
+  return std::complex<U>{std::real(complexNumber) / scalar,
+                         std::imag(complexNumber) / scalar};
+}
+
+template <class U, class V>
+auto operator/(const V &scalar, const std::complex<U> &complexNumber)
+    -> std::complex<U> {
+  return std::complex<U>{scalar, 0} / complexNumber;
+}
+
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index 554b9cf9b2257e6b0ef97b24239ea210baa8b63b..2aa5f1603bfe3310666c2ec55fb8b09a720243cc 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -887,7 +887,7 @@ class KernelConstraintsCheck:
     def process_assignment(self, assignment):
         # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
         new_rhs = self.process_expression(assignment.rhs)
-        new_lhs = self._process_lhs(assignment.lhs)
+        new_lhs = self._process_lhs(assignment.lhs, assignment.rhs)
         return ast.SympyAssignment(new_lhs, new_rhs)
 
     def process_expression(self, rhs, type_constants=True):
@@ -933,7 +933,7 @@ class KernelConstraintsCheck:
     def fields_written(self):
         return set(k.field for k, v in self._field_writes.items() if len(v))
 
-    def _process_lhs(self, lhs):
+    def _process_lhs(self, lhs, rhs):
         assert isinstance(lhs, sp.Symbol)
         self._update_accesses_lhs(lhs)
         if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
diff --git a/pystencils_tests/test_complex_numbers.py b/pystencils_tests/test_complex_numbers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1230f3a31ab53346befa1f9db62359f34a7f606
--- /dev/null
+++ b/pystencils_tests/test_complex_numbers.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
+#
+# Distributed under terms of the GPLv3 license.
+"""
+
+"""
+
+import itertools
+
+import pytest
+import sympy
+from sympy.functions import im, re
+
+import pystencils
+from pystencils import AssignmentCollection
+from pystencils.data_types import create_type, TypedSymbol
+
+
+X, Y = pystencils.fields('x, y: complex64[2d]')
+A, B = pystencils.fields('a, b: float32[2d]')
+S1, S2 = sympy.symbols('S1, S2')
+T64 = TypedSymbol('t', create_type('complex64'))
+
+TEST_ASSIGNMENTS = [
+    AssignmentCollection({X[0, 0]: 1j}),
+    AssignmentCollection({
+        S1: re(Y.center),
+        S2: im(Y.center),
+        X[0, 0]: 2j * S1 + S2
+    }),
+    AssignmentCollection({
+        A.center: re(Y.center),
+        B.center: im(Y.center),
+    }),
+    AssignmentCollection({
+        Y.center: re(Y.center) + X.center + 2j,
+    }),
+    AssignmentCollection({
+        T64: 2 + 4j,
+        Y.center: X.center / T64,
+    })
+]
+
+SCALAR_DTYPES = ['float32', 'float64']
+
+
+@pytest.mark.parametrize("assignment, scalar_dtypes",
+                         itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES))
+def test_complex_numbers(assignment, scalar_dtypes):
+    ast = pystencils.create_kernel(assignment,
+                                   target='cpu',
+                                   data_type=scalar_dtypes)
+    code = str(pystencils.show_code(ast))
+
+    print(code)
+    assert "Not supported" not in code
+
+    kernel = ast.compile()
+    assert kernel is not None
+
+X, Y = pystencils.fields('x, y: complex128[2d]')
+A, B = pystencils.fields('a, b: float64[2d]')
+S1, S2 = sympy.symbols('S1, S2')
+T128 = TypedSymbol('t', create_type('complex128'))
+
+TEST_ASSIGNMENTS = [
+    AssignmentCollection({X[0, 0]: 1j}),
+    AssignmentCollection({
+        S1: re(Y.center),
+        S2: im(Y.center),
+        X[0, 0]: 2j * S1 + S2
+    }),
+    AssignmentCollection({
+        A.center: re(Y.center),
+        B.center: im(Y.center),
+    }),
+    AssignmentCollection({
+        Y.center: re(Y.center) + X.center + 2j,
+    }),
+    AssignmentCollection({
+        T128: 2 + 4j,
+        Y.center: X.center / T128,
+    })
+]
+
+SCALAR_DTYPES = ['float32', 'float64']
+@pytest.mark.parametrize("assignment, scalar_dtypes",
+                         itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES))
+def test_complex_numbers_64(assignment, scalar_dtypes):
+    ast = pystencils.create_kernel(assignment,
+                                   target='cpu',
+                                   data_type=scalar_dtypes)
+    code = str(pystencils.show_code(ast))
+
+    print(code)
+    assert "Not supported" not in code
+
+    kernel = ast.compile()
+    assert kernel is not None