From bcb36f7752b633127a62ee3abd0b703056cfe33b Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Thu, 27 May 2021 15:28:47 +0200
Subject: [PATCH] only vectorize the RNG symbols if the loop was vectorized

---
 pystencils/cpu/vectorization.py | 28 +++++++++-------------------
 1 file changed, 9 insertions(+), 19 deletions(-)

diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index e4a27ad34..b54b78d85 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -126,7 +126,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
     vector_width = vector_is['width']
     kernel_ast.instruction_set = vector_is
 
-    vectorize_rng(kernel_ast, vector_width)
     strided = 'storeS' in vector_is and 'loadS' in vector_is
     keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU']
     vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal,
@@ -134,24 +133,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
     insert_vector_casts(kernel_ast)
 
 
-def vectorize_rng(kernel_ast, vector_width):
-    """Replace scalar result symbols on RNG nodes with vectorial ones"""
-    from pystencils.rng import RNGBase
-    subst = {}
-
-    def visit_node(node):
-        for arg in node.args:
-            if isinstance(arg, RNGBase):
-                new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width))
-                                      for s in arg.result_symbols]
-                subst.update({s[0]: s[1] for s in zip(arg.result_symbols, new_result_symbols)})
-                arg._symbols_defined = set(new_result_symbols)
-            else:
-                visit_node(arg)
-    visit_node(kernel_ast)
-    fast_subs(kernel_ast.body, subst, skip=lambda e: isinstance(e, RNGBase))
-
-
 def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields,
                                                 strided, keep_loop_stop, assume_sufficient_line_padding):
     """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
@@ -227,6 +208,15 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
 
         mask_conditionals(loop_node)
 
+        from pystencils.rng import RNGBase
+        substitutions = {}
+        for rng in loop_node.atoms(RNGBase):
+            new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width))
+                                  for s in rng.result_symbols]
+            substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)})
+            rng._symbols_defined = set(new_result_symbols)
+        fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase))
+
 
 def mask_conditionals(loop_body):
     def visit_node(node, mask):
-- 
GitLab