Skip to content

Commit fa88952

Browse files
keeeeenwrasbt
andauthored
Add MicroLlama training support (#1457)
Co-authored-by: rasbt <mail@sebastianraschka.com>
1 parent e567dbe commit fa88952

10 files changed

+249
-72
lines changed

README.md

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -73,30 +73,31 @@ LitGPT has 🤯 **custom, from-scratch implementations** of [20+ LLMs](tutorials
7373

7474
| Model | Model size | Author | Reference |
7575
|----|----|----|----|
76-
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
77-
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
78-
| Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) |
79-
| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |
80-
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
81-
| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
82-
| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
83-
| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
84-
| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
85-
| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) |
86-
| LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
87-
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
88-
| Mistral | 7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
89-
| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) |
90-
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
91-
| Phi | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
76+
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
77+
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
78+
| Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) |
79+
| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |
80+
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
81+
| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
82+
| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
83+
| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
84+
| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
85+
| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) |
86+
| LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
87+
| MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama)
88+
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
89+
| Mistral | 7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
90+
| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) |
91+
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
92+
| Phi | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
9293
| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
93-
| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
94-
| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
95-
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
96-
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
97-
| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
98-
| TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) |
99-
| Vicuna | 7B, 13B, 33B | LMSYS | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/)
94+
| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
95+
| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
96+
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
97+
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
98+
| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
99+
| TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) |
100+
| Vicuna | 7B, 13B, 33B | LMSYS | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) |
100101

101102
**Tip**: You can list all available models by running the `litgpt download list` command.
102103

config_hub/pretrain/microllama.yaml

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
2+
# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with
3+
# ``model_config``. (type: Optional[str], default: null)
4+
model_name: micro-llama-300M
5+
6+
# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with
7+
# ``model_config``. (type: Optional[Config], default: null)
8+
model_config:
9+
10+
# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
11+
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
12+
out_dir: out/pretrain/micro-llama
13+
14+
# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
15+
precision: bf16-mixed
16+
17+
# Optional path to a checkpoint directory to initialize the model from.
18+
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
19+
initial_checkpoint_dir:
20+
21+
# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
22+
# from the latest checkpoint in ``out_dir``. (type: Union[bool, Path], default: False)
23+
resume: false
24+
25+
# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
26+
data: MicroLlama
27+
28+
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
29+
train:
30+
31+
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
32+
save_interval: 1000
33+
34+
# Number of iterations between logging calls (type: int, default: 1)
35+
log_interval: 1
36+
37+
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 48)
38+
# Scale this number according to the number of GPU and memory size per GPU
39+
# For example, we used 48 for 4 x 24G 4090
40+
global_batch_size: 48
41+
42+
# Number of samples per data-parallel rank (type: int, default: 12)
43+
# Scale this number according to the memory size per GPU
44+
# For example, we used 12 for 24G 4090
45+
micro_batch_size: 12
46+
47+
# Number of iterations with learning rate warmup active (type: int, default: 2000)
48+
lr_warmup_steps: 2000
49+
50+
# Number of epochs to train on (type: Optional[int], default: null)
51+
epochs:
52+
53+
# Total number of tokens to train on (type: Optional[int], default: 3000000000000)
54+
max_tokens: 3000000000000
55+
56+
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
57+
max_steps:
58+
59+
# Limits the length of samples. Off by default (type: Optional[int], default: null)
60+
max_seq_length: 2048
61+
62+
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False)
63+
tie_embeddings:
64+
65+
# (type: Optional[float], default: 1.0)
66+
max_norm: 1.0
67+
68+
# (type: float, default: 4e-05)
69+
min_lr: 4.0e-05
70+
71+
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
72+
eval:
73+
74+
# Number of optimizer steps between evaluation calls (type: int, default: 1000)
75+
interval: 1000
76+
77+
# Number of tokens to generate (type: Optional[int], default: null)
78+
max_new_tokens:
79+
80+
# Number of iterations (type: int, default: 100)
81+
max_iters: 100
82+
83+
# Whether to evaluate on the validation set at the beginning of the training
84+
initial_validation: false
85+
86+
# Optimizer-related arguments
87+
optimizer:
88+
89+
class_path: torch.optim.AdamW
90+
91+
init_args:
92+
93+
# (type: float, default: 0.001)
94+
lr: 4e-4
95+
96+
# (type: float, default: 0.01)
97+
weight_decay: 0.1
98+
99+
# (type: tuple, default: (0.9,0.999))
100+
betas:
101+
- 0.9
102+
- 0.95
103+
104+
# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)
105+
devices: auto
106+
107+
# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data
108+
# module require this. (type: Optional[Path], default: null)
109+
tokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf
110+
111+
# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: tensorboard)
112+
logger_name: tensorboard
113+
114+
# The random seed to use for reproducibility. (type: int, default: 42)
115+
seed: 42

