From 5da69849dd674823e63d3d34a7a1951f8434a67c Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Mon, 13 Jan 2025 13:38:35 -0500 Subject: [PATCH] removed a query that does not affect the results --- src/cehrbert_data/tools/ehrshot_to_omop.py | 36 +++------------------- 1 file changed, 5 insertions(+), 31 deletions(-) diff --git a/src/cehrbert_data/tools/ehrshot_to_omop.py b/src/cehrbert_data/tools/ehrshot_to_omop.py index 0023380..2cf8371 100644 --- a/src/cehrbert_data/tools/ehrshot_to_omop.py +++ b/src/cehrbert_data/tools/ehrshot_to_omop.py @@ -493,39 +493,11 @@ def generate_visit_id( domain_records.write.mode("overwrite").parquet(temp_domain_records_folder) domain_records = spark.read.parquet(temp_domain_records_folder) - # Invalidate visit_id if the record's time stamp falls outside the visit start/end - domain_records = domain_records.alias("domain").join( - real_visits.where(f.col("code").isin(['Visit/IP', 'Visit/ERIP'])).alias("in_visit"), - (f.col("domain.patient_id") == f.col("in_visit.patient_id")) & - (f.col("domain.visit_id") == f.col("in_visit.visit_id")), - "left_outer" - ).withColumn( - "new_visit_id", - f.coalesce( - f.when( - f.col("domain.start").between( - f.date_sub(f.col("in_visit.start"), day_cutoff), - f.date_add(f.col("in_visit.end"), day_cutoff) - ), - f.col("domain.visit_id") - ).otherwise(f.lit(None).cast(t.LongType())), - f.col("domain.visit_id") - ) - ).select( - [ - f.col("domain." + field).alias(field) - for field in domain_records.schema.fieldNames() if not field.endswith("visit_id") - ] + [f.col("new_visit_id").alias("visit_id")] - ) - # Join the records to the nearest visits if they occur within the visit span with aliasing domain_records = domain_records.alias("domain").join( - real_visits.alias("visit"), + real_visits.where(f.col("code").isin(['Visit/IP', 'Visit/ERIP'])).alias("visit"), (f.col("domain.patient_id") == f.col("visit.patient_id")) & - (f.col("domain.start").between( - f.col("visit.visit_start_datetime"), - f.col("visit.visit_end_datetime")) - ), + (f.col("domain.start").between(f.col("visit.start"), f.col("visit.end"))), "left_outer" ).withColumn( "ranking", @@ -832,7 +804,9 @@ def main(args): ).where(f.col("omop_table") == "visit_occurrence") # We don't use the end column to get the max end because some end datetime could be years apart from the start date outpatient_visit_start_end = ehr_shot_data.join(outpatient_visits.select("visit_id"), "visit_id").where( - f.col("omop_table").isin(["condition_occurrence", "procedure_occurrence", "drug_exposure", "measurement"]) + f.col("omop_table").isin( + ["condition_occurrence", "procedure_occurrence", "drug_exposure", "measurement", "observation", "death"] + ) ).groupby("visit_id").agg(f.min("start").alias("start"), f.max("start").alias("end")).withColumn( "hour_diff", (f.unix_timestamp("end") - f.unix_timestamp("start")) / 3600 ).withColumn(