From 466a3426baebd9f01c220219672eb52f7fdb9e8d Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Sun, 21 Mar 2021 22:15:23 +0100
Subject: [PATCH] Fix AES-NI on Ice Lake processors

---
 pystencils/include/aesni_rand.h | 38 +++++++++++++++++++++++++++++----
 1 file changed, 34 insertions(+), 4 deletions(-)

diff --git a/pystencils/include/aesni_rand.h b/pystencils/include/aesni_rand.h
index e55f93272..00e1dffee 100644
--- a/pystencils/include/aesni_rand.h
+++ b/pystencils/include/aesni_rand.h
@@ -21,6 +21,36 @@
 typedef std::uint32_t uint32;
 typedef std::uint64_t uint64;
 
+template <typename T, std::size_t Alignment>
+class AlignedAllocator
+{
+public:
+    typedef T value_type;
+
+    template <typename U>
+    struct rebind {
+        typedef AlignedAllocator<U, Alignment> other;
+    };
+
+    T * allocate(const std::size_t n) const {
+        if (n == 0) {
+            return nullptr;
+        }
+        void * const p = _mm_malloc(n*sizeof(T), Alignment);
+        if (p == nullptr) {
+            throw std::bad_alloc();
+        }
+        return static_cast<T *>(p);
+    }
+
+    void deallocate(T * const p, const std::size_t n) const {
+        _mm_free(p);
+    }
+};
+
+template <typename Key, typename T>
+using AlignedMap = std::map<Key, T, std::less<Key>, AlignedAllocator<std::pair<const Key, T>, sizeof(Key)>>;
+
 #if defined(__AES__) || defined(_MSC_VER)
 QUALIFIERS __m128i aesni_keygen_assist(__m128i temp1, __m128i temp2) {
     __m128i temp3; 
@@ -88,7 +118,7 @@ QUALIFIERS const std::array<__m128i,11> & aesni_roundkeys(const __m128i & k128)
     alignas(16) std::array<uint32,4> a;
     _mm_store_si128((__m128i*) a.data(), k128);
     
-    static std::map<std::array<uint32,4>, std::array<__m128i,11>> roundkeys;
+    static AlignedMap<std::array<uint32,4>, std::array<__m128i,11>> roundkeys;
     
     if(roundkeys.find(a) == roundkeys.end()) {
         auto rk = aesni_keygen(k128);
@@ -311,7 +341,7 @@ QUALIFIERS const std::array<__m256i,11> & aesni_roundkeys(const __m256i & k256)
     alignas(32) std::array<uint32,8> a;
     _mm256_store_si256((__m256i*) a.data(), k256);
     
-    static std::map<std::array<uint32,8>, std::array<__m256i,11>> roundkeys;
+    static AlignedMap<std::array<uint32,8>, std::array<__m256i,11>> roundkeys;
     
     if(roundkeys.find(a) == roundkeys.end()) {
         auto rk1 = aesni_keygen(_mm256_extractf128_si256(k256, 0));
@@ -526,7 +556,7 @@ QUALIFIERS const std::array<__m512i,11> & aesni_roundkeys(const __m512i & k512)
     alignas(64) std::array<uint32,16> a;
     _mm512_store_si512((__m512i*) a.data(), k512);
     
-    static std::map<std::array<uint32,16>, std::array<__m512i,11>> roundkeys;
+    static AlignedMap<std::array<uint32,16>, std::array<__m512i,11>> roundkeys;
     
     if(roundkeys.find(a) == roundkeys.end()) {
         auto rk1 = aesni_keygen(_mm512_extracti32x4_epi32(k512, 0));
@@ -553,7 +583,7 @@ QUALIFIERS __m512i aesni1xm128i(const __m512i & in, const __m512i & k0) {
     x = _mm512_aesenc_epi128(x, k[7]);
     x = _mm512_aesenc_epi128(x, k[8]);
     x = _mm512_aesenc_epi128(x, k[9]);
-    x = _mm512_aesenclast_epi128(x[10], k);
+    x = _mm512_aesenclast_epi128(x, k[10]);
 #else
     __m128i a = aesni1xm128i(_mm512_extracti32x4_epi32(in, 0), _mm512_extracti32x4_epi32(k0, 0));
     __m128i b = aesni1xm128i(_mm512_extracti32x4_epi32(in, 1), _mm512_extracti32x4_epi32(k0, 1));
-- 
GitLab