Skip to content

Commit

Permalink
added the token_priority module to centralize the token priorities
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Oct 2, 2024
1 parent 7e6ab2d commit 6c06275
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 31 deletions.
3 changes: 3 additions & 0 deletions src/cehrbert_data/const/artificial_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
VS_TOKEN = "[VS]"
VE_TOKEN = "[VE]"
DEATH_TOKEN = "[DEATH]"
40 changes: 28 additions & 12 deletions src/cehrbert_data/decorators/artificial_time_token_decorator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
from pyspark.sql import DataFrame, functions as F, types as T, Window as W

from .patient_event_decorator_base import PatientEventDecorator, AttType, time_day_token, \
time_week_token, time_month_token, time_mix_token, time_token_func
from ..const.artificial_tokens import VS_TOKEN, VE_TOKEN
from .patient_event_decorator_base import (
PatientEventDecorator, AttType,
time_day_token,
time_week_token,
time_month_token,
time_mix_token,
time_token_func
)
from .token_priority import (
ATT_TOKEN_PRIORITY,
VS_TOKEN_PRIORITY,
VISIT_TYPE_TOKEN_PRIORITY,
DISCHARGE_TOKEN_PRIORITY,
VE_TOKEN_PRIORITY,
get_inpatient_token_priority,
get_inpatient_att_token_priority
)


class AttEventDecorator(PatientEventDecorator):
Expand Down Expand Up @@ -82,10 +98,10 @@ def _decorate(self, patient_events: DataFrame):
visit_start_events = (
visits.withColumn("date", F.col("visit_start_date"))
.withColumn("datetime", F.to_timestamp("visit_start_date"))
.withColumn("standard_concept_id", F.lit("VS"))
.withColumn("standard_concept_id", F.lit(VS_TOKEN))
.withColumn("visit_concept_order", F.col("min_visit_concept_order"))
.withColumn("concept_order", F.col("min_concept_order") - 1)
.withColumn("priority", F.lit(-2))
.withColumn("priority", F.lit(VS_TOKEN_PRIORITY))
.withColumn("unit", F.lit(None).cast("string"))
.drop("min_visit_concept_order", "max_visit_concept_order")
.drop("min_concept_order", "max_concept_order")
Expand All @@ -95,10 +111,10 @@ def _decorate(self, patient_events: DataFrame):
visits.withColumn("date", F.col("visit_end_date"))
.withColumn("datetime", F.date_add(F.to_timestamp("visit_end_date"), 1))
.withColumn("datetime", F.expr("datetime - INTERVAL 1 MINUTE"))
.withColumn("standard_concept_id", F.lit("VE"))
.withColumn("standard_concept_id", F.lit(VE_TOKEN))
.withColumn("visit_concept_order", F.col("max_visit_concept_order"))
.withColumn("concept_order", F.col("max_concept_order") + 1)
.withColumn("priority", F.lit(200))
.withColumn("priority", F.lit(VE_TOKEN_PRIORITY))
.withColumn("unit", F.lit(None).cast("string"))
.drop("min_visit_concept_order", "max_visit_concept_order")
.drop("min_concept_order", "max_concept_order")
Expand Down Expand Up @@ -138,7 +154,7 @@ def _decorate(self, patient_events: DataFrame):
F.when(F.col("time_delta") < 0, F.lit(0)).otherwise(F.col("time_delta")),
)
.withColumn("standard_concept_id", time_token_udf("time_delta"))
.withColumn("priority", F.lit(-3))
.withColumn("priority", F.lit(ATT_TOKEN_PRIORITY))
.withColumn("visit_rank_order", F.col("visit_rank_order"))
.withColumn("visit_concept_order", F.col("min_visit_concept_order"))
.withColumn("concept_order", F.lit(0))
Expand All @@ -160,7 +176,7 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("datetime", F.to_timestamp("date"))
.withColumn("visit_concept_order", F.col("min_visit_concept_order"))
.withColumn("concept_order", F.lit(0))
.withColumn("priority", F.lit(-1))
.withColumn("priority", F.lit(VISIT_TYPE_TOKEN_PRIORITY))
.withColumn("unit", F.lit(None).cast("string"))
.drop("min_visit_concept_order", "max_visit_concept_order")
.drop("min_concept_order", "max_concept_order")
Expand Down Expand Up @@ -194,7 +210,7 @@ def _decorate(self, patient_events: DataFrame):
F.when(F.col("date") > F.col("visit_end_date"), F.col("visit_end_date")).otherwise(F.col("date"))
),
)
.withColumn("priority", F.col("priority") + F.col("concept_order") * 0.1)
.withColumn("priority", get_inpatient_token_priority())
.drop("visit_end_date")
)

