Skip to content

Commit 45c3e5b

Browse files
authored
funcs-test. (#31)
1 parent b095f78 commit 45c3e5b

File tree

12 files changed

+197
-40
lines changed

12 files changed

+197
-40
lines changed

.github/workflows/main.yml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@ jobs:
1212
runs-on: ${{matrix.os}}
1313
strategy:
1414
matrix:
15-
os: [ubuntu-latest]
15+
os:
16+
- ubuntu-latest
17+
platforms:
18+
- linux/arm64
19+
- linux/amd64
1620
steps:
1721
- name: Checkout Repo
1822
uses: actions/checkout@v3
@@ -25,3 +29,15 @@ jobs:
2529
id: build
2630
run: |
2731
make main
32+
make funcs-test
33+
make quants-test
34+
make llama2-tasks-test
35+
make grok1-tasks-test
36+
- name: funcs-test
37+
run: ./funcs-test
38+
- name: quants-test
39+
run: ./quants-test
40+
- name: llama2-tasks-test
41+
run: ./llama2-tasks-test
42+
- name: grok1-tasks-test
43+
run: ./grok1-tasks-test

.gitignore

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
*.bin
77
__pycache__
88

9-
quants-test
10-
llama2-tasks-test
11-
grok1-tasks-test
9+
*-test
1210
main
1311
run.sh

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ quants: src/quants.cpp
77
$(CXX) $(CXXFLAGS) -c src/quants.cpp -o quants.o
88
funcs: src/funcs.cpp
99
$(CXX) $(CXXFLAGS) -c src/funcs.cpp -o funcs.o
10+
funcs-test: src/funcs-test.cpp funcs
11+
$(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o
1012
socket: src/socket.cpp
1113
$(CXX) $(CXXFLAGS) -c src/socket.cpp -o socket.o
1214
transformer: src/utils.cpp
@@ -24,6 +26,8 @@ tokenizer: src/tokenizer.cpp
2426

2527
main: src/main.cpp utils quants funcs socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer
2628
$(CXX) $(CXXFLAGS) src/main.cpp -o main utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o -lpthread
29+
funcs-test: src/funcs-test.cpp funcs utils quants
30+
$(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o utils.o quants.o
2731
quants-test: src/quants.cpp utils quants
2832
$(CXX) $(CXXFLAGS) src/quants-test.cpp -o quants-test utils.o quants.o -lpthread
2933
llama2-tasks-test: src/llama2-tasks-test.cpp utils quants funcs socket transformer tasks llama2-tasks tokenizer

converter/convert-llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import math
66
import numpy as np
7-
from writer import writeTensor, writeHeader
7+
from writer import writeTensor, writeHeader, isFloatTypeSupported
88
from pathlib import Path
99

1010
LAYER_CHUNK_SIZE = 48
@@ -106,7 +106,7 @@ def usage():
106106
modelPath = sys.argv[1]
107107
targetFloatType = sys.argv[2]
108108

109-
if (not modelPath or not targetFloatType in ['f16', 'f32', 'q40']):
109+
if (not modelPath or not isFloatTypeSupported(targetFloatType)):
110110
usage()
111111

112112
modelName = modelPath.split('/')[-1]

converter/convert-tokenizer-sentencepiece.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def export(self):
6767
print(f"{bytes.decode('utf-8')} {score}")
6868
f.write(struct.pack("fI", score, len(bytes)))
6969
f.write(bytes)
70+
print(f'Created {outputPath}')
7071

7172
if __name__ == "__main__":
7273
if (len(sys.argv) < 2):

converter/writer.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
import numpy as np
55

66
def isFloatTypeSupported(type):
7-
return type in ['f16', 'f32', 'q40']
7+
return type in ['f16', 'f32', 'q40', 'q80']
88

99
def writeQuantizedQ40Tensor(file, x):
10-
t0 = time.time()
1110
x = x.to(torch.float32).numpy().astype(np.float32)
1211
blockSize = 32
1312
blockHalfSize = blockSize // 2
@@ -35,28 +34,61 @@ def writeQuantizedQ40Tensor(file, x):
3534
buffer = struct.pack(f'e{blockHalfSize}B', delta16, *block)
3635
file.write(buffer)
3736
nBytes += len(buffer)
38-
t1 = time.time()
39-
print(f'Quantized tensor to {nBytes} bytes in {t1 - t0:.2f} s')
37+
return nBytes
38+
39+
def writeQuantizedQ80Tensor(file, x):
40+
x = x.to(torch.float32).numpy().astype(np.float32)
41+
blockSize = 32
42+
assert(x.shape[0] % blockSize == 0)
43+
groups = x.reshape(-1, blockSize)
44+
gmax = np.max(groups, axis=1)
45+
gmin = np.min(groups, axis=1)
46+
gabsMax = np.where(-gmin > gmax, -gmin, gmax)
47+
deltas = gabsMax / ((1 << 7) - 1)
48+
deltas16 = deltas.astype(np.float16)
49+
ids = np.where(deltas != 0, 1.0 / deltas, 0)
50+
groups = groups * ids[:, np.newaxis]
51+
groups8 = np.round(groups).astype(np.int8)
52+
53+
nBytes = 0
54+
for groupIndex in range(0, len(groups)):
55+
buffer = struct.pack(f'e{blockSize}b', deltas16[groupIndex], *groups8[groupIndex])
56+
file.write(buffer)
57+
nBytes += len(buffer)
58+
return nBytes
4059

4160
def writeF32Tensor(file, d):
4261
chunkSize = 10000
62+
nBytes = 0
4363
for i in range(0, len(d), chunkSize):
4464
chunk = d[i:i+chunkSize].to(torch.float32).numpy().astype(np.float32)
4565
b = struct.pack(f'{len(chunk)}f', *chunk)
66+
nBytes += len(b)
4667
file.write(b)
68+
return nBytes
69+
70+
def writeF16Tensor(file, d):
71+
d = d.to(torch.float16).numpy().astype(np.float16)
72+
b = struct.pack(f'{len(d)}e', *d)
73+
file.write(b)
74+
return len(b)
4775

4876
def writeTensor(file, tensor, floatType):
4977
d = tensor.detach().cpu().view(-1)
78+
t0 = time.time()
79+
nBytes = 0
5080
if (floatType == 'f16'):
51-
d = d.to(torch.float16).numpy().astype(np.float16)
52-
b = struct.pack(f'{len(d)}e', *d)
53-
file.write(b)
81+
nBytes = writeF16Tensor(file, d)
5482
elif (floatType == 'f32'):
55-
writeF32Tensor(file, d)
83+
nBytes = writeF32Tensor(file, d)
5684
elif (floatType == 'q40'):
57-
writeQuantizedQ40Tensor(file, d)
85+
nBytes = writeQuantizedQ40Tensor(file, d)
86+
elif (floatType == 'q80'):
87+
nBytes = writeQuantizedQ80Tensor(file, d)
5888
else:
5989
raise Exception('Unknown float type')
90+
t1 = time.time()
91+
print(f'Saved {floatType} tensor in {t1 - t0:.2f}s, {nBytes} bytes')
6092

6193
def writeHeader(file, params):
6294
headerKeys = {

src/funcs-test.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#include "funcs.hpp"
2+
#include "utils.hpp"
3+
#include <stdio.h>
4+
#include <stdlib.h>
5+
#include <math.h>
6+
7+
void testRms() {
8+
float x[] = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
9+
float r = rms(x, 8);
10+
if (fabs(r - 1.980256) > 0.001) {
11+
printf("❌ rms() = %f\n", r);
12+
exit(EXIT_FAILURE);
13+
}
14+
printf("✅ rms\n");
15+
}
16+
17+
void testMatmulQ80() {
18+
const int n = 512;
19+
const int d = 256;
20+
unsigned long long state = 88888888L;
21+
float x[n];
22+
float w[n * d];
23+
float y[d];
24+
float yQ0[d];
25+
float yQ1[d];
26+
int i;
27+
for (i = 0; i < n; i++) x[i] = randomF32(&state) / 127.0f;
28+
for (i = 0; i < n * d; i++) w[i] = randomF32(&state) / 127.0f;
29+
30+
char* xQ = new char[getBatchBytes(Q80, n, 1)];
31+
char* wQ = new char[getBatchBytes(Q80, n, d)];
32+
quantizeQ80Row(x, (BlockQ80*)xQ, n, 1, 0);
33+
quantizeQ80Row(w, (BlockQ80*)wQ, n * d, 1, 0);
34+
35+
matmul(F32, F32, y, x, w, n, d, 1, 0);
36+
matmul(Q80, F32, yQ0, x, wQ, n, d, 1, 0);
37+
matmul(Q80, Q80, yQ1, xQ, wQ, n, d, 1, 0);
38+
39+
for (i = 0; i < d; i++) {
40+
float diff = fabs(y[i] - yQ0[i]);
41+
if (diff > 0.001) {
42+
printf("❌ matmulQ80() ix=%d %f != %f diff=%f\n", i, y[i], yQ0[i], diff);
43+
exit(EXIT_FAILURE);
44+
}
45+
}
46+
printf("✅ matmulQ80\n");
47+
48+
for (i = 0; i < d; i++) {
49+
float diff = fabs(y[i] - yQ1[i]);
50+
if (diff > 0.001) {
51+
printf("❌ matmulQ80vQ80() ix=%d %f != %f diff=%f\n", i, y[i], yQ1[i], diff);
52+
exit(EXIT_FAILURE);
53+
}
54+
}
55+
printf("✅ matmulQ80vQ80\n");
56+
57+
delete[] xQ;
58+
delete[] wQ;
59+
}
60+
61+
int main() {
62+
initQuants();
63+
64+
testRms();
65+
testMatmulQ80();
66+
return EXIT_SUCCESS;
67+
}

src/funcs.cpp

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,25 @@ void matmulQ40(MatmulThreadInfo* a) {
233233
#endif
234234
}
235235

236+
void matmulQ80(MatmulThreadInfo* a) {
237+
float* input = (float*)a->input;
238+
BlockQ80* weights = (BlockQ80*)a->weights;
239+
assert(a->n % QK80 == 0);
240+
int nb = a->n / QK80;
241+
242+
for (int d = a->ds; d < a->de; d++) {
243+
float sum = 0.0;
244+
for (int i = 0; i < nb; i++) {
245+
float s = 0.0;
246+
for (int j = 0; j < QK80; j++) {
247+
s += input[i * QK80 + j] * (float)weights[d * nb + i].qs[j];
248+
}
249+
sum += s * convertF16ToF32(weights[d * nb + i].d);
250+
}
251+
a->output[d] = sum;
252+
}
253+
}
254+
236255
void matmulQ40vQ80(MatmulThreadInfo* a) {
237256
const BlockQ40* w = (BlockQ40*)a->weights;
238257
const BlockQ80* input = (BlockQ80*)a->input;
@@ -334,6 +353,25 @@ void matmulQ40vQ80(MatmulThreadInfo* a) {
334353
#endif
335354
}
336355

356+
void matmulQ80vQ80(MatmulThreadInfo* a) {
357+
BlockQ80* input = (BlockQ80*)a->input;
358+
BlockQ80* weights = (BlockQ80*)a->weights;
359+
assert(a->n % QK80 == 0);
360+
int nb = a->n / QK80;
361+
362+
for (int d = a->ds; d < a->de; d++) {
363+
float sum = 0.0;
364+
for (int i = 0; i < nb; i++) {
365+
int s = 0;
366+
for (int j = 0; j < QK80; j++) {
367+
s += input[i].qs[j] * (int)weights[d * nb + i].qs[j];
368+
}
369+
sum += s * (convertF16ToF32(input[i].d) * convertF16ToF32(weights[d * nb + i].d));
370+
}
371+
a->output[d] = sum;
372+
}
373+
}
374+
337375
// weights input output
338376
// ___________ ___ ___
339377
// | | | | | |
@@ -363,10 +401,19 @@ void matmul(FloatType weightsFloatType, FloatType inputFloatType, float* output,
363401
matmulQ40(&s);
364402
return;
365403
}
366-
}
367-
if (inputFloatType == Q80 && weightsFloatType == Q40) {
368-
matmulQ40vQ80(&s);
369-
return;
404+
if (weightsFloatType == Q80) {
405+
matmulQ80(&s);
406+
return;
407+
}
408+
} else if (inputFloatType == Q80) {
409+
if (weightsFloatType == Q40) {
410+
matmulQ40vQ80(&s);
411+
return;
412+
}
413+
if (weightsFloatType == Q80) {
414+
matmulQ80vQ80(&s);
415+
return;
416+
}
370417
}
371418

372419
printf("Unsupported float types: %d/%d\n", weightsFloatType, inputFloatType);

src/grok1-tasks-test.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ float expectedOutput_5012_5016[] = { 0.0126675405, 0.0169415697, 0.0183475353, 0
1616

1717
void compare(float* a, float* b, int n) {
1818
for (int i = 0; i < n; i++) {
19-
if (fabs(a[i] - b[i]) > 0.00001) { // Optimization may cause some differences
19+
if (std::isnan(a[i]) || fabs(a[i] - b[i]) > 0.00001) { // Optimization may cause some differences
2020
printf("%.9g != %.9g\n", a[i], b[i]); i++;
2121
printf("%.9g != %.9g\n", a[i], b[i]); i++;
2222
printf("%.9g != %.9g\n", a[i], b[i]); i++;
@@ -44,6 +44,8 @@ int main() {
4444
spec.weightsFloatType = F32;
4545
spec.bufferFloatType = F32;
4646
spec.nSlices = 1;
47+
spec.hiddenAct = GELU;
48+
spec.ropeTheta = 10000.0f;
4749

4850
size_t beforeBlockBytes = spec.dim * spec.vocabSize * sizeof(float);
4951
size_t blockBytes = 956596224;
@@ -62,7 +64,7 @@ int main() {
6264
transformer.pos = 0;
6365

6466
float* x = transformer.x;
65-
for (int i = 0; i < spec.dim; i++) x[i] = randomF32(&state) / 100.0;
67+
for (int i = 0; i < spec.dim; i++) x[i] = (randomF32(&state) / 100.0) / 78.38367176906169f;
6668

6769
TransformerArch arch = buildGrok1Arch(&spec);
6870

@@ -73,7 +75,8 @@ int main() {
7375
context.socket = NULL;
7476
context.socketPool = &socketPool;
7577

76-
TaskLoop loop(nThreads, arch.inference.nTasks, TASK_N_TYPES, arch.inference.tasks, &context);
78+
int skipLastNTasks = 4;
79+
TaskLoop loop(nThreads, arch.inference.nTasks - skipLastNTasks, TASK_N_TYPES, arch.inference.tasks, &context);
7780
long t0 = timeMs();
7881
loop.run();
7982
long t1 = timeMs();

src/llama2-tasks-test.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -524,8 +524,6 @@ float expectedOutput[4096] = {
524524
1.00493455, 1.00216055, 1.02500832, 1.01412213, 0.997673035, 1.01922369, 1.01705575, 1.01369667,
525525
};
526526

527-
void nop(TASK_ARGS) {}
528-
529527
int main() {
530528
TransformerSpec spec;
531529
spec.headerSize = sizeof(TransformerFileOldHeader) + sizeof(int);
@@ -571,18 +569,16 @@ int main() {
571569
for (int i = 0; i < spec.dim; i++) x[i] = randomF32(&state) / 120.0;
572570

573571
TransformerArch arch = buildLlama2Arch(&spec);
574-
arch.inference.tasks[arch.inference.nTasks - 3].handler = &nop;
575-
arch.inference.tasks[arch.inference.nTasks - 2].handler = &nop;
576-
arch.inference.tasks[arch.inference.nTasks - 1].handler = &nop;
577572

578-
int nThreads = 1;
573+
int nThreads = 4;
579574
TransformerContext context;
580575
context.transformer = &transformer;
581576
context.currentBlockIndex = 0;
582577
context.socket = NULL;
583578
context.socketPool = &socketPool;
584579

585-
TaskLoop loop(nThreads, arch.inference.nTasks, TASK_N_TYPES, arch.inference.tasks, &context);
580+
int skipLastNTasks = 3;
581+
TaskLoop loop(nThreads, arch.inference.nTasks - skipLastNTasks, TASK_N_TYPES, arch.inference.tasks, &context);
586582
long t0 = timeMs();
587583
loop.run();
588584
long t1 = timeMs();
@@ -591,7 +587,7 @@ int main() {
591587

592588
int ix = -1;
593589
for (int i = 0; i < spec.dim; i++) {
594-
if (fabs(x[i] - expectedOutput[i]) > 0.00001) { // Optimization may cause some differences
590+
if (std::isnan(x[i]) || fabs(x[i] - expectedOutput[i]) > 0.00001) { // Optimization may cause some differences
595591
ix = i;
596592
break;
597593
}

0 commit comments

Comments
 (0)