Skip to content

Commit 4b8a0ca

Browse files
authored
feat: support llama 3.1. (#106)
1 parent 8c57298 commit 4b8a0ca

15 files changed

+249
-55
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*.0
55
*.dSYM
66
*.data
7+
*.temp
78
__pycache__
89

910
*-test

converter/convert-hf.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,14 @@ def parseHiddenAct(act: str):
143143
raise Exception(f'Unsupported hidden act: {act}')
144144
return hiddenAct
145145

146+
def parseRopeType(rt: str):
147+
ropeType = {
148+
'llama3': 2, # LLAMA3_1
149+
}.get(rt)
150+
if (ropeType is None):
151+
raise Exception(f'Unsupported rope type: {ropeType}')
152+
return ropeType
153+
146154
def loadConfig(folderPath: str, weightsFloatType: int):
147155
allFiles = os.listdir(folderPath)
148156
allFiles.sort()
@@ -178,6 +186,14 @@ def loadConfig(folderPath: str, weightsFloatType: int):
178186
ropeTheta = config.get('rope_theta')
179187
if (ropeTheta is not None):
180188
result['rope_theta'] = int(ropeTheta)
189+
190+
ropeScaling = config.get('rope_scaling')
191+
if (ropeScaling is not None):
192+
result['rope_scaling_factor'] = int(ropeScaling['factor'])
193+
result['rope_scaling_low_freq_factor'] = int(ropeScaling['low_freq_factor'])
194+
result['rope_scaling_high_freq_factory'] = int(ropeScaling['high_freq_factor'])
195+
result['rope_scaling_orig_max_seq_len'] = int(ropeScaling['original_max_position_embeddings'])
196+
result['rope_type'] = parseRopeType(ropeScaling['rope_type'])
181197
return result
182198

183199
def printUsage():

converter/writer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,12 @@ def writeHeader(file, params):
125125
'max_seq_len': 10,
126126
'hidden_act': 11,
127127
'rope_theta': 12,
128-
'weights_float_type': 13
128+
'weights_float_type': 13,
129+
'rope_scaling_factor': 14,
130+
'rope_scaling_low_freq_factor': 15,
131+
'rope_scaling_high_freq_factory': 16,
132+
'rope_scaling_orig_max_seq_len': 17,
133+
'rope_type': 18,
129134
}
130135
header = struct.pack('i', 0xA00ABCD)
131136

launch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
'https://huggingface.co/b4rtaz/Llama-3-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_lama3_instruct_q40.m?download=true',
1919
'https://huggingface.co/b4rtaz/Llama-3-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama3.t?download=true',
2020
'q40', 'q80', 'chat'
21-
]
21+
],
22+
'llama3_1_8b_instruct_q40': [
23+
'https://huggingface.co/b4rtaz/Llama-3_1-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama3.1_instruct_q40.m?download=true',
24+
'https://huggingface.co/b4rtaz/Llama-3_1-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama_3_1.t?download=true',
25+
'q40', 'q80', 'chat'
26+
],
2227
}
2328

2429
def downloadFile(url: str, path: str):

src/app.cpp

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,27 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
4141
args.steps = 0;
4242
args.seed = (unsigned long long)time(NULL);
4343
args.chatTemplateType = TEMPLATE_UNKNOWN;
44+
args.useDiscForKvCache = false;
4445

