Skip to content
Snippets Groups Projects
Commit 84d19478 authored by Michael Kuron's avatar Michael Kuron :mortar_board: Committed by Frederik Hennig
Browse files

ARM SME detection for macOS

parent 8081e9b2
No related branches found
No related tags found
1 merge request!441ARM SME detection for macOS
import os import os
import platform import platform
from ctypes import CDLL from ctypes import CDLL, c_int, c_size_t, sizeof, byref
from warnings import warn from warnings import warn
import numpy as np import numpy as np
...@@ -38,7 +38,14 @@ def get_supported_instruction_sets(): ...@@ -38,7 +38,14 @@ def get_supported_instruction_sets():
if 'PYSTENCILS_SIMD' in os.environ: if 'PYSTENCILS_SIMD' in os.environ:
return os.environ['PYSTENCILS_SIMD'].split(',') return os.environ['PYSTENCILS_SIMD'].split(',')
if platform.system() == 'Darwin' and platform.machine() == 'arm64': if platform.system() == 'Darwin' and platform.machine() == 'arm64':
return ['neon'] result = ['neon']
libc = CDLL('/usr/lib/libc.dylib')
value = c_int(0)
size = c_size_t(sizeof(value))
status = libc.sysctlbyname(b"hw.optional.arm.FEAT_SME", byref(value), byref(size), None, 0)
if status == 0 and value.value == 1:
result.insert(0, "sme")
return result
elif platform.system() == 'Windows' and platform.machine() == 'ARM64': elif platform.system() == 'Windows' and platform.machine() == 'ARM64':
return ['neon'] return ['neon']
elif platform.system() == 'Linux' and platform.machine() == 'aarch64': elif platform.system() == 'Linux' and platform.machine() == 'aarch64':
...@@ -59,7 +66,7 @@ def get_supported_instruction_sets(): ...@@ -59,7 +66,7 @@ def get_supported_instruction_sets():
length //= 2 length //= 2
result.append(name) result.append(name)
if hwcap2 & (1 << 23): # HWCAP2_SME if hwcap2 & (1 << 23): # HWCAP2_SME
result.append("sme") result.insert(0, "sme") # prepend to list so it is not automatically chosen as best instruction set
return result return result
elif platform.system() == 'Linux' and platform.machine().startswith('riscv'): elif platform.system() == 'Linux' and platform.machine().startswith('riscv'):
libc = CDLL('libc.so.6') libc = CDLL('libc.so.6')
......
...@@ -63,6 +63,7 @@ import numpy as np ...@@ -63,6 +63,7 @@ import numpy as np
from pystencils import FieldType from pystencils import FieldType
from pystencils.astnodes import LoopOverCoordinate from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import generate_c, get_headers from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets
from pystencils.cpu.msvc_detection import get_environment from pystencils.cpu.msvc_detection import get_environment
from pystencils.include import get_pystencils_include_path from pystencils.include import get_pystencils_include_path
from pystencils.kernel_wrapper import KernelWrapper from pystencils.kernel_wrapper import KernelWrapper
...@@ -172,7 +173,11 @@ def read_config(): ...@@ -172,7 +173,11 @@ def read_config():
('restrict_qualifier', '__restrict__') ('restrict_qualifier', '__restrict__')
]) ])
if platform.machine() == 'arm64': if platform.machine() == 'arm64':
default_compiler_config['flags'] = default_compiler_config['flags'].replace('-march=native ', '') if 'sme' in get_supported_instruction_sets():
flag = '-march=armv8.7-a+sme '
else:
flag = ''
default_compiler_config['flags'] = default_compiler_config['flags'].replace('-march=native ', flag)
for libomp in ['/opt/local/lib/libomp/libomp.dylib', '/usr/local/lib/libomp.dylib', for libomp in ['/opt/local/lib/libomp/libomp.dylib', '/usr/local/lib/libomp.dylib',
'/opt/homebrew/lib/libomp.dylib']: '/opt/homebrew/lib/libomp.dylib']:
if os.path.exists(libomp): if os.path.exists(libomp):
......
...@@ -74,13 +74,17 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ...@@ -74,13 +74,17 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define QUALIFIERS static __forceinline__ __device__ #define QUALIFIERS static __forceinline__ __device__
#elif defined(__OPENCL_VERSION__) #elif defined(__OPENCL_VERSION__)
#define QUALIFIERS static inline #define QUALIFIERS static inline
#elif defined(__ARM_FEATURE_SME)
#define QUALIFIERS __attribute__((arm_streaming_compatible))
#else #else
#define QUALIFIERS inline #define QUALIFIERS inline
#include "myintrin.h" #include "myintrin.h"
#endif #endif
#if defined(__ARM_FEATURE_SME)
#define SVE_QUALIFIERS __attribute__((arm_streaming_compatible)) QUALIFIERS
#else
#define SVE_QUALIFIERS QUALIFIERS
#endif
#define PHILOX_W32_0 (0x9E3779B9) #define PHILOX_W32_0 (0x9E3779B9)
#define PHILOX_W32_1 (0xBB67AE85) #define PHILOX_W32_1 (0xBB67AE85)
#define PHILOX_M4x32_0 (0xD2511F53) #define PHILOX_M4x32_0 (0xD2511F53)
...@@ -749,7 +753,7 @@ QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 ...@@ -749,7 +753,7 @@ QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32
#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_SME) #if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_SME)
QUALIFIERS void _philox4x32round(svuint32x4_t & ctr, svuint32x2_t & key) SVE_QUALIFIERS void _philox4x32round(svuint32x4_t & ctr, svuint32x2_t & key)
{ {
svuint32_t lo0 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0)); svuint32_t lo0 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0));
svuint32_t hi0 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0)); svuint32_t hi0 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0));
...@@ -762,14 +766,14 @@ QUALIFIERS void _philox4x32round(svuint32x4_t & ctr, svuint32x2_t & key) ...@@ -762,14 +766,14 @@ QUALIFIERS void _philox4x32round(svuint32x4_t & ctr, svuint32x2_t & key)
ctr = svset4_u32(ctr, 3, lo0); ctr = svset4_u32(ctr, 3, lo0);
} }
QUALIFIERS void _philox4x32bumpkey(svuint32x2_t & key) SVE_QUALIFIERS void _philox4x32bumpkey(svuint32x2_t & key)
{ {
key = svset2_u32(key, 0, svadd_u32_x(svptrue_b32(), svget2_u32(key, 0), svdup_u32(PHILOX_W32_0))); key = svset2_u32(key, 0, svadd_u32_x(svptrue_b32(), svget2_u32(key, 0), svdup_u32(PHILOX_W32_0)));
key = svset2_u32(key, 1, svadd_u32_x(svptrue_b32(), svget2_u32(key, 1), svdup_u32(PHILOX_W32_1))); key = svset2_u32(key, 1, svadd_u32_x(svptrue_b32(), svget2_u32(key, 1), svdup_u32(PHILOX_W32_1)));
} }
template<bool high> template<bool high>
QUALIFIERS svfloat64_t _uniform_double_hq(svuint32_t x, svuint32_t y) SVE_QUALIFIERS svfloat64_t _uniform_double_hq(svuint32_t x, svuint32_t y)
{ {
// convert 32 to 64 bit // convert 32 to 64 bit
if (high) if (high)
...@@ -796,9 +800,9 @@ QUALIFIERS svfloat64_t _uniform_double_hq(svuint32_t x, svuint32_t y) ...@@ -796,9 +800,9 @@ QUALIFIERS svfloat64_t _uniform_double_hq(svuint32_t x, svuint32_t y)
} }
QUALIFIERS void philox_float4(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3, SVE_QUALIFIERS void philox_float4(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3,
uint32 key0, uint32 key1, uint32 key0, uint32 key1,
svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4) svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
{ {
svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1)); svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1));
svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3); svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3);
...@@ -826,9 +830,9 @@ QUALIFIERS void philox_float4(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, ...@@ -826,9 +830,9 @@ QUALIFIERS void philox_float4(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2,
} }
QUALIFIERS void philox_double2(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3, SVE_QUALIFIERS void philox_double2(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3,
uint32 key0, uint32 key1, uint32 key0, uint32 key1,
svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi) svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi)
{ {
svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1)); svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1));
svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3); svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3);
...@@ -849,9 +853,9 @@ QUALIFIERS void philox_double2(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2 ...@@ -849,9 +853,9 @@ QUALIFIERS void philox_double2(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2
rnd2hi = _uniform_double_hq<true>(svget4_u32(ctr, 2), svget4_u32(ctr, 3)); rnd2hi = _uniform_double_hq<true>(svget4_u32(ctr, 2), svget4_u32(ctr, 3));
} }
QUALIFIERS void philox_float4(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3, SVE_QUALIFIERS void philox_float4(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key0, uint32 key1,
svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4) svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
{ {
svuint32_t ctr0v = svdup_u32(ctr0); svuint32_t ctr0v = svdup_u32(ctr0);
svuint32_t ctr2v = svdup_u32(ctr2); svuint32_t ctr2v = svdup_u32(ctr2);
...@@ -860,16 +864,16 @@ QUALIFIERS void philox_float4(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ...@@ -860,16 +864,16 @@ QUALIFIERS void philox_float4(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4); philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
} }
QUALIFIERS void philox_float4(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3, SVE_QUALIFIERS void philox_float4(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key0, uint32 key1,
svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4) svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
{ {
philox_float4(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4); philox_float4(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4);
} }
QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3, SVE_QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key0, uint32 key1,
svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi) svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi)
{ {
svuint32_t ctr0v = svdup_u32(ctr0); svuint32_t ctr0v = svdup_u32(ctr0);
svuint32_t ctr2v = svdup_u32(ctr2); svuint32_t ctr2v = svdup_u32(ctr2);
...@@ -878,9 +882,9 @@ QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ...@@ -878,9 +882,9 @@ QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi); philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
} }
QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3, SVE_QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key0, uint32 key1,
svfloat64_st & rnd1, svfloat64_st & rnd2) svfloat64_st & rnd1, svfloat64_st & rnd2)
{ {
svuint32_t ctr0v = svdup_u32(ctr0); svuint32_t ctr0v = svdup_u32(ctr0);
svuint32_t ctr2v = svdup_u32(ctr2); svuint32_t ctr2v = svdup_u32(ctr2);
...@@ -890,9 +894,9 @@ QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ...@@ -890,9 +894,9 @@ QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore); philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
} }
QUALIFIERS void philox_double2(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3, SVE_QUALIFIERS void philox_double2(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key0, uint32 key1,
svfloat64_st & rnd1, svfloat64_st & rnd2) svfloat64_st & rnd1, svfloat64_st & rnd2)
{ {
philox_double2(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2); philox_double2(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment