Skip to content

Commit e144717

Browse files
authored
Remove eval compute metrics (#54)
* Stopped using the compute_metrics to calculate roc_auc/pr_auc/accuracy for the evaluation step, the reason is that for large eval datasets, the evaluation steps can run out of the CPU memory as it keeps all predictions on CPU * Added the backward compatibility if the finetuning data is constructed from the OMOP data, where the age_at_index is labeled as age
1 parent b4f719d commit e144717

File tree

2 files changed

+57
-32
lines changed

2 files changed

+57
-32
lines changed

src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,6 @@ class HFFineTuningMapping(DatasetMapping):
453453

454454
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
455455
return {
456-
"age_at_index": record["age_at_index"],
456+
"age_at_index": record["age"] if "age" in record else record["age_at_index"],
457457
"classifier_label": record["label"],
458458
}

src/cehrbert/runners/hf_cehrbert_finetune_runner.py

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import json
22
import os
3-
from typing import Tuple
3+
from typing import Any, Dict, List, Tuple
44

55
import numpy as np
66
import pandas as pd
77
from datasets import DatasetDict, load_from_disk
88
from peft import LoraConfig, get_peft_model
99
from scipy.special import expit as sigmoid
10-
from sklearn.metrics import accuracy_score, auc, precision_recall_curve, roc_auc_score
10+
from sklearn.metrics import auc, precision_recall_curve, roc_auc_score
1111
from transformers import EarlyStoppingCallback, Trainer, set_seed
1212
from transformers.utils import logging
1313

@@ -33,38 +33,60 @@
3333
LOG = logging.get_logger("transformers")
3434

3535

36-
def compute_metrics(eval_pred):
37-
outputs, labels = eval_pred
38-
logits = outputs[0]
36+
def compute_metrics(references: List[float], logits: List[float]) -> Dict[str, Any]:
37+
"""
38+
Computes evaluation metrics for binary classification, including ROC-AUC and PR-AUC, based on reference labels and model logits.
3939
40-
# Convert logits to probabilities using sigmoid
41-
probabilities = sigmoid(logits)
42-
43-
if probabilities.shape[1] == 2:
44-
positive_probs = probabilities[:, 1]
45-
else:
46-
positive_probs = probabilities.squeeze() # Ensure it's a 1D array
47-
48-
# Calculate predictions based on probability threshold of 0.5
49-
predictions = (positive_probs > 0.5).astype(np.int32)
40+
Args:
41+
references (List[float]): Ground truth binary labels (0 or 1).
42+
logits (List[float]): Logits output from the model (raw prediction scores), which will be converted to probabilities using the sigmoid function.
5043
51-
# Calculate accuracy
52-
accuracy = accuracy_score(labels, predictions)
44+
Returns:
45+
Dict[str, Any]: A dictionary containing:
46+
- 'roc_auc': The Area Under the Receiver Operating Characteristic Curve (ROC-AUC).
47+
- 'pr_auc': The Area Under the Precision-Recall Curve (PR-AUC).
5348
49+
Notes:
50+
- The `sigmoid` function is used to convert the logits into probabilities.
51+
- ROC-AUC measures the model's ability to distinguish between classes, while PR-AUC focuses on performance when dealing with imbalanced data.
52+
"""
53+
# Convert logits to probabilities using sigmoid
54+
probabilities = sigmoid(logits)
55+
# # Calculate PR-AUC
5456
# Calculate ROC-AUC
55-
roc_auc = roc_auc_score(labels, positive_probs)
56-
57-
# Calculate PR-AUC
58-
precision, recall, _ = precision_recall_curve(labels, positive_probs)
57+
roc_auc = roc_auc_score(references, probabilities)
58+
precision, recall, _ = precision_recall_curve(references, probabilities)
5959
pr_auc = auc(recall, precision)
60-
61-
return {"accuracy": accuracy, "roc_auc": roc_auc, "pr_auc": pr_auc}
60+
return {"roc_auc": roc_auc, "pr_auc": pr_auc}
6261

6362

6463
def load_pretrained_model_and_tokenizer(
6564
model_args,
6665
) -> Tuple[CehrBertPreTrainedModel, CehrBertTokenizer]:
67-
# Try to load the pretrained tokenizer
66+
"""
67+
Loads a pretrained model and tokenizer based on the given model arguments.
68+
69+
Args:
70+
model_args (Namespace): An argument object containing the following fields:
71+
- tokenizer_name_or_path (str): The path or name of the pretrained tokenizer to load.
72+
- model_name_or_path (str): The path or name of the pretrained model to load.
73+
- finetune_model_type (str): The type of fine-tuning model to use. Must be one of the values in `FineTuneModelType`.
74+
75+
Returns:
76+
Tuple[CehrBertPreTrainedModel, CehrBertTokenizer]:
77+
- CehrBertPreTrainedModel: The loaded pretrained model (either a classification or LSTM model).
78+
- CehrBertTokenizer: The loaded pretrained tokenizer.
79+
80+
Raises:
81+
ValueError: If the tokenizer cannot be loaded from the specified path, or if the fine-tuning model type is invalid.
82+
83+
Notes:
84+
- If loading the model fails, the function will attempt to create a new model using the provided model arguments
85+
and the tokenizer's configuration.
86+
- The function supports two types of models for fine-tuning:
87+
- `CehrBertForClassification` for pooling-based models.
88+
- `CehrBertLstmForClassification` for LSTM-based models.
89+
"""
6890
try:
6991
tokenizer = CehrBertTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
7092
except Exception:
@@ -284,7 +306,6 @@ def assign_split(example):
284306
data_collator=collator,
285307
train_dataset=processed_dataset["train"],
286308
eval_dataset=processed_dataset["validation"],
287-
compute_metrics=compute_metrics,
288309
callbacks=[EarlyStoppingCallback(early_stopping_patience=model_args.early_stopping_patience)],
289310
args=training_args,
290311
)
@@ -297,6 +318,8 @@ def assign_split(example):
297318

298319
trainer.log_metrics("train", metrics)
299320
trainer.save_metrics("train", metrics)
321+
trainer.log_metrics("eval", metrics)
322+
trainer.save_metrics("eval", metrics)
300323
trainer.save_state()
301324

302325
if training_args.do_predict:
@@ -309,12 +332,6 @@ def assign_split(example):
309332
trainer._load_from_checkpoint(training_args.output_dir)
310333

311334
test_results = trainer.predict(processed_dataset["test"])
312-
# Save results to JSON
313-
test_results_path = os.path.join(training_args.output_dir, "test_results.json")
314-
with open(test_results_path, "w") as f:
315-
json.dump(test_results.metrics, f, indent=4)
316-
317-
LOG.info(f"Test results: {test_results.metrics}")
318335

319336
person_ids = [row["person_id"] for row in processed_dataset["test"]]
320337

@@ -330,6 +347,14 @@ def assign_split(example):
330347
prediction_pd = pd.DataFrame({"person_id ": person_ids, "prediction": predictions, "label": labels})
331348
prediction_pd.to_csv(os.path.join(training_args.output_dir, "test_predictions.csv"), index=False)
332349

350+
# Save results to JSON
351+
metrics = compute_metrics(references=labels, logits=predictions)
352+
test_results_path = os.path.join(training_args.output_dir, "test_results.json")
353+
with open(test_results_path, "w") as f:
354+
json.dump(metrics, f, indent=4)
355+
356+
LOG.info(f"Test results: {metrics}")
357+
333358

334359
if __name__ == "__main__":
335360
main()

0 commit comments

Comments
 (0)