From b699c8830ebf6a1a33c5ab9dee201287a1c1a94f Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Thu, 16 May 2024 22:07:24 +0200
Subject: [PATCH] Support ARM64 Streaming SVE

This is part of the Scalable Matrix Extensions (SME)
---
 .gitlab-ci.yml                                |  2 +-
 .../backends/arm_instruction_sets.py          | 40 ++++++++++---------
 .../backends/simd_instruction_sets.py         |  5 ++-
 src/pystencils/cpu/cpujit.py                  |  7 +++-
 src/pystencils/include/philox_rand.h          |  8 ++--
 tests/test_conditional_vec.py                 |  8 ++--
 tests/test_random.py                          |  4 +-
 tests/test_vectorization.py                   |  4 +-
 8 files changed, 46 insertions(+), 32 deletions(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 04e690dc6..6d8a1c13f 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -181,7 +181,7 @@ arm64v9:
   image: i10git.cs.fau.de:5005/pycodegen/pycodegen/arm64
   before_script:
     - *multiarch_before_script
-    - sed -i s/march=native/march=armv8-a+sve/g ~/.config/pystencils/config.json
+    - sed -i s/march=native/march=armv8-a+sve+sme/g ~/.config/pystencils/config.json
     - sed -i s/g\+\+/clang++/g ~/.config/pystencils/config.json
 
 riscv64:
diff --git a/src/pystencils/backends/arm_instruction_sets.py b/src/pystencils/backends/arm_instruction_sets.py
index 7dede78aa..3e50d5f45 100644
--- a/src/pystencils/backends/arm_instruction_sets.py
+++ b/src/pystencils/backends/arm_instruction_sets.py
@@ -16,9 +16,9 @@ def get_argument_string(function_shortcut, first=''):
 
 
 def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
-    if instruction_set != 'neon' and not instruction_set.startswith('sve'):
+    if instruction_set not in ['neon', 'sme'] and not instruction_set.startswith('sve'):
         raise NotImplementedError(instruction_set)
-    if instruction_set == 'sve':
+    if instruction_set in ['sve', 'sme']:
         cmp = 'cmp'
     elif instruction_set.startswith('sve'):
         cmp = 'cmp'
@@ -52,7 +52,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
 
     result = dict()
 
-    if instruction_set == 'sve':
+    if instruction_set in ['sve', 'sme']:
         width = 'svcntd()' if data_type == 'double' else 'svcntw()'
         intwidth = 'svcntw()'
         result['bytes'] = 'svcntb()'
@@ -60,14 +60,14 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
         width = bitwidth // bits[data_type]
         intwidth = bitwidth // bits['int']
         result['bytes'] = bitwidth // 8
-    if instruction_set.startswith('sve'):
+    if instruction_set.startswith('sve') or instruction_set == 'sme':
         prefix = 'sv'
         suffix = f'_f{bits[data_type]}' 
     elif instruction_set == 'neon':
         prefix = 'v'
         suffix = f'q_f{bits[data_type]}' 
 
-    if instruction_set == 'sve':
+    if instruction_set in ['sve', 'sme']:
         predicate = f'{prefix}whilelt_b{bits[data_type]}_u64({{loop_counter}}, {{loop_stop}})'
         int_predicate = f'{prefix}whilelt_b{bits["int"]}_u64({{loop_counter}}, {{loop_stop}})'
     else:
@@ -86,7 +86,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
 
         result[intrinsic_id] = prefix + name + suffix + undef + arg_string
 
-    if instruction_set == 'sve':
+    if instruction_set in ['sve', 'sme']:
         from pystencils.backends.cbackend import CFunction
         result['width'] = CFunction(width, "int")
         result['intwidth'] = CFunction(intwidth, "int")
@@ -94,23 +94,24 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
         result['width'] = width
         result['intwidth'] = intwidth
 
-    if instruction_set.startswith('sve'):
+    if instruction_set.startswith('sve') or instruction_set == 'sme':
         result['makeVecConst'] = f'svdup_f{bits[data_type]}' + '({0})'
         result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})'
         result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})'
 
