1
+ import functools
1
2
import json
2
3
import os
3
4
from datetime import datetime
@@ -69,6 +70,12 @@ def compute_metrics(references: Union[List[float], pd.Series], probs: Union[List
69
70
return {"roc_auc" : None , "pr_auc" : None }
70
71
71
72
73
+ def get_torch_dtype (torch_dtype : str ) -> Union [torch .dtype , str ]:
74
+ if hasattr (torch , torch_dtype ):
75
+ return getattr (torch , torch_dtype )
76
+ return "auto"
77
+
78
+
72
79
def load_pretrained_tokenizer (
73
80
model_args ,
74
81
) -> CehrBertTokenizer :
@@ -90,9 +97,7 @@ def load_finetuned_model(model_args: ModelArguments, model_name_or_path: str) ->
90
97
)
91
98
# Try to create a new model based on the base model
92
99
model_name_or_path = os .path .expanduser (model_name_or_path )
93
- torch_dtype = "auto"
94
- if hasattr (torch , model_args .torch_dtype ):
95
- torch_dtype = getattr (torch , model_args .torch_dtype )
100
+ torch_dtype = get_torch_dtype (model_args .torch_dtype )
96
101
try :
97
102
model = finetune_model_cls .from_pretrained (
98
103
model_name_or_path , torch_dtype = torch_dtype , attn_implementation = model_args .attn_implementation
@@ -107,6 +112,16 @@ def load_finetuned_model(model_args: ModelArguments, model_name_or_path: str) ->
107
112
raise ValueError (f"Can not load the finetuned model from { model_name_or_path } " )
108
113
109
114
115
+ def data_collate_fn (features , model_type : torch .dtype , collator : CehrBertDataCollator ):
116
+ batch = collator (features )
117
+ if model_type != torch .float32 :
118
+ for key , value in batch .items ():
119
+ # Only convert float32 tensors to bfloat16
120
+ if isinstance (value , torch .Tensor ) and value .dtype == torch .float32 :
121
+ batch [key ] = value .to (model_type )
122
+ return batch
123
+
124
+
110
125
def main ():
111
126
112
127
data_args , model_args , training_args = parse_runner_args ()
@@ -297,7 +312,12 @@ def assign_split(example):
297
312
298
313
# Remove all the cached files collected during the data transformation if there are any
299
314
cache_file_collector .remove_cache_files ()
300
- collator = CehrBertDataCollator (tokenizer , model_args .max_position_embeddings , is_pretraining = False )
315
+
316
+ collator = functools .partial (
317
+ data_collate_fn ,
318
+ model_type = get_torch_dtype (model_args .torch_dtype ),
319
+ collator = CehrBertDataCollator (tokenizer , model_args .max_position_embeddings , is_pretraining = False ),
320
+ )
301
321
302
322
# Set seed before initializing model.
303
323
set_seed (training_args .seed )
@@ -307,6 +327,7 @@ def assign_split(example):
307
327
308
328
if training_args .do_train :
309
329
model = load_finetuned_model (model_args , model_args .model_name_or_path )
330
+
310
331
# If lora is enabled, we add LORA adapters to the model
311
332
if model_args .use_lora :
312
333
# When LORA is used, the trainer could not automatically find this label,
0 commit comments