Skip to content

Commit

Permalink
cleaned up do_predict and added docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Sep 10, 2024
1 parent a5019e8 commit f13cc36
Showing 1 changed file with 78 additions and 56 deletions.
134 changes: 78 additions & 56 deletions src/cehrbert/runners/hf_cehrbert_finetune_runner.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Union

import numpy as np
import pandas as pd
import torch
from datasets import DatasetDict, load_from_disk
from pandas import Series
from peft import LoraConfig, get_peft_model
from peft import LoraConfig, LoraModel, PeftModel, get_peft_model
from scipy.special import expit as sigmoid
from sklearn.metrics import auc, precision_recall_curve, roc_auc_score
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
from transformers.utils import logging

from cehrbert.data_generators.hf_data_generator.hf_dataset import create_cehrbert_finetuning_dataset
from cehrbert.data_generators.hf_data_generator.hf_dataset_collator import CehrBertDataCollator
from cehrbert.data_generators.hf_data_generator.meds_utils import create_dataset_from_meds_reader
from cehrbert.models.hf_models.config import CehrBertConfig
from cehrbert.models.hf_models.hf_cehrbert import (
CehrBertForClassification,
CehrBertLstmForClassification,
Expand All @@ -36,7 +35,7 @@
LOG = logging.get_logger("transformers")


def compute_metrics(references: Union[List[float], Series], logits: Union[List[float], Series]) -> Dict[str, Any]:
def compute_metrics(references: Union[List[float], pd.Series], logits: Union[List[float], pd.Series]) -> Dict[str, Any]:
"""
Computes evaluation metrics for binary classification, including ROC-AUC and PR-AUC, based on reference labels and model logits.
Expand Down Expand Up @@ -294,63 +293,86 @@ def assign_split(example):
trainer.save_state()

if training_args.do_predict:
import torch
from peft import PeftModel
from tqdm import tqdm

if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")

# If lora is enabled, we add LORA adapters to the model
if model_args.use_lora:
base_model = load_finetuned_model(model_args, model_args.model_name_or_path)
model = PeftModel.from_pretrained(base_model, model_id=training_args.output_dir)
else:
model = load_finetuned_model(model_args, training_args.output_dir)

model = model.to(device).eval()
# Create the prediction folder if not exists
test_prediction_folder = os.path.join(training_args.output_dir, "test_predictions")
Path(test_prediction_folder).mkdir(parents=True, exist_ok=True)
test_ = processed_dataset["test"]
test_.set_format("pt")
test_dataloader = DataLoader(
dataset=test_,
dataset=processed_dataset["test"],
batch_size=training_args.per_device_eval_batch_size,
num_workers=training_args.dataloader_num_workers,
collate_fn=collator,
pin_memory=training_args.dataloader_pin_memory,
)
LOG.info(
"Started generating predictions for test set at %s",
test_prediction_folder,
)
with torch.no_grad():
index = 0
for batch in tqdm(test_dataloader):
batched_labels = batch["classifier_label"]
batch = {k: v.to(device) for k, v in batch.items()}
cehrbert_output = model(**batch)
logits = cehrbert_output.logits.squeeze().cpu().detach().numpy()
labels = batched_labels.squeeze().cpu().detach().numpy()
prediction_pd = pd.DataFrame({"prediction": logits, "label": labels})
prediction_pd.to_parquet(os.path.join(test_prediction_folder, f"{index}.parquet"))
index += 1

LOG.info(
"Started computing metrics using the test set predictions at %s",
test_prediction_folder,
)
test_prediction_pd = pd.read_parquet(test_prediction_folder)
# Save results to JSON
metrics = compute_metrics(references=test_prediction_pd.label, logits=test_prediction_pd.prediction)
test_results_path = os.path.join(training_args.output_dir, "test_results.json")
with open(test_results_path, "w") as f:
json.dump(metrics, f, indent=4)

LOG.info(f"Test results: {metrics}")
do_predict(test_dataloader, model_args, training_args)


def do_predict(test_dataloader: DataLoader, model_args: ModelArguments, training_args: TrainingArguments):
"""
Performs inference on the test dataset using a fine-tuned model, saves predictions and evaluation metrics.
The reason we created this custom do_predict is that there is a memory leakage for transformers trainer.predict(),
for large test sets, it will throw the CPU OOM error
Args:
test_dataloader (DataLoader): DataLoader containing the test dataset, with batches of input features and labels.
model_args (ModelArguments): Arguments for configuring and loading the fine-tuned model.
training_args (TrainingArguments): Arguments related to training, evaluation, and output directories.
Returns:
None. Results are saved to disk.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model and LoRA adapters if applicable
model = (
load_finetuned_model(model_args, model_args.model_name_or_path)
if not model_args.use_lora
else load_lora_model(model_args, training_args)
)

model = model.to(device).eval()

# Ensure prediction folder exists
test_prediction_folder = Path(training_args.output_dir) / "test_predictions"
test_prediction_folder.mkdir(parents=True, exist_ok=True)

LOG.info("Generating predictions for test set at %s", test_prediction_folder)

test_losses = []
with torch.no_grad():
for index, batch in enumerate(tqdm(test_dataloader, desc="Predicting")):

batch = {k: v.to(device) for k, v in batch.items()}
# Forward pass
output = model(**batch)
test_losses.append(output.loss.item())

# Collect logits and labels for prediction
logits = output.logits.cpu().numpy().squeeze()
labels = batch["classifier_label"].cpu().numpy().squeeze()
# Save predictions to parquet file
test_prediction_pd = pd.DataFrame({"prediction": logits, "label": labels})
test_prediction_pd.to_parquet(test_prediction_folder / f"{index}.parquet")

LOG.info("Computing metrics using the test set predictions at %s", test_prediction_folder)

# Load all predictions
test_prediction_pd = pd.read_parquet(test_prediction_folder)

# Compute metrics and save results
metrics = compute_metrics(references=test_prediction_pd.label, logits=test_prediction_pd.prediction)
metrics["test_loss"] = np.mean(test_losses)

test_results_path = Path(training_args.output_dir) / "test_results.json"
with open(test_results_path, "w") as f:
json.dump(metrics, f, indent=4)

LOG.info("Test results: %s", metrics)


def load_lora_model(model_args, training_args) -> Union[LoraModel, CehrBertPreTrainedModel]:
LOG.info("Loading base model from %s", model_args.model_name_or_path)
base_model = load_finetuned_model(model_args, model_args.model_name_or_path)

LOG.info("Loading LoRA adapter from %s", training_args.output_dir)
return PeftModel.from_pretrained(base_model, model_id=training_args.output_dir)


if __name__ == "__main__":
Expand Down

0 comments on commit f13cc36

Please sign in to comment.