litgpt/config.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1548,7 +1548,7 @@ def norm_class(self) -> Type:
15481548
rotary_percentage=1.0,
15491549
parallel_residual=False,
15501550
bias=False,
1551-
norm_class_name="RMSNorm", # original TinyLlama uses FusedRMSNorm
1551+
norm_class_name="RMSNorm", # original TinyLlama use FusedRMSNorm
15521552
norm_eps=1e-5,
15531553
mlp_class_name="LLaMAMLP",
15541554
intermediate_size=5632,
@@ -1563,6 +1563,32 @@ def norm_class(self) -> Type:
15631563
configs.append(copy)
15641564

15651565

1566+
############
1567+
# MicroLlama
1568+
############
1569+
micro_llama = [
1570+
dict(
1571+
name="micro-llama-300M",
1572+
hf_config=dict(org="keeeeenw", name="MicroLlama"),
1573+
block_size=2048,
1574+
vocab_size=32000,
1575+
padding_multiple=64,
1576+
n_layer=12,
1577+
n_head=16,
1578+
n_embd=1024,
1579+
rotary_percentage=1.0,
1580+
parallel_residual=False,
1581+
bias=False,
1582+
norm_class_name="RMSNorm", # original TinyLlama and MicroLlama use FusedRMSNorm
1583+
norm_eps=1e-5,
1584+
mlp_class_name="LLaMAMLP",
1585+
intermediate_size=5632,
1586+
n_query_groups=4,
1587+
)
1588+
]
1589+
configs.extend(micro_llama)
1590+
1591+
15661592
##########################
15671593
# Trelis Function Calling
15681594
##########################

litgpt/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from litgpt.data.tinyllama import TinyLlama
1616
from litgpt.data.tinystories import TinyStories
1717
from litgpt.data.openwebtext import OpenWebText
18+
from litgpt.data.microllama import MicroLlama
1819

1920

2021
__all__ = [
@@ -34,5 +35,6 @@
3435
"TextFiles",
3536
"TinyLlama",
3637
"TinyStories",
38+
"MicroLlama"
3739
"get_sft_collate_fn",
3840
]

litgpt/data/microllama.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2+
from dataclasses import dataclass
3+
from pathlib import Path
4+
from typing import Union
5+
6+
from litgpt.data import TinyLlama
7+
8+
@dataclass
9+
class MicroLlama(TinyLlama):
10+
"""The MicroLlama data module is composed of only SlimPajama data."""
11+
12+
def __init__(self, data_path: Union[str, Path] = Path("data/"), seed: int = 42, num_workers: int = 8):
13+
super().__init__(data_path=data_path, seed=seed, num_workers=num_workers, use_starcoder=False)

litgpt/data/prepare_slimpajama.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212

1313
class SlimPajamaDataRecipe(DataChunkRecipe):
14+
is_generator = True
15+
1416
def __init__(self, tokenizer: Tokenizer, chunk_size: int):
1517
super().__init__(chunk_size)
1618
self.tokenizer = tokenizer

litgpt/data/prepare_starcoder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919

2020
class StarcoderDataRecipe(DataChunkRecipe):
21+
is_generator = True
22+
2123
def __init__(self, tokenizer: Tokenizer, chunk_size: int):
2224
super().__init__(chunk_size)
2325
self.tokenizer = tokenizer

