diff --git a/pystencils/gpu/indexing.py b/pystencils/gpu/indexing.py
index 05837445af9649906498ee988b18a4bb973804dc..70af0ab315a2e929d9940bd744ada2aff4c165ce 100644
--- a/pystencils/gpu/indexing.py
+++ b/pystencils/gpu/indexing.py
@@ -1,6 +1,7 @@
 import abc
 from functools import partial
 import math
+from typing import Tuple
 
 import sympy as sp
 from sympy.core.cache import cacheit
@@ -8,7 +9,6 @@ from sympy.core.cache import cacheit
 from pystencils.astnodes import Block, Conditional
 from pystencils.typing import TypedSymbol, create_type
 from pystencils.integer_functions import div_ceil, div_floor
-from pystencils.slicing import normalize_slice
 from pystencils.sympyextensions import is_integer_sequence, prod
 
 
@@ -33,12 +33,34 @@ GRID_DIM = [ThreadIndexingSymbol("gridDim." + coord, create_type("int32")) for c
 
 class AbstractIndexing(abc.ABC):
     """
-    Abstract base class for all Indexing classes. An Indexing class defines how a multidimensional
-    field is mapped to GPU's block and grid system. It calculates indices based on GPU's thread and block indices
-    and computes the number of blocks and threads a kernel is started with. The Indexing class is created with
-    a pystencils field, a slice to iterate over, and further optional parameters that must have default values.
+    Abstract base class for all Indexing classes. An Indexing class defines how an iteration space is mapped
+    to GPU's block and grid system. It calculates indices based on GPU's thread and block indices
+    and computes the number of blocks and threads a kernel is started with.
+    The Indexing class is created with an iteration space that is given as list of slices to determine start, stop
+    and the step size for each coordinate. Further the data_layout is given as tuple to determine the fast and slow
+    coordinates. This is important to get an optimal mapping of coordinates to GPU threads.
     """
 
+    def __init__(self, iteration_space: Tuple[slice], data_layout: Tuple):
+        self._iteration_space = iteration_space
+        self._data_layout = data_layout
+        self._dim = len(iteration_space)
+
+    @property
+    def iteration_space(self):
+        """Iteration space to loop over"""
+        return self._iteration_space
+
+    @property
+    def data_layout(self):
+        """Data layout of the kernels arrays"""
+        return self._data_layout
+
+    @property
+    def dim(self):
+        """Number of spatial dimensions"""
+        return self._dim
+
     @property
     @abc.abstractmethod
     def coordinates(self):
@@ -92,8 +114,8 @@ class BlockIndexing(AbstractIndexing):
     """Generic indexing scheme that maps sub-blocks of an array to GPU blocks.
 
     Args:
-        field: pystencils field (common to all Indexing classes)
-        iteration_slice: slice that defines rectangular subarea which is iterated over
+        iteration_space: list of slices to determine start, stop and the step size for each coordinate
+        data_layout: tuple to determine the fast and slow coordinates.
         permute_block_size_dependent_on_layout: if True the block_size is permuted such that the fastest coordinate
                                                 gets the largest amount of threads
         compile_time_block_size: compile in concrete block size, otherwise the gpu variable 'blockDim' is used
@@ -102,14 +124,16 @@ class BlockIndexing(AbstractIndexing):
         device_number: device number of the used GPU. By default, the zeroth device is used.
     """
 
-    def __init__(self, field, iteration_slice,
+    def __init__(self, iteration_space: Tuple[slice], data_layout: Tuple,
                  block_size=(16, 16, 1), permute_block_size_dependent_on_layout=True, compile_time_block_size=False,
                  maximum_block_size=(1024, 1024, 64), device_number=None):
-        if field.spatial_dimensions > 3:
+        super(BlockIndexing, self).__init__(iteration_space, data_layout)
+
+        if len(iteration_space) > 3:
             raise NotImplementedError("This indexing scheme supports at most 3 spatial dimensions")
 
         if permute_block_size_dependent_on_layout:
-            block_size = self.permute_block_size_according_to_layout(block_size, field.layout)
+            block_size = self.permute_block_size_according_to_layout(block_size, data_layout)
 
         self._block_size = block_size
         if maximum_block_size == 'auto':
@@ -124,9 +148,6 @@ class BlockIndexing(AbstractIndexing):
                 maximum_block_size = tuple(da[f"MaxBlockDim{c}"] for c in ["X", "Y", "Z"])
 
         self._maximum_block_size = maximum_block_size
-        self._iterationSlice = normalize_slice(iteration_slice, field.spatial_shape)
-        self._dim = field.spatial_dimensions
-        self._symbolic_shape = [e if isinstance(e, sp.Basic) else None for e in field.spatial_shape]
         self._compile_time_block_size = compile_time_block_size
         self._device_number = device_number
 
@@ -140,17 +161,13 @@ class BlockIndexing(AbstractIndexing):
 
     @property
     def coordinates(self):
-        offsets = _get_start_from_slice(self._iterationSlice)
-        coordinates = [c + off for c, off in zip(self.cuda_indices, offsets)]
-
+        coordinates = [c + iter_slice.start for c, iter_slice in zip(self.cuda_indices, self._iteration_space)]
         return coordinates[:self._dim]
 
     def call_parameters(self, arr_shape):
-        substitution_dict = {sym: value for sym, value in zip(self._symbolic_shape, arr_shape) if sym is not None}
+        numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
+        widths = [len(range(*s.indices(s.stop))) for s in numeric_iteration_slice]
 
-        widths = [end - start for start, end in zip(_get_start_from_slice(self._iterationSlice),
-                                                    _get_end_from_slice(self._iterationSlice, arr_shape))]
-        widths = sp.Matrix(widths).subs(substitution_dict)
         extend_bs = (1,) * (3 - len(self._block_size))
         block_size = self._block_size + extend_bs
         if not self._compile_time_block_size:
@@ -171,10 +188,11 @@ class BlockIndexing(AbstractIndexing):
 
     def guard(self, kernel_content, arr_shape):
         arr_shape = arr_shape[:self._dim]
-        end = _get_end_from_slice(self._iterationSlice, arr_shape)
+        numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
+        end = [s.stop for s in numeric_iteration_slice]
 
         conditions = [c < e for c, e in zip(self.coordinates, end)]
-        for cuda_index, iter_slice in zip(self.cuda_indices, self._iterationSlice):
+        for cuda_index, iter_slice in zip(self.cuda_indices, self._iteration_space):
             if isinstance(iter_slice, slice) and iter_slice.step > 1:
                 conditions.append(sp.Eq(sp.Mod(cuda_index, iter_slice.step), 0))
 
@@ -183,8 +201,8 @@ class BlockIndexing(AbstractIndexing):
             condition = sp.And(condition, c)
         return Block([Conditional(condition, kernel_content)])
 
-    def iteration_space(self, arr_shape):
-        return _iteration_space(self._iterationSlice, arr_shape)
+    def numeric_iteration_space(self, arr_shape):
+        return _get_numeric_iteration_slice(self._iteration_space, arr_shape)
 
     def limit_block_size_by_register_restriction(self, block_size, required_registers_per_thread):
         """Shrinks the block_size if there are too many registers used per block.
@@ -246,38 +264,41 @@ class LineIndexing(AbstractIndexing):
     The fastest coordinate is indexed with thread_idx.x, the remaining coordinates are mapped to block_idx.{x,y,z}
     This indexing scheme supports up to 4 spatial dimensions, where the innermost dimensions is not larger than the
     maximum amount of threads allowed in a GPU block (which depends on device).
+
+    Args:
+        iteration_space: list of slices to determine start, stop and the step size for each coordinate
+        data_layout: tuple to determine the fast and slow coordinates.
     """
 
-    def __init__(self, field, iteration_slice):
-        available_indices = [THREAD_IDX[0]] + BLOCK_IDX
-        if field.spatial_dimensions > 4:
+    def __init__(self,  iteration_space: Tuple[slice], data_layout: Tuple):
+        super(LineIndexing, self).__init__(iteration_space, data_layout)
+
+        if len(iteration_space) > 4:
             raise NotImplementedError("This indexing scheme supports at most 4 spatial dimensions")
 
-        coordinates = available_indices[:field.spatial_dimensions]
+    @property
+    def cuda_indices(self):
+        available_indices = [THREAD_IDX[0]] + BLOCK_IDX
+        coordinates = available_indices[:self.dim]
 
-        fastest_coordinate = field.layout[-1]
+        fastest_coordinate = self.data_layout[-1]
         coordinates[0], coordinates[fastest_coordinate] = coordinates[fastest_coordinate], coordinates[0]
 
-        self._coordinates = coordinates
-        self._iterationSlice = normalize_slice(iteration_slice, field.spatial_shape)
-        self._symbolicShape = [e if isinstance(e, sp.Basic) else None for e in field.spatial_shape]
+        return coordinates
 
     @property
     def coordinates(self):
-        return [i + offset for i, offset in zip(self._coordinates, _get_start_from_slice(self._iterationSlice))]
+        return [i + o.start for i, o in zip(self.cuda_indices, self._iteration_space)]
 
     def call_parameters(self, arr_shape):
-        substitution_dict = {sym: value for sym, value in zip(self._symbolicShape, arr_shape) if sym is not None}
-
-        widths = [end - start for start, end in zip(_get_start_from_slice(self._iterationSlice),
-                                                    _get_end_from_slice(self._iterationSlice, arr_shape))]
-        widths = sp.Matrix(widths).subs(substitution_dict)
+        numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
+        widths = [len(range(*s.indices(s.stop))) for s in numeric_iteration_slice]
 
         def get_shape_of_cuda_idx(cuda_idx):
-            if cuda_idx not in self._coordinates:
+            if cuda_idx not in self.cuda_indices:
                 return 1
             else:
-                idx = self._coordinates.index(cuda_idx)
+                idx = self.cuda_indices.index(cuda_idx)
                 return widths[idx]
 
         return {'block': tuple([get_shape_of_cuda_idx(idx) for idx in THREAD_IDX]),
@@ -292,52 +313,25 @@ class LineIndexing(AbstractIndexing):
     def symbolic_parameters(self):
         return set()
 
-    def iteration_space(self, arr_shape):
-        return _iteration_space(self._iterationSlice, arr_shape)
+    def numeric_iteration_space(self, arr_shape):
+        return _get_numeric_iteration_slice(self._iteration_space, arr_shape)
 
 
 # -------------------------------------- Helper functions --------------------------------------------------------------
 
-def _get_start_from_slice(iteration_slice):
+def _get_numeric_iteration_slice(iteration_slice, arr_shape):
     res = []
-    for slice_component in iteration_slice:
-        if type(slice_component) is slice:
-            res.append(slice_component.start if slice_component.start is not None else 0)
+    for slice_component, shape in zip(iteration_slice, arr_shape):
+        if not isinstance(slice_component.stop, int):
+            stop = slice_component.stop
+            assert len(stop.free_symbols) == 1
+            stop = stop.subs({symbol: shape for symbol in stop.free_symbols})
+            res.append(slice(slice_component.start, stop, slice_component.step))
         else:
-            assert isinstance(slice_component, int)
             res.append(slice_component)
     return res
 
 
-def _get_end_from_slice(iteration_slice, arr_shape):
-    iter_slice = normalize_slice(iteration_slice, arr_shape)
-    res = []
-    for slice_component in iter_slice:
-        if type(slice_component) is slice:
-            res.append(slice_component.stop)
-        else:
-            assert isinstance(slice_component, int)
-            res.append(slice_component + 1)
-    return res
-
-
-def _get_steps_from_slice(iteration_slice):
-    res = []
-    for slice_component in iteration_slice:
-        if type(slice_component) is slice:
-            res.append(slice_component.step)
-        else:
-            res.append(1)
-    return res
-
-
-def _iteration_space(iteration_slice, arr_shape):
-    starts = _get_start_from_slice(iteration_slice)
-    ends = _get_end_from_slice(iteration_slice, arr_shape)
-    steps = _get_steps_from_slice(iteration_slice)
-    return [slice(start, end, step) for start, end, step in zip(starts, ends, steps)]
-
-
 def indexing_creator_from_params(gpu_indexing, gpu_indexing_params):
     if isinstance(gpu_indexing, str):
         if gpu_indexing == 'block':
diff --git a/pystencils/gpu/kernelcreation.py b/pystencils/gpu/kernelcreation.py
index 066038cde246934d1694ee20613434ac9e3dc678..e8ac135c3681b7284ac0e048fa6cc1104f65e640 100644
--- a/pystencils/gpu/kernelcreation.py
+++ b/pystencils/gpu/kernelcreation.py
@@ -12,6 +12,7 @@ from pystencils.gpu.gpujit import make_python_function
 from pystencils.node_collection import NodeCollection
 from pystencils.gpu.indexing import indexing_creator_from_params
 from pystencils.simp.assignment_collection import AssignmentCollection
+from pystencils.slicing import normalize_slice
 from pystencils.transformations import (
     get_base_buffer_index, get_common_field, parse_base_pointer_info,
     resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols)
@@ -64,7 +65,11 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
                 iteration_slice.append(slice(ghost_layers[i][0],
                                              -ghost_layers[i][1] if ghost_layers[i][1] > 0 else None))
 
-    indexing = indexing_creator(field=common_field, iteration_slice=iteration_slice)
+        iteration_space = normalize_slice(iteration_slice, common_shape)
+    else:
+        iteration_space = normalize_slice(iteration_slice, common_shape)
+
+    indexing = indexing_creator(iteration_space=iteration_space, data_layout=common_field.layout)
     coord_mapping = indexing.coordinates
 
     cell_idx_assignments = [SympyAssignment(LoopOverCoordinate.get_loop_counter_symbol(i), value)
@@ -94,7 +99,6 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
     coord_mapping = {f.name: cell_idx_symbols for f in all_fields}
 
     if any(FieldType.is_buffer(f) for f in all_fields):
-        iteration_space = indexing.iteration_space(common_shape)
         resolve_buffer_accesses(ast, get_base_buffer_index(ast, cell_idx_symbols, iteration_space), read_only_fields)
 
     resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_info,