Skip to content

Commit

Permalink
Added event_group_id to the patient_events dataframe, the event_group…
Browse files Browse the repository at this point in the history
…_id is constructed using the following format "domain-domain_id", the introduction of event_group_id is required to group measurement and value_as_concept_id together
  • Loading branch information
ChaoPang committed Oct 2, 2024
1 parent 6c06275 commit 20ff415
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("concept_order", F.col("min_concept_order") - 1)
.withColumn("priority", F.lit(VS_TOKEN_PRIORITY))
.withColumn("unit", F.lit(None).cast("string"))
.withColumn("event_group_id", F.lit("N/A"))
.drop("min_visit_concept_order", "max_visit_concept_order")
.drop("min_concept_order", "max_concept_order")
)
Expand All @@ -116,6 +117,7 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("concept_order", F.col("max_concept_order") + 1)
.withColumn("priority", F.lit(VE_TOKEN_PRIORITY))
.withColumn("unit", F.lit(None).cast("string"))
.withColumn("event_group_id", F.lit("N/A"))
.drop("min_visit_concept_order", "max_visit_concept_order")
.drop("min_concept_order", "max_concept_order")
)
Expand Down Expand Up @@ -159,6 +161,7 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("visit_concept_order", F.col("min_visit_concept_order"))
.withColumn("concept_order", F.lit(0))
.withColumn("unit", F.lit(None).cast("string"))
.withColumn("event_group_id", F.lit("N/A"))
.drop("prev_visit_end_date", "time_delta")
.drop("min_visit_concept_order", "max_visit_concept_order")
.drop("min_concept_order", "max_concept_order")
Expand All @@ -178,6 +181,7 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("concept_order", F.lit(0))
.withColumn("priority", F.lit(VISIT_TYPE_TOKEN_PRIORITY))
.withColumn("unit", F.lit(None).cast("string"))
.withColumn("event_group_id", F.lit("N/A"))
.drop("min_visit_concept_order", "max_visit_concept_order")
.drop("min_concept_order", "max_concept_order")
)
Expand Down Expand Up @@ -227,6 +231,7 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("datetime", F.expr("datetime - INTERVAL 1 MINUTE"))
.withColumn("priority", F.lit(DISCHARGE_TOKEN_PRIORITY))
.withColumn("unit", F.lit(None).cast("string"))
.withColumn("event_group_id", F.lit("N/A"))
.drop("discharged_to_concept_id", "visit_end_date")
.drop("min_visit_concept_order", "max_visit_concept_order")
.drop("min_concept_order", "max_concept_order")
Expand Down Expand Up @@ -276,6 +281,7 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("concept_value_mask", F.lit(0))
.withColumn("concept_value", F.lit(0.0))
.withColumn("unit", F.lit(None).cast("string"))
.withColumn("event_group_id", F.lit("N/A"))
.drop("prev_date", "time_delta", "is_span_boundary")
.drop("prev_datetime", "hour_delta")
)
Expand All @@ -302,6 +308,7 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("concept_value_mask", F.lit(0))
.withColumn("concept_value", F.lit(0.0))
.withColumn("unit", F.lit(None).cast("string"))
.withColumn("event_group_id", F.lit("N/A"))
.drop("prev_date", "time_delta", "is_span_boundary")
)

Expand Down
4 changes: 4 additions & 0 deletions src/cehrbert_data/decorators/death_event_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("domain", F.lit("death"))
.withColumn("visit_rank_order", F.lit(1) + F.col("visit_rank_order"))
.withColumn("priority", DEATH_TOKEN_PRIORITY)
.withColumn("event_group_id", F.lit("N/A"))
.drop("max_visit_occurrence_id")
)

Expand All @@ -62,13 +63,15 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("standard_concept_id", F.lit(VS_TOKEN))
.withColumn("priority", VS_TOKEN_PRIORITY)
.withColumn("unit", F.lit(None).cast("string"))
.withColumn("event_group_id", F.lit("N/A"))
)

