Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed the streaming bug #75

Merged
merged 2 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions src/cehrbert/data_generators/hf_data_generator/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,7 @@ def create_cehrbert_pretraining_dataset(
required_columns = TRANSFORMER_COLUMNS + CEHRBERT_COLUMNS

# Remove patients without any records
dataset = dataset.filter(
lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]],
num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None,
batched=True,
batch_size=data_args.preprocessing_batch_size,
)
dataset = filter_dataset(dataset, data_args)

# If the data is already in meds, we don't need to sort the sequence anymore
if data_args.is_data_in_meds:
Expand Down Expand Up @@ -82,12 +77,7 @@ def create_cehrbert_finetuning_dataset(
required_columns = TRANSFORMER_COLUMNS + CEHRBERT_COLUMNS + FINETUNING_COLUMNS

# Remove patients without any records
dataset = dataset.filter(
lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]],
num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None,
batched=True,
batch_size=data_args.preprocessing_batch_size,
)
dataset = filter_dataset(dataset, data_args)

if data_args.is_data_in_meds:
mapping_functions = [
Expand Down Expand Up @@ -120,6 +110,26 @@ def create_cehrbert_finetuning_dataset(
return dataset


def filter_dataset(dataset: Union[Dataset, DatasetDict], data_args: DataTrainingArguments):
# Remove patients without any records
# check if DatatsetDict or IterableDatasetDict, if so, filter each dataset
if isinstance(dataset, DatasetDict) and data_args.streaming:
for key in dataset.keys():
dataset[key] = dataset[key].filter(
lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]],
batched=True,
batch_size=data_args.preprocessing_batch_size,
)
else:
dataset = dataset.filter(
lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]],
num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None,
batched=True,
batch_size=data_args.preprocessing_batch_size,
)
return dataset


def apply_cehrbert_dataset_mapping(
dataset: Union[DatasetDict, Dataset, IterableDataset, IterableDatasetDict],
mapping_function: DatasetMapping,
Expand Down
14 changes: 3 additions & 11 deletions src/cehrbert/runners/hf_cehrbert_pretrain_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def main():
test_size=data_args.validation_split_percentage,
seed=training_args.seed,
)
dataset = DatasetDict({"train": dataset["train"], "validation": dataset["test"]})
else:
raise RuntimeError(
f"Can not split the data. If streaming is enabled, validation_split_num needs "
Expand Down Expand Up @@ -261,20 +262,11 @@ def filter_func(examples):
if not data_args.streaming:
processed_dataset.set_format("pt")

eval_dataset = None
if isinstance(processed_dataset, DatasetDict) or isinstance(processed_dataset, IterableDatasetDict):
train_dataset = processed_dataset["train"]
if "validation" in processed_dataset:
eval_dataset = processed_dataset["validation"]
else:
train_dataset = processed_dataset

trainer = Trainer(
model=model,
data_collator=collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
# compute_metrics=compute_metrics,
train_dataset=processed_dataset["train"],
eval_dataset=processed_dataset["validation"],
args=training_args,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import shutil
import sys
import tempfile
import unittest
from pathlib import Path

from datasets import disable_caching

from cehrbert.runners.hf_cehrbert_pretrain_runner import main

disable_caching()
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["WANDB_MODE"] = "disabled"
os.environ["TRANSFORMERS_VERBOSITY"] = "info"


class HfCehrBertRunnerIntegrationTest(unittest.TestCase):
def setUp(self):
# Get the root folder of the project
root_folder = Path(os.path.abspath(__file__)).parent.parent.parent.parent
data_folder = os.path.join(root_folder, "sample_data", "pretrain")
# Create a temporary directory to store model and tokenizer
self.temp_dir = tempfile.mkdtemp()
self.model_folder_path = os.path.join(self.temp_dir, "model")
Path(self.model_folder_path).mkdir(parents=True, exist_ok=True)
self.dataset_prepared_path = os.path.join(self.temp_dir, "dataset_prepared_path")
Path(self.dataset_prepared_path).mkdir(parents=True, exist_ok=True)
sys.argv = [
"hf_cehrbert_pretraining_runner.py",
"--model_name_or_path",
self.model_folder_path,
"--tokenizer_name_or_path",
self.model_folder_path,
"--output_dir",
self.model_folder_path,
"--data_folder",
data_folder,
"--dataset_prepared_path",
self.dataset_prepared_path,
"--max_steps",
"10",
"--streaming",
]

def tearDown(self):
# Remove the temporary directory
shutil.rmtree(self.temp_dir)

def test_train_model(self):
main()


if __name__ == "__main__":
unittest.main()
Loading