Skip to content

Commit

Permalink
fixed the bug where the cohort could not be joined to patient splits …
Browse files Browse the repository at this point in the history
…data when the meds format is enabled
  • Loading branch information
ChaoPang committed Feb 25, 2025
1 parent 648c114 commit 68ea04e
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/cehrbert_data/cohorts/spark_app_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,9 +570,16 @@ def build(self):
# if patient_splits is provided, we will
if self._patient_splits_folder:
patient_splits = self.spark.read.parquet(self._patient_splits_folder)
cohort.join(patient_splits, "person_id").orderBy("person_id", "cohort_member_id").write.mode(
cohort.join(
patient_splits,
cohort[person_id_column] == patient_splits.person_id
).select(
[cohort[c] for c in cohort.columns] + [patient_splits.split]
).orderBy(person_id_column, index_date_column).write.mode(
"overwrite"
).parquet(os.path.join(self._output_data_folder, "temp"))
).parquet(
os.path.join(self._output_data_folder, "temp")
)
# Reload the data from the disk
cohort = self.spark.read.parquet(os.path.join(self._output_data_folder, "temp"))
cohort.where('split="train"').write.mode("overwrite").parquet(
Expand Down

0 comments on commit 68ea04e

Please sign in to comment.