From 72ea97cb5564d39604e37086122122a5826f64ab Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 9 Jul 2024 16:06:50 +0200
Subject: [PATCH] Revert previous changes to frontend API

---
 pyproject.toml                                |   2 +-
 src/pystencils/__init__.py                    |   8 +-
 src/pystencils/assignment.py                  | 110 ++++
 .../backend/kernelcreation/analysis.py        |   3 +-
 .../backend/kernelcreation/freeze.py          |   3 +-
 .../backend/kernelcreation/iteration_space.py |   2 +-
 .../boundaries/boundaryconditions.py          |   2 +-
 src/pystencils/boundaries/boundaryhandling.py |   2 +-
 src/pystencils/fd/finitedifferences.py        |   2 +-
 src/pystencils/kernel_decorator.py            |   2 +-
 src/pystencils/kernelcreation.py              |   4 +
 src/pystencils/placeholder_function.py        |   2 +-
 src/pystencils/{sympyextensions => }/rng.py   |   3 +-
 src/pystencils/simp/__init__.py               |  45 ++
 src/pystencils/simp/assignment_collection.py  | 476 ++++++++++++++
 .../simplifications.py                        |   8 +-
 .../simplificationstrategy.py                 |   2 +-
 .../subexpression_insertion.py                |   2 +-
 src/pystencils/sympyextensions/__init__.py    |  54 +-
 src/pystencils/sympyextensions/astnodes.py    | 590 +-----------------
 .../sympyextensions/fast_approximation.py     |   3 +-
 src/pystencils/sympyextensions/math.py        |   2 +-
 .../kernelcreation/test_domain_kernels.py     |   2 +-
 tests/symbolics/test_address_of.py            |   2 +-
 .../test_conditional_field_access.py          |   2 +-
 tests/test_fvm.py                             |   2 +-
 tests/test_random.py                          |   2 +-
 27 files changed, 669 insertions(+), 668 deletions(-)
 create mode 100644 src/pystencils/assignment.py
 rename src/pystencils/{sympyextensions => }/rng.py (97%)
 create mode 100644 src/pystencils/simp/__init__.py
 create mode 100644 src/pystencils/simp/assignment_collection.py
 rename src/pystencils/{sympyextensions => simp}/simplifications.py (97%)
 rename src/pystencils/{sympyextensions => simp}/simplificationstrategy.py (99%)
 rename src/pystencils/{sympyextensions => simp}/subexpression_insertion.py (98%)

diff --git a/pyproject.toml b/pyproject.toml
index 5ef106e59..cc33e2b65 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,7 +12,7 @@ authors = [
 ]
 license = { file = "COPYING.txt" }
 requires-python = ">=3.10"