-        vindex = f'svindex_u{bits[data_type]}(0, {{0}})'
-        result['storeS'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
-                           vindex.format("{2}") + ', {1})'
-        result['loadS'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
-                          vindex.format("{1}") + ')'
+        if instruction_set != 'sme':
+            vindex = f'svindex_u{bits[data_type]}(0, {{0}})'
+            result['storeS'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
+                               vindex.format("{2}") + ', {1})'
+            result['loadS'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
+                              vindex.format("{1}") + ')'
 
         result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})"
 
-        result['float'] = f'svfloat{bits["float"]}_{"s" if instruction_set != "sve" else ""}t'
-        result['double'] = f'svfloat{bits["double"]}_{"s" if instruction_set != "sve" else ""}t'
-        result['int'] = f'svint{bits["int"]}_{"s" if instruction_set != "sve" else ""}t'
-        result['bool'] = f'svbool_{"s" if instruction_set != "sve" else ""}t'
+        result['float'] = f'svfloat{bits["float"]}_{"s" if instruction_set not in ["sve", "sme"] else ""}t'
+        result['double'] = f'svfloat{bits["double"]}_{"s" if instruction_set not in ["sve", "sme"] else ""}t'
+        result['int'] = f'svint{bits["int"]}_{"s" if instruction_set not in ["sve", "sme"] else ""}t'
+        result['bool'] = f'svbool_{"s" if instruction_set not in ["sve", "sme"] else ""}t'
 
         result['headers'] = ['<arm_sve.h>', '"arm_neon_helpers.h"']
 
@@ -121,9 +122,12 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
         result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}'
 
         result['maskStoreU'] = result['storeU'].replace(predicate, '{2}')
-        result['maskStoreS'] = result['storeS'].replace(predicate, '{3}')
+        if instruction_set != 'sme':
+            result['maskStoreS'] = result['storeS'].replace(predicate, '{3}')
 
-        if instruction_set != 'sve':
+        if instruction_set == 'sme':
+            result['function_prefix'] = '__attribute__((arm_locally_streaming))'
+        elif instruction_set not in ['sve', 'sme']:
             result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
     else:
         result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})'
diff --git a/src/pystencils/backends/simd_instruction_sets.py b/src/pystencils/backends/simd_instruction_sets.py
index e9bce8737..b94d9f374 100644
--- a/src/pystencils/backends/simd_instruction_sets.py
+++ b/src/pystencils/backends/simd_instruction_sets.py
@@ -22,7 +22,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
 
     type_name = numpy_name_to_c(np.dtype(data_type).name)
 
-    if instruction_set in ['neon'] or instruction_set.startswith('sve'):
+    if instruction_set in ['neon', 'sme'] or instruction_set.startswith('sve'):
         return get_vector_instruction_set_arm(type_name, instruction_set)
     elif instruction_set in ['vsx']:
         return get_vector_instruction_set_ppc(type_name, instruction_set)
@@ -53,6 +53,9 @@ def get_supported_instruction_sets():
                 result.append(f"sve{length}")
                 length //= 2
             result.append("sve")
+        hwcap2 = libc.getauxval(26)  # AT_HWCAP2
+        if hwcap2 & (1 << 23):  # HWCAP2_SME
+            result.append("sme")
         return result
     elif platform.system() == 'Linux' and platform.machine().startswith('riscv'):
         libc = CDLL('libc.so.6')
diff --git a/src/pystencils/cpu/cpujit.py b/src/pystencils/cpu/cpujit.py
index b839f87cf..8cc100045 100644
--- a/src/pystencils/cpu/cpujit.py
+++ b/src/pystencils/cpu/cpujit.py
@@ -617,7 +617,12 @@ def compile_and_load(ast, custom_backend=None):
     cache_config = get_cache_config()
 
     compiler_config = get_compiler_config()
-    function_prefix = '__declspec(dllexport)' if compiler_config['os'].lower() == 'windows' else ''
+    if compiler_config['os'].lower() == 'windows':
+        function_prefix = '__declspec(dllexport)'
+    elif ast.instruction_set and 'function_prefix' in ast.instruction_set:
+        function_prefix = ast.instruction_set['function_prefix']
+    else:
+        function_prefix = ''
 
     code = ExtensionModuleCode(custom_backend=custom_backend)
     code.add_function(ast, ast.function_name)
diff --git a/src/pystencils/include/philox_rand.h b/src/pystencils/include/philox_rand.h
index 4320a8b93..8aa6a4e75 100644
--- a/src/pystencils/include/philox_rand.h
+++ b/src/pystencils/include/philox_rand.h
@@ -18,7 +18,7 @@
 #ifdef __ARM_NEON
 #include <arm_neon.h>
 #endif
-#ifdef __ARM_FEATURE_SVE
+#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_SME)
 #include <arm_sve.h>
 #endif
 
@@ -42,6 +42,8 @@
 #define QUALIFIERS static __forceinline__ __device__
 #elif defined(__OPENCL_VERSION__)
 #define QUALIFIERS static inline
+#elif defined(__ARM_FEATURE_SME)
+#define QUALIFIERS __attribute__((arm_streaming_compatible))
 #else
 #define QUALIFIERS inline
 #include "myintrin.h"
@@ -69,7 +71,7 @@ typedef std::uint64_t uint64;
 #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0
 typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
 typedef svfloat64_t svfloat64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
-#elif defined(__ARM_FEATURE_SVE)
+#elif defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_SME)
 typedef svfloat32_t svfloat32_st;
 typedef svfloat64_t svfloat64_st;
 #endif
@@ -714,7 +716,7 @@ QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32
 #endif
 
 
-#if defined(__ARM_FEATURE_SVE)
+#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_SME)
 QUALIFIERS void _philox4x32round(svuint32x4_t & ctr, svuint32x2_t & key)
 {
     svuint32_t lo0 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0));
diff --git a/tests/test_conditional_vec.py b/tests/test_conditional_vec.py
index ebb25d50d..032c8ab78 100644
--- a/tests/test_conditional_vec.py
+++ b/tests/test_conditional_vec.py
@@ -15,7 +15,7 @@ supported_instruction_sets = get_supported_instruction_sets() if get_supported_i
 @pytest.mark.parametrize('instruction_set', supported_instruction_sets)
 @pytest.mark.parametrize('dtype', ('float32', 'float64'))
 def test_vec_any(instruction_set, dtype):
