Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated meds_reader to 0.1.9 #51

Merged
merged 8 commits into from
Sep 7, 2024
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
"femr==0.2.0",
"Jinja2==3.1.3",
"meds==0.3.3",
"meds_reader==0.1.1",
"meds_reader==0.1.9",
"networkx==3.2.1",
"numpy==1.24.3",
"packaging==23.2",
Expand Down
6 changes: 3 additions & 3 deletions src/cehrbert/data_generators/hf_data_generator/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict

from ...data_generators.hf_data_generator.hf_dataset_mapping import (
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
DatasetMapping,
HFFineTuningMapping,
HFTokenizationMapping,
SortPatientSequenceMapping,
)
from ...models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
from ...runners.hf_runner_argument_dataclass import DataTrainingArguments
from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments

CEHRBERT_COLUMNS = [
"concept_ids",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import List

from ....data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import (
from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import (
EventConversionRule,
MedsToCehrBertConversion,
)
Expand Down
109 changes: 89 additions & 20 deletions src/cehrbert/data_generators/hf_data_generator/meds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import pandas as pd
from datasets import Dataset, DatasetDict, Split

from ...data_generators.hf_data_generator.hf_dataset import apply_cehrbert_dataset_mapping
from ...data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping, birth_codes
from ...data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import (
from cehrbert.data_generators.hf_data_generator.hf_dataset import apply_cehrbert_dataset_mapping
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping, birth_codes
from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import (
MedsToCehrBertConversion,
)
from ...med_extension.schema_extension import CehrBertPatient, Event, Visit
from ...runners.hf_runner_argument_dataclass import DataTrainingArguments, MedsToCehrBertConversionType
from cehrbert.med_extension.schema_extension import CehrBertPatient, Event, Visit
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments, MedsToCehrBertConversionType

UNKNOWN_VALUE = "Unknown"
DEFAULT_ED_CONCEPT_ID = "9203"
Expand All @@ -39,19 +39,48 @@ def get_meds_to_cehrbert_conversion_cls(
raise RuntimeError(f"{meds_to_cehrbert_conversion_type} is not a valid MedsToCehrBertConversionType")


def get_patient_split(meds_reader_db_path: str) -> Dict[str, List[int]]:
patient_split = pd.read_parquet(os.path.join(meds_reader_db_path, "metadata/patient_splits.parquet"))
result = {str(group): records["patient_id"].tolist() for group, records in patient_split.groupby("split")}
def get_subject_split(meds_reader_db_path: str) -> Dict[str, List[int]]:
patient_split = pd.read_parquet(os.path.join(meds_reader_db_path, "metadata/subject_splits.parquet"))
result = {str(group): records["subject_id"].tolist() for group, records in patient_split.groupby("split")}
return result


class PatientBlock:
"""
Represents a block of medical events for a single patient visit, including.
inferred visit type and various admission and discharge statuses.
Attributes:
visit_id (int): The unique ID of the visit.
events (List[meds_reader.Event]): A list of medical events associated with this visit.
min_time (datetime): The earliest event time in the visit.
max_time (datetime): The latest event time in the visit.
conversion (MedsToCehrBertConversion): Conversion object for mapping event codes to CEHR-BERT.
has_ed_admission (bool): Whether the visit includes an emergency department (ED) admission event.
has_admission (bool): Whether the visit includes an admission event.
has_discharge (bool): Whether the visit includes a discharge event.
visit_type (str): The inferred type of visit, such as inpatient, ED, or outpatient.
"""

def __init__(
self,
events: List[meds_reader.Event],
visit_id: int,
conversion: MedsToCehrBertConversion,
):
"""
Initializes a PatientBlock instance, inferring the visit type based on the events and caching.
admission and discharge status.
Args:
events (List[meds_reader.Event]): The medical events associated with the visit.
visit_id (int): The unique ID of the visit.
conversion (MedsToCehrBertConversion): Conversion object for mapping event codes to CEHR-BERT.
Attributes are initialized to store visit metadata and calculate admission/discharge statuses.
"""
self.visit_id = visit_id
self.events = events
self.min_time = events[0].time
Expand All @@ -73,28 +102,51 @@ def __init__(
self.visit_type = DEFAULT_OUTPATIENT_CONCEPT_ID

def _has_ed_admission(self) -> bool:
"""Make this configurable in the future."""
"""
Determines if the visit includes an emergency department (ED) admission event.
Returns:
bool: True if an ED admission event is found, False otherwise.
"""
for event in self.events:
for matching_rule in self.conversion.get_ed_admission_matching_rules():
if re.match(matching_rule, event.code):
return True
return False

def _has_admission(self) -> bool:
"""
Determines if the visit includes a hospital admission event.
Returns:
bool: True if an admission event is found, False otherwise.
"""
for event in self.events:
for matching_rule in self.conversion.get_admission_matching_rules():
if re.match(matching_rule, event.code):
return True
return False

def _has_discharge(self) -> bool:
"""
Determines if the visit includes a discharge event.
Returns:
bool: True if a discharge event is found, False otherwise.
"""
for event in self.events:
for matching_rule in self.conversion.get_discharge_matching_rules():
if re.match(matching_rule, event.code):
return True
return False

def get_discharge_facility(self) -> Optional[str]:
"""
Extracts the discharge facility code from the discharge event, if present.
Returns:
Optional[str]: The sanitized discharge facility code, or None if no discharge event is found.
"""
if self._has_discharge():
for event in self.events:
for matching_rule in self.conversion.get_discharge_matching_rules():
Expand All @@ -105,12 +157,22 @@ def get_discharge_facility(self) -> Optional[str]:
return None

def _convert_event(self, event) -> List[Event]:
"""
Converts a medical event into a list of CEHR-BERT-compatible events, potentially parsing.
numeric values from text-based events.
Args:
event (meds_reader.Event): The medical event to be converted.
Returns:
List[Event]: A list of converted events, possibly numeric, based on the original event's code and value.
"""
code = event.code
time = getattr(event, "time", None)
text_value = getattr(event, "text_value", None)
numeric_value = getattr(event, "numeric_value", None)
# We try to parse the numeric values from the text value, in other words,
# we try to construct numeric events from the event with a text value

if numeric_value is None and text_value is not None:
conversion_rule = self.conversion.get_text_event_to_numeric_events_rule(code)
if conversion_rule:
Expand Down Expand Up @@ -140,14 +202,20 @@ def _convert_event(self, event) -> List[Event]:
]

def get_meds_events(self) -> Iterable[Event]:
"""
Retrieves all medication events for the visit, converting each raw event if necessary.
Returns:
Iterable[Event]: A list of CEHR-BERT-compatible medication events for the visit.
"""
events = []
for e in self.events:
events.extend(self._convert_event(e))
return events


def convert_one_patient(
patient: meds_reader.Patient,
patient: meds_reader.Subject,
conversion: MedsToCehrBertConversion,
default_visit_id: int = 1,
prediction_time: datetime = None,
Expand Down Expand Up @@ -296,10 +364,10 @@ def convert_one_patient(
age_at_index -= 1

# birth_datetime can not be None
assert birth_datetime is not None, f"patient_id: {patient.patient_id} does not have a valid birth_datetime"
assert birth_datetime is not None, f"patient_id: {patient.subject_id} does not have a valid birth_datetime"

return CehrBertPatient(
patient_id=patient.patient_id,
patient_id=patient.subject_id,
birth_datetime=birth_datetime,
visits=visits,
race=race if race else UNKNOWN_VALUE,
Expand Down Expand Up @@ -346,7 +414,7 @@ def _meds_to_cehrbert_generator(
) -> CehrBertPatient:
conversion = get_meds_to_cehrbert_conversion_cls(meds_to_cehrbert_conversion_type)
for shard in shards:
with meds_reader.PatientDatabase(path_to_db) as patient_database:
with meds_reader.SubjectDatabase(path_to_db) as patient_database:
for patient_id, prediction_time, label in shard:
patient = patient_database[patient_id]
yield convert_one_patient(patient, conversion, default_visit_id, prediction_time, label)
Expand All @@ -363,20 +431,21 @@ def _create_cehrbert_data_from_meds(
if data_args.cohort_folder:
cohort = pd.read_parquet(os.path.join(data_args.cohort_folder, split))
for cohort_row in cohort.itertuples():
patient_id = cohort_row.patient_id
subject_id = cohort_row.subject_id
prediction_time = cohort_row.prediction_time
label = int(cohort_row.boolean_value)
batches.append((patient_id, prediction_time, label))
batches.append((subject_id, prediction_time, label))
else:
patient_split = get_patient_split(data_args.data_folder)
for patient_id in patient_split[split]:
batches.append((patient_id, None, None))
patient_split = get_subject_split(data_args.data_folder)
for subject_id in patient_split[split]:
batches.append((subject_id, None, None))

split_batches = np.array_split(np.asarray(batches), data_args.preprocessing_num_workers)
batch_func = functools.partial(
_meds_to_cehrbert_generator,
path_to_db=data_args.data_folder,
default_visit_id=default_visit_id,
meds_to_cehrbert_conversion_type=data_args.meds_to_cehrbert_conversion_type,
)
dataset = Dataset.from_generator(
batch_func,
Expand Down
2 changes: 1 addition & 1 deletion src/cehrbert/runners/hf_cehrbert_pretrain_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def main():
dataset = load_from_disk(meds_extension_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
except RuntimeError as e:
except FileNotFoundError as e:
LOG.exception(e)
dataset = create_dataset_from_meds_reader(data_args, is_pretraining=True)
if not data_args.streaming:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class AttType(Enum):
DAY = "day"
WEEK = "week"
MONTH = "month"
CEHR_BERT = "cehrbert"
CEHR_BERT = "cehr_bert"
MIX = "mix"
NONE = "none"

Expand Down
Loading