diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 47a484d..4e1d6e1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,6 +23,11 @@ jobs: uses: actions/setup-python@v3 with: python-version: "3.10" + - name: Set up Java 11 + uses: actions/setup-java@v3 + with: + java-version: "11" # specify the Java version here + distribution: "temurin" # or use 'adopt' or 'zulu', depending on your preference - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/src/cehrbert_data/apps/generate_training_data.py b/src/cehrbert_data/apps/generate_training_data.py index 9b05247..85a79fc 100644 --- a/src/cehrbert_data/apps/generate_training_data.py +++ b/src/cehrbert_data/apps/generate_training_data.py @@ -173,6 +173,8 @@ def main( exclude_demographic=exclude_demographic, use_age_group=use_age_group, include_inpatient_hour_token=include_inpatient_hour_token, + spark=spark, + persistence_folder=output_folder, ) else: sequence_data = create_sequence_data( diff --git a/src/cehrbert_data/decorators/artificial_time_token_decorator.py b/src/cehrbert_data/decorators/artificial_time_token_decorator.py index a45d160..727abf6 100644 --- a/src/cehrbert_data/decorators/artificial_time_token_decorator.py +++ b/src/cehrbert_data/decorators/artificial_time_token_decorator.py @@ -1,4 +1,6 @@ -from pyspark.sql import DataFrame, functions as F, types as T, Window as W +import os.path + +from pyspark.sql import SparkSession, DataFrame, functions as F, types as T, Window as W from ..const.common import NA from ..const.artificial_tokens import VS_TOKEN, VE_TOKEN @@ -26,6 +28,8 @@ def __init__( att_type: AttType, inpatient_att_type: AttType, include_inpatient_hour_token: bool = False, + spark: SparkSession = None, + persistence_folder: str = None, ): self._visit_occurrence = visit_occurrence self._include_visit_type = include_visit_type @@ -33,6 +37,10 @@ def __init__( self._att_type = att_type self._inpatient_att_type = inpatient_att_type self._include_inpatient_hour_token = include_inpatient_hour_token + super().__init__(spark=spark, persistence_folder=persistence_folder) + + def get_name(self): + return "att_events" def _decorate(self, patient_events: DataFrame): if self._att_type == AttType.NONE: @@ -53,9 +61,21 @@ def _decorate(self, patient_events: DataFrame): F.max("concept_order").alias("max_concept_order"), ) + # The visit records are joined to the cohort members (there could be multiple entries for the same patient) + # if multiple entries are present, we duplicate the visit records for those. If the visit_occurrence dataframe + # contains visits for each cohort member, then we need to add cohort_member_id to the joined expression as well. + if "cohort_member_id" in self._visit_occurrence.columns: + joined_expr = ["person_id", "cohort_member_id"] + else: + joined_expr = ["person_id"] + visit_occurrence = ( - self._visit_occurrence.select( + self._visit_occurrence.join( + cohort_member_person_pair, + joined_expr + ).select( "person_id", + "cohort_member_id", F.col("visit_start_date").cast(T.DateType()).alias("date"), F.col("visit_start_date").cast(T.DateType()).alias("visit_start_date"), F.col("visit_start_datetime").cast(T.TimestampType()).alias("visit_start_datetime"), @@ -71,8 +91,10 @@ def _decorate(self, patient_events: DataFrame): "age", "discharged_to_concept_id", ) - .join(valid_visit_ids, "visit_occurrence_id") - .join(cohort_member_person_pair, ["person_id", "cohort_member_id"]) + .join( + valid_visit_ids, + ["visit_occurrence_id", "cohort_member_id"] + ) ) # We assume outpatient visits end on the same day, therefore we start visit_end_date to visit_start_date due @@ -89,7 +111,10 @@ def _decorate(self, patient_events: DataFrame): visit_occurrence = visit_occurrence.withColumn("date_in_week", weeks_since_epoch_udf) # Cache visit for faster processing - visit_occurrence.cache() + visit_occurrence = self.try_persist_data( + visit_occurrence, + os.path.join(self.get_name(), "visit_occurrence_temp"), + ) visits = visit_occurrence.drop("discharged_to_concept_id") @@ -180,6 +205,12 @@ def _decorate(self, patient_events: DataFrame): artificial_tokens = artificial_tokens.drop("visit_end_date") + # Try persisting artificial events + artificial_tokens = self.try_persist_data( + artificial_tokens, + os.path.join(self.get_name(), "artificial_tokens"), + ) + # Retrieving the events that are ONLY linked to inpatient visits inpatient_visits = ( visit_occurrence @@ -235,8 +266,10 @@ def _decorate(self, patient_events: DataFrame): # Add discharge events to the inpatient visits inpatient_events = inpatient_events.unionByName(discharge_events) - # Try caching the inpatient events - inpatient_events.cache() + # Try persisting the inpatient events for fasting processing + inpatient_events = self.try_persist_data( + inpatient_events, os.path.join(self.get_name(), "inpatient_events") + ) # Get the prev days_since_epoch inpatient_prev_date_udf = F.lag("date").over( @@ -285,11 +318,11 @@ def _decorate(self, patient_events: DataFrame): # Create ATT tokens within the inpatient visits inpatient_att_events = ( inpatient_events.withColumn( - "time_stamp_hour", F.hour("datetime") + "time_stamp_hour", F.hour("datetime") ).withColumn( "is_span_boundary", F.row_number().over( - W.partitionBy("cohort_member_id", "visit_occurrence_id", "concept_order") + W.partitionBy("cohort_member_id", "visit_occurrence_id") .orderBy("priority", "date", "time_stamp_hour") ), ) @@ -343,6 +376,11 @@ def _decorate(self, patient_events: DataFrame): .drop("prev_date", "time_delta", "is_span_boundary") ) + # Try persisting the inpatient att events + inpatient_att_events = self.try_persist_data( + inpatient_att_events, os.path.join(self.get_name(), "inpatient_att_events") + ) + self.validate(inpatient_events) self.validate(inpatient_att_events) @@ -352,6 +390,10 @@ def _decorate(self, patient_events: DataFrame): ["visit_occurrence_id", "cohort_member_id"], how="left_anti", ) + # Try persisting the other events + other_events = self.try_persist_data( + other_events, os.path.join(self.get_name(), "other_events") + ) patient_events = inpatient_events.unionByName(inpatient_att_events).unionByName(other_events) diff --git a/src/cehrbert_data/decorators/clinical_event_decorator.py b/src/cehrbert_data/decorators/clinical_event_decorator.py index 69f8c8e..bb4cec0 100644 --- a/src/cehrbert_data/decorators/clinical_event_decorator.py +++ b/src/cehrbert_data/decorators/clinical_event_decorator.py @@ -1,10 +1,10 @@ +import os.path + from ..const.common import ( MEASUREMENT, CATEGORICAL_MEASUREMENT, - MEASUREMENT_QUESTION_PREFIX, - MEASUREMENT_ANSWER_PREFIX ) -from pyspark.sql import DataFrame, functions as F, Window as W, types as T +from pyspark.sql import SparkSession, DataFrame, functions as F, Window as W, types as T from .patient_event_decorator_base import PatientEventDecorator from .token_priority import DEFAULT_PRIORITY @@ -17,8 +17,12 @@ class ClinicalEventDecorator(PatientEventDecorator): # 'concept_value_masks', 'value_as_numbers', 'value_as_concepts', 'mlm_skip_values', # 'visit_concept_ids', "units" # ] - def __init__(self, visit_occurrence): + def __init__(self, visit_occurrence, spark: SparkSession = None, persistence_folder: str = None): self._visit_occurrence = visit_occurrence + super().__init__(spark=spark, persistence_folder=persistence_folder) + + def get_name(self): + return "clinical_events" def _decorate(self, patient_events: DataFrame): """ @@ -43,10 +47,18 @@ def _decorate(self, patient_events: DataFrame): visit_segment_udf = F.col("visit_rank_order") % F.lit(2) + 1 # The visit records are joined to the cohort members (there could be multiple entries for the same patient) - # if multiple entries are present, we duplicate the visit records for those. + # if multiple entries are present, we duplicate the visit records for those. If the visit_occurrence dataframe + # contains visits for each cohort member, then we need to add cohort_member_id to the joined expression as well. + if "cohort_member_id" in self._visit_occurrence.columns: + joined_expr = ["visit_occurrence_id", "cohort_member_id"] + else: + joined_expr = ["visit_occurrence_id"] + visits = ( - self._visit_occurrence.join(valid_visit_ids, "visit_occurrence_id") - .select( + self._visit_occurrence.join( + valid_visit_ids, + joined_expr + ).select( "person_id", "cohort_member_id", "visit_occurrence_id", @@ -90,6 +102,15 @@ def _decorate(self, patient_events: DataFrame): ), ).otherwise(F.col("visit_start_date")) + # We need to set the visit_start_datetime at the beginning of the visit start date + # because there could be outpatient visit records whose visit_start_datetime is set to the end of the day + visit_start_datetime_udf = ( + F.when( + F.col("is_inpatient") == 0, + F.col("visit_start_date") + ).otherwise(F.col("visit_start_datetime")) + ).cast(T.TimestampType()) + # We need to bound the medical event dates between visit_start_date and visit_end_date bound_medical_event_date = F.when( F.col("date") < F.col("visit_start_date"), F.col("visit_start_date") @@ -108,8 +129,10 @@ def _decorate(self, patient_events: DataFrame): patient_events = ( patient_events.join(visits, ["cohort_member_id", "visit_occurrence_id"]) + .withColumn("datetime", F.to_timestamp("datetime")) + .withColumn("visit_start_datetime", visit_start_datetime_udf) .withColumn("visit_end_date", visit_end_date_udf) - .withColumn("visit_end_datetime", F.date_add("visit_end_date", 1)) + .withColumn("visit_end_datetime", F.date_add("visit_end_date", 1).cast(T.TimestampType())) .withColumn("visit_end_datetime", F.expr("visit_end_datetime - INTERVAL 1 MINUTE")) .withColumn("date", bound_medical_event_date) .withColumn("datetime", bound_medical_event_datetime) @@ -145,4 +168,10 @@ def _decorate(self, patient_events: DataFrame): if "concept_as_value" not in patient_events.schema.fieldNames(): patient_events = patient_events.withColumn("concept_as_value", F.lit(None).cast("string")) + # Try persisting the clinical events + patient_events = self.try_persist_data( + patient_events, + self.get_name() + ) + return patient_events diff --git a/src/cehrbert_data/decorators/death_event_decorator.py b/src/cehrbert_data/decorators/death_event_decorator.py index fc47961..2a9303e 100644 --- a/src/cehrbert_data/decorators/death_event_decorator.py +++ b/src/cehrbert_data/decorators/death_event_decorator.py @@ -1,4 +1,5 @@ -from pyspark.sql import DataFrame, functions as F, Window as W, types as T +import os +from pyspark.sql import SparkSession, DataFrame, functions as F, Window as W, types as T from ..const.common import NA from ..const.artificial_tokens import VS_TOKEN, VE_TOKEN, DEATH_TOKEN @@ -20,9 +21,13 @@ class DeathEventDecorator(PatientEventDecorator): - def __init__(self, death, att_type): + def __init__(self, death, att_type, spark: SparkSession = None, persistence_folder: str = None): self._death = death self._att_type = att_type + super().__init__(spark=spark, persistence_folder=persistence_folder) + + def get_name(self): + return "death_tokens" def _decorate(self, patient_events: DataFrame): if self._death is None: @@ -32,7 +37,7 @@ def _decorate(self, patient_events: DataFrame): max_visit_occurrence_id = death_records.select(F.max("visit_occurrence_id").alias("max_visit_occurrence_id")) - last_ve_record = ( + last_ve_events = ( death_records.where(F.col("standard_concept_id") == VE_TOKEN) .withColumn( "record_rank", @@ -42,7 +47,7 @@ def _decorate(self, patient_events: DataFrame): .drop("record_rank") ) - last_ve_record.cache() + last_ve_events.cache() # set(['cohort_member_id', 'person_id', 'standard_concept_id', 'date', # 'visit_occurrence_id', 'domain', 'concept_value', 'visit_rank_order', # 'visit_segment', 'priority', 'date_in_week', 'concept_value_mask', @@ -52,7 +57,7 @@ def _decorate(self, patient_events: DataFrame): ) + F.col("max_visit_occurrence_id") death_records = ( - last_ve_record.crossJoin(max_visit_occurrence_id) + last_ve_events.crossJoin(max_visit_occurrence_id) .withColumn("visit_occurrence_id", artificial_visit_id) .withColumn("standard_concept_id", F.lit(DEATH_TOKEN)) .withColumn("domain", F.lit("death")) @@ -107,6 +112,10 @@ def _decorate(self, patient_events: DataFrame): new_tokens = death_events.unionByName(vs_records).unionByName(death_records).unionByName(ve_records) new_tokens = new_tokens.drop("death_date") + new_tokens = self.try_persist_data( + new_tokens, + os.path.join(self.get_name(), "death_events") + ) self.validate(new_tokens) return patient_events.unionByName(new_tokens) diff --git a/src/cehrbert_data/decorators/demographic_event_decorator.py b/src/cehrbert_data/decorators/demographic_event_decorator.py index ab68b2e..3760b9a 100644 --- a/src/cehrbert_data/decorators/demographic_event_decorator.py +++ b/src/cehrbert_data/decorators/demographic_event_decorator.py @@ -1,4 +1,5 @@ -from pyspark.sql import DataFrame, functions as F, Window as W, types as T +import os +from pyspark.sql import SparkSession, DataFrame, functions as F, Window as W, types as T from .patient_event_decorator_base import PatientEventDecorator @@ -12,9 +13,18 @@ class DemographicEventDecorator(PatientEventDecorator): - def __init__(self, patient_demographic, use_age_group: bool = False): + def __init__( + self, patient_demographic, + use_age_group: bool = False, + spark: SparkSession = None, + persistence_folder: str = None + ): self._patient_demographic = patient_demographic self._use_age_group = use_age_group + super().__init__(spark=spark, persistence_folder=persistence_folder) + + def get_name(self): + return "demographic_events" def _decorate(self, patient_events: DataFrame): if self._patient_demographic is None: @@ -63,7 +73,10 @@ def _decorate(self, patient_events: DataFrame): .withColumn("concept_order", F.lit(0)) ) - sequence_start_year_token.cache() + # Try persisting the start year tokens + sequence_start_year_token = self.try_persist_data( + sequence_start_year_token, os.path.join(self.get_name(), "sequence_start_year_tokens") + ) if self._use_age_group: calculate_age_group_at_first_visit_udf = F.ceil( @@ -89,6 +102,11 @@ def _decorate(self, patient_events: DataFrame): .drop("birth_datetime") ) + # Try persisting the age tokens + sequence_age_token = self.try_persist_data( + sequence_age_token, os.path.join(self.get_name(), "sequence_age_tokens") + ) + sequence_gender_token = ( self._patient_demographic.select(F.col("person_id"), F.col("gender_concept_id")) .join(sequence_start_year_token, "person_id") @@ -97,6 +115,11 @@ def _decorate(self, patient_events: DataFrame): .drop("gender_concept_id") ) + # Try persisting the gender tokens + sequence_gender_token = self.try_persist_data( + sequence_gender_token, os.path.join(self.get_name(), "sequence_gender_tokens") + ) + sequence_race_token = ( self._patient_demographic.select(F.col("person_id"), F.col("race_concept_id")) .join(sequence_start_year_token, "person_id") @@ -105,6 +128,11 @@ def _decorate(self, patient_events: DataFrame): .drop("race_concept_id") ) + # Try persisting the race tokens + sequence_race_token = self.try_persist_data( + sequence_race_token, os.path.join(self.get_name(), "sequence_race_tokens") + ) + patient_events = patient_events.unionByName(sequence_start_year_token) patient_events = patient_events.unionByName(sequence_age_token) patient_events = patient_events.unionByName(sequence_gender_token) diff --git a/src/cehrbert_data/decorators/patient_event_decorator_base.py b/src/cehrbert_data/decorators/patient_event_decorator_base.py index ed73866..3d4743e 100644 --- a/src/cehrbert_data/decorators/patient_event_decorator_base.py +++ b/src/cehrbert_data/decorators/patient_event_decorator_base.py @@ -1,10 +1,11 @@ +import os import math from abc import ABC, abstractmethod from enum import Enum from typing import Optional, Union, Set, Callable import numpy as np -from pyspark.sql import DataFrame +from pyspark.sql import DataFrame, SparkSession class AttType(Enum): @@ -17,15 +18,36 @@ class AttType(Enum): class PatientEventDecorator(ABC): + def __init__(self, spark: SparkSession = None, persistence_folder: str = None,): + self.spark = spark + self.persistence_folder = persistence_folder + @abstractmethod def _decorate(self, patient_events): pass + @abstractmethod + def get_name(self): + pass + def decorate(self, patient_events): decorated_patient_events = self._decorate(patient_events) self.validate(decorated_patient_events) return decorated_patient_events + def try_persist_data(self, data: DataFrame, folder_name: str) -> DataFrame: + if self.persistence_folder and self.spark: + temp_folder = os.path.join(self.persistence_folder, folder_name) + data.write.mode("overwrite").parquet(temp_folder) + return self.spark.read.parquet(temp_folder) + return data + + def load_recursive(self) -> Optional[DataFrame]: + if self.persistence_folder and self.spark: + temp_folder = os.path.join(self.persistence_folder, self.get_name()) + return self.spark.read.option("recursiveFileLookup", "true").parquet(temp_folder) + return None + @classmethod def get_required_columns(cls) -> Set[str]: return { diff --git a/src/cehrbert_data/tools/connect_omop_visit.py b/src/cehrbert_data/tools/connect_omop_visit.py new file mode 100644 index 0000000..411bedd --- /dev/null +++ b/src/cehrbert_data/tools/connect_omop_visit.py @@ -0,0 +1,288 @@ +import os +import argparse +from typing import Tuple + +from pyspark.sql import SparkSession, DataFrame +from pyspark.sql import types as t +from pyspark.sql import functions as f +from pyspark.sql.window import Window + + +def connect_visits_in_chronological_order( + spark: SparkSession, + visit_to_fix: DataFrame, + visit_occurrence: DataFrame, + hour_diff_threshold: int, + workspace_folder: str, + visit_name: str +): + visit_to_fix = visit_to_fix.withColumn( + "visit_end_datetime", + f.coalesce("visit_end_datetime", f.col("visit_end_date").cast(t.TimestampType())) + ).withColumn( + "visit_end_datetime", + f.when( + f.col("visit_end_datetime") > f.col("visit_start_datetime"), f.col("visit_end_datetime") + ).otherwise(f.col("visit_start_datetime")) + ).withColumn( + "visit_order", + f.row_number().over( + Window.partitionBy("person_id").orderBy("visit_start_datetime", "visit_occurrence_id") + ) + ).withColumn( + "prev_visit_end_datetime", + f.lag("visit_end_datetime").over( + Window.partitionBy("person_id").orderBy("visit_order") + ) + ).withColumn( + "hour_diff", + f.coalesce( + (f.unix_timestamp("visit_start_datetime") - f.unix_timestamp("prev_visit_end_datetime")) / 3600, + f.lit(0) + ) + ).withColumn( + "visit_partition", + f.sum((f.col("hour_diff") > hour_diff_threshold).cast("int")).over( + Window.partitionBy("person_id").orderBy("visit_order") + .rowsBetween(Window.unboundedPreceding, Window.currentRow) + ) + ).withColumn( + "is_master_visit", + f.row_number().over(Window.partitionBy("person_id", "visit_partition").orderBy("visit_order")) == 1 + ) + visit_to_fix_folder = os.path.join(workspace_folder, f"{visit_name}_visit_to_fix") + visit_to_fix.write.mode("overwrite").parquet(visit_to_fix_folder) + visit_to_fix = spark.read.parquet(visit_to_fix_folder) + # Connect all the individual visits + master_visit = visit_to_fix.alias("visit").join( + visit_to_fix.where( + f.col("is_master_visit") + ).alias("master"), + (f.col("visit.person_id") == f.col("master.person_id")) + & (f.col("visit.visit_partition") == f.col("master.visit_partition")), + ).groupby( + f.col("master.person_id").alias("person_id"), + f.col("master.visit_partition").alias("visit_partition"), + f.col("master.visit_occurrence_id").alias("visit_occurrence_id"), + ).agg( + f.min("visit.visit_start_date").alias("visit_start_date"), + f.min("visit.visit_start_datetime").alias("visit_start_datetime"), + f.max("visit.visit_end_date").alias("visit_end_date"), + f.max("visit.visit_end_datetime").alias("visit_end_datetime"), + ) + master_visit_folder = os.path.join(workspace_folder, f"{visit_name}_master_visit") + master_visit.write.mode("overwrite").parquet(master_visit_folder) + master_visit = spark.read.parquet(master_visit_folder) + visit_mapping = master_visit.alias("master").join( + visit_to_fix.alias("visit"), + (f.col("master.person_id") == f.col("visit.person_id")) + & (f.col("master.visit_partition") == f.col("visit.visit_partition")), + ).where( + f.col("master.visit_occurrence_id") != f.col("visit.visit_occurrence_id") + ).select( + f.col("master.person_id").alias("person_id"), + f.col("master.visit_partition").alias("visit_partition"), + f.col("master.visit_occurrence_id").alias("master_visit_occurrence_id"), + f.col("visit.visit_occurrence_id").alias("visit_occurrence_id"), + ) + visit_mapping_folder = os.path.join(workspace_folder, f"{visit_name}_visit_mapping") + visit_mapping.write.mode("overwrite").parquet(visit_mapping_folder) + visit_mapping = spark.read.parquet(visit_mapping_folder) + # Update the visit_start_date(time) and visit_end_date(time) + columns_to_update = [ + "visit_occurrence_id", "visit_start_date", "visit_end_date", "visit_start_datetime", "visit_end_datetime" + ] + other_columns = [column for column in visit_occurrence.columns if column not in columns_to_update] + visit_occurrence_fixed = visit_occurrence.alias("visit").join( + master_visit.alias("master"), + (f.col("master.visit_occurrence_id") == f.col("visit.visit_occurrence_id")), + "left_outer" + ).select( + [ + f.coalesce(f.col(f"master.{column}"), f.col(f"visit.{column}")).alias(column) + for column in columns_to_update + ] + [ + f.col(f"visit.{column}").alias(column) + for column in other_columns + ] + ) + visit_occurrence_fixed = visit_occurrence_fixed.join( + visit_mapping.select("visit_occurrence_id"), + on="visit_occurrence_id", + how="left_anti" + ) + visit_occurrence_fixed.write.mode("overwrite").parquet( + os.path.join(workspace_folder, f"{visit_name}_visit_occurrence_fixed") + ) + return visit_occurrence_fixed, visit_mapping + + +def step_3_consolidate_outpatient_visits( + spark: SparkSession, + visit_occurrence: DataFrame, + output_folder: str, + outpatient_hour_diff_threshold: int +) -> Tuple[DataFrame, DataFrame]: + # We need to connect the visits together + workspace_folder = os.path.join(output_folder, "outpatient_visit_workspace") + outpatient_visit = visit_occurrence.where( + ~f.col("visit_concept_id").isin(9201, 262) + ).select( + "person_id", "visit_occurrence_id", + "visit_start_date", "visit_start_datetime", + "visit_end_date", "visit_end_datetime" + ) + visit_occurrence_outpatient_visit_fixed, outpatient_visit_mapping = connect_visits_in_chronological_order( + spark=spark, + visit_to_fix=outpatient_visit, + visit_occurrence=visit_occurrence, + hour_diff_threshold=outpatient_hour_diff_threshold, + workspace_folder=workspace_folder, + visit_name="outpatient", + ) + return visit_occurrence_outpatient_visit_fixed, outpatient_visit_mapping + + +def step_1_consolidate_inpatient_visits( + spark: SparkSession, + visit_occurrence: DataFrame, + output_folder: str, + inpatient_hour_diff_threshold: int +) -> Tuple[DataFrame, DataFrame]: + # We need to connect the visits together + workspace_folder = os.path.join(output_folder, "inpatient_visit_workspace") + inpatient_visits = visit_occurrence.where( + f.col("visit_concept_id").isin(9201, 262) + ).select( + "person_id", "visit_occurrence_id", + "visit_start_date", "visit_start_datetime", + "visit_end_date", "visit_end_datetime" + ) + visit_occurrence_inpatient_visit_fixed, inpatient_visit_mapping = connect_visits_in_chronological_order( + spark=spark, + visit_to_fix=inpatient_visits, + visit_occurrence=visit_occurrence, + hour_diff_threshold=inpatient_hour_diff_threshold, + workspace_folder=workspace_folder, + visit_name="inpatient", + ) + return visit_occurrence_inpatient_visit_fixed, inpatient_visit_mapping + + +def step_2_connect_outpatient_to_inpatient( + spark: SparkSession, + visit_occurrence: DataFrame, + output_folder: str, +) -> Tuple[DataFrame, DataFrame]: + # We need to connect the visits together + workspace_folder = os.path.join(output_folder, "outpatient_to_inpatient_visit_workspace") + inpatient_visits = visit_occurrence.where( + f.col("visit_concept_id").isin(9201, 262) + ).select( + "person_id", "visit_occurrence_id", + "visit_start_date", "visit_start_datetime", + "visit_end_date", "visit_end_datetime" + ) + outpatient_visits = visit_occurrence.where( + ~f.col("visit_concept_id").isin(9201, 262) + ).select( + "person_id", "visit_occurrence_id", + "visit_start_date", "visit_start_datetime", + "visit_end_date", "visit_end_datetime" + ) + outpatient_to_inpatient_visit_mapping = inpatient_visits.alias("in").join( + outpatient_visits.alias("out"), + (f.col("in.person_id") == f.col("out.person_id")) + & (f.col("in.visit_start_datetime") < f.col("out.visit_start_datetime")) + & (f.col("out.visit_start_datetime") < f.col("in.visit_end_datetime")), + ).groupby( + f.col("out.visit_occurrence_id").alias("visit_occurrence_id") + ).agg( + f.min("in.visit_occurrence_id").alias("master_visit_occurrence_id"), + ) + outpatient_to_inpatient_visit_mapping_folder = os.path.join( + workspace_folder, + "outpatient_to_inpatient_visit_mapping" + ) + outpatient_to_inpatient_visit_mapping.write.mode("overwrite").parquet( + outpatient_to_inpatient_visit_mapping_folder + ) + outpatient_to_inpatient_visit_mapping = spark.read.parquet( + outpatient_to_inpatient_visit_mapping_folder + ) + visit_occurrence_fixed = visit_occurrence.join( + outpatient_to_inpatient_visit_mapping.select("visit_occurrence_id"), + on="visit_occurrence_id", + how="left_anti" + ) + visit_occurrence_outpatient_to_inpatient_fix_folder = os.path.join( + workspace_folder, "visit_occurrence_outpatient_to_inpatient_fix" + ) + visit_occurrence_fixed.write.mode("overwrite").parquet( + visit_occurrence_outpatient_to_inpatient_fix_folder + ) + visit_occurrence_fixed = spark.read.parquet(visit_occurrence_outpatient_to_inpatient_fix_folder) + return visit_occurrence_fixed, outpatient_to_inpatient_visit_mapping + + +def main(args): + spark = SparkSession.builder.appName("Clean up visit_occurrence").getOrCreate() + visit_occurrence = spark.read.parquet(os.path.join(args.input_folder, "visit_occurrence")) + visit_occurrence_step_1, in_to_in_visit_mapping = step_1_consolidate_inpatient_visits( + spark, + visit_occurrence, + output_folder=args.output_folder, + inpatient_hour_diff_threshold=args.inpatient_hour_diff_threshold, + ) + visit_occurrence_step_2, out_to_in_visit_mapping = step_2_connect_outpatient_to_inpatient( + spark, + visit_occurrence_step_1, + output_folder=args.output_folder, + ) + visit_occurrence_step_3, out_to_out_visit_mapping = step_3_consolidate_outpatient_visits( + spark, + visit_occurrence_step_2, + output_folder=args.output_folder, + outpatient_hour_diff_threshold=args.outpatient_hour_diff_threshold, + ) + visit_occurrence_step_3.write.mode("overwrite").parquet(os.path.join(args.output_folder, "visit_occurrence")) + mapping_columns = ["visit_occurrence_id", "master_visit_occurrence_id"] + visit_mapping = in_to_in_visit_mapping.select(mapping_columns).unionByName( + out_to_in_visit_mapping.select(mapping_columns) + ).unionByName(out_to_out_visit_mapping.select(mapping_columns)) + visit_mapping.write.mode("overwrite").parquet(os.path.join(args.output_folder, "visit_mapping")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Arguments for connecting OMOP visits in chronological order") + parser.add_argument( + "--input_folder", + dest="input_folder", + action="store", + required=True, + ) + parser.add_argument( + "--output_folder", + dest="output_folder", + action="store", + required=True, + ) + parser.add_argument( + "--inpatient_hour_diff_threshold", + dest="inpatient_hour_diff_threshold", + action="store", + type=int, + default=24, + required=False, + ) + parser.add_argument( + "--outpatient_hour_diff_threshold", + dest="outpatient_hour_diff_threshold", + action="store", + type=int, + default=1, + required=False, + ) + main( + parser.parse_args() + ) diff --git a/src/cehrbert_data/tools/ehrshot_to_omop.py b/src/cehrbert_data/tools/ehrshot_to_omop.py index a27f216..d5b38a0 100644 --- a/src/cehrbert_data/tools/ehrshot_to_omop.py +++ b/src/cehrbert_data/tools/ehrshot_to_omop.py @@ -2,6 +2,7 @@ import logging import argparse import shutil + from cehrbert_data.utils.logging_utils import add_console_logging from pyspark.sql import SparkSession, DataFrame @@ -418,7 +419,12 @@ def convert_code_to_omop_concept( ).select(output_columns) -def generate_visit_id(data: DataFrame) -> DataFrame: +def generate_visit_id( + data: DataFrame, + spark: SparkSession, + cache_folder: str, + day_cutoff: int = 1 +) -> DataFrame: """ Generates unique `visit_id`s for each visit based on distinct patient event records. @@ -436,6 +442,12 @@ def generate_visit_id(data: DataFrame) -> DataFrame: - `end`: (Optional) End timestamp of the event. - `omop_table`: String specifying the type of event (e.g., "visit_occurrence" for real visits). - `visit_id`: (Optional) Identifier for visits. May be missing in some records. + spark: SparkSession + The current spark session + cache_folder: str + The cache folder for saving the intermediate dataframes + day_cutoff: int + Day cutoff to disconnect the records with their associated visit Returns ------- @@ -454,6 +466,7 @@ def generate_visit_id(data: DataFrame) -> DataFrame: 4. **Create Artificial Visits**: Generates artificial visit records for orphan `visit_id`s. 5. **Merge and Validate**: Combines the original records with artificial visits and validates the uniqueness of each `visit_id`. """ + visit_reconstruction_folder = os.path.join(cache_folder, "visit_reconstruction") data = data.repartition(16) real_visits = data.where( f.col("omop_table") == "visit_occurrence" @@ -464,7 +477,9 @@ def generate_visit_id(data: DataFrame) -> DataFrame: "visit_end_date", f.coalesce(f.col("end").cast(t.DateType()), f.col("visit_start_date")) ) - + real_visits_folder = os.path.join(visit_reconstruction_folder, "real_visits") + real_visits.write.mode("overwrite").parquet(real_visits_folder) + real_visits = spark.read.parquet(real_visits_folder) # Getting the records that do not have a visit_id domain_records = data.where( f.col("omop_table") != "visit_occurrence" @@ -474,42 +489,23 @@ def generate_visit_id(data: DataFrame) -> DataFrame: ) # This is important to have a deterministic behavior for generating record_id - domain_records.cache() - - # Invalidate visit_id if the record's time stamp falls outside the visit start/end - domain_records = domain_records.alias("domain").join( - real_visits.where("code == 'Visit/IP'").alias("in_visit"), - (f.col("domain.patient_id") == f.col("in_visit.patient_id")) & - (f.col("domain.visit_id") == f.col("in_visit.visit_id")), - "left_outer" - ).withColumn( - "new_visit_id", - f.coalesce( - f.when( - f.col("domain.start").between( - f.date_sub(f.col("in_visit.start"), 1), f.date_add(f.col("in_visit.end"), 1) - ), - f.col("domain.visit_id") - ).otherwise(f.lit(None).cast(t.LongType())), - f.col("domain.visit_id") - ) - ).select( - [ - f.col("domain." + field).alias(field) - for field in domain_records.schema.fieldNames() if not field.endswith("visit_id") - ] + [f.col("new_visit_id").alias("visit_id")] - ) + temp_domain_records_folder = os.path.join(visit_reconstruction_folder, "temp_domain_records") + domain_records.write.mode("overwrite").parquet(temp_domain_records_folder) + domain_records = spark.read.parquet(temp_domain_records_folder) - # Join the DataFrames with aliasing + # Join the records to the nearest visits if they occur within the visit span with aliasing domain_records = domain_records.alias("domain").join( - real_visits.alias("visit"), + real_visits.where(f.col("code").isin(['Visit/IP', 'Visit/ERIP'])).alias("visit"), (f.col("domain.patient_id") == f.col("visit.patient_id")) & - (f.col("domain.start").cast(t.DateType()).between(f.col("visit.visit_start_date"), - f.col("visit.visit_end_date"))), + (f.col("domain.start").between(f.col("visit.start"), f.col("visit.end"))), "left_outer" ).withColumn( "ranking", - f.row_number().over(Window.partitionBy("domain.record_id").orderBy(f.col("visit.visit_start_date").desc())) + f.row_number().over( + Window.partitionBy("domain.record_id").orderBy( + f.abs(f.unix_timestamp("visit.start") - f.unix_timestamp("domain.start")) + ) + ) ).where( f.col("ranking") == 1 ).select( @@ -561,7 +557,7 @@ def generate_visit_id(data: DataFrame) -> DataFrame: # Generate the artificial visits artificial_visits = orphan_records.groupBy("new_visit_id", "patient_id").agg( f.min("start").alias("start"), - f.max("end").alias("end") + f.max("start").alias("end") ).withColumn( "code", f.lit(0) @@ -578,18 +574,171 @@ def generate_visit_id(data: DataFrame) -> DataFrame: "new_visit_id", "visit_id" ).drop("record_id") + artificial_visits_folder = os.path.join(visit_reconstruction_folder, "artificial_visits") + artificial_visits.write.mode("overwrite").parquet(artificial_visits_folder) + artificial_visits = spark.read.parquet(artificial_visits_folder) + # Drop visit_start_date and visit_end_date real_visits = real_visits.drop("visit_start_date", "visit_end_date") # Validate the uniqueness of visit_id artificial_visits.groupby("visit_id").count().select(f.assert_true(f.col("count") == 1)).collect() - # Join the generated visit_id back to data + return domain_records.unionByName( real_visits ).unionByName( artificial_visits ) +def disconnect_visit_id( + data: DataFrame, + spark: SparkSession, + cache_folder: str, + day_cutoff: int = 1 +): + # There are records that fall outside the corresponding visits, the time difference could be days + # and evens years apart, this is likely that the timestamps of the lab events are the time when the + # lab results came back instead of when the labs were sent out, therefore creating the time discrepancy. + # In this case, we will disassociate such records with the visits and will try to connect them to the + # other visits. + visit_reconstruction_folder = os.path.join(cache_folder, "visit_reconstruction") + domain_records = data.where(f.col("omop_table") != "visit_occurrence") + visit_records = data.where(f.col("omop_table") == "visit_occurrence") + visit_inferred_start_end = domain_records.alias("domain").join( + visit_records.alias("visit"), + f.col("domain.visit_id") == f.col("visit.visit_id"), + ).groupby("domain.visit_id").agg( + f.min("domain.start").alias("start"), + f.max("domain.start").alias("end") + ) + visit_to_fix = visit_inferred_start_end.alias("d_visit").join( + visit_records.alias("visit"), + f.col("d_visit.visit_id") == f.col("visit.visit_id"), + ).where( + # If the record is 24 * day_cutoff hours before the visit_start or + # if the record is 24 * day_cutoff hours after the visit_end + ((f.unix_timestamp("visit.start") - f.unix_timestamp("d_visit.start")) / 3600 > day_cutoff * 24) | + ((f.unix_timestamp("d_visit.end") - f.unix_timestamp("visit.end")) / 3600 > day_cutoff * 24) + ).select( + f.col("visit.visit_id").alias("visit_id"), + f.col("visit.start").alias("start"), + f.col("visit.end").alias("end"), + f.col("d_visit.start").alias("inferred_start"), + f.col("d_visit.end").alias("inferred_end"), + ) + visit_to_fix_folder = os.path.join(visit_reconstruction_folder, "visit_to_fix") + visit_to_fix.write.mode("overwrite").parquet(visit_to_fix_folder) + visit_to_fix = spark.read.parquet(visit_to_fix_folder) + + # Identify the unique visit_id/start pairs, we will identify the boundary of the visit + distinct_visit_date_mapping = domain_records.alias("domain").join( + visit_to_fix.alias("visit"), + f.col("domain.visit_id") == f.col("visit.visit_id"), + ).select( + f.col("domain.visit_id").alias("visit_id"), + f.col("domain.start").alias("start"), + f.col("domain.code").alias("code"), + ).distinct().withColumn( + "visit_order", + f.row_number().over( + Window.partitionBy("visit_id").orderBy("start") + ) + ).withColumn( + "prev_start", + f.lag("start").over( + Window.partitionBy("visit_id").orderBy("visit_order") + ) + ).withColumn( + "hour_diff", + f.coalesce( + (f.unix_timestamp("start") - f.unix_timestamp("prev_start")) / 3600, + f.lit(0) + ) + ).withColumn( + "visit_partition", + f.sum((f.col("hour_diff") > 24).cast("int")).over( + Window.partitionBy("visit_id").orderBy("visit_order") + .rowsBetween(Window.unboundedPreceding, Window.currentRow) + ) + ).withColumn( + "visit_partition_rank", + f.dense_rank().over(Window.orderBy(f.col("visit_id"), f.col("visit_partition"))) + ).crossJoin( + visit_records.select(f.max("visit_id").alias("max_visit_id")) + ).withColumn( + "new_visit_id", + f.col("max_visit_id") + f.col("visit_partition_rank") + ).drop( + "max_visit_id", "row_number" + ) + + # Connect visit partitions in chronological order + distinct_visit_date_pair_folder = os.path.join(visit_reconstruction_folder, "distinct_visit_date_mapping") + distinct_visit_date_mapping.write.mode("overwrite").parquet(distinct_visit_date_pair_folder) + distinct_visit_date_mapping = spark.read.parquet(distinct_visit_date_pair_folder) + + fix_visit_records = data.alias("ehr").join( + distinct_visit_date_mapping.alias("visit"), + f.col("ehr.visit_id") == f.col("visit.visit_id"), + ).where( + f.col("ehr.omop_table") == "visit_occurrence" + ).groupby( + f.col("visit.visit_id").alias("original_visit_id"), + f.col("visit.new_visit_id").alias("visit_id"), + f.col("ehr.patient_id").alias("patient_id"), + f.col("ehr.code").alias("code"), + f.col("ehr.value").alias("value"), + f.col("ehr.unit").alias("unit"), + f.col("ehr.omop_table").alias("omop_table"), + ).agg( + f.min("visit.start").alias("start"), + f.max("visit.start").alias("end"), + ).withColumn( + "code", + f.when( + (f.col("code").isin(['Visit/IP', 'Visit/ERIP'])) + & ((f.unix_timestamp("end") - f.unix_timestamp("start")) / 3600 <= 24), + f.lit("Visit/OP") + ).otherwise(f.col("code")) + ) + + # Fix visit records + fix_visit_records_folder = os.path.join(visit_reconstruction_folder, "fix_visit_records") + fix_visit_records.write.mode("overwrite").parquet(fix_visit_records_folder) + fix_visit_records = spark.read.parquet(fix_visit_records_folder) + + fix_domain_records = data.alias("ehr").join( + distinct_visit_date_mapping.alias("visit"), + (f.col("ehr.visit_id") == f.col("visit.visit_id")) + & (f.col("ehr.start") == f.col("visit.start")) + & (f.col("ehr.code") == f.col("visit.code")), + ).where( + f.col("ehr.omop_table") != "visit_occurrence" + ).select( + [ + f.coalesce(f.col("visit.new_visit_id"), f.col("ehr.visit_id")).alias("visit_id"), + f.coalesce(f.col("visit.visit_id"), f.col("ehr.visit_id")).alias("original_visit_id") + ] + + + [ + f.col(f"ehr.{column}").alias(column) for column in data.columns if column != "visit_id" + ] + ) + + # Fix domain records + fix_domain_records_folder = os.path.join(visit_reconstruction_folder, "fix_domain_records") + fix_domain_records.write.mode("overwrite").parquet(fix_domain_records_folder) + fix_domain_records = spark.read.parquet(fix_domain_records_folder) + + # Retrieve other records that do not require fixing + other_events = data.join( + distinct_visit_date_mapping.select("visit_id").distinct(), + "visit_id", + "left_anti" + ).withColumn("original_visit_id", f.col("visit_id")) + + return other_events.unionByName(fix_domain_records).unionByName(fix_visit_records) + def drop_duplicate_visits(data: DataFrame) -> DataFrame: """ @@ -651,7 +800,61 @@ def main(args): ).drop("_c0") # Add visit_id based on the time intervals between neighboring events ehr_shot_data = generate_visit_id( - ehr_shot_data + ehr_shot_data, + spark, + args.output_folder, + ) + # Disconnect domain records whose timestamps fall outside of the corresponding visit ranges + ehr_shot_data = disconnect_visit_id( + ehr_shot_data, + spark, + args.output_folder, + args.day_cutoff, + ) + outpatient_visits = ehr_shot_data.where( + ~f.col("code").isin(["Visit/IP", "Visit/ERIP"]) + ).where(f.col("omop_table") == "visit_occurrence") + # We don't use the end column to get the max end because some end datetime could be years apart from the start date + outpatient_visit_start_end = ehr_shot_data.join(outpatient_visits.select("visit_id"), "visit_id").where( + f.col("omop_table").isin( + ["condition_occurrence", "procedure_occurrence", "drug_exposure", "measurement", "observation", "death"] + ) + ).groupby("visit_id").agg(f.min("start").alias("start"), f.max("start").alias("end")).withColumn( + "hour_diff", (f.unix_timestamp("end") - f.unix_timestamp("start")) / 3600 + ).withColumn( + "inpatient_indicator", + (f.col("hour_diff") > 24).cast("int") + ) + # Reload it from the disk to update the dataframe + outpatient_visit_start_end_folder = os.path.join(args.output_folder, "outpatient_visit_start_end") + outpatient_visit_start_end.write.mode("overwrite").parquet( + outpatient_visit_start_end_folder + ) + outpatient_visit_start_end = spark.read.parquet(outpatient_visit_start_end_folder) + inferred_inpatient_visits = outpatient_visit_start_end.where("inpatient_indicator = 1").select( + "visit_id", "start", "end", f.lit("Visit/IP").alias("code"), + ) + ehr_shot_data = ehr_shot_data.alias("ehr").join( + inferred_inpatient_visits.alias("visits"), "visit_id", "left_outer" + ).select( + f.col("ehr.patient_id").alias("patient_id"), + f.when( + f.col("ehr.omop_table") == "visit_occurrence", + f.coalesce(f.col("visits.start"), f.col("ehr.start")), + ).otherwise(f.col("ehr.start")).alias("start"), + f.when( + f.col("ehr.omop_table") == "visit_occurrence", + f.coalesce(f.col("visits.end"), f.col("ehr.end")), + ).otherwise(f.col("ehr.end")).alias("end"), + f.when( + f.col("ehr.omop_table") == "visit_occurrence", + f.coalesce(f.col("visits.code"), f.col("ehr.code")), + ).otherwise(f.col("ehr.code")).alias("code"), + f.col("ehr.value").alias("value"), + f.col("ehr.unit").alias("unit"), + f.col("ehr.omop_table").alias("omop_table"), + f.col("ehr.visit_id").alias("visit_id"), + f.col("ehr.original_visit_id").alias("original_visit_id"), ) ehr_shot_data.write.mode("overwrite").parquet(ehr_shot_path) @@ -736,6 +939,14 @@ def main(args): dest="refresh_ehrshot", action="store_true", ) + parser.add_argument( + "--day_cutoff", + dest="day_cutoff", + action="store", + type=int, + default=1, + required=False, + ) main( parser.parse_args() ) diff --git a/src/cehrbert_data/tools/extract_features.py b/src/cehrbert_data/tools/extract_features.py index d03c6c5..8f559f8 100644 --- a/src/cehrbert_data/tools/extract_features.py +++ b/src/cehrbert_data/tools/extract_features.py @@ -1,5 +1,4 @@ import os -import re from pathlib import Path import shutil from enum import Enum @@ -22,11 +21,6 @@ class PredictionType(Enum): REGRESSION = "regression" -def get_temp_folder(args, table_name): - cleaned_cohort_name = re.sub(r'[^A-Za-z0-9]', '_', args.cohort_name) - return os.path.join(args.output_folder, f"{cleaned_cohort_name}_{table_name}") - - def create_feature_extraction_args(): spark_args = create_spark_args( parse=False @@ -91,7 +85,9 @@ def main(args): cohort_csv = cohort_csv.withColumn("cohort_member_id", cohort_member_id_udf) # Save cohort as parquet files - cohort_temp_folder = get_temp_folder(args, "cohort") + cohort_temp_folder = os.path.join( + args.output_folder, args.cohort_name, "cohort" + ) cohort_csv.write.mode("overwrite").parquet(cohort_temp_folder) cohort = spark.read.parquet(cohort_temp_folder) @@ -111,56 +107,100 @@ def main(args): ehr_records = cohort.select("person_id", "cohort_member_id", "index_date").join( ehr_records, "person_id" - ).where(ehr_records["date"] <= cohort["index_date"]) - - ehr_records_temp_folder = get_temp_folder(args, "ehr_records") - ehr_records.repartition("person_id").write.mode("overwrite").parquet(ehr_records_temp_folder) + ).where(ehr_records["datetime"] <= cohort["index_date"]) + ehr_records_temp_folder = os.path.join( + args.output_folder, args.cohort_name, "ehr_records" + ) + ehr_records.write.mode("overwrite").parquet(ehr_records_temp_folder) ehr_records = spark.read.parquet(ehr_records_temp_folder) visit_occurrence = spark.read.parquet(os.path.join(args.input_folder, "visit_occurrence")) cohort_visit_occurrence = visit_occurrence.join( - cohort.select("person_id").distinct(), + cohort.select("person_id", "cohort_member_id", "index_date"), "person_id" + ).withColumn( + "visit_start_date", + f.col("visit_start_date").cast(t.DateType()) ).withColumn( "visit_end_date", - f.coalesce(f.col("visit_end_date"), f.col("visit_start_date")) + f.coalesce(f.col("visit_end_date"), f.col("visit_start_date")).cast(t.DateType()) + ).withColumn( + "visit_start_datetime", + f.col("visit_start_datetime").cast(t.TimestampType()) ).withColumn( "visit_end_datetime", f.coalesce( f.col("visit_end_datetime"), f.col("visit_end_date").cast(t.TimestampType()), f.col("visit_start_datetime") - ) + ).cast(t.TimestampType()) + ).where( + f.col("visit_start_datetime") <= f.col("index_date") ) + # For each patient/index_date pair, we get the last record before the index_date # we get the corresponding visit_occurrence_id and index_date if args.bound_visit_end_date: cohort_visit_occurrence = cohort_visit_occurrence.withColumn( - "order", f.row_number().over(Window.orderBy(f.monotonically_increasing_id())) - ) + "time_diff_from_index_date", + f.abs(f.unix_timestamp("index_date") - f.unix_timestamp("visit_start_datetime")) + ).withColumn( + "visit_rank", + f.row_number().over( + Window.partitionBy("person_id", "cohort_member_id").orderBy("time_diff_from_index_date") + ) + ).drop("time_diff_from_index_date") - visit_index_date = cohort_visit_occurrence.alias("visit").join( - cohort.alias("cohort"), - "person_id" - ).where( - f.col("cohort.index_date").between(f.col("visit.visit_start_datetime"), f.col("visit.visit_end_datetime")) + # We create placeholder tokens for those inpatient visits, where the first token occurs after the index_date + placeholder_tokens = cohort_visit_occurrence.where( + f.col("visit_rank") == 1 ).select( - f.col("visit.visit_occurrence_id").alias("visit_occurrence_id"), - f.col("cohort.index_date").alias("index_date"), + "person_id", + "cohort_member_id", + "index_date", + "visit_occurrence_id", + f.lit("0").alias("standard_concept_id"), + f.col("index_date").cast(t.DateType()).alias("date"), + f.col("index_date").alias("datetime"), + f.lit("unknown").alias("domain"), + f.lit(None).cast(t.StringType()).alias("unit"), + f.lit(None).cast(t.FloatType()).alias("number_as_value"), + f.lit(None).cast(t.StringType()).alias("concept_as_value"), + f.lit(None).cast(t.StringType()).alias("event_group_id"), + "visit_concept_id", + f.lit(-1).alias("age") + ).join( + ehr_records.select("cohort_member_id", "visit_occurrence_id"), + ["cohort_member_id", "visit_occurrence_id"], + "left_anti", + ) + + placeholder_tokens.write.mode("overwrite").parquet( + os.path.join(args.output_folder, args.cohort_name, "placeholder_tokens") + ) + # Add an artificial token for the visit in which the prediction is made + ehr_records = ehr_records.unionByName( + placeholder_tokens ) # Bound the visit_end_date and visit_end_datetime - cohort_visit_occurrence = cohort_visit_occurrence.join( - visit_index_date, - "visit_occurrence_id", - "left_outer", + cohort_visit_occurrence = cohort_visit_occurrence.withColumn( + "visit_end_datetime", + f.when( + f.col("visit_end_datetime") > f.col("index_date"), + f.col("index_date") + ).otherwise(f.col("visit_end_datetime")) ).withColumn( "visit_end_date", - f.coalesce(f.col("index_date").cast(t.DateType()), f.col("visit_end_date")) - ).withColumn( - "visit_end_datetime", - f.coalesce(f.col("index_date"), f.col("visit_end_datetime")) - ).orderBy(f.col("order")).drop("order") + f.col("visit_end_datetime").cast(t.DateType()) + ) + cohort_member_visit_folder = os.path.join( + args.output_folder, args.cohort_name, "cohort_member_visit_occurrence" + ) + cohort_visit_occurrence.write.mode("overwrite").parquet( + cohort_member_visit_folder + ) + cohort_visit_occurrence = spark.read.parquet(cohort_member_visit_folder).drop("visit_rank") birthdate_udf = f.coalesce( "birth_datetime", @@ -195,7 +235,9 @@ def main(args): inpatient_att_type=AttType(args.inpatient_att_type), exclude_demographic=args.exclude_demographic, use_age_group=args.use_age_group, - include_inpatient_hour_token=args.include_inpatient_hour_token + include_inpatient_hour_token=args.include_inpatient_hour_token, + spark=spark, + persistence_folder=str(os.path.join(args.output_folder, args.cohort_name)), ) elif args.is_feature_concept_frequency: ehr_records = create_concept_frequency_data( @@ -251,12 +293,6 @@ def main(args): else: cohort.write.mode("overwrite").parquet(cohort_folder) - if os.path.exists(cohort_temp_folder): - shutil.rmtree(cohort_temp_folder) - - if os.path.exists(ehr_records_temp_folder): - shutil.rmtree(ehr_records_temp_folder) - spark.stop() diff --git a/src/cehrbert_data/tools/update_omop_visit.py b/src/cehrbert_data/tools/update_omop_visit.py new file mode 100644 index 0000000..4fc0f4b --- /dev/null +++ b/src/cehrbert_data/tools/update_omop_visit.py @@ -0,0 +1,60 @@ +import os +import argparse +import shutil + +from cehrbert_data.tools.ehrshot_to_omop import table_mapping, VOCABULARY_TABLES +from pyspark.sql import SparkSession +from pyspark.sql import functions as f + +def main(args): + spark = SparkSession.builder.appName("Clean up visit_occurrence").getOrCreate() + visit_mapping = spark.read.parquet( + os.path.join(args.output_folder, "visit_mapping") + ) + omop_table_name: str + for omop_table_name in table_mapping.keys(): + if omop_table_name not in ["visit_occurrence", "death"]: + omop_table = spark.read.parquet(os.path.join(args.input_folder, omop_table_name)) + omop_table.alias("domain").join( + visit_mapping.alias("visit"), + on=f.col("domain.visit_occurrence_id") == f.col("visit.visit_occurrence_id"), + how="left" + ).select( + [ + f.coalesce( + f.col("visit.master_visit_occurrence_id"), + f.col("domain.visit_occurrence_id") + ).alias("visit_occurrence_id") + ] + + [ + f.col(f"domain.{column}").alias(column) + for column in omop_table.columns if column != "visit_occurrence_id" + ] + ) + omop_table.write.mode("overwrite").parquet(os.path.join(args.output_folder, omop_table_name)) + + vocabulary_table: str + for vocabulary_table in VOCABULARY_TABLES + ["person"]: + if not os.path.exists(os.path.join(args.output_folder, vocabulary_table)): + shutil.copytree( + os.path.join(args.vocabulary_folder, vocabulary_table), + os.path.join(args.output_folder, vocabulary_table), + ) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Arguments for connecting OMOP visits in chronological order") + parser.add_argument( + "--input_folder", + dest="input_folder", + action="store", + required=True, + ) + parser.add_argument( + "--output_folder", + dest="output_folder", + action="store", + required=True, + ) + main( + parser.parse_args() + ) diff --git a/src/cehrbert_data/utils/spark_utils.py b/src/cehrbert_data/utils/spark_utils.py index 288e0b6..e404dae 100644 --- a/src/cehrbert_data/utils/spark_utils.py +++ b/src/cehrbert_data/utils/spark_utils.py @@ -174,7 +174,7 @@ def join_domain_tables(domain_tables: List[DataFrame]) -> DataFrame: filtered_domain_table["person_id"], filtered_domain_table[concept_id_field].alias("standard_concept_id"), filtered_domain_table["date"].cast("date"), - filtered_domain_table["datetime"], + filtered_domain_table["datetime"].cast(T.TimestampType()), filtered_domain_table["visit_occurrence_id"], F.lit(table_domain_field).alias("domain"), F.lit(None).cast("string").alias("event_group_id"), @@ -669,6 +669,8 @@ def create_sequence_data_with_att( exclude_demographic: bool = True, use_age_group: bool = False, include_inpatient_hour_token: bool = False, + spark: SparkSession = None, + persistence_folder: str = None, ): """ Create a sequence of the events associated with one patient in a chronological order. @@ -685,6 +687,8 @@ def create_sequence_data_with_att( :param exclude_demographic: :param use_age_group: :param include_inpatient_hour_token: + :param spark: SparkSession + :param persistence_folder: persistence folder for the temp data frames :return: """ @@ -692,7 +696,7 @@ def create_sequence_data_with_att( patient_events = patient_events.where(F.col("date").cast("date") >= date_filter) decorators = [ - ClinicalEventDecorator(visit_occurrence), + ClinicalEventDecorator(visit_occurrence, spark=spark, persistence_folder=persistence_folder), AttEventDecorator( visit_occurrence, include_visit_type, @@ -700,12 +704,21 @@ def create_sequence_data_with_att( att_type, inpatient_att_type, include_inpatient_hour_token, + spark=spark, + persistence_folder=persistence_folder ), - DeathEventDecorator(death, att_type), + DeathEventDecorator(death, att_type, spark=spark, persistence_folder=persistence_folder), ] if not exclude_demographic: - decorators.append(DemographicEventDecorator(patient_demographic, use_age_group)) + decorators.append( + DemographicEventDecorator( + patient_demographic, + use_age_group, + spark=spark, + persistence_folder=persistence_folder + ) + ) for decorator in decorators: patient_events = decorator.decorate(patient_events) diff --git a/tests/pyspark_test_base.py b/tests/pyspark_test_base.py index a4e908b..81e19fb 100644 --- a/tests/pyspark_test_base.py +++ b/tests/pyspark_test_base.py @@ -22,7 +22,17 @@ def setUpClass(cls): def setUp(self): from pyspark.sql import SparkSession + # The error InaccessibleObjectException: Unable to make private java.nio.DirectByteBuffer(long,int) accessible + # occurs because the PySpark code is trying to make a private constructor accessible, which is prohibited by + # the Java module system for security reasons. + # Starting from Java 9, JPMS enforces strong encapsulation of Java modules unless explicitly opened up. + # This means that unless the java.base module explicitly opens the java.nio package to your application, + # reflection on its classes and members will be blocked. + # Add JVM Options: If you must use a newer version of Java, you can try adding JVM options to open up the + # necessary modules. This is done by adding arguments to the spark.driver.extraJavaOptions and + # spark.executor.extraJavaOptions in your Spark configuration: self.spark = SparkSession.builder.master("local").appName("test").getOrCreate() + # Get the root folder of the project root_folder = Path(os.path.abspath(__file__)).parent.parent self.data_folder = os.path.join(root_folder, "sample_data", "omop_sample") diff --git a/tests/unit_tests/test_ehrshot_to_omop.py b/tests/unit_tests/test_ehrshot_to_omop.py index 6de39b1..4592f84 100644 --- a/tests/unit_tests/test_ehrshot_to_omop.py +++ b/tests/unit_tests/test_ehrshot_to_omop.py @@ -1,4 +1,5 @@ import unittest +import tempfile from datetime import datetime from pyspark.sql import SparkSession from pyspark.sql import functions as f @@ -65,10 +66,10 @@ def test_generate_visit_id(self): # Sample data with multiple events for each patient and different time gaps data = [ (1, datetime(2023, 1, 1, 8), datetime(2023, 1, 1, 9), 1, "visit_occurrence", None, None, None), - (1, datetime(2023, 1, 1, 20), datetime(2023, 1, 1, 20), None, "condition_occurrence", None, None, None), # 11-hour gap (merged visit) - (1, datetime(2023, 1, 2, 20), datetime(2023, 1, 2, 20), 2, "visit_occurrence", None, None, None), # another visit + (1, datetime(2023, 1, 2, 20), datetime(2023, 1, 2, 20), None, "condition_occurrence", None, None, None), # merge with the visit record below + (1, datetime(2023, 1, 2, 20), datetime(2023, 1, 2, 20), 2, "visit_occurrence", "Visit/IP", None, None), # another visit (2, datetime(2023, 1, 1, 8), datetime(2023, 1, 1, 9), 3, "visit_occurrence", None, None, None), - (2, datetime(2023, 1, 1, 10), datetime(2023, 1, 1, 11), None, "condition_occurrence", None, None, None), # same visit for patient 2 + (2, datetime(2023, 1, 1, 10), datetime(2023, 1, 1, 11), None, "condition_occurrence", None, None, None), (3, datetime(2023, 1, 1, 8), datetime(2023, 1, 1, 9), 1000, "visit_occurrence", None, None, None), (4, datetime(2023, 1, 1, 8), datetime(2023, 1, 1, 9), None, "condition_occurrence", None, None, None) ] @@ -76,11 +77,13 @@ def test_generate_visit_id(self): # Create DataFrame data = self.spark.createDataFrame(data, schema=schema) # Run the function to generate visit IDs - result_df = generate_visit_id(data) - result_df.show() + temp_dir = tempfile.mkdtemp() + result_df = generate_visit_id(data, self.spark, temp_dir) + result_df.orderBy("patient_id", "start").show() # Validate the number of visits - self.assertEqual(5, result_df.select("visit_id").where(f.col("visit_id").isNotNull()).distinct().count()) - self.assertEqual(8, result_df.count()) + self.assertEqual(6, result_df.select("visit_id").where(f.col("visit_id").isNotNull()).distinct().count()) + # Two artificial visits are created therefore it's 7 + 2 = 9 + self.assertEqual(9, result_df.count()) # Check that visit_id was generated as an integer (bigint) self.assertIn( @@ -98,7 +101,7 @@ def test_generate_visit_id(self): patient_2_visits = result_df.filter(f.col("patient_id") == 2).select("visit_id").distinct().count() self.assertEqual( - 1, patient_2_visits, "Patient 2 should have one visit as events are within time interval." + 2, patient_2_visits, "Patient 2 should have one visit as events are within time interval." ) patient_3_visits = result_df.filter(f.col("patient_id") == 3).select("visit_id").distinct().count() @@ -107,8 +110,8 @@ def test_generate_visit_id(self): ) patient_4_visits = result_df.filter(f.col("patient_id") == 4).select("visit_id").collect()[0].visit_id - self.assertEqual( - 1001, patient_4_visits, "Patient 4 should have one generated visit_id." + self.assertTrue( + patient_4_visits > 1000, "Patient 4 should have one generated visit_id." ) def test_drop_duplicate_visits(self):