ve_records = (
death_records
.withColumn("standard_concept_id", F.lit(VE_TOKEN))
.withColumn("priority", VE_TOKEN_PRIORITY)
.withColumn("unit", F.lit(None).cast("string"))
.withColumn("event_group_id", F.lit("N/A"))
)

# Udf for calculating the time token
Expand All @@ -94,6 +97,7 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("standard_concept_id", time_token_udf("time_delta"))
.withColumn("priority", F.lit(ATT_TOKEN_PRIORITY))
.withColumn("unit", F.lit(None).cast("string"))
.withColumn("event_group_id", F.lit("N/A"))
.drop("time_delta")
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("concept_value_mask", F.lit(0))
.withColumn("concept_value", F.lit(0.0))
.withColumn("unit", F.lit(None).cast("string"))
.withColumn("event_group_id", F.lit("N/A"))
.where("token_order = 1")
.drop("token_order")
)
Expand Down
3 changes: 2 additions & 1 deletion src/cehrbert_data/decorators/patient_event_decorator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def get_required_columns(cls) -> Set[str]:
"visit_start_date",
"visit_start_datetime",
"visit_concept_order",
"concept_order"
"concept_order",
"event_group_id"
}

def validate(self, patient_events: DataFrame):
Expand Down
72 changes: 47 additions & 25 deletions src/cehrbert_data/utils/spark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,34 @@
"condition_concept_id",
"condition_start_date",
"condition_start_datetime",
"condition",
"condition"
)
],
"procedure_occurrence_id": [
(
"procedure_concept_id",
"procedure_date",
"procedure_datetime",
"procedure"
)
],
"procedure_occurrence_id": [("procedure_concept_id", "procedure_date", "procedure_datetime", "procedure")],
"drug_exposure_id": [
(
"drug_concept_id",
"drug_exposure_start_date",
"drug_exposure_start_datetime",
"drug",
"drug"
)
],
"measurement_id": [
(
"measurement_concept_id",
"measurement_date",
"measurement_datetime",
"measurement",
"measurement"
)
],
"death_date": [("person_id", "death_date", "death_datetime", "death")],
"death_date": [("death_concept_id", "death_date", "death_datetime", "death")],
"visit_concept_id": [
("visit_concept_id", "visit_start_date", "visit"),
("discharged_to_concept_id", "visit_end_date", "visit"),
Expand All @@ -67,7 +74,7 @@
LOGGER = logging.getLogger(__name__)


def get_key_fields(domain_table) -> List[Tuple[str, str, str, str]]:
def get_key_fields(domain_table: DataFrame) -> List[Tuple[str, str, str, str]]:
field_names = domain_table.schema.fieldNames()
for k, v in DOMAIN_KEY_FIELDS.items():
if k in field_names:
Expand All @@ -77,7 +84,7 @@ def get_key_fields(domain_table) -> List[Tuple[str, str, str, str]]:
get_concept_id_field(domain_table),
get_domain_date_field(domain_table),
get_domain_datetime_field(domain_table),
get_domain_field(domain_table),
get_domain_field(domain_table)
)
]

Expand All @@ -89,6 +96,17 @@ def domain_has_unit(domain_table: DataFrame) -> bool:
return False


def get_domain_id_field(domain_table: DataFrame) -> str:
table_fields = domain_table.schema.fieldNames()
candidate_id_fields = [
f for f in table_fields
if not f.endswith("_concept_id") and f.endswith("_id")
]
if candidate_id_fields:
return candidate_id_fields[0]
raise ValueError(f"{domain_table} does not have a valid id columns: {table_fields}")


def get_domain_date_field(domain_table: DataFrame) -> str:
# extract the domain start_date column
return [f for f in domain_table.schema.fieldNames() if "date" in f][0]
Expand All @@ -115,9 +133,8 @@ def create_file_path(input_folder: str, table_name: str):
return file_path


