diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py
index fc6765e589e341b40ebbc778c7c26839531ae826..c28e7d0ad686cad3f1689f1c1e8c67a0ab40def5 100644
--- a/pystencils/cpu/kernelcreation.py
+++ b/pystencils/cpu/kernelcreation.py
@@ -8,6 +8,7 @@ from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, Sympy
 from pystencils.cpu.cpujit import make_python_function
 from pystencils.data_types import BasicType, StructType, TypedSymbol, create_type
 from pystencils.field import Field, FieldType
+from pystencils.math_optimizations import optims_pystencils_cpu, optimize_assignments
 from pystencils.transformations import (
     add_types, filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering,
     make_loop_over_domain, move_constants_before_loop, parse_base_pointer_info,
@@ -18,7 +19,7 @@ AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]]
 
 def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "kernel", type_info='double',
                   split_groups=(), iteration_slice=None, ghost_layers=None,
-                  skip_independence_check=False) -> KernelFunction:
+                  skip_independence_check=False, sympy_optimizations=optims_pystencils_cpu) -> KernelFunction:
     """Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
 
     Loops are created according to the field accesses in the equations.
@@ -54,6 +55,10 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
         else:
             raise ValueError("Term has to be field access or symbol")
 
+    if sympy_optimizations is None:
+        sympy_optimizations = optims_pystencils_cpu
+    assignments = optimize_assignments(assignments, sympy_optimizations)
+
     fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
     all_fields = fields_read.union(fields_written)
     read_only_fields = set([f.name for f in fields_read - fields_written])
@@ -89,8 +94,12 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
     return ast_node
 
 
-def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, function_name="kernel",
-                          type_info=None, coordinate_names=('x', 'y', 'z')) -> KernelFunction:
+def create_indexed_kernel(assignments: AssignmentOrAstNodeList,
+                          index_fields,
+                          function_name="kernel",
+                          type_info=None,
+                          coordinate_names=('x', 'y', 'z'),
+                          sympy_optimizations=optims_pystencils_cpu) -> KernelFunction:
     """
     Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
     coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.
@@ -107,6 +116,11 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu
         function_name: see documentation of :func:`create_kernel`
         coordinate_names: name of the coordinate fields in the struct data type
     """
+    if sympy_optimizations is None:
+        sympy_optimizations = optims_pystencils_cpu
+
+    assignments = optimize_assignments(assignments, sympy_optimizations)
+
     fields_read, fields_written, assignments = add_types(assignments, type_info, check_independence_condition=False)
     all_fields = fields_read.union(fields_written)
 
diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py
index ff82107000b00595dbcca6699ed42b172c324353..12c7bd3c97dd8f8cc17f2641dbd442b615110082 100644
--- a/pystencils/gpucuda/kernelcreation.py
+++ b/pystencils/gpucuda/kernelcreation.py
@@ -3,13 +3,25 @@ from pystencils.data_types import BasicType, StructType, TypedSymbol
 from pystencils.field import Field, FieldType
 from pystencils.gpucuda.cudajit import make_python_function
 from pystencils.gpucuda.indexing import BlockIndexing
+from pystencils.math_optimizations import optimize_assignments, optims_pystencils_gpu
 from pystencils.transformations import (
     add_types, get_base_buffer_index, get_common_shape, parse_base_pointer_info,
     resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols)
 
 
-def create_cuda_kernel(assignments, function_name="kernel", type_info=None, indexing_creator=BlockIndexing,
-                       iteration_slice=None, ghost_layers=None, skip_independence_check=False):
+def create_cuda_kernel(assignments,
+                       function_name="kernel",
+                       type_info=None,
+                       indexing_creator=BlockIndexing,
+                       iteration_slice=None,
+                       ghost_layers=None,
+                       skip_independence_check=False,
+                       sympy_optimizations=None):
+
+    if sympy_optimizations is None:
+        sympy_optimizations = optims_pystencils_gpu
+    assignments = optimize_assignments(assignments, sympy_optimizations)
+
     fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
     all_fields = fields_read.union(fields_written)
     read_only_fields = set([f.name for f in fields_read - fields_written])
@@ -86,8 +98,17 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
     return ast
 
 
-def created_indexed_cuda_kernel(assignments, index_fields, function_name="kernel", type_info=None,
-                                coordinate_names=('x', 'y', 'z'), indexing_creator=BlockIndexing):
+def created_indexed_cuda_kernel(assignments,
+                                index_fields,
+                                function_name="kernel",
+                                type_info=None,
+                                coordinate_names=('x', 'y', 'z'),
+                                indexing_creator=BlockIndexing,
+                                sympy_optimizations=None):
+    if sympy_optimizations is None:
+        sympy_optimizations = optims_pystencils_gpu
+    assignments = optimize_assignments(assignments, sympy_optimizations)
+
     fields_read, fields_written, assignments = add_types(assignments, type_info, check_independence_condition=False)
     all_fields = fields_read.union(fields_written)
     read_only_fields = set([f.name for f in fields_read - fields_written])
diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index ade980f55969a4005f2c5055ad27f861ab5524de..79c3d04d10224b63733bbbc8b88c9ebb093e5ce6 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -15,7 +15,7 @@ from pystencils.transformations import (
 def create_kernel(assignments, target='cpu', data_type="double", iteration_slice=None, ghost_layers=None,
                   skip_independence_check=False,
                   cpu_openmp=False, cpu_vectorize_info=None, cpu_blocking=None,
-                  gpu_indexing='block', gpu_indexing_params=MappingProxyType({})):
+                  gpu_indexing='block', gpu_indexing_params=MappingProxyType({}), sympy_optimizations=None):
     """
     Creates abstract syntax tree (AST) of kernel, using a list of update equations.
 
