-
Notifications
You must be signed in to change notification settings - Fork 97
/
Copy pathtrain.py
80 lines (64 loc) · 2.94 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import sys
import traceback
from finetrainers import BaseArgs, SFTTrainer, TrainingType, get_logger
from finetrainers.config import _get_model_specifiction_cls
from finetrainers.trainer.sft_trainer.config import SFTFullRankConfig, SFTLowRankConfig
logger = get_logger()
def main():
try:
import multiprocessing
multiprocessing.set_start_method("fork")
except Exception as e:
logger.error(
f'Failed to set multiprocessing start method to "fork". This can lead to poor performance, high memory usage, or crashes. '
f"See: https://pytorch.org/docs/stable/notes/multiprocessing.html\n"
f"Error: {e}"
)
try:
args = BaseArgs()
argv = [y.strip() for x in sys.argv for y in x.split()]
training_type_index = argv.index("--training_type")
if training_type_index == -1:
raise ValueError("Training type not provided in command line arguments.")
training_type = argv[training_type_index + 1]
training_cls = None
if training_type == TrainingType.LORA:
training_cls = SFTLowRankConfig
elif training_type == TrainingType.FULL_FINETUNE:
training_cls = SFTFullRankConfig
else:
raise ValueError(f"Training type {training_type} not supported.")
training_config = training_cls()
args.extend_args(training_config.add_args, training_config.map_args, training_config.validate_args)
args = args.parse_args()
model_specification_cls = _get_model_specifiction_cls(args.model_name, args.training_type)
model_specification = model_specification_cls(
pretrained_model_name_or_path=args.pretrained_model_name_or_path,
tokenizer_id=args.tokenizer_id,
tokenizer_2_id=args.tokenizer_2_id,
tokenizer_3_id=args.tokenizer_3_id,
text_encoder_id=args.text_encoder_id,
text_encoder_2_id=args.text_encoder_2_id,
text_encoder_3_id=args.text_encoder_3_id,
transformer_id=args.transformer_id,
vae_id=args.vae_id,
text_encoder_dtype=args.text_encoder_dtype,
text_encoder_2_dtype=args.text_encoder_2_dtype,
text_encoder_3_dtype=args.text_encoder_3_dtype,
transformer_dtype=args.transformer_dtype,
vae_dtype=args.vae_dtype,
revision=args.revision,
cache_dir=args.cache_dir,
)
if args.training_type in [TrainingType.LORA, TrainingType.FULL_FINETUNE]:
trainer = SFTTrainer(args, model_specification)
else:
raise ValueError(f"Training type {args.training_type} not supported.")
trainer.run()
except KeyboardInterrupt:
logger.info("Received keyboard interrupt. Exiting...")
except Exception as e:
logger.error(f"An error occurred during training: {e}")
logger.error(traceback.format_exc())
if __name__ == "__main__":
main()