Skip to content

Commit

Permalink
added the option to add include_inpatient_hour_token
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Jan 31, 2025
1 parent 4cbf8fb commit 6e49d04
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/cehrbert_data/cohorts/spark_app_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import ABC
from typing import List

from numpy.random import permutation
from pandas import to_datetime
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import functions as F
Expand Down Expand Up @@ -309,6 +310,7 @@ def __init__(
is_population_estimation: bool = False,
att_type: AttType = AttType.CEHR_BERT,
inpatient_att_type: AttType = AttType.MIX,
include_inpatient_hour_token: bool = False,
exclude_demographic: bool = True,
use_age_group: bool = False,
single_contribution: bool = False,
Expand Down Expand Up @@ -353,6 +355,7 @@ def __init__(
self._is_population_estimation = is_population_estimation
self._att_type = att_type
self._inpatient_att_type = inpatient_att_type
self._include_inpatient_hour_token = include_inpatient_hour_token
self._exclude_demographic = exclude_demographic
self._use_age_group = use_age_group
self._single_contribution = single_contribution
Expand Down Expand Up @@ -391,6 +394,7 @@ def __init__(
f"is_population_estimation: {is_population_estimation}\n"
f"att_type: {att_type}\n"
f"inpatient_att_type: {inpatient_att_type}\n"
f"include_inpatient_hour_token: {include_inpatient_hour_token}\n"
f"exclude_demographic: {exclude_demographic}\n"
f"use_age_group: {use_age_group}\n"
f"single_contribution: {single_contribution}\n"
Expand Down Expand Up @@ -692,6 +696,9 @@ def extract_ehr_records_for_cohort(self, cohort: DataFrame):
inpatient_att_type=self._inpatient_att_type,
exclude_demographic=self._exclude_demographic,
use_age_group=self._use_age_group,
include_inpatient_hour_token=self._include_inpatient_hour_token,
spark=self.spark,
persistence_folder=self._output_data_folder,
)

return create_sequence_data(
Expand Down Expand Up @@ -808,6 +815,7 @@ def create_prediction_cohort(
is_population_estimation=spark_args.is_population_estimation,
att_type=AttType(spark_args.att_type),
inpatient_att_type=AttType(spark_args.inpatient_att_type),
include_inpatient_hour_token=spark_args.include_inpatient_hour_token,
exclude_demographic=spark_args.exclude_demographic,
use_age_group=spark_args.use_age_group,
single_contribution=spark_args.single_contribution,
Expand Down
6 changes: 6 additions & 0 deletions src/cehrbert_data/utils/spark_parse_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,12 @@ def create_spark_args(parse: bool = True):
default=AttType.NONE,
choices=[e.value for e in AttType],
)
parser.add_argument(
"--include_inpatient_hour_token",
dest="include_inpatient_hour_token",
action="store_true",
help="Indicate whether we should insert the hour tokens within inpatient visits",
)
parser.add_argument(
"--exclude_demographic",
dest="exclude_demographic",
Expand Down

0 comments on commit 6e49d04

Please sign in to comment.