|
1295 | 1295 | "from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl\n",
|
1296 | 1296 | "from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR\n",
|
1297 | 1297 | "\n",
|
| 1298 | + "\n", |
1298 | 1299 | "class SavePeftModelCallback(TrainerCallback):\n",
|
1299 | 1300 | " def on_save(\n",
|
1300 |
| - " self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs,\n", |
| 1301 | + " self,\n", |
| 1302 | + " args: TrainingArguments,\n", |
| 1303 | + " state: TrainerState,\n", |
| 1304 | + " control: TrainerControl,\n", |
| 1305 | + " **kwargs,\n", |
1301 | 1306 | " ):\n",
|
1302 |
| - " checkpoint_folder = os.path.join(\n", |
1303 |
| - " args.output_dir, f\"{PREFIX_CHECKPOINT_DIR}-{state.global_step}\"\n", |
1304 |
| - " ) \n", |
| 1307 | + " checkpoint_folder = os.path.join(args.output_dir, f\"{PREFIX_CHECKPOINT_DIR}-{state.global_step}\")\n", |
1305 | 1308 | "\n",
|
1306 | 1309 | " peft_model_path = os.path.join(checkpoint_folder, \"adapter_model\")\n",
|
1307 | 1310 | " kwargs[\"model\"].save_pretrained(peft_model_path)\n",
|
|
1311 | 1314 | " os.remove(pytorch_model_path)\n",
|
1312 | 1315 | " return control\n",
|
1313 | 1316 | "\n",
|
| 1317 | + "\n", |
1314 | 1318 | "trainer = Seq2SeqTrainer(\n",
|
1315 | 1319 | " args=training_args,\n",
|
1316 | 1320 | " model=model,\n",
|
|
1319 | 1323 | " data_collator=data_collator,\n",
|
1320 | 1324 | " # compute_metrics=compute_metrics,\n",
|
1321 | 1325 | " tokenizer=processor.feature_extractor,\n",
|
1322 |
| - " callbacks=[SavePeftModelCallback]\n", |
| 1326 | + " callbacks=[SavePeftModelCallback],\n", |
1323 | 1327 | ")\n",
|
1324 | 1328 | "model.config.use_cache = False # silence the warnings. Please re-enable for inference!"
|
1325 | 1329 | ]
|
|
0 commit comments