Expand All @@ -209,7 +225,7 @@ def _decorate(self, patient_events: DataFrame):
.withColumn("date", F.col("visit_end_date"))
.withColumn("datetime", F.date_add(F.to_timestamp("visit_end_date"), 1))
.withColumn("datetime", F.expr("datetime - INTERVAL 1 MINUTE"))
.withColumn("priority", F.lit(100))
.withColumn("priority", F.lit(DISCHARGE_TOKEN_PRIORITY))
.withColumn("unit", F.lit(None).cast("string"))
.drop("discharged_to_concept_id", "visit_end_date")
.drop("min_visit_concept_order", "max_visit_concept_order")
Expand Down Expand Up @@ -256,7 +272,7 @@ def _decorate(self, patient_events: DataFrame):
.where(F.col("hour_delta") > 0)
.withColumn("standard_concept_id", inpatient_att_token)
.withColumn("visit_concept_order", F.col("visit_concept_order"))
.withColumn("priority", F.col("priority") - 0.01)
.withColumn("priority", get_inpatient_att_token_priority())
.withColumn("concept_value_mask", F.lit(0))
.withColumn("concept_value", F.lit(0.0))
.withColumn("unit", F.lit(None).cast("string"))
Expand All @@ -282,7 +298,7 @@ def _decorate(self, patient_events: DataFrame):
F.concat(F.lit("i-"), time_token_udf("time_delta")),
)
.withColumn("visit_concept_order", F.col("visit_concept_order"))
.withColumn("priority", F.col("priority") - 0.01)
.withColumn("priority", get_inpatient_att_token_priority())
.withColumn("concept_value_mask", F.lit(0))
.withColumn("concept_value", F.lit(0.0))
.withColumn("unit", F.lit(None).cast("string"))
Expand Down
13 changes: 8 additions & 5 deletions src/cehrbert_data/decorators/clinical_event_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pyspark.sql import DataFrame, functions as F, Window as W, types as T

from .patient_event_decorator_base import PatientEventDecorator
from .token_priority import DEFAULT_PRIORITY


class ClinicalEventDecorator(PatientEventDecorator):
Expand Down Expand Up @@ -113,11 +114,13 @@ def _decorate(self, patient_events: DataFrame):
.distinct()
)