-dependencies = ["sympy>=1.6,<=1.11.1", "numpy>=1.8.0", "appdirs", "joblib", "pyyaml"]
+dependencies = ["sympy>=1.9,<=1.12.1", "numpy>=1.8.0", "appdirs", "joblib", "pyyaml"]
 classifiers = [
     "Development Status :: 4 - Beta",
     "Framework :: Jupyter",
diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py
index c39cd3b82..ac3801518 100644
--- a/src/pystencils/__init__.py
+++ b/src/pystencils/__init__.py
@@ -15,7 +15,7 @@ from .config import (
     OpenMpConfig,
 )
 from .kernel_decorator import kernel, kernel_config
-from .kernelcreation import create_kernel
+from .kernelcreation import create_kernel, create_staggered_kernel
 from .backend.kernelfunction import KernelFunction
 from .slicing import make_slice
 from .spatial_coordinates import (
@@ -28,12 +28,13 @@ from .spatial_coordinates import (
     z_,
     z_staggered,
 )
-from .sympyextensions import Assignment, AssignmentCollection, AddAugmentedAssignment
-from .sympyextensions.astnodes import assignment_from_stencil
+from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil
+from .simp import AssignmentCollection
 from .sympyextensions.typed_sympy import TypedSymbol
 from .sympyextensions.math import SymbolCreator
 from .datahandling import create_data_handling
 
+
 __all__ = [
     "Field",
     "FieldType",
@@ -48,6 +49,7 @@ __all__ = [
     "VectorizationConfig",
     "OpenMpConfig",
     "create_kernel",
+    "create_staggered_kernel",
     "KernelFunction",
     "Target",
     "show_code",
diff --git a/src/pystencils/assignment.py b/src/pystencils/assignment.py
new file mode 100644
index 000000000..af32bc664
--- /dev/null
+++ b/src/pystencils/assignment.py
@@ -0,0 +1,110 @@
+import numpy as np
+import sympy as sp
+from sympy.codegen.ast import Assignment, AugmentedAssignment
+from sympy.codegen.ast import AddAugmentedAssignment as SpAddAugAssignment
+from sympy.printing.latex import LatexPrinter
+
+__all__ = ['Assignment', 'AugmentedAssignment', 'AddAugmentedAssignment', 'assignment_from_stencil']
+
+
+def print_assignment_latex(printer, expr):
+    binop = f"{expr.binop}=" if isinstance(expr, AugmentedAssignment) else ''
+    """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
+    printed_lhs = printer.doprint(expr.lhs)
+    printed_rhs = printer.doprint(expr.rhs)
+    return fr"{printed_lhs} \leftarrow_{{{binop}}} {printed_rhs}"
+
+
+def assignment_str(assignment):
+    op = f"{assignment.binop}=" if isinstance(assignment, AugmentedAssignment) else '←'
+    return fr"{assignment.lhs} {op} {assignment.rhs}"
+
+
+_old_new = sp.codegen.ast.Assignment.__new__
+
+
+# TODO Typing Part2 add default type, defult_float_type, default_int_type and use sane defaults
+def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
+    if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
+        assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
+        return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
+    return _old_new(cls, lhs, rhs, *args, **kwargs)
+
+
+Assignment.__str__ = assignment_str
+Assignment.__new__ = _Assignment__new__
+LatexPrinter._print_Assignment = print_assignment_latex
+
+AugmentedAssignment.__str__ = assignment_str
+LatexPrinter._print_AugmentedAssignment = print_assignment_latex
+
+sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
+
+#   Re-Export
+AddAugmentedAssignment = SpAddAugAssignment
+
+
+def assignment_from_stencil(stencil_array, input_field, output_field,
+                            normalization_factor=None, order='visual') -> Assignment:
+    """Creates an assignment
+
+    Args:
+        stencil_array: nested list of numpy array defining the stencil weights
+        input_field: field or field access, defining where the stencil should be applied to
+        output_field: field or field access where the result is written to
+        normalization_factor: optional normalization factor for the stencil
+        order: defines how the stencil_array is interpreted. Possible values are 'visual' and 'numpy'.
+               For details see examples
+
+    Returns:
+        Assignment that can be used to create a kernel
+
+    Examples:
+        >>> import pystencils as ps
+        >>> f, g = ps.fields("f, g: [2D]")
+        >>> stencil = [[0, 2, 0],
+        ...            [3, 4, 5],
+        ...            [0, 6, 0]]
+
+        By default 'visual ordering is used - i.e. the stencil is applied as the nested lists are written down
+        >>> expected_output = Assignment(g[0, 0], 3*f[-1, 0] + 6*f[0, -1] + 4*f[0, 0] + 2*f[0, 1] + 5*f[1, 0])
+        >>> assignment_from_stencil(stencil, f, g, order='visual') == expected_output
+        True
+
+        'numpy' ordering uses the first coordinate of the stencil array for x offset, second for y offset etc.
+        >>> expected_output = Assignment(g[0, 0], 2*f[-1, 0] + 3*f[0, -1] + 4*f[0, 0] + 5*f[0, 1] + 6*f[1, 0])
+        >>> assignment_from_stencil(stencil, f, g, order='numpy') == expected_output
+        True
+
+        You can also pass field accesses to apply the stencil at an already shifted position:
+        >>> expected_output = Assignment(g[2, 0], 3*f[0, 0] + 6*f[1, -1] + 4*f[1, 0] + 2*f[1, 1] + 5*f[2, 0])
+        >>> assignment_from_stencil(stencil, f[1, 0], g[2, 0]) == expected_output
+        True
+    """
+    from pystencils.field import Field
+
+    stencil_array = np.array(stencil_array)
+    if order == 'visual':
+        stencil_array = np.swapaxes(stencil_array, 0, 1)
+        stencil_array = np.flip(stencil_array, axis=1)
+    elif order == 'numpy':
+        pass
+    else:
+        raise ValueError("'order' has to be either 'visual' or 'numpy'")
+
+    if isinstance(input_field, Field):
+        input_field = input_field.center
+    if isinstance(output_field, Field):
+        output_field = output_field.center
+
+    rhs = 0
+    offset = tuple(s // 2 for s in stencil_array.shape)
+
+    for index, factor in np.ndenumerate(stencil_array):
+        shift = tuple(i - o for i, o in zip(index, offset))
+        rhs += factor * input_field.get_shifted(*shift)
+
+    if normalization_factor:
+        rhs *= normalization_factor
+
+    return Assignment(output_field, rhs)
diff --git a/src/pystencils/backend/kernelcreation/analysis.py b/src/pystencils/backend/kernelcreation/analysis.py
index a72191b5b..05aa79928 100644
--- a/src/pystencils/backend/kernelcreation/analysis.py
+++ b/src/pystencils/backend/kernelcreation/analysis.py
@@ -9,7 +9,8 @@ import sympy as sp
 from .context import KernelCreationContext
 
 from ...field import Field
-from ...sympyextensions import Assignment, AssignmentCollection
+from ...assignment import Assignment
+from ...simp import AssignmentCollection
 
 from ..exceptions import PsInternalCompilerError, KernelConstraintsError
 
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 59fa04b3b..f81ed586b 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -7,7 +7,8 @@ import sympy.core.relational
 import sympy.logic.boolalg
 from sympy.codegen.ast import AssignmentBase, AugmentedAssignment
 
-from ...sympyextensions.astnodes import Assignment, AssignmentCollection
+from ...assignment import Assignment
+from ...simp import AssignmentCollection
 from ...sympyextensions import (
     integer_functions,
     ConditionalFieldAccess,
diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py
index 2a3d2774e..5208c906c 100644
--- a/src/pystencils/backend/kernelcreation/iteration_space.py
+++ b/src/pystencils/backend/kernelcreation/iteration_space.py
@@ -6,7 +6,7 @@ from functools import reduce
 from operator import mul
 
 from ...defaults import DEFAULTS
-from ...sympyextensions import AssignmentCollection
+from ...simp import AssignmentCollection
 from ...field import Field, FieldType
 
 from ..symbols import PsSymbol
diff --git a/src/pystencils/boundaries/boundaryconditions.py b/src/pystencils/boundaries/boundaryconditions.py
index f52573bca..cf6a3e824 100644
--- a/src/pystencils/boundaries/boundaryconditions.py
+++ b/src/pystencils/boundaries/boundaryconditions.py
@@ -1,6 +1,6 @@
 from typing import Any, List, Tuple, Sequence
 
-from pystencils.sympyextensions import Assignment
+from pystencils.assignment import Assignment
 from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo
 from pystencils.types import create_type
 
diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py
index 57a1cd95f..f171d5609 100644
--- a/src/pystencils/boundaries/boundaryhandling.py
+++ b/src/pystencils/boundaries/boundaryhandling.py
@@ -4,7 +4,7 @@ import numpy as np
 import sympy as sp
 
 from pystencils import create_kernel, CreateKernelConfig, Target
-from pystencils.sympyextensions import Assignment
+from pystencils.assignment import Assignment
 from pystencils.boundaries.createindexlist import (
     create_boundary_index_array, numpy_data_type_for_boundary_object)
 from pystencils.sympyextensions import TypedSymbol
diff --git a/src/pystencils/fd/finitedifferences.py b/src/pystencils/fd/finitedifferences.py
index 9c4116ee5..f34a448ed 100644
--- a/src/pystencils/fd/finitedifferences.py
+++ b/src/pystencils/fd/finitedifferences.py
@@ -7,7 +7,7 @@ from pystencils.fd import Diff
 from pystencils.fd.derivative import diff_args
 from pystencils.fd.spatial import fd_stencils_standard
 from pystencils.field import Field
-from pystencils.sympyextensions import AssignmentCollection
+from pystencils.simp import AssignmentCollection
 from pystencils.sympyextensions.math import fast_subs
 
 FieldOrFieldAccess = Union[Field, Field.Access]
diff --git a/src/pystencils/kernel_decorator.py b/src/pystencils/kernel_decorator.py
index deb94eec0..ce0a31d54 100644
--- a/src/pystencils/kernel_decorator.py
+++ b/src/pystencils/kernel_decorator.py
@@ -5,7 +5,7 @@ from typing import Callable, Union, List, Dict, Tuple
 
 import sympy as sp
 
-from .sympyextensions import Assignment
+from .assignment import Assignment
 from .sympyextensions.math import SymbolCreator
 from pystencils.config import CreateKernelConfig
 
diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py
index 3cda5aa46..154cb2307 100644
--- a/src/pystencils/kernelcreation.py
+++ b/src/pystencils/kernelcreation.py
@@ -152,3 +152,7 @@ def create_kernel_function(
     return KernelFunction(
         body, target_spec, function_name, params, req_headers, ctx.constraints, jit
     )
+
+
+def create_staggered_kernel(assignments, target: Target = Target.CPU, gpu_exclusive_conditions=False, **kwargs):
+    raise NotImplementedError("Staggered kernels are not yet implemented for pystencils 2.0")
diff --git a/src/pystencils/placeholder_function.py b/src/pystencils/placeholder_function.py
index 00acb17bd..e9a3a0aba 100644
--- a/src/pystencils/placeholder_function.py
+++ b/src/pystencils/placeholder_function.py
@@ -2,7 +2,7 @@ from typing import List
 
 import sympy as sp
 
-from pystencils.sympyextensions import Assignment
+from .assignment import Assignment
 from pystencils.sympyextensions import is_constant
 from pystencils.sympyextensions.astnodes import generic_visit
 
diff --git a/src/pystencils/sympyextensions/rng.py b/src/pystencils/rng.py
similarity index 97%
rename from src/pystencils/sympyextensions/rng.py
rename to src/pystencils/rng.py
index 859669a6a..d6c6cd274 100644
--- a/src/pystencils/sympyextensions/rng.py
+++ b/src/pystencils/rng.py
@@ -2,10 +2,9 @@ import copy
 import numpy as np
 import sympy as sp
 
-from pystencils.sympyextensions import TypedSymbol, CastFunc
+from .sympyextensions import TypedSymbol, CastFunc, fast_subs
 # from pystencils.sympyextensions.astnodes import LoopOverCoordinate # TODO nbackend: replace
 # from pystencils.backends.cbackend import CustomCodeNode # TODO nbackend: replace
-from pystencils.sympyextensions import fast_subs
 
 
 # class RNGBase(CustomCodeNode): TODO nbackend: replace
diff --git a/src/pystencils/simp/__init__.py b/src/pystencils/simp/__init__.py
new file mode 100644
index 000000000..6c553af8b
--- /dev/null
+++ b/src/pystencils/simp/__init__.py
@@ -0,0 +1,45 @@
+from .assignment_collection import AssignmentCollection
+from .simplifications import (
+    add_subexpressions_for_constants,
+    add_subexpressions_for_divisions,
+    add_subexpressions_for_field_reads,
+    add_subexpressions_for_sums,
+    apply_on_all_subexpressions,
+    apply_to_all_assignments,
+    subexpression_substitution_in_existing_subexpressions,
+    subexpression_substitution_in_main_assignments,
+    sympy_cse,
+    sympy_cse_on_assignment_list,
+)
+from .subexpression_insertion import (
+    insert_aliases,
+    insert_zeros,
+    insert_constants,
+    insert_constant_additions,
+    insert_constant_multiples,
+    insert_squares,
+    insert_symbol_times_minus_one,
+)
+from .simplificationstrategy import SimplificationStrategy
+
+__all__ = [
+    "AssignmentCollection",
+    "SimplificationStrategy",
+    "sympy_cse",
+    "sympy_cse_on_assignment_list",
+    "apply_to_all_assignments",
+    "apply_on_all_subexpressions",
+    "subexpression_substitution_in_existing_subexpressions",
+    "subexpression_substitution_in_main_assignments",
+    "add_subexpressions_for_constants",
+    "add_subexpressions_for_divisions",
+    "add_subexpressions_for_sums",
+    "add_subexpressions_for_field_reads",
+    "insert_aliases",
+    "insert_zeros",
+    "insert_constants",
+    "insert_constant_additions",
+    "insert_constant_multiples",
+    "insert_squares",
+    "insert_symbol_times_minus_one",
+]
diff --git a/src/pystencils/simp/assignment_collection.py b/src/pystencils/simp/assignment_collection.py
new file mode 100644
index 000000000..f1ba87154
--- /dev/null
+++ b/src/pystencils/simp/assignment_collection.py
@@ -0,0 +1,476 @@
+import itertools
+from copy import copy
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
+
+import sympy as sp
+
+import pystencils
+from ..assignment import Assignment
+from .simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
+from ..sympyextensions import count_operations, fast_subs
+
+
+class AssignmentCollection:
+    """
+    A collection of equations with subexpression definitions, also represented as assignments,
+    that are used in the main equations. AssignmentCollection can be passed to simplification methods.
+    These simplification methods can change the subexpressions, but the number and
+    left hand side of the main equations themselves is not altered.
+    Additionally a dictionary of simplification hints is stored, which are set by the functions that create
+    assignment collections to transport information to the simplification system.
+
+    Args:
+        main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each
+                          assignment is a field access. Thus the generated equations write on arrays.
+        subexpressions: List of assignments defining subexpressions used in main equations
+        simplification_hints: Dict that is used to annotate the assignment collection with hints that are
+                              used by the simplification system. See documentation of the simplification rules for
+                              potentially required hints and their meaning.
+        subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added
+                                        used to get new symbols that are unique for this AssignmentCollection
+
+    """
+
+    __match_args__ = ("main_assignments", "subexpressions")
+
+    # ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
+
+    def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
+                 subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None,
+                 simplification_hints: Optional[Dict[str, Any]] = None,
+                 subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
+
+        if subexpressions is None:
+            subexpressions = {}
+
+        if isinstance(main_assignments, Dict):
+            main_assignments = [Assignment(k, v)
+                                for k, v in main_assignments.items()]
+        if isinstance(subexpressions, Dict):
+            subexpressions = [Assignment(k, v)
+                              for k, v in subexpressions.items()]
+
+        main_assignments = list(itertools.chain.from_iterable(
+            [(a if isinstance(a, Iterable) else [a]) for a in main_assignments]))
+        subexpressions = list(itertools.chain.from_iterable(
+            [(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))
+
+        self.main_assignments = main_assignments
+        self.subexpressions = subexpressions
+
+        if simplification_hints is None:
+            simplification_hints = {}
+
+        self.simplification_hints = simplification_hints
+
+        ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name]
+        max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0
+
+        if subexpression_symbol_generator is None:
+            self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr)
+        else:
+            self.subexpression_symbol_generator = subexpression_symbol_generator
+
+    def add_simplification_hint(self, key: str, value: Any) -> None:
+        """Adds an entry to the simplification_hints dictionary and checks that is does not exist yet."""
+        assert key not in self.simplification_hints, "This hint already exists"
+        self.simplification_hints[key] = value
+
+    def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol:
+        """Adds a subexpression to current collection.
+
+        Args:
+            rhs: right hand side of new subexpression
+            lhs: optional left hand side of new subexpression. If None a new unique symbol is generated.
+            topological_sort: sort the subexpressions topologically after insertion, to make sure that
+                              definition of a symbol comes before its usage. If False, subexpression is appended.
+
+        Returns:
+            left hand side symbol (which could have been generated)
+        """
+        if lhs is None:
+            lhs = next(self.subexpression_symbol_generator)
+        eq = Assignment(lhs, rhs)
+        self.subexpressions.append(eq)
+        if topological_sort:
+            self.topological_sort(sort_subexpressions=True,
+                                  sort_main_assignments=False)
+        return lhs
+
+    def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
+        """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition."""
+        if sort_subexpressions:
+            self.subexpressions = sort_assignments_topologically(self.subexpressions)
+        if sort_main_assignments:
+            self.main_assignments = sort_assignments_topologically(self.main_assignments)
+
+    # ---------------------------------------------- Properties  -------------------------------------------------------
+
+    @property
+    def all_assignments(self) -> List[Assignment]:
+        """Subexpression and main equations as a single list."""
+        return self.subexpressions + self.main_assignments
+
+    @property
+    def rhs_symbols(self) -> Set[sp.Symbol]:
+        """All symbols used in the assignment collection, which occur on the rhs of any assignment."""
+        rhs_symbols = set()
+        for eq in self.all_assignments:
+            if isinstance(eq, Assignment):
+                rhs_symbols.update(eq.rhs.atoms(sp.Symbol))
+            # elif isinstance(eq, pystencils.astnodes.Node): # TODO remove or replace
+            #     rhs_symbols.update(eq.undefined_symbols)
+
+        return rhs_symbols
+
+    @property
+    def free_symbols(self) -> Set[sp.Symbol]:
+        """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
+        return self.rhs_symbols - self.bound_symbols
+
+    @property
+    def bound_symbols(self) -> Set[sp.Symbol]:
+        """All symbols which occur on the left hand side of a main assignment or a subexpression."""
+        bound_symbols_set = set(
+            [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)]
+        )
+
+        assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \
+            "Not in SSA form - same symbol assigned multiple times"
+
+        # bound_symbols_set = bound_symbols_set.union(*[
+        #     assignment.symbols_defined for assignment in self.all_assignments
+        #     if isinstance(assignment, pystencils.astnodes.Node)
+        # ])  TODO: replace?
+
+        return bound_symbols_set
+
+    @property
+    def rhs_fields(self):
+        """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
+        return {s.field for s in self.rhs_symbols if hasattr(s, 'field')}
+
+    @property
+    def free_fields(self):
+        """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
+        return {s.field for s in self.free_symbols if hasattr(s, 'field')}
+
+    @property
+    def bound_fields(self):
+        """All field accessed on the left hand side of a main assignment or a subexpression."""
+        return {s.field for s in self.bound_symbols if hasattr(s, 'field')}
+
+    @property
+    def defined_symbols(self) -> Set[sp.Symbol]:
+        """All symbols which occur as left-hand-sides of one of the main equations"""
+        lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)])
+        return lhs_set
+        # return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments
+        #                         if isinstance(assignment, pystencils.astnodes.Node)])) TODO
+
+    @property
+    def operation_count(self):
+        """See :func:`count_operations` """
+        return count_operations(self.all_assignments, only_type=None)
+
+    def atoms(self, *args):
+        return set().union(*[a.atoms(*args) for a in self.all_assignments])
+
+    def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]:
+        """Returns all symbols that depend on one of the passed symbols.
+
+        A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some_expression(b)' i.e. when
+        'b' is required to compute 'a'.
+        """
+
+        queue = list(symbols)
+
+        def add_symbols_from_expr(expr):
+            dependent_symbols = expr.atoms(sp.Symbol)
+            for ds in dependent_symbols:
+                queue.append(ds)
+
+        handled_symbols = set()
+        assignment_dict = {e.lhs: e.rhs for e in self.all_assignments}
+
+        while len(queue) > 0:
+            e = queue.pop(0)
+            if e in handled_symbols:
+                continue
+            if e in assignment_dict:
+                add_symbols_from_expr(assignment_dict[e])
+            handled_symbols.add(e)
+
+        return handled_symbols
+
+    def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]] = None, module=None):
+        """Returns a python function to evaluate this equation collection.
+
+        Args:
+            symbols: symbol(s) which are the parameter for the created function
+            fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify
+            module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy'
+
+        Examples:
+              >>> a, b, c, d = sp.symbols("a b c d")
+              >>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)],
+              ...                           subexpressions=[Assignment(b, a + b / 2)])
+              >>> python_function = ac.lambdify([a], fixed_symbols={b: 2})
+              >>> python_function(4)
+              {c: 6, d: 18}
+        """
+        assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self
+        assignments = assignments.new_without_subexpressions().main_assignments
+        lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments}
+
+        def f(*args, **kwargs):
+            return {s: func(*args, **kwargs) for s, func in lambdas.items()}
+
+        return f
+
+    # ---------------------------- Creating new modified collections ---------------------------------------------------
+
+    def copy(self,
+             main_assignments: Optional[List[Assignment]] = None,
+             subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection':
+        """Returns a copy with optionally replaced main_assignments and/or subexpressions."""
+
+        res = copy(self)
+        res.simplification_hints = self.simplification_hints.copy()
+        res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator)
+
+        if main_assignments is not None:
+            res.main_assignments = main_assignments
+        else:
+            res.main_assignments = self.main_assignments.copy()
+
+        if subexpressions is not None:
+            res.subexpressions = subexpressions
+        else:
+            res.subexpressions = self.subexpressions.copy()
+
+        return res
+
+    def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False,
+                               substitute_on_lhs: bool = True,
+                               sort_topologically: bool = True) -> 'AssignmentCollection':
+        """Returns new object, where terms are substituted according to the passed substitution dict.
+
+        Args:
+            substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions
+            add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions
+            substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments
+            sort_topologically: if subexpressions are added as substitutions and this parameters is true,
+                                the subexpressions are sorted topologically after insertion
+        Returns:
+            New AssignmentCollection where substitutions have been applied, self is not altered.
+        """
+        transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs
+        transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions)
+        transformed_assignments = transform(self.main_assignments, fast_subs, substitutions)
+
+        if add_substitutions_as_subexpressions:
+            transformed_subexpressions = [Assignment(b, a) for a, b in
+                                          substitutions.items()] + transformed_subexpressions
+            if sort_topologically:
+                transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)
+        return self.copy(transformed_assignments, transformed_subexpressions)
+
+    def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
+        """Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
+        own_definitions = set([e.lhs for e in self.main_assignments])
+        other_definitions = set([e.lhs for e in other.main_assignments])
+        assert len(own_definitions.intersection(other_definitions)) == 0, \
+            "Cannot merge collections, since both define the same symbols"
+
+        own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
+        substitution_dict = {}
+
+        processed_other_subexpression_equations = []
+        for other_subexpression_eq in other.subexpressions:
+            if other_subexpression_eq.lhs in own_subexpression_symbols:
+                new_rhs = fast_subs(other_subexpression_eq.rhs, substitution_dict)
+                if new_rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
+                    continue  # exact the same subexpression equation exists already
+                else:
+                    # different definition - a new name has to be introduced
+                    new_lhs = next(self.subexpression_symbol_generator)
+                    new_eq = Assignment(new_lhs, new_rhs)
+                    processed_other_subexpression_equations.append(new_eq)
+                    substitution_dict[other_subexpression_eq.lhs] = new_lhs
+            else:
+                processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict))
+
+        processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments]
+        return self.copy(self.main_assignments + processed_other_main_assignments,
+                         self.subexpressions + processed_other_subexpression_equations)
+
+    def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection':
+        """Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions.
+
+        Returns:
+            new AssignmentCollection, self is not altered
+        """
+        symbols_to_extract = set(symbols_to_extract)
+        dependent_symbols = self.dependent_symbols(symbols_to_extract)
+        new_assignments = []
+        for eq in self.all_assignments:
+            if eq.lhs in symbols_to_extract:
+                new_assignments.append(eq)
+
+        new_sub_expr = [eq for eq in self.all_assignments
+                        if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]
+        return self.copy(new_assignments, new_sub_expr)
+
+    def new_without_unused_subexpressions(self) -> 'AssignmentCollection':
+        """Returns new collection that only contains subexpressions required to compute the main assignments."""
+        all_lhs = [eq.lhs for eq in self.main_assignments]
+        return self.new_filtered(all_lhs)
+
+    def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection':
+        """Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere."""
+        new_subexpressions = []
+        subs_dict = None
+        for se in self.subexpressions:
+            if se.lhs == symbol:
+                subs_dict = {se.lhs: se.rhs}
+            else:
+                new_subexpressions.append(se)
+        if subs_dict is None:
+            return self
+
+        new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions]
+        new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]
+        return self.copy(new_eqs, new_subexpressions)
+
+    def new_without_subexpressions(self, subexpressions_to_keep=None) -> 'AssignmentCollection':
+        """Returns a new collection where all subexpressions have been inserted."""
+        if subexpressions_to_keep is None:
+            subexpressions_to_keep = set()
+        if len(self.subexpressions) == 0:
+            return self.copy()
+
+        subexpressions_to_keep = set(subexpressions_to_keep)
+
+        kept_subexpressions = []
+        if self.subexpressions[0].lhs in subexpressions_to_keep:
+            substitution_dict = {}
+            kept_subexpressions.append(self.subexpressions[0])
+        else:
+            substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
+
+        subexpression = [e for e in self.subexpressions]
+        for i in range(1, len(subexpression)):
+            subexpression[i] = fast_subs(subexpression[i], substitution_dict)
+            if subexpression[i].lhs in subexpressions_to_keep:
+                kept_subexpressions.append(subexpression[i])
+            else:
+                substitution_dict[subexpression[i].lhs] = subexpression[i].rhs
+
+        new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments]
+        return self.copy(new_assignment, kept_subexpressions)
+
+    # ----------------------------------------- Display and Printing   -------------------------------------------------
+
+    def _repr_html_(self):
+        """Interface to Jupyter notebook, to display as a nicely formatted HTML table"""
+
+        def make_html_equation_table(equations):
+            no_border = 'style="border:none"'
+            html_table = '<table style="border:none; width: 100%; ">'
+            line = '<tr {nb}> <td {nb}>$${eq}$$</td>  </tr> '
+            for eq in equations:
+                format_dict = {'eq': sp.latex(eq),
+                               'nb': no_border, }
+                html_table += line.format(**format_dict)
+            html_table += "</table>"
+            return html_table
+
+        result = ""
+        if len(self.subexpressions) > 0:
+            result += "<div>Subexpressions:</div>"
+            result += make_html_equation_table(self.subexpressions)
+        result += "<div>Main Assignments:</div>"
+        result += make_html_equation_table(self.main_assignments)
+        return result
+
+    def __repr__(self):
+        return f"AssignmentCollection: {str(tuple(self.defined_symbols))[1:-1]} <- f{tuple(self.free_symbols)}"
+
+    def __str__(self):
+        result = "Subexpressions:\n"
+        for eq in self.subexpressions:
+            result += f"\t{eq}\n"
+        result += "Main Assignments:\n"
+        for eq in self.main_assignments:
+            result += f"\t{eq}\n"
+        return result
+
+    def __iter__(self):
+        return self.all_assignments.__iter__()
+
+    @property
+    def main_assignments_dict(self):
+        return {a.lhs: a.rhs for a in self.main_assignments}
+
+    @property
+    def subexpressions_dict(self):
+        return {a.lhs: a.rhs for a in self.subexpressions}
+
+    def set_main_assignments_from_dict(self, main_assignments_dict):
+        self.main_assignments = [Assignment(k, v)
+                                 for k, v in main_assignments_dict.items()]
+
+    def set_sub_expressions_from_dict(self, sub_expressions_dict):
+        self.subexpressions = [Assignment(k, v)
+                               for k, v in sub_expressions_dict.items()]
+
+    def find(self, *args, **kwargs):
+        return set.union(
+            *[a.find(*args, **kwargs) for a in self.all_assignments]
+        )
+
+    def match(self, *args, **kwargs):
+        rtn = {}
+        for a in self.all_assignments:
+            partial_result = a.match(*args, **kwargs)
+            if partial_result:
+                rtn.update(partial_result)
+        return rtn
+
+    def subs(self, *args, **kwargs):
+        return AssignmentCollection(
+            main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments],
+            subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions]
+        )
+
+    def replace(self, *args, **kwargs):
+        return AssignmentCollection(
+            main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments],
+            subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions]
+        )
+
+    def __eq__(self, other):
+        return set(self.all_assignments) == set(other.all_assignments)
+
+    def __bool__(self):
+        return bool(self.all_assignments)
+
+
+class SymbolGen:
+    """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
+
+    def __init__(self, symbol="xi", dtype=None, ctr=0):
+        self._ctr = ctr
+        self._symbol = symbol
+        self._dtype = dtype
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        name = f"{self._symbol}_{self._ctr}"
+        self._ctr += 1
+        if self._dtype is not None:
+            return pystencils.TypedSymbol(name, self._dtype)
+        return sp.Symbol(name)
diff --git a/src/pystencils/sympyextensions/simplifications.py b/src/pystencils/simp/simplifications.py
similarity index 97%
rename from src/pystencils/sympyextensions/simplifications.py
rename to src/pystencils/simp/simplifications.py
index cdcad81e7..73d80ecd4 100644
--- a/src/pystencils/sympyextensions/simplifications.py
+++ b/src/pystencils/simp/simplifications.py
@@ -4,9 +4,9 @@ from collections import defaultdict
 
 import sympy as sp
 
