1
1
import json
2
2
import os
3
- from typing import Tuple
3
+ from typing import Any , Dict , List , Tuple
4
4
5
5
import numpy as np
6
6
import pandas as pd
7
7
from datasets import DatasetDict , load_from_disk
8
8
from peft import LoraConfig , get_peft_model
9
9
from scipy .special import expit as sigmoid
10
- from sklearn .metrics import accuracy_score , auc , precision_recall_curve , roc_auc_score
10
+ from sklearn .metrics import auc , precision_recall_curve , roc_auc_score
11
11
from transformers import EarlyStoppingCallback , Trainer , set_seed
12
12
from transformers .utils import logging
13
13
33
33
LOG = logging .get_logger ("transformers" )
34
34
35
35
36
- def compute_metrics (eval_pred ) :
37
- outputs , labels = eval_pred
38
- logits = outputs [ 0 ]
36
+ def compute_metrics (references : List [ float ], logits : List [ float ]) -> Dict [ str , Any ] :
37
+ """
38
+ Computes evaluation metrics for binary classification, including ROC-AUC and PR-AUC, based on reference labels and model logits.
39
39
40
- # Convert logits to probabilities using sigmoid
41
- probabilities = sigmoid (logits )
42
-
43
- if probabilities .shape [1 ] == 2 :
44
- positive_probs = probabilities [:, 1 ]
45
- else :
46
- positive_probs = probabilities .squeeze () # Ensure it's a 1D array
47
-
48
- # Calculate predictions based on probability threshold of 0.5
49
- predictions = (positive_probs > 0.5 ).astype (np .int32 )
40
+ Args:
41
+ references (List[float]): Ground truth binary labels (0 or 1).
42
+ logits (List[float]): Logits output from the model (raw prediction scores), which will be converted to probabilities using the sigmoid function.
50
43
51
- # Calculate accuracy
52
- accuracy = accuracy_score (labels , predictions )
44
+ Returns:
45
+ Dict[str, Any]: A dictionary containing:
46
+ - 'roc_auc': The Area Under the Receiver Operating Characteristic Curve (ROC-AUC).
47
+ - 'pr_auc': The Area Under the Precision-Recall Curve (PR-AUC).
53
48
49
+ Notes:
50
+ - The `sigmoid` function is used to convert the logits into probabilities.
51
+ - ROC-AUC measures the model's ability to distinguish between classes, while PR-AUC focuses on performance when dealing with imbalanced data.
52
+ """
53
+ # Convert logits to probabilities using sigmoid
54
+ probabilities = sigmoid (logits )
55
+ # # Calculate PR-AUC
54
56
# Calculate ROC-AUC
55
- roc_auc = roc_auc_score (labels , positive_probs )
56
-
57
- # Calculate PR-AUC
58
- precision , recall , _ = precision_recall_curve (labels , positive_probs )
57
+ roc_auc = roc_auc_score (references , probabilities )
58
+ precision , recall , _ = precision_recall_curve (references , probabilities )
59
59
pr_auc = auc (recall , precision )
60
-
61
- return {"accuracy" : accuracy , "roc_auc" : roc_auc , "pr_auc" : pr_auc }
60
+ return {"roc_auc" : roc_auc , "pr_auc" : pr_auc }
62
61
63
62
64
63
def load_pretrained_model_and_tokenizer (
65
64
model_args ,
66
65
) -> Tuple [CehrBertPreTrainedModel , CehrBertTokenizer ]:
67
- # Try to load the pretrained tokenizer
66
+ """
67
+ Loads a pretrained model and tokenizer based on the given model arguments.
68
+
69
+ Args:
70
+ model_args (Namespace): An argument object containing the following fields:
71
+ - tokenizer_name_or_path (str): The path or name of the pretrained tokenizer to load.
72
+ - model_name_or_path (str): The path or name of the pretrained model to load.
73
+ - finetune_model_type (str): The type of fine-tuning model to use. Must be one of the values in `FineTuneModelType`.
74
+
75
+ Returns:
76
+ Tuple[CehrBertPreTrainedModel, CehrBertTokenizer]:
77
+ - CehrBertPreTrainedModel: The loaded pretrained model (either a classification or LSTM model).
78
+ - CehrBertTokenizer: The loaded pretrained tokenizer.
79
+
80
+ Raises:
81
+ ValueError: If the tokenizer cannot be loaded from the specified path, or if the fine-tuning model type is invalid.
82
+
83
+ Notes:
84
+ - If loading the model fails, the function will attempt to create a new model using the provided model arguments
85
+ and the tokenizer's configuration.
86
+ - The function supports two types of models for fine-tuning:
87
+ - `CehrBertForClassification` for pooling-based models.
88
+ - `CehrBertLstmForClassification` for LSTM-based models.
89
+ """
68
90
try :
69
91
tokenizer = CehrBertTokenizer .from_pretrained (model_args .tokenizer_name_or_path )
70
92
except Exception :
@@ -284,7 +306,6 @@ def assign_split(example):
284
306
data_collator = collator ,
285
307
train_dataset = processed_dataset ["train" ],
286
308
eval_dataset = processed_dataset ["validation" ],
287
- compute_metrics = compute_metrics ,
288
309
callbacks = [EarlyStoppingCallback (early_stopping_patience = model_args .early_stopping_patience )],
289
310
args = training_args ,
290
311
)
@@ -297,6 +318,8 @@ def assign_split(example):
297
318
298
319
trainer .log_metrics ("train" , metrics )
299
320
trainer .save_metrics ("train" , metrics )
321
+ trainer .log_metrics ("eval" , metrics )
322
+ trainer .save_metrics ("eval" , metrics )
300
323
trainer .save_state ()
301
324
302
325
if training_args .do_predict :
@@ -309,12 +332,6 @@ def assign_split(example):
309
332
trainer ._load_from_checkpoint (training_args .output_dir )
310
333
311
334
test_results = trainer .predict (processed_dataset ["test" ])
312
- # Save results to JSON
313
- test_results_path = os .path .join (training_args .output_dir , "test_results.json" )
314
- with open (test_results_path , "w" ) as f :
315
- json .dump (test_results .metrics , f , indent = 4 )
316
-
317
- LOG .info (f"Test results: { test_results .metrics } " )
318
335
319
336
person_ids = [row ["person_id" ] for row in processed_dataset ["test" ]]
320
337
@@ -330,6 +347,14 @@ def assign_split(example):
330
347
prediction_pd = pd .DataFrame ({"person_id " : person_ids , "prediction" : predictions , "label" : labels })
331
348
prediction_pd .to_csv (os .path .join (training_args .output_dir , "test_predictions.csv" ), index = False )
332
349
350
+ # Save results to JSON
351
+ metrics = compute_metrics (references = labels , logits = predictions )
352
+ test_results_path = os .path .join (training_args .output_dir , "test_results.json" )
353
+ with open (test_results_path , "w" ) as f :
354
+ json .dump (metrics , f , indent = 4 )
355
+
356
+ LOG .info (f"Test results: { metrics } " )
357
+
333
358
334
359
if __name__ == "__main__" :
335
360
main ()
0 commit comments