Skip to content

Commit

Permalink
persist the patient events throughout the generation of the patient s…
Browse files Browse the repository at this point in the history
…equences for better readability and debugging purposes (#24)

* persist the patient events throughout the generation of the patient sequences for debugging purposes

* set the outpatient visit_start_datetime to be at the start of the day

* explicitly cast datetime of ehr records to time stamps

* cast timestamps in extract_features

* infer the inpatient visits based on the duration of the visits. If the duration is grater than 24 hours, the visits will be set to inpatient

* switched to spark function to filter instead of the sql syntax

* saved all the temp dataframes when cleaning up the visit information of the EHR shot records

* updated the real_visits_folder variable name

* fixed a bug in referencing the columns

* use aliases in joining ehr_shot_data to inferred_inpatient_visits

* fixed a bug in creating code for the ehrshot data for non visit events

* fixed a bug in creating the start and end date times for visit records in ehrshot data

* started working on fixing inpatient visits

* added spark applications to connect ehrshot visit chronologically

* removed unncessary logging

* added inpatient_hour_diff_threshold and outpatient_hour_diff_threshold to control how far apart two visits are from each other to be considered as separate visits

* exclude visit_occurrence and death from being updated using the visit_mapping table

* copy vocab tables

* copy person table over

* persist visit during AttEventDecorator

* fixed a bug when extract_features spark app is run when bound_visit_end_date is set to True

* the visits need to be bounded per cohort member

* persist cohort_visit_occurrence to the disk

* updated the persistant dataframe paths

* disconnect the records whose time stamps fall outside of the corresponding visit window

* fix the visit and domain records whose time stamps fall outside of the corresponding visit range

* fixed the bug in creating the new visit_id

* added original_visit_id to the ehrshot output

* fixed a bug when bound_visit_end is enable in extract_features.py

* Add an artificial token for the visit in which the prediction is made

* added placeholder tokens to the output folder for debugging

* fixed the placeholder token time stamps

* filter for visit_occurrence records based on the cohort_member_id and index_date

* use event_start to construct the visit_end_datetime for artificial visits

* changed the way to bound the visit_end_datetime

* fixed windowing bug in creating visit_rank

* fixed a bug in constructing the visit mappings

* removed a query that does not affect the results

* use hours to infer whether the visit start/end date times need to be fixed

* if we split inpatient visits into multiple visits, we need to check if each individual visit is less than 24 hours

* changed the day_cutoff default value to 1

* if we split inpatient visits into multiple visits, we need to check if each individual visit is less than 24 hours

* We create placeholder tokens for those inpatient visits, where the first token occurs after the index_date

* removed unused imports

* try to fix the java incompatibility issue with pyspark

* fixed the unit test
  • Loading branch information
ChaoPang authored Jan 16, 2025
1 parent 6114e38 commit 66b2a6f
Show file tree
Hide file tree
Showing 14 changed files with 872 additions and 114 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: "3.10"
- name: Set up Java 11
uses: actions/setup-java@v3
with:
java-version: "11" # specify the Java version here
distribution: "temurin" # or use 'adopt' or 'zulu', depending on your preference
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 2 additions & 0 deletions src/cehrbert_data/apps/generate_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def main(
exclude_demographic=exclude_demographic,
use_age_group=use_age_group,
include_inpatient_hour_token=include_inpatient_hour_token,
spark=spark,
persistence_folder=output_folder,
)
else:
sequence_data = create_sequence_data(
Expand Down
60 changes: 51 additions & 9 deletions src/cehrbert_data/decorators/artificial_time_token_decorator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from pyspark.sql import DataFrame, functions as F, types as T, Window as W
import os.path

from pyspark.sql import SparkSession, DataFrame, functions as F, types as T, Window as W

from ..const.common import NA
from ..const.artificial_tokens import VS_TOKEN, VE_TOKEN
Expand Down Expand Up @@ -26,13 +28,19 @@ def __init__(
att_type: AttType,
inpatient_att_type: AttType,
include_inpatient_hour_token: bool = False,
spark: SparkSession = None,
persistence_folder: str = None,
):
self._visit_occurrence = visit_occurrence
self._include_visit_type = include_visit_type
self._exclude_visit_tokens = exclude_visit_tokens
self._att_type = att_type
self._inpatient_att_type = inpatient_att_type
self._include_inpatient_hour_token = include_inpatient_hour_token
super().__init__(spark=spark, persistence_folder=persistence_folder)

def get_name(self):
return "att_events"

def _decorate(self, patient_events: DataFrame):
if self._att_type == AttType.NONE:
Expand All @@ -53,9 +61,21 @@ def _decorate(self, patient_events: DataFrame):
F.max("concept_order").alias("max_concept_order"),
)

# The visit records are joined to the cohort members (there could be multiple entries for the same patient)
# if multiple entries are present, we duplicate the visit records for those. If the visit_occurrence dataframe
# contains visits for each cohort member, then we need to add cohort_member_id to the joined expression as well.
if "cohort_member_id" in self._visit_occurrence.columns:
joined_expr = ["person_id", "cohort_member_id"]
else:
joined_expr = ["person_id"]

visit_occurrence = (
self._visit_occurrence.select(
self._visit_occurrence.join(
cohort_member_person_pair,
joined_expr
).select(
"person_id",
"cohort_member_id",
F.col("visit_start_date").cast(T.DateType()).alias("date"),
F.col("visit_start_date").cast(T.DateType()).alias("visit_start_date"),
F.col("visit_start_datetime").cast(T.TimestampType()).alias("visit_start_datetime"),
Expand All @@ -71,8 +91,10 @@ def _decorate(self, patient_events: DataFrame):
"age",
"discharged_to_concept_id",
)
.join(valid_visit_ids, "visit_occurrence_id")
.join(cohort_member_person_pair, ["person_id", "cohort_member_id"])
.join(
valid_visit_ids,
["visit_occurrence_id", "cohort_member_id"]
)
)

# We assume outpatient visits end on the same day, therefore we start visit_end_date to visit_start_date due
Expand All @@ -89,7 +111,10 @@ def _decorate(self, patient_events: DataFrame):
visit_occurrence = visit_occurrence.withColumn("date_in_week", weeks_since_epoch_udf)

# Cache visit for faster processing
visit_occurrence.cache()
visit_occurrence = self.try_persist_data(
visit_occurrence,
os.path.join(self.get_name(), "visit_occurrence_temp"),
)

visits = visit_occurrence.drop("discharged_to_concept_id")

Expand Down Expand Up @@ -180,6 +205,12 @@ def _decorate(self, patient_events: DataFrame):

artificial_tokens = artificial_tokens.drop("visit_end_date")

# Try persisting artificial events
artificial_tokens = self.try_persist_data(
artificial_tokens,
os.path.join(self.get_name(), "artificial_tokens"),
)

# Retrieving the events that are ONLY linked to inpatient visits
inpatient_visits = (
visit_occurrence
Expand Down Expand Up @@ -235,8 +266,10 @@ def _decorate(self, patient_events: DataFrame):
# Add discharge events to the inpatient visits
inpatient_events = inpatient_events.unionByName(discharge_events)

# Try caching the inpatient events
inpatient_events.cache()
# Try persisting the inpatient events for fasting processing
inpatient_events = self.try_persist_data(
inpatient_events, os.path.join(self.get_name(), "inpatient_events")
)

# Get the prev days_since_epoch
inpatient_prev_date_udf = F.lag("date").over(
Expand Down Expand Up @@ -285,11 +318,11 @@ def _decorate(self, patient_events: DataFrame):
# Create ATT tokens within the inpatient visits
inpatient_att_events = (
inpatient_events.withColumn(
"time_stamp_hour", F.hour("datetime")
"time_stamp_hour", F.hour("datetime")
).withColumn(
"is_span_boundary",
F.row_number().over(
W.partitionBy("cohort_member_id", "visit_occurrence_id", "concept_order")
W.partitionBy("cohort_member_id", "visit_occurrence_id")
.orderBy("priority", "date", "time_stamp_hour")
),
)
Expand Down Expand Up @@ -343,6 +376,11 @@ def _decorate(self, patient_events: DataFrame):
.drop("prev_date", "time_delta", "is_span_boundary")
)

# Try persisting the inpatient att events
inpatient_att_events = self.try_persist_data(
inpatient_att_events, os.path.join(self.get_name(), "inpatient_att_events")
)

self.validate(inpatient_events)
self.validate(inpatient_att_events)

Expand All @@ -352,6 +390,10 @@ def _decorate(self, patient_events: DataFrame):
["visit_occurrence_id", "cohort_member_id"],
how="left_anti",
)
# Try persisting the other events
other_events = self.try_persist_data(
other_events, os.path.join(self.get_name(), "other_events")
)

patient_events = inpatient_events.unionByName(inpatient_att_events).unionByName(other_events)

Expand Down
45 changes: 37 additions & 8 deletions src/cehrbert_data/decorators/clinical_event_decorator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os.path

from ..const.common import (
MEASUREMENT,
CATEGORICAL_MEASUREMENT,
MEASUREMENT_QUESTION_PREFIX,
MEASUREMENT_ANSWER_PREFIX
)
from pyspark.sql import DataFrame, functions as F, Window as W, types as T
from pyspark.sql import SparkSession, DataFrame, functions as F, Window as W, types as T

from .patient_event_decorator_base import PatientEventDecorator
from .token_priority import DEFAULT_PRIORITY
Expand All @@ -17,8 +17,12 @@ class ClinicalEventDecorator(PatientEventDecorator):
# 'concept_value_masks', 'value_as_numbers', 'value_as_concepts', 'mlm_skip_values',
# 'visit_concept_ids', "units"
# ]
def __init__(self, visit_occurrence):
def __init__(self, visit_occurrence, spark: SparkSession = None, persistence_folder: str = None):
self._visit_occurrence = visit_occurrence
super().__init__(spark=spark, persistence_folder=persistence_folder)

def get_name(self):
return "clinical_events"

def _decorate(self, patient_events: DataFrame):
"""
Expand All @@ -43,10 +47,18 @@ def _decorate(self, patient_events: DataFrame):
visit_segment_udf = F.col("visit_rank_order") % F.lit(2) + 1

# The visit records are joined to the cohort members (there could be multiple entries for the same patient)
# if multiple entries are present, we duplicate the visit records for those.
# if multiple entries are present, we duplicate the visit records for those. If the visit_occurrence dataframe
# contains visits for each cohort member, then we need to add cohort_member_id to the joined expression as well.
if "cohort_member_id" in self._visit_occurrence.columns:
joined_expr = ["visit_occurrence_id", "cohort_member_id"]
else:
joined_expr = ["visit_occurrence_id"]

visits = (
self._visit_occurrence.join(valid_visit_ids, "visit_occurrence_id")
.select(
self._visit_occurrence.join(
valid_visit_ids,
joined_expr
).select(
"person_id",
"cohort_member_id",
"visit_occurrence_id",
Expand Down Expand Up @@ -90,6 +102,15 @@ def _decorate(self, patient_events: DataFrame):
),
).otherwise(F.col("visit_start_date"))

# We need to set the visit_start_datetime at the beginning of the visit start date
# because there could be outpatient visit records whose visit_start_datetime is set to the end of the day
visit_start_datetime_udf = (
F.when(
F.col("is_inpatient") == 0,
F.col("visit_start_date")
).otherwise(F.col("visit_start_datetime"))
).cast(T.TimestampType())

# We need to bound the medical event dates between visit_start_date and visit_end_date
bound_medical_event_date = F.when(
F.col("date") < F.col("visit_start_date"), F.col("visit_start_date")
Expand All @@ -108,8 +129,10 @@ def _decorate(self, patient_events: DataFrame):

patient_events = (
patient_events.join(visits, ["cohort_member_id", "visit_occurrence_id"])
.withColumn("datetime", F.to_timestamp("datetime"))
.withColumn("visit_start_datetime", visit_start_datetime_udf)
.withColumn("visit_end_date", visit_end_date_udf)
.withColumn("visit_end_datetime", F.date_add("visit_end_date", 1))
.withColumn("visit_end_datetime", F.date_add("visit_end_date", 1).cast(T.TimestampType()))
.withColumn("visit_end_datetime", F.expr("visit_end_datetime - INTERVAL 1 MINUTE"))
.withColumn("date", bound_medical_event_date)
.withColumn("datetime", bound_medical_event_datetime)
Expand Down Expand Up @@ -145,4 +168,10 @@ def _decorate(self, patient_events: DataFrame):
if "concept_as_value" not in patient_events.schema.fieldNames():
patient_events = patient_events.withColumn("concept_as_value", F.lit(None).cast("string"))

# Try persisting the clinical events
patient_events = self.try_persist_data(
patient_events,
self.get_name()
)

return patient_events
19 changes: 14 additions & 5 deletions src/cehrbert_data/decorators/death_event_decorator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pyspark.sql import DataFrame, functions as F, Window as W, types as T
import os
from pyspark.sql import SparkSession, DataFrame, functions as F, Window as W, types as T

from ..const.common import NA
from ..const.artificial_tokens import VS_TOKEN, VE_TOKEN, DEATH_TOKEN
Expand All @@ -20,9 +21,13 @@


class DeathEventDecorator(PatientEventDecorator):
def __init__(self, death, att_type):
def __init__(self, death, att_type, spark: SparkSession = None, persistence_folder: str = None):
self._death = death
self._att_type = att_type
super().__init__(spark=spark, persistence_folder=persistence_folder)

def get_name(self):
return "death_tokens"

def _decorate(self, patient_events: DataFrame):
if self._death is None:
Expand All @@ -32,7 +37,7 @@ def _decorate(self, patient_events: DataFrame):

max_visit_occurrence_id = death_records.select(F.max("visit_occurrence_id").alias("max_visit_occurrence_id"))

last_ve_record = (
last_ve_events = (
death_records.where(F.col("standard_concept_id") == VE_TOKEN)
.withColumn(
"record_rank",
Expand All @@ -42,7 +47,7 @@ def _decorate(self, patient_events: DataFrame):
.drop("record_rank")
)

last_ve_record.cache()
last_ve_events.cache()
# set(['cohort_member_id', 'person_id', 'standard_concept_id', 'date',
# 'visit_occurrence_id', 'domain', 'concept_value', 'visit_rank_order',
# 'visit_segment', 'priority', 'date_in_week', 'concept_value_mask',
Expand All @@ -52,7 +57,7 @@ def _decorate(self, patient_events: DataFrame):
) + F.col("max_visit_occurrence_id")

death_records = (
last_ve_record.crossJoin(max_visit_occurrence_id)
last_ve_events.crossJoin(max_visit_occurrence_id)
.withColumn("visit_occurrence_id", artificial_visit_id)
.withColumn("standard_concept_id", F.lit(DEATH_TOKEN))
.withColumn("domain", F.lit("death"))
Expand Down Expand Up @@ -107,6 +112,10 @@ def _decorate(self, patient_events: DataFrame):

new_tokens = death_events.unionByName(vs_records).unionByName(death_records).unionByName(ve_records)
new_tokens = new_tokens.drop("death_date")
new_tokens = self.try_persist_data(
new_tokens,
os.path.join(self.get_name(), "death_events")
)
self.validate(new_tokens)

return patient_events.unionByName(new_tokens)
34 changes: 31 additions & 3 deletions src/cehrbert_data/decorators/demographic_event_decorator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pyspark.sql import DataFrame, functions as F, Window as W, types as T
import os
from pyspark.sql import SparkSession, DataFrame, functions as F, Window as W, types as T

from .patient_event_decorator_base import PatientEventDecorator

Expand All @@ -12,9 +13,18 @@


class DemographicEventDecorator(PatientEventDecorator):
def __init__(self, patient_demographic, use_age_group: bool = False):
def __init__(
self, patient_demographic,
use_age_group: bool = False,
spark: SparkSession = None,
persistence_folder: str = None
):
self._patient_demographic = patient_demographic
self._use_age_group = use_age_group
super().__init__(spark=spark, persistence_folder=persistence_folder)

def get_name(self):
return "demographic_events"

def _decorate(self, patient_events: DataFrame):
if self._patient_demographic is None:
Expand Down Expand Up @@ -63,7 +73,10 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("concept_order", F.lit(0))
)

sequence_start_year_token.cache()
# Try persisting the start year tokens
sequence_start_year_token = self.try_persist_data(
sequence_start_year_token, os.path.join(self.get_name(), "sequence_start_year_tokens")
)

if self._use_age_group:
calculate_age_group_at_first_visit_udf = F.ceil(
Expand All @@ -89,6 +102,11 @@ def _decorate(self, patient_events: DataFrame):
.drop("birth_datetime")
)

# Try persisting the age tokens
sequence_age_token = self.try_persist_data(
sequence_age_token, os.path.join(self.get_name(), "sequence_age_tokens")
)

sequence_gender_token = (
self._patient_demographic.select(F.col("person_id"), F.col("gender_concept_id"))
.join(sequence_start_year_token, "person_id")
Expand All @@ -97,6 +115,11 @@ def _decorate(self, patient_events: DataFrame):
.drop("gender_concept_id")
)

# Try persisting the gender tokens
sequence_gender_token = self.try_persist_data(
sequence_gender_token, os.path.join(self.get_name(), "sequence_gender_tokens")
)

sequence_race_token = (
self._patient_demographic.select(F.col("person_id"), F.col("race_concept_id"))
.join(sequence_start_year_token, "person_id")
Expand All @@ -105,6 +128,11 @@ def _decorate(self, patient_events: DataFrame):
.drop("race_concept_id")
)

# Try persisting the race tokens
sequence_race_token = self.try_persist_data(
sequence_race_token, os.path.join(self.get_name(), "sequence_race_tokens")
)

patient_events = patient_events.unionByName(sequence_start_year_token)
patient_events = patient_events.unionByName(sequence_age_token)
patient_events = patient_events.unionByName(sequence_gender_token)
Expand Down
Loading

0 comments on commit 66b2a6f

Please sign in to comment.