def join_domain_tables(domain_tables):
def join_domain_tables(domain_tables: List[DataFrame]) -> DataFrame:
"""Standardize the format of OMOP domain tables using a time frame.
Keyword arguments:
domain_tables -- the array containing the OMOP domain tables except visit_occurrence
except measurement
Expand All @@ -136,39 +153,42 @@ def join_domain_tables(domain_tables):
concept_id_field,
date_field,
datetime_field,
table_domain_field,
table_domain_field
) in get_key_fields(domain_table):
domain_id_field = get_domain_id_field(domain_table)
# Remove records that don't have a date or standard_concept_id
sub_domain_table = domain_table.where(F.col(date_field).isNotNull()).where(
filtered_domain_table = domain_table.where(F.col(date_field).isNotNull()).where(
F.col(concept_id_field).isNotNull()
)
datetime_field_udf = F.to_timestamp(F.coalesce(datetime_field, date_field), "yyyy-MM-dd HH:mm:ss")
sub_domain_table = (
sub_domain_table.where(F.col(concept_id_field).cast("string") != "0")
filtered_domain_table = (
filtered_domain_table.where(F.col(concept_id_field).cast("string") != "0")
.withColumn("date", F.to_date(F.col(date_field)))
.withColumn("datetime", datetime_field_udf)
.withColumn("domain_id", F.col(domain_id_field).cast("string"))
)

unit_udf = F.col("unit") if domain_has_unit(sub_domain_table) else F.lit(None).cast("string")
sub_domain_table = sub_domain_table.select(
sub_domain_table["person_id"],
sub_domain_table[concept_id_field].alias("standard_concept_id"),
sub_domain_table["date"].cast("date"),
sub_domain_table["datetime"],
sub_domain_table["visit_occurrence_id"],
unit_udf = F.col("unit") if domain_has_unit(filtered_domain_table) else F.lit(None).cast("string")
filtered_domain_table = filtered_domain_table.select(
filtered_domain_table["person_id"],
filtered_domain_table[concept_id_field].alias("standard_concept_id"),
filtered_domain_table["date"].cast("date"),
filtered_domain_table["datetime"],
filtered_domain_table["visit_occurrence_id"],
F.lit(table_domain_field).alias("domain"),
F.concat(F.lit(table_domain_field), F.lit("-"), F.col("domain_id")).alias("event_group_id"),
F.lit(-1).alias("concept_value"),
unit_udf.alias("unit"),
).distinct()

# Remove "Patient Died" from condition_occurrence
if sub_domain_table == "condition_occurrence":
sub_domain_table = sub_domain_table.where("condition_concept_id != 4216643")
if filtered_domain_table == "condition_occurrence":
filtered_domain_table = filtered_domain_table.where("condition_concept_id != 4216643")

if patient_event is None:
patient_event = sub_domain_table
patient_event = filtered_domain_table
else:
patient_event = patient_event.union(sub_domain_table)
patient_event = patient_event.union(filtered_domain_table)

return patient_event

Expand Down Expand Up @@ -692,6 +712,7 @@ def create_sequence_data_with_att(
"concept_order",
"priority",
"datetime",
"event_group_id",
"standard_concept_id",
)
)
Expand Down Expand Up @@ -891,6 +912,7 @@ def extract_ehr_records(
patient_ehr_records["visit_occurrence_id"],
patient_ehr_records["domain"],
patient_ehr_records["unit"],
patient_ehr_records["event_group_id"],
visit_occurrence["visit_concept_id"],
patient_ehr_records["age"],
)
Expand Down Expand Up @@ -1381,7 +1403,7 @@ def process_measurement(
m.person_id,
CASE
WHEN value_as_concept_id IS NOT NULL AND value_as_concept_id <> 0
THEN CONCAT(CAST(measurement_concept_id AS STRING), '-', CAST(value_as_concept_id AS STRING))
THEN CONCAT(CAST(measurement_concept_id AS STRING), '-', CAST(COALESCE(value_as_concept_id, 0) AS STRING))
ELSE CAST(measurement_concept_id AS STRING)
END AS standard_concept_id,
CAST(m.measurement_date AS DATE) AS date,
Expand Down

0 comments on commit 20ff415

Please sign in to comment.