diff --git a/pystencils_kernels.py b/pystencils_kernels.py
index 8d4e307b1ed73bcff7dcada7cf1eed2395e997d7..182e96ffb0f6e2930595c620509b2bfd8737ec0b 100644
--- a/pystencils_kernels.py
+++ b/pystencils_kernels.py
@@ -128,3 +128,64 @@ def propagation_kernel(size, Q, c_ix, c_iy):
             ps.Assignment(data[0, 0](i), data_in[c_iy[i], -1 * c_ix[i]](i))
         )
     return ps.create_kernel(symbolic_description).compile()
+
+
+def time_step_kernel(size, omega, lattice):
+    c_s_sq, weights, velocity = lattice.velocity_set()
+    c_ix = velocity[0]
+    c_iy = velocity[1]
+    Q = len(weights)
+    ghost_layers = c_ix.max()
+
+    # Data handling four ghost layers
+    dh = ps.create_data_handling(domain_size=size, periodicity=(True, True))
+    src = dh.add_array('src', values_per_cell=Q, ghost_layers=ghost_layers)
+    dst = dh.add_array_like('dst', 'src')
+    # Copy func used to copy over periodic values at boundaries. Only the space filling property of 'stencil'
+    # is important
+    copy_func = dh.synchronization_function(['src'], stencil='D2Q9')
+
+    density, u_x, u_y = sp.symbols("density, u_x, u_y")
+
+    density_formula = sum([src[0, 0](i) for i in range(Q)])
+    vel_x_formula = (sum([src[0, 0](i) * c_ix[i] for i in range(Q)])) / density
+    vel_y_formula = (sum([src[0, 0](i) * c_iy[i] for i in range(Q)])) / density
+
+    feq_formula = [0] * Q
+    for i in range(Q):
+        feq_formula[i] = weights[i] * density * (
+                1
+                + (u_x * c_ix[i] + u_y * c_iy[i]) / c_s_sq
+                + (u_x * c_ix[i] + u_y * c_iy[i]) ** 2 / (2 * c_s_sq ** 2)
+                - (u_x ** 2 + u_y ** 2) / (2 * c_s_sq)
+        )
+        if len(c_ix) >= 17:
+            feq_formula[i] += weights[i] * density * (
+                    -2 * (1 / (2 * c_s_sq)) ** 2 * (
+                        c_ix[i] * u_x ** 3 + c_ix[i] * u_x * u_y ** 2 + c_iy[i] * u_y * u_x ** 2 + c_iy[
+                    i] * u_y ** 3)
+                    + 4 / 3 * (1 / (2 * c_s_sq)) ** 3 * (u_x * c_ix[i] + u_y * c_iy[i]) ** 3
+            )
+        if len(c_ix) >= 37:
+            feq_formula[i] += weights[i] * density * \
+                           (1 / (6 * c_s_sq ** 2) * (
+                                   (u_x * c_ix[i] + u_y * c_iy[i]) ** 4 / (4 * c_s_sq ** 2) - (
+                                       3 * (u_x ** 2 + u_y ** 2) *
+                                       (u_x * c_ix[i] + u_y * c_iy[i]) ** 2) / (2 * c_s_sq) + 3 * (
+                                               u_x ** 4 + u_y ** 4)))
+
+    # Update macroscopic moments
+    symbolic_description = [ps.Assignment(density, density_formula),
+                            ps.Assignment(u_x, vel_x_formula),
+                            ps.Assignment(u_y, vel_y_formula), ]
+
+    # Collision + Propagation
+    for i in range(Q):
+        symbolic_description.append(
+            ps.Assignment(dst[-1 * c_iy[i], c_ix[i]](i), omega * feq_formula[i] + (1 - omega) * src[0, 0](i))
+        )
+    kernel = ps.create_kernel(symbolic_description).compile()
+
+    return dh, copy_func, kernel
+
+
diff --git a/visualize.py b/visualize.py
index b0580e5df30734201ef9122b67c5dc4c98e27eed..0259b3620867e6151cbc16708dc0e6da651b7b52 100644
--- a/visualize.py
+++ b/visualize.py
@@ -16,6 +16,7 @@ from utils.utils import logger
 import numpy as np
 from utils.utils import Timer
 import click
