@@ -61,15 +61,18 @@ def dist_gather(o):
61
61
def wrap_model (model ):
62
62
return DistributedDataParallel (model , device_ids = [dist .get_rank ()])
63
63
64
- def get_run_name (train_cfg ):
64
+ def get_run_name (train_cfg , vlm_cfg ):
65
65
dataset_size = "full_ds" if train_cfg .data_cutoff_idx is None else f"{ train_cfg .data_cutoff_idx } samples"
66
66
batch_size = f"bs{ int (train_cfg .batch_size * get_world_size ()* train_cfg .gradient_accumulation_steps )} "
67
67
epochs = f"ep{ train_cfg .epochs } "
68
68
learning_rate = f"lr{ train_cfg .lr_backbones } -{ train_cfg .lr_mp } "
69
69
num_gpus = f"{ get_world_size ()} xGPU"
70
70
date = time .strftime ("%m%d" )
71
+ vit = f"{ vlm_cfg .vit_model_type .split ('/' )[- 1 ]} "
72
+ mp = f"mp{ vlm_cfg .mp_pixel_shuffle_factor } "
73
+ llm = f"{ vlm_cfg .lm_model_type .split ('/' )[- 1 ]} "
71
74
72
- return f"nanoVLM_{ num_gpus } _{ dataset_size } _{ batch_size } _{ epochs } _{ learning_rate } _{ date } "
75
+ return f"nanoVLM_{ vit } _ { mp } _ { llm } _ { num_gpus } _{ dataset_size } _{ batch_size } _{ epochs } _{ learning_rate } _{ date } "
73
76
74
77
def get_dataloaders (train_cfg , vlm_cfg ):
75
78
# Create datasets
@@ -202,7 +205,7 @@ def train(train_cfg, vlm_cfg):
202
205
203
206
total_dataset_size = len (train_loader .dataset )
204
207
if train_cfg .log_wandb and is_master ():
205
- run_name = get_run_name (train_cfg )
208
+ run_name = get_run_name (train_cfg , vlm_cfg )
206
209
if train_cfg .data_cutoff_idx is None :
207
210
run_name = run_name .replace ("full_ds" , f"{ total_dataset_size } samples" )
208
211
run = wandb .init (
@@ -353,7 +356,7 @@ def train(train_cfg, vlm_cfg):
353
356
epoch_accuracy = test_mmstar (eval_model , tokenizer , test_loader , device )
354
357
if epoch_accuracy > best_accuracy :
355
358
best_accuracy = epoch_accuracy
356
- eval_model .save_pretrained (save_directory = vlm_cfg .vlm_checkpoint_path )
359
+ eval_model .save_pretrained (save_directory = os . path . join ( vlm_cfg .vlm_checkpoint_path , run_name ) )
357
360
if train_cfg .log_wandb and is_master ():
358
361
run .log ({"accuracy" : epoch_accuracy }, step = global_step )
359
362
print (f"Step: { global_step } , Loss: { batch_loss :.4f} , Tokens/s: { tokens_per_second :.2f} , Accuracy: { epoch_accuracy :.4f} " )
@@ -404,7 +407,7 @@ def train(train_cfg, vlm_cfg):
404
407
# Push the best model to the hub (Please set your user name in the config!)
405
408
if vlm_cfg .hf_repo_name is not None :
406
409
print ("Training complete. Pushing model to Hugging Face Hub..." )
407
- hf_model = VisionLanguageModel .from_pretrained (vlm_cfg .vlm_checkpoint_path )
410
+ hf_model = VisionLanguageModel .from_pretrained (os . path . join ( vlm_cfg .vlm_checkpoint_path , run_name ) )
408
411
hf_model .push_to_hub (vlm_cfg .hf_repo_name )
409
412
410
413
if train_cfg .log_wandb :
0 commit comments