litgpt/data/tinyllama.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class TinyLlama(DataModule):
2424
"""The random seed for shuffling the dataset."""
2525
num_workers: int = 8
2626
"""How many DataLoader processes to use for loading."""
27+
use_starcoder: bool = True
28+
"""Toggle for using Starcoder data."""
2729

2830
batch_size: int = field(init=False, repr=False, default=1)
2931
seq_length: int = field(init=False, repr=False, default=2048)
@@ -32,7 +34,11 @@ def __post_init__(self):
3234
# Could be a remote path (s3://) or a local path
3335
self.slimpajama_train = str(self.data_path).rstrip("/") + "/slimpajama/train"
3436
self.slimpajama_val = str(self.data_path).rstrip("/") + "/slimpajama/val"
35-
self.starcoder_train = str(self.data_path).rstrip("/") + "/starcoder"
37+
self.required_paths = [self.slimpajama_train, self.slimpajama_val]
38+
39+
if self.use_starcoder:
40+
self.starcoder_train = str(self.data_path).rstrip("/") + "/starcoder"
41+
self.required_paths += [self.starcoder_train]
3642

3743
def connect(
3844
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
@@ -41,7 +47,7 @@ def connect(
4147
self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well
4248

4349
def prepare_data(self) -> None:
44-
for path in (self.slimpajama_train, self.slimpajama_val, self.starcoder_train):
50+
for path in self.required_paths:
4551
if not path.startswith("s3://") and not Path(path).is_dir():
4652
raise FileNotFoundError(
4753
"The data path for TinyLlama is expected to be the directory containing these subdirectories:"
@@ -52,28 +58,33 @@ def prepare_data(self) -> None:
5258
def train_dataloader(self) -> DataLoader:
5359
from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset, TokensLoader
5460

55-
train_datasets = [
56-
StreamingDataset(
57-
input_dir=self.slimpajama_train,
58-
item_loader=TokensLoader(block_size=self.seq_length),
59-
shuffle=True,
60-
drop_last=True,
61-
),
62-
StreamingDataset(
63-
input_dir=self.starcoder_train,
64-
item_loader=TokensLoader(block_size=self.seq_length),
65-
shuffle=True,
66-
drop_last=True,
67-
),
68-
]
69-
70-
# Mix SlimPajama data and Starcoder data with these proportions:
71-
weights = (0.693584, 0.306416)
72-
combined_dataset = CombinedStreamingDataset(
73-
datasets=train_datasets, seed=self.seed, weights=weights, iterate_over_all=False
61+
slim_train_data = StreamingDataset(
62+
input_dir=self.slimpajama_train,
63+
item_loader=TokensLoader(block_size=self.seq_length),
64+
shuffle=True,
65+
drop_last=True,
7466
)
67+
train_data = slim_train_data
68+
69+
if self.use_starcoder:
70+
train_datasets = [
71+
slim_train_data,
72+
StreamingDataset(
73+
input_dir=self.starcoder_train,
74+
item_loader=TokensLoader(block_size=self.seq_length),
75+
shuffle=True,
76+
drop_last=True,
77+
),
78+
]
79+
80+
# Mix SlimPajama data and Starcoder data with these proportions:
81+
weights = (0.693584, 0.306416)
82+
train_data = CombinedStreamingDataset(
83+
datasets=train_datasets, seed=self.seed, weights=weights, iterate_over_all=False
84+
)
85+
7586
train_dataloader = StreamingDataLoader(
76-
combined_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
87+
train_data, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
7788
)
7889
return train_dataloader
7990

litgpt/pretrain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from litgpt import Tokenizer
2121
from litgpt.args import EvalArgs, TrainArgs
2222
from litgpt.config import name_to_config
23-
from litgpt.data import DataModule, TinyLlama
23+
from litgpt.data import DataModule, TinyLlama, MicroLlama
2424
from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP
2525
from litgpt.utils import (
2626
CycleIterator,

0 commit comments

Comments
 (0)