Skip to content

Commit f2137af

Browse files
authored
feat: avx2 (#2)
* avx2.
1 parent 7eb77ca commit f2137af

File tree

4 files changed

+90
-8
lines changed

4 files changed

+90
-8
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ __pycache__
99
quants-test
1010
transformer-tasks-test
1111
main
12+
run.sh

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
CXX = g++
2-
CXXFLAGS = -std=c++11 -Werror -O3
2+
CXXFLAGS = -std=c++11 -Werror -O3 -march=native -mtune=native
33

44
utils: src/utils.cpp
55
$(CXX) $(CXXFLAGS) -c src/utils.cpp -o utils.o

README.md

+13-7
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,15 @@ This project was initiated based on the [llama2.c](https://github.com/karpathy/l
1919
* This project is a proof of concept, it's not optimized for production usage.
2020
* You can run Distributed Llama only on 1, 2, 4... 2^n devices.
2121
* The project supports only the inference mode, the chat mode is not supported.
22-
* Optimized for:
23-
* ✅ ARM CPUs
24-
* ❌ x86_64 CPUs (Q40xF32 mode works but is slow)
22+
* Optimized for (weights format × buffer format):
23+
* ARM CPUs
24+
* ✅ F32 × F32
25+
* ❌ F16 × F16
26+
* ✅ Q40 × Q80
27+
* x86_64 AVX2 CPUs
28+
* ❌ F32 × F32
29+
* ❌ F16 × F16
30+
* ⚠️ Q40 × Q80 (partial optimization)
2531

2632
**Supported models**
2733
* Llama 2 7B
@@ -134,7 +140,7 @@ sudo nice -n -20 ./main worker --port 9998
134140
```
135141
10. Run root node on the root device:
136142
```sh
137-
sudo nice -n -20 ./main inference --model ../dllama_llama-2-13b_q40.bin --tokenizer ../tokenizer.bin --weights-float-type q40 --buffer-float-type q80 --prompt "Hello world" --steps 16 --nthreads 4 --workers 10.0.0.2:9998
143+
sudo nice -n -20 ./main inference --model ../dllama_llama-2-7b_q40.bin --tokenizer ../tokenizer.bin --weights-float-type q40 --buffer-float-type q80 --prompt "Hello world" --steps 16 --nthreads 4 --workers 10.0.0.2:9998
138144
```
139145

140146
To add more worker nodes, just add more addresses to the `--workers` argument.
@@ -145,9 +151,9 @@ To add more worker nodes, just add more addresses to the `--workers` argument.
145151

146152
[Share your results](https://github.com/b4rtaz/distributed-llama/discussions)!
147153

148-
## 💻 How to Run on Debian x86_64
154+
## 💻 How to Run on MacOS or Linux
149155

150-
x86_64 CPUs are not optimized yet but still you can observe a significant speedup when you run Distributed Llama on multiple devices.
156+
You need to have x86_64 AVX2 CPU or ARM CPU. Different devices may have different CPUs.
151157

152158
1. Install Git and G++:
153159
```sh
@@ -177,7 +183,7 @@ sudo nice -n -20 ./main worker --port 9998
177183
```
178184
7. Run worker nodes on worker devices:
179185
```sh
180-
sudo nice -n -20 ./main inference --model ../dllama_llama-2-13b_q40.bin --tokenizer ../tokenizer.bin --weights-float-type q40 --buffer-float-type f32 --prompt "Hello world" --steps 16 --nthreads 4 --workers 192.168.0.1:9998
186+
sudo nice -n -20 ./main inference --model ../dllama_llama-2-7b_q40.bin --tokenizer ../tokenizer.bin --weights-float-type q40 --buffer-float-type q80 --prompt "Hello world" --steps 16 --nthreads 4 --workers 192.168.0.1:9998
181187
```
182188

183189
## 💡 License

src/funcs.cpp

+75
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,57 @@
77

88
#if defined(__ARM_NEON)
99
#include <arm_neon.h>
10+
#elif defined(__AVX2__)
11+
#include <immintrin.h>
12+
#endif
13+
14+
#if defined(__AVX2__)
15+
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
16+
17+
static inline __m256i bytes_from_nibbles_32(const uint8_t* rsi) {
18+
// Load 16 bytes from memory
19+
__m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
20+
__m128i tmph = _mm_srli_epi16(tmpl, 4);
21+
const __m128i lowMask = _mm_set1_epi8(0xF);
22+
tmpl = _mm_and_si128(lowMask, tmpl);
23+
tmph = _mm_and_si128(lowMask, tmph);
24+
return MM256_SET_M128I(tmph, tmpl);
25+
}
26+
27+
static inline float hsum_float_8(const __m256 x) {
28+
__m128 res = _mm256_extractf128_ps(x, 1);
29+
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
30+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
31+
res = _mm_add_ss(res, _mm_movehdup_ps(res));
32+
return _mm_cvtss_f32(res);
33+
}
34+
35+
// add int16_t pairwise and return as float vector
36+
static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
37+
const __m128i ones = _mm_set1_epi16(1);
38+
const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
39+
const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
40+
const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
41+
return _mm256_cvtepi32_ps(summed_pairs);
42+
}
43+
44+
// multiply int8_t, add results pairwise twice and return as float vector
45+
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
46+
const __m128i xl = _mm256_castsi256_si128(x);
47+
const __m128i xh = _mm256_extractf128_si256(x, 1);
48+
const __m128i yl = _mm256_castsi256_si128(y);
49+
const __m128i yh = _mm256_extractf128_si256(y, 1);
50+
// Get absolute values of x vectors
51+
const __m128i axl = _mm_sign_epi8(xl, xl);
52+
const __m128i axh = _mm_sign_epi8(xh, xh);
53+
// Sign the values of the y vectors
54+
const __m128i syl = _mm_sign_epi8(yl, xl);
55+
const __m128i syh = _mm_sign_epi8(yh, xh);
56+
// Perform multiplication and create 16-bit values
57+
const __m128i dotl = _mm_maddubs_epi16(axl, syl);
58+
const __m128i doth = _mm_maddubs_epi16(axh, syh);
59+
return sum_i16_pairs_float(doth, dotl);
60+
}
1061
#endif
1162

1263
void softmax(float* x, const int size) {
@@ -253,6 +304,30 @@ void matmulQ40vQ80(MatmulThreadInfo* a) {
253304
}
254305
a->output[d] = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
255306
}
307+
#elif defined(__AVX2__)
308+
for (int d = a->ds; d < a->de; d++) {
309+
__m256 acc = _mm256_setzero_ps();
310+
311+
for (int j = 0; j < n; j++) {
312+
/* Compute combined scale for the block */
313+
const __m256 cd = _mm256_set1_ps( convertF16ToF32(w[d * n + j].d) * convertF16ToF32(input[j].d) );
314+
315+
__m256i bx = bytes_from_nibbles_32(w[d * n + j].qs);
316+
317+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
318+
const __m256i off = _mm256_set1_epi8( 8 );
319+
bx = _mm256_sub_epi8(bx, off);
320+
321+
__m256i by = _mm256_loadu_si256((const __m256i *)input[j].qs);
322+
323+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
324+
325+
/* Multiply q with scale and accumulate */
326+
acc = _mm256_fmadd_ps( cd, q, acc );
327+
}
328+
329+
a->output[d] = hsum_float_8(acc);
330+
}
256331
#else
257332
printf("matmulQ40vQ80 - not implemented\n");
258333
exit(EXIT_FAILURE);

0 commit comments

Comments
 (0)