-from .astnodes import Assignment
-from .math import subs_additive, is_constant, recursive_collect
-from .typed_sympy import TypedSymbol
+from ..assignment import Assignment
+from ..sympyextensions.math import subs_additive, is_constant, recursive_collect
+from ..sympyextensions.typed_sympy import TypedSymbol
 
 
 # TODO rewrite with SymPy AST
@@ -55,7 +55,7 @@ def sympy_cse(ac, **kwargs):
 
 def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
     """Extracts common subexpressions from a list of assignments."""
-    from pystencils.sympyextensions import AssignmentCollection
+    from pystencils.simp import AssignmentCollection
     ec = AssignmentCollection([], assignments)
     return sympy_cse(ec).all_assignments
 
diff --git a/src/pystencils/sympyextensions/simplificationstrategy.py b/src/pystencils/simp/simplificationstrategy.py
similarity index 99%
rename from src/pystencils/sympyextensions/simplificationstrategy.py
rename to src/pystencils/simp/simplificationstrategy.py
index b76d12711..22ffa34d0 100644
--- a/src/pystencils/sympyextensions/simplificationstrategy.py
+++ b/src/pystencils/simp/simplificationstrategy.py
@@ -3,7 +3,7 @@ from typing import Any, Callable, Optional, Sequence
 
 import sympy as sp
 
