Skip to content

Commit

Permalink
Combine bin with concept (#21)
Browse files Browse the repository at this point in the history
* created partitions on intermediate dataframes

* test new partition

* try broadcasting visit_occurrence_bound

* testing using a refreshed visit_occurrence data frame

* cache ehr records

* testing a new strategy

* updated ehrshot omop conversion

* updated pyspark dependency

* re-implemented generate_visit_id for ehrshot omop conversion

* fixed the visit_id data type in ehrshot data because it needs to be loaded as Float otherwise the entire column would be populated with null

* fixed the logic for creating the artificial visits

* fixed a bug in creating the end date of the artificial visits

* test visit_end_date bounding with the new OMOP

* try broadcasting cohort_visit_occurrence

* try broadcasting visit_index_date

* broadcase visit_occurrence_person

* try repartitioning

* try a different partition strategy

* randomly shuffle visit_occurrence_person

* created the order and reshuffle the dataframe using the order afterwards

* removed an extra comma from a query

* removed event_group_ids from the cehrgpt input data

* upgraded pyspark

* replace event_group_id to N/A instead of NULL

* Revert "removed event_group_ids from the cehrgpt input data"

This reverts commit 38742ec.

* do not take person records into account when creating artificial visits

* invalidate the records that fall ouside the visits

* downgrade pyspark to 3.1.2

* resolve the ambiguous visit_id column

* set the ehrshot visit_id to string type when loading the csv file

* use max_visit_id df to cross join

* added another assertion to test whether the patient count equals 1

* udpated the visit construction logic to ensure its uniqueness

* cache domain_records so record_id is fixed

* changed death_concept_id to cause_concept_id

* fixed the death token priority column

* fixed the unit test

* added an option to exclude features and store the cohorts in the meds format (#20)

* convert prediction_time to timestamp when meds_format is enabled (#22)

* added number_as_value and concept_as_value to the spark dataframes

* set the default value for number_as_value and concept_as_value to None

* insert an hour token between the visit type token and the first medical event in the inpatient visit

* added include_inpatient_hour_token to extract_features

* calculate the first inpatient hour token using the date part of visit_start_datetime

* added aggregate_by_hour to perform the average aggregration over the lab values that occurred within the same hour

* fixed a bug in the lab_hour

* fixed a union bug

* try caching the inpatient events in att decorator

* fixed a bug in generating the hour tokens

* fixed a bug in generating the hour tokens
  • Loading branch information
ChaoPang authored Jan 7, 2025
1 parent 26a65a6 commit 6114e38
Show file tree
Hide file tree
Showing 15 changed files with 371 additions and 185 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"numpy==1.24.3",
"packaging==23.2",
"pandas==2.2.0",
"pyspark==3.5.3"
"pyspark==3.1.2"
]

[tool.setuptools_scm]
Expand Down
13 changes: 11 additions & 2 deletions src/cehrbert_data/apps/generate_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def main(
with_drug_rollup: bool = True,
include_inpatient_hour_token: bool = False,
continue_from_events: bool = False,
refresh_measurement: bool = False
refresh_measurement: bool = False,
aggregate_by_hour: bool = True,
):
spark = SparkSession.builder.appName("Generate CEHR-BERT Training Data").getOrCreate()

Expand All @@ -72,6 +73,7 @@ def main(
f"use_age_group: {use_age_group}\n"
f"with_drug_rollup: {with_drug_rollup}\n"
f"refresh_measurement: {refresh_measurement}\n"
f"aggregate_by_hour: {aggregate_by_hour}\n"
)

domain_tables = []
Expand Down Expand Up @@ -129,7 +131,8 @@ def main(
processed_measurement = get_measurement_table(
spark,
input_folder,
refresh=refresh_measurement
refresh=refresh_measurement,
aggregate_by_hour=aggregate_by_hour,
)
if patient_events:
# Union all measurement records together with other domain records
Expand Down Expand Up @@ -327,6 +330,11 @@ def main(
dest="refresh_measurement",
action="store_true"
)
parser.add_argument(
"--aggregate_by_hour",
dest="aggregate_by_hour",
action="store_true"
)
parser.add_argument(
"--att_type",
dest="att_type",
Expand Down Expand Up @@ -367,4 +375,5 @@ def main(
include_inpatient_hour_token=ARGS.include_inpatient_hour_token,
continue_from_events=ARGS.continue_from_events,
refresh_measurement=ARGS.refresh_measurement,
aggregate_by_hour=ARGS.aggregate_by_hour,
)
5 changes: 5 additions & 0 deletions src/cehrbert_data/cohorts/spark_app_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def __init__(
is_drug_roll_up_concept: bool = True,
include_concept_list: bool = True,
refresh_measurement: bool = False,
aggregate_by_hour: bool = True,
is_new_patient_representation: bool = False,
gpt_patient_sequence: bool = False,
is_hierarchical_bert: bool = False,
Expand Down Expand Up @@ -343,6 +344,7 @@ def __init__(
self._is_prediction_window_unbounded = is_prediction_window_unbounded
self._include_concept_list = include_concept_list
self._refresh_measurement = refresh_measurement
self._aggregate_by_hour = aggregate_by_hour
self._allow_measurement_only = allow_measurement_only
self._output_data_folder = os.path.join(
self._output_folder, re.sub("[^a-z0-9]+", "_", self._cohort_name.lower())
Expand Down Expand Up @@ -383,6 +385,7 @@ def __init__(
f"is_prediction_window_unbounded: {is_prediction_window_unbounded}\n"
f"include_concept_list: {include_concept_list}\n"
f"refresh_measurement: {refresh_measurement}\n"
f"aggregate_by_hour: {aggregate_by_hour}\n"
f"is_observation_window_unbounded: {is_observation_window_unbounded}\n"
f"is_population_estimation: {is_population_estimation}\n"
f"att_type: {att_type}\n"
Expand Down Expand Up @@ -597,6 +600,7 @@ def extract_ehr_records_for_cohort(self, cohort: DataFrame):
with_drug_rollup=self._is_drug_roll_up_concept,
include_concept_list=self._include_concept_list,
refresh_measurement=self._refresh_measurement,
aggregate_by_hour=self._aggregate_by_hour,
)

# Duplicate the records for cohorts that allow multiple entries
Expand Down Expand Up @@ -786,6 +790,7 @@ def create_prediction_cohort(
is_drug_roll_up_concept=spark_args.is_drug_roll_up_concept,
include_concept_list=spark_args.include_concept_list,
refresh_measurement=spark_args.refresh_measurement,
aggregate_by_hour=spark_args.aggregate_by_hour,
is_new_patient_representation=spark_args.is_new_patient_representation,
gpt_patient_sequence=spark_args.gpt_patient_sequence,
is_hierarchical_bert=spark_args.is_hierarchical_bert,
Expand Down
53 changes: 46 additions & 7 deletions src/cehrbert_data/decorators/artificial_time_token_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
VISIT_TYPE_TOKEN_PRIORITY,
DISCHARGE_TOKEN_PRIORITY,
VE_TOKEN_PRIORITY,
FIRST_VISIT_HOUR_TOKEN_PRIORITY,
get_inpatient_token_priority,
get_inpatient_att_token_priority
)
Expand Down Expand Up @@ -38,7 +39,7 @@ def _decorate(self, patient_events: DataFrame):
return patient_events

# visits should the following columns (person_id,
# visit_concept_id, visit_start_date, visit_occurrence_id, domain, concept_value)
# visit_concept_id, visit_start_date, visit_occurrence_id, domain)
cohort_member_person_pair = patient_events.select("person_id", "cohort_member_id").distinct()
valid_visit_ids = patient_events.groupby(
"cohort_member_id",
Expand All @@ -62,7 +63,9 @@ def _decorate(self, patient_events: DataFrame):
"visit_concept_id",
"visit_occurrence_id",
F.lit("visit").alias("domain"),
F.lit(0.0).alias("concept_value"),
F.lit(None).cast("float").alias("number_as_value"),
F.lit(None).cast("string").alias("concept_as_value"),
F.lit(0).alias("is_numeric_type"),
F.lit(0).alias("concept_value_mask"),
F.lit(0).alias("mlm_skip_value"),
"age",
Expand Down Expand Up @@ -232,6 +235,9 @@ def _decorate(self, patient_events: DataFrame):
# Add discharge events to the inpatient visits
inpatient_events = inpatient_events.unionByName(discharge_events)

# Try caching the inpatient events
inpatient_events.cache()

# Get the prev days_since_epoch
inpatient_prev_date_udf = F.lag("date").over(
W.partitionBy("cohort_member_id", "visit_occurrence_id").orderBy("concept_order")
Expand All @@ -252,15 +258,41 @@ def _decorate(self, patient_events: DataFrame):
inpatient_att_token = F.when(
F.col("hour_delta") < 24, F.concat(F.lit("i-H"), F.col("hour_delta"))
).otherwise(F.concat(F.lit("i-"), inpatient_time_token_udf("time_delta")))

# We need to insert an ATT token between midnight and the visit start datetime
first_inpatient_hour_delta_udf = (
F.floor((F.unix_timestamp("visit_start_datetime") - F.unix_timestamp(
F.col("visit_start_datetime").cast("date"))) / 3600)
)

first_hour_tokens = (
visits.where(F.col("visit_concept_id").isin([9201, 262, 8971, 8920]))
.withColumn("hour_delta", first_inpatient_hour_delta_udf)
.where(F.col("hour_delta") > 0)
.withColumn("date", F.col("visit_start_date"))
.withColumn("datetime", F.to_timestamp("date"))
.withColumn("standard_concept_id", F.concat(F.lit("i-H"), F.col("hour_delta")))
.withColumn("visit_concept_order", F.col("min_visit_concept_order"))
.withColumn("concept_order", F.lit(0))
.withColumn("priority", F.lit(FIRST_VISIT_HOUR_TOKEN_PRIORITY))
.withColumn("unit", F.lit(NA))
.withColumn("event_group_id", F.lit(NA))
.drop("min_visit_concept_order", "max_visit_concept_order")
.drop("min_concept_order", "max_concept_order")
.drop("hour_delta", "visit_end_date")
)

# Create ATT tokens within the inpatient visits
inpatient_att_events = (
inpatient_events.withColumn(
"time_stamp_hour", F.hour("datetime")
).withColumn(
"is_span_boundary",
F.row_number().over(
W.partitionBy("cohort_member_id", "visit_occurrence_id", "concept_order").orderBy("priority")
W.partitionBy("cohort_member_id", "visit_occurrence_id", "concept_order")
.orderBy("priority", "date", "time_stamp_hour")
),
)
.where(F.col("is_span_boundary") == 1)
.withColumn("prev_date", inpatient_prev_date_udf)
.withColumn("time_delta", inpatient_time_delta_udf)
.withColumn("prev_datetime", inpatient_prev_datetime_udf)
Expand All @@ -271,12 +303,17 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("visit_concept_order", F.col("visit_concept_order"))
.withColumn("priority", get_inpatient_att_token_priority())
.withColumn("concept_value_mask", F.lit(0))
.withColumn("concept_value", F.lit(0.0))
.withColumn("number_as_value", F.lit(None).cast("float"))
.withColumn("concept_as_value", F.lit(None).cast("string"))
.withColumn("is_numeric_type", F.lit(0))
.withColumn("unit", F.lit(NA))
.withColumn("event_group_id", F.lit(NA))
.drop("prev_date", "time_delta", "is_span_boundary")
.drop("prev_datetime", "hour_delta")
.drop("prev_datetime", "hour_delta", "time_stamp_hour")
)

# Insert the first hour tokens between the visit type and first medical event
inpatient_att_events = inpatient_att_events.unionByName(first_hour_tokens)
else:
# Create ATT tokens within the inpatient visits
inpatient_att_events = (
Expand All @@ -298,7 +335,9 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("visit_concept_order", F.col("visit_concept_order"))
.withColumn("priority", get_inpatient_att_token_priority())
.withColumn("concept_value_mask", F.lit(0))
.withColumn("concept_value", F.lit(0.0))
.withColumn("number_as_value", F.lit(None).cast("float"))
.withColumn("concept_as_value", F.lit(None).cast("string"))
.withColumn("is_numeric_type", F.lit(0))
.withColumn("unit", F.lit(NA))
.withColumn("event_group_id", F.lit(NA))
.drop("prev_date", "time_delta", "is_span_boundary")
Expand Down
38 changes: 9 additions & 29 deletions src/cehrbert_data/decorators/clinical_event_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ClinicalEventDecorator(PatientEventDecorator):
# output_columns = [
# 'cohort_member_id', 'person_id', 'concept_ids', 'visit_segments', 'orders',
# 'dates', 'ages', 'visit_concept_orders', 'num_of_visits', 'num_of_concepts',
# 'concept_value_masks', 'concept_values', 'mlm_skip_values',
# 'concept_value_masks', 'value_as_numbers', 'value_as_concepts', 'mlm_skip_values',
# 'visit_concept_ids', "units"
# ]
def __init__(self, visit_occurrence):
Expand Down Expand Up @@ -131,38 +131,18 @@ def _decorate(self, patient_events: DataFrame):
# Create the concept_value_mask field to indicate whether domain values should be skipped
# As of now only measurement has values, so other domains would be skipped.
patient_events = patient_events.withColumn(
"concept_value_mask", (F.col("domain") == MEASUREMENT).cast("int")
"concept_value_mask", (F.col("domain").isin(MEASUREMENT, CATEGORICAL_MEASUREMENT)).cast("int")
).withColumn(
"is_numeric_type", (F.col("domain") == MEASUREMENT).cast("int")
).withColumn(
"mlm_skip_value",
(F.col("domain").isin([MEASUREMENT, CATEGORICAL_MEASUREMENT])).cast("int"),
)

if "concept_value" not in patient_events.schema.fieldNames():
patient_events = patient_events.withColumn("concept_value", F.lit(0.0))
if "number_as_value" not in patient_events.schema.fieldNames():
patient_events = patient_events.withColumn("number_as_value", F.lit(None).cast("float"))

# Split the categorical measurement standard_concept_id into the question/answer pairs
categorical_measurement_events = (
patient_events.where(F.col("domain") == CATEGORICAL_MEASUREMENT)
.withColumn("measurement_components", F.split("standard_concept_id", "-"))
)
if "concept_as_value" not in patient_events.schema.fieldNames():
patient_events = patient_events.withColumn("concept_as_value", F.lit(None).cast("string"))

categorical_measurement_events_question = categorical_measurement_events.withColumn(
"standard_concept_id",
F.concat(F.lit(MEASUREMENT_QUESTION_PREFIX), F.col("measurement_components").getItem(0))
).drop("measurement_components")

categorical_measurement_events_answer = categorical_measurement_events.withColumn(
"standard_concept_id",
F.concat(F.lit(MEASUREMENT_ANSWER_PREFIX), F.coalesce(F.col("measurement_components").getItem(1), F.lit("0")))
).drop("measurement_components")

other_events = patient_events.where(F.col("domain") != CATEGORICAL_MEASUREMENT)

# (cohort_member_id, person_id, standard_concept_id, date, datetime, visit_occurrence_id, domain,
# concept_value, visit_rank_order, visit_segment, priority, date_in_week,
# concept_value_mask, mlm_skip_value, age)
return other_events.unionByName(
categorical_measurement_events_question
).unionByName(
categorical_measurement_events_answer
)
return patient_events
8 changes: 3 additions & 5 deletions src/cehrbert_data/decorators/death_event_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,10 @@ def _decorate(self, patient_events: DataFrame):
)

last_ve_record.cache()
last_ve_record.show()
# set(['cohort_member_id', 'person_id', 'standard_concept_id', 'date',
# 'visit_occurrence_id', 'domain', 'concept_value', 'visit_rank_order',
# 'visit_segment', 'priority', 'date_in_week', 'concept_value_mask',
# 'mlm_skip_value', 'age', 'visit_concept_id'])

artificial_visit_id = F.row_number().over(
W.partitionBy(F.lit(0)).orderBy("person_id", "cohort_member_id")
) + F.col("max_visit_occurrence_id")
Expand All @@ -59,23 +57,23 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("standard_concept_id", F.lit(DEATH_TOKEN))
.withColumn("domain", F.lit("death"))
.withColumn("visit_rank_order", F.lit(1) + F.col("visit_rank_order"))
.withColumn("priority", DEATH_TOKEN_PRIORITY)
.withColumn("priority", F.lit(DEATH_TOKEN_PRIORITY))
.withColumn("event_group_id", F.lit(NA))
.drop("max_visit_occurrence_id")
)

vs_records = (
death_records
.withColumn("standard_concept_id", F.lit(VS_TOKEN))
.withColumn("priority", VS_TOKEN_PRIORITY)
.withColumn("priority", F.lit(VS_TOKEN_PRIORITY))
.withColumn("unit", F.lit(NA))
.withColumn("event_group_id", F.lit(NA))
)

ve_records = (
death_records
.withColumn("standard_concept_id", F.lit(VE_TOKEN))
.withColumn("priority", VE_TOKEN_PRIORITY)
.withColumn("priority", F.lit(VE_TOKEN_PRIORITY))
.withColumn("unit", F.lit(NA))
.withColumn("event_group_id", F.lit(NA))
)
Expand Down
8 changes: 5 additions & 3 deletions src/cehrbert_data/decorators/demographic_event_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def _decorate(self, patient_events: DataFrame):
return patient_events

# set(['cohort_member_id', 'person_id', 'standard_concept_id', 'date',
# 'visit_occurrence_id', 'domain', 'concept_value', 'visit_rank_order',
# 'visit_segment', 'priority', 'date_in_week', 'concept_value_mask',
# 'visit_occurrence_id', 'domain', 'value_as_number', 'value_as_concept', 'visit_rank_order',
# 'is_numeric_type', 'visit_segment', 'priority', 'date_in_week', 'concept_value_mask',
# 'mlm_skip_value', 'age', 'visit_concept_id'])

# Get the first token of the patient history
Expand All @@ -39,7 +39,9 @@ def _decorate(self, patient_events: DataFrame):
patient_first_token = (
patient_events.withColumn("token_order", first_token_udf)
.withColumn("concept_value_mask", F.lit(0))
.withColumn("concept_value", F.lit(0.0))
.withColumn("number_as_value", F.lit(None).cast("float"))
.withColumn("concept_as_value", F.lit(None).cast("string"))
.withColumn("is_numeric_type", F.lit(0))
.withColumn("unit", F.lit(NA))
.withColumn("event_group_id", F.lit(NA))
.where("token_order = 1")
Expand Down
4 changes: 3 additions & 1 deletion src/cehrbert_data/decorators/patient_event_decorator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def get_required_columns(cls) -> Set[str]:
"datetime",
"visit_occurrence_id",
"domain",
"concept_value",
"concept_as_value",
"is_numeric_type",
"number_as_value",
"visit_rank_order",
"visit_segment",
"priority",
Expand Down
1 change: 1 addition & 0 deletions src/cehrbert_data/decorators/token_priority.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ATT_TOKEN_PRIORITY = -3
VS_TOKEN_PRIORITY = -2
VISIT_TYPE_TOKEN_PRIORITY = -1
FIRST_VISIT_HOUR_TOKEN_PRIORITY = -0.5
DEFAULT_PRIORITY = 0
DISCHARGE_TOKEN_PRIORITY = 100
DEATH_TOKEN_PRIORITY = 199
Expand Down
Loading

0 comments on commit 6114e38

Please sign in to comment.