diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py
index 915d5601b5fe7bc05b9833c21da62915b22552e5..cfef01e9052a35db0af0019cca830ac6ae7db1c7 100644
--- a/pystencils/backends/cuda_backend.py
+++ b/pystencils/backends/cuda_backend.py
@@ -8,6 +8,7 @@ from pystencils.astnodes import Node
 from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
 from pystencils.data_types import cast_func, create_type, get_type_of_expression
 from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
+from pystencils.interpolation_astnodes import InterpolationMode
 
 with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
     lines = f.readlines()
@@ -77,7 +78,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
     def _print_TextureAccess(self, node):
         dtype = node.texture.field.dtype.numpy_dtype
 
-        if node.texture.cubic_bspline_interpolation:
+        if node.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
             template = "cubicTex%iDSimple<%s>(%s, %s)"
         else:
             if dtype.itemsize > 4:
diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py
index ee78b4663c4f40ebf3de0f96ebd405046a75de3c..8a403fd76828620e471dc07161a33faf1a5f0c6b 100644
--- a/pystencils/gpucuda/cudajit.py
+++ b/pystencils/gpucuda/cudajit.py
@@ -7,7 +7,7 @@ from pystencils.data_types import StructType
 from pystencils.field import FieldType
 from pystencils.gpucuda.texture_utils import ndarray_to_tex
 from pystencils.include import get_pystencils_include_path
-from pystencils.interpolation_astnodes import TextureAccess
+from pystencils.interpolation_astnodes import InterpolationMode, TextureAccess
 from pystencils.kernelparameters import FieldPointerSymbol
 
 USE_FAST_MATH = True
@@ -46,7 +46,7 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
     if USE_FAST_MATH:
         nvcc_options.append("-use_fast_math")
 
-    if any(t.cubic_bspline_interpolation for t in textures):
+    if any(t.interpolation_mode == InterpolationMode.CUBIC_SPLINE for t in textures):
         assert isdir(join(dirname(__file__), "CubicInterpolationCUDA", "code")), \
             "Submodule CubicInterpolationCUDA does not exist"
         nvcc_options += ["-I" + join(dirname(__file__), "CubicInterpolationCUDA", "code")]
diff --git a/pystencils/interpolation_astnodes.py b/pystencils/interpolation_astnodes.py
index b18c9c40e32c25668ef9d70fd6921f354bc1c261..46aef0bdb05e2875b5523d2e850a0421a64c3772 100644
--- a/pystencils/interpolation_astnodes.py
+++ b/pystencils/interpolation_astnodes.py
@@ -9,6 +9,7 @@
 """
 
 import itertools
+from enum import Enum
 from typing import Set
 
 import sympy as sp
@@ -24,7 +25,14 @@ except Exception:
     pass
 
 
-class LinearInterpolator(object):
+class InterpolationMode(str, Enum):
+    NEAREST_NEIGHBOR = "nearest_neighbour"
+    NN = NEAREST_NEIGHBOR
+    LINEAR = "linear"
+    CUBIC_SPLINE = "cubic_spline"
+
+
+class Interpolator(object):
     """
     Implements non-integer accesses on fields using linear interpolation.
 
@@ -58,7 +66,9 @@ class LinearInterpolator(object):
 
     required_global_declarations = []
 
-    def __init__(self, parent_field: pystencils.Field,
+    def __init__(self,
+                 parent_field: pystencils.Field,
+                 interpolation_mode: InterpolationMode,
                  address_mode='BORDER',
                  use_normalized_coordinates=False):
         super().__init__()
@@ -72,6 +82,7 @@ class LinearInterpolator(object):
                                   'dummy_symbol_carrying_field' + self.field.name + hash_str)
         self.symbol.field = self.field
         self.symbol.interpolator = self
+        self.interpolation_mode = interpolation_mode
 
     def at(self, offset):
         return InterpolatorAccess(self.symbol, *offset)
@@ -93,6 +104,30 @@ class LinearInterpolator(object):
                      self.use_normalized_coordinates))
 
 
