Skip to content

Commit

Permalink
Updated meds_reader to 0.1.9 (#51)
Browse files Browse the repository at this point in the history
* Updated meds_reader to 0.1.9
* Switched back from RunTimeError to Exception for loading the meds data
* Changed patient_id to subject_id in meds_utils.py file
* Used fully qualified names for the imports instead of using their relative paths
* Added the missing positional parameter for _create_cehrbert_data_from_meds
* Restored the AttType cehr_bert that's automatically updated by the code formatter
* Restored the enum type and MedsToCehrBertConversion class check
  • Loading branch information
ChaoPang authored Sep 7, 2024
1 parent 3d09645 commit c8410e0
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 27 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
"femr==0.2.0",
"Jinja2==3.1.3",
"meds==0.3.3",
"meds_reader==0.1.1",
"meds_reader==0.1.9",
"networkx==3.2.1",
"numpy==1.24.3",
"packaging==23.2",
Expand Down
6 changes: 3 additions & 3 deletions src/cehrbert/data_generators/hf_data_generator/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict

from ...data_generators.hf_data_generator.hf_dataset_mapping import (
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
DatasetMapping,
HFFineTuningMapping,
HFTokenizationMapping,
SortPatientSequenceMapping,
)
from ...models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
from ...runners.hf_runner_argument_dataclass import DataTrainingArguments
from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments

CEHRBERT_COLUMNS = [
"concept_ids",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import List

from ....data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import (
from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import (
EventConversionRule,
MedsToCehrBertConversion,
)
Expand Down
109 changes: 89 additions & 20 deletions src/cehrbert/data_generators/hf_data_generator/meds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import pandas as pd
from datasets import Dataset, DatasetDict, Split

from ...data_generators.hf_data_generator.hf_dataset import apply_cehrbert_dataset_mapping
from ...data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping, birth_codes
from ...data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import (
from cehrbert.data_generators.hf_data_generator.hf_dataset import apply_cehrbert_dataset_mapping
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping, birth_codes
from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import (
MedsToCehrBertConversion,
)
from ...med_extension.schema_extension import CehrBertPatient, Event, Visit
from ...runners.hf_runner_argument_dataclass import DataTrainingArguments, MedsToCehrBertConversionType
from cehrbert.med_extension.schema_extension import CehrBertPatient, Event, Visit
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments, MedsToCehrBertConversionType

UNKNOWN_VALUE = "Unknown"
DEFAULT_ED_CONCEPT_ID = "9203"
Expand All @@ -39,19 +39,48 @@ def get_meds_to_cehrbert_conversion_cls(
raise RuntimeError(f"{meds_to_cehrbert_conversion_type} is not a valid MedsToCehrBertConversionType")


def get_patient_split(meds_reader_db_path: str) -> Dict[str, List[int]]:
patient_split = pd.read_parquet(os.path.join(meds_reader_db_path, "metadata/patient_splits.parquet"))
result = {str(group): records["patient_id"].tolist() for group, records in patient_split.groupby("split")}
def get_subject_split(meds_reader_db_path: str) -> Dict[str, List[int]]:
patient_split = pd.read_parquet(os.path.join(meds_reader_db_path, "metadata/subject_splits.parquet"))
result = {str(group): records["subject_id"].tolist() for group, records in patient_split.groupby("split")}
return result


class PatientBlock:
"""
Represents a block of medical events for a single patient visit, including.
inferred visit type and various admission and discharge statuses.
Attributes:
visit_id (int): The unique ID of the visit.
events (List[meds_reader.Event]): A list of medical events associated with this visit.
min_time (datetime): The earliest event time in the visit.
max_time (datetime): The latest event time in the visit.
conversion (MedsToCehrBertConversion): Conversion object for mapping event codes to CEHR-BERT.
has_ed_admission (bool): Whether the visit includes an emergency department (ED) admission event.
has_admission (bool): Whether the visit includes an admission event.
has_discharge (bool): Whether the visit includes a discharge event.
visit_type (str): The inferred type of visit, such as inpatient, ED, or outpatient.
"""

def __init__(
self,
events: List[meds_reader.Event],
visit_id: int,
conversion: MedsToCehrBertConversion,
):
"""
Initializes a PatientBlock instance, inferring the visit type based on the events and caching.
admission and discharge status.
Args:
events (List[meds_reader.Event]): The medical events associated with the visit.
visit_id (int): The unique ID of the visit.
conversion (MedsToCehrBertConversion): Conversion object for mapping event codes to CEHR-BERT.
Attributes are initialized to store visit metadata and calculate admission/discharge statuses.
"""
self.visit_id = visit_id
self.events = events
self.min_time = events[0].time
Expand All @@ -73,28 +102,51 @@ def __init__(
self.visit_type = DEFAULT_OUTPATIENT_CONCEPT_ID

def _has_ed_admission(self) -> bool:
"""Make this configurable in the future."""
"""
Determines if the visit includes an emergency department (ED) admission event.
Returns:
bool: True if an ED admission event is found, False otherwise.
"""
for event in self.events:
for matching_rule in self.conversion.get_ed_admission_matching_rules():
if re.match(matching_rule, event.code):
return True
return False

def _has_admission(self) -> bool:
"""
Determines if the visit includes a hospital admission event.
Returns:
bool: True if an admission event is found, False otherwise.
"""
for event in self.events:
for matching_rule in self.conversion.get_admission_matching_rules():
if re.match(matching_rule, event.code):
return True
return False

def _has_discharge(self) -> bool:
"""
Determines if the visit includes a discharge event.
Returns:
bool: True if a discharge event is found, False otherwise.
"""
for event in self.events:
for matching_rule in self.conversion.get_discharge_matching_rules():
if re.match(matching_rule, event.code):
return True
return False

def get_discharge_facility(self) -> Optional[str]:
"""
Extracts the discharge facility code from the discharge event, if present.
Returns:
Optional[str]: The sanitized discharge facility code, or None if no discharge event is found.
"""
if self._has_discharge():
for event in self.events:
for matching_rule in self.conversion.get_discharge_matching_rules():
Expand All @@ -105,12 +157,22 @@ def get_discharge_facility(self) -> Optional[str]:
return None

def _convert_event(self, event) -> List[Event]:
"""
Converts a medical event into a list of CEHR-BERT-compatible events, potentially parsing.
numeric values from text-based events.
Args:
event (meds_reader.Event): The medical event to be converted.
Returns:
List[Event]: A list of converted events, possibly numeric, based on the original event's code and value.
"""
code = event.code
time = getattr(event, "time", None)
text_value = getattr(event, "text_value", None)
numeric_value = getattr(event, "numeric_value", None)
# We try to parse the numeric values from the text value, in other words,
# we try to construct numeric events from the event with a text value

if numeric_value is None and text_value is not None:
conversion_rule = self.conversion.get_text_event_to_numeric_events_rule(code)
if conversion_rule:
Expand Down Expand Up @@ -140,14 +202,20 @@ def _convert_event(self, event) -> List[Event]:
]

def get_meds_events(self) -> Iterable[Event]:
"""
Retrieves all medication events for the visit, converting each raw event if necessary.
Returns:
Iterable[Event]: A list of CEHR-BERT-compatible medication events for the visit.
"""
events = []
for e in self.events:
events.extend(self._convert_event(e))
return events


def convert_one_patient(
patient: meds_reader.Patient,
patient: meds_reader.Subject,
conversion: MedsToCehrBertConversion,
default_visit_id: int = 1,
prediction_time: datetime = None,
Expand Down Expand Up @@ -296,10 +364,10 @@ def convert_one_patient(
age_at_index -= 1

# birth_datetime can not be None
assert birth_datetime is not None, f"patient_id: {patient.patient_id} does not have a valid birth_datetime"
assert birth_datetime is not None, f"patient_id: {patient.subject_id} does not have a valid birth_datetime"

return CehrBertPatient(
patient_id=patient.patient_id,
patient_id=patient.subject_id,
birth_datetime=birth_datetime,
visits=visits,
race=race if race else UNKNOWN_VALUE,
Expand Down Expand Up @@ -346,7 +414,7 @@ def _meds_to_cehrbert_generator(
) -> CehrBertPatient:
conversion = get_meds_to_cehrbert_conversion_cls(meds_to_cehrbert_conversion_type)
for shard in shards:
with meds_reader.PatientDatabase(path_to_db) as patient_database:
with meds_reader.SubjectDatabase(path_to_db) as patient_database:
for patient_id, prediction_time, label in shard:
patient = patient_database[patient_id]
yield convert_one_patient(patient, conversion, default_visit_id, prediction_time, label)
Expand All @@ -363,20 +431,21 @@ def _create_cehrbert_data_from_meds(
if data_args.cohort_folder:
cohort = pd.read_parquet(os.path.join(data_args.cohort_folder, split))
for cohort_row in cohort.itertuples():
patient_id = cohort_row.patient_id
subject_id = cohort_row.subject_id
prediction_time = cohort_row.prediction_time
label = int(cohort_row.boolean_value)
batches.append((patient_id, prediction_time, label))
batches.append((subject_id, prediction_time, label))
else:
patient_split = get_patient_split(data_args.data_folder)
for patient_id in patient_split[split]:
batches.append((patient_id, None, None))
patient_split = get_subject_split(data_args.data_folder)
for subject_id in patient_split[split]:
batches.append((subject_id, None, None))

split_batches = np.array_split(np.asarray(batches), data_args.preprocessing_num_workers)
batch_func = functools.partial(
_meds_to_cehrbert_generator,
path_to_db=data_args.data_folder,
default_visit_id=default_visit_id,
meds_to_cehrbert_conversion_type=data_args.meds_to_cehrbert_conversion_type,
)
dataset = Dataset.from_generator(
batch_func,
Expand Down
2 changes: 1 addition & 1 deletion src/cehrbert/runners/hf_cehrbert_pretrain_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def main():
dataset = load_from_disk(meds_extension_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
except RuntimeError as e:
except FileNotFoundError as e:
LOG.exception(e)
dataset = create_dataset_from_meds_reader(data_args, is_pretraining=True)
if not data_args.streaming:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class AttType(Enum):
DAY = "day"
WEEK = "week"
MONTH = "month"
CEHR_BERT = "cehrbert"
CEHR_BERT = "cehr_bert"
MIX = "mix"
NONE = "none"

Expand Down

0 comments on commit c8410e0

Please sign in to comment.