From 84d19478a7d9411329853cbfcde3e7de3a2aa592 Mon Sep 17 00:00:00 2001 From: Michael Kuron <m.kuron@gmx.de> Date: Thu, 30 Jan 2025 11:12:25 +0100 Subject: [PATCH] ARM SME detection for macOS --- .../backends/simd_instruction_sets.py | 13 ++++- src/pystencils/cpu/cpujit.py | 7 ++- src/pystencils/include/philox_rand.h | 56 ++++++++++--------- 3 files changed, 46 insertions(+), 30 deletions(-) diff --git a/src/pystencils/backends/simd_instruction_sets.py b/src/pystencils/backends/simd_instruction_sets.py index ac6a626c3..e0a59ea7a 100644 --- a/src/pystencils/backends/simd_instruction_sets.py +++ b/src/pystencils/backends/simd_instruction_sets.py @@ -1,6 +1,6 @@ import os import platform -from ctypes import CDLL +from ctypes import CDLL, c_int, c_size_t, sizeof, byref from warnings import warn import numpy as np @@ -38,7 +38,14 @@ def get_supported_instruction_sets(): if 'PYSTENCILS_SIMD' in os.environ: return os.environ['PYSTENCILS_SIMD'].split(',') if platform.system() == 'Darwin' and platform.machine() == 'arm64': - return ['neon'] + result = ['neon'] + libc = CDLL('/usr/lib/libc.dylib') + value = c_int(0) + size = c_size_t(sizeof(value)) + status = libc.sysctlbyname(b"hw.optional.arm.FEAT_SME", byref(value), byref(size), None, 0) + if status == 0 and value.value == 1: + result.insert(0, "sme") + return result elif platform.system() == 'Windows' and platform.machine() == 'ARM64': return ['neon'] elif platform.system() == 'Linux' and platform.machine() == 'aarch64': @@ -59,7 +66,7 @@ def get_supported_instruction_sets(): length //= 2 result.append(name) if hwcap2 & (1 << 23): # HWCAP2_SME - result.append("sme") + result.insert(0, "sme") # prepend to list so it is not automatically chosen as best instruction set return result elif platform.system() == 'Linux' and platform.machine().startswith('riscv'): libc = CDLL('libc.so.6') diff --git a/src/pystencils/cpu/cpujit.py b/src/pystencils/cpu/cpujit.py index d9a320e76..0364ca689 100644 --- a/src/pystencils/cpu/cpujit.py +++ b/src/pystencils/cpu/cpujit.py @@ -63,6 +63,7 @@ import numpy as np from pystencils import FieldType from pystencils.astnodes import LoopOverCoordinate from pystencils.backends.cbackend import generate_c, get_headers +from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets from pystencils.cpu.msvc_detection import get_environment from pystencils.include import get_pystencils_include_path from pystencils.kernel_wrapper import KernelWrapper @@ -172,7 +173,11 @@ def read_config(): ('restrict_qualifier', '__restrict__') ]) if platform.machine() == 'arm64': - default_compiler_config['flags'] = default_compiler_config['flags'].replace('-march=native ', '') + if 'sme' in get_supported_instruction_sets(): + flag = '-march=armv8.7-a+sme ' + else: + flag = '' + default_compiler_config['flags'] = default_compiler_config['flags'].replace('-march=native ', flag) for libomp in ['/opt/local/lib/libomp/libomp.dylib', '/usr/local/lib/libomp.dylib', '/opt/homebrew/lib/libomp.dylib']: if os.path.exists(libomp): diff --git a/src/pystencils/include/philox_rand.h b/src/pystencils/include/philox_rand.h index cfbf54b04..186fa86d7 100644 --- a/src/pystencils/include/philox_rand.h +++ b/src/pystencils/include/philox_rand.h @@ -74,13 +74,17 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define QUALIFIERS static __forceinline__ __device__ #elif defined(__OPENCL_VERSION__) #define QUALIFIERS static inline -#elif defined(__ARM_FEATURE_SME) -#define QUALIFIERS __attribute__((arm_streaming_compatible)) #else #define QUALIFIERS inline #include "myintrin.h" #endif +#if defined(__ARM_FEATURE_SME) +#define SVE_QUALIFIERS __attribute__((arm_streaming_compatible)) QUALIFIERS +#else +#define SVE_QUALIFIERS QUALIFIERS +#endif + #define PHILOX_W32_0 (0x9E3779B9) #define PHILOX_W32_1 (0xBB67AE85) #define PHILOX_M4x32_0 (0xD2511F53) @@ -749,7 +753,7 @@ QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 #if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_SME) -QUALIFIERS void _philox4x32round(svuint32x4_t & ctr, svuint32x2_t & key) +SVE_QUALIFIERS void _philox4x32round(svuint32x4_t & ctr, svuint32x2_t & key) { svuint32_t lo0 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0)); svuint32_t hi0 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0)); @@ -762,14 +766,14 @@ QUALIFIERS void _philox4x32round(svuint32x4_t & ctr, svuint32x2_t & key) ctr = svset4_u32(ctr, 3, lo0); } -QUALIFIERS void _philox4x32bumpkey(svuint32x2_t & key) +SVE_QUALIFIERS void _philox4x32bumpkey(svuint32x2_t & key) { key = svset2_u32(key, 0, svadd_u32_x(svptrue_b32(), svget2_u32(key, 0), svdup_u32(PHILOX_W32_0))); key = svset2_u32(key, 1, svadd_u32_x(svptrue_b32(), svget2_u32(key, 1), svdup_u32(PHILOX_W32_1))); } template<bool high> -QUALIFIERS svfloat64_t _uniform_double_hq(svuint32_t x, svuint32_t y) +SVE_QUALIFIERS svfloat64_t _uniform_double_hq(svuint32_t x, svuint32_t y) { // convert 32 to 64 bit if (high) @@ -796,9 +800,9 @@ QUALIFIERS svfloat64_t _uniform_double_hq(svuint32_t x, svuint32_t y) } -QUALIFIERS void philox_float4(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3, - uint32 key0, uint32 key1, - svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4) +SVE_QUALIFIERS void philox_float4(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3, + uint32 key0, uint32 key1, + svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4) { svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1)); svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3); @@ -826,9 +830,9 @@ QUALIFIERS void philox_float4(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, } -QUALIFIERS void philox_double2(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3, - uint32 key0, uint32 key1, - svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi) +SVE_QUALIFIERS void philox_double2(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3, + uint32 key0, uint32 key1, + svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi) { svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1)); svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3); @@ -849,9 +853,9 @@ QUALIFIERS void philox_double2(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2 rnd2hi = _uniform_double_hq<true>(svget4_u32(ctr, 2), svget4_u32(ctr, 3)); } -QUALIFIERS void philox_float4(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3, - uint32 key0, uint32 key1, - svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4) +SVE_QUALIFIERS void philox_float4(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4) { svuint32_t ctr0v = svdup_u32(ctr0); svuint32_t ctr2v = svdup_u32(ctr2); @@ -860,16 +864,16 @@ QUALIFIERS void philox_float4(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4); } -QUALIFIERS void philox_float4(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3, - uint32 key0, uint32 key1, - svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4) +SVE_QUALIFIERS void philox_float4(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4) { philox_float4(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4); } -QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3, - uint32 key0, uint32 key1, - svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi) +SVE_QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi) { svuint32_t ctr0v = svdup_u32(ctr0); svuint32_t ctr2v = svdup_u32(ctr2); @@ -878,9 +882,9 @@ QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi); } -QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3, - uint32 key0, uint32 key1, - svfloat64_st & rnd1, svfloat64_st & rnd2) +SVE_QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + svfloat64_st & rnd1, svfloat64_st & rnd2) { svuint32_t ctr0v = svdup_u32(ctr0); svuint32_t ctr2v = svdup_u32(ctr2); @@ -890,9 +894,9 @@ QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore); } -QUALIFIERS void philox_double2(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3, - uint32 key0, uint32 key1, - svfloat64_st & rnd1, svfloat64_st & rnd2) +SVE_QUALIFIERS void philox_double2(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + svfloat64_st & rnd1, svfloat64_st & rnd2) { philox_double2(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2); } -- GitLab