Skip to content

Commit 26eabdd

Browse files
committed
added logic to convert float32 to the corresponding precision
1 parent 8413edd commit 26eabdd

File tree

1 file changed

+25
-4
lines changed

1 file changed

+25
-4
lines changed

Diff for: src/cehrbert/runners/hf_cehrbert_finetune_runner.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import json
23
import os
34
from datetime import datetime
@@ -69,6 +70,12 @@ def compute_metrics(references: Union[List[float], pd.Series], probs: Union[List
6970
return {"roc_auc": None, "pr_auc": None}
7071

7172

73+
def get_torch_dtype(torch_dtype: str) -> Union[torch.dtype, str]:
74+
if hasattr(torch, torch_dtype):
75+
return getattr(torch, torch_dtype)
76+
return "auto"
77+
78+
7279
def load_pretrained_tokenizer(
7380
model_args,
7481
) -> CehrBertTokenizer:
@@ -90,9 +97,7 @@ def load_finetuned_model(model_args: ModelArguments, model_name_or_path: str) ->
9097
)
9198
# Try to create a new model based on the base model
9299
model_name_or_path = os.path.expanduser(model_name_or_path)
93-
torch_dtype = "auto"
94-
if hasattr(torch, model_args.torch_dtype):
95-
torch_dtype = getattr(torch, model_args.torch_dtype)
100+
torch_dtype = get_torch_dtype(model_args.torch_dtype)
96101
try:
97102
model = finetune_model_cls.from_pretrained(
98103
model_name_or_path, torch_dtype=torch_dtype, attn_implementation=model_args.attn_implementation
@@ -107,6 +112,16 @@ def load_finetuned_model(model_args: ModelArguments, model_name_or_path: str) ->
107112
raise ValueError(f"Can not load the finetuned model from {model_name_or_path}")
108113

109114

115+
def data_collate_fn(features, model_type: torch.dtype, collator: CehrBertDataCollator):
116+
batch = collator(features)
117+
if model_type != torch.float32:
118+
for key, value in batch.items():
119+
# Only convert float32 tensors to bfloat16
120+
if isinstance(value, torch.Tensor) and value.dtype == torch.float32:
121+
batch[key] = value.to(model_type)
122+
return batch
123+
124+
110125
def main():
111126

112127
data_args, model_args, training_args = parse_runner_args()
@@ -297,7 +312,12 @@ def assign_split(example):
297312

298313
# Remove all the cached files collected during the data transformation if there are any
299314
cache_file_collector.remove_cache_files()
300-
collator = CehrBertDataCollator(tokenizer, model_args.max_position_embeddings, is_pretraining=False)
315+
316+
collator = functools.partial(
317+
data_collate_fn,
318+
model_type=get_torch_dtype(model_args.torch_dtype),
319+
collator=CehrBertDataCollator(tokenizer, model_args.max_position_embeddings, is_pretraining=False),
320+
)
301321

302322
# Set seed before initializing model.
303323
set_seed(training_args.seed)
@@ -307,6 +327,7 @@ def assign_split(example):
307327

308328
if training_args.do_train:
309329
model = load_finetuned_model(model_args, model_args.model_name_or_path)
330+
310331
# If lora is enabled, we add LORA adapters to the model
311332
if model_args.use_lora:
312333
# When LORA is used, the trainer could not automatically find this label,

0 commit comments

Comments
 (0)