Skip to content

Commit dc71d88

Browse files
mhenrichsenMads Henrichsen
and
Mads Henrichsen
authored
feat/llama-2 examples (axolotl-ai-cloud#319)
* qlora llama-2 * qlora llama-2 * linting * readme * lora added * linting * change group_by_length * 13b fitting on 24gb * grouped lengths true * add pad token * change out dir --------- Co-authored-by: Mads Henrichsen <mads@Brbar-tilhrende-Mads.local>
1 parent 77085ea commit dc71d88

File tree

3 files changed

+153
-0
lines changed

3 files changed

+153
-0
lines changed

Diff for: examples/llama-2/README.md

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Overview
2+
3+
This is an example of a llama-2 configuration for 7b and 13b. The yaml file contains configuration for the 7b variant, but you can just aswell use the same settings for 13b.
4+
5+
The 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes.
6+
7+
The 13b variant will fit if you change these settings to these values:
8+
gradient_accumulation_steps: 2
9+
micro_batch_size: 1
10+
11+
```shell
12+
accelerate launch scripts/finetune.py examples/llama-2/qlora.yml
13+
14+
```
15+
or
16+
17+
```shell
18+
accelerate launch scripts/finetune.py examples/llama-2/lora.yml
19+
20+
```

Diff for: examples/llama-2/lora.yml

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
base_model: meta-llama/Llama-2-7b-hf
2+
base_model_config: meta-llama/Llama-2-7b-hf
3+
model_type: LlamaForCausalLM
4+
tokenizer_type: LlamaTokenizer
5+
6+
load_in_8bit: true
7+
load_in_4bit: false
8+
strict: false
9+
10+
datasets:
11+
- path: mhenrichsen/alpaca_2k_test
12+
type: alpaca
13+
dataset_prepared_path: last_run_prepared
14+
val_set_size: 0.01
15+
output_dir: ./lora-out
16+
17+
sequence_len: 4096
18+
max_packed_sequence_len: 4096
19+
20+
adapter: lora
21+
lora_model_dir:
22+
lora_r: 32
23+
lora_alpha: 16
24+
lora_dropout: 0.05
25+
lora_target_linear: true
26+
lora_fan_in_fan_out:
27+
28+
wandb_project:
29+
wandb_watch:
30+
wandb_run_id:
31+
wandb_log_model:
32+
33+
gradient_accumulation_steps: 4
34+
micro_batch_size: 2
35+
num_epochs: 3
36+
optimizer: adamw_bnb_8bit
37+
lr_scheduler: cosine
38+
learning_rate: 0.0002
39+
40+
train_on_inputs: false
41+
group_by_length: true
42+
bf16: true
43+
fp16: false
44+
tf32: false
45+
46+
gradient_checkpointing: true
47+
early_stopping_patience:
48+
resume_from_checkpoint:
49+
local_rank:
50+
logging_steps: 1
51+
xformers_attention: true
52+
flash_attention:
53+
54+
warmup_steps: 10
55+
eval_steps: 20
56+
save_steps:
57+
debug:
58+
deepspeed:
59+
weight_decay: 0.0
60+
fsdp:
61+
fsdp_config:
62+
special_tokens:
63+
bos_token: "<s>"
64+
eos_token: "</s>"
65+
unk_token: "<unk>"
66+
pad_token: "<pad>"

Diff for: examples/llama-2/qlora.yml

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
base_model: meta-llama/Llama-2-7b-hf
2+
base_model_config: meta-llama/Llama-2-7b-hf
3+
model_type: LlamaForCausalLM
4+
tokenizer_type: LlamaTokenizer
5+
6+
load_in_8bit: false
7+
load_in_4bit: true
8+
strict: false
9+
10+
datasets:
11+
- path: mhenrichsen/alpaca_2k_test
12+
type: alpaca
13+
dataset_prepared_path: last_run_prepared
14+
val_set_size: 0.01
15+
output_dir: ./qlora-out
16+
17+
adapter: qlora
18+
lora_model_dir:
19+
20+
sequence_len: 4096
21+
max_packed_sequence_len: 4096
22+
lora_r: 32
23+
lora_alpha: 16
24+
lora_dropout: 0.05
25+
lora_target_modules:
26+
lora_target_linear: true
27+
lora_fan_in_fan_out:
28+
29+
wandb_project:
30+
wandb_watch:
31+
wandb_run_id:
32+
wandb_log_model:
33+
34+
gradient_accumulation_steps: 4
35+
micro_batch_size: 2
36+
num_epochs: 3
37+
optimizer: paged_adamw_32bit
38+
lr_scheduler: cosine
39+
learning_rate: 0.0002
40+
41+
train_on_inputs: false
42+
group_by_length: true
43+
bf16: true
44+
fp16: false
45+
tf32: false
46+
47+
gradient_checkpointing: true
48+
early_stopping_patience:
49+
resume_from_checkpoint:
50+
local_rank:
51+
logging_steps: 1
52+
xformers_attention: true
53+
flash_attention:
54+
55+
warmup_steps: 10
56+
eval_steps: 20
57+
save_steps:
58+
debug:
59+
deepspeed:
60+
weight_decay: 0.0
61+
fsdp:
62+
fsdp_config:
63+
special_tokens:
64+
bos_token: "<s>"
65+
eos_token: "</s>"
66+
unk_token: "<unk>"
67+
pad_token: "<pad>"

0 commit comments

Comments
 (0)