Skip to content

Commit

Permalink
stopped using the compute_metrics to calculate roc_auc/pr_auc/accurac…
Browse files Browse the repository at this point in the history
…y for the evaluation step, the reason is that for large eval datasets, the evalulation steps can run out of the CPU memory as it keeps all predictions on CPU
  • Loading branch information
ChaoPang committed Sep 9, 2024
1 parent b4f719d commit 63db5d6
Showing 1 changed file with 55 additions and 31 deletions.
86 changes: 55 additions & 31 deletions src/cehrbert/runners/hf_cehrbert_finetune_runner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import json
import os
from typing import Tuple
from typing import Any, Dict, List, Tuple

import numpy as np
import pandas as pd
from datasets import DatasetDict, load_from_disk
from peft import LoraConfig, get_peft_model
from scipy.special import expit as sigmoid
from sklearn.metrics import accuracy_score, auc, precision_recall_curve, roc_auc_score
from sklearn.metrics import auc, precision_recall_curve, roc_auc_score
from transformers import EarlyStoppingCallback, Trainer, set_seed
from transformers.utils import logging

Expand All @@ -33,38 +33,59 @@
LOG = logging.get_logger("transformers")


def compute_metrics(eval_pred):
outputs, labels = eval_pred
logits = outputs[0]
def compute_metrics(references: List[float], logits: List[float]) -> Dict[str, Any]:
"""
Compute a set of evaluation metrics including accuracy, ROC-AUC, and PR-AUC.
# Convert logits to probabilities using sigmoid
probabilities = sigmoid(logits)

if probabilities.shape[1] == 2:
positive_probs = probabilities[:, 1]
else:
positive_probs = probabilities.squeeze() # Ensure it's a 1D array

# Calculate predictions based on probability threshold of 0.5
predictions = (positive_probs > 0.5).astype(np.int32)
Args:
references (list or array-like): Ground truth (correct) labels for each sample.
logits (list or array-like): Predicted scores for each sample, typically the model's output.
# Calculate accuracy
accuracy = accuracy_score(labels, predictions)
Returns:
Dict[str, Any]: A dictionary containing the computed metrics where keys represent the metric names
(e.g., 'accuracy', 'roc_auc', 'pr_auc') and values are the corresponding metric values.
This function uses the `evaluate` library to compute the following metrics:
- Accuracy: The proportion of correct predictions.
- ROC-AUC: The area under the receiver operating characteristic curve.
"""
# Convert logits to probabilities using sigmoid
probabilities = sigmoid(logits)
# # Calculate PR-AUC
# Calculate ROC-AUC
roc_auc = roc_auc_score(labels, positive_probs)

# Calculate PR-AUC
precision, recall, _ = precision_recall_curve(labels, positive_probs)
roc_auc = roc_auc_score(references, probabilities)
precision, recall, _ = precision_recall_curve(references, probabilities)
pr_auc = auc(recall, precision)

return {"accuracy": accuracy, "roc_auc": roc_auc, "pr_auc": pr_auc}
return {"roc_auc": roc_auc, "pr_auc": pr_auc}


def load_pretrained_model_and_tokenizer(
model_args,
) -> Tuple[CehrBertPreTrainedModel, CehrBertTokenizer]:
# Try to load the pretrained tokenizer
"""
Loads a pretrained model and tokenizer based on the given model arguments.
Args:
model_args (Namespace): An argument object containing the following fields:
- tokenizer_name_or_path (str): The path or name of the pretrained tokenizer to load.
- model_name_or_path (str): The path or name of the pretrained model to load.
- finetune_model_type (str): The type of fine-tuning model to use. Must be one of the values in `FineTuneModelType`.
Returns:
Tuple[CehrBertPreTrainedModel, CehrBertTokenizer]:
- CehrBertPreTrainedModel: The loaded pretrained model (either a classification or LSTM model).
- CehrBertTokenizer: The loaded pretrained tokenizer.
Raises:
ValueError: If the tokenizer cannot be loaded from the specified path, or if the fine-tuning model type is invalid.
Notes:
- If loading the model fails, the function will attempt to create a new model using the provided model arguments
and the tokenizer's configuration.
- The function supports two types of models for fine-tuning:
- `CehrBertForClassification` for pooling-based models.
- `CehrBertLstmForClassification` for LSTM-based models.
"""
try:
tokenizer = CehrBertTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
except Exception:
Expand Down Expand Up @@ -284,7 +305,6 @@ def assign_split(example):
data_collator=collator,
train_dataset=processed_dataset["train"],
eval_dataset=processed_dataset["validation"],
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=model_args.early_stopping_patience)],
args=training_args,
)
Expand All @@ -297,6 +317,8 @@ def assign_split(example):

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
trainer.save_state()

if training_args.do_predict:
Expand All @@ -309,12 +331,6 @@ def assign_split(example):
trainer._load_from_checkpoint(training_args.output_dir)

test_results = trainer.predict(processed_dataset["test"])
# Save results to JSON
test_results_path = os.path.join(training_args.output_dir, "test_results.json")
with open(test_results_path, "w") as f:
json.dump(test_results.metrics, f, indent=4)

LOG.info(f"Test results: {test_results.metrics}")

person_ids = [row["person_id"] for row in processed_dataset["test"]]

Expand All @@ -330,6 +346,14 @@ def assign_split(example):
prediction_pd = pd.DataFrame({"person_id ": person_ids, "prediction": predictions, "label": labels})
prediction_pd.to_csv(os.path.join(training_args.output_dir, "test_predictions.csv"), index=False)

# Save results to JSON
metrics = compute_metrics(references=labels, logits=predictions)
test_results_path = os.path.join(training_args.output_dir, "test_results.json")
with open(test_results_path, "w") as f:
json.dump(metrics, f, indent=4)

LOG.info(f"Test results: {metrics}")


if __name__ == "__main__":
main()

0 comments on commit 63db5d6

Please sign in to comment.