Skip to content

Commit 6a8bca7

Browse files
committed
pass the attn_implementation and torch_dtype to the model during fine-tuning
1 parent fe5af52 commit 6a8bca7

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

Diff for: src/cehrbert/runners/hf_cehrbert_finetune_runner.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,19 @@ def load_finetuned_model(model_args: ModelArguments, model_name_or_path: str) ->
9090
)
9191
# Try to create a new model based on the base model
9292
model_name_or_path = os.path.expanduser(model_name_or_path)
93+
torch_dtype = "auto"
94+
if hasattr(torch, model_args.torch_dtype):
95+
torch_dtype = getattr(torch, model_args.torch_dtype)
9396
try:
94-
return finetune_model_cls.from_pretrained(model_name_or_path)
97+
model = finetune_model_cls.from_pretrained(
98+
model_name_or_path, torch_dtype=torch_dtype, attn_implementation=model_args.attn_implementation
99+
)
100+
if torch_dtype == torch.bfloat16:
101+
return model.bfloat16()
102+
elif torch_dtype == torch.float16:
103+
return model.half()
104+
else:
105+
return model.float()
95106
except ValueError:
96107
raise ValueError(f"Can not load the finetuned model from {model_name_or_path}")
97108

0 commit comments

Comments
 (0)