+class LinearInterpolator(Interpolator):
+
+    def __init__(self,
+                 parent_field: pystencils.Field,
+                 address_mode='BORDER',
+                 use_normalized_coordinates=False):
+        super().__init__(parent_field,
+                         InterpolationMode.LINEAR,
+                         address_mode,
+                         use_normalized_coordinates)
+
+
+class NearestNeightborInterpolator(Interpolator):
+
+    def __init__(self,
+                 parent_field: pystencils.Field,
+                 address_mode='BORDER',
+                 use_normalized_coordinates=False):
+        super().__init__(parent_field,
+                         InterpolationMode.NN,
+                         address_mode,
+                         use_normalized_coordinates)
+
+
 class InterpolatorAccess(TypedSymbol):
     def __new__(cls, field, offsets, *args, **kwargs):
         obj = TextureAccess.__xnew_cached_(cls, field, offsets, *args, **kwargs)
@@ -149,6 +184,10 @@ class InterpolatorAccess(TypedSymbol):
     def symbols_defined(self) -> Set[sp.Symbol]:
         return {self}
 
+    @property
+    def interpolation_mode(self):
+        return self.interpolator.interpolation_mode
+
     def implementation_with_stencils(self):
         field = self.field
 
@@ -165,56 +204,66 @@ class InterpolatorAccess(TypedSymbol):
         offsets = self.offsets
         rounding_functions = (sp.floor, lambda x: sp.floor(x) + 1)
 
-        # TODO optimization: implement via lerp: https://devblogs.nvidia.com/lerp-faster-cuda/
-        for c in itertools.product(rounding_functions, repeat=field.spatial_dimensions):
-            weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
-            index = [f(offset) for (f, offset) in zip(c, offsets)]
-            for channel_idx in range(field.shape[0] if field.index_dimensions else 1):
-                # Hardware boundary handling on GPU
+        for channel_idx in range(field.shape[0] if field.index_dimensions else 1):
+            if self.interpolation_mode == InterpolationMode.NN:
                 if use_textures:
-                    weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
-                    sum[channel_idx] += \
-                        weight * absolute_access(index, channel_idx if field.index_dimensions else ())
-                # else boundary handling using software
-                elif str(self.interpolator.address_mode).lower() == 'border':
-                    is_inside_field = sp.And(
-                        *itertools.chain([i >= 0 for i in index],
-                                         [idx < field.shape[dim] for (dim, idx) in enumerate(index)]))
-                    index = [cast_func(i, default_int_type) for i in index]
-                    sum[channel_idx] += sp.Piecewise(
-                        (weight * absolute_access(index, channel_idx if field.index_dimensions else ()),
-                            is_inside_field),
-                        (sp.simplify(0), True)
-                    )
-                elif str(self.interpolator.address_mode).lower() == 'clamp':
-                    index = [cast_func(sp.Min(sp.Max(0, i), field.shape[dim] - 1), default_int_type)
-                             for (dim, i) in enumerate(index)]
-                    sum[channel_idx] += weight * \
-                        absolute_access(index, channel_idx if field.index_dimensions else ())
-                elif str(self.interpolator.address_mode).lower() == 'wrap':
-                    index = [cast_func(sp.Piecewise((sp.Mod(i, field.shape[dim]), i >= 0),
-                                                    (field.shape[dim] + sp.Mod(i, field.shape[dim]), True)),
-                                       default_int_type)
-                             for (dim, i) in enumerate(index)]
-                    sum[channel_idx] += weight * \
-                        absolute_access(index, channel_idx if field.index_dimensions else ())
-                elif str(self.interpolator.address_mode).lower() == 'mirror':
-                    def triangle_fun(x, half_period):
-                        saw_tooth = sp.Abs(x) % (2 * half_period)
-                        return sp.Piecewise((saw_tooth, saw_tooth < half_period),
-                                            (2 * half_period - 1 - saw_tooth, True))
-                    index = [cast_func(triangle_fun(i, field.shape[dim]),
-                                       default_int_type) for (dim, i) in enumerate(index)]
-                    sum[channel_idx] += weight * absolute_access(index, channel_idx if field.index_dimensions else ())
+                    sum[channel_idx] = self
                 else:
-                    raise NotImplementedError()
-
-        sum = [sp.factor(s) for s in sum]
+                    sum[channel_idx] = absolute_access([sp.floor(i + 0.5) for i in offsets], channel_idx)
 