-    if instruction_set in ['sve', 'rvv']:
+    if instruction_set in ['sve', 'sme', 'rvv']:
         width = 4  # we don't know the actual value
     else:
         width = get_vector_instruction_set(dtype, instruction_set)['width']
@@ -34,7 +34,7 @@ def test_vec_any(instruction_set, dtype):
                            cpu_vectorize_info={'instruction_set': instruction_set})
     kernel = ast.compile()
     kernel(data=data_arr)
-    if instruction_set in ['sve', 'rvv']:
+    if instruction_set in ['sve', 'sme', 'rvv']:
         # we only know that the first value has changed
         np.testing.assert_equal(data_arr[3:9, :3 * width - 1], 2.0)
     else:
@@ -44,7 +44,7 @@ def test_vec_any(instruction_set, dtype):
 @pytest.mark.parametrize('instruction_set', supported_instruction_sets)
 @pytest.mark.parametrize('dtype', ('float32', 'float64'))
 def test_vec_all(instruction_set, dtype):
-    if instruction_set in ['sve', 'rvv']:
+    if instruction_set in ['sve', 'sme', 'rvv']:
         width = 1000  # we don't know the actual value, need something guaranteed larger than vector
     else:
         width = get_vector_instruction_set(dtype, instruction_set)['width']
@@ -59,7 +59,7 @@ def test_vec_all(instruction_set, dtype):
                            cpu_vectorize_info={'instruction_set': instruction_set})
     kernel = ast.compile()
     kernel(data=data_arr)
-    if instruction_set in ['sve', 'rvv']:
+    if instruction_set in ['sve', 'sme', 'rvv']:
         # we only know that some values in the middle have been replaced
         assert np.all(data_arr[3:9, :2] <= 1.0)
         assert np.any(data_arr[3:9, 2:] == 2.0)
diff --git a/tests/test_random.py b/tests/test_random.py
index e82bff309..21933e893 100644
--- a/tests/test_random.py
+++ b/tests/test_random.py
@@ -32,7 +32,7 @@ if get_compiler_config()['os'] == 'windows':
 def test_rng(target, rng, precision, dtype, t=124, offsets=(0, 0), keys=(0, 0), offset_values=None):
     if target == Target.GPU:
         pytest.importorskip('cupy')
-    if instruction_sets and {'neon', 'sve', 'vsx', 'rvv'}.intersection(instruction_sets) and rng == 'aesni':
+    if instruction_sets and {'neon', 'sve', 'sme', 'vsx', 'rvv'}.intersection(instruction_sets) and rng == 'aesni':
         pytest.xfail('AES not yet implemented for this architecture')
     if rng == 'aesni' and len(keys) == 2:
         keys *= 2
@@ -122,7 +122,7 @@ def test_rng_offsets(kind, vectorized):
 @pytest.mark.parametrize('rng', ('philox', 'aesni'))
 @pytest.mark.parametrize('precision,dtype', (('float', 'float'), ('double', 'double')))
 def test_rng_vectorized(target, rng, precision, dtype, t=130, offsets=(1, 3), keys=(0, 0), offset_values=None):
-    if (target in ['neon', 'vsx', 'rvv'] or target.startswith('sve')) and rng == 'aesni':
+    if (target in ['neon', 'vsx', 'rvv', 'sme'] or target.startswith('sve')) and rng == 'aesni':
         pytest.xfail('AES not yet implemented for this architecture')
     cpu_vectorize_info = {'assume_inner_stride_one': True, 'assume_aligned': True, 'instruction_set': target}
 
diff --git a/tests/test_vectorization.py b/tests/test_vectorization.py
index d3066e1df..d2350526e 100644
--- a/tests/test_vectorization.py
+++ b/tests/test_vectorization.py
@@ -146,7 +146,7 @@ def test_aligned_and_nt_stores(openmp, instruction_set=instruction_set):
     if instruction_set in ['sse'] or instruction_set.startswith('avx'):
         assert 'stream' in ast.instruction_set
         assert 'streamFence' in ast.instruction_set
-    if instruction_set in ['neon', 'vsx', 'rvv'] or instruction_set.startswith('sve'):
+    if instruction_set in ['neon', 'sme', 'vsx', 'rvv'] or instruction_set.startswith('sve'):
         assert 'cachelineZero' in ast.instruction_set
     if instruction_set in ['vsx']:
         assert 'storeAAndFlushCacheline' in ast.instruction_set
@@ -331,7 +331,7 @@ def test_logical_operators(instruction_set=instruction_set):
 
 
 def test_hardware_query():
-    assert {'sse', 'neon', 'sve', 'vsx', 'rvv'}.intersection(supported_instruction_sets)
+    assert {'sse', 'neon', 'sve', 'sme', 'vsx', 'rvv'}.intersection(supported_instruction_sets)
 
 
 def test_vectorised_pow(instruction_set=instruction_set):
-- 
GitLab