From 6fcc0cccaf935c92d1adafa0a01128847d4efc6a Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Thu, 20 May 2021 19:04:54 +0200
Subject: [PATCH] Rename scatter/gather to strided

Some instruction sets have separate strided and scatter/gather operations, e.g. RISC-V-V or NEC SX
---
 pystencils/backends/arm_instruction_sets.py     | 10 +++++-----
 pystencils/backends/cbackend.py                 |  4 ++--
 pystencils/backends/riscv_instruction_sets.py   | 12 ++++++------
 pystencils/backends/x86_instruction_sets.py     | 10 +++++-----
 pystencils/cpu/vectorization.py                 | 10 +++++-----
 pystencils_tests/test_vectorization_specific.py |  4 ++--
 6 files changed, 25 insertions(+), 25 deletions(-)

diff --git a/pystencils/backends/arm_instruction_sets.py b/pystencils/backends/arm_instruction_sets.py
index 904329628..73ea7eb44 100644
--- a/pystencils/backends/arm_instruction_sets.py
+++ b/pystencils/backends/arm_instruction_sets.py
@@ -102,10 +102,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
         result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})'
 
         vindex = f'svindex_u{bits[data_type]}(0, {{0}})'
-        result['scatter'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
-                            vindex.format("{2}") + ', {1})'
-        result['gather'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
-                           vindex.format("{1}") + ')'
+        result['storeS'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
+                           vindex.format("{2}") + ', {1})'
+        result['loadS'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
+                          vindex.format("{1}") + ')'
 
         result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})"
 
@@ -124,7 +124,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
 
         result['maskStoreU'] = result['storeU'].replace(predicate, '{2}')
         result['maskStoreA'] = result['storeA'].replace(predicate, '{2}')
-        result['maskScatter'] = result['scatter'].replace(predicate, '{3}')
+        result['maskStoreS'] = result['storeS'].replace(predicate, '{3}')
 
         if instruction_set != 'sve':
             result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 0774b53ff..ba7426a14 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -311,7 +311,7 @@ class CBackend:
                 ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0])
 
                 if stride != 1:
-                    instr = 'maskScatter' if mask != True else 'scatter'  # NOQA
+                    instr = 'maskStoreS' if mask != True else 'storeS'  # NOQA
                     return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
                                                                       stride, printed_mask, **self._kwargs) + ';'
 
@@ -648,7 +648,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
         if isinstance(expr, vector_memory_access):
             arg, data_type, aligned, _, mask, stride = expr.args
             if stride != 1:
-                return self.instruction_set['gather'].format("& " + self._print(arg), stride, **self._kwargs)
+                return self.instruction_set['loadS'].format("& " + self._print(arg), stride, **self._kwargs)
             instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
             return instruction.format("& " + self._print(arg), **self._kwargs)
         elif isinstance(expr, cast_func):
diff --git a/pystencils/backends/riscv_instruction_sets.py b/pystencils/backends/riscv_instruction_sets.py
index 78c16da09..d93aee701 100644
--- a/pystencils/backends/riscv_instruction_sets.py
+++ b/pystencils/backends/riscv_instruction_sets.py
@@ -35,9 +35,9 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'):
         'storeA': f'se{bits[data_type]}_v[0, 1]',
         'maskStoreU': f'se{bits[data_type]}_v[2, 0, 1]',
         'maskStoreA': f'se{bits[data_type]}_v[2, 0, 1]',
-        'gather': f'lse{bits[data_type]}_v[0, 1]',
-        'scatter': f'sse{bits[data_type]}_v[0, 2, 1]',
-        'maskScatter': f'sse{bits[data_type]}_v[2, 0, 3, 1]',
+        'loadS': f'lse{bits[data_type]}_v[0, 1]',
+        'storeS': f'sse{bits[data_type]}_v[0, 2, 1]',
+        'maskStoreS': f'sse{bits[data_type]}_v[2, 0, 3, 1]',
 
         'abs': 'fabs_v[0]',
         '==': 'mfeq_vv[0, 1]',
@@ -90,9 +90,9 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'):
     result['makeVecIndex'] = f'vmacc_vx_i{bits["int"]}m1({result["makeVecConstInt"]}, {{1}}, ' + \
                              f'vid_v_i{bits["int"]}m1({int_vl}), {int_vl})'
 
-    result['scatter'] = result['scatter'].replace('{2}', f'{{2}}*{bits[data_type]//8}')
-    result['gather'] = result['gather'].replace('{1}', f'{{1}}*{bits[data_type]//8}')
-    result['maskScatter'] = result['maskScatter'].replace('{3}', f'{{3}}*{bits[data_type]//8}')
+    result['storeS'] = result['storeS'].replace('{2}', f'{{2}}*{bits[data_type]//8}')
+    result['loadS'] = result['loadS'].replace('{1}', f'{{1}}*{bits[data_type]//8}')
+    result['maskStoreS'] = result['maskStoreS'].replace('{3}', f'{{3}}*{bits[data_type]//8}')
 
     result['+int'] = f"vadd_vv_i{bits['int']}m1({{0}}, {{1}}, {int_vl})"
 
