diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 0c49160c8ba6beb4e9aa58707ec2697d0f47663a..f50ef65a938f48c9102b97b8969fe99019b59b70 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -303,7 +303,7 @@ class SkipIteration(Node):
 
 
 class Block(Node):
-    def __init__(self, nodes: List[Node]):
+    def __init__(self, nodes: Union[Node, List[Node]]):
         super(Block, self).__init__()
         if not isinstance(nodes, list):
             nodes = [nodes]
diff --git a/pystencils/gpu/kernelcreation.py b/pystencils/gpu/kernelcreation.py
index c0d6e71d05e3a2250249224e3a12e3daa30978d0..e3ad451bb154c2cee279713fdfd2167a4268b3fb 100644
--- a/pystencils/gpu/kernelcreation.py
+++ b/pystencils/gpu/kernelcreation.py
@@ -1,5 +1,3 @@
-from typing import Union
-
 from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
 from pystencils.config import CreateKernelConfig
 from pystencils.typing import StructType, TypedSymbol
@@ -9,15 +7,13 @@ from pystencils.enums import Target, Backend
 from pystencils.gpu.gpujit import make_python_function
 from pystencils.node_collection import NodeCollection
 from pystencils.gpu.indexing import indexing_creator_from_params
-from pystencils.simp.assignment_collection import AssignmentCollection
 from pystencils.slicing import normalize_slice
 from pystencils.transformations import (
     get_base_buffer_index, get_common_field, parse_base_pointer_info,
     resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols)
 
 
-def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
-                       config: CreateKernelConfig):
+def create_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig):
 
     function_name = config.function_name
     indexing_creator = indexing_creator_from_params(config.gpu_indexing, config.gpu_indexing_params)
@@ -114,31 +110,24 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
     return ast
 
 
-def created_indexed_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
-                                config: CreateKernelConfig):
+def created_indexed_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig):
 
     index_fields = config.index_fields
     function_name = config.function_name
     coordinate_names = config.coordinate_names
     indexing_creator = indexing_creator_from_params(config.gpu_indexing, config.gpu_indexing_params)
-
     fields_written = assignments.bound_fields
     fields_read = assignments.rhs_fields
-    assignments = assignments.all_assignments
-
-    assignments = add_types(assignments, config)
 
     all_fields = fields_read.union(fields_written)
     read_only_fields = set([f.name for f in fields_read - fields_written])
-
-    for index_field in index_fields:
-        index_field.field_type = FieldType.INDEXED
-        assert FieldType.is_indexed(index_field)
-        assert index_field.spatial_dimensions == 1, "Index fields have to be 1D"
-
+    # extract the index fields based on the name. The original index field might have been modified
+    index_fields = [idx_field for idx_field in index_fields if idx_field.name in [f.name for f in all_fields]]
     non_index_fields = [f for f in all_fields if f not in index_fields]
     spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
-    assert len(spatial_coordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
+    assert len(spatial_coordinates) == 1, f"Non-index fields do not have the same number of spatial coordinates " \
+                                          f"Non index fields are {non_index_fields}, spatial coordinates are " \
+                                          f"{spatial_coordinates}"
     spatial_coordinates = list(spatial_coordinates)[0]
 
     def get_coordinate_symbol_assignment(name):
diff --git a/pystencils/node_collection.py b/pystencils/node_collection.py
index 352406566e47c9279bd486410d34f6b6b9bfff53..e0af05fd055bf8df0443c7f9401e352138f6b303 100644
--- a/pystencils/node_collection.py
+++ b/pystencils/node_collection.py
@@ -1,34 +1,42 @@
-from collections.abc import Iterable
 from typing import Any, Dict, List, Union, Optional, Set
 
 import sympy
 import sympy as sp
-from sympy.codegen.ast import Assignment, AddAugmentedAssignment
 from sympy.codegen.rewriting import ReplaceOptim, optimize
 
-from pystencils.astnodes import Block, Node, SympyAssignment
+from pystencils.assignment import Assignment, AddAugmentedAssignment
+import pystencils.astnodes as ast
 from pystencils.backends.cbackend import CustomCodeNode
 from pystencils.functions import DivFunc
 from pystencils.simp import AssignmentCollection
 
 
 class NodeCollection:
-    def __init__(self, assignments: List[Union[Node, Assignment]],
+    def __init__(self, assignments: List[Union[ast.Node, Assignment]],
                  simplification_hints: Optional[Dict[str, Any]] = None,
                  bound_fields: Set[sp.Symbol] = None, rhs_fields: Set[sp.Symbol] = None):
-        nodes = list()
-        assignments = [assignments, ] if not isinstance(assignments, Iterable) else assignments
-        for assignment in assignments:
-            if isinstance(assignment, Assignment):
-                nodes.append(SympyAssignment(assignment.lhs, assignment.rhs))
-            elif isinstance(assignment, AddAugmentedAssignment):
-                nodes.append(SympyAssignment(assignment.lhs, assignment.lhs + assignment.rhs))
-            elif isinstance(assignment, Node):
-                nodes.append(assignment)
+
+        def visit(obj):
+            if isinstance(obj, (list, tuple)):
+                return [visit(e) for e in obj]
+            if isinstance(obj, Assignment):
+                return ast.SympyAssignment(obj.lhs, obj.rhs)
+            elif isinstance(obj, AddAugmentedAssignment):
+                return ast.SympyAssignment(obj.lhs, obj.lhs + obj.rhs)
+            elif isinstance(obj, ast.SympyAssignment):
+                return obj
+            elif isinstance(obj, ast.Conditional):
+                true_block = visit(obj.true_block)
+                false_block = None if obj.false_block is None else visit(obj.false_block)
+                return ast.Conditional(obj.condition_expr, true_block=true_block, false_block=false_block)
+            elif isinstance(obj, ast.Block):
+                return ast.Block([visit(e) for e in obj.args])
+            elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
+                return obj
             else:
-                raise ValueError(f"Unknown node in the AssignmentCollection: {assignment}")
+                raise ValueError("Invalid object in the List of Assignments " + str(type(obj)))
 
-        self.all_assignments = nodes
+        self.all_assignments = visit(assignments)
         self.simplification_hints = simplification_hints if simplification_hints else {}
         self.bound_fields = bound_fields if bound_fields else {}
         self.rhs_fields = rhs_fields if rhs_fields else {}
@@ -57,13 +65,13 @@ class NodeCollection:
         def visitor(node):
             if isinstance(node, CustomCodeNode):
                 return node
-            elif isinstance(node, Block):
+            elif isinstance(node, ast.Block):
                 return node.func([visitor(child) for child in node.args])
-            elif isinstance(node, SympyAssignment):
+            elif isinstance(node, ast.SympyAssignment):
                 new_lhs = visitor(node.lhs)
                 new_rhs = visitor(node.rhs)
                 return node.func(new_lhs, new_rhs, node.is_const, node.use_auto)
-            elif isinstance(node, Node):
+            elif isinstance(node, ast.Node):
                 return node.func(*[visitor(child) for child in node.args])
             elif isinstance(node, sympy.Basic):
                 return optimize(node, sympy_optimisations)