Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added units to cehrbert/cehrgpt training data #6

Merged
merged 11 commits into from
Oct 2, 2024
Merged
54 changes: 29 additions & 25 deletions src/cehrbert_data/apps/generate_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from pyspark.sql import functions as F
from pyspark.sql.window import Window

from cehrbert_data.const.common import DEATH, MEASUREMENT, PERSON, REQUIRED_MEASUREMENT, VISIT_OCCURRENCE
from cehrbert_data.decorators.patient_event_decorator import AttType
from cehrbert_data.const.common import DEATH, MEASUREMENT, PERSON, REQUIRED_MEASUREMENT, VISIT_OCCURRENCE, CONCEPT
from cehrbert_data.decorators import AttType
from cehrbert_data.utils.spark_utils import (
create_sequence_data,
create_sequence_data_with_att,
Expand All @@ -21,26 +21,26 @@


def main(
input_folder,
output_folder,
domain_table_list,
date_filter,
include_visit_type,
is_new_patient_representation,
exclude_visit_tokens,
is_classic_bert,
include_prolonged_stay,
include_concept_list: bool,
gpt_patient_sequence: bool,
apply_age_filter: bool,
include_death: bool,
att_type: AttType,
include_sequence_information_content: bool = False,
exclude_demographic: bool = False,
use_age_group: bool = False,
with_drug_rollup: bool = True,
include_inpatient_hour_token: bool = False,
continue_from_events: bool = False,
input_folder,
output_folder,
domain_table_list,
date_filter,
include_visit_type,
is_new_patient_representation,
exclude_visit_tokens,
is_classic_bert,
include_prolonged_stay,
include_concept_list: bool,
gpt_patient_sequence: bool,
apply_age_filter: bool,
include_death: bool,
att_type: AttType,
include_sequence_information_content: bool = False,
exclude_demographic: bool = False,
use_age_group: bool = False,
with_drug_rollup: bool = True,
include_inpatient_hour_token: bool = False,
continue_from_events: bool = False,
):
spark = SparkSession.builder.appName("Generate CEHR-BERT Training Data").getOrCreate()

Expand Down Expand Up @@ -119,15 +119,19 @@ def main(
if MEASUREMENT in domain_table_list:
measurement = preprocess_domain_table(spark, input_folder, MEASUREMENT)
required_measurement = preprocess_domain_table(spark, input_folder, REQUIRED_MEASUREMENT)
if os.path.exists(os.path.join(input_folder, CONCEPT)):
concept = preprocess_domain_table(spark, input_folder, CONCEPT)
else:
concept = None
# The select is necessary to make sure the order of the columns is the same as the
# original dataframe, otherwise the union might use the wrong columns
scaled_measurement = process_measurement(spark, measurement, required_measurement, output_folder)
filtered_measurement = process_measurement(spark, measurement, required_measurement, concept)

if patient_events:
# Union all measurement records together with other domain records
patient_events = patient_events.unionByName(scaled_measurement)
patient_events = patient_events.unionByName(filtered_measurement)
else:
patient_events = scaled_measurement
patient_events = filtered_measurement

patient_events = (
patient_events.join(visit_occurrence_person, "visit_occurrence_id")
Expand Down
4 changes: 2 additions & 2 deletions src/cehrbert_data/cohorts/spark_app_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from pyspark.sql import functions as F
from pyspark.sql.window import Window

from cehrbert_data.decorators.patient_event_decorator import AttType
from cehrbert_data.decorators import AttType
from cehrbert_data.const.common import VISIT_OCCURRENCE
from cehrbert_data.utils.spark_utils import (
VISIT_OCCURRENCE,
build_ancestry_table_for,
create_concept_frequency_data,
create_hierarchical_sequence_data,
Expand Down
5 changes: 5 additions & 0 deletions src/cehrbert_data/decorators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .artificial_time_token_decorator import AttEventDecorator
from .death_event_decorator import DeathEventDecorator
from .clinical_event_decorator import ClinicalEventDecorator
from .demographic_event_decorator import DemographicEventDecorator
from .patient_event_decorator_base import time_token_func, get_att_function, AttType
Binary file not shown.
Binary file not shown.
Loading
Loading