Skip to content

Commit

Permalink
Added units to cehrbert/cehrgpt training data (#6)
Browse files Browse the repository at this point in the history
* split patient event decorators into separate modules

* fixed the imports cehrbert_data.decorators.patient_event_decorator

* removed src/cehrbert_data/decorators/__pycache__

* updated the import for visit_occurrence constant in spark_app_basee.py

* added the argument type and function return type for att functions

* updated the process_measurement function in spark_utils to include the standard unit names for numeric measurements

* added the unit column to patient events

* added units to the cehrbert/gpt training data

* removed the wrong Value import

* updated the test class name in test_generate_training_data.py

* fixed a bug where the unit column is missing when include_visit_type is set to True for generating prediction cohorts
  • Loading branch information
ChaoPang authored Oct 2, 2024
1 parent b9c11b8 commit 5697bd6
Show file tree
Hide file tree
Showing 14 changed files with 893 additions and 845 deletions.
54 changes: 29 additions & 25 deletions src/cehrbert_data/apps/generate_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from pyspark.sql import functions as F
from pyspark.sql.window import Window

from cehrbert_data.const.common import DEATH, MEASUREMENT, PERSON, REQUIRED_MEASUREMENT, VISIT_OCCURRENCE
from cehrbert_data.decorators.patient_event_decorator import AttType
from cehrbert_data.const.common import DEATH, MEASUREMENT, PERSON, REQUIRED_MEASUREMENT, VISIT_OCCURRENCE, CONCEPT
from cehrbert_data.decorators import AttType
from cehrbert_data.utils.spark_utils import (
create_sequence_data,
create_sequence_data_with_att,
Expand All @@ -21,26 +21,26 @@


def main(
input_folder,
output_folder,
domain_table_list,
date_filter,
include_visit_type,
is_new_patient_representation,
exclude_visit_tokens,
is_classic_bert,
include_prolonged_stay,
include_concept_list: bool,
gpt_patient_sequence: bool,
apply_age_filter: bool,
include_death: bool,
att_type: AttType,
include_sequence_information_content: bool = False,
exclude_demographic: bool = False,
use_age_group: bool = False,
with_drug_rollup: bool = True,
include_inpatient_hour_token: bool = False,
continue_from_events: bool = False,
input_folder,
output_folder,
domain_table_list,
date_filter,
include_visit_type,
is_new_patient_representation,
exclude_visit_tokens,
is_classic_bert,
include_prolonged_stay,
include_concept_list: bool,
gpt_patient_sequence: bool,
apply_age_filter: bool,
include_death: bool,
att_type: AttType,
include_sequence_information_content: bool = False,
exclude_demographic: bool = False,
use_age_group: bool = False,
with_drug_rollup: bool = True,
include_inpatient_hour_token: bool = False,
continue_from_events: bool = False,
):
spark = SparkSession.builder.appName("Generate CEHR-BERT Training Data").getOrCreate()

Expand Down Expand Up @@ -119,15 +119,19 @@ def main(
if MEASUREMENT in domain_table_list:
measurement = preprocess_domain_table(spark, input_folder, MEASUREMENT)
required_measurement = preprocess_domain_table(spark, input_folder, REQUIRED_MEASUREMENT)
if os.path.exists(os.path.join(input_folder, CONCEPT)):
concept = preprocess_domain_table(spark, input_folder, CONCEPT)
else:
concept = None
# The select is necessary to make sure the order of the columns is the same as the
# original dataframe, otherwise the union might use the wrong columns
scaled_measurement = process_measurement(spark, measurement, required_measurement, output_folder)
filtered_measurement = process_measurement(spark, measurement, required_measurement, concept)

if patient_events:
# Union all measurement records together with other domain records
patient_events = patient_events.unionByName(scaled_measurement)
patient_events = patient_events.unionByName(filtered_measurement)
else:
patient_events = scaled_measurement
patient_events = filtered_measurement

patient_events = (
patient_events.join(visit_occurrence_person, "visit_occurrence_id")
Expand Down
4 changes: 2 additions & 2 deletions src/cehrbert_data/cohorts/spark_app_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from pyspark.sql import functions as F
from pyspark.sql.window import Window

from cehrbert_data.decorators.patient_event_decorator import AttType
from cehrbert_data.decorators import AttType
from cehrbert_data.const.common import VISIT_OCCURRENCE
from cehrbert_data.utils.spark_utils import (
VISIT_OCCURRENCE,
build_ancestry_table_for,
create_concept_frequency_data,
create_hierarchical_sequence_data,
Expand Down
5 changes: 5 additions & 0 deletions src/cehrbert_data/decorators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .artificial_time_token_decorator import AttEventDecorator
from .death_event_decorator import DeathEventDecorator
from .clinical_event_decorator import ClinicalEventDecorator
from .demographic_event_decorator import DemographicEventDecorator
from .patient_event_decorator_base import time_token_func, get_att_function, AttType
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 5697bd6

Please sign in to comment.