From c8410e00c738a903e078883a0b7a2bbb2ff82134 Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Sat, 7 Sep 2024 08:01:09 -0400 Subject: [PATCH] Updated meds_reader to 0.1.9 (#51) * Updated meds_reader to 0.1.9 * Switched back from RunTimeError to Exception for loading the meds data * Changed patient_id to subject_id in meds_utils.py file * Used fully qualified names for the imports instead of using their relative paths * Added the missing positional parameter for _create_cehrbert_data_from_meds * Restored the AttType cehr_bert that's automatically updated by the code formatter * Restored the enum type and MedsToCehrBertConversion class check --- pyproject.toml | 2 +- .../hf_data_generator/hf_dataset.py | 6 +- .../meds_to_cehrbert_micmic4.py | 2 +- .../hf_data_generator/meds_utils.py | 109 ++++++++++++++---- .../runners/hf_cehrbert_pretrain_runner.py | 2 +- .../decorators/patient_event_decorator.py | 2 +- 6 files changed, 96 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b26988e3..e687ea4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "femr==0.2.0", "Jinja2==3.1.3", "meds==0.3.3", - "meds_reader==0.1.1", + "meds_reader==0.1.9", "networkx==3.2.1", "numpy==1.24.3", "packaging==23.2", diff --git a/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py b/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py index 6ca0843c..41f19bb7 100644 --- a/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py +++ b/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py @@ -2,14 +2,14 @@ from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict -from ...data_generators.hf_data_generator.hf_dataset_mapping import ( +from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import ( DatasetMapping, HFFineTuningMapping, HFTokenizationMapping, SortPatientSequenceMapping, ) -from ...models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer -from ...runners.hf_runner_argument_dataclass import DataTrainingArguments +from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer +from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments CEHRBERT_COLUMNS = [ "concept_ids", diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_micmic4.py b/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_micmic4.py index 0ffbb696..9e8a57dd 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_micmic4.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_micmic4.py @@ -1,7 +1,7 @@ import re from typing import List -from ....data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import ( +from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import ( EventConversionRule, MedsToCehrBertConversion, ) diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py index 61260b91..0fe9ecc3 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py @@ -10,13 +10,13 @@ import pandas as pd from datasets import Dataset, DatasetDict, Split -from ...data_generators.hf_data_generator.hf_dataset import apply_cehrbert_dataset_mapping -from ...data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping, birth_codes -from ...data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import ( +from cehrbert.data_generators.hf_data_generator.hf_dataset import apply_cehrbert_dataset_mapping +from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping, birth_codes +from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import ( MedsToCehrBertConversion, ) -from ...med_extension.schema_extension import CehrBertPatient, Event, Visit -from ...runners.hf_runner_argument_dataclass import DataTrainingArguments, MedsToCehrBertConversionType +from cehrbert.med_extension.schema_extension import CehrBertPatient, Event, Visit +from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments, MedsToCehrBertConversionType UNKNOWN_VALUE = "Unknown" DEFAULT_ED_CONCEPT_ID = "9203" @@ -39,19 +39,48 @@ def get_meds_to_cehrbert_conversion_cls( raise RuntimeError(f"{meds_to_cehrbert_conversion_type} is not a valid MedsToCehrBertConversionType") -def get_patient_split(meds_reader_db_path: str) -> Dict[str, List[int]]: - patient_split = pd.read_parquet(os.path.join(meds_reader_db_path, "metadata/patient_splits.parquet")) - result = {str(group): records["patient_id"].tolist() for group, records in patient_split.groupby("split")} +def get_subject_split(meds_reader_db_path: str) -> Dict[str, List[int]]: + patient_split = pd.read_parquet(os.path.join(meds_reader_db_path, "metadata/subject_splits.parquet")) + result = {str(group): records["subject_id"].tolist() for group, records in patient_split.groupby("split")} return result class PatientBlock: + """ + Represents a block of medical events for a single patient visit, including. + + inferred visit type and various admission and discharge statuses. + + Attributes: + visit_id (int): The unique ID of the visit. + events (List[meds_reader.Event]): A list of medical events associated with this visit. + min_time (datetime): The earliest event time in the visit. + max_time (datetime): The latest event time in the visit. + conversion (MedsToCehrBertConversion): Conversion object for mapping event codes to CEHR-BERT. + has_ed_admission (bool): Whether the visit includes an emergency department (ED) admission event. + has_admission (bool): Whether the visit includes an admission event. + has_discharge (bool): Whether the visit includes a discharge event. + visit_type (str): The inferred type of visit, such as inpatient, ED, or outpatient. + """ + def __init__( self, events: List[meds_reader.Event], visit_id: int, conversion: MedsToCehrBertConversion, ): + """ + Initializes a PatientBlock instance, inferring the visit type based on the events and caching. + + admission and discharge status. + + Args: + events (List[meds_reader.Event]): The medical events associated with the visit. + visit_id (int): The unique ID of the visit. + conversion (MedsToCehrBertConversion): Conversion object for mapping event codes to CEHR-BERT. + + Attributes are initialized to store visit metadata and calculate admission/discharge statuses. + """ self.visit_id = visit_id self.events = events self.min_time = events[0].time @@ -73,7 +102,12 @@ def __init__( self.visit_type = DEFAULT_OUTPATIENT_CONCEPT_ID def _has_ed_admission(self) -> bool: - """Make this configurable in the future.""" + """ + Determines if the visit includes an emergency department (ED) admission event. + + Returns: + bool: True if an ED admission event is found, False otherwise. + """ for event in self.events: for matching_rule in self.conversion.get_ed_admission_matching_rules(): if re.match(matching_rule, event.code): @@ -81,6 +115,12 @@ def _has_ed_admission(self) -> bool: return False def _has_admission(self) -> bool: + """ + Determines if the visit includes a hospital admission event. + + Returns: + bool: True if an admission event is found, False otherwise. + """ for event in self.events: for matching_rule in self.conversion.get_admission_matching_rules(): if re.match(matching_rule, event.code): @@ -88,6 +128,12 @@ def _has_admission(self) -> bool: return False def _has_discharge(self) -> bool: + """ + Determines if the visit includes a discharge event. + + Returns: + bool: True if a discharge event is found, False otherwise. + """ for event in self.events: for matching_rule in self.conversion.get_discharge_matching_rules(): if re.match(matching_rule, event.code): @@ -95,6 +141,12 @@ def _has_discharge(self) -> bool: return False def get_discharge_facility(self) -> Optional[str]: + """ + Extracts the discharge facility code from the discharge event, if present. + + Returns: + Optional[str]: The sanitized discharge facility code, or None if no discharge event is found. + """ if self._has_discharge(): for event in self.events: for matching_rule in self.conversion.get_discharge_matching_rules(): @@ -105,12 +157,22 @@ def get_discharge_facility(self) -> Optional[str]: return None def _convert_event(self, event) -> List[Event]: + """ + Converts a medical event into a list of CEHR-BERT-compatible events, potentially parsing. + + numeric values from text-based events. + + Args: + event (meds_reader.Event): The medical event to be converted. + + Returns: + List[Event]: A list of converted events, possibly numeric, based on the original event's code and value. + """ code = event.code time = getattr(event, "time", None) text_value = getattr(event, "text_value", None) numeric_value = getattr(event, "numeric_value", None) - # We try to parse the numeric values from the text value, in other words, - # we try to construct numeric events from the event with a text value + if numeric_value is None and text_value is not None: conversion_rule = self.conversion.get_text_event_to_numeric_events_rule(code) if conversion_rule: @@ -140,6 +202,12 @@ def _convert_event(self, event) -> List[Event]: ] def get_meds_events(self) -> Iterable[Event]: + """ + Retrieves all medication events for the visit, converting each raw event if necessary. + + Returns: + Iterable[Event]: A list of CEHR-BERT-compatible medication events for the visit. + """ events = [] for e in self.events: events.extend(self._convert_event(e)) @@ -147,7 +215,7 @@ def get_meds_events(self) -> Iterable[Event]: def convert_one_patient( - patient: meds_reader.Patient, + patient: meds_reader.Subject, conversion: MedsToCehrBertConversion, default_visit_id: int = 1, prediction_time: datetime = None, @@ -296,10 +364,10 @@ def convert_one_patient( age_at_index -= 1 # birth_datetime can not be None - assert birth_datetime is not None, f"patient_id: {patient.patient_id} does not have a valid birth_datetime" + assert birth_datetime is not None, f"patient_id: {patient.subject_id} does not have a valid birth_datetime" return CehrBertPatient( - patient_id=patient.patient_id, + patient_id=patient.subject_id, birth_datetime=birth_datetime, visits=visits, race=race if race else UNKNOWN_VALUE, @@ -346,7 +414,7 @@ def _meds_to_cehrbert_generator( ) -> CehrBertPatient: conversion = get_meds_to_cehrbert_conversion_cls(meds_to_cehrbert_conversion_type) for shard in shards: - with meds_reader.PatientDatabase(path_to_db) as patient_database: + with meds_reader.SubjectDatabase(path_to_db) as patient_database: for patient_id, prediction_time, label in shard: patient = patient_database[patient_id] yield convert_one_patient(patient, conversion, default_visit_id, prediction_time, label) @@ -363,20 +431,21 @@ def _create_cehrbert_data_from_meds( if data_args.cohort_folder: cohort = pd.read_parquet(os.path.join(data_args.cohort_folder, split)) for cohort_row in cohort.itertuples(): - patient_id = cohort_row.patient_id + subject_id = cohort_row.subject_id prediction_time = cohort_row.prediction_time label = int(cohort_row.boolean_value) - batches.append((patient_id, prediction_time, label)) + batches.append((subject_id, prediction_time, label)) else: - patient_split = get_patient_split(data_args.data_folder) - for patient_id in patient_split[split]: - batches.append((patient_id, None, None)) + patient_split = get_subject_split(data_args.data_folder) + for subject_id in patient_split[split]: + batches.append((subject_id, None, None)) split_batches = np.array_split(np.asarray(batches), data_args.preprocessing_num_workers) batch_func = functools.partial( _meds_to_cehrbert_generator, path_to_db=data_args.data_folder, default_visit_id=default_visit_id, + meds_to_cehrbert_conversion_type=data_args.meds_to_cehrbert_conversion_type, ) dataset = Dataset.from_generator( batch_func, diff --git a/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py b/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py index e53fac93..95626693 100644 --- a/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py +++ b/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py @@ -179,7 +179,7 @@ def main(): dataset = load_from_disk(meds_extension_path) if data_args.streaming: dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers) - except RuntimeError as e: + except FileNotFoundError as e: LOG.exception(e) dataset = create_dataset_from_meds_reader(data_args, is_pretraining=True) if not data_args.streaming: diff --git a/src/cehrbert/spark_apps/decorators/patient_event_decorator.py b/src/cehrbert/spark_apps/decorators/patient_event_decorator.py index 7eaf2409..de49ddd6 100644 --- a/src/cehrbert/spark_apps/decorators/patient_event_decorator.py +++ b/src/cehrbert/spark_apps/decorators/patient_event_decorator.py @@ -16,7 +16,7 @@ class AttType(Enum): DAY = "day" WEEK = "week" MONTH = "month" - CEHR_BERT = "cehrbert" + CEHR_BERT = "cehr_bert" MIX = "mix" NONE = "none"