-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain_finetune.py
98 lines (80 loc) · 2.97 KB
/
train_finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import clip
from data.text_image_dm import MIMICDataModule
from models import CustomCLIPWrapper, init_img_model, init_txt_model, parse_arguments
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
def main(hparams):
using_clip = False
if hparams.use_clip:
if hparams.image_encoder is not None or hparams.text_encoder is not None:
print("Warning - image_encoder and text_encoder args unused")
using_clip = True
clp, preprocess = clip.load("RN50", device="cpu")
for p in clp.parameters():
p.data = p.data.float()
if p.grad:
p.grad.data = p.grad.data.float()
img_encoder = clp.visual
txt_encoder = clp.transformer
else:
if hparams.image_encoder is None or hparams.text_encoder is None:
print("Please --use_clip or set image and text encoders, exiting...")
exit(1)
img_encoder, _ = init_img_model(
hparams.image_encoder,
hparams.embed_dim,
hparams.freeze_img_encoder,
hparams.use_pretrained,
)
txt_encoder, tokenizer = init_txt_model(
hparams.text_encoder,
hparams.embed_dim,
hparams.freeze_txt_encoder,
hparams.freeze_layers,
hparams.add_projection,
hparams.local_files_only,
)
if hparams.minibatch_size < 1:
hparams.minibatch_size = hparams.batch_size
if hparams.checkpoint_file is not None:
model = CustomCLIPWrapper.load_from_checkpoint(
checkpoint_path=hparams.checkpoint_file,
image_encoder=img_encoder,
text_encoder=txt_encoder,
)
else:
model = CustomCLIPWrapper(
img_encoder,
txt_encoder,
hparams.minibatch_size,
using_clip=using_clip,
lr=hparams.lr,
lr_img=hparams.lr_img,
lr_txt=hparams.lr_txt,
warmup_epochs=hparams.warmup_epochs,
use_teacher=hparams.use_teacher,
)
if using_clip:
model.finish_clip_init(clp)
dm = MIMICDataModule.from_argparse_args(
hparams, custom_tokenizer=None if using_clip else tokenizer
)
callbacks = []
callbacks.append(ModelCheckpoint(monitor="val_loss", save_top_k=1))
if hparams.checkpoint_every > 0:
callbacks.append(ModelCheckpoint(every_n_train_steps=hparams.checkpoint_every))
if int(hparams.devices) > 1:
trainer = Trainer.from_argparse_args(
hparams, accelerator="gpu", strategy="ddp", callbacks=callbacks
)
else:
trainer = Trainer.from_argparse_args(
hparams, accelerator="auto", callbacks=callbacks
)
if hparams.checkpoint_file is not None:
trainer.fit(model, dm, ckpt_path=hparams.checkpoint_file)
else:
trainer.fit(model, dm)
if __name__ == "__main__":
args = parse_arguments()
main(args)