4546
int i = 1;
4647
if (hasMode && argc > 1) {
4748
args.mode = argv[1];
4849
i++;
4950
}
5051
for (; i + 1 < argc; i += 2) {
51-
if (strcmp(argv[i], "--model") == 0) {
52-
args.modelPath = argv[i + 1];
53-
} else if (strcmp(argv[i], "--tokenizer") == 0) {
54-
args.tokenizerPath = argv[i + 1];
55-
} else if (strcmp(argv[i], "--prompt") == 0) {
56-
args.prompt = argv[i + 1];
57-
} else if (strcmp(argv[i], "--weights-float-type") == 0) {
58-
args.weightsFloatType = parseFloatType(argv[i + 1]);
59-
} else if (strcmp(argv[i], "--buffer-float-type") == 0) {
60-
args.bufferFloatType = parseFloatType(argv[i + 1]);
61-
} else if (strcmp(argv[i], "--workers") == 0) {
52+
char* name = argv[i];
53+
char* value = argv[i + 1];
54+
if (strcmp(name, "--model") == 0) {
55+
args.modelPath = value;
56+
} else if (strcmp(name, "--tokenizer") == 0) {
57+
args.tokenizerPath = value;
58+
} else if (strcmp(name, "--prompt") == 0) {
59+
args.prompt = value;
60+
} else if (strcmp(name, "--weights-float-type") == 0) {
61+
args.weightsFloatType = parseFloatType(value);
62+
} else if (strcmp(name, "--buffer-float-type") == 0) {
63+
args.bufferFloatType = parseFloatType(value);
64+
} else if (strcmp(name, "--workers") == 0) {
6265
int j = i + 1;
6366
for (; j < argc && argv[j][0] != '-'; j++);
6467
int count = j - i - 1;
@@ -82,22 +85,24 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
8285
}
8386

8487
i += count - 1;
85-
} else if (strcmp(argv[i], "--port") == 0) {
86-
args.port = atoi(argv[i + 1]);
87-
} else if (strcmp(argv[i], "--nthreads") == 0) {
88-
args.nThreads = atoi(argv[i + 1]);
89-
} else if (strcmp(argv[i], "--steps") == 0) {
90-
args.steps = atoi(argv[i + 1]);
91-
} else if (strcmp(argv[i], "--temperature") == 0) {
92-
args.temperature = atof(argv[i + 1]);
93-
} else if (strcmp(argv[i], "--topp") == 0) {
94-
args.topp = atof(argv[i + 1]);
95-
} else if (strcmp(argv[i], "--seed") == 0) {
96-
args.seed = atoll(argv[i + 1]);
97-
} else if (strcmp(argv[i], "--chat-template") == 0) {
98-
args.chatTemplateType = parseChatTemplateType(argv[i + 1]);
88+
} else if (strcmp(name, "--port") == 0) {
89+
args.port = atoi(value);
90+
} else if (strcmp(name, "--nthreads") == 0) {
91+
args.nThreads = atoi(value);
92+
} else if (strcmp(name, "--steps") == 0) {
93+
args.steps = atoi(value);
94+
} else if (strcmp(name, "--temperature") == 0) {
95+
args.temperature = atof(value);
96+
} else if (strcmp(name, "--topp") == 0) {
97+
args.topp = atof(value);
98+
} else if (strcmp(name, "--seed") == 0) {
99+
args.seed = atoll(value);
100+
} else if (strcmp(name, "--chat-template") == 0) {
101+
args.chatTemplateType = parseChatTemplateType(value);
102+
} else if (strcmp(name, "--kv-cache-storage") == 0) {
103+
args.useDiscForKvCache = strcmp(value, "disc") == 0;
99104
} else {
100-
printf("Unknown option %s\n", argv[i]);
105+
printf("Unknown option %s\n", name);
101106
exit(EXIT_FAILURE);
102107
}
103108
}
@@ -131,7 +136,10 @@ void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* s
131136
args->steps = spec.seqLen;
132137
}
133138

134-
Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, socketPool);
139+
TransformerConfig config;
140+
config.useDiscForKvCache = args->useDiscForKvCache;
141+
142+
Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, &config, socketPool);
135143
socketPool->setTurbo(true);
136144

137145
Inference inference = Inference(&arch, args->nThreads, &transformer, socketPool);

