Skip to content

Commit

Permalink
Merge pull request #19 from gyfffffff/main
Browse files Browse the repository at this point in the history
更新PPT,完善文档和代码细节
  • Loading branch information
hscspring authored Dec 21, 2024
2 parents 81699bd + 010145b commit 7c6ea38
Show file tree
Hide file tree
Showing 18 changed files with 5,075 additions and 563 deletions.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,10 @@ cython_debug/
.idea/
.DS_Store
docs/chapter2/models/GPT-2/*
MetaICL/
MetaICL/
data/
output/
outputs/

docs/chapter2/code/BabyLlama/models/*
*.zip
Binary file added PPT/distillation.pptx
Binary file not shown.
6 changes: 2 additions & 4 deletions docs/chapter2/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# 2.1 蒸馏

本章讲介绍基于Transformer的模型的主流蒸馏方法和代码,还将实现一个端侧部署demo
本章将介绍大模型的主流蒸馏方法和代码

## Roadmap
### 1. 蒸馏基础
Expand All @@ -18,7 +18,7 @@
- 2.1 概述
- 何时使用白盒蒸馏
- 2.2 MiniLLM
- 2.3 GKD
- 2.3 BabyLlama

### 3. 基于涌现能力的蒸馏(黑盒蒸馏)
- 3.1 概述
Expand All @@ -33,5 +33,3 @@


### 4. 总结
- 4.1 前沿相关工作扩展
- 4.2 总结
4 changes: 4 additions & 0 deletions docs/chapter2/chapter2_1.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ https://github.com/datawhalechina/awesome-compression/blob/main/docs/ch06/ch06.m

![](images/Figure%206.png)

# 前置知识
了解以下知识有助于接下来的学习:
1. logits 和 软目标
2. 监督微调(SFT技术)


参考文献:
Expand Down
55 changes: 41 additions & 14 deletions docs/chapter2/chapter2_2.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# 白盒蒸馏

## 1. 什么是白盒蒸馏
对于开源的大模型,我们可以获得一切的模型推理时数据,包括token输出的概率分布。这种能够获得token输出概率分布的场景,可以被看作“白盒”场景。反之是黑盒场景。利用白盒所提供的数据进行蒸馏,是白盒蒸馏。
白盒蒸馏是指在蒸馏过程中使用到教师模型的参数或 logits 的 蒸馏技术[2].

接下来我们会介绍经典的白盒蒸馏方法和代码实现。

Expand Down Expand Up @@ -47,18 +47,45 @@ MiniLLM的论文中提出了另一个新颖的视角——逆向KL其实可以
由于这部分涉及较多数学公式推导和强化学习,有兴趣的同学可以查看论文自行学习。

# 3. BabyLlama(实践)
[BabyLlama](http://arxiv.org/abs/2308.02019)将蒸馏看作一种提高训练样本利用效率的有效方式。作为代码实践的例子,我们将看到它的蒸馏损失函数使用到了教师模型的soft-labels。

BabyLlama的代码包含了
1. 数据清洗和tokenizer训练
2. 教师模型训练
3. 蒸馏学生模型

但实际上白盒蒸馏也可以使用现成的开源模型和tokenizer。


[BabyLlama](http://arxiv.org/abs/2308.02019) 将小模型蒸馏直接应用到了大模型上。它的损失函数是以下两种损失的加权和:
- 和硬损失的交叉熵
- 和软损失的KL散度

在code/BabyLlama/3.distill.ipynb中可以看到它的损失函数:
```python
def compute_loss(self, model, inputs, return_outputs=False):
# 硬损失,即和ground truth的交叉熵
outputs_student = model(**inputs)
student_loss = outputs_student.loss

# compute teacher output
with torch.no_grad():
all_teacher_logits = []
for teacher in self.teachers:
outputs_teacher = teacher(**inputs)
all_teacher_logits.append(outputs_teacher.logits)
avg_teacher_logits = torch.stack(all_teacher_logits).mean(dim=0)

# assert size
assert outputs_student.logits.size() == avg_teacher_logits.size()

# 软损失,和教师模型输出分布的KL散度
loss_function = nn.KLDivLoss(reduction="batchmean")
loss_logits = (
loss_function(
F.log_softmax(outputs_student.logits / self.args.temperature, dim=-1),
F.softmax(avg_teacher_logits / self.args.temperature, dim=-1),
)
* (self.args.temperature ** 2)
)
# Return weighted student loss
loss = self.args.alpha * student_loss + (1.0 - self.args.alpha) * loss_logits
return (loss, outputs_student) if return_outputs else loss
```

## 参考资料
- MiniLLM: Knowledge Distillation of Large Language Models
- https://github.com/microsoft/LMOps/tree/main/minillm
- https://blog.csdn.net/ningmengzhihe/article/details/130679350
1. MiniLLM: Knowledge Distillation of Large Language Models
2. Efficient Large Language Models: A Survey
3. https://github.com/microsoft/LMOps/tree/main/minillm
4. https://blog.csdn.net/ningmengzhihe/article/details/130679350
5. Baby Llama: knowledge distillation from an ensemble of teachers trained on a small dataset with no performance penalty
20 changes: 10 additions & 10 deletions docs/chapter2/chapter2_3.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# 基于涌现能力的蒸馏(黑盒蒸馏)
# 黑盒蒸馏(技巧蒸馏)
黑盒蒸馏所用到的仅仅是教师模型的回答(有时也包括输出的概率分布,即软目标,但是不会用到logits)。

黑盒蒸馏意味着教师模型的输出是我们唯一能获取到的训练资源,因此,黑盒蒸馏整体的思路可以分成两步:
1. 从教师模型收集问答数据
Expand All @@ -17,9 +18,7 @@

以下是一个简单的ICL例子:

<div align="center">
<img src="images/image-1.png" alt="alt text" width="550"/>
</div>
![alt_text](images/image-1.png)

模型成功模仿了示例中的答题思路和答题格式。

Expand All @@ -44,9 +43,9 @@ y_3
即只要在prompt前加几个例子,模型就能学到其中的格式和逻辑,从而不用更新参数就能学习。

训练之前,我们会收集如下含有提示词和标签的数据:
<div align="center">
<img src="images/image-3.png" alt="alt text" width="750"/>
</div>


![alt_text](images/image-3.png)

## 1.2 ICL 微调

Expand Down Expand Up @@ -132,7 +131,7 @@ $$
$$

## 1.5 实践

参见code/ICL。

## 1.6 改进方向
模型上下文学习的性能和上下文中的例子质量紧密相关,所以有人研究专门设计了一个例子检索器,检索高质量的示例[6]
Expand Down Expand Up @@ -175,7 +174,7 @@ $$


## 2.3 指令跟随蒸馏实践

参见 code/InsructFollowing.

## 2.4 对抗蒸馏
对抗蒸馏(adversarial distillation)提出除了可以让知识单向地从教师注入学生,学生也可以产生“反馈”,
Expand Down Expand Up @@ -210,12 +209,13 @@ $$
## 步骤3:实践
正式的微调Loss函数是交叉熵损失。

参见code/CoT.



# 4. 扩展和总结
实际上,除了以上三种涌现能力的蒸馏,只要是从教授模型收集某种类型的数据,然后用这些数据微调学生模型,都是黑盒蒸馏的应用范围。
因此,对于一些特定领域和特定需求的任务,也可以使用类似的方法达到希望的效果。比如近期上海交通大学的[O1复现论文](https://arxiv.org/pdf/2410.18982)就是一个很好的对教师模型的推理能力进行蒸馏的例子。
因此,对于一些特定领域和特定需求的任务,也可以使用类似的方法达到希望的效果。比如随着OpenAI 强大的复杂推理模型O1的发布,对推理能力进行蒸馏也可以套用上面的方法。近期上海交通大学的[O1复现论文](https://arxiv.org/pdf/2410.18982)就是一个很好的对教师模型的推理能力进行蒸馏的例子。

但是也有研究[5]指出黑盒蒸馏导致仅模仿但不理解的问题,要提高学习质量,还需学生有良好的天赋(base 模型的能力)。

Expand Down
5 changes: 5 additions & 0 deletions docs/chapter2/chapter2_4.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
# 总结

本章中,我们学习了大模型蒸馏的概念,与传统蒸馏的不同,以及主流的大模型蒸馏范式。
笔者认为,不论是白盒还是黑盒蒸馏,大模型蒸馏贯穿始终的思想是“训练数据来源于教师”,而非人为标注或机器标注。

蒸馏无疑是一种低成本高效率的提升小模型能力的方式,也可以说它是一条“捷径”,它的初衷是有在限资源部署更好的模型。
但是作为长期主义的研究工作者,想要提升模型能力,不能一味地依靠蒸馏“走捷径”,还是要从第一性原理出发,从根本上探索提升模型能力的技术路线。
81 changes: 49 additions & 32 deletions docs/chapter2/code/BabyLlama/1.clean_and_tokenize.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -11,11 +11,16 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# 下载数据:https://osf.io/rduj2"
"# 下载两份数据:https://osf.io/5mk3x, https://osf.io/m48ed\n",
"# 将两份数据解压到当前目录下的data文件夹中\n",
"# data目录结构如下:\n",
"# data/\n",
"# |--train_10M/\n",
"# |--dev/"
]
},
{
Expand All @@ -36,23 +41,24 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from mrclean import *"
"from mrclean import *\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"DATA_ROOT = Path(\"F:/llm-deploy-data/data/Babyllama\")\n",
"DATA_ROOT = Path(\"./data\")\n",
"SEQ_LENGTH = 128 # this is a legacy parameter, it does not affect cleaning\n",
"DATA_SPLITS = ['babylm_10M', 'babylm_dev']\n",
"DATA_SPLITS = ['train_10M', 'dev']\n",
"\n",
"CLEANUP_FUNCTIONS = {\n",
" 'aochildes': cleanup_aochildes,\n",
Expand All @@ -70,25 +76,25 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🧹 Cleaned 'bnc_spoken.train' (size 4883879 -> 4851676) in babylm_10M\n",
"🧹 Cleaned 'childes.train' (size 15482927 -> 15482927) in babylm_10M\n",
"🧹 Cleaned 'gutenberg.train' (size 13910986 -> 13910986) in babylm_10M\n",
"🧹 Cleaned 'open_subtitles.train' (size 10806305 -> 10804026) in babylm_10M\n",
"🧹 Cleaned 'simple_wiki.train' (size 8411630 -> 8387062) in babylm_10M\n",
"🧹 Cleaned 'switchboard.train' (size 719322 -> 719322) in babylm_10M\n",
"🧹 Cleaned 'bnc_spoken.dev' (size 6538139 -> 6503778) in babylm_dev\n",
"🧹 Cleaned 'childes.dev' (size 14638378 -> 14638378) in babylm_dev\n",
"🧹 Cleaned 'gutenberg.dev' (size 15490473 -> 15490473) in babylm_dev\n",
"🧹 Cleaned 'open_subtitles.dev' (size 11016133 -> 11014854) in babylm_dev\n",
"🧹 Cleaned 'simple_wiki.dev' (size 8149513 -> 8128239) in babylm_dev\n",
"🧹 Cleaned 'switchboard.dev' (size 724013 -> 724013) in babylm_dev\n"
"🧹 Cleaned 'childes.train' (size 15482927 -> 15482927) in train_10M\n",
"🧹 Cleaned 'simple_wiki.train' (size 8411630 -> 8387062) in train_10M\n",
"🧹 Cleaned 'bnc_spoken.train' (size 4883879 -> 4851676) in train_10M\n",
"🧹 Cleaned 'gutenberg.train' (size 13910986 -> 13910986) in train_10M\n",
"🧹 Cleaned 'switchboard.train' (size 719322 -> 719322) in train_10M\n",
"🧹 Cleaned 'open_subtitles.train' (size 10806305 -> 10804026) in train_10M\n",
"🧹 Cleaned 'switchboard.dev' (size 724013 -> 724013) in dev\n",
"🧹 Cleaned 'simple_wiki.dev' (size 8149513 -> 8128239) in dev\n",
"🧹 Cleaned 'gutenberg.dev' (size 15490473 -> 15490473) in dev\n",
"🧹 Cleaned 'bnc_spoken.dev' (size 6538139 -> 6503778) in dev\n",
"🧹 Cleaned 'open_subtitles.dev' (size 11016133 -> 11014854) in dev\n",
"🧹 Cleaned 'childes.dev' (size 14638378 -> 14638378) in dev\n"
]
}
],
Expand Down Expand Up @@ -117,7 +123,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -129,7 +135,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand All @@ -142,7 +148,7 @@
],
"source": [
"# We train the tokenizer on the train data only\n",
"data_dir = Path(\"F:/llm-deploy-data/data/Babyllama/babylm_10M_clean/\")\n",
"data_dir = Path(\"./data/train_10M_clean/\")\n",
"\n",
"paths = [str(f) for f in data_dir.glob(\"*\") if f.is_file() and not f.name.endswith(\".DS_Store\") and f.suffix in [\".train\"]]\n",
"\n",
Expand All @@ -153,7 +159,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -167,21 +173,32 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 9,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\n"
]
}
],
"source": [
"trainer = trainers.BpeTrainer(vocab_size=16000, min_frequency=2, special_tokens=[\"<pad>\", \"<s>\", \"</s>\"])\n",
"tokenizer.train(paths, trainer)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"tokenizer_path = DATA_ROOT / \"models/gpt-clean-16000.json\"\n",
"tokenizer_path = \"./models/gpt-clean-16000.json\"\n",
"os.makedirs(\"models\", exist_ok=True)\n",
"tokenizer.save(str(tokenizer_path), pretty=True)"
]
},
Expand All @@ -194,15 +211,15 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Encoded String: ['ĠThe', 'Ġquick', 'Ġbrown', 'Ġfox', 'Ġjumps', 'Ġover', 'Ġthe', 'Ġlazy', 'Ġdog', '.']\n",
"Encoded IDs: [302, 1784, 3266, 5712, 15961, 541, 190, 11553, 1469, 16]\n",
"Encoded IDs: [300, 1782, 3264, 5710, 15959, 539, 188, 11551, 1467, 16]\n",
"Decoded String: The quick brown fox jumps over the lazy dog.\n"
]
}
Expand Down Expand Up @@ -248,7 +265,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.9.20"
},
"orig_nbformat": 4
},
Expand Down
Loading

0 comments on commit 7c6ea38

Please sign in to comment.