+import time
 import matplotlib.pyplot as plt
 import matplotlib.cm as cm
 import matplotlib.animation as animation
@@ -200,6 +201,61 @@ def compare_v37_accuracy():
     plt.show()
 
 
+def compare_pystencils_numpy(resolution, iterations):
+    lattice_v17 = Lattice.from_name("D2V17")
+    size = (resolution, resolution)
+
+    lbm_np = LBM(size=size, lattice=lattice_v17, pystencils=False, omega=0.8, fill_mode="taylor-green")
+    lbm_ps = LBM(size=size, lattice=lattice_v17, pystencils=True, omega=0.8, fill_mode="taylor-green")
+
+    t_start = time.time()
+    for i in range(iterations):
+        lbm_np.iterate()
+    t_break = time.time()
+    for i in range(iterations):
+        lbm_ps.iterate()
+    t_stop = time.time()
+
+    print("Pystencils = {} Numpy = {}".format((t_stop-t_break), (t_break - t_start)))
+
+
+def time_step_kernel(resolution):
+    lattice_q9 = Lattice.from_name("D2Q9")
+    size = (resolution, resolution)
+    omega = 0.8
+
+    lbm_ps = LBM(size=size, lattice=lattice_q9, pystencils=True, omega=omega, fill_mode="taylor-green")
+
+    import pystencils_kernels as pystencils_kernels
+    dh, copy_func, kernel = pystencils_kernels.time_step_kernel(size=size, omega=omega, lattice=lattice_q9)
+    dh.fill('src', 0, ghost_layers=True)
+    dh.fill('dst', 0, ghost_layers=True)
+    for block in dh.iterate(ghost_layers=False):
+        x, y = block.cell_index_arrays
+        block["src"][True] = lbm_ps.data[x, y]
+    assert (dh.gather_array('src', ghost_layers=False) == lbm_ps.data).all()
+
+    #print(dh.gather_array('src',ghost_layers=True)[resolution + 1, resolution + 1, :])
+    copy_func()
+    assert (dh.gather_array('src', ghost_layers=True)[0, 0] == dh.gather_array('src', ghost_layers=False)[-1, -1]).all()
+    assert (dh.gather_array('src', ghost_layers=True)[3, 0] == dh.gather_array('src', ghost_layers=False)[2, -1]).all()
+    #print(dh.gather_array('src',ghost_layers=True)[resolution + 1, resolution + 1, :])
+    #print(abs(dh.gather_array('src', ghost_layers=True)[1, 1, :] - dh.gather_array('src',ghost_layers=True)[resolution + 1, resolution + 1, :]) < 1e-160)
+
+    dh.run_kernel(kernel)
+    dh.swap('src', 'dst')
+    lbm_ps.iterate()
+
+    tx, ty= 0, 0
+    print(dh.gather_array('src', ghost_layers=False)[tx, ty, :])
+    print(lbm_ps.data[tx, ty, :])
+    print(abs((dh.gather_array('src', ghost_layers=False)[tx, ty, :] - lbm_ps.data[tx, ty, :])) <= 1e-10)
+    print(abs(dh.gather_array('src', ghost_layers=False)[1:-2, 1:-2, :] - lbm_ps.data[1:-2, 1:-2, :]).max())
+    #print((dh.gather_array('src', ghost_layers=False)[:, :, 3] == lbm_ps.data[:, :, 3]))
+
+    assert (dh.gather_array('src', ghost_layers=False) == lbm_ps.data).all()
+
+
 if __name__ == "__main__":
     # logger.setLevel(logging.INFO)
     logger.setLevel(logging.WARNING)
@@ -209,4 +265,7 @@ if __name__ == "__main__":
     # compare_numpy_accuracy()
     # compare_numpy_performance()
     # compare_v37_pystencils()
-    compare_v37_accuracy()
+    # compare_v37_accuracy()
+    # compare_pystencils_numpy(100, iterations=100)
+
+    time_step_kernel(7)
\ No newline at end of file