Skip to content

Commit aec85b9

Browse files
authored
feat: support r1 distill llama. (#161)
1 parent caea6eb commit aec85b9

9 files changed

+135
-99
lines changed

README.md

+13-21
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
[![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/b4rtaz/distributed-llama/.github%2Fworkflows%2Fmain.yml?style=flat-square)](https://github.com/b4rtaz/distributed-llama/actions) [![License: MIT](https://img.shields.io/github/license/mashape/apistatus.svg?style=flat-square)](/LICENSE) [![Support this project](https://img.shields.io/github/sponsors/b4rtaz?style=flat-square&label=support%20this%20project&color=green)](https://github.com/sponsors/b4rtaz) [![Discord](https://discordapp.com/api/guilds/1245814812353495070/widget.png?style=shield)](https://discord.com/widget?id=1245814812353495070&theme=dark)
66

7-
Tensor parallelism is all you need. Run LLMs on weak devices or make powerful devices even more powerful by distributing the workload and dividing the RAM usage. This project proves that it's possible split the workload of LLMs across multiple devices and achieve a significant speedup. Distributed Llama allows you to run huge LLMs in-house. The project uses TCP sockets to synchronize the state. You can easily configure your AI cluster by using a home router.
7+
Connect home devices into a powerful cluster to accelerate LLM inference. More devices mean faster performance, leveraging tensor parallelism and high-speed synchronization over Ethernet.
8+
9+
Supports Linux, macOS, and Windows. Optimized for ARM and x86_64 AVX2 CPUs.
810

911
**News**
1012
- 12 Feb 2025 - 🚧 Merged the [fundamental codebase refactor](https://github.com/b4rtaz/distributed-llama/releases/tag/v0.12.0)
@@ -16,36 +18,26 @@ Tensor parallelism is all you need. Run LLMs on weak devices or make powerful de
1618

1719
Python 3 and C++ compiler required. The command will download the model and the tokenizer.
1820

19-
| Model | Purpose | Size | Command |
20-
| --------------------------- | --------- | -------- | --------------------------------------------- |
21-
| Llama 3.1 8B Instruct Q40 | Chat, API | 6.32 GB | `python launch.py llama3_1_8b_instruct_q40` |
22-
| Llama 3.1 405B Instruct Q40 | Chat, API | 238 GB | `python launch.py llama3_1_405b_instruct_q40` |
23-
| Llama 3.2 1B Instruct Q40 | Chat, API | 1.7 GB | `python launch.py llama3_2_1b_instruct_q40` |
24-
| Llama 3.2 3B Instruct Q40 | Chat, API | 3.4 GB | `python launch.py llama3_2_3b_instruct_q40` |
25-
| Llama 3.3 70B Instruct Q40 | Chat, API | 40 GB | `python launch.py llama3_3_70b_instruct_q40` |
21+
| Model | Purpose | Size | Command |
22+
| --------------------------------- | --------- | -------- | ---------------------------------------------------- |
23+
| Llama 3.1 8B Instruct Q40 | Chat, API | 6.32 GB | `python launch.py llama3_1_8b_instruct_q40` |
24+
| Llama 3.1 405B Instruct Q40. | Chat, API | 238 GB | `python launch.py llama3_1_405b_instruct_q40`. |
25+
| Llama 3.2 1B Instruct Q40 | Chat, API | 1.7 GB | `python launch.py llama3_2_1b_instruct_q40` |
26+
| Llama 3.2 3B Instruct Q40 | Chat, API | 3.4 GB | `python launch.py llama3_2_3b_instruct_q40` |
27+
| Llama 3.3 70B Instruct Q40 | Chat, API | 40 GB | `python launch.py llama3_3_70b_instruct_q40` |
28+
| DeepSeek R1 Distill Llama 8B Q40 | Chat, API | 6.32 GB | `python launch.py deepseek_r1_distill_llama_8b_q40` |
2629

2730
### 🛠️ Convert Model Manually
2831

29-
Supported architectures: Llama, Mixtral
32+
Supported architectures: Llama.
3033

31-
* [How to Convert Llama 2, Llama 3, Llama 3.1](./docs/LLAMA.md)
34+
* [How to Convert Llama 3.1](./docs/LLAMA.md)
3235
* [How to Convert Hugging Face Model](./docs/HUGGINGFACE.md)
3336

3437
### 🚧 Known Limitations
3538

3639
* You can run Distributed Llama only on 1, 2, 4... 2^n nodes.
3740
* The maximum number of nodes is equal to the number of KV heads in the model [#70](https://github.com/b4rtaz/distributed-llama/issues/70).
38-
* CPU support only, GPU support is planned, optimized for (weights format × buffer format):
39-
* ARM CPUs
40-
* ✅ F32 × F32
41-
* ❌ F16 × F32
42-
* ✅ Q40 × F32
43-
* ✅ Q40 × Q80
44-
* x86_64 AVX2 CPUs
45-
* ✅ F32 × F32
46-
* ❌ F16 × F32
47-
* ✅ Q40 × F32
48-
* ✅ Q40 × Q80
4941

5042
### 👷 Architecture
5143

converter/convert-hf.py

-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
class ArchType:
99
LLAMA = 0xABCD00
10-
MIXTRAL = 0xABCD02
1110

1211
def permute(tensor, nHeads: int, nKvHeads: int):
1312
if nHeads != nKvHeads:
@@ -128,7 +127,6 @@ def parseArchType(type: str):
128127
archType = {
129128
'llama': ArchType.LLAMA,
130129
'mistral': ArchType.LLAMA,
131-
'mixtral': ArchType.MIXTRAL,
132130
}.get(type)
133131
if (archType is None):
134132
raise Exception(f'Unsupported arch type: {type}')

converter/convert-tokenizer-hf.py

+39-19
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,26 @@
22
import json
33
import os
44
from sentencepiece import SentencePieceProcessor
5+
from transformers import PreTrainedTokenizerFast
56
writer = __import__('tokenizer-writer')
67

78
def openJson(path):
89
with open(path, 'r', encoding='utf-8') as file:
910
return json.load(file)
1011

12+
def unicodeToBytes():
13+
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
14+
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
15+
cs = bs[:]
16+
n = 0
17+
for b in range(2 ** 8):
18+
if b not in bs:
19+
bs.append(b)
20+
cs.append(2 ** 8 + n)
21+
n += 1
22+
cs = [chr(n) for n in cs]
23+
return dict(zip(cs, bs))
24+
1125
class TokensResolver:
1226
def __init__(self, dirPath, tokenizerConfig):
1327
self.dirPath = dirPath
@@ -18,25 +32,28 @@ def __init__(self, dirPath, tokenizerConfig):
1832
self.scores = []
1933

2034
def resolvePreTrainedTokenizerFast(self):
21-
tokenizer = openJson(os.path.join(self.dirPath, 'tokenizer.json'))
22-
assert(tokenizer['model']['type'] == 'BPE')
23-
24-
i = 0
25-
for token in tokenizer['model']['vocab'].keys():
26-
assert(tokenizer['model']['vocab'][token] == i)
27-
self.tokens.append(token.encode('utf8'))
35+
utb = unicodeToBytes()
36+
tokenizer = PreTrainedTokenizerFast(tokenizer_file = os.path.join(self.dirPath, 'tokenizer.json'))
37+
vocabLen = len(tokenizer.get_vocab())
38+
for i in range(vocabLen):
39+
tokenChars = list(tokenizer.convert_ids_to_tokens([i])[0])
40+
tokenBytes = []
41+
for chr in tokenChars:
42+
if (chr in utb):
43+
tokenBytes.append(utb[chr])
44+
else:
45+
tokenBytes += list(chr.encode('utf-8'))
46+
self.tokens.append(bytes(tokenBytes))
2847
self.scores.append(-float(i))
29-
i += 1
30-
if ('added_tokens' in tokenizer):
31-
for at in tokenizer['added_tokens']:
32-
assert(at['id'] == i)
33-
self.tokens.append(at['content'].encode('utf8'))
34-
self.scores.append(-float(i))
35-
if (at['content'] == self.tokenizerConfig['bos_token']):
36-
self.bosId = i
37-
if (at['content'] == self.tokenizerConfig['eos_token']):
38-
self.eosId = i
39-
i += 1
48+
49+
self.bosId = tokenizer.bos_token_id
50+
self.eosId = tokenizer.eos_token_id
51+
if (self.bosId is None or self.eosId is None):
52+
config = openJson(os.path.join(self.dirPath, 'config.json'))
53+
if (self.bosId is None):
54+
self.bosId = config['bos_token_id']
55+
if (self.eosId is None):
56+
self.eosId = config['eos_token_id']
4057

4158
def resolveLlamaTokenizer(self):
4259
modelPath = os.path.join(self.dirPath, 'tokenizer.model')
@@ -57,12 +74,13 @@ def resolveLlamaTokenizer(self):
5774

5875
def resolve(self):
5976
cls = self.tokenizerConfig['tokenizer_class']
60-
if (cls == 'PreTrainedTokenizerFast'):
77+
if (cls == 'PreTrainedTokenizerFast' or cls == 'LlamaTokenizerFast'):
6178
return self.resolvePreTrainedTokenizerFast()
6279
if (cls == 'LlamaTokenizer'):
6380
return self.resolveLlamaTokenizer()
6481
raise Exception(f'Tokenizer {cls} is not supported')
6582

83+
6684
def printUsage():
6785
print('Usage: python convert-tokenizer-hf.py <tokenizerFolderPath> <name>')
6886
print()
@@ -82,6 +100,8 @@ def printUsage():
82100
resolver = TokensResolver(dirPath, tokenizerConfig)
83101
resolver.resolve()
84102

103+
if (resolver.bosId is None or resolver.eosId is None):
104+
raise Exception('Cannot resolve bosId or eosId')
85105
print(f'bosId: {resolver.bosId} ({resolver.tokens[resolver.bosId]})')
86106
print(f'eosId: {resolver.eosId} ({resolver.tokens[resolver.eosId]})')
87107

launch.py

+5
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def parts(length):
3838
'https://huggingface.co/b4rtaz/Llama-3_3-70B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama-3.3-70b.t?download=true',
3939
'q40', 'q80', 'chat', '--max-seq-len 4096'
4040
],
41+
'deepseek_r1_distill_llama_8b_q40': [
42+
['https://huggingface.co/b4rtaz/DeepSeek-R1-Distill-Llama-8B-Distributed-Llama/resolve/main/dllama_model_deepseek-r1-distill-llama-8b_q40.m?download=true'],
43+
'https://huggingface.co/b4rtaz/DeepSeek-R1-Distill-Llama-8B-Distributed-Llama/resolve/main/dllama_tokenizer_deepseek-r1-distill-llama-8b.t?download=true',
44+
'q40', 'q80', 'chat', '--max-seq-len 4096'
45+
],
4146
}
4247

4348
def confirm(message: str):

src/app.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ static NnFloatType parseFloatType(char *val) {
1414
static ChatTemplateType parseChatTemplateType(char *val) {
1515
if (std::strcmp(val, "llama2") == 0) return TEMPLATE_LLAMA2;
1616
if (std::strcmp(val, "llama3") == 0) return TEMPLATE_LLAMA3;
17-
if (std::strcmp(val, "zephyr") == 0) return TEMPLATE_ZEPHYR;
18-
if (std::strcmp(val, "chatml") == 0) return TEMPLATE_CHATML;
17+
if (std::strcmp(val, "deepSeek3") == 0) return TEMPLATE_DEEP_SEEK3;
1918
throw std::runtime_error("Invalid chat template type: " + std::string(val));
2019
}
2120

src/dllama-api.cpp

+11-11
Original file line numberDiff line numberDiff line change
@@ -345,14 +345,13 @@ class ApiServer {
345345
inputItems[i].message = deltaPrompt[i].content;
346346
}
347347

348-
std::string inputPrompt = chatTemplate->generate(nInputItems, inputItems, true);
349-
printf("🔹%s🔸", inputPrompt.c_str());
348+
GeneratedChat inputPrompt = chatTemplate->generate(nInputItems, inputItems, true);
349+
printf("🔹%s🔸", inputPrompt.content);
350350

351-
size_t promptLength = inputPrompt.size();
352351
int nPromptTokens;
353-
std::unique_ptr<int[]> promptTokensPtr(new int[promptLength + 2]);
352+
std::unique_ptr<int[]> promptTokensPtr(new int[inputPrompt.length + 2]);
354353
int *promptTokens = promptTokensPtr.get();
355-
tokenizer->encode((char*)inputPrompt.c_str(), promptTokens, &nPromptTokens, true, true);
354+
tokenizer->encode((char*)inputPrompt.content, promptTokens, &nPromptTokens, true, true);
356355

357356
pos_t promptEndPos = startPos + nPromptTokens - 1;
358357
if (promptEndPos > header->seqLen)
@@ -366,13 +365,16 @@ class ApiServer {
366365
naiveCache.push(NaiveCacheItem(promptEndPos, deltaPrompt[j]));
367366
}
368367

369-
if (params.stream) {
368+
std::string buffer;
369+
370+
if (params.stream)
370371
request.writeStreamStartChunk();
372+
if (inputPrompt.publicPrompt != nullptr) {
373+
if (params.stream)
374+
writeChatCompletionChunk(request, inputPrompt.publicPrompt, false);
375+
buffer += inputPrompt.publicPrompt;
371376
}
372377

373-
std::string buffer;
374-
size_t nStops = params.stop.size();
375-
376378
NnSize pos = startPos;
377379
int token;
378380
for (NnSize i = 0; ;) {
@@ -400,8 +402,6 @@ class ApiServer {
400402
tokenizer->resetDecoder();
401403

402404
for (; pos < maxPredPos;) {
403-
int prevToken = token;
404-
405405
inference->setPosition(pos);
406406
inference->setToken(0, token);
407407
inference->forward();

src/dllama.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,12 @@ static void chat(AppInferenceContext *context) {
142142

143143
deltaItems.push_back(ChatItem{"user", prompt});
144144

145-
std::string inputPrompt = chatTemplate.generate(deltaItems.size(), deltaItems.data(), true);
146-
std::unique_ptr<int[]> inputTokensPtr(new int[inputPrompt.size() + 2]);
145+
GeneratedChat inputPrompt = chatTemplate.generate(deltaItems.size(), deltaItems.data(), true);
146+
std::unique_ptr<int[]> inputTokensPtr(new int[inputPrompt.length + 2]);
147147
int *inputTokens = inputTokensPtr.get();
148148

149149
bool addBos = pos == 0;
150-
context->tokenizer->encode((char*)inputPrompt.c_str(), inputTokens, &nInputTokens, addBos, true);
150+
context->tokenizer->encode((char*)inputPrompt.content, inputTokens, &nInputTokens, addBos, true);
151151

152152
NnSize userPromptEndPos = (NnSize)std::min<unsigned int>(seqLen, pos + nInputTokens - 1);
153153
for (NnSize i = 0; ;) {
@@ -174,7 +174,9 @@ static void chat(AppInferenceContext *context) {
174174
context->tokenizer->resetDecoder();
175175

176176
printf("\n🤖 Assistant\n");
177-
std::string answer;
177+
if (inputPrompt.publicPrompt != nullptr)
178+
printf("%s", inputPrompt.publicPrompt);
179+
178180
while (pos < seqLen) {
179181
context->inference->setPosition(pos);
180182
context->inference->setToken(0, token);

src/tokenizer.cpp

+51-37
Original file line numberDiff line numberDiff line change
@@ -471,70 +471,84 @@ TokenizerChatStops::~TokenizerChatStops() {
471471
delete[] stops;
472472
}
473473

474-
ChatTemplate::ChatTemplate(const ChatTemplateType type, const char* chatTemplate, const char* eos) {
474+
static const char *chatTemplateTypeToString(const ChatTemplateType type) {
475+
if (type == TEMPLATE_LLAMA2) return "llama2";
476+
if (type == TEMPLATE_LLAMA3) return "llama3";
477+
if (type == TEMPLATE_DEEP_SEEK3) return "deepSeek3";
478+
return "unknown";
479+
}
480+
481+
ChatTemplate::ChatTemplate(const ChatTemplateType type, const char* chatTemplate, const char* eos)
482+
: buffer()
483+
{
475484
if (type == TEMPLATE_UNKNOWN) {
476485
if (chatTemplate == NULL)
477486
throw std::runtime_error("The tokenizer does not include chat template");
478487
if (strstr(chatTemplate, "[INST]") != NULL) {
479488
this->type = TEMPLATE_LLAMA2;
480489
} else if (strstr(chatTemplate, "<|start_header_id|>") != NULL) {
481490
this->type = TEMPLATE_LLAMA3;
482-
} else if (strstr(chatTemplate, "<|user|>") != NULL) {
483-
this->type = TEMPLATE_ZEPHYR;
484-
} else if (strstr(chatTemplate, "<|im_start|>") != NULL) {
485-
this->type = TEMPLATE_CHATML;
491+
} else if (strstr(chatTemplate, "<|Assistant|>") != NULL) {
492+
this->type = TEMPLATE_DEEP_SEEK3;
486493
} else {
487-
throw new std::runtime_error("Not supported chat template");
494+
throw std::runtime_error("Not supported chat template");
488495
}
489496
} else {
490497
this->type = type;
491498
}
492499
this->eos = eos;
493500

494-
printf("⭐ Chat template: ");
495-
if (this->type == TEMPLATE_LLAMA2) {
496-
printf("llama2\n");
497-
} else if (this->type == TEMPLATE_LLAMA3) {
498-
printf("llama3\n");
499-
} else if (this->type == TEMPLATE_ZEPHYR) {
500-
printf("zephyr\n");
501-
} else if (this->type == TEMPLATE_CHATML) {
502-
printf("chatml\n");
503-
}
501+
printf("⭐ Chat template: %s\n", chatTemplateTypeToString(this->type));
504502
}
505503

506-
std::string ChatTemplate::generate(unsigned int nMessages, ChatItem* items, bool appendGenerationPrompt) {
507-
std::ostringstream buffer;
504+
GeneratedChat ChatTemplate::generate(unsigned int nItems, ChatItem* items, bool appendGenerationPrompt) {
505+
buffer.clear();
506+
507+
size_t publicPromptSize = 0;
508+
508509
if (type == TEMPLATE_LLAMA2) {
509510
unsigned int i = 0;
510-
if (nMessages >= 2 && items[0].role == "system" && items[1].role == "user") {
511-
buffer << "[INST] <<SYS>>\n" << items[0].message << "\n<</SYS>>\n\n" << items[1].message << " [/INST]" << eos;
511+
if (nItems >= 2 && items[0].role == "system" && items[1].role == "user") {
512+
buffer += "[INST] <<SYS>>\n" + items[0].message + "\n<</SYS>>\n\n" + items[1].message + " [/INST]" + eos;
512513
i += 2;
513514
}
514-
for (; i < nMessages; i++) {
515+
for (; i < nItems; i++) {
515516
if (items[i].role == "assistant") {
516-
buffer << items[i].message << eos;
517+
buffer += items[i].message + eos;
517518
} else if (items[i].role == "user") {
518-
buffer << "[INST] " << items[i].message << " [/INST]" << eos;
519+
buffer += "[INST] " + items[i].message + " [/INST]" + eos;
519520
}
520521
}
521522
} else if (type == TEMPLATE_LLAMA3) {
522-
for (unsigned int i = 0; i < nMessages; i++)
523-
buffer << "<|start_header_id|>" << items[i].role << "<|end_header_id|>\n\n" << items[i].message << eos;
524-
if (appendGenerationPrompt)
525-
buffer << "<|start_header_id|>assistant<|end_header_id|>\n\n";
526-
} else if (type == TEMPLATE_CHATML) {
527-
for (unsigned int i = 0; i < nMessages; i++)
528-
buffer << "<|im_start|>" << items[i].role << "\n" << items[i].message << "<|im_end|>\n";
529-
if (appendGenerationPrompt)
530-
buffer << "<|im_start|>assistant\n";
531-
} else if (type == TEMPLATE_ZEPHYR) {
532-
for (unsigned int i = 0; i < nMessages; i++)
533-
buffer << "<|" << items[i].role << "|>\n" << items[i].message << eos << "\n";
523+
for (unsigned int i = 0; i < nItems; i++)
524+
buffer += "<|start_header_id|>" + items[i].role + "<|end_header_id|>\n\n" + items[i].message + eos;
534525
if (appendGenerationPrompt)
535-
buffer << "<|assistant|>\n";
526+
buffer += "<|start_header_id|>assistant<|end_header_id|>\n\n";
527+
} else if (type == TEMPLATE_DEEP_SEEK3) {
528+
unsigned int i = 0;
529+
if (nItems > 0 && items[0].role == "system") {
530+
buffer += items[0].message;
531+
i++;
532+
}
533+
for (; i < nItems; i++) {
534+
if (items[i].role == "user") {
535+
buffer += "<|User|>" + items[i].message;
536+
} else if (items[i].role == "assistant") {
537+
buffer += "<|Assistant|>" + items[i].message;
538+
}
539+
}
540+
if (appendGenerationPrompt) {
541+
buffer += "<|Assistant|><think>\n";
542+
publicPromptSize = 8;
543+
}
536544
}
537-
return buffer.str();
545+
546+
const char *content = buffer.c_str();
547+
size_t length = buffer.size();
548+
const char *publicPrompt = publicPromptSize > 0
549+
? &content[length - publicPromptSize]
550+
: nullptr;
551+
return {content, length, publicPrompt};
538552
}
539553

540554
EosDetector::EosDetector(size_t nTokens, const int *tokens, const char** pieces, int paddingLeft, int paddingRight) {

0 commit comments

Comments
 (0)