diff --git a/src/cehrbert/runners/hf_cehrbert_finetune_runner.py b/src/cehrbert/runners/hf_cehrbert_finetune_runner.py index 2ef1047..6123c03 100644 --- a/src/cehrbert/runners/hf_cehrbert_finetune_runner.py +++ b/src/cehrbert/runners/hf_cehrbert_finetune_runner.py @@ -1,23 +1,22 @@ import json -import os from pathlib import Path from typing import Any, Dict, List, Union import numpy as np import pandas as pd +import torch from datasets import DatasetDict, load_from_disk -from pandas import Series -from peft import LoraConfig, get_peft_model +from peft import LoraConfig, LoraModel, PeftModel, get_peft_model from scipy.special import expit as sigmoid from sklearn.metrics import auc, precision_recall_curve, roc_auc_score from torch.utils.data import DataLoader +from tqdm import tqdm from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed from transformers.utils import logging from cehrbert.data_generators.hf_data_generator.hf_dataset import create_cehrbert_finetuning_dataset from cehrbert.data_generators.hf_data_generator.hf_dataset_collator import CehrBertDataCollator from cehrbert.data_generators.hf_data_generator.meds_utils import create_dataset_from_meds_reader -from cehrbert.models.hf_models.config import CehrBertConfig from cehrbert.models.hf_models.hf_cehrbert import ( CehrBertForClassification, CehrBertLstmForClassification, @@ -36,7 +35,7 @@ LOG = logging.get_logger("transformers") -def compute_metrics(references: Union[List[float], Series], logits: Union[List[float], Series]) -> Dict[str, Any]: +def compute_metrics(references: Union[List[float], pd.Series], logits: Union[List[float], pd.Series]) -> Dict[str, Any]: """ Computes evaluation metrics for binary classification, including ROC-AUC and PR-AUC, based on reference labels and model logits. @@ -294,63 +293,86 @@ def assign_split(example): trainer.save_state() if training_args.do_predict: - import torch - from peft import PeftModel - from tqdm import tqdm - - if torch.cuda.is_available(): - device = torch.device("cuda") - else: - device = torch.device("cpu") - - # If lora is enabled, we add LORA adapters to the model - if model_args.use_lora: - base_model = load_finetuned_model(model_args, model_args.model_name_or_path) - model = PeftModel.from_pretrained(base_model, model_id=training_args.output_dir) - else: - model = load_finetuned_model(model_args, training_args.output_dir) - - model = model.to(device).eval() - # Create the prediction folder if not exists - test_prediction_folder = os.path.join(training_args.output_dir, "test_predictions") - Path(test_prediction_folder).mkdir(parents=True, exist_ok=True) - test_ = processed_dataset["test"] - test_.set_format("pt") test_dataloader = DataLoader( - dataset=test_, + dataset=processed_dataset["test"], batch_size=training_args.per_device_eval_batch_size, num_workers=training_args.dataloader_num_workers, collate_fn=collator, pin_memory=training_args.dataloader_pin_memory, ) - LOG.info( - "Started generating predictions for test set at %s", - test_prediction_folder, - ) - with torch.no_grad(): - index = 0 - for batch in tqdm(test_dataloader): - batched_labels = batch["classifier_label"] - batch = {k: v.to(device) for k, v in batch.items()} - cehrbert_output = model(**batch) - logits = cehrbert_output.logits.squeeze().cpu().detach().numpy() - labels = batched_labels.squeeze().cpu().detach().numpy() - prediction_pd = pd.DataFrame({"prediction": logits, "label": labels}) - prediction_pd.to_parquet(os.path.join(test_prediction_folder, f"{index}.parquet")) - index += 1 - - LOG.info( - "Started computing metrics using the test set predictions at %s", - test_prediction_folder, - ) - test_prediction_pd = pd.read_parquet(test_prediction_folder) - # Save results to JSON - metrics = compute_metrics(references=test_prediction_pd.label, logits=test_prediction_pd.prediction) - test_results_path = os.path.join(training_args.output_dir, "test_results.json") - with open(test_results_path, "w") as f: - json.dump(metrics, f, indent=4) - - LOG.info(f"Test results: {metrics}") + do_predict(test_dataloader, model_args, training_args) + + +def do_predict(test_dataloader: DataLoader, model_args: ModelArguments, training_args: TrainingArguments): + """ + Performs inference on the test dataset using a fine-tuned model, saves predictions and evaluation metrics. + + The reason we created this custom do_predict is that there is a memory leakage for transformers trainer.predict(), + for large test sets, it will throw the CPU OOM error + + Args: + test_dataloader (DataLoader): DataLoader containing the test dataset, with batches of input features and labels. + model_args (ModelArguments): Arguments for configuring and loading the fine-tuned model. + training_args (TrainingArguments): Arguments related to training, evaluation, and output directories. + + Returns: + None. Results are saved to disk. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load model and LoRA adapters if applicable + model = ( + load_finetuned_model(model_args, model_args.model_name_or_path) + if not model_args.use_lora + else load_lora_model(model_args, training_args) + ) + + model = model.to(device).eval() + + # Ensure prediction folder exists + test_prediction_folder = Path(training_args.output_dir) / "test_predictions" + test_prediction_folder.mkdir(parents=True, exist_ok=True) + + LOG.info("Generating predictions for test set at %s", test_prediction_folder) + + test_losses = [] + with torch.no_grad(): + for index, batch in enumerate(tqdm(test_dataloader, desc="Predicting")): + + batch = {k: v.to(device) for k, v in batch.items()} + # Forward pass + output = model(**batch) + test_losses.append(output.loss.item()) + + # Collect logits and labels for prediction + logits = output.logits.cpu().numpy().squeeze() + labels = batch["classifier_label"].cpu().numpy().squeeze() + # Save predictions to parquet file + test_prediction_pd = pd.DataFrame({"prediction": logits, "label": labels}) + test_prediction_pd.to_parquet(test_prediction_folder / f"{index}.parquet") + + LOG.info("Computing metrics using the test set predictions at %s", test_prediction_folder) + + # Load all predictions + test_prediction_pd = pd.read_parquet(test_prediction_folder) + + # Compute metrics and save results + metrics = compute_metrics(references=test_prediction_pd.label, logits=test_prediction_pd.prediction) + metrics["test_loss"] = np.mean(test_losses) + + test_results_path = Path(training_args.output_dir) / "test_results.json" + with open(test_results_path, "w") as f: + json.dump(metrics, f, indent=4) + + LOG.info("Test results: %s", metrics) + + +def load_lora_model(model_args, training_args) -> Union[LoraModel, CehrBertPreTrainedModel]: + LOG.info("Loading base model from %s", model_args.model_name_or_path) + base_model = load_finetuned_model(model_args, model_args.model_name_or_path) + + LOG.info("Loading LoRA adapter from %s", training_args.output_dir) + return PeftModel.from_pretrained(base_model, model_id=training_args.output_dir) if __name__ == "__main__":