Skip to content

Commit 098db57

Browse files
Merge pull request #79 from huggingface/improve_checkpointing
Improve checkpointing
2 parents 8be6c6a + 469afdc commit 098db57

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

models/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class VLMConfig:
3737
mp_pixel_shuffle_factor: int = 2
3838

3939
vlm_load_backbone_weights: bool = True
40-
vlm_checkpoint_path: str = 'checkpoints/nanoVLM-222M'
40+
vlm_checkpoint_path: str = 'checkpoints'
4141
hf_repo_name: str = 'nanoVLM'
4242

4343

train.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,18 @@ def dist_gather(o):
6161
def wrap_model(model):
6262
return DistributedDataParallel(model, device_ids=[dist.get_rank()])
6363

64-
def get_run_name(train_cfg):
64+
def get_run_name(train_cfg, vlm_cfg):
6565
dataset_size = "full_ds" if train_cfg.data_cutoff_idx is None else f"{train_cfg.data_cutoff_idx}samples"
6666
batch_size = f"bs{int(train_cfg.batch_size*get_world_size()*train_cfg.gradient_accumulation_steps)}"
6767
epochs = f"ep{train_cfg.epochs}"
6868
learning_rate = f"lr{train_cfg.lr_backbones}-{train_cfg.lr_mp}"
6969
num_gpus = f"{get_world_size()}xGPU"
7070
date = time.strftime("%m%d")
71+
vit = f"{vlm_cfg.vit_model_type.split('/')[-1]}"
72+
mp = f"mp{vlm_cfg.mp_pixel_shuffle_factor}"
73+
llm = f"{vlm_cfg.lm_model_type.split('/')[-1]}"
7174

72-
return f"nanoVLM_{num_gpus}_{dataset_size}_{batch_size}_{epochs}_{learning_rate}_{date}"
75+
return f"nanoVLM_{vit}_{mp}_{llm}_{num_gpus}_{dataset_size}_{batch_size}_{epochs}_{learning_rate}_{date}"
7376

7477
def get_dataloaders(train_cfg, vlm_cfg):
7578
# Create datasets
@@ -202,7 +205,7 @@ def train(train_cfg, vlm_cfg):
202205

203206
total_dataset_size = len(train_loader.dataset)
204207
if train_cfg.log_wandb and is_master():
205-
run_name = get_run_name(train_cfg)
208+
run_name = get_run_name(train_cfg, vlm_cfg)
206209
if train_cfg.data_cutoff_idx is None:
207210
run_name = run_name.replace("full_ds", f"{total_dataset_size}samples")
208211
run = wandb.init(
@@ -353,7 +356,7 @@ def train(train_cfg, vlm_cfg):
353356
epoch_accuracy = test_mmstar(eval_model, tokenizer, test_loader, device)
354357
if epoch_accuracy > best_accuracy:
355358
best_accuracy = epoch_accuracy
356-
eval_model.save_pretrained(save_directory=vlm_cfg.vlm_checkpoint_path)
359+
eval_model.save_pretrained(save_directory=os.path.join(vlm_cfg.vlm_checkpoint_path, run_name))
357360
if train_cfg.log_wandb and is_master():
358361
run.log({"accuracy": epoch_accuracy}, step=global_step)
359362
print(f"Step: {global_step}, Loss: {batch_loss:.4f}, Tokens/s: {tokens_per_second:.2f}, Accuracy: {epoch_accuracy:.4f}")
@@ -404,7 +407,7 @@ def train(train_cfg, vlm_cfg):
404407
# Push the best model to the hub (Please set your user name in the config!)
405408
if vlm_cfg.hf_repo_name is not None:
406409
print("Training complete. Pushing model to Hugging Face Hub...")
407-
hf_model = VisionLanguageModel.from_pretrained(vlm_cfg.vlm_checkpoint_path)
410+
hf_model = VisionLanguageModel.from_pretrained(os.path.join(vlm_cfg.vlm_checkpoint_path, run_name))
408411
hf_model.push_to_hub(vlm_cfg.hf_repo_name)
409412

410413
if train_cfg.log_wandb:

0 commit comments

Comments
 (0)