-from .astnodes import AssignmentCollection
+from ..simp import AssignmentCollection
 
 
 class SimplificationStrategy:
diff --git a/src/pystencils/sympyextensions/subexpression_insertion.py b/src/pystencils/simp/subexpression_insertion.py
similarity index 98%
rename from src/pystencils/sympyextensions/subexpression_insertion.py
rename to src/pystencils/simp/subexpression_insertion.py
index 8cedad466..0c5aca1dd 100644
--- a/src/pystencils/sympyextensions/subexpression_insertion.py
+++ b/src/pystencils/simp/subexpression_insertion.py
@@ -1,5 +1,5 @@
 import sympy as sp
-from .math import is_constant
+from ..sympyextensions.math import is_constant
 
 #   Subexpression Insertion
 
diff --git a/src/pystencils/sympyextensions/__init__.py b/src/pystencils/sympyextensions/__init__.py
index fd1145bcb..0b8b3690d 100644
--- a/src/pystencils/sympyextensions/__init__.py
+++ b/src/pystencils/sympyextensions/__init__.py
@@ -1,34 +1,5 @@
-from .astnodes import (
-    Assignment,
-    AugmentedAssignment,
-    AddAugmentedAssignment,
-    AssignmentCollection,
-    SymbolGen,
-    ConditionalFieldAccess
-)
+from .astnodes import ConditionalFieldAccess
 from .typed_sympy import TypedSymbol, CastFunc
