Skip to content

Commit

Permalink
enabled the continuous job to build the prediction cohorts based on t…
Browse files Browse the repository at this point in the history
…he existing cohorts (#5)
  • Loading branch information
ChaoPang authored Oct 2, 2024
1 parent 5697bd6 commit 7e6ab2d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/cehrbert_data/cohorts/spark_app_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
age_upper_bound: int,
prior_observation_period: int,
post_observation_period: int,
continue_job: bool = False
):

self._query_builder = query_builder
Expand All @@ -110,6 +111,7 @@ def __init__(
self._post_observation_period = post_observation_period
cohort_name = re.sub("[^a-z0-9]+", "_", self._query_builder.get_cohort_name().lower())
self._output_data_folder = os.path.join(self._output_folder, cohort_name)
self._continue_job = continue_job

self.get_logger().info(
f"query_builder: {query_builder}\n"
Expand All @@ -121,6 +123,7 @@ def __init__(
f"age_upper_bound: {age_upper_bound}\n"
f"prior_observation_period: {prior_observation_period}\n"
f"post_observation_period: {post_observation_period}\n"
f"continue_job: {continue_job}\n"
)

# Validate the age range, observation_window and prediction_window
Expand Down Expand Up @@ -187,6 +190,11 @@ def create_cohort(self):

def build(self):
"""Build the cohort and write the dataframe as parquet files to _output_data_folder."""

# Check whether the cohort has been generated
if self._continue_job and self.cohort_exists():
return self

cohort = self.create_cohort()

cohort = self._apply_observation_period(cohort)
Expand All @@ -201,6 +209,13 @@ def build(self):

return self

def cohort_exists(self) -> bool:
try:
self.load_cohort()
return True
except Exception:
return False

def load_cohort(self):
return self.spark.read.parquet(self._output_data_folder)

Expand Down Expand Up @@ -678,6 +693,7 @@ def create_prediction_cohort(
age_upper_bound=spark_args.age_upper_bound,
prior_observation_period=prior_observation_period,
post_observation_period=post_observation_period,
continue_job=spark_args.continue_job
)
.build()
.load_cohort()
Expand Down
6 changes: 6 additions & 0 deletions src/cehrbert_data/utils/spark_parse_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def create_spark_args():
help="The path for your output_folder",
required=True,
)
parser.add_argument(
"--continue_job",
dest="continue_job",
action="store_true",
help="If set, the job continues from a previous run"
)
parser.add_argument(
"--ehr_table_list",
dest="ehr_table_list",
Expand Down

0 comments on commit 7e6ab2d

Please sign in to comment.