Skip to content

Commit

Permalink
fixed the streaming bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Nov 15, 2024
1 parent a58255d commit fbaf2a2
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 12 deletions.
34 changes: 22 additions & 12 deletions src/cehrbert/data_generators/hf_data_generator/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,7 @@ def create_cehrbert_pretraining_dataset(
required_columns = TRANSFORMER_COLUMNS + CEHRBERT_COLUMNS

# Remove patients without any records
dataset = dataset.filter(
lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]],
num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None,
batched=True,
batch_size=data_args.preprocessing_batch_size,
)
dataset = filter_dataset(dataset, data_args)

# If the data is already in meds, we don't need to sort the sequence anymore
if data_args.is_data_in_meds:
Expand Down Expand Up @@ -82,12 +77,7 @@ def create_cehrbert_finetuning_dataset(
required_columns = TRANSFORMER_COLUMNS + CEHRBERT_COLUMNS + FINETUNING_COLUMNS

# Remove patients without any records
dataset = dataset.filter(
lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]],
num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None,
batched=True,
batch_size=data_args.preprocessing_batch_size,
)
dataset = filter_dataset(dataset, data_args)

if data_args.is_data_in_meds:
mapping_functions = [
Expand Down Expand Up @@ -120,6 +110,26 @@ def create_cehrbert_finetuning_dataset(
return dataset


def filter_dataset(dataset: Union[Dataset, DatasetDict], data_args: DataTrainingArguments):
# Remove patients without any records
# check if DatatsetDict or IterableDatasetDict, if so, filter each dataset
if isinstance(dataset, DatasetDict) and data_args.streaming:
for key in dataset.keys():
dataset[key] = dataset[key].filter(
lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]],
batched=True,
batch_size=data_args.preprocessing_batch_size,
)
else:
dataset = dataset.filter(
lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]],
num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None,
batched=True,
batch_size=data_args.preprocessing_batch_size,
)
return dataset


def apply_cehrbert_dataset_mapping(
dataset: Union[DatasetDict, Dataset, IterableDataset, IterableDatasetDict],
mapping_function: DatasetMapping,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import shutil
import sys
import tempfile
import unittest
from pathlib import Path

from datasets import disable_caching

from cehrbert.runners.hf_cehrbert_pretrain_runner import main

disable_caching()
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["WANDB_MODE"] = "disabled"
os.environ["TRANSFORMERS_VERBOSITY"] = "info"


class HfCehrBertRunnerIntegrationTest(unittest.TestCase):
def setUp(self):
# Get the root folder of the project
root_folder = Path(os.path.abspath(__file__)).parent.parent.parent.parent
data_folder = os.path.join(root_folder, "sample_data", "pretrain")
# Create a temporary directory to store model and tokenizer
self.temp_dir = tempfile.mkdtemp()
self.model_folder_path = os.path.join(self.temp_dir, "model")
Path(self.model_folder_path).mkdir(parents=True, exist_ok=True)
self.dataset_prepared_path = os.path.join(self.temp_dir, "dataset_prepared_path")
Path(self.dataset_prepared_path).mkdir(parents=True, exist_ok=True)
sys.argv = [
"hf_cehrbert_pretraining_runner.py",
"--model_name_or_path",
self.model_folder_path,
"--tokenizer_name_or_path",
self.model_folder_path,
"--output_dir",
self.model_folder_path,
"--data_folder",
data_folder,
"--dataset_prepared_path",
self.dataset_prepared_path,
"--max_steps",
"10",
"--streaming",
]

def tearDown(self):
# Remove the temporary directory
shutil.rmtree(self.temp_dir)

def test_train_model(self):
main()


if __name__ == "__main__":
unittest.main()

0 comments on commit fbaf2a2

Please sign in to comment.