-from .simplificationstrategy import SimplificationStrategy
-from .simplifications import (
-    sympy_cse,
-    sympy_cse_on_assignment_list,
-    apply_to_all_assignments,
-    apply_on_all_subexpressions,
-    subexpression_substitution_in_existing_subexpressions,
-    subexpression_substitution_in_main_assignments,
-    add_subexpressions_for_constants,
-    add_subexpressions_for_divisions,
-    add_subexpressions_for_sums,
-    add_subexpressions_for_field_reads
-)
-from .subexpression_insertion import (
-    insert_aliases,
-    insert_zeros,
-    insert_constants,
-    insert_constant_additions,
-    insert_constant_multiples,
-    insert_squares,
-    insert_symbol_times_minus_one,
-)
 
 from .math import (
     prod,
@@ -59,32 +30,9 @@ from .math import (
 
 
 __all__ = [
-    "Assignment",
-    "AugmentedAssignment",
-    "AddAugmentedAssignment",
-    "AssignmentCollection",
-    "SymbolGen",
     "ConditionalFieldAccess",
     "TypedSymbol",
     "CastFunc",
-    "SimplificationStrategy",
-    "sympy_cse",
-    "sympy_cse_on_assignment_list",
-    "apply_to_all_assignments",
-    "apply_on_all_subexpressions",
-    "subexpression_substitution_in_existing_subexpressions",
-    "subexpression_substitution_in_main_assignments",
-    "add_subexpressions_for_constants",
-    "add_subexpressions_for_divisions",
-    "add_subexpressions_for_sums",
-    "add_subexpressions_for_field_reads",
-    "insert_aliases",
-    "insert_zeros",
-    "insert_constants",
-    "insert_constant_additions",
-    "insert_constant_multiples",
-    "insert_squares",
-    "insert_symbol_times_minus_one",
     "remove_higher_order_terms",
     "prod",
     "remove_small_floats",
diff --git a/src/pystencils/sympyextensions/astnodes.py b/src/pystencils/sympyextensions/astnodes.py
index 68613b41a..74906cc5c 100644
--- a/src/pystencils/sympyextensions/astnodes.py
+++ b/src/pystencils/sympyextensions/astnodes.py
@@ -1,592 +1,4 @@
-from copy import copy
-import itertools
-import uuid
-from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
-
 import sympy as sp
-from sympy.codegen.ast import Assignment, AugmentedAssignment
-from sympy.codegen.ast import AddAugmentedAssignment as SpAddAugAssignment
-from sympy.printing.latex import LatexPrinter
-import numpy as np
-
-from .math import count_operations, fast_subs
-from .simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
-from .typed_sympy import create_type, TypedSymbol
-
-
-def print_assignment_latex(printer, expr):
-    binop = f"{expr.binop}=" if isinstance(expr, AugmentedAssignment) else ''
-    """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
-    printed_lhs = printer.doprint(expr.lhs)
-    printed_rhs = printer.doprint(expr.rhs)
-    return fr"{printed_lhs} \leftarrow_{{{binop}}} {printed_rhs}"
-
-
-def assignment_str(assignment):
-    op = f"{assignment.binop}=" if isinstance(assignment, AugmentedAssignment) else '←'
-    return fr"{assignment.lhs} {op} {assignment.rhs}"
-
-
-_old_new = sp.codegen.ast.Assignment.__new__
-
-
-# TODO Typing Part2 add default type, defult_float_type, default_int_type and use sane defaults
-def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
-    if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
-        assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
-        return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
-    return _old_new(cls, lhs, rhs, *args, **kwargs)
-
-
-Assignment.__str__ = assignment_str
-Assignment.__new__ = _Assignment__new__
-LatexPrinter._print_Assignment = print_assignment_latex
-
-AugmentedAssignment.__str__ = assignment_str
-LatexPrinter._print_AugmentedAssignment = print_assignment_latex
-
-sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
-
-#   Re-Export
-AddAugmentedAssignment = SpAddAugAssignment
-
-
-def assignment_from_stencil(stencil_array, input_field, output_field,
-                            normalization_factor=None, order='visual') -> Assignment:
-    """Creates an assignment
-
-    Args:
-        stencil_array: nested list of numpy array defining the stencil weights
-        input_field: field or field access, defining where the stencil should be applied to
-        output_field: field or field access where the result is written to
-        normalization_factor: optional normalization factor for the stencil
-        order: defines how the stencil_array is interpreted. Possible values are 'visual' and 'numpy'.
-               For details see examples
-
-    Returns:
-        Assignment that can be used to create a kernel
-
-    Examples:
-        >>> import pystencils as ps
-        >>> f, g = ps.fields("f, g: [2D]")
-        >>> stencil = [[0, 2, 0],
-        ...            [3, 4, 5],
-        ...            [0, 6, 0]]
-
-        By default 'visual ordering is used - i.e. the stencil is applied as the nested lists are written down
-        >>> expected_output = Assignment(g[0, 0], 3*f[-1, 0] + 6*f[0, -1] + 4*f[0, 0] + 2*f[0, 1] + 5*f[1, 0])
-        >>> assignment_from_stencil(stencil, f, g, order='visual') == expected_output
-        True
-
-        'numpy' ordering uses the first coordinate of the stencil array for x offset, second for y offset etc.
-        >>> expected_output = Assignment(g[0, 0], 2*f[-1, 0] + 3*f[0, -1] + 4*f[0, 0] + 5*f[0, 1] + 6*f[1, 0])
-        >>> assignment_from_stencil(stencil, f, g, order='numpy') == expected_output
-        True
-
-        You can also pass field accesses to apply the stencil at an already shifted position:
-        >>> expected_output = Assignment(g[2, 0], 3*f[0, 0] + 6*f[1, -1] + 4*f[1, 0] + 2*f[1, 1] + 5*f[2, 0])
-        >>> assignment_from_stencil(stencil, f[1, 0], g[2, 0]) == expected_output
-        True
-    """
-    from pystencils.field import Field
-
-    stencil_array = np.array(stencil_array)
-    if order == 'visual':
-        stencil_array = np.swapaxes(stencil_array, 0, 1)
-        stencil_array = np.flip(stencil_array, axis=1)
-    elif order == 'numpy':
-        pass
-    else:
-        raise ValueError("'order' has to be either 'visual' or 'numpy'")
-
-    if isinstance(input_field, Field):
-        input_field = input_field.center
-    if isinstance(output_field, Field):
-        output_field = output_field.center
-
-    rhs = 0
-    offset = tuple(s // 2 for s in stencil_array.shape)
-
-    for index, factor in np.ndenumerate(stencil_array):
-        shift = tuple(i - o for i, o in zip(index, offset))
-        rhs += factor * input_field.get_shifted(*shift)
-
-    if normalization_factor:
-        rhs *= normalization_factor
-
-    return Assignment(output_field, rhs)
-
-
-class AssignmentCollection:
-    """
-    A collection of equations with subexpression definitions, also represented as assignments,
-    that are used in the main equations. AssignmentCollection can be passed to simplification methods.
-    These simplification methods can change the subexpressions, but the number and
-    left hand side of the main equations themselves is not altered.
-    Additionally a dictionary of simplification hints is stored, which are set by the functions that create
-    assignment collections to transport information to the simplification system.
-
-    Args:
-        main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each
-                          assignment is a field access. Thus the generated equations write on arrays.
-        subexpressions: List of assignments defining subexpressions used in main equations
-        simplification_hints: Dict that is used to annotate the assignment collection with hints that are
-                              used by the simplification system. See documentation of the simplification rules for
-                              potentially required hints and their meaning.
-        subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added
-                                        used to get new symbols that are unique for this AssignmentCollection
-
-    """
-
-    __match_args__ = ("main_assignments", "subexpressions")
-
-    # ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
-
-    def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
-                 subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None,
-                 simplification_hints: Optional[Dict[str, Any]] = None,
-                 subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
-
-        if subexpressions is None:
-            subexpressions = {}
-
-        if isinstance(main_assignments, Dict):
-            main_assignments = [Assignment(k, v)
-                                for k, v in main_assignments.items()]
-        if isinstance(subexpressions, Dict):
-            subexpressions = [Assignment(k, v)
-                              for k, v in subexpressions.items()]
-
-        main_assignments = list(itertools.chain.from_iterable(
-            [(a if isinstance(a, Iterable) else [a]) for a in main_assignments]))
-        subexpressions = list(itertools.chain.from_iterable(
-            [(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))
-
-        self.main_assignments = main_assignments
-        self.subexpressions = subexpressions
-
-        if simplification_hints is None:
-            simplification_hints = {}
-
-        self.simplification_hints = simplification_hints
-
-        ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name]
-        max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0
-
-        if subexpression_symbol_generator is None:
-            self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr)
-        else:
-            self.subexpression_symbol_generator = subexpression_symbol_generator
-
-    def add_simplification_hint(self, key: str, value: Any) -> None:
-        """Adds an entry to the simplification_hints dictionary and checks that is does not exist yet."""
-        assert key not in self.simplification_hints, "This hint already exists"
-        self.simplification_hints[key] = value
-
-    def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol:
-        """Adds a subexpression to current collection.
-
-        Args:
-            rhs: right hand side of new subexpression
-            lhs: optional left hand side of new subexpression. If None a new unique symbol is generated.
-            topological_sort: sort the subexpressions topologically after insertion, to make sure that
-                              definition of a symbol comes before its usage. If False, subexpression is appended.
-
-        Returns:
-            left hand side symbol (which could have been generated)
-        """
-        if lhs is None:
-            lhs = next(self.subexpression_symbol_generator)
-        eq = Assignment(lhs, rhs)
-        self.subexpressions.append(eq)
-        if topological_sort:
-            self.topological_sort(sort_subexpressions=True,
-                                  sort_main_assignments=False)
-        return lhs
-
-    def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
-        """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition."""
-        if sort_subexpressions:
-            self.subexpressions = sort_assignments_topologically(self.subexpressions)
-        if sort_main_assignments:
-            self.main_assignments = sort_assignments_topologically(self.main_assignments)
-
-    # ---------------------------------------------- Properties  -------------------------------------------------------
-
-    @property
-    def all_assignments(self) -> List[Assignment]:
-        """Subexpression and main equations as a single list."""
-        return self.subexpressions + self.main_assignments
-
-    @property
-    def rhs_symbols(self) -> Set[sp.Symbol]:
-        """All symbols used in the assignment collection, which occur on the rhs of any assignment."""
-        rhs_symbols = set()
-        for eq in self.all_assignments:
-            if isinstance(eq, Assignment):
-                rhs_symbols.update(eq.rhs.atoms(sp.Symbol))
-            # TODO rewrite with SymPy AST
-            # elif isinstance(eq, pystencils.astnodes.Node):
-            #     rhs_symbols.update(eq.undefined_symbols)
-
-        return rhs_symbols
-
-    @property
-    def free_symbols(self) -> Set[sp.Symbol]:
-        """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
-        return self.rhs_symbols - self.bound_symbols
-
-    @property
-    def bound_symbols(self) -> Set[sp.Symbol]:
-        """All symbols which occur on the left hand side of a main assignment or a subexpression."""
-        bound_symbols_set = set(
-            [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)]
-        )
-
-        assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \
-            "Not in SSA form - same symbol assigned multiple times"
-
-        # TODO rewrite with SymPy AST
-        # bound_symbols_set = bound_symbols_set.union(*[
-        #     assignment.symbols_defined for assignment in self.all_assignments
-        #     if isinstance(assignment, pystencils.astnodes.Node)
-        # ])
-
-        return bound_symbols_set
-
-    @property
-    def rhs_fields(self):
-        """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
-        return {s.field for s in self.rhs_symbols if hasattr(s, 'field')}
-
-    @property
-    def free_fields(self):
-        """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
-        return {s.field for s in self.free_symbols if hasattr(s, 'field')}
-
-    @property
-    def bound_fields(self):
-        """All field accessed on the left hand side of a main assignment or a subexpression."""
-        return {s.field for s in self.bound_symbols if hasattr(s, 'field')}
-
-    @property
-    def defined_symbols(self) -> Set[sp.Symbol]:
-        """All symbols which occur as left-hand-sides of one of the main equations"""
-        lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)])
-        return lhs_set
-        # TODO rewrite with SymPy AST
-        # return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments
-        #                         if isinstance(assignment, pystencils.astnodes.Node)]))
-
-    @property
-    def operation_count(self):
-        """See :func:`count_operations` """
-        return count_operations(self.all_assignments, only_type=None)
-
-    def atoms(self, *args):
-        return set().union(*[a.atoms(*args) for a in self.all_assignments])
-
-    def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]:
-        """Returns all symbols that depend on one of the passed symbols.
-
-        A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some_expression(b)' i.e. when
-        'b' is required to compute 'a'.
-        """
-
-        queue = list(symbols)
-
-        def add_symbols_from_expr(expr):
-            dependent_symbols = expr.atoms(sp.Symbol)
-            for ds in dependent_symbols:
-                queue.append(ds)
-
-        handled_symbols = set()
-        assignment_dict = {e.lhs: e.rhs for e in self.all_assignments}
-
-        while len(queue) > 0:
-            e = queue.pop(0)
-            if e in handled_symbols:
-                continue
-            if e in assignment_dict:
-                add_symbols_from_expr(assignment_dict[e])
-            handled_symbols.add(e)
-
-        return handled_symbols
-
-    def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]] = None, module=None):
-        """Returns a python function to evaluate this equation collection.
-
-        Args:
-            symbols: symbol(s) which are the parameter for the created function
-            fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify
-            module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy'
-
-        Examples:
-              >>> a, b, c, d = sp.symbols("a b c d")
-              >>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)],
-              ...                           subexpressions=[Assignment(b, a + b / 2)])
-              >>> python_function = ac.lambdify([a], fixed_symbols={b: 2})
-              >>> python_function(4)
-              {c: 6, d: 18}
-        """
-        assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self
-        assignments = assignments.new_without_subexpressions().main_assignments
-        lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments}
-
-        def f(*args, **kwargs):
-            return {s: func(*args, **kwargs) for s, func in lambdas.items()}
-
-        return f
-
-    # ---------------------------- Creating new modified collections ---------------------------------------------------
-
-    def copy(self,
-             main_assignments: Optional[List[Assignment]] = None,
-             subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection':
-        """Returns a copy with optionally replaced main_assignments and/or subexpressions."""
-
-        res = copy(self)
-        res.simplification_hints = self.simplification_hints.copy()
-        res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator)
-
-        if main_assignments is not None:
-            res.main_assignments = main_assignments
-        else:
-            res.main_assignments = self.main_assignments.copy()
-
-        if subexpressions is not None:
-            res.subexpressions = subexpressions
-        else:
-            res.subexpressions = self.subexpressions.copy()
-
-        return res
-
-    def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False,
-                               substitute_on_lhs: bool = True,
-                               sort_topologically: bool = True) -> 'AssignmentCollection':
-        """Returns new object, where terms are substituted according to the passed substitution dict.
-
-        Args:
-            substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions
-            add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions
-            substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments
-            sort_topologically: if subexpressions are added as substitutions and this parameters is true,
-                                the subexpressions are sorted topologically after insertion
-        Returns:
-            New AssignmentCollection where substitutions have been applied, self is not altered.
-        """
-        transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs
-        transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions)
-        transformed_assignments = transform(self.main_assignments, fast_subs, substitutions)
-
-        if add_substitutions_as_subexpressions:
-            transformed_subexpressions = [Assignment(b, a) for a, b in
-                                          substitutions.items()] + transformed_subexpressions
-            if sort_topologically:
-                transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)
-        return self.copy(transformed_assignments, transformed_subexpressions)
-
-    def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
-        """Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
-        own_definitions = set([e.lhs for e in self.main_assignments])
-        other_definitions = set([e.lhs for e in other.main_assignments])
-        assert len(own_definitions.intersection(other_definitions)) == 0, \
-            "Cannot merge collections, since both define the same symbols"
-
-        own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
-        substitution_dict = {}
-
-        processed_other_subexpression_equations = []
-        for other_subexpression_eq in other.subexpressions:
-            if other_subexpression_eq.lhs in own_subexpression_symbols:
-                if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
-                    continue  # exact the same subexpression equation exists already
-                else:
-                    # different definition - a new name has to be introduced
-                    new_lhs = next(self.subexpression_symbol_generator)
-                    new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict))
-                    processed_other_subexpression_equations.append(new_eq)
-                    substitution_dict[other_subexpression_eq.lhs] = new_lhs
-            else:
-                processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict))
-
-        processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments]
-        return self.copy(self.main_assignments + processed_other_main_assignments,
-                         self.subexpressions + processed_other_subexpression_equations)
-
-    def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection':
-        """Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions.
-
-        Returns:
-            new AssignmentCollection, self is not altered
-        """
-        symbols_to_extract = set(symbols_to_extract)
-        dependent_symbols = self.dependent_symbols(symbols_to_extract)
-        new_assignments = []
-        for eq in self.all_assignments:
-            if eq.lhs in symbols_to_extract:
-                new_assignments.append(eq)
-
-        new_sub_expr = [eq for eq in self.all_assignments
-                        if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]
-        return self.copy(new_assignments, new_sub_expr)
-
-    def new_without_unused_subexpressions(self) -> 'AssignmentCollection':
-        """Returns new collection that only contains subexpressions required to compute the main assignments."""
-        all_lhs = [eq.lhs for eq in self.main_assignments]
-        return self.new_filtered(all_lhs)
-
-    def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection':
-        """Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere."""
-        new_subexpressions = []
-        subs_dict = None
-        for se in self.subexpressions:
-            if se.lhs == symbol:
-                subs_dict = {se.lhs: se.rhs}
-            else:
-                new_subexpressions.append(se)
-        if subs_dict is None:
-            return self
-
-        new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions]
-        new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]
-        return self.copy(new_eqs, new_subexpressions)
-
-    def new_without_subexpressions(self, subexpressions_to_keep=None) -> 'AssignmentCollection':
-        """Returns a new collection where all subexpressions have been inserted."""
-        if subexpressions_to_keep is None:
-            subexpressions_to_keep = set()
-        if len(self.subexpressions) == 0:
-            return self.copy()
-
-        subexpressions_to_keep = set(subexpressions_to_keep)
-
-        kept_subexpressions = []
-        if self.subexpressions[0].lhs in subexpressions_to_keep:
-            substitution_dict = {}
-            kept_subexpressions.append(self.subexpressions[0])
-        else:
-            substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
-
-        subexpression = [e for e in self.subexpressions]
-        for i in range(1, len(subexpression)):
-            subexpression[i] = fast_subs(subexpression[i], substitution_dict)
-            if subexpression[i].lhs in subexpressions_to_keep:
-                kept_subexpressions.append(subexpression[i])
-            else:
-                substitution_dict[subexpression[i].lhs] = subexpression[i].rhs
-
-        new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments]
-        return self.copy(new_assignment, kept_subexpressions)
-
-    # ----------------------------------------- Display and Printing   -------------------------------------------------
-
-    def _repr_html_(self):
-        """Interface to Jupyter notebook, to display as a nicely formatted HTML table"""
-
-        def make_html_equation_table(equations):
-            no_border = 'style="border:none"'
-            html_table = '<table style="border:none; width: 100%; ">'
-            line = '<tr {nb}> <td {nb}>$${eq}$$</td>  </tr> '
-            for eq in equations:
-                format_dict = {'eq': sp.latex(eq),
-                               'nb': no_border, }
-                html_table += line.format(**format_dict)
-            html_table += "</table>"
-            return html_table
-
-        result = ""
-        if len(self.subexpressions) > 0:
-            result += "<div>Subexpressions:</div>"
-            result += make_html_equation_table(self.subexpressions)
-        result += "<div>Main Assignments:</div>"
-        result += make_html_equation_table(self.main_assignments)
-        return result
-
-    def __repr__(self):
-        return f"AssignmentCollection: {str(tuple(self.defined_symbols))[1:-1]} <- f{tuple(self.free_symbols)}"
-
-    def __str__(self):
-        result = "Subexpressions:\n"
-        for eq in self.subexpressions:
-            result += f"\t{eq}\n"
-        result += "Main Assignments:\n"
-        for eq in self.main_assignments:
-            result += f"\t{eq}\n"
-        return result
-
-    def __iter__(self):
-        return self.all_assignments.__iter__()
-
-    @property
-    def main_assignments_dict(self):
-        return {a.lhs: a.rhs for a in self.main_assignments}
-
-    @property
-    def subexpressions_dict(self):
-        return {a.lhs: a.rhs for a in self.subexpressions}
-
-    def set_main_assignments_from_dict(self, main_assignments_dict):
-        self.main_assignments = [Assignment(k, v)
-                                 for k, v in main_assignments_dict.items()]
-
-    def set_sub_expressions_from_dict(self, sub_expressions_dict):
-        self.subexpressions = [Assignment(k, v)
-                               for k, v in sub_expressions_dict.items()]
-
-    def find(self, *args, **kwargs):
-        return set.union(
-            *[a.find(*args, **kwargs) for a in self.all_assignments]
-        )
-
-    def match(self, *args, **kwargs):
-        rtn = {}
-        for a in self.all_assignments:
-            partial_result = a.match(*args, **kwargs)
-            if partial_result:
-                rtn.update(partial_result)
-        return rtn
-
-    def subs(self, *args, **kwargs):
-        return AssignmentCollection(
-            main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments],
-            subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions]
-        )
-
-    def replace(self, *args, **kwargs):
-        return AssignmentCollection(
-            main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments],
-            subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions]
-        )
-
-    def __eq__(self, other):
-        return set(self.all_assignments) == set(other.all_assignments)
-
-    def __bool__(self):
-        return bool(self.all_assignments)
-
-
-class SymbolGen:
-    """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
-
-    def __init__(self, symbol="xi", dtype=None, ctr=0):
-        self._ctr = ctr
-        self._symbol = symbol
-        self._dtype = dtype
-
-    def __iter__(self):
-        return self
-
-    def __next__(self):
-        name = f"{self._symbol}_{self._ctr}"
-        self._ctr += 1
-        if self._dtype is not None:
-            return TypedSymbol(name, self._dtype)
-        return sp.Symbol(name)
-
-
-def get_dummy_symbol(dtype='bool'):
-    return TypedSymbol(f'dummy{uuid.uuid4().hex}', create_type(dtype))
 
 
 class ConditionalFieldAccess(sp.Function):
