forked from johnsmith0031/alpaca_lora_4bit
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy patharg_parser.py
119 lines (106 loc) · 6.67 KB
/
arg_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import argparse
from Finetune4bConfig import Finetune4bConfig
def parse_commandline():
parser = argparse.ArgumentParser(
prog=__file__.split(os.path.sep)[-1],
description="Produce LoRA in 4bit training",
usage="%(prog)s [config] [training]\n\nAll arguments are optional"
)
parser.add_argument("dataset", nargs="?",
default="./dataset.json",
help="Path to dataset file. Default: %(default)s"
)
parser_config = parser.add_argument_group("config")
parser_training = parser.add_argument_group("training")
# Config args group
parser_config.add_argument("--ds_type", choices=["txt", "alpaca", "gpt4all", "plaintext"], default="alpaca", required=False,
help="Dataset structure format. Default: %(default)s"
)
parser_config.add_argument("--lora_out_dir", default="alpaca_lora", required=False,
help="Directory to place new LoRA. Default: %(default)s"
)
parser_config.add_argument("--lora_apply_dir", default=None, required=False,
help="Path to directory from which LoRA has to be applied before training. Default: %(default)s"
)
parser_training.add_argument("--resume_checkpoint", default=None, required=False,
help="Resume training from specified checkpoint. Default: %(default)s"
)
parser_config.add_argument("--llama_q4_config_dir", default="./llama-13b-4bit/", required=False,
help="Path to the config.json, tokenizer_config.json, etc. Default: %(default)s"
)
parser_config.add_argument("--llama_q4_model", default="./llama-13b-4bit.pt", required=False,
help="Path to the quantized model in huggingface format. Default: %(default)s"
)
# Training args group
parser_training.add_argument("--mbatch_size", default=1, type=int, help="Micro-batch size. Default: %(default)s")
parser_training.add_argument("--batch_size", default=2, type=int, help="Batch size. Default: %(default)s")
parser_training.add_argument("--epochs", default=3, type=int, help="Epochs. Default: %(default)s")
parser_training.add_argument("--lr", default=2e-4, type=float, help="Learning rate. Default: %(default)s")
parser_training.add_argument("--cutoff_len", default=256, type=int, help="Default: %(default)s")
parser_training.add_argument("--max_position_embeddings", default=2048, type=int, help="Extend LLaMA position embedding table to embed sequences longer than the default. Default: %(default)s")
parser_training.add_argument("--lora_r", default=8, type=int, help="Default: %(default)s")
parser_training.add_argument("--lora_alpha", default=16, type=int, help="Default: %(default)s")
parser_training.add_argument("--lora_dropout", default=0.05, type=float, help="Default: %(default)s")
parser_training.add_argument("--grad_chckpt", action="store_true", required=False, help="Use gradient checkpoint. For 30B model. Default: %(default)s")
parser_training.add_argument("--grad_chckpt_ratio", default=1, type=float, help="Gradient checkpoint ratio. Default: %(default)s")
parser_training.add_argument("--val_set_size", default=0.2, type=float, help="Validation set size. Default: %(default)s")
parser_training.add_argument("--warmup_steps", default=50, type=int, help="Default: %(default)s")
parser_training.add_argument("--save_steps", default=50, type=int, help="Default: %(default)s")
parser_training.add_argument("--save_total_limit", default=3, type=int, help="Default: %(default)s")
parser_training.add_argument("--logging_steps", default=10, type=int, help="Default: %(default)s")
parser_training.add_argument("-c", "--checkpoint", action="store_true", help="Produce checkpoint instead of LoRA. Default: %(default)s")
parser_training.add_argument("--skip", action="store_true", help="Don't train model. Can be useful to produce checkpoint from existing LoRA. Default: %(default)s")
parser_training.add_argument("--verbose", action="store_true", help="If output log of training. Default: %(default)s")
# Data args
parser_training.add_argument("--txt_row_thd", default=-1, type=int, help="Custom thd for txt rows.")
parser_training.add_argument("--use_eos_token", default=1, type=int, help="Use eos token instead if padding with 0. enable with 1, disable with 0.")
# V2 model support
parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model")
parser_training.add_argument("--v1", action="store_true", help="Use V1 model")
# Multi GPU Support
parser_training.add_argument("--local_rank", type=int, default=0, help="local rank if using torch.distributed.launch")
# Flash Attention
parser_training.add_argument("--flash_attention", action="store_true", help="enables flash attention, can improve performance and reduce VRAM use")
parser_training.add_argument("--xformers", action="store_true", help="enables xformers memory efficient attention, can improve performance and reduce VRAM use")
# Train Backend
parser_training.add_argument("--backend", type=str, default='cuda', help="Backend to use. Triton or Cuda.")
return vars(parser.parse_args())
def get_config() -> Finetune4bConfig:
args = parse_commandline()
return Finetune4bConfig(
dataset=args["dataset"],
ds_type=args["ds_type"],
lora_out_dir=args["lora_out_dir"],
lora_apply_dir=args["lora_apply_dir"],
resume_checkpoint=args["resume_checkpoint"],
llama_q4_config_dir=args["llama_q4_config_dir"],
llama_q4_model=args["llama_q4_model"],
mbatch_size=args["mbatch_size"],
batch_size=args["batch_size"],
epochs=args["epochs"],
lr=args["lr"],
cutoff_len=args["cutoff_len"],
max_position_embeddings=args["max_position_embeddings"],
lora_r=args["lora_r"],
lora_alpha=args["lora_alpha"],
lora_dropout=args["lora_dropout"],
val_set_size=args["val_set_size"],
gradient_checkpointing=args["grad_chckpt"],
gradient_checkpointing_ratio=args["grad_chckpt_ratio"],
warmup_steps=args["warmup_steps"],
save_steps=args["save_steps"],
save_total_limit=args["save_total_limit"],
logging_steps=args["logging_steps"],
checkpoint=args["checkpoint"],
skip=args["skip"],
verbose=args["verbose"],
txt_row_thd=args["txt_row_thd"],
use_eos_token=args["use_eos_token"]!=0,
groupsize=args["groupsize"],
v1=args["v1"],
local_rank=args["local_rank"],
flash_attention=args["flash_attention"],
xformers=args["xformers"],
backend=args["backend"],
)