Skip to content

Commit

Permalink
fixed a bug in reading the predicted_boolean_probability for computin…
Browse files Browse the repository at this point in the history
…g auc (#82)

* fixed a bug in reading the predicted_boolean_probability for computing auc

* added logic for handling invalid time stamps in finetuning dataset
  • Loading branch information
ChaoPang authored Dec 9, 2024
1 parent 07292d0 commit 5bcbd61
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 14 deletions.
16 changes: 10 additions & 6 deletions src/cehrbert/runners/hf_cehrbert_finetune_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,15 @@ def do_predict(test_dataloader: DataLoader, model_args: ModelArguments, training
with torch.no_grad():
for index, batch in enumerate(tqdm(test_dataloader, desc="Predicting")):
person_ids = batch.pop("person_id").numpy().squeeze().astype(int)
index_dates = (
map(datetime.fromtimestamp, batch.pop("index_date").numpy().squeeze().tolist())
if "index_date" in batch
else None
)
# Extract and process index_dates
index_dates = None
if "index_date" in batch:
try:
timestamps = batch.pop("index_date").numpy().squeeze().tolist()
# Handle potential NaN or invalid timestamps
index_dates = [datetime.fromtimestamp(ts) if not np.isnan(ts) else None for ts in timestamps]
except (ValueError, OverflowError, TypeError):
index_dates = [None] * len(timestamps)
batch = {k: v.to(device) for k, v in batch.items()}
# Forward pass
output = model(**batch, output_attentions=False, output_hidden_states=False)
Expand All @@ -377,7 +381,7 @@ def do_predict(test_dataloader: DataLoader, model_args: ModelArguments, training
test_prediction_pd = pd.read_parquet(test_prediction_folder)
# Compute metrics and save results
metrics = compute_metrics(
references=test_prediction_pd.boolean_value, probs=test_prediction_pd.boolean_prediction_probability
references=test_prediction_pd.boolean_value, probs=test_prediction_pd.predicted_boolean_probability
)
metrics["test_loss"] = np.mean(test_losses)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from datasets import disable_caching

from cehrbert.runners.hf_cehrbert_finetune_runner import main as finetune_main
from cehrbert.runners.hf_cehrbert_pretrain_runner import main

disable_caching()
Expand All @@ -19,13 +20,21 @@ class HfCehrBertRunnerIntegrationTest(unittest.TestCase):
def setUp(self):
# Get the root folder of the project
root_folder = Path(os.path.abspath(__file__)).parent.parent.parent.parent
data_folder = os.path.join(root_folder, "sample_data", "pretrain")
self.pretrain_data_folder = os.path.join(root_folder, "sample_data", "pretrain")
self.finetune_data_folder = os.path.join(root_folder, "sample_data", "finetune", "full")
# Create a temporary directory to store model and tokenizer
self.temp_dir = tempfile.mkdtemp()
self.model_folder_path = os.path.join(self.temp_dir, "model")
self.finetuned_model_folder_path = os.path.join(self.temp_dir, "model_finetuned")
Path(self.model_folder_path).mkdir(parents=True, exist_ok=True)
self.dataset_prepared_path = os.path.join(self.temp_dir, "dataset_prepared_path")
Path(self.dataset_prepared_path).mkdir(parents=True, exist_ok=True)

def tearDown(self):
# Remove the temporary directory
shutil.rmtree(self.temp_dir)

def test_train_model(self):
sys.argv = [
"hf_cehrbert_pretraining_runner.py",
"--model_name_or_path",
Expand All @@ -35,19 +44,39 @@ def setUp(self):
"--output_dir",
self.model_folder_path,
"--data_folder",
data_folder,
self.pretrain_data_folder,
"--dataset_prepared_path",
self.dataset_prepared_path,
"--max_steps",
"10",
]

def tearDown(self):
# Remove the temporary directory
shutil.rmtree(self.temp_dir)

def test_train_model(self):
main()
sys.argv = [
"hf_cehrbert_finetune_runner.py",
"--model_name_or_path",
self.model_folder_path,
"--tokenizer_name_or_path",
self.model_folder_path,
"--output_dir",
self.finetuned_model_folder_path,
"--data_folder",
self.finetune_data_folder,
"--dataset_prepared_path",
self.dataset_prepared_path,
"--max_steps",
"10",
"--save_strategy",
"steps",
"--evaluation_strategy",
"steps",
"--do_train",
"true",
"--do_predict",
"true",
"--load_best_model_at_end",
"true",
]
finetune_main()


if __name__ == "__main__":
Expand Down

0 comments on commit 5bcbd61

Please sign in to comment.