@@ -618,6 +30,8 @@ class ConditionalFieldAccess(sp.Function):
 
 
 def generic_visit(term, visitor):
+    from pystencils import AssignmentCollection, Assignment
+
     if isinstance(term, AssignmentCollection):
         new_main_assignments = generic_visit(term.main_assignments, visitor)
         new_subexpressions = generic_visit(term.subexpressions, visitor)
diff --git a/src/pystencils/sympyextensions/fast_approximation.py b/src/pystencils/sympyextensions/fast_approximation.py
index 9088348fb..d9656025e 100644
--- a/src/pystencils/sympyextensions/fast_approximation.py
+++ b/src/pystencils/sympyextensions/fast_approximation.py
@@ -2,7 +2,8 @@ from typing import List, Union
 
 import sympy as sp
 
-from pystencils.sympyextensions import AssignmentCollection, Assignment
+from ..assignment import Assignment
+from ..simp import AssignmentCollection
 
 
 # noinspection PyPep8Naming
diff --git a/src/pystencils/sympyextensions/math.py b/src/pystencils/sympyextensions/math.py
index a2df9458e..1a006efe6 100644
--- a/src/pystencils/sympyextensions/math.py
+++ b/src/pystencils/sympyextensions/math.py
@@ -10,7 +10,7 @@ from sympy import PolynomialError
 from sympy.functions import Abs
 from sympy.core.numbers import Zero
 
