diff --git a/pystencils/rng.py b/pystencils/rng.py
index 87ca251b8ad88586e076025ef878cdf118ffa6a1..c1daed1d4ba43167a33650c7bcf80a2b167885b5 100644
--- a/pystencils/rng.py
+++ b/pystencils/rng.py
@@ -108,7 +108,8 @@ def random_symbol(assignment_list, dim, seed=TypedSymbol("seed", np.uint32), rng
     """
     counter = 0
     while True:
-        node = rng_node(dim, keys=(counter, seed), time_step=time_step, offsets=offsets)
+        keys = (counter, seed) + (0,) * (rng_node._num_keys - 2)
+        node = rng_node(dim, keys=keys, time_step=time_step, offsets=offsets)
         inserted = False
         for symbol in node.result_symbols:
             if not inserted:
diff --git a/pystencils_tests/test_random.py b/pystencils_tests/test_random.py
index 85b8d7ee66adb639bee818f49e1f7082ac227d80..9a0724cf16468d16350667aa67e96b8c815bc366 100644
--- a/pystencils_tests/test_random.py
+++ b/pystencils_tests/test_random.py
@@ -175,7 +175,7 @@ def test_staggered(vectorized):
     dh = ps.create_data_handling((8, 8), default_ghost_layers=0, default_target="cpu")
     j = dh.add_array("j", values_per_cell=dh.dim, field_type=ps.FieldType.STAGGERED_FLUX)
     a = ps.AssignmentCollection([ps.Assignment(j.staggered_access(n), 0) for n in j.staggered_stencil])
-    rng_symbol_gen = random_symbol(a.subexpressions, dim=dh.dim)
+    rng_symbol_gen = random_symbol(a.subexpressions, dim=dh.dim, rng_node=AESNITwoDoubles)
     a.main_assignments[0] = ps.Assignment(a.main_assignments[0].lhs, next(rng_symbol_gen))
     kernel = ps.create_staggered_kernel(a, target=dh.default_target).compile()