diff --git a/pystencils/backends/x86_instruction_sets.py b/pystencils/backends/x86_instruction_sets.py
index 913db542f..f72b48266 100644
--- a/pystencils/backends/x86_instruction_sets.py
+++ b/pystencils/backends/x86_instruction_sets.py
@@ -147,11 +147,11 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
 
         vindex = f'{pre}_set_epi{bit_width//size}(' + ', '.join([str(i) for i in range(result['width'])][::-1]) + ')'
         vindex = f'{pre}_mullo_epi{bit_width//size}({vindex}, {pre}_set1_epi{bit_width//size}({{0}}))'
-        result['scatter'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \
-                            f', {{1}}, {64//size})'
-        result['maskScatter'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \
-                                f', {{1}}, {64//size})'
-        result['gather'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {64//size})'
+        result['storeS'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \
+                           f', {{1}}, {64//size})'
+        result['maskStoreS'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \
+                               f', {{1}}, {64//size})'
+        result['loadS'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {64//size})'
 
     if instruction_set == 'avx' and data_type == 'float':
         result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})"
diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index 2444519d7..b9fa2819e 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -127,10 +127,10 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
     kernel_ast.instruction_set = vector_is
 
     vectorize_rng(kernel_ast, vector_width)
-    scattergather = 'scatter' in vector_is and 'gather' in vector_is
+    strided = 'storeS' in vector_is and 'loadS' in vector_is
     keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU']
     vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal,
-                                                scattergather, keep_loop_stop, assume_sufficient_line_padding)
+                                                strided, keep_loop_stop, assume_sufficient_line_padding)
     insert_vector_casts(kernel_ast)
 
 
@@ -153,7 +153,7 @@ def vectorize_rng(kernel_ast, vector_width):
 
 
 def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields,
-                                                scattergather, keep_loop_stop, assume_sufficient_line_padding):
+                                                strided, keep_loop_stop, assume_sufficient_line_padding):
     """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
     all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
     inner_loops = [n for n in all_loops if n.is_innermost_loop]
@@ -187,7 +187,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
                 loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms()
                 aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0
                 stride = sp.simplify(index.subs({loop_counter_symbol: loop_counter_symbol + 1}) - index)
-                if not loop_counter_is_offset and (not scattergather or loop_counter_symbol in stride.atoms()):
+                if not loop_counter_is_offset and (not strided or loop_counter_symbol in stride.atoms()):
                     successful = False
                     break
                 typed_symbol = base.label
@@ -200,7 +200,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
                 if hasattr(indexed, 'field'):
                     nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields)
                 substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True,
-                                                              stride if scattergather else 1)
+                                                              stride if strided else 1)
                 if nontemporal:
                     # insert NontemporalFence after the outermost loop
                     parent = loop_node.parent
diff --git a/pystencils_tests/test_vectorization_specific.py b/pystencils_tests/test_vectorization_specific.py
index ebc86ee8a..1c0c35e53 100644
--- a/pystencils_tests/test_vectorization_specific.py
+++ b/pystencils_tests/test_vectorization_specific.py
@@ -55,10 +55,10 @@ def test_vectorized_abs(instruction_set, dtype):
 
 @pytest.mark.parametrize('dtype', ('float', 'double'))
 @pytest.mark.parametrize('instruction_set', supported_instruction_sets)
-def test_scatter_gather(instruction_set, dtype):
+def test_strided(instruction_set, dtype):
     f, g = ps.fields(f"f, g : float{64 if dtype == 'double' else 32}[2D]")
     update_rule = [ps.Assignment(g[0, 0], f[0, 0] + f[-1, 0] + f[1, 0] + f[0, 1] + f[0, -1] + 42.0)]
-    if 'scatter' not in get_vector_instruction_set(dtype, instruction_set) and not instruction_set in ['avx512', 'rvv'] and not instruction_set.startswith('sve'):
+    if 'storeS' not in get_vector_instruction_set(dtype, instruction_set) and not instruction_set in ['avx512', 'rvv'] and not instruction_set.startswith('sve'):
         with pytest.warns(UserWarning) as warn:
             ast = ps.create_kernel(update_rule, cpu_vectorize_info={'instruction_set': instruction_set})
             assert 'Could not vectorize loop' in warn[0].message.args[0]
-- 
GitLab