From 6114e38221c2e325a4b0620cf9ac09851a810cc5 Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Tue, 7 Jan 2025 09:02:08 -0500 Subject: [PATCH] Combine bin with concept (#21) * created partitions on intermediate dataframes * test new partition * try broadcasting visit_occurrence_bound * testing using a refreshed visit_occurrence data frame * cache ehr records * testing a new strategy * updated ehrshot omop conversion * updated pyspark dependency * re-implemented generate_visit_id for ehrshot omop conversion * fixed the visit_id data type in ehrshot data because it needs to be loaded as Float otherwise the entire column would be populated with null * fixed the logic for creating the artificial visits * fixed a bug in creating the end date of the artificial visits * test visit_end_date bounding with the new OMOP * try broadcasting cohort_visit_occurrence * try broadcasting visit_index_date * broadcase visit_occurrence_person * try repartitioning * try a different partition strategy * randomly shuffle visit_occurrence_person * created the order and reshuffle the dataframe using the order afterwards * removed an extra comma from a query * removed event_group_ids from the cehrgpt input data * upgraded pyspark * replace event_group_id to N/A instead of NULL * Revert "removed event_group_ids from the cehrgpt input data" This reverts commit 38742ec850f0ec6561bd8a807be7532ed89d1688. * do not take person records into account when creating artificial visits * invalidate the records that fall ouside the visits * downgrade pyspark to 3.1.2 * resolve the ambiguous visit_id column * set the ehrshot visit_id to string type when loading the csv file * use max_visit_id df to cross join * added another assertion to test whether the patient count equals 1 * udpated the visit construction logic to ensure its uniqueness * cache domain_records so record_id is fixed * changed death_concept_id to cause_concept_id * fixed the death token priority column * fixed the unit test * added an option to exclude features and store the cohorts in the meds format (#20) * convert prediction_time to timestamp when meds_format is enabled (#22) * added number_as_value and concept_as_value to the spark dataframes * set the default value for number_as_value and concept_as_value to None * insert an hour token between the visit type token and the first medical event in the inpatient visit * added include_inpatient_hour_token to extract_features * calculate the first inpatient hour token using the date part of visit_start_datetime * added aggregate_by_hour to perform the average aggregration over the lab values that occurred within the same hour * fixed a bug in the lab_hour * fixed a union bug * try caching the inpatient events in att decorator * fixed a bug in generating the hour tokens * fixed a bug in generating the hour tokens --- pyproject.toml | 2 +- .../apps/generate_training_data.py | 13 +- src/cehrbert_data/cohorts/spark_app_base.py | 5 + .../artificial_time_token_decorator.py | 53 ++++- .../decorators/clinical_event_decorator.py | 38 +-- .../decorators/death_event_decorator.py | 8 +- .../decorators/demographic_event_decorator.py | 8 +- .../patient_event_decorator_base.py | 4 +- .../decorators/token_priority.py | 1 + src/cehrbert_data/tools/ehrshot_to_omop.py | 223 +++++++++++++----- src/cehrbert_data/tools/extract_features.py | 69 +++--- src/cehrbert_data/utils/spark_parse_args.py | 6 + src/cehrbert_data/utils/spark_utils.py | 94 +++++--- .../test_generate_training_data.py | 1 + tests/unit_tests/test_ehrshot_to_omop.py | 31 ++- 15 files changed, 371 insertions(+), 185 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ae06c81..433c622 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "numpy==1.24.3", "packaging==23.2", "pandas==2.2.0", - "pyspark==3.5.3" + "pyspark==3.1.2" ] [tool.setuptools_scm] diff --git a/src/cehrbert_data/apps/generate_training_data.py b/src/cehrbert_data/apps/generate_training_data.py index 5d8b722..9b05247 100644 --- a/src/cehrbert_data/apps/generate_training_data.py +++ b/src/cehrbert_data/apps/generate_training_data.py @@ -48,7 +48,8 @@ def main( with_drug_rollup: bool = True, include_inpatient_hour_token: bool = False, continue_from_events: bool = False, - refresh_measurement: bool = False + refresh_measurement: bool = False, + aggregate_by_hour: bool = True, ): spark = SparkSession.builder.appName("Generate CEHR-BERT Training Data").getOrCreate() @@ -72,6 +73,7 @@ def main( f"use_age_group: {use_age_group}\n" f"with_drug_rollup: {with_drug_rollup}\n" f"refresh_measurement: {refresh_measurement}\n" + f"aggregate_by_hour: {aggregate_by_hour}\n" ) domain_tables = [] @@ -129,7 +131,8 @@ def main( processed_measurement = get_measurement_table( spark, input_folder, - refresh=refresh_measurement + refresh=refresh_measurement, + aggregate_by_hour=aggregate_by_hour, ) if patient_events: # Union all measurement records together with other domain records @@ -327,6 +330,11 @@ def main( dest="refresh_measurement", action="store_true" ) + parser.add_argument( + "--aggregate_by_hour", + dest="aggregate_by_hour", + action="store_true" + ) parser.add_argument( "--att_type", dest="att_type", @@ -367,4 +375,5 @@ def main( include_inpatient_hour_token=ARGS.include_inpatient_hour_token, continue_from_events=ARGS.continue_from_events, refresh_measurement=ARGS.refresh_measurement, + aggregate_by_hour=ARGS.aggregate_by_hour, ) diff --git a/src/cehrbert_data/cohorts/spark_app_base.py b/src/cehrbert_data/cohorts/spark_app_base.py index 3067bab..8461f79 100644 --- a/src/cehrbert_data/cohorts/spark_app_base.py +++ b/src/cehrbert_data/cohorts/spark_app_base.py @@ -295,6 +295,7 @@ def __init__( is_drug_roll_up_concept: bool = True, include_concept_list: bool = True, refresh_measurement: bool = False, + aggregate_by_hour: bool = True, is_new_patient_representation: bool = False, gpt_patient_sequence: bool = False, is_hierarchical_bert: bool = False, @@ -343,6 +344,7 @@ def __init__( self._is_prediction_window_unbounded = is_prediction_window_unbounded self._include_concept_list = include_concept_list self._refresh_measurement = refresh_measurement + self._aggregate_by_hour = aggregate_by_hour self._allow_measurement_only = allow_measurement_only self._output_data_folder = os.path.join( self._output_folder, re.sub("[^a-z0-9]+", "_", self._cohort_name.lower()) @@ -383,6 +385,7 @@ def __init__( f"is_prediction_window_unbounded: {is_prediction_window_unbounded}\n" f"include_concept_list: {include_concept_list}\n" f"refresh_measurement: {refresh_measurement}\n" + f"aggregate_by_hour: {aggregate_by_hour}\n" f"is_observation_window_unbounded: {is_observation_window_unbounded}\n" f"is_population_estimation: {is_population_estimation}\n" f"att_type: {att_type}\n" @@ -597,6 +600,7 @@ def extract_ehr_records_for_cohort(self, cohort: DataFrame): with_drug_rollup=self._is_drug_roll_up_concept, include_concept_list=self._include_concept_list, refresh_measurement=self._refresh_measurement, + aggregate_by_hour=self._aggregate_by_hour, ) # Duplicate the records for cohorts that allow multiple entries @@ -786,6 +790,7 @@ def create_prediction_cohort( is_drug_roll_up_concept=spark_args.is_drug_roll_up_concept, include_concept_list=spark_args.include_concept_list, refresh_measurement=spark_args.refresh_measurement, + aggregate_by_hour=spark_args.aggregate_by_hour, is_new_patient_representation=spark_args.is_new_patient_representation, gpt_patient_sequence=spark_args.gpt_patient_sequence, is_hierarchical_bert=spark_args.is_hierarchical_bert, diff --git a/src/cehrbert_data/decorators/artificial_time_token_decorator.py b/src/cehrbert_data/decorators/artificial_time_token_decorator.py index 0eb8cbc..a45d160 100644 --- a/src/cehrbert_data/decorators/artificial_time_token_decorator.py +++ b/src/cehrbert_data/decorators/artificial_time_token_decorator.py @@ -11,6 +11,7 @@ VISIT_TYPE_TOKEN_PRIORITY, DISCHARGE_TOKEN_PRIORITY, VE_TOKEN_PRIORITY, + FIRST_VISIT_HOUR_TOKEN_PRIORITY, get_inpatient_token_priority, get_inpatient_att_token_priority ) @@ -38,7 +39,7 @@ def _decorate(self, patient_events: DataFrame): return patient_events # visits should the following columns (person_id, - # visit_concept_id, visit_start_date, visit_occurrence_id, domain, concept_value) + # visit_concept_id, visit_start_date, visit_occurrence_id, domain) cohort_member_person_pair = patient_events.select("person_id", "cohort_member_id").distinct() valid_visit_ids = patient_events.groupby( "cohort_member_id", @@ -62,7 +63,9 @@ def _decorate(self, patient_events: DataFrame): "visit_concept_id", "visit_occurrence_id", F.lit("visit").alias("domain"), - F.lit(0.0).alias("concept_value"), + F.lit(None).cast("float").alias("number_as_value"), + F.lit(None).cast("string").alias("concept_as_value"), + F.lit(0).alias("is_numeric_type"), F.lit(0).alias("concept_value_mask"), F.lit(0).alias("mlm_skip_value"), "age", @@ -232,6 +235,9 @@ def _decorate(self, patient_events: DataFrame): # Add discharge events to the inpatient visits inpatient_events = inpatient_events.unionByName(discharge_events) + # Try caching the inpatient events + inpatient_events.cache() + # Get the prev days_since_epoch inpatient_prev_date_udf = F.lag("date").over( W.partitionBy("cohort_member_id", "visit_occurrence_id").orderBy("concept_order") @@ -252,15 +258,41 @@ def _decorate(self, patient_events: DataFrame): inpatient_att_token = F.when( F.col("hour_delta") < 24, F.concat(F.lit("i-H"), F.col("hour_delta")) ).otherwise(F.concat(F.lit("i-"), inpatient_time_token_udf("time_delta"))) + + # We need to insert an ATT token between midnight and the visit start datetime + first_inpatient_hour_delta_udf = ( + F.floor((F.unix_timestamp("visit_start_datetime") - F.unix_timestamp( + F.col("visit_start_datetime").cast("date"))) / 3600) + ) + + first_hour_tokens = ( + visits.where(F.col("visit_concept_id").isin([9201, 262, 8971, 8920])) + .withColumn("hour_delta", first_inpatient_hour_delta_udf) + .where(F.col("hour_delta") > 0) + .withColumn("date", F.col("visit_start_date")) + .withColumn("datetime", F.to_timestamp("date")) + .withColumn("standard_concept_id", F.concat(F.lit("i-H"), F.col("hour_delta"))) + .withColumn("visit_concept_order", F.col("min_visit_concept_order")) + .withColumn("concept_order", F.lit(0)) + .withColumn("priority", F.lit(FIRST_VISIT_HOUR_TOKEN_PRIORITY)) + .withColumn("unit", F.lit(NA)) + .withColumn("event_group_id", F.lit(NA)) + .drop("min_visit_concept_order", "max_visit_concept_order") + .drop("min_concept_order", "max_concept_order") + .drop("hour_delta", "visit_end_date") + ) + # Create ATT tokens within the inpatient visits inpatient_att_events = ( inpatient_events.withColumn( + "time_stamp_hour", F.hour("datetime") + ).withColumn( "is_span_boundary", F.row_number().over( - W.partitionBy("cohort_member_id", "visit_occurrence_id", "concept_order").orderBy("priority") + W.partitionBy("cohort_member_id", "visit_occurrence_id", "concept_order") + .orderBy("priority", "date", "time_stamp_hour") ), ) - .where(F.col("is_span_boundary") == 1) .withColumn("prev_date", inpatient_prev_date_udf) .withColumn("time_delta", inpatient_time_delta_udf) .withColumn("prev_datetime", inpatient_prev_datetime_udf) @@ -271,12 +303,17 @@ def _decorate(self, patient_events: DataFrame): .withColumn("visit_concept_order", F.col("visit_concept_order")) .withColumn("priority", get_inpatient_att_token_priority()) .withColumn("concept_value_mask", F.lit(0)) - .withColumn("concept_value", F.lit(0.0)) + .withColumn("number_as_value", F.lit(None).cast("float")) + .withColumn("concept_as_value", F.lit(None).cast("string")) + .withColumn("is_numeric_type", F.lit(0)) .withColumn("unit", F.lit(NA)) .withColumn("event_group_id", F.lit(NA)) .drop("prev_date", "time_delta", "is_span_boundary") - .drop("prev_datetime", "hour_delta") + .drop("prev_datetime", "hour_delta", "time_stamp_hour") ) + + # Insert the first hour tokens between the visit type and first medical event + inpatient_att_events = inpatient_att_events.unionByName(first_hour_tokens) else: # Create ATT tokens within the inpatient visits inpatient_att_events = ( @@ -298,7 +335,9 @@ def _decorate(self, patient_events: DataFrame): .withColumn("visit_concept_order", F.col("visit_concept_order")) .withColumn("priority", get_inpatient_att_token_priority()) .withColumn("concept_value_mask", F.lit(0)) - .withColumn("concept_value", F.lit(0.0)) + .withColumn("number_as_value", F.lit(None).cast("float")) + .withColumn("concept_as_value", F.lit(None).cast("string")) + .withColumn("is_numeric_type", F.lit(0)) .withColumn("unit", F.lit(NA)) .withColumn("event_group_id", F.lit(NA)) .drop("prev_date", "time_delta", "is_span_boundary") diff --git a/src/cehrbert_data/decorators/clinical_event_decorator.py b/src/cehrbert_data/decorators/clinical_event_decorator.py index 7d3a38b..69f8c8e 100644 --- a/src/cehrbert_data/decorators/clinical_event_decorator.py +++ b/src/cehrbert_data/decorators/clinical_event_decorator.py @@ -14,7 +14,7 @@ class ClinicalEventDecorator(PatientEventDecorator): # output_columns = [ # 'cohort_member_id', 'person_id', 'concept_ids', 'visit_segments', 'orders', # 'dates', 'ages', 'visit_concept_orders', 'num_of_visits', 'num_of_concepts', - # 'concept_value_masks', 'concept_values', 'mlm_skip_values', + # 'concept_value_masks', 'value_as_numbers', 'value_as_concepts', 'mlm_skip_values', # 'visit_concept_ids', "units" # ] def __init__(self, visit_occurrence): @@ -131,38 +131,18 @@ def _decorate(self, patient_events: DataFrame): # Create the concept_value_mask field to indicate whether domain values should be skipped # As of now only measurement has values, so other domains would be skipped. patient_events = patient_events.withColumn( - "concept_value_mask", (F.col("domain") == MEASUREMENT).cast("int") + "concept_value_mask", (F.col("domain").isin(MEASUREMENT, CATEGORICAL_MEASUREMENT)).cast("int") + ).withColumn( + "is_numeric_type", (F.col("domain") == MEASUREMENT).cast("int") ).withColumn( "mlm_skip_value", (F.col("domain").isin([MEASUREMENT, CATEGORICAL_MEASUREMENT])).cast("int"), ) - if "concept_value" not in patient_events.schema.fieldNames(): - patient_events = patient_events.withColumn("concept_value", F.lit(0.0)) + if "number_as_value" not in patient_events.schema.fieldNames(): + patient_events = patient_events.withColumn("number_as_value", F.lit(None).cast("float")) - # Split the categorical measurement standard_concept_id into the question/answer pairs - categorical_measurement_events = ( - patient_events.where(F.col("domain") == CATEGORICAL_MEASUREMENT) - .withColumn("measurement_components", F.split("standard_concept_id", "-")) - ) + if "concept_as_value" not in patient_events.schema.fieldNames(): + patient_events = patient_events.withColumn("concept_as_value", F.lit(None).cast("string")) - categorical_measurement_events_question = categorical_measurement_events.withColumn( - "standard_concept_id", - F.concat(F.lit(MEASUREMENT_QUESTION_PREFIX), F.col("measurement_components").getItem(0)) - ).drop("measurement_components") - - categorical_measurement_events_answer = categorical_measurement_events.withColumn( - "standard_concept_id", - F.concat(F.lit(MEASUREMENT_ANSWER_PREFIX), F.coalesce(F.col("measurement_components").getItem(1), F.lit("0"))) - ).drop("measurement_components") - - other_events = patient_events.where(F.col("domain") != CATEGORICAL_MEASUREMENT) - - # (cohort_member_id, person_id, standard_concept_id, date, datetime, visit_occurrence_id, domain, - # concept_value, visit_rank_order, visit_segment, priority, date_in_week, - # concept_value_mask, mlm_skip_value, age) - return other_events.unionByName( - categorical_measurement_events_question - ).unionByName( - categorical_measurement_events_answer - ) + return patient_events diff --git a/src/cehrbert_data/decorators/death_event_decorator.py b/src/cehrbert_data/decorators/death_event_decorator.py index 0fc7127..fc47961 100644 --- a/src/cehrbert_data/decorators/death_event_decorator.py +++ b/src/cehrbert_data/decorators/death_event_decorator.py @@ -43,12 +43,10 @@ def _decorate(self, patient_events: DataFrame): ) last_ve_record.cache() - last_ve_record.show() # set(['cohort_member_id', 'person_id', 'standard_concept_id', 'date', # 'visit_occurrence_id', 'domain', 'concept_value', 'visit_rank_order', # 'visit_segment', 'priority', 'date_in_week', 'concept_value_mask', # 'mlm_skip_value', 'age', 'visit_concept_id']) - artificial_visit_id = F.row_number().over( W.partitionBy(F.lit(0)).orderBy("person_id", "cohort_member_id") ) + F.col("max_visit_occurrence_id") @@ -59,7 +57,7 @@ def _decorate(self, patient_events: DataFrame): .withColumn("standard_concept_id", F.lit(DEATH_TOKEN)) .withColumn("domain", F.lit("death")) .withColumn("visit_rank_order", F.lit(1) + F.col("visit_rank_order")) - .withColumn("priority", DEATH_TOKEN_PRIORITY) + .withColumn("priority", F.lit(DEATH_TOKEN_PRIORITY)) .withColumn("event_group_id", F.lit(NA)) .drop("max_visit_occurrence_id") ) @@ -67,7 +65,7 @@ def _decorate(self, patient_events: DataFrame): vs_records = ( death_records .withColumn("standard_concept_id", F.lit(VS_TOKEN)) - .withColumn("priority", VS_TOKEN_PRIORITY) + .withColumn("priority", F.lit(VS_TOKEN_PRIORITY)) .withColumn("unit", F.lit(NA)) .withColumn("event_group_id", F.lit(NA)) ) @@ -75,7 +73,7 @@ def _decorate(self, patient_events: DataFrame): ve_records = ( death_records .withColumn("standard_concept_id", F.lit(VE_TOKEN)) - .withColumn("priority", VE_TOKEN_PRIORITY) + .withColumn("priority", F.lit(VE_TOKEN_PRIORITY)) .withColumn("unit", F.lit(NA)) .withColumn("event_group_id", F.lit(NA)) ) diff --git a/src/cehrbert_data/decorators/demographic_event_decorator.py b/src/cehrbert_data/decorators/demographic_event_decorator.py index 6c5db51..ab68b2e 100644 --- a/src/cehrbert_data/decorators/demographic_event_decorator.py +++ b/src/cehrbert_data/decorators/demographic_event_decorator.py @@ -21,8 +21,8 @@ def _decorate(self, patient_events: DataFrame): return patient_events # set(['cohort_member_id', 'person_id', 'standard_concept_id', 'date', - # 'visit_occurrence_id', 'domain', 'concept_value', 'visit_rank_order', - # 'visit_segment', 'priority', 'date_in_week', 'concept_value_mask', + # 'visit_occurrence_id', 'domain', 'value_as_number', 'value_as_concept', 'visit_rank_order', + # 'is_numeric_type', 'visit_segment', 'priority', 'date_in_week', 'concept_value_mask', # 'mlm_skip_value', 'age', 'visit_concept_id']) # Get the first token of the patient history @@ -39,7 +39,9 @@ def _decorate(self, patient_events: DataFrame): patient_first_token = ( patient_events.withColumn("token_order", first_token_udf) .withColumn("concept_value_mask", F.lit(0)) - .withColumn("concept_value", F.lit(0.0)) + .withColumn("number_as_value", F.lit(None).cast("float")) + .withColumn("concept_as_value", F.lit(None).cast("string")) + .withColumn("is_numeric_type", F.lit(0)) .withColumn("unit", F.lit(NA)) .withColumn("event_group_id", F.lit(NA)) .where("token_order = 1") diff --git a/src/cehrbert_data/decorators/patient_event_decorator_base.py b/src/cehrbert_data/decorators/patient_event_decorator_base.py index 77f4d5c..ed73866 100644 --- a/src/cehrbert_data/decorators/patient_event_decorator_base.py +++ b/src/cehrbert_data/decorators/patient_event_decorator_base.py @@ -37,7 +37,9 @@ def get_required_columns(cls) -> Set[str]: "datetime", "visit_occurrence_id", "domain", - "concept_value", + "concept_as_value", + "is_numeric_type", + "number_as_value", "visit_rank_order", "visit_segment", "priority", diff --git a/src/cehrbert_data/decorators/token_priority.py b/src/cehrbert_data/decorators/token_priority.py index bcf3af5..0c62f6f 100644 --- a/src/cehrbert_data/decorators/token_priority.py +++ b/src/cehrbert_data/decorators/token_priority.py @@ -8,6 +8,7 @@ ATT_TOKEN_PRIORITY = -3 VS_TOKEN_PRIORITY = -2 VISIT_TYPE_TOKEN_PRIORITY = -1 +FIRST_VISIT_HOUR_TOKEN_PRIORITY = -0.5 DEFAULT_PRIORITY = 0 DISCHARGE_TOKEN_PRIORITY = 100 DEATH_TOKEN_PRIORITY = 199 diff --git a/src/cehrbert_data/tools/ehrshot_to_omop.py b/src/cehrbert_data/tools/ehrshot_to_omop.py index 5579cb1..a27f216 100644 --- a/src/cehrbert_data/tools/ehrshot_to_omop.py +++ b/src/cehrbert_data/tools/ehrshot_to_omop.py @@ -98,7 +98,7 @@ def get_schema() -> t.StructType: t.StructField("code", t.StringType(), True), t.StructField("value", t.StringType(), True), t.StructField("unit", t.StringType(), True), - t.StructField("visit_id", t.LongType(), True), # Converted to IntegerType + t.StructField("visit_id", t.StringType(), True), # Converted to IntegerType t.StructField("omop_table", t.StringType(), True) ]) @@ -418,81 +418,177 @@ def convert_code_to_omop_concept( ).select(output_columns) -def generate_visit_id(data: DataFrame, time_interval: int = 12) -> DataFrame: +def generate_visit_id(data: DataFrame) -> DataFrame: """ - Generates unique `visit_id`s for each visit based on time intervals between events per patient. - - This function identifies distinct visits within a patient's event history by analyzing time gaps - between consecutive events. Events with gaps exceeding the specified `time_interval` (in hours) - are considered separate visits. A unique integer `visit_id` is then assigned to each visit - using a hash of the patient ID and visit order. - - Parameters - ---------- - data : DataFrame - A PySpark DataFrame containing at least the following columns: - - `patient_id`: Identifier for each patient. - - `start`: Start timestamp of the event. - - `end`: (Optional) End timestamp of the event. - - time_interval : int, optional, default=12 - The maximum time gap in hours between consecutive events within the same visit. If the time - difference between two events exceeds this interval, a new visit is assigned. - - Returns - ------- - DataFrame - The input DataFrame with an additional `visit_id` column, which is a unique integer identifier - for each visit, and each `patient_id`'s events are grouped according to visit. - """ - order_window = Window.partitionBy("patient_id").orderBy(f.col("start")) - - data = data.repartition(16).withColumn( - "patient_event_order", f.row_number().over(order_window) - ).withColumn( - "time", f.coalesce(f.col("end"), f.col("start")) + Generates unique `visit_id`s for each visit based on distinct patient event records. + + This function identifies records associated with actual visits (`visit_occurrence` table) and assigns + `visit_id`s to those records. For other event records without a `visit_id`, it attempts to link them to + existing visits based on overlapping date ranges. If no matching visit is found, it generates new `visit_id`s + for these orphan records and creates artificial visits. + + Parameters + ---------- + data : DataFrame + A PySpark DataFrame containing at least the following columns: + - `patient_id`: Identifier for each patient. + - `start`: Start timestamp of the event. + - `end`: (Optional) End timestamp of the event. + - `omop_table`: String specifying the type of event (e.g., "visit_occurrence" for real visits). + - `visit_id`: (Optional) Identifier for visits. May be missing in some records. + + Returns + ------- + DataFrame + A DataFrame with a `visit_id` assigned to each event record, including both real visits and artificial visits. + The returned DataFrame includes both the original records and any generated artificial visit records, + with each record grouped according to the identified visit. + + Steps + ----- + 1. **Identify Real Visits**: Filters out records from `visit_occurrence` and sets start and end dates for each visit. + 2. **Assign `visit_id`s to Other Records**: Attempts to link non-visit records (from other tables) to real visits + based on matching `patient_id` and date ranges. + 3. **Handle Orphan Records**: For records without a matching visit, assigns new `visit_id`s by grouping + records by patient and start date. + 4. **Create Artificial Visits**: Generates artificial visit records for orphan `visit_id`s. + 5. **Merge and Validate**: Combines the original records with artificial visits and validates the uniqueness of each `visit_id`. + """ + data = data.repartition(16) + real_visits = data.where( + f.col("omop_table") == "visit_occurrence" ).withColumn( - "prev_time", f.coalesce(f.lag(f.col("time")).over(order_window), f.col("time")) + "visit_start_date", + f.col("start").cast(t.DateType()) ).withColumn( - "hour_diff", (f.unix_timestamp("start") - f.unix_timestamp("prev_time")) / 3600 + "visit_end_date", + f.coalesce(f.col("end").cast(t.DateType()), f.col("visit_start_date")) + ) + + # Getting the records that do not have a visit_id + domain_records = data.where( + f.col("omop_table") != "visit_occurrence" ).withColumn( - "is_gap", (f.col("hour_diff") > time_interval).cast(t.IntegerType()) - ).drop( - "time", "prev_time", "hour_diff" + "record_id", + f.row_number().over(Window.orderBy(f.monotonically_increasing_id())) ) - cumulative_window = Window.partitionBy("patient_id").orderBy("patient_event_order").rowsBetween( - Window.unboundedPreceding, - Window.currentRow + # This is important to have a deterministic behavior for generating record_id + domain_records.cache() + + # Invalidate visit_id if the record's time stamp falls outside the visit start/end + domain_records = domain_records.alias("domain").join( + real_visits.where("code == 'Visit/IP'").alias("in_visit"), + (f.col("domain.patient_id") == f.col("in_visit.patient_id")) & + (f.col("domain.visit_id") == f.col("in_visit.visit_id")), + "left_outer" + ).withColumn( + "new_visit_id", + f.coalesce( + f.when( + f.col("domain.start").between( + f.date_sub(f.col("in_visit.start"), 1), f.date_add(f.col("in_visit.end"), 1) + ), + f.col("domain.visit_id") + ).otherwise(f.lit(None).cast(t.LongType())), + f.col("domain.visit_id") + ) + ).select( + [ + f.col("domain." + field).alias(field) + for field in domain_records.schema.fieldNames() if not field.endswith("visit_id") + ] + [f.col("new_visit_id").alias("visit_id")] ) - data = data.withColumn( - "visit_order", - f.sum("is_gap").over(cumulative_window) - ).drop( - "is_gap" + # Join the DataFrames with aliasing + domain_records = domain_records.alias("domain").join( + real_visits.alias("visit"), + (f.col("domain.patient_id") == f.col("visit.patient_id")) & + (f.col("domain.start").cast(t.DateType()).between(f.col("visit.visit_start_date"), + f.col("visit.visit_end_date"))), + "left_outer" + ).withColumn( + "ranking", + f.row_number().over(Window.partitionBy("domain.record_id").orderBy(f.col("visit.visit_start_date").desc())) + ).where( + f.col("ranking") == 1 + ).select( + [f.col("domain." + _).alias(_) for _ in domain_records.schema.fieldNames() if _ != "visit_id"] + + [f.coalesce(f.col("visit.visit_id"), f.col("domain.visit_id")).alias("visit_id")] ) - # We only allow the generated visit_ids associated with the visit_occurrence table - visit = data.where( - f.col("omop_table") == "visit_occurrence" - ).select("patient_id", "visit_order").distinct().withColumn( + max_visit_id_df = real_visits.select(f.max("visit_id").alias("max_visit_id")) + orphan_records = domain_records.where( + f.col("visit_id").isNull() + ).where( + f.col("omop_table") != "person" + ).crossJoin( + max_visit_id_df + ).withColumn( "new_visit_id", - f.abs( - f.hash(f.concat(f.col("patient_id").cast("string"), f.col("visit_order").cast("string"))) - ).cast("bigint") + f.dense_rank().over( + Window.orderBy(f.col("patient_id"), f.col("start").cast(t.DateType())) + ).cast(t.LongType()) + f.col("max_visit_id").cast(t.LongType()) + ).drop( + "visit_id" + ) + orphan_records.groupby("new_visit_id").agg( + f.countDistinct("patient_id").alias("pat_count") + ).select( + f.assert_true(f.col("pat_count") == 1) + ).collect() + + # Link the artificial visit_ids back to the domain_records + domain_records = domain_records.alias("domain").join( + orphan_records.alias("orphan").select( + f.col("orphan.record_id"), + f.col("orphan.new_visit_id"), + ), + f.col("domain.record_id") == f.col("orphan.record_id"), + "left_outer" + ).withColumn( + "update_visit_id", + f.coalesce(f.col("orphan.new_visit_id"), f.col("domain.visit_id")) + ).select( + [ + f.col("domain." + field).alias(field) + for field in domain_records.schema.fieldNames() if not field.endswith("visit_id") + ] + [f.col("update_visit_id").alias("visit_id")] + ).drop( + "record_id" ) + # Generate the artificial visits + artificial_visits = orphan_records.groupBy("new_visit_id", "patient_id").agg( + f.min("start").alias("start"), + f.max("end").alias("end") + ).withColumn( + "code", + f.lit(0) + ).withColumn( + "value", + f.lit(None).cast(t.StringType()) + ).withColumn( + "unit", + f.lit(None).cast(t.StringType()) + ).withColumn( + "omop_table", + f.lit("visit_occurrence") + ).withColumnRenamed( + "new_visit_id", "visit_id" + ).drop("record_id") + + # Drop visit_start_date and visit_end_date + real_visits = real_visits.drop("visit_start_date", "visit_end_date") + # Validate the uniqueness of visit_id - visit.groupby("new_visit_id").count().select(f.assert_true(f.col("count") == 1)) + artificial_visits.groupby("visit_id").count().select(f.assert_true(f.col("count") == 1)).collect() # Join the generated visit_id back to data - return data.join( - visit, - on=["patient_id", "visit_order"], - how="left_outer" - ).withColumn( - "visit_id", f.coalesce(f.col("new_visit_id"), f.col("visit_id")) - ).drop("visit_order", "patient_event_order", "new_visit_id") + return domain_records.unionByName( + real_visits + ).unionByName( + artificial_visits + ) def drop_duplicate_visits(data: DataFrame) -> DataFrame: @@ -549,7 +645,10 @@ def main(args): if args.refresh_ehrshot or not os.path.exists(ehr_shot_path): ehr_shot_data = spark.read.option("header", "true").schema(get_schema()).csv( args.ehr_shot_file - ) + ).withColumn( + "visit_id", + f.col("visit_id").cast(t.LongType()) + ).drop("_c0") # Add visit_id based on the time intervals between neighboring events ehr_shot_data = generate_visit_id( ehr_shot_data diff --git a/src/cehrbert_data/tools/extract_features.py b/src/cehrbert_data/tools/extract_features.py index d7264a7..d03c6c5 100644 --- a/src/cehrbert_data/tools/extract_features.py +++ b/src/cehrbert_data/tools/extract_features.py @@ -57,6 +57,11 @@ def create_feature_extraction_args(): '--bound_visit_end_date', action='store_true', ) + spark_args.add_argument( + "--include_inpatient_hour_token", + dest="include_inpatient_hour_token", + action="store_true", + ) return spark_args.parse_args() @@ -99,6 +104,7 @@ def main(args): with_drug_rollup=args.is_drug_roll_up_concept, include_concept_list=args.include_concept_list, refresh_measurement=args.refresh_measurement, + aggregate_by_hour=args.aggregate_by_hour, ) # Drop index_date because create_sequence_data_with_att does not expect this column @@ -108,45 +114,53 @@ def main(args): ).where(ehr_records["date"] <= cohort["index_date"]) ehr_records_temp_folder = get_temp_folder(args, "ehr_records") - ehr_records.write.mode("overwrite").parquet(ehr_records_temp_folder) + ehr_records.repartition("person_id").write.mode("overwrite").parquet(ehr_records_temp_folder) ehr_records = spark.read.parquet(ehr_records_temp_folder) visit_occurrence = spark.read.parquet(os.path.join(args.input_folder, "visit_occurrence")) + cohort_visit_occurrence = visit_occurrence.join( + cohort.select("person_id").distinct(), + "person_id" + ).withColumn( + "visit_end_date", + f.coalesce(f.col("visit_end_date"), f.col("visit_start_date")) + ).withColumn( + "visit_end_datetime", + f.coalesce( + f.col("visit_end_datetime"), + f.col("visit_end_date").cast(t.TimestampType()), + f.col("visit_start_datetime") + ) + ) # For each patient/index_date pair, we get the last record before the index_date # we get the corresponding visit_occurrence_id and index_date if args.bound_visit_end_date: - visit_occurrence_bound = ehr_records.withColumn( - "rn", - f.row_number().over(Window.partitionBy("person_id", "index_date").orderBy(f.desc("datetime"))) + cohort_visit_occurrence = cohort_visit_occurrence.withColumn( + "order", f.row_number().over(Window.orderBy(f.monotonically_increasing_id())) + ) + + visit_index_date = cohort_visit_occurrence.alias("visit").join( + cohort.alias("cohort"), + "person_id" ).where( - f.col("rn") == 1 + f.col("cohort.index_date").between(f.col("visit.visit_start_datetime"), f.col("visit.visit_end_datetime")) ).select( - "visit_occurrence_id", - "index_date", + f.col("visit.visit_occurrence_id").alias("visit_occurrence_id"), + f.col("cohort.index_date").alias("index_date"), ) + # Bound the visit_end_date and visit_end_datetime - visit_occurrence = visit_occurrence.join( - visit_occurrence_bound, + cohort_visit_occurrence = cohort_visit_occurrence.join( + visit_index_date, "visit_occurrence_id", "left_outer", ).withColumn( "visit_end_date", - f.coalesce(f.col("visit_end_date"), f.col("visit_start_date")) - ).withColumn( - "visit_end_datetime", - f.coalesce( - f.col("visit_end_datetime"), - f.col("visit_end_date").cast(t.TimestampType()), - f.col("visit_start_datetime") - ) - ).withColumn( - "visit_end_date", - f.least(f.col("visit_end_date"), f.col("index_date").cast(t.DateType())) + f.coalesce(f.col("index_date").cast(t.DateType()), f.col("visit_end_date")) ).withColumn( "visit_end_datetime", - f.least(f.col("visit_end_datetime"), f.col("index_date")) - ) - + f.coalesce(f.col("index_date"), f.col("visit_end_datetime")) + ).orderBy(f.col("order")).drop("order") birthdate_udf = f.coalesce( "birth_datetime", @@ -162,7 +176,7 @@ def main(args): age_udf = f.ceil(f.months_between(f.col("visit_start_date"), f.col("birth_datetime")) / f.lit(12)) visit_occurrence_person = ( - visit_occurrence + cohort_visit_occurrence .join(patient_demographic, "person_id") .withColumn("age", age_udf) .drop("birth_datetime") @@ -177,10 +191,11 @@ def main(args): patient_demographic=( patient_demographic if args.gpt_patient_sequence else None ), - att_type=AttType.DAY, - inpatient_att_type=AttType.DAY, + att_type=AttType(args.att_type), + inpatient_att_type=AttType(args.inpatient_att_type), exclude_demographic=args.exclude_demographic, - use_age_group=args.use_age_group + use_age_group=args.use_age_group, + include_inpatient_hour_token=args.include_inpatient_hour_token ) elif args.is_feature_concept_frequency: ehr_records = create_concept_frequency_data( diff --git a/src/cehrbert_data/utils/spark_parse_args.py b/src/cehrbert_data/utils/spark_parse_args.py index 0174466..c1e2fc8 100644 --- a/src/cehrbert_data/utils/spark_parse_args.py +++ b/src/cehrbert_data/utils/spark_parse_args.py @@ -335,6 +335,12 @@ def create_spark_args(parse: bool = True): action="store_true", help="Apply the filter to remove low-frequency concepts", ) + parser.add_argument( + "--aggregate_by_hour", + dest="aggregate_by_hour", + action="store_true", + help="Apply the aggregation on numeric labs by hour", + ) parser.add_argument( "--allow_measurement_only", dest="allow_measurement_only", diff --git a/src/cehrbert_data/utils/spark_utils.py b/src/cehrbert_data/utils/spark_utils.py index 77791ed..288e0b6 100644 --- a/src/cehrbert_data/utils/spark_utils.py +++ b/src/cehrbert_data/utils/spark_utils.py @@ -68,7 +68,7 @@ "measurement" ) ], - "death_date": [("death_concept_id", "death_date", "death_datetime", "death")], + "death_date": [("cause_concept_id", "death_date", "death_datetime", "death")], "visit_concept_id": [ ("visit_concept_id", "visit_start_date", "visit"), ("discharged_to_concept_id", "visit_end_date", "visit"), @@ -178,7 +178,8 @@ def join_domain_tables(domain_tables: List[DataFrame]) -> DataFrame: filtered_domain_table["visit_occurrence_id"], F.lit(table_domain_field).alias("domain"), F.lit(None).cast("string").alias("event_group_id"), - F.lit(0.0).alias("concept_value"), + F.lit(None).cast("float").alias("number_as_value"), + F.lit(None).cast("string").alias("concept_as_value"), F.col("unit") if domain_has_unit(filtered_domain_table) else F.lit(NA).alias("unit"), ).distinct() @@ -716,7 +717,6 @@ def create_sequence_data_with_att( "concept_order", "priority", "datetime", - "event_group_id", "standard_concept_id", ) ) @@ -737,14 +737,15 @@ def create_sequence_data_with_att( "age", "visit_rank_order", "concept_value_mask", - "concept_value", + "number_as_value", + "concept_as_value", + "is_numeric_type", "mlm_skip_value", "visit_concept_id", "visit_concept_order", "concept_order", "priority", "unit", - "event_group_id", ] output_columns = [ "cohort_member_id", @@ -758,7 +759,9 @@ def create_sequence_data_with_att( "num_of_visits", "num_of_concepts", "concept_value_masks", - "concept_values", + "number_as_values", + "concept_as_values", + "is_numeric_types", "mlm_skip_values", "priorities", "visit_concept_ids", @@ -766,7 +769,6 @@ def create_sequence_data_with_att( "concept_orders", "record_ranks", "units", - "event_group_ids", ] patient_grouped_events = ( @@ -793,13 +795,13 @@ def create_sequence_data_with_att( .withColumn("concept_orders", F.col("data_for_sorting.concept_order")) .withColumn("priorities", F.col("data_for_sorting.priority")) .withColumn("concept_value_masks", F.col("data_for_sorting.concept_value_mask")) - .withColumn("concept_values", F.col("data_for_sorting.concept_value")) + .withColumn("number_as_values", F.col("data_for_sorting.number_as_value")) + .withColumn("concept_as_values", F.col("data_for_sorting.concept_as_value")) + .withColumn("is_numeric_types", F.col("data_for_sorting.is_numeric_type")) .withColumn("mlm_skip_values", F.col("data_for_sorting.mlm_skip_value")) .withColumn("visit_concept_ids", F.col("data_for_sorting.visit_concept_id")) .withColumn("units", F.col("data_for_sorting.unit")) - .withColumn("event_group_ids", F.col("data_for_sorting.event_group_id")) ) - return patient_grouped_events.select(output_columns) @@ -831,14 +833,15 @@ def create_concept_frequency_data(patient_event, date_filter=None): def extract_ehr_records( - spark, - input_folder, - domain_table_list, - include_visit_type=False, - with_diagnosis_rollup=False, - with_drug_rollup=True, - include_concept_list=False, - refresh_measurement=False + spark: SparkSession, + input_folder: str, + domain_table_list: List[str], + include_visit_type: bool = False, + with_diagnosis_rollup: bool = False, + with_drug_rollup: bool = True, + include_concept_list: bool = False, + refresh_measurement: bool = False, + aggregate_by_hour: bool = False, ): """ Extract the ehr records for domain_table_list from input_folder. @@ -851,6 +854,7 @@ def extract_ehr_records( :param with_drug_rollup: whether ot not to roll up the drug concepts to the parent levels :param include_concept_list: :param refresh_measurement: + :param aggregate_by_hour: :return: """ domain_tables = [] @@ -872,7 +876,6 @@ def extract_ehr_records( qualified_concepts = preprocess_domain_table(spark, input_folder, QUALIFIED_CONCEPT_LIST_PATH).select( "standard_concept_id" ) - patient_ehr_records = patient_ehr_records.join(qualified_concepts, "standard_concept_id") # Process the measurement table if exists @@ -880,7 +883,8 @@ def extract_ehr_records( processed_measurement = get_measurement_table( spark, input_folder, - refresh=refresh_measurement + refresh=refresh_measurement, + aggregate_by_hour=aggregate_by_hour, ) if patient_ehr_records: # Union all measurement records together with other domain records @@ -914,7 +918,8 @@ def extract_ehr_records( patient_ehr_records["visit_occurrence_id"], patient_ehr_records["domain"], patient_ehr_records["unit"], - patient_ehr_records["concept_value"], + patient_ehr_records["number_as_value"], + patient_ehr_records["concept_as_value"], patient_ehr_records["event_group_id"], visit_occurrence["visit_concept_id"], patient_ehr_records["age"], @@ -1068,7 +1073,7 @@ def create_hierarchical_sequence_data( F.coalesce(patient_events["standard_concept_id"], F.lit(UNKNOWN_CONCEPT)).alias("standard_concept_id"), F.coalesce(patient_events["date"], visit_occurrence["visit_start_date"]).alias("date"), F.coalesce(patient_events["domain"], F.lit("unknown")).alias("domain"), - F.coalesce(patient_events["concept_value"], F.lit(-1.0)).alias("concept_value"), + F.coalesce(patient_events["number_as_value"], F.lit(-1.0)).alias("number_as_value"), ] # Convert standard_concept_id to string type, this is needed for the tokenization @@ -1125,7 +1130,7 @@ def create_hierarchical_sequence_data( .withColumn("visit_concept_order", F.lit(0)) .withColumn("date", F.col("visit_start_date")) .withColumn("concept_value_mask", F.lit(0)) - .withColumn("concept_value", F.lit(-1.0)) + .withColumn("number_as_value", F.lit(-1.0)) .withColumn("mlm_skip", F.lit(1)) .withColumn("condition_mask", F.lit(0)) ) @@ -1137,7 +1142,7 @@ def create_hierarchical_sequence_data( "date_in_week", "age", "concept_value_mask", - "concept_value", + "number_as_value", "mlm_skip", "condition_mask", ] @@ -1365,7 +1370,8 @@ def clean_up_unit(dataframe: DataFrame) -> DataFrame: def get_measurement_table( spark: SparkSession, input_folder: str, - refresh: bool = False + refresh: bool = False, + aggregate_by_hour: bool = False, ) -> DataFrame: """ A helper function to process and create the measurement table @@ -1373,6 +1379,7 @@ def get_measurement_table( :param spark: :param input_folder: :param refresh: + :param aggregate_by_hour: :return: """ @@ -1396,7 +1403,7 @@ def get_measurement_table( processed_measurement = preprocess_domain_table(spark, input_folder, PROCESSED_MEASUREMENT) else: processed_measurement = process_measurement( - spark, measurement, required_measurement, measurement_stats, concept + spark, measurement, required_measurement, measurement_stats, concept, aggregate_by_hour ) processed_measurement.write.mode("overwrite").parquet(os.path.join(input_folder, PROCESSED_MEASUREMENT)) @@ -1408,7 +1415,8 @@ def process_measurement( measurement: DataFrame, required_measurement: DataFrame, measurement_stats: DataFrame, - concept: DataFrame + concept: DataFrame, + aggregate_by_hour: bool = False ): """ Preprocess the measurement table and only include the measurements whose measurement_concept_ids are specified @@ -1436,10 +1444,11 @@ def process_measurement( m.measurement_concept_id AS standard_concept_id, CAST(m.measurement_date AS DATE) AS date, CAST(COALESCE(m.measurement_datetime, m.measurement_date) AS TIMESTAMP) AS datetime, - m.visit_occurrence_id, + m.visit_occurrence_id AS visit_occurrence_id, 'measurement' AS domain, CAST(NULL AS STRING) AS event_group_id, - m.value_as_number AS concept_value, + m.value_as_number AS number_as_value, + CAST(NULL AS STRING) AS concept_as_value, c.concept_code AS unit FROM measurement AS m JOIN measurement_unit_stats AS s @@ -1451,22 +1460,37 @@ def process_measurement( AND m.value_as_number BETWEEN s.lower_bound AND s.upper_bound """ ) - numeric_lab = clean_up_unit(numeric_lab) + if aggregate_by_hour: + numeric_lab = numeric_lab.withColumn("lab_hour", F.hour("datetime")) + numeric_lab = numeric_lab.groupby( + "person_id", "visit_occurrence_id", "standard_concept_id", "unit", "date", "lab_hour" + ).agg( + F.min("datetime").alias("datetime"), + F.avg("number_as_value").alias("number_as_value"), + ).withColumn( + "domain", F.lit("measurement").cast("string") + ).withColumn( + "concept_as_value", F.lit(None).cast("string") + ).withColumn( + "event_group_id", F.lit(None).cast("string") + ).drop("lab_hour") + numeric_lab = clean_up_unit(numeric_lab) # For categorical measurements in required_measurement, we concatenate measurement_concept_id # with value_as_concept_id to construct a new standard_concept_id categorical_lab = spark.sql( """ SELECT m.person_id, - CONCAT(CAST(measurement_concept_id AS STRING), '-', CAST(COALESCE(value_as_concept_id, 0) AS STRING)) AS standard_concept_id, + measurement_concept_id AS standard_concept_id, CAST(m.measurement_date AS DATE) AS date, CAST(COALESCE(m.measurement_datetime, m.measurement_date) AS TIMESTAMP) AS datetime, - m.visit_occurrence_id, + m.visit_occurrence_id AS visit_occurrence_id, 'categorical_measurement' AS domain, CONCAT('mea-', CAST(m.measurement_id AS STRING)) AS event_group_id, - 0.0 AS concept_value, - 'N/A' AS unit, + CAST(NULL AS FLOAT) AS number_as_value, + CAST(COALESCE(value_as_concept_id, 0) AS STRING) AS concept_as_value, + 'N/A' AS unit FROM measurement AS m WHERE EXISTS ( SELECT @@ -1477,7 +1501,7 @@ def process_measurement( ) """ ) - return numeric_lab.unionAll(categorical_lab) + return numeric_lab.unionByName(categorical_lab) def get_mlm_skip_domains(spark, input_folder, mlm_skip_table_list): diff --git a/tests/integration_tests/test_generate_training_data.py b/tests/integration_tests/test_generate_training_data.py index c9ba102..81970c1 100644 --- a/tests/integration_tests/test_generate_training_data.py +++ b/tests/integration_tests/test_generate_training_data.py @@ -23,6 +23,7 @@ def test_run_pyspark_app(self): gpt_patient_sequence=True, apply_age_filter=True, include_death=False, + include_inpatient_hour_token=True, att_type=AttType.DAY, inpatient_att_type=AttType.DAY, ) diff --git a/tests/unit_tests/test_ehrshot_to_omop.py b/tests/unit_tests/test_ehrshot_to_omop.py index acf0354..6de39b1 100644 --- a/tests/unit_tests/test_ehrshot_to_omop.py +++ b/tests/unit_tests/test_ehrshot_to_omop.py @@ -21,6 +21,7 @@ class EHRShotUnitTest(unittest.TestCase): def setUpClass(cls): # Initialize the Spark session for testing cls.spark = SparkSession.builder.appName("ehr_shot").getOrCreate() + cls.spark.conf.set("spark.sql.analyzer.failAmbiguousSelfJoin", False) @classmethod def tearDownClass(cls): @@ -56,30 +57,34 @@ def test_generate_visit_id(self): StructField("end", TimestampType(), True), StructField("visit_id", IntegerType(), True), StructField("omop_table", StringType(), True), + StructField("code", StringType(), True), + StructField("unit", StringType(), True), + StructField("value", StringType(), True), ]) # Sample data with multiple events for each patient and different time gaps data = [ - (1, datetime(2023, 1, 1, 8), datetime(2023, 1, 1, 9), None, "visit_occurrence"), - (1, datetime(2023, 1, 1, 20), datetime(2023, 1, 1, 20), None, "condition_occurrence"), # 11-hour gap (merged visit) - (1, datetime(2023, 1, 2, 20), datetime(2023, 1, 2, 20), None, "visit_occurrence"), # another visit - (2, datetime(2023, 1, 1, 8), datetime(2023, 1, 1, 9), None, "visit_occurrence"), - (2, datetime(2023, 1, 1, 10), datetime(2023, 1, 1, 11), None, "condition_occurrence"), # same visit for patient 2 - (3, datetime(2023, 1, 1, 8), datetime(2023, 1, 1, 9), 1000, "visit_occurrence"), - (4, datetime(2023, 1, 1, 8), datetime(2023, 1, 1, 9), None, "condition_occurrence") + (1, datetime(2023, 1, 1, 8), datetime(2023, 1, 1, 9), 1, "visit_occurrence", None, None, None), + (1, datetime(2023, 1, 1, 20), datetime(2023, 1, 1, 20), None, "condition_occurrence", None, None, None), # 11-hour gap (merged visit) + (1, datetime(2023, 1, 2, 20), datetime(2023, 1, 2, 20), 2, "visit_occurrence", None, None, None), # another visit + (2, datetime(2023, 1, 1, 8), datetime(2023, 1, 1, 9), 3, "visit_occurrence", None, None, None), + (2, datetime(2023, 1, 1, 10), datetime(2023, 1, 1, 11), None, "condition_occurrence", None, None, None), # same visit for patient 2 + (3, datetime(2023, 1, 1, 8), datetime(2023, 1, 1, 9), 1000, "visit_occurrence", None, None, None), + (4, datetime(2023, 1, 1, 8), datetime(2023, 1, 1, 9), None, "condition_occurrence", None, None, None) ] # Create DataFrame data = self.spark.createDataFrame(data, schema=schema) # Run the function to generate visit IDs - result_df = generate_visit_id(data, time_interval=12) + result_df = generate_visit_id(data) + result_df.show() # Validate the number of visits - self.assertEqual(4, result_df.select("visit_id").where(f.col("visit_id").isNotNull()).distinct().count()) - self.assertEqual(7, result_df.count()) + self.assertEqual(5, result_df.select("visit_id").where(f.col("visit_id").isNotNull()).distinct().count()) + self.assertEqual(8, result_df.count()) # Check that visit_id was generated as an integer (bigint) - self.assertEqual( - result_df.schema["visit_id"].dataType.simpleString(), "bigint", + self.assertIn( + result_df.schema["visit_id"].dataType.simpleString(), ["int", "bigint"], "visit_id should be of type bigint" ) @@ -103,7 +108,7 @@ def test_generate_visit_id(self): patient_4_visits = result_df.filter(f.col("patient_id") == 4).select("visit_id").collect()[0].visit_id self.assertEqual( - None, patient_4_visits, "Patient 4 should have a null visit_id." + 1001, patient_4_visits, "Patient 4 should have one generated visit_id." ) def test_drop_duplicate_visits(self):