src/app.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
class AppArgs {
1717
public:
1818
char* mode;
19-
int nThreads;
19+
int nThreads;
20+
bool useDiscForKvCache;
2021

2122
// inference
2223
char* modelPath;

src/apps/dllama/dllama.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ class Chat {
160160
int nInputTokens;
161161
tokenizer->encode((char*)inputPrompt.c_str(), inputTokens, &nInputTokens, true, false);
162162

163-
pos_t userPromptEndPos = (pos_t)std::min(spec->seqLen, pos + nInputTokens - 1);
163+
pos_t userPromptEndPos = (pos_t)std::min(spec->seqLen, (int)pos + nInputTokens - 1);
164164
for (pos_t i = 0; pos < userPromptEndPos; pos++, i++) {
165165
inference->infer(inputTokens[i], pos);
166166
token = inputTokens[i + 1];
@@ -207,10 +207,13 @@ void worker(AppArgs* args) {
207207
throw std::runtime_error("Invalid port number");
208208
}
209209

210+
TransformerConfig config;
211+
config.useDiscForKvCache = args->useDiscForKvCache;
212+
210213
SocketServer server(args->port);
211214
Socket socket = server.accept();
212215
TransformerSpec spec;
213-
Transformer transformer = Transformer::loadSlice(&spec, &socket);
216+
Transformer transformer = Transformer::loadSlice(&spec, &config, &socket);
214217
TransformerArch arch = TransformerArchFactory::create(&spec);
215218

216219
Worker worker = Worker(&arch, args->nThreads, &transformer, &socket);

src/commands.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#include <cassert>
22
#include <cstring>
3+
#ifdef _WIN32
4+
#define _USE_MATH_DEFINES
5+
#endif
36
#include <cmath>
47
#include "utils.hpp"
58
#include "funcs.hpp"
@@ -167,6 +170,54 @@ void LlamaRopeCommand::forward(bool isQ, float* qOrK, pos_t pos, unsigned int nT
167170
}
168171
}
169172

173+
Llama3_1RopeCommand::Llama3_1RopeCommand(RopeSlice *slice, float ropeScalingFactor, float ropeScalingLowFreqFactor, float ropeScalingHighFreqFactory, int ropeScalingOrigMaxSeqLen) {
174+
this->slice = slice;
175+
this->ropeScalingFactor = ropeScalingFactor;
176+
this->ropeScalingLowFreqFactor = ropeScalingLowFreqFactor;
177+
this->ropeScalingHighFreqFactory = ropeScalingHighFreqFactory;
178+
this->ropeScalingOrigMaxSeqLen = ropeScalingOrigMaxSeqLen;
179+
printf("🕒 ropeScalingFactor: %f\n", ropeScalingFactor);
180+
printf("🕒 ropeScalingLowFreqFactor: %f\n", ropeScalingLowFreqFactor);
181+
printf("🕒 ropeScalingHighFreqFactory: %f\n", ropeScalingHighFreqFactory);
182+
printf("🕒 ropeScalingOrigMaxSeqLen: %d\n", ropeScalingOrigMaxSeqLen);
183+
}
184+
185+
float Llama3_1RopeCommand::scale(float freq) {
186+
float waveLen = 2.0f * M_PI * freq;
187+
float lowFreqWavelen = ropeScalingOrigMaxSeqLen / ropeScalingLowFreqFactor;
188+
float highFreqWavelen = ropeScalingOrigMaxSeqLen / ropeScalingHighFreqFactory;
189+
if (waveLen < highFreqWavelen) {
190+
return freq;
191+
} else if (waveLen > lowFreqWavelen) {
192+
return freq / ropeScalingFactor;
193+
} else {
194+
float smooth = (ropeScalingOrigMaxSeqLen / waveLen - ropeScalingLowFreqFactor) / (ropeScalingHighFreqFactory - ropeScalingLowFreqFactor);
195+
return (1 - smooth) * freq / ropeScalingFactor + smooth * freq;
196+
}
197+
}
198+
199+
void Llama3_1RopeCommand::forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex) {
200+
const unsigned int dim0Half = (isQ ? slice->qDim0 : slice->kvDim0) / 2;
201+
const unsigned int shift = isQ ? slice->qShift : 0;
202+
SPLIT_RANGE_TO_THREADS(s, e, 0, dim0Half, nThreads, threadIndex);
203+
const unsigned int iStart = s * 2;
204+
const unsigned int iEnd = e * 2;
205+
206+
for (unsigned int i = iStart; i < iEnd; i += 2) {
207+
const unsigned int headDim = i % slice->headSize;
208+
const float freq = 1.0f / powf(slice->ropeTheta, headDim / (float)slice->headSize);
209+
const float val = pos * freq;
210+
const float fcr = cosf(val);
211+
const float fci = sinf(val);
212+
213+
float v0 = qOrK[i];
214+
float v1 = qOrK[i + 1];
215+
216+
qOrK[i] = scale(v0 * fcr - v1 * fci);
217+
qOrK[i + 1] = scale(v0 * fci + v1 * fcr);
218+
}
219+
}
220+
170221
FalconRopeCommand::FalconRopeCommand(RopeSlice *slice) {
171222
this->slice = slice;
172223
}

src/commands.hpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
// *Slice - calculates sizes, offsets, slice sizes etc. It is not responsible for memory allocation. It may help in the loading of data.
1010
// *Command - allocates memory for weights, performs calculations.
1111

12-
typedef unsigned short pos_t;
12+
typedef unsigned int pos_t;
1313
typedef uint8_t slice_index_t;
1414

1515
class MatmulSlice {
@@ -106,6 +106,19 @@ class LlamaRopeCommand : public RopeCommand {
106106
void forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex);
107107
};
108108

109+
class Llama3_1RopeCommand : public RopeCommand {
110+
private:
111+
RopeSlice* slice;
112+
float ropeScalingFactor;
113+
float ropeScalingLowFreqFactor;
114+
float ropeScalingHighFreqFactory;
115+
int ropeScalingOrigMaxSeqLen;
116+
public:
117+
Llama3_1RopeCommand(RopeSlice *slice, float ropeScalingFactor, float ropeScalingLowFreqFactor, float ropeScalingHighFreqFactory, int ropeScalingOrigMaxSeqLen);
118+
void forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex);
119+
float scale(float freq);
120+
};
121+
109122
class FalconRopeCommand : public RopeCommand {
110123
private:
111124
RopeSlice* slice;

src/grok1-tasks-test.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ int main() {
3030
TransformerSpec spec;
3131
spec.headerSize = sizeof(TransformerFileOldHeader) + sizeof(int);
3232
spec.archType = GROK1;
33+
spec.ropeType = ROPE_FALCON;
3334
spec.dim = 6144;
3435
spec.nLayers = 1;
3536
spec.nHeads = 48;
@@ -47,6 +48,9 @@ int main() {
4748
spec.hiddenAct = GELU;
4849
spec.ropeTheta = 10000.0f;
4950

51+
TransformerConfig config;
52+
config.useDiscForKvCache = false;
53+
5054
size_t beforeBlockBytes = spec.dim * spec.vocabSize * sizeof(float);
5155
size_t blockBytes = 956596224;
5256
size_t afterBlockBytes = (spec.dim + spec.dim * spec.vocabSize) * sizeof(float);
@@ -60,7 +64,7 @@ int main() {
6064
for (int f = 0; f < nFloats; f++) block[f] = randomF32(&state) / 100.0;
6165

6266
SocketPool socketPool(0, NULL);
63-
Transformer transformer = Transformer::loadRoot(weights, &spec, &socketPool);
67+
Transformer transformer = Transformer::loadRoot(weights, &spec, &config, &socketPool);
6468
transformer.pos = 0;
6569

6670
float* x = transformer.x;

src/llama2-tasks-test.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ int main() {
528528
TransformerSpec spec;
529529
spec.headerSize = sizeof(TransformerFileOldHeader) + sizeof(int);
530530
spec.archType = LLAMA;
531+
spec.ropeType = ROPE_LLAMA;
531532
spec.dim = 4096;
532533
spec.nLayers = 1;
533534
spec.headSize = 128;
@@ -545,6 +546,9 @@ int main() {
545546
spec.hiddenAct = SILU;
546547
spec.ropeTheta = 10000.0f;
547548

549+
TransformerConfig config;
550+
config.useDiscForKvCache = false;
551+
548552
size_t beforeBlockBytes = /* embedding */ 524288000;
549553
size_t blockBytes = 809533440;
550554
size_t afterBlockBytes = /* norm */ 16384 + /* embedding */ 524288000;
@@ -562,7 +566,7 @@ int main() {
562566
for (int i = 0; i < mm; i++) mmData[i] = randomF32(&state) / 120.0;
563567

564568
SocketPool socketPool(0, NULL);
565-
Transformer transformer = Transformer::loadRoot((char*)data, &spec, &socketPool);
569+
Transformer transformer = Transformer::loadRoot((char*)data, &spec, &config, &socketPool);
566570
transformer.pos = 0;
567571

568572
float* x = transformer.x;

0 commit comments

Comments
 (0)