-        if field.index_dimensions:
-            return sp.Matrix(sum)
-        else:
-            return sum[0]
+            elif self.interpolation_mode == InterpolationMode.LINEAR:
+                # TODO optimization: implement via lerp: https://devblogs.nvidia.com/lerp-faster-cuda/
+                for c in itertools.product(rounding_functions, repeat=field.spatial_dimensions):
+                    weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
+                    index = [f(offset) for (f, offset) in zip(c, offsets)]
+                    # Hardware boundary handling on GPU
+                    if use_textures:
+                        weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
+                        sum[channel_idx] += \
+                            weight * absolute_access(index, channel_idx if field.index_dimensions else ())
+                    # else boundary handling using software
+                    elif str(self.interpolator.address_mode).lower() == 'border':
+                        is_inside_field = sp.And(
+                            *itertools.chain([i >= 0 for i in index],
+                                             [idx < field.shape[dim] for (dim, idx) in enumerate(index)]))
+                        index = [cast_func(i, default_int_type) for i in index]
+                        sum[channel_idx] += sp.Piecewise(
+                            (weight * absolute_access(index, channel_idx if field.index_dimensions else ()),
+                                is_inside_field),
+                            (sp.simplify(0), True)
+                        )
+                    elif str(self.interpolator.address_mode).lower() == 'clamp':
+                        index = [cast_func(sp.Min(sp.Max(0, i), field.shape[dim] - 1), default_int_type)
+                                 for (dim, i) in enumerate(index)]
+                        sum[channel_idx] += weight * \
+                            absolute_access(index, channel_idx if field.index_dimensions else ())
+                    elif str(self.interpolator.address_mode).lower() == 'wrap':
+                        index = [cast_func(sp.Piecewise((sp.Mod(i, field.shape[dim]), i >= 0),
+                                                        (field.shape[dim] + sp.Mod(i, field.shape[dim]), True)),
+                                           default_int_type)
+                                 for (dim, i) in enumerate(index)]
+                        sum[channel_idx] += weight * \
+                            absolute_access(index, channel_idx if field.index_dimensions else ())
+                    elif str(self.interpolator.address_mode).lower() == 'mirror':
+                        def triangle_fun(x, half_period):
+                            saw_tooth = sp.Abs(x) % (2 * half_period)
+                            return sp.Piecewise((saw_tooth, saw_tooth < half_period),
+                                                (2 * half_period - 1 - saw_tooth, True))
+                        index = [cast_func(triangle_fun(i, field.shape[dim]),
+                                           default_int_type) for (dim, i) in enumerate(index)]
+                        sum[channel_idx] += weight * \
+                            absolute_access(index, channel_idx if field.index_dimensions else ())
+                    else:
+                        raise NotImplementedError()
+            elif self.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
+                raise NotImplementedError("only works with HW interpolation for float32")
+
+            sum = [sp.factor(s) for s in sum]
+
+            if field.index_dimensions:
+                return sp.Matrix(sum)
+            else:
+                return sum[0]
 
     # noinspection SpellCheckingInspection
     __xnew__ = staticmethod(__new_stage2__)
@@ -231,9 +280,10 @@ class TextureCachedField:
     def __init__(self, parent_field,
                  address_mode=None,
                  filter_mode=None,
+                 interpolation_mode: InterpolationMode = InterpolationMode.LINEAR,
                  use_normalized_coordinates=False,
-                 read_as_integer=False,
-                 cubic_bspline_interpolation=False):
+                 read_as_integer=False
+                 ):
         if isinstance(address_mode, str):
             address_mode = getattr(pycuda.driver.address_mode, address_mode.upper())
 
@@ -252,14 +302,14 @@ class TextureCachedField:
         self.symbol = TypedSymbol(str(self), self.field.dtype.numpy_dtype)
         self.symbol.interpolator = self
         self.symbol.field = self.field
-        self.cubic_bspline_interpolation = cubic_bspline_interpolation
+        self.interpolation_mode = interpolation_mode
 
         # assert str(self.field.dtype) != 'double', "CUDA does not support double textures!"
         # assert dtype_supports_textures(self.field.dtype), "CUDA only supports texture types with 32 bits or less"
 
     @classmethod
     def from_interpolator(cls, interpolator: LinearInterpolator):
-        obj = cls(interpolator.field, interpolator.address_mode)
+        obj = cls(interpolator.field, interpolator.address_mode, interpolation_mode=interpolator.interpolation_mode)
         return obj
 
     def at(self, offset):