Skip to content

Commit

Permalink
moved the micmic meds to cehrbert logic to meds_to_cehrbert_micmic4.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Oct 10, 2024
1 parent 54a7af8 commit 72babce
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 82 deletions.
4 changes: 4 additions & 0 deletions src/cehrbert/data_generators/hf_data_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
UNKNOWN_VALUE = "Unknown"
DEFAULT_ED_CONCEPT_ID = "Visit/ER"
DEFAULT_OUTPATIENT_CONCEPT_ID = "Visit/OP"
DEFAULT_INPATIENT_CONCEPT_ID = "Visit/IP"
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class MedsToCehrBertConversion(ABC):
or None if no rule exists.
"""

def __init__(self):
def __init__(self, **kwargs):
"""
Initializes the MedsToCehrBertConversion class by caching the matching rules and.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import meds_reader

from cehrbert.data_generators.hf_data_generator import DEFAULT_INPATIENT_CONCEPT_ID
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import birth_codes
from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import (
EventConversionRule,
Expand All @@ -14,8 +15,8 @@

class MedsToBertMimic4(MedsToCehrBertConversion):

def __init__(self, default_visit_id):
super().__init__()
def __init__(self, default_visit_id, **kwargs):
super().__init__(**kwargs)
self.default_visit_id = default_visit_id

def generate_demographics_and_patient_blocks(
Expand Down Expand Up @@ -62,6 +63,73 @@ def generate_demographics_and_patient_blocks(
if events_for_current_date:
patient_blocks.append(PatientBlock(events_for_current_date, visit_id, self))

admit_discharge_pairs = []
active_ed_index = None
active_admission_index = None
# |ED|24-hours|Admission| ... |Discharge| -> ED will be merged into the admission (within 24 hours)
# |ED|25-hours|Admission| ... |Discharge| -> ED will NOT be merged into the admission
# |Admission|ED| ... |Discharge| -> ED will be merged into the admission
# |Admission|Admission|ED| ... |Discharge|
# -> The first admission will be ignored and turned into a separate visit
# -> The second Admission and ED will be merged
for i, patient_block in enumerate(patient_blocks):
# Keep track of the ED block when there is no on-going admission
if patient_block.has_ed_admission and active_admission_index is None:
active_ed_index = i
# Keep track of the admission block
if patient_block.has_admission:
# If the ED event has occurred, we need to check the time difference between
# the ED event and the subsequent hospital admission
if active_ed_index is not None:

hour_diff = (
patient_block.min_time - patient_blocks[active_ed_index].max_time
).total_seconds() / 3600
# If the time difference between the ed and admission is leq 24 hours,
# we consider ED to be part of the visits
if hour_diff <= 24 or active_ed_index == i:
active_admission_index = active_ed_index
active_ed_index = None
else:
active_admission_index = i

if patient_block.has_discharge:
if active_admission_index is not None:
admit_discharge_pairs.append((active_admission_index, i))
# When the patient is discharged from the hospital, we assume the admission and ED should end
active_admission_index = None
active_ed_index = None

# Check the last block of the patient history to see whether the admission is partial
if i == len(patient_blocks) - 1:
# This indicates an ongoing (incomplete) inpatient visit,
# this is a common pattern for inpatient visit prediction problems,
# where the data from the first 24-48 hours after the admission
# are used to predict something about the admission
if active_admission_index is not None and prediction_time is not None:
admit_discharge_pairs.append((active_admission_index, i))

# Update visit_id for the admission blocks
for admit_index, discharge_index in admit_discharge_pairs:
admission_block = patient_blocks[admit_index]
discharge_block = patient_blocks[discharge_index]
visit_id = admission_block.visit_id
for i in range(admit_index, discharge_index + 1):
patient_blocks[i].visit_id = visit_id
patient_blocks[i].visit_type = DEFAULT_INPATIENT_CONCEPT_ID
# There could be events that occur after the discharge, which are considered as part of the visit
# we need to check if the time stamp of the next block is within 12 hours
if discharge_index + 1 < len(patient_blocks):
next_block = patient_blocks[discharge_index + 1]
hour_diff = (next_block.min_time - discharge_block.max_time).total_seconds() / 3600
assert hour_diff >= 0, (
f"next_block.min_time: {next_block.min_time} "
f"must be GE discharge_block.max_time: {discharge_block.max_time}"
)
if hour_diff <= 12:
next_block.visit_id = visit_id
next_block.visit_type = DEFAULT_INPATIENT_CONCEPT_ID

demographics = PatientDemographics(birth_datetime=birth_datetime, race=race, gender=gender, ethnicity=ethnicity)
return demographics, patient_blocks

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def generate_demographics_and_patient_blocks(
current_date = None
events_for_current_date = []
patient_blocks = []

for e in patient.events:

# Skip out of the loop if the events' time stamps are beyond the prediction time
Expand Down
93 changes: 14 additions & 79 deletions src/cehrbert/data_generators/hf_data_generator/meds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,20 @@
import pandas as pd
from datasets import Dataset, DatasetDict, Split

from cehrbert.data_generators.hf_data_generator import (
DEFAULT_ED_CONCEPT_ID,
DEFAULT_INPATIENT_CONCEPT_ID,
DEFAULT_OUTPATIENT_CONCEPT_ID,
UNKNOWN_VALUE,
)
from cehrbert.data_generators.hf_data_generator.hf_dataset import apply_cehrbert_dataset_mapping
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping, birth_codes
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping
from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import (
MedsToCehrBertConversion,
)
from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_omop import (
MedsToCehrbertOMOP,
)
from cehrbert.med_extension.schema_extension import CehrBertPatient, Event, Visit
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments, MedsToCehrBertConversionType

UNKNOWN_VALUE = "Unknown"
DEFAULT_ED_CONCEPT_ID = "9203"
DEFAULT_OUTPATIENT_CONCEPT_ID = "9202"
DEFAULT_INPATIENT_CONCEPT_ID = "9201"
MEDS_SPLIT_DATA_SPLIT_MAPPING = {
"train": Split.TRAIN,
"tuning": Split.VALIDATION,
Expand All @@ -35,15 +34,15 @@


def get_meds_to_cehrbert_conversion_cls(
meds_to_cehrbert_conversion_type: Union[MedsToCehrBertConversionType, str],
meds_to_cehrbert_conversion_type: Union[MedsToCehrBertConversionType, str], **kwargs
) -> MedsToCehrBertConversion:
for cls in MedsToCehrBertConversion.__subclasses__():
if isinstance(meds_to_cehrbert_conversion_type, MedsToCehrBertConversionType):
if meds_to_cehrbert_conversion_type.name == cls.__name__:
return cls()
return cls(**kwargs)
elif isinstance(meds_to_cehrbert_conversion_type, str):
if meds_to_cehrbert_conversion_type == cls.__name__:
return cls()
return cls(**kwargs)
raise RuntimeError(f"{meds_to_cehrbert_conversion_type} is not a valid MedsToCehrBertConversionType")


Expand Down Expand Up @@ -236,7 +235,6 @@ def get_meds_events(self) -> Iterable[Event]:
def convert_one_patient(
patient: meds_reader.Subject,
conversion: MedsToCehrBertConversion,
default_visit_id: int = 1,
prediction_time: datetime = None,
label: Union[int, float] = None,
) -> CehrBertPatient:
Expand Down Expand Up @@ -309,71 +307,6 @@ def convert_one_patient(
patient=patient, prediction_time=prediction_time
)

admit_discharge_pairs = []
active_ed_index = None
active_admission_index = None
# |ED|24-hours|Admission| ... |Discharge| -> ED will be merged into the admission (within 24 hours)
# |ED|25-hours|Admission| ... |Discharge| -> ED will NOT be merged into the admission
# |Admission|ED| ... |Discharge| -> ED will be merged into the admission
# |Admission|Admission|ED| ... |Discharge|
# -> The first admission will be ignored and turned into a separate visit
# -> The second Admission and ED will be merged
for i, patient_block in enumerate(patient_blocks):
# Keep track of the ED block when there is no on-going admission
if patient_block.has_ed_admission and active_admission_index is None:
active_ed_index = i
# Keep track of the admission block
if patient_block.has_admission:
# If the ED event has occurred, we need to check the time difference between
# the ED event and the subsequent hospital admission
if active_ed_index is not None:

hour_diff = (patient_block.min_time - patient_blocks[active_ed_index].max_time).total_seconds() / 3600
# If the time difference between the ed and admission is leq 24 hours,
# we consider ED to be part of the visits
if hour_diff <= 24 or active_ed_index == i:
active_admission_index = active_ed_index
active_ed_index = None
else:
active_admission_index = i

if patient_block.has_discharge:
if active_admission_index is not None:
admit_discharge_pairs.append((active_admission_index, i))
# When the patient is discharged from the hospital, we assume the admission and ED should end
active_admission_index = None
active_ed_index = None

# Check the last block of the patient history to see whether the admission is partial
if i == len(patient_blocks) - 1:
# This indicates an ongoing (incomplete) inpatient visit,
# this is a common pattern for inpatient visit prediction problems,
# where the data from the first 24-48 hours after the admission
# are used to predict something about the admission
if active_admission_index is not None and prediction_time is not None:
admit_discharge_pairs.append((active_admission_index, i))

# Update visit_id for the admission blocks
for admit_index, discharge_index in admit_discharge_pairs:
admission_block = patient_blocks[admit_index]
discharge_block = patient_blocks[discharge_index]
visit_id = admission_block.visit_id
for i in range(admit_index, discharge_index + 1):
patient_blocks[i].visit_id = visit_id
patient_blocks[i].visit_type = DEFAULT_INPATIENT_CONCEPT_ID
# There could be events that occur after the discharge, which are considered as part of the visit
# we need to check if the time stamp of the next block is within 12 hours
if discharge_index + 1 < len(patient_blocks):
next_block = patient_blocks[discharge_index + 1]
hour_diff = (next_block.min_time - discharge_block.max_time).total_seconds() / 3600
assert hour_diff >= 0, (
f"next_block.min_time: {next_block.min_time} "
f"must be GE discharge_block.max_time: {discharge_block.max_time}"
)
if hour_diff <= 12:
next_block.visit_id = visit_id
next_block.visit_type = DEFAULT_INPATIENT_CONCEPT_ID

patient_block_dict = collections.defaultdict(list)
for patient_block in patient_blocks:
patient_block_dict[patient_block.visit_id].append(patient_block)
Expand Down Expand Up @@ -461,12 +394,14 @@ def _meds_to_cehrbert_generator(
default_visit_id: int,
meds_to_cehrbert_conversion_type: MedsToCehrBertConversionType,
) -> CehrBertPatient:
conversion = get_meds_to_cehrbert_conversion_cls(meds_to_cehrbert_conversion_type)
conversion = get_meds_to_cehrbert_conversion_cls(
meds_to_cehrbert_conversion_type, default_visit_id=default_visit_id
)
with meds_reader.SubjectDatabase(path_to_db) as patient_database:
for shard in shards:
for patient_id, prediction_time, label in shard:
patient = patient_database[patient_id]
yield convert_one_patient(patient, conversion, default_visit_id, prediction_time, label)
yield convert_one_patient(patient, conversion, prediction_time, label)


def _create_cehrbert_data_from_meds(
Expand Down

0 comments on commit 72babce

Please sign in to comment.