# Set the priority for the events.
# Create the week since epoch UDF
weeks_since_epoch_udf = (F.unix_timestamp("date") / F.lit(24 * 60 * 60 * 7)).cast("int")
patient_events = patient_events.withColumn("priority", F.lit(0)).withColumn(
"date_in_week", weeks_since_epoch_udf
# Set the priority for the events. Create the week since epoch UDF
patient_events = (
patient_events
.withColumn("priority", F.lit(DEFAULT_PRIORITY))
.withColumn(
"date_in_week", (F.unix_timestamp("date") / F.lit(24 * 60 * 60 * 7)).cast("int")
)
)

# Create the concept_value_mask field to indicate whether domain values should be skipped
Expand Down
30 changes: 20 additions & 10 deletions src/cehrbert_data/decorators/death_event_decorator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from pyspark.sql import DataFrame, functions as F, Window as W, types as T

from .patient_event_decorator_base import PatientEventDecorator, AttType, time_day_token, \
time_week_token, time_month_token, time_mix_token, time_token_func
from ..const.artificial_tokens import VS_TOKEN, VE_TOKEN, DEATH_TOKEN
from .patient_event_decorator_base import (
PatientEventDecorator,
AttType,
time_day_token,
time_week_token,
time_month_token,
time_mix_token,
time_token_func
)
from .token_priority import VS_TOKEN_PRIORITY, VE_TOKEN_PRIORITY, ATT_TOKEN_PRIORITY, DEATH_TOKEN_PRIORITY


class DeathEventDecorator(PatientEventDecorator):
Expand All @@ -18,7 +27,7 @@ def _decorate(self, patient_events: DataFrame):
max_visit_occurrence_id = death_records.select(F.max("visit_occurrence_id").alias("max_visit_occurrence_id"))

last_ve_record = (
death_records.where(F.col("standard_concept_id") == "VE")
death_records.where(F.col("standard_concept_id") == VE_TOKEN)
.withColumn(
"record_rank",
F.row_number().over(W.partitionBy("person_id", "cohort_member_id").orderBy(F.desc("date"))),
Expand All @@ -37,27 +46,28 @@ def _decorate(self, patient_events: DataFrame):
artificial_visit_id = F.row_number().over(
W.partitionBy(F.lit(0)).orderBy("person_id", "cohort_member_id")
) + F.col("max_visit_occurrence_id")

death_records = (
last_ve_record.crossJoin(max_visit_occurrence_id)
.withColumn("visit_occurrence_id", artificial_visit_id)
.withColumn("standard_concept_id", F.lit("[DEATH]"))
.withColumn("standard_concept_id", F.lit(DEATH_TOKEN))
.withColumn("domain", F.lit("death"))
.withColumn("visit_rank_order", F.lit(1) + F.col("visit_rank_order"))
.withColumn("priority", F.lit(20))
.withColumn("priority", DEATH_TOKEN_PRIORITY)
.drop("max_visit_occurrence_id")
)

vs_records = (
death_records
.withColumn("standard_concept_id", F.lit("VS"))
.withColumn("priority", F.lit(15))
.withColumn("standard_concept_id", F.lit(VS_TOKEN))
.withColumn("priority", VS_TOKEN_PRIORITY)
.withColumn("unit", F.lit(None).cast("string"))
)

ve_records = (
death_records
.withColumn("standard_concept_id", F.lit("VE"))
.withColumn("priority", F.lit(30))
.withColumn("standard_concept_id", F.lit(VE_TOKEN))
.withColumn("priority", VE_TOKEN_PRIORITY)
.withColumn("unit", F.lit(None).cast("string"))
)

Expand All @@ -82,7 +92,7 @@ def _decorate(self, patient_events: DataFrame):
death_events = (
death_events.withColumn("time_delta", F.datediff("death_date", "date"))
.withColumn("standard_concept_id", time_token_udf("time_delta"))
.withColumn("priority", F.lit(10))
.withColumn("priority", F.lit(ATT_TOKEN_PRIORITY))
.withColumn("unit", F.lit(None).cast("string"))
.drop("time_delta")
)
Expand Down
15 changes: 11 additions & 4 deletions src/cehrbert_data/decorators/demographic_event_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

from .patient_event_decorator_base import PatientEventDecorator

from .token_priority import (
YEAR_TOKEN_PRIORITY,
AGE_TOKEN_PRIORITY,
GENDER_TOKEN_PRIORITY,
RACE_TOKEN_PRIORITY
)


class DemographicEventDecorator(PatientEventDecorator):
def __init__(self, patient_demographic, use_age_group: bool = False):
Expand Down Expand Up @@ -43,7 +50,7 @@ def _decorate(self, patient_events: DataFrame):
"standard_concept_id",
F.concat(F.lit("year:"), F.year("date").cast(T.StringType())),
)
.withColumn("priority", F.lit(-10))
.withColumn("priority", F.lit(YEAR_TOKEN_PRIORITY))
.withColumn("visit_segment", F.lit(0))
.withColumn("date_in_week", F.lit(0))
.withColumn("age", F.lit(-1))
Expand Down Expand Up @@ -74,23 +81,23 @@ def _decorate(self, patient_events: DataFrame):
self._patient_demographic.select(F.col("person_id"), F.col("birth_datetime"))
.join(sequence_start_year_token, "person_id")
.withColumn("standard_concept_id", age_at_first_visit_udf)
.withColumn("priority", F.lit(-9))
.withColumn("priority", F.lit(AGE_TOKEN_PRIORITY))
.drop("birth_datetime")
)

sequence_gender_token = (
self._patient_demographic.select(F.col("person_id"), F.col("gender_concept_id"))
.join(sequence_start_year_token, "person_id")
.withColumn("standard_concept_id", F.col("gender_concept_id").cast(T.StringType()))
.withColumn("priority", F.lit(-8))
.withColumn("priority", F.lit(GENDER_TOKEN_PRIORITY))
.drop("gender_concept_id")
)

sequence_race_token = (
self._patient_demographic.select(F.col("person_id"), F.col("race_concept_id"))
.join(sequence_start_year_token, "person_id")
.withColumn("standard_concept_id", F.col("race_concept_id").cast(T.StringType()))
.withColumn("priority", F.lit(-7))
.withColumn("priority", F.lit(RACE_TOKEN_PRIORITY))
.drop("race_concept_id")
)

Expand Down
22 changes: 22 additions & 0 deletions src/cehrbert_data/decorators/token_priority.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pyspark.sql import functions
from pyspark.sql import Column

YEAR_TOKEN_PRIORITY = -10
AGE_TOKEN_PRIORITY = -9
GENDER_TOKEN_PRIORITY = -8
RACE_TOKEN_PRIORITY = -7
ATT_TOKEN_PRIORITY = -3
VS_TOKEN_PRIORITY = -2
VISIT_TYPE_TOKEN_PRIORITY = -1
DEFAULT_PRIORITY = 0
DISCHARGE_TOKEN_PRIORITY = 100
DEATH_TOKEN_PRIORITY = 199
VE_TOKEN_PRIORITY = 200


def get_inpatient_token_priority() -> Column:
return functions.col("priority") + functions.col("concept_order") * 0.1


def get_inpatient_att_token_priority() -> Column:
return functions.col("priority") - 0.01

0 comments on commit 6c06275

Please sign in to comment.