Skip to content

Commit

Permalink
added an option to cache the clinical events generated along the way
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Feb 3, 2025
1 parent ebcef66 commit 210d5a9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/cehrbert_data/cohorts/spark_app_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def __init__(
single_contribution: bool = False,
exclude_features: bool = True,
meds_format: bool = False,
cache_events: bool = False,
):
self._cohort_name = cohort_name
self._input_folder = input_folder
Expand Down Expand Up @@ -361,6 +362,7 @@ def __init__(
self._single_contribution = single_contribution
self._exclude_features = exclude_features
self._meds_format = meds_format
self._cache_events = cache_events

self.get_logger().info(
f"cohort_name: {cohort_name}\n"
Expand Down Expand Up @@ -400,6 +402,7 @@ def __init__(
f"single_contribution: {single_contribution}\n"
f"extract_features: {exclude_features}\n"
f"meds_format: {meds_format}\n"
f"cache_events: {cache_events}\n"
)

self.spark = SparkSession.builder.appName(f"Generate {self._cohort_name}").getOrCreate()
Expand Down Expand Up @@ -697,8 +700,8 @@ def extract_ehr_records_for_cohort(self, cohort: DataFrame):
exclude_demographic=self._exclude_demographic,
use_age_group=self._use_age_group,
include_inpatient_hour_token=self._include_inpatient_hour_token,
spark=self.spark,
persistence_folder=self._output_data_folder,
spark=self.spark if self._cache_events else None,
persistence_folder=self._output_data_folder if self._cache_events else None,
)

return create_sequence_data(
Expand Down
6 changes: 6 additions & 0 deletions src/cehrbert_data/utils/spark_parse_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,4 +403,10 @@ def create_spark_args(parse: bool = True):
action="store_true",
help="Indicate whether we want to generate the cohorts in the MEDS format",
)
parser.add_argument(
"--cache_events",
dest="cache_events",
action="store_true",
help="Indicate whether we want to cache all patient events including ATT in the local folder",
)
return parser.parse_args() if parse else parser

0 comments on commit 210d5a9

Please sign in to comment.