From a946d58ea90b43d8e43d1554d7a2c502bb5334ec Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Wed, 28 Apr 2021 19:29:12 +0200
Subject: [PATCH] gather/scatter vector instructions

---
 pystencils/backends/arm_instruction_sets.py | 13 ++++++++++++-
 pystencils/backends/x86_instruction_sets.py | 10 +++++++++-
 2 files changed, 21 insertions(+), 2 deletions(-)

diff --git a/pystencils/backends/arm_instruction_sets.py b/pystencils/backends/arm_instruction_sets.py
index 9f7b4ee22..5318dffeb 100644
--- a/pystencils/backends/arm_instruction_sets.py
+++ b/pystencils/backends/arm_instruction_sets.py
@@ -88,9 +88,16 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
         result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})'
         result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})'
 
+        vindex = f'svindex_u{bits[data_type]}(0, {{0}})'
+        result['scatter'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
+                            vindex.format("{2}") + ', {1})'
+        result['gather'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
+                           vindex.format("{1}") + ')'
+
         result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})"
 
-        result[data_type] = f'svfloat{bits[data_type]}_st'
+        result['float'] = 'svfloat32_st'
+        result['double'] = 'svfloat64_st'
         result['int'] = f'svint{bits["int"]}_st'
         result['bool'] = 'svbool_st'
 
@@ -102,6 +109,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
         result['any'] = f'svptest_any({predicate}, {{0}})'
         result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}'
 
+        result['maskStoreU'] = result['storeU'].replace(predicate, '{2}')
+        result['maskStoreA'] = result['storeA'].replace(predicate, '{2}')
+        result['maskScatter'] = result['scatter'].replace(predicate, '{3}')
+
         result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
     else:
         result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})'
diff --git a/pystencils/backends/x86_instruction_sets.py b/pystencils/backends/x86_instruction_sets.py
index 50005c5ae..6bf12b448 100644
--- a/pystencils/backends/x86_instruction_sets.py
+++ b/pystencils/backends/x86_instruction_sets.py
@@ -130,7 +130,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
     result['all'] = f"{pre}_movemask_{suf}({{0}}) == {hex(2**result['width']-1)}"
 
     if instruction_set == 'avx512':
-        size = 8 if data_type == 'double' else 16
+        size = result['width']
         result['&'] = f'_kand_mask{size}({{0}}, {{1}})'
         result['|'] = f'_kor_mask{size}({{0}}, {{1}})'
         result['any'] = f'!_ktestz_mask{size}_u8({{0}}, {{0}})'
@@ -145,6 +145,14 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
         params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)])
         result['makeVecConstBool'] = f"__mmask8(({params}) )"
 
+        vindex = f'{pre}_set_epi{bit_width//size}(' + ', '.join([str(i) for i in range(result['width'])][::-1]) + ')'
+        vindex = f'{pre}_mullo_epi{bit_width//size}({vindex}, {pre}_set1_epi{bit_width//size}({{0}}))'
+        result['scatter'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \
+                            f', {{1}}, {64//size})'
+        result['maskScatter'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \
+                                f', {{1}}, {64//size})'
+        result['gather'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {64//size})'
+
     if instruction_set == 'avx' and data_type == 'float':
         result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})"
 
-- 
GitLab