|
7 | 7 |
|
8 | 8 | #if defined(__ARM_NEON)
|
9 | 9 | #include <arm_neon.h>
|
| 10 | +#elif defined(__AVX2__) |
| 11 | + #include <immintrin.h> |
| 12 | +#endif |
| 13 | + |
| 14 | +#if defined(__AVX2__) |
| 15 | + #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) |
| 16 | + |
| 17 | + static inline __m256i bytes_from_nibbles_32(const uint8_t* rsi) { |
| 18 | + // Load 16 bytes from memory |
| 19 | + __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); |
| 20 | + __m128i tmph = _mm_srli_epi16(tmpl, 4); |
| 21 | + const __m128i lowMask = _mm_set1_epi8(0xF); |
| 22 | + tmpl = _mm_and_si128(lowMask, tmpl); |
| 23 | + tmph = _mm_and_si128(lowMask, tmph); |
| 24 | + return MM256_SET_M128I(tmph, tmpl); |
| 25 | + } |
| 26 | + |
| 27 | + static inline float hsum_float_8(const __m256 x) { |
| 28 | + __m128 res = _mm256_extractf128_ps(x, 1); |
| 29 | + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); |
| 30 | + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); |
| 31 | + res = _mm_add_ss(res, _mm_movehdup_ps(res)); |
| 32 | + return _mm_cvtss_f32(res); |
| 33 | + } |
| 34 | + |
| 35 | + // add int16_t pairwise and return as float vector |
| 36 | + static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { |
| 37 | + const __m128i ones = _mm_set1_epi16(1); |
| 38 | + const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); |
| 39 | + const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); |
| 40 | + const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); |
| 41 | + return _mm256_cvtepi32_ps(summed_pairs); |
| 42 | + } |
| 43 | + |
| 44 | + // multiply int8_t, add results pairwise twice and return as float vector |
| 45 | + static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { |
| 46 | + const __m128i xl = _mm256_castsi256_si128(x); |
| 47 | + const __m128i xh = _mm256_extractf128_si256(x, 1); |
| 48 | + const __m128i yl = _mm256_castsi256_si128(y); |
| 49 | + const __m128i yh = _mm256_extractf128_si256(y, 1); |
| 50 | + // Get absolute values of x vectors |
| 51 | + const __m128i axl = _mm_sign_epi8(xl, xl); |
| 52 | + const __m128i axh = _mm_sign_epi8(xh, xh); |
| 53 | + // Sign the values of the y vectors |
| 54 | + const __m128i syl = _mm_sign_epi8(yl, xl); |
| 55 | + const __m128i syh = _mm_sign_epi8(yh, xh); |
| 56 | + // Perform multiplication and create 16-bit values |
| 57 | + const __m128i dotl = _mm_maddubs_epi16(axl, syl); |
| 58 | + const __m128i doth = _mm_maddubs_epi16(axh, syh); |
| 59 | + return sum_i16_pairs_float(doth, dotl); |
| 60 | + } |
10 | 61 | #endif
|
11 | 62 |
|
12 | 63 | void softmax(float* x, const int size) {
|
@@ -253,6 +304,30 @@ void matmulQ40vQ80(MatmulThreadInfo* a) {
|
253 | 304 | }
|
254 | 305 | a->output[d] = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
255 | 306 | }
|
| 307 | +#elif defined(__AVX2__) |
| 308 | + for (int d = a->ds; d < a->de; d++) { |
| 309 | + __m256 acc = _mm256_setzero_ps(); |
| 310 | + |
| 311 | + for (int j = 0; j < n; j++) { |
| 312 | + /* Compute combined scale for the block */ |
| 313 | + const __m256 cd = _mm256_set1_ps( convertF16ToF32(w[d * n + j].d) * convertF16ToF32(input[j].d) ); |
| 314 | + |
| 315 | + __m256i bx = bytes_from_nibbles_32(w[d * n + j].qs); |
| 316 | + |
| 317 | + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. |
| 318 | + const __m256i off = _mm256_set1_epi8( 8 ); |
| 319 | + bx = _mm256_sub_epi8(bx, off); |
| 320 | + |
| 321 | + __m256i by = _mm256_loadu_si256((const __m256i *)input[j].qs); |
| 322 | + |
| 323 | + const __m256 q = mul_sum_i8_pairs_float(bx, by); |
| 324 | + |
| 325 | + /* Multiply q with scale and accumulate */ |
| 326 | + acc = _mm256_fmadd_ps( cd, q, acc ); |
| 327 | + } |
| 328 | + |
| 329 | + a->output[d] = hsum_float_8(acc); |
| 330 | + } |
256 | 331 | #else
|
257 | 332 | printf("matmulQ40vQ80 - not implemented\n");
|
258 | 333 | exit(EXIT_FAILURE);
|
|
0 commit comments