-from .astnodes import Assignment
+from ..assignment import Assignment
 from .typed_sympy import CastFunc, FieldPointerSymbol
 from ..types import PsPointerType, PsVectorType
 
diff --git a/tests/nbackend/kernelcreation/test_domain_kernels.py b/tests/nbackend/kernelcreation/test_domain_kernels.py
index 9ce2f661d..c9cc81abb 100644
--- a/tests/nbackend/kernelcreation/test_domain_kernels.py
+++ b/tests/nbackend/kernelcreation/test_domain_kernels.py
@@ -2,7 +2,7 @@ import sympy as sp
 import numpy as np
 
 from pystencils import fields, Field, AssignmentCollection
-from pystencils.sympyextensions.astnodes import assignment_from_stencil
+from pystencils.assignment import assignment_from_stencil
 
 from pystencils.kernelcreation import create_kernel
 
diff --git a/tests/symbolics/test_address_of.py b/tests/symbolics/test_address_of.py
index da11ecbe5..99f33ddbd 100644
--- a/tests/symbolics/test_address_of.py
+++ b/tests/symbolics/test_address_of.py
@@ -6,7 +6,7 @@ import pystencils
 from pystencils.types import PsPointerType, create_type
 from pystencils.sympyextensions.pointers import AddressOf
 from pystencils.sympyextensions.typed_sympy import CastFunc
