diff --git a/src/cehrbert_data/tools/extract_features.py b/src/cehrbert_data/tools/extract_features.py index af41741..304a947 100644 --- a/src/cehrbert_data/tools/extract_features.py +++ b/src/cehrbert_data/tools/extract_features.py @@ -153,15 +153,37 @@ def main(args): ).where( f.col("cohort.index_date").between(f.col("visit.visit_start_datetime"), f.col("visit.visit_end_datetime")) ).select( + f.col("visit.person_id").alias("person_id"), f.col("visit.cohort_member_id").alias("cohort_member_id"), f.col("visit.visit_occurrence_id").alias("visit_occurrence_id"), + f.col("visit.visit_concept_id").alias("visit_concept_id"), f.col("cohort.index_date").alias("index_date"), ) + # Add an artificial token for the visit in which the prediction is made + ehr_records = ehr_records.unionByName( + visit_index_date.select( + "person_id", + "cohort_member_id", + "index_date", + "visit_occurrence_id", + f.lit("[END]").alias("standard_concept_id"), + f.col("index_date").cast(t.DateType()).alias("date"), + f.expr("index_date - INTERVAL 1 MINUTE").alias("datetime"), + f.lit("unknown").alias("domain"), + f.lit(None).cast(t.StringType()).alias("unit"), + f.lit(None).cast(t.FloatType()).alias("number_as_value"), + f.lit(None).cast(t.StringType()).alias("concept_as_value"), + f.lit(None).cast(t.StringType()).alias("event_group_id"), + "visit_concept_id", + f.lit(-1).alias("age") + ) + ) + # Bound the visit_end_date and visit_end_datetime cohort_visit_occurrence = cohort_visit_occurrence.join( - visit_index_date, - ["visit_occurrence_id", "cohort_member_id"], + visit_index_date.drop("visit_concept_id"), + ["visit_occurrence_id", "cohort_member_id", "person_id"], "left_outer", ).withColumn( "visit_end_date",