@@ -75,7 +75,7 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice
         from pystencils.cpu import add_openmp
         ast = create_kernel(assignments, type_info=data_type, split_groups=split_groups,
                             iteration_slice=iteration_slice, ghost_layers=ghost_layers,
-                            skip_independence_check=skip_independence_check)
+                            skip_independence_check=skip_independence_check, sympy_optimizations=sympy_optimizations)
         omp_collapse = None
         if cpu_blocking:
             omp_collapse = loop_blocking(ast, cpu_blocking)
@@ -91,22 +91,29 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice
         return ast
     elif target == 'llvm':
         from pystencils.llvm import create_kernel
-        ast = create_kernel(assignments, type_info=data_type, split_groups=split_groups,
-                            iteration_slice=iteration_slice, ghost_layers=ghost_layers)
+        ast = create_kernel(assignments,
+                            type_info=data_type,
+                            split_groups=split_groups,
+                            iteration_slice=iteration_slice,
+                            ghost_layers=ghost_layers,
+                            sympy_optimizations=sympy_optimizations)
         return ast
     elif target == 'gpu':
         from pystencils.gpucuda import create_cuda_kernel
         ast = create_cuda_kernel(assignments, type_info=data_type,
                                  indexing_creator=indexing_creator_from_params(gpu_indexing, gpu_indexing_params),
-                                 iteration_slice=iteration_slice, ghost_layers=ghost_layers,
-                                 skip_independence_check=skip_independence_check)
+                                 iteration_slice=iteration_slice,
+                                 ghost_layers=ghost_layers,
+                                 skip_independence_check=skip_independence_check,
+                                 sympy_optimizations=sympy_optimizations)
         return ast
     else:
         raise ValueError("Unknown target %s. Has to be one of 'cpu', 'gpu' or 'llvm' " % (target,))
 
 
 def create_indexed_kernel(assignments, index_fields, target='cpu', data_type="double", coordinate_names=('x', 'y', 'z'),
-                          cpu_openmp=True, gpu_indexing='block', gpu_indexing_params=MappingProxyType({})):
+                          cpu_openmp=True, gpu_indexing='block', gpu_indexing_params=MappingProxyType({}),
+                          sympy_optimizations=None):
     """
     Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
     coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.
@@ -150,7 +157,7 @@ def create_indexed_kernel(assignments, index_fields, target='cpu', data_type="do
         from pystencils.cpu import create_indexed_kernel
         from pystencils.cpu import add_openmp
         ast = create_indexed_kernel(assignments, index_fields=index_fields, type_info=data_type,
-                                    coordinate_names=coordinate_names)
+                                    coordinate_names=coordinate_names, sympy_optimizations=sympy_optimizations)
         if cpu_openmp:
             add_openmp(ast, num_threads=cpu_openmp)
         return ast
@@ -160,14 +167,15 @@ def create_indexed_kernel(assignments, index_fields, target='cpu', data_type="do
         from pystencils.gpucuda import created_indexed_cuda_kernel
         idx_creator = indexing_creator_from_params(gpu_indexing, gpu_indexing_params)
         ast = created_indexed_cuda_kernel(assignments, index_fields, type_info=data_type,
-                                          coordinate_names=coordinate_names, indexing_creator=idx_creator)
+                                          coordinate_names=coordinate_names, indexing_creator=idx_creator,
+                                          sympy_optimizations=sympy_optimizations)
         return ast
     else:
         raise ValueError("Unknown target %s. Has to be either 'cpu' or 'gpu'" % (target,))
 
 
 def create_staggered_kernel(staggered_field, expressions, subexpressions=(), target='cpu',
-                            gpu_exclusive_conditions=False, **kwargs):
+                            gpu_exclusive_conditions=False, sympy_optimizations=None, **kwargs):
     """Kernel that updates a staggered field.
 
     .. image:: /img/staggered_grid.svg