-from pystencils.sympyextensions import sympy_cse
+from pystencils.simp import sympy_cse
 
 import sympy as sp
 
diff --git a/tests/symbolics/test_conditional_field_access.py b/tests/symbolics/test_conditional_field_access.py
index bd384a959..e18ffc56a 100644
--- a/tests/symbolics/test_conditional_field_access.py
+++ b/tests/symbolics/test_conditional_field_access.py
@@ -16,7 +16,7 @@ import sympy as sp
 import pystencils as ps
 from pystencils import Field, x_vector
 from pystencils.sympyextensions.astnodes import ConditionalFieldAccess
-from pystencils.sympyextensions import sympy_cse
+from pystencils.simp import sympy_cse
 
 
 def add_fixed_constant_boundary_handling(assignments, with_cse):
diff --git a/tests/test_fvm.py b/tests/test_fvm.py
index e10874cbd..0b103e525 100644
--- a/tests/test_fvm.py
+++ b/tests/test_fvm.py
@@ -3,7 +3,7 @@ import pystencils as ps
 import numpy as np
 import pytest
 from itertools import product
-from pystencils.sympyextensions.rng import random_symbol
+from pystencils.rng import random_symbol
 from pystencils.sympyextensions.astnodes import SympyAssignment
 from pystencils.node_collection import NodeCollection
 
diff --git a/tests/test_random.py b/tests/test_random.py
index 6c3a888db..63cd61494 100644
--- a/tests/test_random.py
+++ b/tests/test_random.py
@@ -5,7 +5,7 @@ import pytest
 import pystencils as ps
 from pystencils.sympyextensions.astnodes import SympyAssignment
 from pystencils.node_collection import NodeCollection
-from pystencils.sympyextensions.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles, random_symbol
+from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles, random_symbol
 from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets
 from pystencils.cpu.cpujit import get_compiler_config
 from pystencils.typing import TypedSymbol
-- 
GitLab