From 6562a56c6b6c0829f760f2e6ba71b197d81f53df Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Mon, 19 Apr 2021 17:44:07 +0000
Subject: [PATCH] Support shorter SVE vectors via predicates

---
 pystencils/backends/arm_instruction_sets.py  | 23 ++++++++++----------
 pystencils/backends/simd_instruction_sets.py | 18 ++++++++++-----
 pystencils/include/arm_neon_helpers.h        | 13 +++++++++++
 3 files changed, 38 insertions(+), 16 deletions(-)

diff --git a/pystencils/backends/arm_instruction_sets.py b/pystencils/backends/arm_instruction_sets.py
index 369e74ec4..9f7b4ee22 100644
--- a/pystencils/backends/arm_instruction_sets.py
+++ b/pystencils/backends/arm_instruction_sets.py
@@ -65,12 +65,14 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
     result = dict()
     result['bytes'] = bitwidth // 8
 
+    predicate = f'{prefix}whilelt_b{bits[data_type]}(0, {width})'
+    int_predicate = f'{prefix}whilelt_b{bits["int"]}(0, {intwidth})'
+
     for intrinsic_id, function_shortcut in base_names.items():
         function_shortcut = function_shortcut.strip()
         name = function_shortcut[:function_shortcut.index('[')]
 
-        arg_string = get_argument_string(function_shortcut, first=f'{prefix}ptrue_b{bits[data_type]}()'
-                                         if prefix == 'sv' else '')
+        arg_string = get_argument_string(function_shortcut, first=predicate if prefix == 'sv' else '')
         if prefix == 'sv' and not name.startswith('ld') and not name.startswith('st') and not name.startswith(cmp):
             undef = '_x'
         else:
@@ -86,20 +88,19 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
         result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})'
         result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})'
 
-        result['+int'] = f"svadd_s{bits['int']}_x(svptrue_b{bits['int']}(), " + "{0}, {1})"
+        result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})"
 
-        attr = f' __attribute__((arm_sve_vector_bits({bitwidth})))'
-        result[data_type] = f'svfloat{bits[data_type]}_t{attr}'
-        result['int'] = f'svint{bits["int"]}_t{attr}'
-        result['bool'] = f'svbool_t{attr}'
+        result[data_type] = f'svfloat{bits[data_type]}_st'
+        result['int'] = f'svint{bits["int"]}_st'
+        result['bool'] = 'svbool_st'
 
         result['headers'] = ['<arm_sve.h>', '"arm_neon_helpers.h"']
 
-        result['&'] = f'svand_b_z(svptrue_b{bits[data_type]}(),' + ' {0}, {1})'
-        result['|'] = f'svorr_b_z(svptrue_b{bits[data_type]}(),' + ' {0}, {1})'
+        result['&'] = f'svand_b_z({predicate},' + ' {0}, {1})'
+        result['|'] = f'svorr_b_z({predicate},' + ' {0}, {1})'
         result['blendv'] = f'svsel_f{bits[data_type]}' + '({2}, {1}, {0})'
-        result['any'] = f'svptest_any(svptrue_b{bits[data_type]}(), {{0}}) > 0'
-        result['all'] = f'svcntp_b{bits[data_type]}(svptrue_b{bits[data_type]}(), {{0}}) == {width}'
+        result['any'] = f'svptest_any({predicate}, {{0}})'
+        result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}'
 
         result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
     else:
diff --git a/pystencils/backends/simd_instruction_sets.py b/pystencils/backends/simd_instruction_sets.py
index 4428982cf..b552da0e9 100644
--- a/pystencils/backends/simd_instruction_sets.py
+++ b/pystencils/backends/simd_instruction_sets.py
@@ -1,5 +1,6 @@
-import os
+import math
 import platform
+from ctypes import CDLL
 
 from pystencils.backends.x86_instruction_sets import get_vector_instruction_set_x86
 from pystencils.backends.arm_instruction_sets import get_vector_instruction_set_arm
@@ -59,10 +60,17 @@ def get_supported_instruction_sets():
     if flags.issuperset(required_neon_flags):
         result.append("neon")
     if flags.issuperset(required_sve_flags):
-        length_file = '/proc/sys/abi/sve_default_vector_length'
-        if os.path.exists(length_file):
-            length = 8 * int(open(length_file, 'r').read())
-            result.append(f"sve{length}")
+        if platform.system() == 'Linux':
+            libc = CDLL('libc.so.6')
+            native_length = 8 * libc.prctl(51, 0, 0, 0, 0)  # PR_SVE_GET_VL
+            if native_length < 0:
+                raise OSError("SVE length query failed")
+            pwr2_length = int(2**math.floor(math.log2(native_length)))
+            if pwr2_length % 256 == 0:
+                result.append(f"sve{pwr2_length//2}")
+            if native_length != pwr2_length:
+                result.append(f"sve{pwr2_length}")
+            result.append(f"sve{native_length}")
         else:
             result.append("sve")
     return result
diff --git a/pystencils/include/arm_neon_helpers.h b/pystencils/include/arm_neon_helpers.h
index 3d06d69bf..a900001e7 100644
--- a/pystencils/include/arm_neon_helpers.h
+++ b/pystencils/include/arm_neon_helpers.h
@@ -1,5 +1,17 @@
+#ifdef __ARM_NEON
 #include <arm_neon.h>
+#endif
 
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0
+#include <arm_sve.h>
+
+typedef svbool_t svbool_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
+typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
+typedef svfloat64_t svfloat64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
+typedef svint32_t svint32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
+#endif
+
+#ifdef __ARM_NEON
 inline float32x4_t makeVec_f32(float a, float b, float c, float d)
 {
     alignas(16) float data[4] = {a, b, c, d};
@@ -17,6 +29,7 @@ inline int32x4_t makeVec_s32(int a, int b, int c, int d)
     alignas(16) int data[4] = {a, b, c, d};
     return vld1q_s32(data);
 }
+#endif
 
 inline void cachelineZero(void * p) {
 	__asm__ volatile("dc zva, %0"::"r"(p));
-- 
GitLab