@@ -251,7 +259,11 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar
     if openmp:
         del kwargs['cpu_openmp']
 
-    ast = create_kernel(final_assignments, ghost_layers=ghost_layers, target=target, **kwargs)
+    ast = create_kernel(final_assignments,
+                        ghost_layers=ghost_layers,
+                        target=target,
+                        sympy_optimizations=sympy_optimizations,
+                        **kwargs)
 
     if target == 'cpu':
         remove_conditionals_in_staggered_kernel(ast)
diff --git a/pystencils/llvm/kernelcreation.py b/pystencils/llvm/kernelcreation.py
index 38ac7fe6be818729eae3a935a4c14003f066e849..86f26ef6f388d95446f323b5c8965b0d90ff15af 100644
--- a/pystencils/llvm/kernelcreation.py
+++ b/pystencils/llvm/kernelcreation.py
@@ -3,7 +3,7 @@ from pystencils.transformations import insert_casts
 
 
 def create_kernel(assignments, function_name="kernel", type_info=None, split_groups=(),
-                  iteration_slice=None, ghost_layers=None):
+                  iteration_slice=None, ghost_layers=None, sympy_optimizations=None):
     """
     Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
 
@@ -26,7 +26,13 @@ def create_kernel(assignments, function_name="kernel", type_info=None, split_gro
     :return: :class:`pystencils.ast.KernelFunction` node
     """
     from pystencils.cpu import create_kernel
-    code = create_kernel(assignments, function_name, type_info, split_groups, iteration_slice, ghost_layers)
+    code = create_kernel(assignments,
+                         function_name,
+                         type_info,
+                         split_groups,
+                         iteration_slice,
+                         ghost_layers,
+                         sympy_optimizations=sympy_optimizations)
     code.body = insert_casts(code.body)
     code._compile_function = make_python_function
     code._backend = 'llvm'
diff --git a/pystencils_tests/test_sum_prod.py b/pystencils_tests/test_sum_prod.py
index 4fa5c0618612b013edda4d164dd035dafdd2438a..3aa51b85623764f5ddc49a84472e45a24c52fcc4 100644
--- a/pystencils_tests/test_sum_prod.py
+++ b/pystencils_tests/test_sum_prod.py
@@ -29,7 +29,7 @@ def test_sum():
         x.center(): sum
     })
 
-    ast = pystencils.create_kernel(assignments)
+    ast = pystencils.create_kernel(assignments, sympy_optimizations=[])
     code = str(pystencils.show_code(ast))
     kernel = ast.compile()
 
@@ -57,7 +57,7 @@ def test_sum_use_float():
         x.center(): sum
     })
 
-    ast = pystencils.create_kernel(assignments, data_type=create_type('float32'))
+    ast = pystencils.create_kernel(assignments, data_type=create_type('float32'), sympy_optimizations=[])
     code = str(pystencils.show_code(ast))
     kernel = ast.compile()
 
@@ -88,7 +88,7 @@ def test_product():
         x.center(): sum
     })
 
-    ast = pystencils.create_kernel(assignments)
+    ast = pystencils.create_kernel(assignments, sympy_optimizations=[])
     code = str(pystencils.show_code(ast))
     kernel = ast.compile()
 
diff --git a/pystencils_tests/test_sympy_optimizations.py b/pystencils_tests/test_sympy_optimizations.py
index 745b936ea7d0e3ce5da6c684dee42c56c09dd835..8a76370b6f587088f8c0d389e4ef05ad23cba747 100644
--- a/pystencils_tests/test_sympy_optimizations.py
+++ b/pystencils_tests/test_sympy_optimizations.py
@@ -15,7 +15,7 @@ def test_sympy_optimizations():
             x[0, 0]: sp.exp(y[0, 0]) - 1
         })
 
-        assignments = optimize_assignments(assignments, optims_pystencils_cpu)
+        optimize_assignments(assignments, optims_pystencils_cpu)
 
         ast = pystencils.create_kernel(assignments, target=target)
         code = str(pystencils.show_code(ast))
@@ -32,7 +32,7 @@ def test_evaluate_constant_terms():
             x[0, 0]: -sp.cos(1) + y[0, 0]
         })
 
-        assignments = optimize_assignments(assignments, optims_pystencils_cpu)
+        optimize_assignments(assignments, optims_pystencils_cpu)
 
         ast = pystencils.create_kernel(assignments, target=target)
         code = str(pystencils.show_code(ast))
@@ -54,7 +54,7 @@ def test_do_not_evaluate_constant_terms():
 
         optimize_assignments(assignments, optimizations)
 
-        ast = pystencils.create_kernel(assignments, target=target)
+        ast = pystencils.create_kernel(assignments, target=target, sympy_optimizations=optimizations)
         code = str(pystencils.show_code(ast))
         assert 'cos(' in code
         print(code)