diff --git a/.github/workflows/python-app.yml b/.github/workflows/tests.yml similarity index 95% rename from .github/workflows/python-app.yml rename to .github/workflows/tests.yml index f456de63..26b8b31f 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/tests.yml @@ -1,7 +1,7 @@ # This workflow will install Python dependencies, run tests and lint with a single version of Python # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python -name: Python application +name: Tests on: push: @@ -36,4 +36,4 @@ jobs: flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | - PYTHONPATH=./: pytest \ No newline at end of file + PYTHONPATH=./: pytest diff --git a/.gitignore b/.gitignore index 3f9dd123..b4c9b9b1 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ .idea/ .vscode/ venv* - +dist/* *ipynb_checkpoints/ *h5 @@ -35,4 +35,4 @@ cehr_transformers.egg-info/top_level.txt test_data test_dataset_prepared -test*results \ No newline at end of file +test*results diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..c49f9d96 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,83 @@ +# For documentation on pre-commit usage, see https://pre-commit.com/ +# This file should be updated quarterly by a developer running `pre-commit autoupdate` +# with changes added and committed. +# This will run all defined formatters prior to adding a commit. +default_language_version: + python: python3 # or python3.10 to set a specific default version + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/DanielNoord/pydocstringformatter + rev: 'v0.7.3' + hooks: + - id: pydocstringformatter + + - repo: https://github.com/PyCQA/autoflake + rev: v2.2.0 + hooks: + - id: autoflake + + - repo: https://github.com/psf/black + rev: '24.1.1' + hooks: + - id: black + # It is recommended to specify the latest version of Python + # supported by your project here, or alternatively use + # pre-commit's default_language_version, see + # https://pre-commit.com/#top_level-default_language_version + # Pre-commit hook info from: https://black.readthedocs.io/en/stable/integrations/source_version_control.html + # Editor integration here: https://black.readthedocs.io/en/stable/integrations/editors.html + + - repo: https://github.com/adamchainz/blacken-docs + rev: "v1.12.1" # replace with latest tag on GitHub + hooks: + - id: blacken-docs + additional_dependencies: + - black>=22.12.0 + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: 'v4.5.0' + hooks: + - id: trailing-whitespace + exclude: .git/COMMIT_EDITMSG + - id: end-of-file-fixer + exclude: .git/COMMIT_EDITMSG + - id: detect-private-key + - id: debug-statements + - id: check-json + - id: pretty-format-json + - id: check-yaml + - id: name-tests-test + - id: requirements-txt-fixer + + - repo: https://github.com/pre-commit/pygrep-hooks + rev: 'v1.10.0' + hooks: + - id: python-no-eval + - id: python-no-log-warn + - id: python-use-type-annotations + + - repo: https://github.com/Lucas-C/pre-commit-hooks + rev: v1.5.4 + hooks: + - id: remove-crlf + - id: remove-tabs # defaults to: 4 + exclude: .git/COMMIT_EDITMSG + + - repo: https://github.com/PyCQA/isort.git + rev: 5.13.2 + hooks: + - id: isort + args: [ "--profile", "black" ] + + - repo: https://github.com/PyCQA/bandit + rev: '1.7.7' + hooks: + - id: bandit + args: ["--skip", "B101,B106,B107,B301,B311,B105,B608,B403"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..5b34f4e3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Department of Biomedical Informatics + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 4ad98d22..c859860b 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ pip install -e .[dev] Download [jtds-1.3.1.jar](jtds-1.3.1.jar) into the spark jars folder in the python environment ```console -cp jtds-1.3.1.jar .venv/lib/python3.10/site-packages/pyspark/jars/ +cp jtds-1.3.1.jar .venv/lib/python3.10/site-packages/pyspark/jars/ ``` ## Instructions for Use with [MEDS](https://github.com/Medical-Event-Data-Standard/meds) @@ -65,7 +65,7 @@ cp jtds-1.3.1.jar .venv/lib/python3.10/site-packages/pyspark/jars/ ### 1. Convert MEDS to the [meds_reader](https://github.com/som-shahlab/meds_reader) database If you don't have the MEDS dataset, you could convert the OMOP dataset to the MEDS -using [meds_etl](https://github.com/Medical-Event-Data-Standard/meds_etl). +using [meds_etl](https://github.com/Medical-Event-Data-Standard/meds_etl). We have prepared a synthea dataset with 1M patients for you to test, you could download it at [omop_synthea.tar.gz](https://drive.google.com/file/d/1k7-cZACaDNw8A1JRI37mfMAhEErxKaQJ/view?usp=share_link) ```console @@ -115,7 +115,7 @@ The sequence can be seen conceptually as [VS] [V1] [VE] [ATT] [VS] [V2] [VE], wh concepts associated with those visits. ```console -PYTHONPATH=./: spark-submit spark_apps/generate_training_data.py -i ~/Documents/omop_test/ -o ~/Documents/omop_test/cehr-bert -tc condition_occurrence procedure_occurrence drug_exposure -d 1985-01-01 --is_new_patient_representation -iv +PYTHONPATH=./: spark-submit spark_apps/generate_training_data.py -i ~/Documents/omop_test/ -o ~/Documents/omop_test/cehr-bert -tc condition_occurrence procedure_occurrence drug_exposure -d 1985-01-01 --is_new_patient_representation -iv ``` ### 3. Pre-train CEHR-BERT @@ -125,7 +125,7 @@ at `sample/patient_sequence` in the repo. CEHR-BERT expects the data folder to b ```console mkdir test_dataset_prepared; mkdir test_results; -python -m cehrbert.runners.hf_cehrbert_pretrain_runner sample_configs/hf_cehrbert_pretrain_runner_config.yaml +python -m cehrbert.runners.hf_cehrbert_pretrain_runner sample_configs/hf_cehrbert_pretrain_runner_config.yaml ``` If your dataset is large, you could add ```--use_dask``` in the command above @@ -157,4 +157,4 @@ Chao Pang, Xinzhuo Jiang, Krishna S. Kalluri, Matthew Spotnitz, RuiJun Chen, Adl Perotte, and Karthik Natarajan. "Cehr-bert: Incorporating temporal information from structured ehr data to improve prediction tasks." In Proceedings of Machine Learning for Health, volume 158 of Proceedings of Machine Learning Research, pages 239–260. PMLR, -04 Dec 2021. \ No newline at end of file +04 Dec 2021. diff --git a/db_properties.ini b/db_properties.ini index 2f48f2c2..db36304e 100644 --- a/db_properties.ini +++ b/db_properties.ini @@ -2,4 +2,4 @@ base_url = jdbc:jtds:sqlserver://servername:1433;useNTLMv2=true;domain=domain_name;databaseName=db driver = net.sourceforge.jtds.jdbc.Driver user = username -password = password \ No newline at end of file +password = password diff --git a/deepspeed_configs/zero1.json b/deepspeed_configs/zero1.json index 1816cb43..787fc0d6 100644 --- a/deepspeed_configs/zero1.json +++ b/deepspeed_configs/zero1.json @@ -19,4 +19,4 @@ "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false -} \ No newline at end of file +} diff --git a/deepspeed_configs/zero2.json b/deepspeed_configs/zero2.json index 8ca90e5a..5b22d996 100644 --- a/deepspeed_configs/zero2.json +++ b/deepspeed_configs/zero2.json @@ -23,4 +23,4 @@ "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false -} \ No newline at end of file +} diff --git a/deepspeed_configs/zero3.json b/deepspeed_configs/zero3.json index a74ea983..a185afab 100644 --- a/deepspeed_configs/zero3.json +++ b/deepspeed_configs/zero3.json @@ -27,4 +27,4 @@ "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false -} \ No newline at end of file +} diff --git a/full_grid_search_config.ini b/full_grid_search_config.ini index 61cd32f8..5d0f1618 100644 --- a/full_grid_search_config.ini +++ b/full_grid_search_config.ini @@ -8,4 +8,4 @@ val_4 = 1.2e-4 val_1 = True [LSTM_UNIT] -val_1 = 128 \ No newline at end of file +val_1 = 128 diff --git a/pyproject.toml b/pyproject.toml index d49e3e7c..439e1689 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ authors = [ ] description = "CEHR-BERT: Incorporating temporal information from structured EHR data to improve prediction tasks" readme = "README.md" +license = { text = "MIT License" } requires-python = ">=3.10.0" classifiers = [ @@ -47,7 +48,7 @@ dependencies = [ "scikit-learn==1.4.0", "scipy==1.12.0", "tensorflow==2.15.0", - "tensorflow-metal==1.1.0; sys_platform == 'darwin'", # macOS only + "tensorflow-metal==1.1.0; sys_platform == 'darwin'", # macOS only "tensorflow-datasets==4.5.2", "tqdm==4.66.1", "torch==2.4.0", @@ -60,11 +61,25 @@ dependencies = [ [tool.setuptools_scm] +[project.urls] +Homepage = "https://github.com/cumc-dbmi/cehr-bert" + [project.scripts] cehrbert-pretraining = "cehrbert.runner.hf_cehrbert_pretrain_runner:main" cehrbert-finetuning = "cehrbert.runner.hf_cehrbert_finetuning_runner:main" [project.optional-dependencies] dev = [ - "pre-commit", "pytest", "pytest-cov", "pytest-subtests", "rootutils", "hypothesis" + "pre-commit", "pytest", "pytest-cov", "pytest-subtests", "rootutils", "hypothesis", "black" ] + +[tool.isort] +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +ensure_newline_before_comments = true +line_length = 120 + +[tool.black] +line_length = 120 diff --git a/sample_configs/hf_cehrbert_pretrain_runner_meds_config.yaml b/sample_configs/hf_cehrbert_pretrain_runner_meds_config.yaml index 88bb0d34..e7bbc756 100644 --- a/sample_configs/hf_cehrbert_pretrain_runner_meds_config.yaml +++ b/sample_configs/hf_cehrbert_pretrain_runner_meds_config.yaml @@ -50,4 +50,4 @@ logging_steps: 100 save_total_limit: load_best_model_at_end: true metric_for_best_model: "eval_loss" -greater_is_better: false \ No newline at end of file +greater_is_better: false diff --git a/simple_grid_search_config.ini b/simple_grid_search_config.ini index 2262b323..e6336e39 100644 --- a/simple_grid_search_config.ini +++ b/simple_grid_search_config.ini @@ -5,4 +5,4 @@ val_1 = 1.0e-4 val_1 = True [LSTM_UNIT] -val_1 = 128 \ No newline at end of file +val_1 = 128 diff --git a/src/cehrbert/__init__.py b/src/cehrbert/__init__.py index 1f07e77c..b93eae3f 100644 --- a/src/cehrbert/__init__.py +++ b/src/cehrbert/__init__.py @@ -2,6 +2,7 @@ It contains the main functions and classes needed to extract cohorts. """ + from importlib.metadata import PackageNotFoundError, version __package_name__ = "cehrbert" diff --git a/src/cehrbert/config/grid_search_config.py b/src/cehrbert/config/grid_search_config.py index b3a1148e..6c7e28bc 100644 --- a/src/cehrbert/config/grid_search_config.py +++ b/src/cehrbert/config/grid_search_config.py @@ -1,14 +1,13 @@ -from typing import NamedTuple, List +from typing import List, NamedTuple -LEARNING_RATE = 'LEARNING_RATE' -LSTM_DIRECTION = 'LSTM_DIRECTION' -LSTM_UNIT = 'LSTM_UNIT' +LEARNING_RATE = "LEARNING_RATE" +LSTM_DIRECTION = "LSTM_DIRECTION" +LSTM_UNIT = "LSTM_UNIT" class GridSearchConfig(NamedTuple): - """ - A data class for storing the row from the pandas data frame and the indexes for slicing the - """ + """A data class for storing the row from the pandas data frame and the indexes for slicing the.""" + learning_rates: List[float] = [1.0e-4] lstm_directions: List[bool] = [True] lstm_units: List[int] = [128] diff --git a/src/cehrbert/config/output_names.py b/src/cehrbert/config/output_names.py index ba931b55..fe663dbe 100644 --- a/src/cehrbert/config/output_names.py +++ b/src/cehrbert/config/output_names.py @@ -1,9 +1,9 @@ -PARQUET_DATA_PATH = 'patient_sequence' -QUALIFIED_CONCEPT_LIST_PATH = 'qualified_concept_list' -TIME_ATTENTION_MODEL_PATH = 'time_aware_model.h5' -BERT_MODEL_VALIDATION_PATH = 'bert_model.h5' -MORTALITY_DATA_PATH = 'mortality' -HEART_FAILURE_DATA_PATH = 'heart_failure' -HOSPITALIZATION_DATA_PATH = 'hospitalization' -INFORMATION_CONTENT_DATA_PATH = 'information_content' -CONCEPT_SIMILARITY_PATH = 'concept_similarity' +PARQUET_DATA_PATH = "patient_sequence" +QUALIFIED_CONCEPT_LIST_PATH = "qualified_concept_list" +TIME_ATTENTION_MODEL_PATH = "time_aware_model.h5" +BERT_MODEL_VALIDATION_PATH = "bert_model.h5" +MORTALITY_DATA_PATH = "mortality" +HEART_FAILURE_DATA_PATH = "heart_failure" +HOSPITALIZATION_DATA_PATH = "hospitalization" +INFORMATION_CONTENT_DATA_PATH = "information_content" +CONCEPT_SIMILARITY_PATH = "concept_similarity" diff --git a/src/cehrbert/const/common.py b/src/cehrbert/const/common.py index 336d0eef..2f9d12ec 100644 --- a/src/cehrbert/const/common.py +++ b/src/cehrbert/const/common.py @@ -1,18 +1,28 @@ -PERSON = 'person' -VISIT_OCCURRENCE = 'visit_occurrence' -CONDITION_OCCURRENCE = 'condition_occurrence' -PROCEDURE_OCCURRENCE = 'procedure_occurrence' -DRUG_EXPOSURE = 'drug_exposure' -DEVICE_EXPOSURE = 'device_exposure' -OBSERVATION = 'observation' -MEASUREMENT = 'measurement' -CATEGORICAL_MEASUREMENT = 'categorical_measurement' -OBSERVATION_PERIOD = 'observation_period' -DEATH = 'death' -CDM_TABLES = [PERSON, VISIT_OCCURRENCE, CONDITION_OCCURRENCE, PROCEDURE_OCCURRENCE, DRUG_EXPOSURE, - DEVICE_EXPOSURE, OBSERVATION, MEASUREMENT, CATEGORICAL_MEASUREMENT, - OBSERVATION_PERIOD, DEATH] -REQUIRED_MEASUREMENT = 'required_measurement' -UNKNOWN_CONCEPT = '[UNKNOWN]' -CONCEPT = 'concept' -CONCEPT_ANCESTOR = 'concept_ancestor' +PERSON = "person" +VISIT_OCCURRENCE = "visit_occurrence" +CONDITION_OCCURRENCE = "condition_occurrence" +PROCEDURE_OCCURRENCE = "procedure_occurrence" +DRUG_EXPOSURE = "drug_exposure" +DEVICE_EXPOSURE = "device_exposure" +OBSERVATION = "observation" +MEASUREMENT = "measurement" +CATEGORICAL_MEASUREMENT = "categorical_measurement" +OBSERVATION_PERIOD = "observation_period" +DEATH = "death" +CDM_TABLES = [ + PERSON, + VISIT_OCCURRENCE, + CONDITION_OCCURRENCE, + PROCEDURE_OCCURRENCE, + DRUG_EXPOSURE, + DEVICE_EXPOSURE, + OBSERVATION, + MEASUREMENT, + CATEGORICAL_MEASUREMENT, + OBSERVATION_PERIOD, + DEATH, +] +REQUIRED_MEASUREMENT = "required_measurement" +UNKNOWN_CONCEPT = "[UNKNOWN]" +CONCEPT = "concept" +CONCEPT_ANCESTOR = "concept_ancestor" diff --git a/src/cehrbert/data_generators/data_classes.py b/src/cehrbert/data_generators/data_classes.py index aadc0a70..7114ee0a 100644 --- a/src/cehrbert/data_generators/data_classes.py +++ b/src/cehrbert/data_generators/data_classes.py @@ -1,25 +1,26 @@ from enum import Enum -from typing import Tuple, NamedTuple +from typing import NamedTuple, Tuple class RecordStatus(Enum): """ - COMPLETE indicates the record contains the entire history of a patient, therefore we should add [START] and [END] + COMPLETE indicates the record contains the entire history of a patient, therefore we should add [START] and [END]. + tokens to the patient history. RIGHT_TRUNCATION indicates that we employ the right truncation of the patient history for long sequences. This means that we should not add the [END] token to the end of the record because this partial history is not supposed to end. TRUNCATION indicates that the record is a slice of the patient history, we should not add [START] token and [END] token to such a sequence because this partial history got truncated on both ends. """ + COMPLETE = 1 RIGHT_TRUNCATION = 2 TRUNCATION = 3 class RowSlicer(NamedTuple): - """ - A data class for storing the row from the pandas data frame and the indexes for slicing the - """ + """A data class for storing the row from the pandas data frame and the indexes for slicing the.""" + row: Tuple start_index: int end_index: int diff --git a/src/cehrbert/data_generators/data_generator_base.py b/src/cehrbert/data_generators/data_generator_base.py index 924fb5ab..c91ea86e 100644 --- a/src/cehrbert/data_generators/data_generator_base.py +++ b/src/cehrbert/data_generators/data_generator_base.py @@ -1,20 +1,35 @@ import inspect import logging -import copy +import random +from abc import ABC, abstractmethod from collections import ChainMap -from itertools import chain -from typing import Set - -from pandas import DataFrame -from .learning_objective import * +from itertools import chain, islice +from typing import List, Set + +import numpy as np +import pandas as pd + +from .data_classes import RowSlicer +from .learning_objective import ( + BertFineTuningLearningObjective, + DemographicsLearningObjective, + HierarchicalArtificialTokenPredictionLearningObjective, + HierarchicalMaskedLanguageModelLearningObjective, + HierarchicalProlongedLengthStayLearningObjective, + HierarchicalReadmissionLearningObjective, + HierarchicalVisitTypePredictionLearningObjective, + LearningObjective, + MaskedLanguageModelLearningObjective, + ProlongedLengthStayLearningObjective, + TimeAttentionLearningObjective, + VisitPredictionLearningObjective, +) from .tokenizer import ConceptTokenizer -from .data_classes import RecordStatus, RowSlicer -def create_indexes_by_time_window(dates, cursor, max_seq_len, - time_window_size): +def create_indexes_by_time_window(dates, cursor, max_seq_len, time_window_size): """ - Extract the start_index and end_index used for slicing the sequences e.g. concept_ids and dates + Extract the start_index and end_index used for slicing the sequences e.g. concept_ids and dates. :param dates: a list of time stamps associated with the context :param cursor: the current index used as the center for slicing the sequence @@ -31,41 +46,39 @@ def create_indexes_by_time_window(dates, cursor, max_seq_len, context_dates = dates[start_index:end_index] time_deltas = context_dates - dates[cursor] context_indexes = np.squeeze( - np.argwhere((time_deltas >= -half_time_window_size) - & (time_deltas <= half_time_window_size)), - axis=-1) + np.argwhere((time_deltas >= -half_time_window_size) & (time_deltas <= half_time_window_size)), + axis=-1, + ) return np.min(context_indexes).item(), np.max(context_indexes).item() def get_required_params(clazz: LearningObjective): """ - Get required parameters for the learning objective class + Get required parameters for the learning objective class. + :param clazz: :return: """ params = inspect.signature(clazz).parameters - return [ - dict(name=name, required=param.default is inspect.Parameter.empty) - for name, param in params.items() - ] + return [dict(name=name, required=param.default is inspect.Parameter.empty) for name, param in params.items()] class AbstractDataGeneratorBase(ABC): default_min_num_of_concepts = 2 - default_required_column = 'concept_ids' + default_required_column = "concept_ids" def __init__( - self, - training_data: DataFrame, - batch_size: int, - max_seq_len: int, - min_num_of_concepts: int, - is_random_cursor: bool = False, - is_pretraining: bool = True, - num_steps: int = None, - *args, - **kwargs + self, + training_data: pd.DataFrame, + batch_size: int, + max_seq_len: int, + min_num_of_concepts: int, + is_random_cursor: bool = False, + is_pretraining: bool = True, + num_steps: int = None, + *args, + **kwargs, ): self._training_data = training_data @@ -77,18 +90,16 @@ def __init__( self._num_steps = num_steps self.get_logger().info( - f'batch_size: {batch_size}\n' - f'max_seq_len: {max_seq_len}\n' - f'min_num_of_concepts: {min_num_of_concepts}\n' - f'is_random_cursor: {is_random_cursor}\n' - f'is_pretraining: {is_pretraining}\n' - f'num_of_steps: {num_steps}\n' + f"batch_size: {batch_size}\n" + f"max_seq_len: {max_seq_len}\n" + f"min_num_of_concepts: {min_num_of_concepts}\n" + f"is_random_cursor: {is_random_cursor}\n" + f"is_pretraining: {is_pretraining}\n" + f"num_of_steps: {num_steps}\n" ) self._learning_objectives = self._initialize_learning_objectives( - max_seq_len=max_seq_len, - is_pretraining=is_pretraining, - **kwargs + max_seq_len=max_seq_len, is_pretraining=is_pretraining, **kwargs ) # validate the required columns in the training data self._validate_data_frame_columns() @@ -97,21 +108,22 @@ def __init__( @abstractmethod def _get_learning_objective_classes(self) -> List[LearningObjective]: """ - Initialize a list of LearningObjectives used for generating the input and and output + Initialize a list of LearningObjectives used for generating the input and and output. + :return: """ - pass - def _initialize_learning_objectives(self, - **kwargs) -> List[LearningObjective]: + def _initialize_learning_objectives(self, **kwargs) -> List[LearningObjective]: """ - Initialize a list of LearningObjectives used for generating the input and and output + Initialize a list of LearningObjectives used for generating the input and and output. + :return: """ def _initialize(learning_objective) -> LearningObjective: """ - Initialize one LearningObjective using the provided keyword arguments + Initialize one LearningObjective using the provided keyword arguments. + from the parent method :param learning_objective: @@ -119,43 +131,40 @@ def _initialize(learning_objective) -> LearningObjective: """ learning_object_input = dict() params = get_required_params(learning_objective) - for required_param in [ - param['name'] for param in params if param['required'] - ]: + for required_param in [param["name"] for param in params if param["required"]]: if required_param in kwargs: - learning_object_input[required_param] = kwargs[ - required_param] + learning_object_input[required_param] = kwargs[required_param] return learning_objective(**learning_object_input) return list(map(_initialize, self._get_learning_objective_classes())) def _validate_data_frame_columns(self): """ - Validate if the training data has all required columns + Validate if the training data has all required columns. + :return: """ dataframe_columns = self._training_data.columns.tolist() for required_column in self._get_required_columns(): if not required_column in dataframe_columns: - raise ValueError( - f'The required column {required_column} does not exist in the training data' - ) + raise ValueError(f"The required column {required_column} does not exist in the training data") @abstractmethod def _clean_dataframe(self): """ - Clean the input data (_training_data) e.g. remove rows whose sequence length is less than + Clean the input data (_training_data) e.g. remove rows whose sequence length is less than. + _minimum_num_of_concepts. Overload this method in the subclasses to overwrite the default behavior :return: """ - pass def create_batch_generator(self): """ - Create the batch generator for tf.dataset.from_generator to use + Create the batch generator for tf.dataset.from_generator to use. + :return: """ while True: @@ -173,18 +182,15 @@ def create_batch_generator(self): output_dicts.append(output_dict) yield dict(ChainMap(*input_dicts)), dict(ChainMap(*output_dicts)) except (RuntimeError, ValueError) as e: - print(f'Error caught: {e}') + print(f"Error caught: {e}") # Break out of the infinite loop in the non pretraining mode if not self._is_pretraining: break - def set_learning_objectives( - self, - learning_objectives: List[LearningObjective] - ): + def set_learning_objectives(self, learning_objectives: List[LearningObjective]): """ - Overwrite the default learning objectives + Overwrite the default learning objectives. :param learning_objectives: :return: @@ -202,13 +208,11 @@ def get_data_size(self): def get_steps_per_epoch(self): """ Calculate the number of steps required for one epoch to complete. + Floor division + 1 if there is any modulo value :return: """ - num_of_steps = ( - self.get_data_size() // self._batch_size - + bool(self.get_data_size() % self._batch_size) - ) + num_of_steps = self.get_data_size() // self._batch_size + bool(self.get_data_size() % self._batch_size) if self._num_steps: return min(self._num_steps, num_of_steps) @@ -217,33 +221,30 @@ def get_steps_per_epoch(self): def _get_required_columns(self) -> Set[str]: """ - Combine lists of required columns from multiple learning objectives into a unique set of + Combine lists of required columns from multiple learning objectives into a unique set of. + required columns :return: """ learning_objective_required_columns = list( - chain(*[ - learning_objective.get_required_columns() - for learning_objective in self._learning_objectives - ])) - return set(learning_objective_required_columns + - [self.default_required_column]) + chain(*[learning_objective.get_required_columns() for learning_objective in self._learning_objectives]) + ) + return set(learning_objective_required_columns + [self.default_required_column]) def get_tf_dataset_schema(self): """ - Combine the input and output tensorflow data schema from multiple learning objectives + Combine the input and output tensorflow data schema from multiple learning objectives. + :return: """ input_dict_schemas = [] output_dict_schemas = [] for learning_objective in self._learning_objectives: - input_dict_schema, output_dict_schema = learning_objective.get_tf_dataset_schema( - ) + input_dict_schema, output_dict_schema = learning_objective.get_tf_dataset_schema() input_dict_schemas.append(input_dict_schema) output_dict_schemas.append(output_dict_schema) - return dict(ChainMap(*input_dict_schemas)), dict( - ChainMap(*output_dict_schemas)) + return dict(ChainMap(*input_dict_schemas)), dict(ChainMap(*output_dict_schemas)) @classmethod def get_logger(cls): @@ -252,40 +253,33 @@ def get_logger(cls): class BertDataGenerator(AbstractDataGeneratorBase): - def __init__( - self, - concept_tokenizer: ConceptTokenizer, - *args, - **kwargs): - super(BertDataGenerator, - self).__init__( - concept_tokenizer=concept_tokenizer, - *args, - **kwargs - ) + def __init__(self, concept_tokenizer: ConceptTokenizer, *args, **kwargs): + super(BertDataGenerator, self).__init__(concept_tokenizer=concept_tokenizer, *args, **kwargs) self._concept_tokenizer = concept_tokenizer def _clean_dataframe(self): self._training_data = self._training_data[ - self._training_data[self.default_required_column].apply( - lambda token_ids: len(token_ids)) >= - max(self.default_min_num_of_concepts, self._min_num_of_concepts)] + self._training_data[self.default_required_column].apply(len) + >= max(self.default_min_num_of_concepts, self._min_num_of_concepts) + ] def _get_learning_objective_classes(self): return [MaskedLanguageModelLearningObjective] def _create_iterator(self): """ - Create an iterator that will iterate through all training data + Create an iterator that will iterate through all training data. + :return: """ for row in self._training_data.sample(frac=1).itertuples(): seq_length = len(row.token_ids) if self._is_pretraining: - cursor = random.randint(0, seq_length - - 1) if self._is_random_cursor & ( - seq_length > self._max_seq_len - ) else seq_length // 2 + cursor = ( + random.randint(0, seq_length - 1) + if self._is_random_cursor & (seq_length > self._max_seq_len) + else seq_length // 2 + ) half_window_size = int(self._max_seq_len / 2) start_index = max(0, cursor - half_window_size) @@ -300,207 +294,28 @@ def get_data_size(self): return len(self._training_data) -class GptDataGenerator(BertDataGenerator): - def __init__( - self, - concept_tokenizer: ConceptTokenizer, - min_num_of_visits: int, - max_num_of_visits: int, - including_long_sequence: bool = False, - sampling_dataset_enabled: bool = False, - include_numeric_value: bool = False, - efficient_training: bool = False, - is_weighted_sample: bool = False, - weighted_sample_scaling_factor: float = 0.5, - weighted_sample_bin_width: int = 20, - sort_sequence_by_length: bool = False, - *args, - **kwargs - ): - self._min_num_of_visits = min_num_of_visits - self._max_num_of_visits = max_num_of_visits - self._including_long_sequence = including_long_sequence - self._concept_tokenizer = concept_tokenizer - self._sampling_dataset_enabled = sampling_dataset_enabled - self._include_numeric_value = include_numeric_value - self._efficient_training = efficient_training - self._is_weighted_sample = is_weighted_sample - self._weighted_sample_scaling_factor = weighted_sample_scaling_factor - self._weighted_sample_bin_width = weighted_sample_bin_width - self._sort_sequence_by_length = sort_sequence_by_length - - super(BertDataGenerator, - self).__init__( - concept_tokenizer=concept_tokenizer, - *args, - **kwargs - ) - - def _clean_dataframe(self): - self._training_data = self._training_data[ - self._training_data['num_of_visits'] >= self._min_num_of_visits] - self._training_data = self._training_data[ - self._training_data['num_of_visits'] <= self._max_num_of_visits] - self._training_data = self._training_data[ - self._training_data['num_of_concepts'] >= self._min_num_of_concepts] - - # Only remove the long sequences when these two options are not enabled - if not self._including_long_sequence and not self._is_random_cursor: - self._training_data = self._training_data[ - self._training_data['num_of_concepts'] <= self._max_seq_len] - - if self._efficient_training: - self._training_data = self._training_data.sort_values('num_of_concepts') - self._training_data['row_num'] = self._training_data.reset_index().index + 1 - self._training_data['batch_num'] = self._training_data.row_num // self._batch_size - - if self._sampling_dataset_enabled and self._is_weighted_sample: - self._training_data['bucket'] = self._training_data.num_of_concepts // self._weighted_sample_bin_width - # Calculate the bucket counts - bucket_counts = self._training_data.groupby(['bucket'])['num_of_concepts'].count() - buck_prob_pd = bucket_counts / len(self._training_data) - - # Dampen the bucket probabilities by applying a power function e.g. 0.5 - buck_dampened_probs = np.power(buck_prob_pd, self._weighted_sample_scaling_factor) - # re-scale the probability distribution so it sums up to 1 - buck_dampened_probs = buck_dampened_probs / buck_dampened_probs.sum() - - # Check the probability distribution - assert buck_dampened_probs.sum() - 1 < 1e-8 - - buck_dampened_prob_df = pd.DataFrame({ - 'bucket_freq': bucket_counts, - 'sample_weight': buck_dampened_probs - }).reset_index() - - # Calculate the individual sample weight by dividing the bucket probability by the total number of - # patient sequences in the bucket - buck_dampened_prob_df['sample_weight'] = ( - buck_dampened_prob_df['sample_weight'] / buck_dampened_prob_df['bucket_freq'] - ) - self._training_data = self._training_data.merge(buck_dampened_prob_df, on='bucket') - - # This is important so that the iloc works correctly when retrieving records from the dataframe - self._training_data = self._training_data.reset_index() - - def _get_learning_objective_classes(self): - learning_objs = [SequenceGenerationLearningObjective] - if self._include_numeric_value: - learning_objs.append(PredictNextValueLearningObjective) - return learning_objs - - def _create_iterator(self): - """ - Create an iterator that will iterate through all training data - :return: - """ - if self._efficient_training: - if self._sort_sequence_by_length: - # This sorts the training data from short to long, the model will be fed with short sequences first, - # then long sequences gradually - self._training_data = self._training_data.sort_values('batch_num') - else: - unique_batch_nums = self._training_data['batch_num'].unique() - uniform_random_order = np.random.uniform(size=unique_batch_nums.size) - random_order_pd = pd.DataFrame({ - 'batch_num': unique_batch_nums, - 'random_order': uniform_random_order} - ) - # Random order the batches of examples so that all the data points in the same batch have the same - # number of concepts - self._training_data = self._training_data.merge( - random_order_pd, on='batch_num' - ).sort_values( - ['random_order', 'batch_num'] - ).drop(columns=['random_order']) - else: - self._training_data = self._training_data.sample(frac=1.0) - - # Create a random sample cache utility class to generate a batch of indices - if self._sampling_dataset_enabled: - sample_weights = self._training_data.sample_weight if self._is_weighted_sample else None - random_sample_cache = RandomSampleCache( - data_indices=self._training_data.index, - cache_size=self._batch_size, - sample_weights=sample_weights - ) - else: - random_sample_cache = None - - for row_index in self._training_data.index: - # If the sampling strategy is enabled, we will randomly sample a record every time - if self._sampling_dataset_enabled: - # Overwrite row_index with a random index sampled from randomized_indices - row_index = random_sample_cache.next() - row = self._training_data.iloc[row_index] - seq_length = len(row.token_ids) - if seq_length <= self._max_seq_len: - yield RowSlicer(row, 0, seq_length) - elif self._is_random_cursor: - try: - starting_index, end_index, demographic_tokens = random_slice_gpt_sequence( - row.concept_ids, - self._max_seq_len - ) - # This indicates the VE token is not found - if starting_index == end_index: - continue - - # concept_ids = demographic_tokens + row.concept_ids[starting_index:end_index + 1] - concept_ids = row.concept_ids[starting_index:end_index + 1] - token_ids = self._concept_tokenizer.encode([concept_ids])[0] - visit_concept_orders = row.visit_concept_orders[starting_index:end_index + 1] - # visit_concept_orders = np.concatenate( - # [row.visit_concept_orders[:len(demographic_tokens)], - # row.visit_concept_orders[starting_index:end_index + 1]] - # ) - new_row = copy.deepcopy(row) - new_row.token_ids = token_ids - new_row.concept_ids = concept_ids - new_row.visit_concept_orders = visit_concept_orders - assert len(new_row.token_ids) <= self._max_seq_len - yield RowSlicer(new_row, 0, len(new_row.token_ids), record_status=RecordStatus.TRUNCATION) - except RuntimeError as e: - print(e) - elif self._including_long_sequence: - # Because the sequence is longer than the context window, we identify the last VE token in the - # sequence and take the patient history before that point - last_ve_token_index = 0 - for i, token in enumerate(row.token_ids): - # When the index exceeds the context window, we break out of the loop - if i >= self._max_seq_len: - break - if token == self._concept_tokenizer.get_visit_end_token_id(): - last_ve_token_index = i - yield RowSlicer(row, 0, last_ve_token_index + 1, record_status=RecordStatus.RIGHT_TRUNCATION) - - class BertVisitPredictionDataGenerator(BertDataGenerator): def __init__(self, visit_tokenizer: ConceptTokenizer, *args, **kwargs): - super(BertDataGenerator, - self).__init__(visit_tokenizer=visit_tokenizer, *args, **kwargs) + super(BertDataGenerator, self).__init__(visit_tokenizer=visit_tokenizer, *args, **kwargs) self._visit_tokenizer = visit_tokenizer def _get_learning_objective_classes(self): - return [ - MaskedLanguageModelLearningObjective, - VisitPredictionLearningObjective - ] + return [MaskedLanguageModelLearningObjective, VisitPredictionLearningObjective] class HierarchicalBertDataGenerator(AbstractDataGeneratorBase): def __init__( - self, - concept_tokenizer: ConceptTokenizer, - visit_tokenizer: ConceptTokenizer, - max_num_of_visits: int, - max_num_of_concepts: int, - include_att_prediction: bool, - include_visit_prediction: bool, - min_num_of_concepts: int = 5, - min_num_of_visits: int = 2, - *args, - **kwargs + self, + concept_tokenizer: ConceptTokenizer, + visit_tokenizer: ConceptTokenizer, + max_num_of_visits: int, + max_num_of_concepts: int, + include_att_prediction: bool, + include_visit_prediction: bool, + *args, + min_num_of_concepts: int = 5, + min_num_of_visits: int = 2, + **kwargs, ): # The num of visits @@ -508,7 +323,7 @@ def __init__( self._max_num_of_visits = max_num_of_visits self._max_num_of_concepts = max_num_of_concepts - super(HierarchicalBertDataGenerator, self).__init__( + super().__init__( concept_tokenizer=concept_tokenizer, visit_tokenizer=visit_tokenizer, max_num_of_visits=max_num_of_visits, @@ -518,18 +333,18 @@ def __init__( include_att_prediction=include_att_prediction, include_visit_prediction=include_visit_prediction, *args, - **kwargs + **kwargs, ) def _clean_dataframe(self): """ - Remove the patients that don't have enough concepts to qualify + Remove the patients that don't have enough concepts to qualify. + :return: """ min_num_of_concepts = max(self.default_min_num_of_concepts, self._min_num_of_concepts) - criteria = ( - (self._training_data['num_of_concepts'] >= min_num_of_concepts) - & (self._training_data['num_of_visits'] >= self._min_num_of_visits) + criteria = (self._training_data["num_of_concepts"] >= min_num_of_concepts) & ( + self._training_data["num_of_visits"] >= self._min_num_of_visits ) self._training_data = self._training_data[criteria] @@ -537,12 +352,13 @@ def _get_learning_objective_classes(self): return [ HierarchicalMaskedLanguageModelLearningObjective, HierarchicalArtificialTokenPredictionLearningObjective, - HierarchicalVisitTypePredictionLearningObjective + HierarchicalVisitTypePredictionLearningObjective, ] def _create_iterator(self): """ - Create an iterator that will iterate through all training example + Create an iterator that will iterate through all training example. + :return: """ for row in self._training_data.itertuples(): @@ -568,29 +384,22 @@ def get_data_size(self): class HierarchicalBertMultiTaskDataGenerator(HierarchicalBertDataGenerator): def __init__( - self, - include_readmission: bool, - include_prolonged_length_stay: bool, - *args, - **kwargs + self, + include_readmission: bool, + include_prolonged_length_stay: bool, + *args, + **kwargs, ): self._include_readmission = include_readmission self._include_prolonged_length_stay = include_prolonged_length_stay - - super( - HierarchicalBertMultiTaskDataGenerator, - self - ).__init__( - *args, - **kwargs - ) + super().__init__(*args, **kwargs) def _get_learning_objective_classes(self): learning_objectives = [ HierarchicalMaskedLanguageModelLearningObjective, HierarchicalArtificialTokenPredictionLearningObjective, - HierarchicalVisitTypePredictionLearningObjective + HierarchicalVisitTypePredictionLearningObjective, ] if self._include_readmission: @@ -606,18 +415,24 @@ class MedBertDataGenerator(BertDataGenerator): def _get_learning_objective_classes(self): return [ MaskedLanguageModelLearningObjective, - ProlongedLengthStayLearningObjective + ProlongedLengthStayLearningObjective, ] class TimeAttentionDataGenerator(AbstractDataGeneratorBase): - def __init__(self, concept_tokenizer: ConceptTokenizer, - time_window_size: int, *args, **kwargs): - super(TimeAttentionDataGenerator, - self).__init__(concept_tokenizer=concept_tokenizer, - time_window_size=time_window_size, - *args, - **kwargs) + def __init__( + self, + concept_tokenizer: ConceptTokenizer, + time_window_size: int, + *args, + **kwargs, + ): + super(TimeAttentionDataGenerator, self).__init__( + concept_tokenizer=concept_tokenizer, + time_window_size=time_window_size, + *args, + **kwargs, + ) self._concept_tokenizer = concept_tokenizer self._time_window_size = time_window_size @@ -626,18 +441,19 @@ def _get_learning_objective_classes(self): def _create_iterator(self): """ - Create an iterator that will iterate forever + Create an iterator that will iterate forever. + :return: """ while True: for row in self._training_data.itertuples(): - concept_ids, dates = zip(*sorted(zip(row.token_ids, row.dates), - key=lambda tup2: tup2[1])) + concept_ids, dates = zip(*sorted(zip(row.token_ids, row.dates), key=lambda tup2: tup2[1])) for i in range(len(concept_ids)): # Only include the concepts whose time stamps are within -half_time_window and # half_time_window from the target time stamp start_index, end_index = create_indexes_by_time_window( - dates, i, self._max_seq_len, self._time_window_size) + dates, i, self._max_seq_len, self._time_window_size + ) if start_index < end_index: yield RowSlicer(row, start_index, end_index, i) @@ -647,6 +463,8 @@ def get_data_size(self): class FineTuningHierarchicalBertDataGenerator(HierarchicalBertDataGenerator): def _get_learning_objective_classes(self): - return [HierarchicalMaskedLanguageModelLearningObjective, - DemographicsLearningObjective, - BertFineTuningLearningObjective] + return [ + HierarchicalMaskedLanguageModelLearningObjective, + DemographicsLearningObjective, + BertFineTuningLearningObjective, + ] diff --git a/src/cehrbert/data_generators/graph_sample_method.py b/src/cehrbert/data_generators/graph_sample_method.py index 1bbe101f..6de0d048 100644 --- a/src/cehrbert/data_generators/graph_sample_method.py +++ b/src/cehrbert/data_generators/graph_sample_method.py @@ -1,48 +1,40 @@ +import random from abc import ABC +from enum import Enum + import networkx as nx import pandas as pd -from enum import Enum -import random class SimilarityType(Enum): - SEMANTIC_SIMILARITY = 'semantic_similarity' - MICA_INFORMATION_CONTENT = 'mica_information_content' - LIN_MEASURE = 'lin_measure' - JIANG_MEASURE = 'jiang_measure' - INFORMATION_COEFFICIENT = 'information_coefficient' - RELEVANCE_MEASURE = 'relevance_measure' - GRAPH_IC_MEASURE = 'graph_ic_measure' - NONE = 'none' + SEMANTIC_SIMILARITY = "semantic_similarity" + MICA_INFORMATION_CONTENT = "mica_information_content" + LIN_MEASURE = "lin_measure" + JIANG_MEASURE = "jiang_measure" + INFORMATION_COEFFICIENT = "information_coefficient" + RELEVANCE_MEASURE = "relevance_measure" + GRAPH_IC_MEASURE = "graph_ic_measure" + NONE = "none" class GraphSampler(ABC): # TODO: Weighted similarity to increase the sampling the local neighbors using logit - def __init__( - self, - concept_similarity_path: str, - concept_similarity_type: str - ): + def __init__(self, concept_similarity_path: str, concept_similarity_type: str): self._concept_similarity_type = concept_similarity_type self._concept_dict, self._similarity_dict = self._init_similarity( concept_similarity_type=concept_similarity_type, - concept_similarity_path=concept_similarity_path + concept_similarity_path=concept_similarity_path, ) - def _is_sampling_enabled( - self - ): + def _is_sampling_enabled(self): """ - Check whether the graph sampling is enabled + Check whether the graph sampling is enabled. + :return: """ return self._concept_similarity_type != SimilarityType.NONE.value - def _init_similarity( - self, - concept_similarity_type: str, - concept_similarity_path: str - ): + def _init_similarity(self, concept_similarity_type: str, concept_similarity_path: str): concept_dict = {} similarity_dict = {} @@ -51,16 +43,15 @@ def _init_similarity( similarity_table = pd.read_parquet(concept_similarity_path) graph = nx.from_pandas_edgelist( similarity_table, - source='concept_id_1', - target='concept_id_2', - edge_attr=concept_similarity_type + source="concept_id_1", + target="concept_id_2", + edge_attr=concept_similarity_type, ) for source in graph.nodes(): # Convert target and source to str types because tokenizer expects string type concept_list, similarity_list = zip( - *[(str(target), val[concept_similarity_type]) for _, target, val in - graph.edges(source, data=True)] + *[(str(target), val[concept_similarity_type]) for _, target, val in graph.edges(source, data=True)] ) # Convert all concept_ids to string type concept_dict[str(source)] = list(map(str, concept_list)) @@ -68,23 +59,17 @@ def _init_similarity( return concept_dict, similarity_dict - def sample_graph( - self, - concept_id - ): + def sample_graph(self, concept_id): # sample the concept from the probability distribution in the graph and generate one if self._is_sampling_enabled() and concept_id in self._concept_dict: # This actually returns a list random_choice = random.choices( population=self._concept_dict[concept_id], weights=self._similarity_dict[concept_id], - k=1 + k=1, ) # In case the choice is none, we return itself as the default value - return next( - iter(random_choice), - concept_id - ) + return next(iter(random_choice), concept_id) # If the concept is an orphan, simply return it return concept_id diff --git a/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py b/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py index 42831bd6..6ca0843c 100644 --- a/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py +++ b/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py @@ -1,40 +1,45 @@ from typing import Union + from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict -from ...models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer + from ...data_generators.hf_data_generator.hf_dataset_mapping import ( - SortPatientSequenceMapping, - HFTokenizationMapping, + DatasetMapping, HFFineTuningMapping, - DatasetMapping + HFTokenizationMapping, + SortPatientSequenceMapping, ) +from ...models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer from ...runners.hf_runner_argument_dataclass import DataTrainingArguments CEHRBERT_COLUMNS = [ - 'concept_ids', 'ages', 'dates', 'visit_segments', - 'visit_concept_orders', 'concept_values', 'concept_value_masks', - 'mlm_skip_values' + "concept_ids", + "ages", + "dates", + "visit_segments", + "visit_concept_orders", + "concept_values", + "concept_value_masks", + "mlm_skip_values", ] -TRANSFORMER_COLUMNS = ['input_ids', 'labels'] +TRANSFORMER_COLUMNS = ["input_ids", "labels"] -FINETUNING_COLUMNS = ['age_at_index', 'classifier_label', 'index_date', 'person_id'] +FINETUNING_COLUMNS = ["age_at_index", "classifier_label", "index_date", "person_id"] def create_cehrbert_pretraining_dataset( - dataset: Union[Dataset, DatasetDict], - concept_tokenizer: CehrBertTokenizer, - data_args: DataTrainingArguments + dataset: Union[Dataset, DatasetDict], + concept_tokenizer: CehrBertTokenizer, + data_args: DataTrainingArguments, ) -> Dataset: required_columns = TRANSFORMER_COLUMNS + CEHRBERT_COLUMNS # If the data is already in meds, we don't need to sort the sequence anymore if data_args.is_data_in_med: - mapping_functions = [ - HFTokenizationMapping(concept_tokenizer, True) - ] + mapping_functions = [HFTokenizationMapping(concept_tokenizer, True)] else: mapping_functions = [ SortPatientSequenceMapping(), - HFTokenizationMapping(concept_tokenizer, True) + HFTokenizationMapping(concept_tokenizer, True), ] for mapping_function in mapping_functions: @@ -43,12 +48,12 @@ def create_cehrbert_pretraining_dataset( mapping_function, num_proc=data_args.preprocessing_num_workers, batch_size=data_args.preprocessing_batch_size, - streaming=data_args.streaming + streaming=data_args.streaming, ) if not data_args.streaming: if isinstance(dataset, DatasetDict): - all_columns = dataset['train'].column_names + all_columns = dataset["train"].column_names else: all_columns = dataset.column_names columns_to_remove = [_ for _ in all_columns if _ not in required_columns] @@ -58,22 +63,22 @@ def create_cehrbert_pretraining_dataset( def create_cehrbert_finetuning_dataset( - dataset: Union[Dataset, DatasetDict], - concept_tokenizer: CehrBertTokenizer, - data_args: DataTrainingArguments + dataset: Union[Dataset, DatasetDict], + concept_tokenizer: CehrBertTokenizer, + data_args: DataTrainingArguments, ) -> Dataset: required_columns = TRANSFORMER_COLUMNS + CEHRBERT_COLUMNS + FINETUNING_COLUMNS if data_args.is_data_in_med: mapping_functions = [ HFTokenizationMapping(concept_tokenizer, False), - HFFineTuningMapping() + HFFineTuningMapping(), ] else: mapping_functions = [ SortPatientSequenceMapping(), HFTokenizationMapping(concept_tokenizer, False), - HFFineTuningMapping() + HFFineTuningMapping(), ] for mapping_function in mapping_functions: @@ -82,12 +87,12 @@ def create_cehrbert_finetuning_dataset( mapping_function, num_proc=data_args.preprocessing_num_workers, batch_size=data_args.preprocessing_batch_size, - streaming=data_args.streaming + streaming=data_args.streaming, ) if not data_args.streaming: if isinstance(dataset, DatasetDict): - all_columns = dataset['train'].column_names + all_columns = dataset["train"].column_names else: all_columns = dataset.column_names columns_to_remove = [_ for _ in all_columns if _ not in required_columns] @@ -96,30 +101,24 @@ def create_cehrbert_finetuning_dataset( def apply_cehrbert_dataset_mapping( - dataset: Union[DatasetDict, Dataset, IterableDataset, IterableDatasetDict], - mapping_function: DatasetMapping, - batch_size: int = 128, - num_proc: int = 1, - streaming: bool = False + dataset: Union[DatasetDict, Dataset, IterableDataset, IterableDatasetDict], + mapping_function: DatasetMapping, + batch_size: int = 128, + num_proc: int = 1, + streaming: bool = False, ): if streaming: if isinstance(dataset, DatasetDict): for dataset_name in dataset.keys(): - dataset[dataset_name] = ( - dataset[dataset_name].map( - mapping_function.batch_transform, - batched=True, - batch_size=batch_size - ) + dataset[dataset_name] = dataset[dataset_name].map( + mapping_function.batch_transform, + batched=True, + batch_size=batch_size, ) if mapping_function.remove_columns(): dataset[dataset_name] = dataset[dataset_name].remove_columns(mapping_function.remove_columns()) else: - dataset = dataset.map( - mapping_function.batch_transform, - batched=True, - batch_size=batch_size - ) + dataset = dataset.map(mapping_function.batch_transform, batched=True, batch_size=batch_size) if mapping_function.remove_columns(): dataset = dataset.remove_columns(mapping_function.remove_columns()) else: @@ -127,7 +126,7 @@ def apply_cehrbert_dataset_mapping( mapping_function.batch_transform, num_proc=num_proc, batched=True, - batch_size=batch_size + batch_size=batch_size, ) if mapping_function.remove_columns(): dataset = dataset.remove_columns(mapping_function.remove_columns()) diff --git a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_collator.py b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_collator.py index 967511c0..92f6c4d6 100644 --- a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_collator.py +++ b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_collator.py @@ -1,21 +1,23 @@ import collections import random +from typing import Any, Dict, Tuple + import numpy as np -from typing import Any, Tuple, Dict import torch from torch.nn.utils.rnn import pad_sequence -from ...models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer -from ...data_generators.hf_data_generator.hf_dataset_mapping import TruncationType +from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import TruncationType +from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer class CehrBertDataCollator: def __init__( - self, tokenizer: CehrBertTokenizer, - max_length: int, - mlm_probability: float = 0.15, - is_pretraining: bool = True, - truncate_type: TruncationType = TruncationType.RANDOM_RIGHT_TRUNCATION + self, + tokenizer: CehrBertTokenizer, + max_length: int, + mlm_probability: float = 0.15, + is_pretraining: bool = True, + truncate_type: TruncationType = TruncationType.RANDOM_RIGHT_TRUNCATION, ): self.tokenizer = tokenizer self.max_length = max_length @@ -27,12 +29,12 @@ def __init__( # Pre-compute these so we can use them later on # We used VS for the historical data, currently, we use the new [VS] for the newer data # so we need to check both cases. - self.vs_token_id = tokenizer._convert_token_to_id('VS') - if self.vs_token_id == tokenizer._oov_token_index: - self.vs_token_id = tokenizer._convert_token_to_id('[VS]') - self.ve_token_id = tokenizer._convert_token_to_id('VE') - if self.ve_token_id == tokenizer._oov_token_index: - self.ve_token_id = tokenizer._convert_token_to_id('[VE]') + self.vs_token_id = tokenizer.convert_token_to_id("VS") + if self.vs_token_id == tokenizer.oov_token_index: + self.vs_token_id = tokenizer.convert_token_to_id("[VS]") + self.ve_token_id = tokenizer.convert_token_to_id("VE") + if self.ve_token_id == tokenizer.oov_token_index: + self.ve_token_id = tokenizer.convert_token_to_id("[VE]") @staticmethod def _convert_to_tensor(features: Any) -> torch.Tensor: @@ -48,165 +50,107 @@ def __call__(self, examples): batch_size = len(examples) # Assume that each example in the batch is a dictionary with 'input_ids' and 'attention_mask' - batch_input_ids = [self._convert_to_tensor(example['input_ids']) for example in examples] + batch_input_ids = [self._convert_to_tensor(example["input_ids"]) for example in examples] batch_attention_mask = [ - torch.ones_like(self._convert_to_tensor(example['input_ids']), dtype=torch.float) for - example in examples + torch.ones_like(self._convert_to_tensor(example["input_ids"]), dtype=torch.float) for example in examples ] - batch_ages = [self._convert_to_tensor(example['ages']) for example in examples] - batch_dates = [self._convert_to_tensor(example['dates']) for example in examples] - batch_visit_concept_orders = [self._convert_to_tensor(example['visit_concept_orders']) for example in examples] - batch_concept_values = [self._convert_to_tensor(example['concept_values']) for example in examples] - batch_concept_value_masks = [self._convert_to_tensor(example['concept_value_masks']) for example in examples] - batch_visit_segments = [self._convert_to_tensor(example['visit_segments']) for example in examples] + batch_ages = [self._convert_to_tensor(example["ages"]) for example in examples] + batch_dates = [self._convert_to_tensor(example["dates"]) for example in examples] + batch_visit_concept_orders = [self._convert_to_tensor(example["visit_concept_orders"]) for example in examples] + batch_concept_values = [self._convert_to_tensor(example["concept_values"]) for example in examples] + batch_concept_value_masks = [self._convert_to_tensor(example["concept_value_masks"]) for example in examples] + batch_visit_segments = [self._convert_to_tensor(example["visit_segments"]) for example in examples] # Pad sequences to the max length in the batch - batch['input_ids'] = pad_sequence( + batch["input_ids"] = pad_sequence( batch_input_ids, batch_first=True, - padding_value=self.tokenizer.pad_token_index - ) - batch['attention_mask'] = pad_sequence( - batch_attention_mask, - batch_first=True, - padding_value=0. + padding_value=self.tokenizer.pad_token_index, ) - batch['ages'] = pad_sequence( - batch_ages, - batch_first=True, - padding_value=0 - ) - batch['dates'] = pad_sequence( - batch_dates, - batch_first=True, - padding_value=0 - ) - batch['visit_concept_orders'] = pad_sequence( + batch["attention_mask"] = pad_sequence(batch_attention_mask, batch_first=True, padding_value=0.0) + batch["ages"] = pad_sequence(batch_ages, batch_first=True, padding_value=0) + batch["dates"] = pad_sequence(batch_dates, batch_first=True, padding_value=0) + batch["visit_concept_orders"] = pad_sequence( batch_visit_concept_orders, batch_first=True, - padding_value=self.max_length - 1 - ) - batch['concept_values'] = pad_sequence( - batch_concept_values, - batch_first=True, - padding_value=0. - ) - batch['concept_value_masks'] = pad_sequence( - batch_concept_value_masks, - batch_first=True, - padding_value=0. - ) - batch['visit_segments'] = pad_sequence( - batch_visit_segments, - batch_first=True, - padding_value=0 + padding_value=self.max_length - 1, ) + batch["concept_values"] = pad_sequence(batch_concept_values, batch_first=True, padding_value=0.0) + batch["concept_value_masks"] = pad_sequence(batch_concept_value_masks, batch_first=True, padding_value=0.0) + batch["visit_segments"] = pad_sequence(batch_visit_segments, batch_first=True, padding_value=0) # Prepend the CLS token and their associated values to the corresponding time series features - batch['input_ids'] = torch.cat( - [torch.full((batch_size, 1), self.tokenizer.cls_token_index), batch['input_ids']], - dim=1 + batch["input_ids"] = torch.cat( + [ + torch.full((batch_size, 1), self.tokenizer.cls_token_index), + batch["input_ids"], + ], + dim=1, ) # The attention_mask is set to 1 to enable attention for the CLS token - batch['attention_mask'] = torch.cat( - [torch.full((batch_size, 1), 1.0), batch['attention_mask']], - dim=1 - ) + batch["attention_mask"] = torch.cat([torch.full((batch_size, 1), 1.0), batch["attention_mask"]], dim=1) # Set the age of the CLS token to the starting age - batch['ages'] = torch.cat( - [batch['ages'][:, 0:1], batch['ages']], - dim=1 - ) + batch["ages"] = torch.cat([batch["ages"][:, 0:1], batch["ages"]], dim=1) # Set the age of the CLS token to the starting date - batch['dates'] = torch.cat( - [batch['dates'][:, 0:1], batch['dates']], - dim=1 - ) + batch["dates"] = torch.cat([batch["dates"][:, 0:1], batch["dates"]], dim=1) # Set the visit_concept_order of the CLS token to the first visit_concept_order in the sequence subtract by 1 - visit_concept_orders_first = batch['visit_concept_orders'][:, 0:1] - 1 + visit_concept_orders_first = batch["visit_concept_orders"][:, 0:1] - 1 visit_concept_orders_first = torch.maximum( - visit_concept_orders_first, - torch.zeros_like(visit_concept_orders_first) - ) - batch['visit_concept_orders'] = torch.cat( - [visit_concept_orders_first, batch['visit_concept_orders']], - dim=1 + visit_concept_orders_first, torch.zeros_like(visit_concept_orders_first) ) + batch["visit_concept_orders"] = torch.cat([visit_concept_orders_first, batch["visit_concept_orders"]], dim=1) # Set the concept_value of the CLS token to a default value -1.0. - batch['concept_values'] = torch.cat( - [torch.full((batch_size, 1), -1.), batch['concept_values']], - dim=1 - ) + batch["concept_values"] = torch.cat([torch.full((batch_size, 1), -1.0), batch["concept_values"]], dim=1) # Set the concept_value of the CLS token to a default value 0.0 indicating that # there is no value associated with this token - batch['concept_value_masks'] = torch.cat( - [torch.full((batch_size, 1), 0.), batch['concept_value_masks']], - dim=1 + batch["concept_value_masks"] = torch.cat( + [torch.full((batch_size, 1), 0.0), batch["concept_value_masks"]], dim=1 ) # Set the visit_segments of the CLS token to a default value 0 because this doesn't belong to a visit - batch['visit_segments'] = torch.cat( - [torch.full((batch_size, 1), 0), batch['visit_segments']], - dim=1 - ) + batch["visit_segments"] = torch.cat([torch.full((batch_size, 1), 0), batch["visit_segments"]], dim=1) # This is the most crucial logic for generating the training labels if self.is_pretraining: batch_mlm_skip_values = [ - self._convert_to_tensor(example['mlm_skip_values']).to(torch.bool) for example in examples + self._convert_to_tensor(example["mlm_skip_values"]).to(torch.bool) for example in examples ] - batch['mlm_skip_values'] = pad_sequence( - batch_mlm_skip_values, - batch_first=True, - padding_value=False - ) + batch["mlm_skip_values"] = pad_sequence(batch_mlm_skip_values, batch_first=True, padding_value=False) # Set the mlm_skip_values of the CLS token to a default value False - batch['mlm_skip_values'] = torch.cat( - [torch.full((batch_size, 1), False), batch['mlm_skip_values']], - dim=1 - ) + batch["mlm_skip_values"] = torch.cat([torch.full((batch_size, 1), False), batch["mlm_skip_values"]], dim=1) # If the labels field is already provided, we will build the MLM labels off of that. # The labels value indicates the positions that are not allowed for MLM. # For example, the mlm_skip_values=1, this means this is a lab value and # we don't want to predict the tokens at this position - if 'labels' in examples[0]: - batch_labels = [self._convert_to_tensor(example['labels']) for example in examples] - batch['labels'] = pad_sequence( - batch_labels, - batch_first=True, - padding_value=-100 - ) + if "labels" in examples[0]: + batch_labels = [self._convert_to_tensor(example["labels"]) for example in examples] + batch["labels"] = pad_sequence(batch_labels, batch_first=True, padding_value=-100) # Disable MLM for the CLS token - batch['labels'] = torch.cat( - [torch.full((batch_size, 1), -100), batch['labels']], - dim=1 - ) + batch["labels"] = torch.cat([torch.full((batch_size, 1), -100), batch["labels"]], dim=1) else: # If the labels is not already provided, we assume all the tokens are subject to # the MLM and simply clone the input_ids - batch['labels'] = batch['input_ids'].clone() + batch["labels"] = batch["input_ids"].clone() - batch['input_ids'], batch['labels'] = self.torch_mask_tokens(batch['input_ids'], batch['labels']) + batch["input_ids"], batch["labels"] = self.torch_mask_tokens(batch["input_ids"], batch["labels"]) - if 'age_at_index' in examples[0]: - batch['age_at_index'] = torch.cat( - [self._convert_to_tensor(example['age_at_index']).reshape(-1, 1) for example in examples], - dim=0 + if "age_at_index" in examples[0]: + batch["age_at_index"] = torch.cat( + [self._convert_to_tensor(example["age_at_index"]).reshape(-1, 1) for example in examples], + dim=0, ).to(torch.float) - if 'classifier_label' in examples[0]: - batch['classifier_label'] = torch.cat( - [self._convert_to_tensor(example['classifier_label']).reshape(-1, 1) for example in examples], - dim=0 + if "classifier_label" in examples[0]: + batch["classifier_label"] = torch.cat( + [self._convert_to_tensor(example["classifier_label"]).reshape(-1, 1) for example in examples], + dim=0, ).to(torch.float) return batch def torch_mask_tokens(self, inputs: torch.Tensor, labels: torch.Tensor) -> Tuple[Any, Any]: - """ - Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. - """ + """Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.""" # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) probability_matrix = torch.full(labels.shape, self.mlm_probability) pad_token_mask = inputs == self.tokenizer.pad_token_index @@ -227,16 +171,13 @@ def torch_mask_tokens(self, inputs: torch.Tensor, labels: torch.Tensor) -> Tuple # The rest of the time (10% of the time) we keep the masked input tokens unchanged return inputs, labels - def generate_start_end_index( - self, - record: Dict[str, Any] - ) -> Dict[str, Any]: + def generate_start_end_index(self, record: Dict[str, Any]) -> Dict[str, Any]: """ - Adapted from https://github.com/OHDSI/Apollo/blob/main/data_loading/data_transformer.py + Adapted from https://github.com/OHDSI/Apollo/blob/main/data_loading/data_transformer.py. Adding the start and end indices to extract a portion of the patient sequence """ - seq_length = len(record['input_ids']) + seq_length = len(record["input_ids"]) new_max_length = self.max_length - 1 # Subtract one for the [CLS] token # Return the record directly if the actual sequence length is less than the max sequence @@ -248,12 +189,13 @@ def generate_start_end_index( start_index = random.randint(0, seq_length - new_max_length) end_index = min(seq_length, start_index + new_max_length) elif self.truncate_type in ( - TruncationType.RANDOM_RIGHT_TRUNCATION, TruncationType.RANDOM_COMPLETE + TruncationType.RANDOM_RIGHT_TRUNCATION, + TruncationType.RANDOM_COMPLETE, ): # We randomly pick a [VS] token starting_points = [] for i in range(seq_length - new_max_length): - current_token = record['input_ids'][i] + current_token = record["input_ids"][i] if current_token == self.vs_token_id: starting_points.append(i) @@ -264,7 +206,7 @@ def generate_start_end_index( # We randomly backtrack to a [VE] token so the sample is complete if self.truncate_type == TruncationType.RANDOM_COMPLETE: for i in reversed(list(range(start_index + 1, end_index))): - current_token = record['input_ids'][i] + current_token = record["input_ids"][i] if current_token == self.ve_token_id: end_index = i break @@ -272,18 +214,14 @@ def generate_start_end_index( start_index = max(0, seq_length - new_max_length) end_index = seq_length for i in range(start_index, seq_length): - current_token = record['input_ids'][i] + current_token = record["input_ids"][i] if current_token == self.vs_token_id: start_index = i break new_record = collections.OrderedDict() for k, v in record.items(): - if ( - isinstance(v, list) or - isinstance(v, np.ndarray) or - (isinstance(v, torch.Tensor) and v.dim() > 0) - ): + if isinstance(v, list) or isinstance(v, np.ndarray) or (isinstance(v, torch.Tensor) and v.dim() > 0): if len(v) == seq_length: new_record[k] = v[start_index:end_index] else: diff --git a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py index 4e42f0f1..fd509025 100644 --- a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py +++ b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py @@ -1,35 +1,53 @@ +import collections +import copy import datetime import re +from abc import ABC, abstractmethod from enum import Enum -from abc import abstractmethod, ABC -from typing import Dict, List, Any, Union -import collections -import copy +from typing import Any, Dict, List, Union import numpy as np import pandas as pd +from datasets.formatting.formatting import LazyBatch from dateutil.relativedelta import relativedelta +from meds.schema import birth_code, death_code from pandas import Series -from datasets.formatting.formatting import LazyBatch -from meds.schema import birth_code, death_code -from ...spark_apps.decorators.patient_event_decorator import get_att_function -from ...models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer -from ...runners.hf_runner_argument_dataclass import DataTrainingArguments +from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer +from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments +from cehrbert.spark_apps.decorators.patient_event_decorator import get_att_function birth_codes = [birth_code, "MEDS_BIRTH"] death_codes = [death_code, "MEDS_DEATH"] # OMOP concept ids for inpatient related visits -INPATIENT_VISIT_TYPES = [ - '9201', '262', '8971', '8920', '38004311' -] +INPATIENT_VISIT_TYPES = ["9201", "262", "8971", "8920", "38004311"] INPATIENT_VISIT_TYPE_CODES = [ - 'Visit/IP', 'Visit/ERIP', 'Visit/51', 'Visit/61', 'NUCC/315D00000X' + "Visit/IP", + "Visit/ERIP", + "Visit/51", + "Visit/61", + "NUCC/315D00000X", ] DISCHARGE_FACILITY_TYPES = [ - '8536', '8863', '44814650', '4161979', '38004519', '4216643', '8717', '8920', '4021968', - '8546', '8971', '8970', '44814649', '8827', '8676', '38003619', '8870', '4146681' + "8536", + "8863", + "44814650", + "4161979", + "38004519", + "4216643", + "8717", + "8920", + "4021968", + "8546", + "8971", + "8970", + "44814649", + "8827", + "8676", + "38003619", + "8870", + "4146681", ] DATE_FORMAT = "%Y-%m-%d %H:%M:%S.%f" @@ -43,21 +61,18 @@ class TruncationType(Enum): RANDOM_COMPLETE = "random_complete" RANDOM_RIGHT_TRUNCATION = "random_right_truncation" RANDOM_TRUNCATION = "random_truncation" - TAIL = 'tail' + TAIL = "tail" class DatasetMapping(ABC): - def batch_transform( - self, - records: Union[LazyBatch, Dict[str, Any]] - ) -> List[Dict[str, Any]]: + def batch_transform(self, records: Union[LazyBatch, Dict[str, Any]]) -> List[Dict[str, Any]]: if isinstance(records, LazyBatch): dataframe = records.pa_table.to_pandas() else: dataframe = pd.DataFrame(records) applied_dataframe = dataframe.apply(self.transform_pandas_series, axis=1) - return applied_dataframe.to_dict(orient='list') + return applied_dataframe.to_dict(orient="list") def transform_pandas_series(self, series: Series) -> Series: record = self.transform(series.to_dict()) @@ -67,26 +82,19 @@ def remove_columns(self): return [] @abstractmethod - def transform( - self, - record: Dict[str, Any] - ) -> Union[Dict[str, Any], Series]: + def transform(self, record: Dict[str, Any]) -> Union[Dict[str, Any], Series]: """ - Transform the record + Transform the record. + Args record: The row to process, as generated by the CDM processing Returns A dictionary from names to numpy arrays to be used by pytorch. """ - pass class MedToCehrBertDatasetMapping(DatasetMapping): - def __init__( - self, - data_args: DataTrainingArguments, - is_pretraining: bool = True - ): + def __init__(self, data_args: DataTrainingArguments, is_pretraining: bool = True): self._time_token_function = get_att_function(data_args.att_function_type) self._include_auxiliary_token = data_args.include_auxiliary_token self._inpatient_time_token_function = get_att_function(data_args.inpatient_att_function_type) @@ -109,59 +117,63 @@ def remove_columns(self): if self._is_pretraining: return ["visits", "patient_id", "birth_datetime", "index_date"] else: - return ["visits", "patient_id", "birth_datetime", "index_date", - "visit_concept_ids", "num_of_concepts", "num_of_visits"] + return [ + "visits", + "patient_id", + "birth_datetime", + "index_date", + "visit_concept_ids", + "num_of_concepts", + "num_of_visits", + ] @staticmethod def _update_cehrbert_record( - cehrbert_record: Dict[str, Any], - code: str, - visit_segment: int = 0, - date: int = 0, - age: int = -1, - visit_concept_order: int = 0, - visit_concept_id: str = '0', - concept_value_mask: int = 0, - concept_value: float = -1., - mlm_skip_value: int = 0, + cehrbert_record: Dict[str, Any], + code: str, + visit_segment: int = 0, + date: int = 0, + age: int = -1, + visit_concept_order: int = 0, + visit_concept_id: str = "0", + concept_value_mask: int = 0, + concept_value: float = -1.0, + mlm_skip_value: int = 0, ) -> None: - cehrbert_record['concept_ids'].append(code) - cehrbert_record['visit_concept_orders'].append(visit_concept_order) - cehrbert_record['ages'].append(age) - cehrbert_record['dates'].append(date) - cehrbert_record['visit_segments'].append(visit_segment) - cehrbert_record['visit_concept_ids'].append(visit_concept_id) - cehrbert_record['concept_value_masks'].append(concept_value_mask) - cehrbert_record['concept_values'].append(concept_value) - cehrbert_record['mlm_skip_values'].append(mlm_skip_value) - - def transform( - self, - record: Dict[str, Any] - ) -> Dict[str, Any]: + cehrbert_record["concept_ids"].append(code) + cehrbert_record["visit_concept_orders"].append(visit_concept_order) + cehrbert_record["ages"].append(age) + cehrbert_record["dates"].append(date) + cehrbert_record["visit_segments"].append(visit_segment) + cehrbert_record["visit_concept_ids"].append(visit_concept_id) + cehrbert_record["concept_value_masks"].append(concept_value_mask) + cehrbert_record["concept_values"].append(concept_value) + cehrbert_record["mlm_skip_values"].append(mlm_skip_value) + + def transform(self, record: Dict[str, Any]) -> Dict[str, Any]: cehrbert_record = { - 'person_id': record['patient_id'], - 'concept_ids': [], - 'visit_segments': [], - 'orders': [], - 'dates': [], - 'ages': [], - 'visit_concept_orders': [], - 'concept_value_masks': [], - 'concept_values': [], - 'mlm_skip_values': [], - 'visit_concept_ids': [] + "person_id": record["patient_id"], + "concept_ids": [], + "visit_segments": [], + "orders": [], + "dates": [], + "ages": [], + "visit_concept_orders": [], + "concept_value_masks": [], + "concept_values": [], + "mlm_skip_values": [], + "visit_concept_ids": [], } # Extract the demographic information - birth_datetime = record['birth_datetime'] + birth_datetime = record["birth_datetime"] if isinstance(birth_datetime, pd.Timestamp): birth_datetime = birth_datetime.to_pydatetime() - gender = record['gender'] - race = record['race'] + gender = record["gender"] + race = record["race"] if self._include_demographic_prompt: - first_visit = record['visits'][0] + first_visit = record["visits"][0] year_str = f'year:{str(first_visit["visit_start_datetime"].year)}' age_str = f'age:{str(relativedelta(first_visit["visit_start_datetime"], birth_datetime).years)}' @@ -177,20 +189,20 @@ def transform( date_cursor = None # Loop through all the visits excluding the first event containing the demographics - for i, visit in enumerate(sorted(record['visits'], key=lambda e: e['visit_start_datetime'])): + for i, visit in enumerate(sorted(record["visits"], key=lambda e: e["visit_start_datetime"])): - events = visit['events'] + events = visit["events"] # Skip this visit if the number measurements in the event is zero if events is None or len(events) == 0: continue - visit_start_datetime = visit['visit_start_datetime'] + visit_start_datetime = visit["visit_start_datetime"] time_delta = (visit_start_datetime - date_cursor).days if date_cursor else None date_cursor = visit_start_datetime # We assume the first measurement to be the visit type of the current visit - visit_type = visit['visit_type'] + visit_type = visit["visit_type"] is_inpatient = visit_type in INPATIENT_VISIT_TYPES or visit_type in INPATIENT_VISIT_TYPE_CODES # Add artificial time tokens to the patient timeline if timedelta exists @@ -199,23 +211,23 @@ def transform( self._update_cehrbert_record( cehrbert_record, code=self._time_token_function(time_delta), - visit_concept_order=i + 1 + visit_concept_order=i + 1, ) # Add the VS token to the patient timeline to mark the start of a visit - age = relativedelta(visit['visit_start_datetime'], birth_datetime).years + age = relativedelta(visit["visit_start_datetime"], birth_datetime).years # Calculate the week number since the epoch time - date = (visit['visit_start_datetime'] - datetime.datetime(year=1970, month=1, day=1)).days // 7 + date = (visit["visit_start_datetime"] - datetime.datetime(year=1970, month=1, day=1)).days // 7 visit_segment = int(visit_segment_indicator) + 1 self._update_cehrbert_record( cehrbert_record, - code='[VS]', + code="[VS]", visit_concept_order=i + 1, age=age, date=date, visit_segment=visit_segment, - visit_concept_id=visit_type + visit_concept_id=visit_type, ) if self._include_auxiliary_token: @@ -226,47 +238,47 @@ def transform( age=age, date=date, visit_segment=visit_segment, - visit_concept_id=visit_type + visit_concept_id=visit_type, ) for e in events: # If the event doesn't have a time stamp, we skip it - if not e['time']: + if not e["time"]: continue # Add a medical token to the patient timeline # If this is an inpatient visit, we use the event time stamps to calculate age and date # because the patient can stay in the hospital for a period of time. if is_inpatient: # Calculate age using the event time stamp - age = relativedelta(e['time'], birth_datetime).years + age = relativedelta(e["time"], birth_datetime).years # Calculate the week number since the epoch time - date = (e['time'] - datetime.datetime(year=1970, month=1, day=1)).days // 7 + date = (e["time"] - datetime.datetime(year=1970, month=1, day=1)).days // 7 else: # For outpatient visits, we use the visit time stamp to calculate age and time because we assume # the outpatient visits start and end on the same day pass # Calculate the time diff in days w.r.t the previous measurement - meas_time_diff = (e['time'] - date_cursor).days + meas_time_diff = (e["time"] - date_cursor).days # Update the date_cursor if the time diff between two neighboring measurements is greater than and # equal to 1 day if meas_time_diff > 0: - date_cursor = e['time'] + date_cursor = e["time"] if self._inpatient_time_token_function: # This generates an artificial time token depending on the choice of the time token functions self._update_cehrbert_record( cehrbert_record, - code=f'i-{self._inpatient_time_token_function(meas_time_diff)}', + code=f"i-{self._inpatient_time_token_function(meas_time_diff)}", visit_concept_order=i + 1, visit_segment=visit_segment, - visit_concept_id=visit_type + visit_concept_id=visit_type, ) # If numeric_value exists, this is a concept/value tuple, we indicate this using a concept_value_mask - numeric_value = e.get('numeric_value', None) + numeric_value = e.get("numeric_value", None) concept_value_mask = int(numeric_value is not None) concept_value = numeric_value if concept_value_mask == 1 else -1.0 - code = replace_escape_chars(e['code']) + code = replace_escape_chars(e["code"]) # If the value mask is 1, this indicates a numeric value associated with the concept if concept_value_mask != 1: # Otherwise we will try to concatenate the answer with the code if the categorical value is provide @@ -285,12 +297,12 @@ def transform( visit_concept_id=visit_type, concept_value_mask=concept_value_mask, concept_value=concept_value, - mlm_skip_value=concept_value_mask + mlm_skip_value=concept_value_mask, ) if is_inpatient: # If visit_end_datetime is populated for the inpatient visit, we update the date_cursor - visit_end_datetime = visit.get('visit_end_datetime', None) + visit_end_datetime = visit.get("visit_end_datetime", None) if visit_end_datetime: date_cursor = visit_end_datetime @@ -298,8 +310,9 @@ def transform( # Reuse the age and date calculated for the last event in the patient timeline for the discharge # facility event discharge_facility = ( - visit['discharge_facility'] if ('discharge_facility' in visit) and visit['discharge_facility'] - else '0' + visit["discharge_facility"] + if ("discharge_facility" in visit) and visit["discharge_facility"] + else "0" ) self._update_cehrbert_record( @@ -309,34 +322,34 @@ def transform( date=date, visit_concept_order=i + 1, visit_segment=visit_segment, - visit_concept_id=visit_type + visit_concept_id=visit_type, ) # Reuse the age and date calculated for the last event in the patient timeline self._update_cehrbert_record( cehrbert_record, - code='[VE]', + code="[VE]", age=age, date=date, visit_concept_order=i + 1, visit_segment=visit_segment, - visit_concept_id=visit_type + visit_concept_id=visit_type, ) # Toggle visit_segment_indicator visit_segment_indicator = not visit_segment_indicator # Generate the orders of the concepts that the cehrbert dataset mapping function expects - cehrbert_record['orders'] = list(range(1, len(cehrbert_record['concept_ids']) + 1)) + cehrbert_record["orders"] = list(range(1, len(cehrbert_record["concept_ids"]) + 1)) # Add some count information for this sequence - cehrbert_record['num_of_concepts'] = len(cehrbert_record['concept_ids']) - cehrbert_record['num_of_visits'] = len(record['visits']) + cehrbert_record["num_of_concepts"] = len(cehrbert_record["concept_ids"]) + cehrbert_record["num_of_visits"] = len(record["visits"]) - if 'label' in record: - cehrbert_record['label'] = record['label'] - if 'age_at_index' in record: - cehrbert_record['age_at_index'] = record['age_at_index'] + if "label" in record: + cehrbert_record["label"] = record["label"] + if "age_at_index" in record: + cehrbert_record["age_at_index"] = record["age_at_index"] return cehrbert_record @@ -344,30 +357,30 @@ def transform( class SortPatientSequenceMapping(DatasetMapping): """ A mapping function to order all the features using a pre-defined orders/dates column. + This may not be necessary since the order is feature columns should've been ordered correctly during the data generation process in the spark application. However, it's a good idea to sort them explicitly one more time """ - def transform( - self, - record: Dict[str, Any] - ) -> Dict[str, Any]: + def transform(self, record: Dict[str, Any]) -> Dict[str, Any]: """ - Sort all the list features using a pre-defined orders/dates. If orders/dates columns are not provided, + Sort all the list features using a pre-defined orders/dates. + + If orders/dates columns are not provided, do nothing. """ - sorting_columns = record.get('orders', None) + sorting_columns = record.get("orders", None) if sorting_columns is None: - sorting_columns = record.get('dates', None) + sorting_columns = record.get("dates", None) if sorting_columns is None: return record sorting_columns = list(map(int, sorting_columns)) - seq_length = len(record['concept_ids']) - column_names = ['concept_ids'] - column_values = [record['concept_ids']] + seq_length = len(record["concept_ids"]) + column_names = ["concept_ids"] + column_values = [record["concept_ids"]] for k, v in record.items(): if k in column_names: @@ -390,59 +403,56 @@ def transform( class HFTokenizationMapping(DatasetMapping): - def __init__( - self, - concept_tokenizer: CehrBertTokenizer, - is_pretraining: bool - ): + def __init__(self, concept_tokenizer: CehrBertTokenizer, is_pretraining: bool): self._concept_tokenizer = concept_tokenizer self._is_pretraining = is_pretraining self._lab_token_ids = self._concept_tokenizer.lab_token_ids - def transform( - self, - record: Dict[str, Any] - ) -> Dict[str, Any]: + def transform(self, record: Dict[str, Any]) -> Dict[str, Any]: - input_ids = self._concept_tokenizer.encode(record['concept_ids']) - record['input_ids'] = input_ids - concept_value_masks = record['concept_value_masks'] - concept_values = record['concept_values'] + input_ids = self._concept_tokenizer.encode(record["concept_ids"]) + record["input_ids"] = input_ids + concept_value_masks = record["concept_value_masks"] + concept_values = record["concept_values"] - assert len(input_ids) == len(record['concept_ids']), \ - "the length of input_ids needs to be the same as the length of concept_ids" + assert len(input_ids) == len( + record["concept_ids"] + ), "the length of input_ids needs to be the same as the length of concept_ids" # If any concept has a value associated with it, we normalize the value if np.any(np.asarray(concept_value_masks) > 0): normalized_concept_values = copy.deepcopy(concept_values) - for i, (concept_id, token_id, concept_value_mask, concept_value) in enumerate( - zip(record['concept_ids'], input_ids, concept_value_masks, concept_values) + for i, ( + concept_id, + token_id, + concept_value_mask, + concept_value, + ) in enumerate( + zip( + record["concept_ids"], + input_ids, + concept_value_masks, + concept_values, + ) ): if token_id in self._lab_token_ids: normalized_concept_value = self._concept_tokenizer.normalize(concept_id, concept_value) normalized_concept_values[i] = normalized_concept_value - record['concept_values'] = normalized_concept_values + record["concept_values"] = normalized_concept_values # If mlm_skip_value=1, this indicates there is a value associated with this position and # hence we block the MLM to randomly pick this token to be predicted if self._is_pretraining: - record.update({ - 'labels': copy.deepcopy(input_ids) - }) + record.update({"labels": copy.deepcopy(input_ids)}) return record class HFFineTuningMapping(DatasetMapping): - """ - Consider removing this transformation in the future - """ + """Consider removing this transformation in the future.""" - def transform( - self, - record: Dict[str, Any] - ) -> Dict[str, Any]: + def transform(self, record: Dict[str, Any]) -> Dict[str, Any]: return { - 'age_at_index': record['age_at_index'], - 'classifier_label': record['label'] - } \ No newline at end of file + "age_at_index": record["age_at_index"], + "classifier_label": record["label"], + } diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_base.py b/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_base.py index 4191012a..d1c034d3 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_base.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_base.py @@ -6,42 +6,121 @@ @dataclass class EventConversionRule: + """ + Represents a rule for converting an event code into corresponding event labels. + + based on a regular expression pattern. + + Attributes: + code (str): The code associated with the event that needs to be parsed. + parsing_pattern (re.Pattern): The regular expression pattern used to parse the event code. + mapped_event_labels (List[str]): A list of event labels mapped to the groups + in the regular expression pattern. + + Methods: + __post_init__(): Ensures that the number of regex groups matches the number of mapped event + labels. This method is automatically called after the object is initialized. + """ + code: str parsing_pattern: re.Pattern mapped_event_labels: List[str] def __post_init__(self): - assert self.parsing_pattern.groups == len(self.mapped_event_labels), \ - "The number the mapped event labels needs to match the number of groups in the regex" + assert self.parsing_pattern.groups == len( + self.mapped_event_labels + ), "The number the mapped event labels needs to match the number of groups in the regex" class MedsToCehrBertConversion(ABC): + """ + Abstract base class for converting medication-related text events into numeric event labels. + + for CehR-BERT models. This class provides an interface for defining matching rules for + ED admission, general admission, discharge, and text-to-numeric event mappings. + + Attributes: + _ed_admission_matching_rules (List[str]): Cached matching rules for identifying ED admissions. + _admission_matching_rules (List[str]): Cached matching rules for identifying admissions. + _discharge_matching_rules (List[str]): Cached matching rules for identifying discharges. + _text_event_numeric_event_map (dict): Cached map of text event codes to EventConversionRule objects. + + Methods: + _create_ed_admission_matching_rules(): Abstract method for creating ED admission matching rules. + _create_admission_matching_rules(): Abstract method for creating admission matching rules. + _create_discharge_matching_rules(): Abstract method for creating discharge matching rules. + _create_text_event_to_numeric_event_rules(): Abstract method for creating text-to-numeric event rules. + get_ed_admission_matching_rules(): Returns the ED admission matching rules. + get_admission_matching_rules(): Returns the general admission matching rules. + get_discharge_matching_rules(): Returns the discharge matching rules. + get_text_event_to_numeric_events_rule(): Returns the EventConversionRule for a given code, + or None if no rule exists. + """ + def __init__(self): + """ + Initializes the MedsToCehrBertConversion class by caching the matching rules and. + + text-to-numeric event mappings, which are created by calling the respective abstract methods. + """ # Cache these variables once self._ed_admission_matching_rules = self._create_ed_admission_matching_rules() self._admission_matching_rules = self._create_admission_matching_rules() self._discharge_matching_rules = self._create_discharge_matching_rules() - self._text_event_numeric_event_map = { - r.code: r for r in self._create_text_event_to_numeric_event_rules() - } + self._text_event_numeric_event_map = {r.code: r for r in self._create_text_event_to_numeric_event_rules()} @abstractmethod def _create_ed_admission_matching_rules(self) -> List[str]: + """ + Abstract method for defining the matching rules for identifying ED admissions. + + Returns: + List[str]: A list of rules for identifying ED admissions. + + Raises: + NotImplementedError: Must be implemented in a subclass. + """ raise NotImplementedError("Must implement the matching rules for identifying the ED admission") @abstractmethod def _create_admission_matching_rules(self) -> List[str]: + """ + Abstract method for defining the matching rules for identifying admissions. + + Returns: + List[str]: A list of rules for identifying admissions. + + Raises: + NotImplementedError: Must be implemented in a subclass. + """ raise NotImplementedError("Must implement the matching rules for identifying the admission") @abstractmethod def _create_discharge_matching_rules(self) -> List[str]: + """ + Abstract method for defining the matching rules for identifying discharges. + + Returns: + List[str]: A list of rules for identifying discharges. + + Raises: + NotImplementedError: Must be implemented in a subclass. + """ raise NotImplementedError("Must implement the matching rules for identifying the discharge") @abstractmethod def _create_text_event_to_numeric_event_rules(self) -> List[EventConversionRule]: - raise NotImplementedError( - "Must implement the event mapping rules for converting the text events to numeric events" - ) + """ + Abstract method for defining the rules for mapping text events to numeric events. + + Returns: + List[EventConversionRule]: A list of event conversion rules mapping text events + to numeric events. + + Raises: + NotImplementedError: Must be implemented in a subclass. + """ + raise NotImplementedError("Must implement the event mapping rules for converting text events to numeric events") def get_ed_admission_matching_rules(self) -> List[str]: return self._ed_admission_matching_rules diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_micmic4.py b/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_micmic4.py index c9ca2cf8..0ffbb696 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_micmic4.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_micmic4.py @@ -2,7 +2,8 @@ from typing import List from ....data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import ( - MedsToCehrBertConversion, EventConversionRule + EventConversionRule, + MedsToCehrBertConversion, ) @@ -23,13 +24,13 @@ def _create_text_event_to_numeric_event_rules(self) -> List[EventConversionRule] "Blood Pressure Lying", "Blood Pressure Sitting", "Blood Pressure Standing (1 min)", - "Blood Pressure Standing (3 mins)" + "Blood Pressure Standing (3 mins)", ] blood_pressure_rules = [ EventConversionRule( code=code, parsing_pattern=re.compile(r"(\d+)/(\d+)"), - mapped_event_labels=[f"Systolic {code}", f"Diastolic {code}"] + mapped_event_labels=[f"Systolic {code}", f"Diastolic {code}"], ) for code in blood_pressure_codes ] @@ -38,7 +39,7 @@ def _create_text_event_to_numeric_event_rules(self) -> List[EventConversionRule] EventConversionRule( code=code, parsing_pattern=re.compile(r"(\d+)"), - mapped_event_labels=[code] + mapped_event_labels=[code], ) for code in height_weight_codes ] @@ -46,7 +47,7 @@ def _create_text_event_to_numeric_event_rules(self) -> List[EventConversionRule] EventConversionRule( code="LAB//50827//UNK", parsing_pattern=re.compile(r"(\d+)/(\d+)"), - mapped_event_labels=["LAB//50827//UNK//1", "LAB//50827//UNK//2"] + mapped_event_labels=["LAB//50827//UNK//1", "LAB//50827//UNK//2"], ) ] return blood_pressure_rules + height_weight_rules + ventilation_rate_rules diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py index 5094213b..61260b91 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py @@ -1,36 +1,37 @@ -import os -import re import collections import functools - -from typing import Dict, List, Optional, Union, Tuple, Iterable +import os +import re from datetime import datetime +from typing import Dict, Iterable, List, Optional, Tuple, Union import meds_reader import numpy as np import pandas as pd +from datasets import Dataset, DatasetDict, Split -from ...runners.hf_runner_argument_dataclass import DataTrainingArguments, MedsToCehrBertConversionType -from ...data_generators.hf_data_generator.hf_dataset_mapping import ( - birth_codes, MedToCehrBertDatasetMapping -) -from ...data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import \ - MedsToCehrBertConversion from ...data_generators.hf_data_generator.hf_dataset import apply_cehrbert_dataset_mapping -from ...med_extension.schema_extension import CehrBertPatient, Visit, Event - -from datasets import Dataset, DatasetDict, Split +from ...data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping, birth_codes +from ...data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import ( + MedsToCehrBertConversion, +) +from ...med_extension.schema_extension import CehrBertPatient, Event, Visit +from ...runners.hf_runner_argument_dataclass import DataTrainingArguments, MedsToCehrBertConversionType UNKNOWN_VALUE = "Unknown" DEFAULT_ED_CONCEPT_ID = "9203" DEFAULT_OUTPATIENT_CONCEPT_ID = "9202" DEFAULT_INPATIENT_CONCEPT_ID = "9201" -MEDS_SPLIT_DATA_SPLIT_MAPPING = {"train": Split.TRAIN, "tuning": Split.VALIDATION, "held_out": Split.TEST} +MEDS_SPLIT_DATA_SPLIT_MAPPING = { + "train": Split.TRAIN, + "tuning": Split.VALIDATION, + "held_out": Split.TEST, +} NON_ALPHANUMERIC_CHARS = r"[\w\/\\:\-_]" def get_meds_to_cehrbert_conversion_cls( - meds_to_cehrbert_conversion_type: MedsToCehrBertConversionType + meds_to_cehrbert_conversion_type: MedsToCehrBertConversionType, ) -> MedsToCehrBertConversion: for cls in MedsToCehrBertConversion.__subclasses__(): if meds_to_cehrbert_conversion_type.name == cls.__name__: @@ -40,19 +41,16 @@ def get_meds_to_cehrbert_conversion_cls( def get_patient_split(meds_reader_db_path: str) -> Dict[str, List[int]]: patient_split = pd.read_parquet(os.path.join(meds_reader_db_path, "metadata/patient_splits.parquet")) - result = { - str(group): records["patient_id"].tolist() - for group, records in patient_split.groupby("split") - } + result = {str(group): records["patient_id"].tolist() for group, records in patient_split.groupby("split")} return result class PatientBlock: def __init__( - self, - events: List[meds_reader.Event], - visit_id: int, - conversion: MedsToCehrBertConversion + self, + events: List[meds_reader.Event], + visit_id: int, + conversion: MedsToCehrBertConversion, ): self.visit_id = visit_id self.events = events @@ -75,9 +73,7 @@ def __init__( self.visit_type = DEFAULT_OUTPATIENT_CONCEPT_ID def _has_ed_admission(self) -> bool: - """ - Make this configurable in the future - """ + """Make this configurable in the future.""" for event in self.events: for matching_rule in self.conversion.get_ed_admission_matching_rules(): if re.match(matching_rule, event.code): @@ -103,8 +99,8 @@ def get_discharge_facility(self) -> Optional[str]: for event in self.events: for matching_rule in self.conversion.get_discharge_matching_rules(): if matching_rule in event.code: - discharge_facility = event.code.replace(matching_rule, '') - discharge_facility = re.sub(r'[^a-zA-Z]', '_', discharge_facility) + discharge_facility = event.code.replace(matching_rule, "") + discharge_facility = re.sub(r"[^a-zA-Z]", "_", discharge_facility) return discharge_facility return None @@ -126,7 +122,7 @@ def _convert_event(self, event) -> List[Event]: code=label, time=time, numeric_value=float(value), - properties={'visit_id': self.visit_id, "table": "meds"} + properties={"visit_id": self.visit_id, "table": "meds"}, ) for label, value in zip(conversion_rule.mapped_event_labels, match.groups()) if value.isnumeric() @@ -139,7 +135,7 @@ def _convert_event(self, event) -> List[Event]: time=time, numeric_value=numeric_value, text_value=text_value, - properties={'visit_id': self.visit_id, "table": "meds"} + properties={"visit_id": self.visit_id, "table": "meds"}, ) ] @@ -151,11 +147,11 @@ def get_meds_events(self) -> Iterable[Event]: def convert_one_patient( - patient: meds_reader.Patient, - conversion: MedsToCehrBertConversion, - default_visit_id: int = 1, - prediction_time: datetime = None, - label: Union[int, float] = None + patient: meds_reader.Patient, + conversion: MedsToCehrBertConversion, + default_visit_id: int = 1, + prediction_time: datetime = None, + label: Union[int, float] = None, ) -> CehrBertPatient: birth_datetime = None race = None @@ -176,11 +172,11 @@ def convert_one_patient( # This indicates demographics features if e.code in birth_codes: birth_datetime = e.time - elif e.code.startswith('RACE'): + elif e.code.startswith("RACE"): race = e.code - elif e.code.startswith('GENDER'): + elif e.code.startswith("GENDER"): gender = e.code - elif e.code.startswith('ETHNICITY'): + elif e.code.startswith("ETHNICITY"): ethnicity = e.code elif e.time is not None: if not current_date: @@ -272,10 +268,11 @@ def convert_one_patient( visit_type = blocks[0].visit_type visit_start_datetime = min([b.min_time for b in blocks]) visit_end_datetime = max([b.max_time for b in blocks]) - discharge_facility = next( - filter(None, [b.get_discharge_facility() for b in blocks]), - None - ) if visit_type == DEFAULT_INPATIENT_CONCEPT_ID else None + discharge_facility = ( + next(filter(None, [b.get_discharge_facility() for b in blocks]), None) + if visit_type == DEFAULT_INPATIENT_CONCEPT_ID + else None + ) visit_events = list() for block in blocks: visit_events.extend(block.get_meds_events()) @@ -285,14 +282,17 @@ def convert_one_patient( visit_type=visit_type, visit_start_datetime=visit_start_datetime, visit_end_datetime=visit_end_datetime, - discharge_facility=discharge_facility if discharge_facility else UNKNOWN_VALUE, - events=visit_events + discharge_facility=(discharge_facility if discharge_facility else UNKNOWN_VALUE), + events=visit_events, ) ) age_at_index = -1 if prediction_time is not None and birth_datetime is not None: age_at_index = prediction_time.year - birth_datetime.year - if (prediction_time.month, prediction_time.day) < (birth_datetime.month, birth_datetime.day): + if (prediction_time.month, prediction_time.day) < ( + birth_datetime.month, + birth_datetime.day, + ): age_at_index -= 1 # birth_datetime can not be None @@ -307,46 +307,42 @@ def convert_one_patient( ethnicity=ethnicity if ethnicity else UNKNOWN_VALUE, index_date=prediction_time, age_at_index=age_at_index, - label=label + label=label, ) def create_dataset_from_meds_reader( - data_args: DataTrainingArguments, - default_visit_id: int = 1, - is_pretraining: bool = True + data_args: DataTrainingArguments, + default_visit_id: int = 1, + is_pretraining: bool = True, ) -> DatasetDict: train_dataset = _create_cehrbert_data_from_meds( data_args=data_args, split="train", default_visit_id=default_visit_id, - is_pretraining=is_pretraining + is_pretraining=is_pretraining, ) tuning_dataset = _create_cehrbert_data_from_meds( data_args=data_args, split="tuning", default_visit_id=default_visit_id, - is_pretraining=is_pretraining + is_pretraining=is_pretraining, ) held_out_dataset = _create_cehrbert_data_from_meds( data_args=data_args, split="held_out", default_visit_id=default_visit_id, - is_pretraining=is_pretraining + is_pretraining=is_pretraining, ) - return DatasetDict({ - "train": train_dataset, - "validation": tuning_dataset, - "test": held_out_dataset - }) + return DatasetDict({"train": train_dataset, "validation": tuning_dataset, "test": held_out_dataset}) def _meds_to_cehrbert_generator( - shards: List[Tuple[np.ndarray, np.ndarray, np.ndarray]], - path_to_db: str, - default_visit_id: int, - meds_to_cehrbert_conversion_type: MedsToCehrBertConversionType + shards: List[Tuple[np.ndarray, np.ndarray, np.ndarray]], + path_to_db: str, + default_visit_id: int, + meds_to_cehrbert_conversion_type: MedsToCehrBertConversionType, ) -> CehrBertPatient: conversion = get_meds_to_cehrbert_conversion_cls(meds_to_cehrbert_conversion_type) for shard in shards: @@ -357,12 +353,12 @@ def _meds_to_cehrbert_generator( def _create_cehrbert_data_from_meds( - data_args: DataTrainingArguments, - split: str, - default_visit_id: int = 1, - is_pretraining: bool = True + data_args: DataTrainingArguments, + split: str, + default_visit_id: int = 1, + is_pretraining: bool = True, ): - assert split in ['held_out', 'train', 'tuning'] + assert split in ["held_out", "train", "tuning"] batches = [] if data_args.cohort_folder: cohort = pd.read_parquet(os.path.join(data_args.cohort_folder, split)) @@ -376,23 +372,20 @@ def _create_cehrbert_data_from_meds( for patient_id in patient_split[split]: batches.append((patient_id, None, None)) - split_batches = np.array_split( - np.asarray(batches), - data_args.preprocessing_num_workers - ) + split_batches = np.array_split(np.asarray(batches), data_args.preprocessing_num_workers) batch_func = functools.partial( _meds_to_cehrbert_generator, path_to_db=data_args.data_folder, - default_visit_id=default_visit_id + default_visit_id=default_visit_id, ) dataset = Dataset.from_generator( batch_func, gen_kwargs={ "shards": split_batches, }, - num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None, + num_proc=(data_args.preprocessing_num_workers if not data_args.streaming else None), writer_batch_size=data_args.preprocessing_batch_size, - streaming=data_args.streaming + streaming=data_args.streaming, ) # Convert the CehrBertPatient to CehrBert data inputs dataset = apply_cehrbert_dataset_mapping( @@ -400,6 +393,6 @@ def _create_cehrbert_data_from_meds( MedToCehrBertDatasetMapping(data_args, is_pretraining), num_proc=data_args.preprocessing_num_workers, batch_size=data_args.preprocessing_batch_size, - streaming=data_args.streaming + streaming=data_args.streaming, ) return dataset diff --git a/src/cehrbert/data_generators/learning_objective.py b/src/cehrbert/data_generators/learning_objective.py index 35241828..00bd5fa6 100644 --- a/src/cehrbert/data_generators/learning_objective.py +++ b/src/cehrbert/data_generators/learning_objective.py @@ -1,21 +1,23 @@ import random from abc import ABC, abstractmethod from itertools import islice -from typing import List, Dict +from typing import Dict, List + import numpy as np import pandas as pd -from tensorflow.dtypes import int32, float32, DType +from tensorflow.dtypes import DType, float32, int32 from tensorflow.keras.utils import pad_sequences -from .data_classes import RowSlicer -from .graph_sample_method import GraphSampler -from .tokenizer import ConceptTokenizer -from ..utils.model_utils import convert_to_list_of_lists +from cehrbert.data_generators.data_classes import RowSlicer +from cehrbert.data_generators.graph_sample_method import GraphSampler +from cehrbert.data_generators.tokenizer import ConceptTokenizer +from cehrbert.utils.model_utils import convert_to_list_of_lists def validate_columns_decorator(function): """ - A decorator to validate whether the parameter rows passed to LearningObjective.process_batch + A decorator to validate whether the parameter rows passed to LearningObjective.process_batch. + contain the required columns. It raises AttributeError if any of the required columns is missing from the rows @@ -28,8 +30,7 @@ def wrapper(self, rows: List[RowSlicer], *args, **kwargs): for row_slicer in rows: for column in required_columns: if not hasattr(row_slicer.row, column): - raise AttributeError( - f'The required column {column} is missing for {self}') + raise AttributeError(f"The required column {column} is missing for {self}") break return function(self, rows, *args, **kwargs) @@ -37,9 +38,9 @@ def wrapper(self, rows: List[RowSlicer], *args, **kwargs): return wrapper -def post_pad_pre_truncate(inputs, pad_value, max_seq_len, d_type='int32'): +def post_pad_pre_truncate(inputs, pad_value, max_seq_len, d_type="int32"): """ - Post _pad and pre-truncate the sequence + Post _pad and pre-truncate the sequence. :param inputs: :param pad_value: @@ -47,13 +48,7 @@ def post_pad_pre_truncate(inputs, pad_value, max_seq_len, d_type='int32'): :param d_type: :return: """ - return pad_sequences( - inputs, - maxlen=max_seq_len, - padding='post', - value=pad_value, - dtype=d_type - ) + return pad_sequences(inputs, maxlen=max_seq_len, padding="post", value=pad_value, dtype=d_type) class LearningObjective(ABC): @@ -66,24 +61,25 @@ def required_columns(self): @abstractmethod def process_batch(self, rows: List[RowSlicer]): """ - Process a batch of rows to generate input and output data for the learning objective + Process a batch of rows to generate input and output data for the learning objective. + :param rows: :return: """ - pass @abstractmethod def get_tf_dataset_schema(self): """ - Get the schema for the input and output to the tensorflow Dataset + Get the schema for the input and output to the tensorflow Dataset. + :return: """ - pass @classmethod def get_required_columns(cls): """ - Get the required columns for this learning objective + Get the required columns for this learning objective. + :return: """ return cls.required_columns @@ -107,7 +103,8 @@ def get_tf_dataset_schema(self): @validate_columns_decorator def process_batch(self, rows: List[RowSlicer]): """ - Process a batch of rows to generate input and output data for the learning objective + Process a batch of rows to generate input and output data for the learning objective. + :param rows: :return: """ @@ -129,16 +126,17 @@ def process_batch(self, rows: List[RowSlicer]): class BertFineTuningLearningObjective(LearningObjective): - required_columns = ['label'] + required_columns = ["label"] def get_tf_dataset_schema(self): - output_dict_schema = {'label': int32} + output_dict_schema = {"label": int32} return {}, output_dict_schema @validate_columns_decorator def process_batch(self, rows: List[RowSlicer]): """ - Process a batch of rows to generate input and output data for the learning objective + Process a batch of rows to generate input and output data for the learning objective. + :param rows: :return: """ @@ -146,24 +144,22 @@ def process_batch(self, rows: List[RowSlicer]): for row_slicer in rows: labels.append(row_slicer.row.label) - output_dict = {'label': labels} + output_dict = {"label": labels} return {}, output_dict class DemographicsLearningObjective(LearningObjective): - required_columns = ['age', 'gender_concept_id'] + required_columns = ["age", "gender_concept_id"] def get_tf_dataset_schema(self): - input_dict_schema = { - 'age': int32, - 'gender': int32 - } + input_dict_schema = {"age": int32, "gender": int32} return input_dict_schema, {} @validate_columns_decorator def process_batch(self, rows: List[RowSlicer]): """ - Process a batch of rows to generate input and output data for the learning objective + Process a batch of rows to generate input and output data for the learning objective. + :param rows: :return: """ @@ -173,27 +169,23 @@ def process_batch(self, rows: List[RowSlicer]): age_input.append(row_slicer.row.age) gender_input.append(row_slicer.row.gender_concept_id) - input_dict = { - 'age': age_input, - 'gender': gender_input - } + input_dict = {"age": age_input, "gender": gender_input} return input_dict, {} class ProlongedLengthStayLearningObjective(LearningObjective): - required_columns = ['prolonged_length_stay'] + required_columns = ["prolonged_length_stay"] def get_tf_dataset_schema(self): - output_dict_schema = { - 'prolonged_length_stay': int32 - } + output_dict_schema = {"prolonged_length_stay": int32} return {}, output_dict_schema @validate_columns_decorator def process_batch(self, rows: List[RowSlicer]): """ - Process a batch of rows to generate input and output data for the learning objective + Process a batch of rows to generate input and output data for the learning objective. + :param rows: :return: """ @@ -201,71 +193,50 @@ def process_batch(self, rows: List[RowSlicer]): for row_slicer in rows: prolonged_length_stay_input.append(row_slicer.row.prolonged_length_stay) - output_dict = { - 'prolonged_length_stay': prolonged_length_stay_input - } + output_dict = {"prolonged_length_stay": prolonged_length_stay_input} return {}, output_dict class VisitPredictionLearningObjective(LearningObjective): - required_columns = ['visit_token_ids', 'visit_concept_orders'] + required_columns = ["visit_token_ids", "visit_concept_orders"] - def __init__( - self, - visit_tokenizer: ConceptTokenizer, - max_seq_len: int - ): + def __init__(self, visit_tokenizer: ConceptTokenizer, max_seq_len: int): self._max_seq_len = max_seq_len self._visit_tokenizer = visit_tokenizer def get_tf_dataset_schema(self): - input_dict_schema = { - 'masked_visit_concepts': int32, - 'mask_visit': int32 - } - output_dict_schema = {'visit_predictions': int32} + input_dict_schema = {"masked_visit_concepts": int32, "mask_visit": int32} + output_dict_schema = {"visit_predictions": int32} return input_dict_schema, output_dict_schema @validate_columns_decorator def process_batch(self, rows: List[RowSlicer]): - (output_mask, masked_visit_concepts, visit_concepts) = zip( - *list(map(self._make_record, rows))) + (output_mask, masked_visit_concepts, visit_concepts) = zip(*list(map(self._make_record, rows))) unused_token_id = self._visit_tokenizer.get_unused_token_id() - visit_concepts = post_pad_pre_truncate( - visit_concepts, - unused_token_id, - self._max_seq_len - ) - masked_visit_concepts = post_pad_pre_truncate( - masked_visit_concepts, - unused_token_id, - self._max_seq_len - ) + visit_concepts = post_pad_pre_truncate(visit_concepts, unused_token_id, self._max_seq_len) + masked_visit_concepts = post_pad_pre_truncate(masked_visit_concepts, unused_token_id, self._max_seq_len) # 1 indicates attention and 0 indicates mask visit_mask = (visit_concepts != unused_token_id).astype(int) - combined_label = np.stack( - [visit_concepts, output_mask], - axis=-1 - ) + combined_label = np.stack([visit_concepts, output_mask], axis=-1) input_dict = { - 'masked_visit_concepts': masked_visit_concepts, - 'mask_visit': visit_mask + "masked_visit_concepts": masked_visit_concepts, + "mask_visit": visit_mask, } - output_dict = {'visit_predictions': combined_label} + output_dict = {"visit_predictions": combined_label} return input_dict, output_dict def _make_record(self, row_slicer: RowSlicer): """ - A method for making a bert record for the bert data generator to yield + A method for making a bert record for the bert data generator to yield. :param row_slicer: a namedtuple containing a pandas row, left_index and right_index for slicing the sequences such as concepts @@ -276,17 +247,16 @@ def _make_record(self, row_slicer: RowSlicer): row, left_index, right_index, *_ = row_slicer iterator = zip(row.visit_concept_orders, row.visit_token_ids) - (dates, visit_concept_ids) = zip( - *islice(sorted(iterator, key=lambda tup2: tup2[0]), left_index, right_index)) + (dates, visit_concept_ids) = zip(*islice(sorted(iterator, key=lambda tup2: tup2[0]), left_index, right_index)) - masked_visit_concepts, output_mask = self._mask_visit_concepts( - visit_concept_ids) + masked_visit_concepts, output_mask = self._mask_visit_concepts(visit_concept_ids) return output_mask, masked_visit_concepts, visit_concept_ids def _mask_visit_concepts(self, visit_concepts): """ - Any visit has 50% chance to be masked + Any visit has 50% chance to be masked. + :param visit_concepts: :return: """ @@ -301,16 +271,21 @@ def _mask_visit_concepts(self, visit_concepts): class MaskedLanguageModelLearningObjective(LearningObjective): required_columns = [ - 'token_ids', 'dates', 'visit_segments', - 'ages', 'visit_concept_orders', - 'concept_values', 'concept_value_masks', 'mlm_skip_values' + "token_ids", + "dates", + "visit_segments", + "ages", + "visit_concept_orders", + "concept_values", + "concept_value_masks", + "mlm_skip_values", ] def __init__( - self, - concept_tokenizer: ConceptTokenizer, - max_seq_len: int, - is_pretraining: bool + self, + concept_tokenizer: ConceptTokenizer, + max_seq_len: int, + is_pretraining: bool, ): self._max_seq_len = max_seq_len self._concept_tokenizer = concept_tokenizer @@ -318,97 +293,70 @@ def __init__( def get_tf_dataset_schema(self): input_dict_schema = { - 'masked_concept_ids': int32, - 'concept_ids': int32, - 'mask': int32, - 'time_stamps': int32, - 'visit_segments': int32, - 'ages': int32, - 'visit_concept_orders': int32, - 'concept_values': float32, - 'concept_value_masks': int32 + "masked_concept_ids": int32, + "concept_ids": int32, + "mask": int32, + "time_stamps": int32, + "visit_segments": int32, + "ages": int32, + "visit_concept_orders": int32, + "concept_values": float32, + "concept_value_masks": int32, } - output_dict_schema = {'concept_predictions': int32} + output_dict_schema = {"concept_predictions": int32} return input_dict_schema, output_dict_schema @validate_columns_decorator def process_batch(self, rows: List[RowSlicer]): ( - output_mask, masked_concepts, concepts, time_stamps, - visit_segments, ages, visit_concept_orders, - concept_value_masks, concept_values + output_mask, + masked_concepts, + concepts, + time_stamps, + visit_segments, + ages, + visit_concept_orders, + concept_value_masks, + concept_values, ) = zip(*list(map(self._make_record, rows))) unused_token_id = self._concept_tokenizer.get_unused_token_id() # The main inputs for bert - masked_concepts = post_pad_pre_truncate( - masked_concepts, - unused_token_id, - self._max_seq_len - ) - concepts = post_pad_pre_truncate( - concepts, - unused_token_id, - self._max_seq_len - ) - concept_value_masks = post_pad_pre_truncate( - concept_value_masks, - 0, - self._max_seq_len - ) - concept_values = post_pad_pre_truncate( - concept_values, - -1.0, - self._max_seq_len, - d_type='float32' - ) + masked_concepts = post_pad_pre_truncate(masked_concepts, unused_token_id, self._max_seq_len) + concepts = post_pad_pre_truncate(concepts, unused_token_id, self._max_seq_len) + concept_value_masks = post_pad_pre_truncate(concept_value_masks, 0, self._max_seq_len) + concept_values = post_pad_pre_truncate(concept_values, -1.0, self._max_seq_len, d_type="float32") # 1 indicates attention and 0 indicates mask mask = (concepts != unused_token_id).astype(int) # The auxiliary inputs for bert - visit_segments = post_pad_pre_truncate( - visit_segments, - 0, - self._max_seq_len - ) - time_stamps = post_pad_pre_truncate( - time_stamps, - 0, - self._max_seq_len - ) - ages = post_pad_pre_truncate( - ages, - 0, - self._max_seq_len - ) - visit_concept_orders = post_pad_pre_truncate( - visit_concept_orders, - self._max_seq_len, - self._max_seq_len - ) + visit_segments = post_pad_pre_truncate(visit_segments, 0, self._max_seq_len) + time_stamps = post_pad_pre_truncate(time_stamps, 0, self._max_seq_len) + ages = post_pad_pre_truncate(ages, 0, self._max_seq_len) + visit_concept_orders = post_pad_pre_truncate(visit_concept_orders, self._max_seq_len, self._max_seq_len) input_dict = { - 'masked_concept_ids': masked_concepts, - 'concept_ids': concepts, - 'mask': mask, - 'time_stamps': time_stamps, - 'ages': ages, - 'visit_segments': visit_segments, - 'visit_concept_orders': visit_concept_orders, - 'concept_value_masks': concept_value_masks, - 'concept_values': concept_values + "masked_concept_ids": masked_concepts, + "concept_ids": concepts, + "mask": mask, + "time_stamps": time_stamps, + "ages": ages, + "visit_segments": visit_segments, + "visit_concept_orders": visit_concept_orders, + "concept_value_masks": concept_value_masks, + "concept_values": concept_values, } - output_dict = {'concept_predictions': np.stack([concepts, output_mask], axis=-1)} + output_dict = {"concept_predictions": np.stack([concepts, output_mask], axis=-1)} return input_dict, output_dict def _make_record(self, row_slicer: RowSlicer): """ - A method for making a bert record for the bert data generator to yield + A method for making a bert record for the bert data generator to yield. :param row_slicer: a tuple containing a pandas row, left_index and right_index for slicing the sequences such as concepts @@ -418,31 +366,50 @@ def _make_record(self, row_slicer: RowSlicer): row, left_index, right_index, *_ = row_slicer - sorting_columns = getattr(row, 'orders') if hasattr(row, 'orders') else row.dates + sorting_columns = getattr(row, "orders") if hasattr(row, "orders") else row.dates iterator = zip( - map(int, sorting_columns), row.token_ids, row.visit_segments, row.dates, - row.ages, row.visit_concept_orders, row.concept_value_masks, row.concept_values, - row.mlm_skip_values + map(int, sorting_columns), + row.token_ids, + row.visit_segments, + row.dates, + row.ages, + row.visit_concept_orders, + row.concept_value_masks, + row.concept_values, + row.mlm_skip_values, ) sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1])) ( - _, concepts, segments, dates, ages, visit_concept_orders, - concept_value_masks, concept_values, mlm_skip_values + _, + concepts, + segments, + dates, + ages, + visit_concept_orders, + concept_value_masks, + concept_values, + mlm_skip_values, ) = zip(*list(islice(sorted_list, left_index, right_index))) masked_concepts, output_mask = self._mask_concepts(concepts, mlm_skip_values) return ( - output_mask, masked_concepts, concepts, - dates, segments, ages, visit_concept_orders, - concept_value_masks, concept_values + output_mask, + masked_concepts, + concepts, + dates, + segments, + ages, + visit_concept_orders, + concept_value_masks, + concept_values, ) def _mask_concepts(self, concepts, mlm_skip_values): """ - Mask out 15% of the concepts + Mask out 15% of the concepts. :param concepts: :param mlm_skip_values: @@ -466,7 +433,8 @@ def _mask_concepts(self, concepts, mlm_skip_values): elif dice < 0.9: masked_concepts[word_pos] = random.randint( self._concept_tokenizer.get_first_token_index(), - self._concept_tokenizer.get_last_token_index()) + self._concept_tokenizer.get_last_token_index(), + ) # else: 10% of the time we just leave the word as is output_mask[word_pos] = 1 @@ -475,65 +443,79 @@ def _mask_concepts(self, concepts, mlm_skip_values): class HierarchicalMaskedLanguageModelLearningObjective(LearningObjective): required_columns = [ - 'concept_ids', 'dates', - 'visit_segments', 'ages', - 'visit_dates', 'visit_masks', - 'visit_rank_orders', - 'concept_values', 'concept_value_masks', 'mlm_skip_values' + "concept_ids", + "dates", + "visit_segments", + "ages", + "visit_dates", + "visit_masks", + "visit_rank_orders", + "concept_values", + "concept_value_masks", + "mlm_skip_values", ] def __init__( - self, - concept_tokenizer: ConceptTokenizer, - max_num_of_visits: int, - max_num_of_concepts: int, - is_pretraining: bool, - concept_similarity_path: str, - concept_similarity_type: str + self, + concept_tokenizer: ConceptTokenizer, + max_num_of_visits: int, + max_num_of_concepts: int, + is_pretraining: bool, + concept_similarity_path: str, + concept_similarity_type: str, ): self._concept_tokenizer = concept_tokenizer self._max_num_of_visits = max_num_of_visits self._max_num_of_concepts = max_num_of_concepts self._is_pretraining = is_pretraining - self._graph_sampler = GraphSampler( - concept_similarity_path, - concept_similarity_type - ) + self._graph_sampler = GraphSampler(concept_similarity_path, concept_similarity_type) def get_tf_dataset_schema(self): input_dict_schema = { - 'pat_seq': int32, - 'pat_seq_age': int32, - 'pat_seq_time': int32, - 'pat_mask': int32, - 'visit_mask': int32, - 'visit_rank_order': int32, - 'concept_values': float32, - 'concept_value_masks': int32 + "pat_seq": int32, + "pat_seq_age": int32, + "pat_seq_time": int32, + "pat_mask": int32, + "visit_mask": int32, + "visit_rank_order": int32, + "concept_values": float32, + "concept_value_masks": int32, } - output_dict_schema = {'concept_predictions': int32} + output_dict_schema = {"concept_predictions": int32} return input_dict_schema, output_dict_schema - def _pad(self, x, padded_token, maxlen, token_dtype='int32'): + def _pad(self, x, padded_token, maxlen, token_dtype="int32"): return pad_sequences( x, maxlen=maxlen, - padding='post', - truncating='post', + padding="post", + truncating="post", value=padded_token, - dtype=token_dtype) + dtype=token_dtype, + ) def _concept_mask(self, concept_ids): - return list(map(lambda c: (c == self._concept_tokenizer.get_unused_token_id()).astype(int), - concept_ids)) + return list( + map( + lambda c: (c == self._concept_tokenizer.get_unused_token_id()).astype(int), + concept_ids, + ) + ) @validate_columns_decorator def process_batch(self, rows: List[RowSlicer]): ( - output_concept_masks, masked_concepts, concepts, dates, ages, - visit_dates, visit_masks, visit_rank_orders, - concept_values, concept_value_masks + output_concept_masks, + masked_concepts, + concepts, + dates, + ages, + visit_dates, + visit_masks, + visit_rank_orders, + concept_values, + concept_value_masks, ) = zip(*list(map(self._make_record, rows))) # Retrieve the unused token id to pad the visit sequences @@ -547,19 +529,22 @@ def process_batch(self, rows: List[RowSlicer]): pd.Series(masked_concepts) .apply(convert_to_list_of_lists) .apply(self._concept_tokenizer.encode) - .apply(lambda tokens: self._pad( - tokens, - padded_token=unused_token_id, - maxlen=self._max_num_of_concepts)) + .apply( + lambda tokens: self._pad( + tokens, + padded_token=unused_token_id, + maxlen=self._max_num_of_concepts, + ) + ) ) # Post pad the sequence and pre-truncate the sequence padded_masked_concepts = np.reshape( post_pad_pre_truncate( masked_concepts.apply(lambda d: d.flatten()), unused_token_id, - max_seq_len + max_seq_len, ), - (-1, self._max_num_of_visits, self._max_num_of_concepts) + (-1, self._max_num_of_visits, self._max_num_of_concepts), ) # 1 indicates attention and 0 indicates mask @@ -567,16 +552,10 @@ def process_batch(self, rows: List[RowSlicer]): # Process visit_rank_orders padded_visit_rank_orders = post_pad_pre_truncate( - visit_rank_orders, - pad_value=0, - max_seq_len=self._max_num_of_visits + visit_rank_orders, pad_value=0, max_seq_len=self._max_num_of_visits ) # Process visit_masks - padded_visit_masks = post_pad_pre_truncate( - visit_masks, - pad_value=1, - max_seq_len=self._max_num_of_visits - ) + padded_visit_masks = post_pad_pre_truncate(visit_masks, pad_value=1, max_seq_len=self._max_num_of_visits) # 1 indicates attention and 0 indicates mask, therefore we need to flip it. padded_visit_masks = 1 - padded_visit_masks @@ -584,16 +563,19 @@ def process_batch(self, rows: List[RowSlicer]): concept_values = ( pd.Series(concept_values) .apply(convert_to_list_of_lists) - .apply(lambda time_stamps: self._pad( - time_stamps, - padded_token=-1.0, - token_dtype='float32', - maxlen=self._max_num_of_concepts)) + .apply( + lambda time_stamps: self._pad( + time_stamps, + padded_token=-1.0, + token_dtype="float32", + maxlen=self._max_num_of_concepts, + ) + ) ) padded_concept_values = np.reshape( post_pad_pre_truncate(concept_values.apply(lambda d: d.flatten()), -1.0, max_seq_len), - (-1, self._max_num_of_visits, self._max_num_of_concepts) + (-1, self._max_num_of_visits, self._max_num_of_concepts), ) # The concept value masks for bert, this indicates which concept in the visit sequence @@ -601,55 +583,46 @@ def process_batch(self, rows: List[RowSlicer]): concept_value_masks = ( pd.Series(concept_value_masks) .apply(convert_to_list_of_lists) - .apply(lambda time_stamps: self._pad( - time_stamps, - padded_token=0, - maxlen=self._max_num_of_concepts)) + .apply(lambda time_stamps: self._pad(time_stamps, padded_token=0, maxlen=self._max_num_of_concepts)) ) padded_concept_value_masks = np.reshape( post_pad_pre_truncate(concept_value_masks.apply(lambda d: d.flatten()), 0, max_seq_len), - (-1, self._max_num_of_visits, self._max_num_of_concepts) + (-1, self._max_num_of_visits, self._max_num_of_concepts), ) # The auxiliary inputs for bert dates = ( pd.Series(dates) .apply(convert_to_list_of_lists) - .apply(lambda time_stamps: self._pad( - time_stamps, - padded_token=0, - maxlen=self._max_num_of_concepts)) + .apply(lambda time_stamps: self._pad(time_stamps, padded_token=0, maxlen=self._max_num_of_concepts)) ) padded_dates = np.reshape( post_pad_pre_truncate(dates.apply(lambda d: d.flatten()), 0, max_seq_len), - (-1, self._max_num_of_visits, self._max_num_of_concepts) + (-1, self._max_num_of_visits, self._max_num_of_concepts), ) ages = ( pd.Series(ages) .apply(convert_to_list_of_lists) - .apply(lambda time_stamps: self._pad( - time_stamps, - padded_token=0, - maxlen=self._max_num_of_concepts)) + .apply(lambda time_stamps: self._pad(time_stamps, padded_token=0, maxlen=self._max_num_of_concepts)) ) padded_ages = np.reshape( post_pad_pre_truncate(ages.apply(lambda d: d.flatten()), 0, max_seq_len), - (-1, self._max_num_of_visits, self._max_num_of_concepts) + (-1, self._max_num_of_visits, self._max_num_of_concepts), ) input_dict = { - 'pat_seq': padded_masked_concepts, - 'pat_mask': padded_pat_mask, - 'pat_seq_time': padded_dates, - 'pat_seq_age': padded_ages, - 'visit_mask': padded_visit_masks, - 'visit_rank_order': padded_visit_rank_orders, - 'concept_values': padded_concept_values, - 'concept_value_masks': padded_concept_value_masks + "pat_seq": padded_masked_concepts, + "pat_mask": padded_pat_mask, + "pat_seq_time": padded_dates, + "pat_seq_age": padded_ages, + "visit_mask": padded_visit_masks, + "visit_rank_order": padded_visit_rank_orders, + "concept_values": padded_concept_values, + "concept_value_masks": padded_concept_value_masks, } # Create the targets for MLM @@ -658,44 +631,38 @@ def process_batch(self, rows: List[RowSlicer]): pd.Series(concepts) .apply(convert_to_list_of_lists) .apply(self._concept_tokenizer.encode) - .apply(lambda tokens: self._pad( - tokens, - padded_token=unused_token_id, - maxlen=self._max_num_of_concepts)) + .apply( + lambda tokens: self._pad( + tokens, + padded_token=unused_token_id, + maxlen=self._max_num_of_concepts, + ) + ) ) # Reshape this into 1-D for the MLM prediction - padded_concepts = post_pad_pre_truncate( - concepts.apply(lambda d: d.flatten()), - unused_token_id, - max_seq_len - ) + padded_concepts = post_pad_pre_truncate(concepts.apply(lambda d: d.flatten()), unused_token_id, max_seq_len) output_concept_masks = ( pd.Series(output_concept_masks) .apply(convert_to_list_of_lists) - .apply(lambda masks: self._pad( - masks, - padded_token=0, - maxlen=self._max_num_of_concepts)) + .apply(lambda masks: self._pad(masks, padded_token=0, maxlen=self._max_num_of_concepts)) ) # Reshape this into 1-D for the MLM prediction padded_output_concept_masks = post_pad_pre_truncate( output_concept_masks.apply(lambda d: d.flatten()), pad_value=0, - max_seq_len=max_seq_len + max_seq_len=max_seq_len, ) - output_dict = { - 'concept_predictions': np.stack([padded_concepts, padded_output_concept_masks], axis=-1) - } + output_dict = {"concept_predictions": np.stack([padded_concepts, padded_output_concept_masks], axis=-1)} return input_dict, output_dict def _make_record(self, row_slicer: RowSlicer): """ - A method for making a bert record for the bert data generator to yield + A method for making a bert record for the bert data generator to yield. :param row_slicer: a tuple containing a pandas row, left_index and right_index for slicing the sequences such as concepts @@ -723,18 +690,25 @@ def _make_record(self, row_slicer: RowSlicer): visit_masks = row.visit_masks[start_index:end_index] visit_rank_orders = row.visit_rank_orders[start_index:end_index] - masked_concepts, output_concept_masks = zip( - *list(map(self._mask_concepts, zip(concepts, mlm_skip_values)))) + masked_concepts, output_concept_masks = zip(*list(map(self._mask_concepts, zip(concepts, mlm_skip_values)))) return ( - output_concept_masks, masked_concepts, concepts, dates, ages, - visit_dates, visit_masks, visit_rank_orders, - concept_values, concept_value_masks + output_concept_masks, + masked_concepts, + concepts, + dates, + ages, + visit_dates, + visit_masks, + visit_rank_orders, + concept_values, + concept_value_masks, ) def _mask_concepts(self, concepts_tuple): """ - Mask out 15% of the concepts + Mask out 15% of the concepts. + :param concepts_tuple: :return: """ @@ -751,7 +725,7 @@ def _mask_concepts(self, concepts_tuple): if mlm_skip_values[word_pos] == 1: continue # Do no mask the [UNKNOWN] token - if concepts[word_pos] == '0': + if concepts[word_pos] == "0": continue if random.random() < 0.15: @@ -761,31 +735,28 @@ def _mask_concepts(self, concepts_tuple): elif dice < 0.9: masked_concepts[word_pos] = random.randint( self._concept_tokenizer.get_first_token_index(), - self._concept_tokenizer.get_last_token_index() + self._concept_tokenizer.get_last_token_index(), ) # else: 10% of the time we just leave the word as is output_mask[word_pos] = 1 elif random.random() < 0.15: # the concept will be replaced by the neighbor concept in the graph - masked_concepts[word_pos] = self._graph_sampler.sample_graph( - masked_concepts[word_pos]) + masked_concepts[word_pos] = self._graph_sampler.sample_graph(masked_concepts[word_pos]) return masked_concepts, output_mask -class HierarchicalVisitTypePredictionLearningObjective( - HierarchicalMaskedLanguageModelLearningObjective -): - required_columns = ['visit_token_ids'] +class HierarchicalVisitTypePredictionLearningObjective(HierarchicalMaskedLanguageModelLearningObjective): + required_columns = ["visit_token_ids"] def __init__( - self, - visit_tokenizer: ConceptTokenizer, - max_num_of_visits: int, - is_pretraining: bool, - include_visit_prediction: bool, - warmup_step: int + self, + visit_tokenizer: ConceptTokenizer, + max_num_of_visits: int, + is_pretraining: bool, + include_visit_prediction: bool, + warmup_step: int, ): self._visit_tokenizer = visit_tokenizer self._max_num_of_visits = max_num_of_visits @@ -795,60 +766,45 @@ def __init__( self._counter = 0 def get_tf_dataset_schema(self): - input_dict_schema = { - 'masked_visit_type': int32 - } + input_dict_schema = {"masked_visit_type": int32} output_dict_schema = {} if self._include_visit_prediction: - output_dict_schema.update({ - 'visit_predictions': int32 - }) + output_dict_schema.update({"visit_predictions": int32}) return input_dict_schema, output_dict_schema @validate_columns_decorator def process_batch(self, rows: List[RowSlicer]): """ - Process a batch of rows to generate input and output data for the learning objective + Process a batch of rows to generate input and output data for the learning objective. + :param rows: :return: """ - (masked_visit_token_ids, output_mask, visit_token_ids) = zip( - *list(map(self._make_record, rows))) + (masked_visit_token_ids, output_mask, visit_token_ids) = zip(*list(map(self._make_record, rows))) padded_masked_visit_token_ids = self._pad( masked_visit_token_ids, padded_token=self._visit_tokenizer.get_unused_token_id(), - maxlen=self._max_num_of_visits + maxlen=self._max_num_of_visits, ) - input_dict = { - 'masked_visit_type': padded_masked_visit_token_ids - } + input_dict = {"masked_visit_type": padded_masked_visit_token_ids} output_dict = {} if self._include_visit_prediction: padded_visit_token_ids = self._pad( visit_token_ids, padded_token=self._visit_tokenizer.get_unused_token_id(), - maxlen=self._max_num_of_visits + maxlen=self._max_num_of_visits, ) - padded_output_masks = self._pad( - output_mask, - padded_token=0, - maxlen=self._max_num_of_visits - ) + padded_output_masks = self._pad(output_mask, padded_token=0, maxlen=self._max_num_of_visits) if self._counter < self._warmup_step: self._counter += 1 - padded_output_masks = np.zeros_like( - padded_output_masks - ) + padded_output_masks = np.zeros_like(padded_output_masks) - output_dict['visit_predictions'] = np.stack( - [padded_visit_token_ids, padded_output_masks], - axis=-1 - ) + output_dict["visit_predictions"] = np.stack([padded_visit_token_ids, padded_output_masks], axis=-1) return input_dict, output_dict @@ -856,13 +812,14 @@ def _pad(self, x, padded_token, maxlen): return pad_sequences( np.asarray(x, dtype=object), maxlen=maxlen, - padding='post', - value=padded_token, dtype='int32' + padding="post", + value=padded_token, + dtype="int32", ) def _make_record(self, row_slicer: RowSlicer): """ - A method for making a bert record for the bert data generator to yield + A method for making a bert record for the bert data generator to yield. :param row_slicer: a tuple containing a pandas row, left_index and right_index for slicing the sequences such as concepts @@ -880,7 +837,8 @@ def _make_record(self, row_slicer: RowSlicer): def _mask_visit_concepts(self, visit_concepts): """ - Any visit has 50% chance to be masked + Any visit has 50% chance to be masked. + :param visit_concepts: :return: """ @@ -894,17 +852,15 @@ def _mask_visit_concepts(self, visit_concepts): return masked_visit_concepts, output_mask -class HierarchicalReadmissionLearningObjective( - HierarchicalVisitTypePredictionLearningObjective -): - required_columns = ['is_readmissions', 'is_inpatients'] +class HierarchicalReadmissionLearningObjective(HierarchicalVisitTypePredictionLearningObjective): + required_columns = ["is_readmissions", "is_inpatients"] def __init__( - self, - max_num_of_visits: int, - is_pretraining: bool, - random_mask_prob: float, - warmup_step: int + self, + max_num_of_visits: int, + is_pretraining: bool, + random_mask_prob: float, + warmup_step: int, ): self._max_num_of_visits = max_num_of_visits self._is_pretraining = is_pretraining @@ -913,31 +869,22 @@ def __init__( self._counter = 0 def get_tf_dataset_schema(self): - output_dict_schema = { - 'is_readmission': int32 - } + output_dict_schema = {"is_readmission": int32} return {}, output_dict_schema @validate_columns_decorator def process_batch(self, rows: List[RowSlicer]): """ - Process a batch of rows to generate input and output data for the learning objective + Process a batch of rows to generate input and output data for the learning objective. + :param rows: :return: """ is_readmissions, is_inpatients = zip(*list(map(self._make_record, rows))) - padded_is_readmissions = self._pad( - is_readmissions, - padded_token=0, - maxlen=self._max_num_of_visits - ) + padded_is_readmissions = self._pad(is_readmissions, padded_token=0, maxlen=self._max_num_of_visits) - padded_is_inpatients = self._pad( - is_inpatients, - padded_token=0, - maxlen=self._max_num_of_visits - ) + padded_is_inpatients = self._pad(is_inpatients, padded_token=0, maxlen=self._max_num_of_visits) # if _random_mask_prob=0.2, there is 20% chance of being masked random_mask = np.random.rand(*padded_is_inpatients.shape) < self._random_mask_prob @@ -945,19 +892,15 @@ def process_batch(self, rows: List[RowSlicer]): if self._counter < self._warmup_step: self._counter += 1 - mask = np.zeros_like( - mask - ) + mask = np.zeros_like(mask) - output_dict = { - 'is_readmission': np.stack([padded_is_readmissions, mask], axis=-1) - } + output_dict = {"is_readmission": np.stack([padded_is_readmissions, mask], axis=-1)} return {}, output_dict def _make_record(self, row_slicer: RowSlicer): """ - A method for making a bert record for the bert data generator to yield + A method for making a bert record for the bert data generator to yield. :param row_slicer: a tuple containing a pandas row, left_index and right_index for slicing the sequences such as concepts @@ -970,22 +913,18 @@ def _make_record(self, row_slicer: RowSlicer): is_readmissions = row.is_readmissions[start_index:end_index].astype(int) is_inpatients = row.is_inpatients[start_index:end_index] - return ( - is_readmissions, is_inpatients - ) + return (is_readmissions, is_inpatients) -class HierarchicalProlongedLengthStayLearningObjective( - HierarchicalVisitTypePredictionLearningObjective -): - required_columns = ['visit_prolonged_stays', 'is_inpatients'] +class HierarchicalProlongedLengthStayLearningObjective(HierarchicalVisitTypePredictionLearningObjective): + required_columns = ["visit_prolonged_stays", "is_inpatients"] def __init__( - self, - max_num_of_visits: int, - is_pretraining: bool, - random_mask_prob: float, - warmup_step: int + self, + max_num_of_visits: int, + is_pretraining: bool, + random_mask_prob: float, + warmup_step: int, ): self._max_num_of_visits = max_num_of_visits self._is_pretraining = is_pretraining @@ -994,31 +933,22 @@ def __init__( self._counter = 0 def get_tf_dataset_schema(self): - output_dict_schema = { - 'visit_prolonged_stay': int32 - } + output_dict_schema = {"visit_prolonged_stay": int32} return {}, output_dict_schema @validate_columns_decorator def process_batch(self, rows: List[RowSlicer]): """ - Process a batch of rows to generate input and output data for the learning objective + Process a batch of rows to generate input and output data for the learning objective. + :param rows: :return: """ visit_prolonged_stays, is_inpatients = zip(*list(map(self._make_record, rows))) - padded_visit_prolonged_stays = self._pad( - visit_prolonged_stays, - padded_token=0, - maxlen=self._max_num_of_visits - ) + padded_visit_prolonged_stays = self._pad(visit_prolonged_stays, padded_token=0, maxlen=self._max_num_of_visits) - padded_is_inpatients = self._pad( - is_inpatients, - padded_token=0, - maxlen=self._max_num_of_visits - ) + padded_is_inpatients = self._pad(is_inpatients, padded_token=0, maxlen=self._max_num_of_visits) # if _random_mask_prob=0.2, there is 20% chance of being masked random_mask = np.random.rand(*padded_is_inpatients.shape) < self._random_mask_prob @@ -1026,23 +956,15 @@ def process_batch(self, rows: List[RowSlicer]): if self._counter < self._warmup_step: self._counter += 1 - mask = np.zeros_like( - mask - ) + mask = np.zeros_like(mask) - output_dict = { - 'visit_prolonged_stay': - np.stack( - [padded_visit_prolonged_stays, mask], - axis=-1 - ) - } + output_dict = {"visit_prolonged_stay": np.stack([padded_visit_prolonged_stays, mask], axis=-1)} return {}, output_dict def _make_record(self, row_slicer: RowSlicer): """ - A method for making a bert record for the bert data generator to yield + A method for making a bert record for the bert data generator to yield. :param row_slicer: a tuple containing a pandas row, left_index and right_index for slicing the sequences such as concepts @@ -1055,98 +977,78 @@ def _make_record(self, row_slicer: RowSlicer): visit_prolonged_stays = row.visit_prolonged_stays[start_index:end_index].astype(int) is_inpatients = row.is_inpatients[start_index:end_index] - return ( - visit_prolonged_stays, is_inpatients - ) + return (visit_prolonged_stays, is_inpatients) -class HierarchicalArtificialTokenPredictionLearningObjective( - HierarchicalMaskedLanguageModelLearningObjective -): - required_columns = ['time_interval_atts'] +class HierarchicalArtificialTokenPredictionLearningObjective(HierarchicalMaskedLanguageModelLearningObjective): + required_columns = ["time_interval_atts"] def __init__( - self, - concept_tokenizer: ConceptTokenizer, - max_num_of_visits: int, - include_att_prediction: bool + self, + concept_tokenizer: ConceptTokenizer, + max_num_of_visits: int, + include_att_prediction: bool, ): self._concept_tokenizer = concept_tokenizer self._max_num_of_visits = max_num_of_visits self._include_att_prediction = include_att_prediction def get_tf_dataset_schema(self): - input_dict_schema = { - 'visit_time_delta_att': int32 - } + input_dict_schema = {"visit_time_delta_att": int32} output_dict_schema = {} # when att prediction is enabled, we update the output data schema if self._include_att_prediction: - output_dict_schema.update({ - 'att_predictions': int32 - }) + output_dict_schema.update({"att_predictions": int32}) return input_dict_schema, output_dict_schema @validate_columns_decorator def process_batch(self, rows: List[RowSlicer]): """ - Process a batch of rows to generate input and output data for the learning objective + Process a batch of rows to generate input and output data for the learning objective. + :param rows: :return: """ - ( - output_mask, masked_time_interval_att_tokens, time_interval_att_tokens - ) = zip(*list(map(self._make_record, rows))) + (output_mask, masked_time_interval_att_tokens, time_interval_att_tokens) = zip( + *list(map(self._make_record, rows)) + ) masked_time_interval_att_tokens = np.asarray( - self._concept_tokenizer.encode( - pd.Series(masked_time_interval_att_tokens).apply(lambda t: t.tolist()) - ) + self._concept_tokenizer.encode(pd.Series(masked_time_interval_att_tokens).apply(lambda t: t.tolist())) ) padded_masked_time_interval_att_tokens = post_pad_pre_truncate( masked_time_interval_att_tokens, self._concept_tokenizer.get_unused_token_id(), - self._max_num_of_visits + self._max_num_of_visits, )[:, 1:] - input_dict = { - 'visit_time_delta_att': padded_masked_time_interval_att_tokens - } + input_dict = {"visit_time_delta_att": padded_masked_time_interval_att_tokens} output_dict = {} if self._include_att_prediction: time_interval_att_tokens = np.asarray( - self._concept_tokenizer.encode( - pd.Series(time_interval_att_tokens).apply(lambda t: t.tolist()) - ) + self._concept_tokenizer.encode(pd.Series(time_interval_att_tokens).apply(lambda t: t.tolist())) ) padded_time_interval_att_tokens = post_pad_pre_truncate( time_interval_att_tokens, self._concept_tokenizer.get_unused_token_id(), - self._max_num_of_visits + self._max_num_of_visits, )[:, 1:] - padded_output_mask = post_pad_pre_truncate( - output_mask, - 0, - self._max_num_of_visits - )[:, 1:] + padded_output_mask = post_pad_pre_truncate(output_mask, 0, self._max_num_of_visits)[:, 1:] - output_dict.update({ - 'att_predictions': np.stack( - [padded_time_interval_att_tokens, padded_output_mask], - axis=-1 - ) - }) + output_dict.update( + {"att_predictions": np.stack([padded_time_interval_att_tokens, padded_output_mask], axis=-1)} + ) return input_dict, output_dict def _make_record(self, row_slicer: RowSlicer): """ - A method for making a bert record for the bert data generator to yield + A method for making a bert record for the bert data generator to yield. :param row_slicer: a tuple containing a pandas row, left_index and right_index for slicing the sequences such as concepts @@ -1157,17 +1059,13 @@ def _make_record(self, row_slicer: RowSlicer): row, start_index, end_index, *_ = row_slicer time_interval_att_tokens = row.time_interval_atts[start_index:end_index] - masked_time_interval_att_tokens, output_mask = self._mask_visit_concepts( - time_interval_att_tokens - ) + masked_time_interval_att_tokens, output_mask = self._mask_visit_concepts(time_interval_att_tokens) return output_mask, masked_time_interval_att_tokens, time_interval_att_tokens - def _mask_visit_concepts( - self, - time_interval_att_tokens - ): + def _mask_visit_concepts(self, time_interval_att_tokens): """ - Any visit has 50% chance to be masked when att prediction is enabled, otherwise just + Any visit has 50% chance to be masked when att prediction is enabled, otherwise just. + return the time_interval_att_tokens as the masked_time_interval_att_tokens :param time_interval_att_tokens: @@ -1188,17 +1086,20 @@ def _mask_visit_concepts( if random.random() < 0.15: output_mask[word_pos] = 1 - masked_time_interval_att_tokens[ - word_pos] = self._concept_tokenizer.get_att_mask_token_id() + masked_time_interval_att_tokens[word_pos] = self._concept_tokenizer.get_att_mask_token_id() return masked_time_interval_att_tokens, output_mask class TimeAttentionLearningObjective(LearningObjective): - required_columns = ['token_ids', 'dates'] + required_columns = ["token_ids", "dates"] - def __init__(self, concept_tokenizer: ConceptTokenizer, max_seq_len: int, - time_window_size: int): + def __init__( + self, + concept_tokenizer: ConceptTokenizer, + max_seq_len: int, + time_window_size: int, + ): super(TimeAttentionLearningObjective, self).__init__() self._concept_tokenizer = concept_tokenizer self._max_seq_len = max_seq_len @@ -1206,43 +1107,48 @@ def __init__(self, concept_tokenizer: ConceptTokenizer, max_seq_len: int, def get_tf_dataset_schema(self): input_dict_schema = { - 'target_concepts': int32, - 'target_time_stamps': int32, - 'context_concepts': int32, - 'context_time_stamps': int32, - 'mask': int32 + "target_concepts": int32, + "target_time_stamps": int32, + "context_concepts": int32, + "context_time_stamps": int32, + "mask": int32, } - output_dict_schema = {'concept_predictions': int32} + output_dict_schema = {"concept_predictions": int32} return input_dict_schema, output_dict_schema @validate_columns_decorator def process_batch(self, rows: List[RowSlicer]): (target_concepts, target_dates, context_concepts, context_time_stamps) = zip( - *list(map(self._make_record, rows))) + *list(map(self._make_record, rows)) + ) target_concepts = np.asarray(target_concepts) target_time_stamps = np.asarray(target_dates) - context_concepts = post_pad_pre_truncate(context_concepts, - self._concept_tokenizer.get_unused_token_id(), - self._max_seq_len) + context_concepts = post_pad_pre_truncate( + context_concepts, + self._concept_tokenizer.get_unused_token_id(), + self._max_seq_len, + ) context_time_stamps = post_pad_pre_truncate(context_time_stamps, 0, self._max_seq_len) mask = (context_concepts == self._concept_tokenizer.get_unused_token_id()).astype(int) - input_dict = {'target_concepts': target_concepts, - 'target_time_stamps': target_time_stamps, - 'context_concepts': context_concepts, - 'context_time_stamps': context_time_stamps, - 'mask': mask} + input_dict = { + "target_concepts": target_concepts, + "target_time_stamps": target_time_stamps, + "context_concepts": context_concepts, + "context_time_stamps": context_time_stamps, + "mask": mask, + } - output_dict = {'concept_predictions': target_concepts} + output_dict = {"concept_predictions": target_concepts} return input_dict, output_dict def _make_record(self, row_slicer: RowSlicer): """ - A method for making a bert record for the time attention data generator to yield + A method for making a bert record for the time attention data generator to yield. :param row_slicer: a tuple containing a pandas row, left_index and right_index for slicing the sequences such as concepts @@ -1259,4 +1165,9 @@ def _make_record(self, row_slicer: RowSlicer): indexes = np.asarray(list(range(start_index, end_index + 1))) indexes = indexes[indexes != target_index] - return [concepts[target_index]], [dates[target_index]], concepts[indexes], dates[indexes] + return ( + [concepts[target_index]], + [dates[target_index]], + concepts[indexes], + dates[indexes], + ) diff --git a/src/cehrbert/data_generators/tokenizer.py b/src/cehrbert/data_generators/tokenizer.py index b2eaae18..487b1b75 100644 --- a/src/cehrbert/data_generators/tokenizer.py +++ b/src/cehrbert/data_generators/tokenizer.py @@ -8,20 +8,18 @@ class ConceptTokenizer: - unused_token = '[UNUSED]' - mask_token = '[MASK]' - att_mask_token = '[ATT_MASK]' - cls_token = '[CLS]' - start_token = '[START]' - end_token = '[END]' - visit_start_token = 'VS' - visit_end_token = 'VE' - - def __init__( - self, special_tokens: Optional[Sequence[str]] = None, oov_token='-1' - ): + unused_token = "[UNUSED]" + mask_token = "[MASK]" + att_mask_token = "[ATT_MASK]" + cls_token = "[CLS]" + start_token = "[START]" + end_token = "[END]" + visit_start_token = "VS" + visit_end_token = "VE" + + def __init__(self, special_tokens: Optional[Sequence[str]] = None, oov_token="-1"): self.special_tokens = special_tokens - self.tokenizer = Tokenizer(oov_token=oov_token, filters='', lower=False) + self.tokenizer = Tokenizer(oov_token=oov_token, filters="", lower=False) self.tokenizer.fit_on_texts([self.mask_token]) self.tokenizer.fit_on_texts([self.att_mask_token]) @@ -40,14 +38,14 @@ def fit_on_concept_sequences(self, concept_sequences: Union[df_series, dd_series if isinstance(concept_sequences, df_series): self.tokenizer.fit_on_texts(concept_sequences.apply(list)) else: - self.tokenizer.fit_on_texts( - concept_sequences.apply(list, meta='iterable') - ) + self.tokenizer.fit_on_texts(concept_sequences.apply(list, meta="iterable")) def encode(self, concept_sequences, is_generator=False): - return self.tokenizer.texts_to_sequences_generator( - concept_sequences) if is_generator else self.tokenizer.texts_to_sequences( - concept_sequences) + return ( + self.tokenizer.texts_to_sequences_generator(concept_sequences) + if is_generator + else self.tokenizer.texts_to_sequences(concept_sequences) + ) def decode(self, concept_sequence_token_ids): return self.tokenizer.sequences_to_texts(concept_sequence_token_ids) @@ -62,15 +60,14 @@ def get_all_token_indexes(self): all_keys.remove(self.tokenizer.word_index[self.tokenizer.oov_token]) if self.special_tokens is not None: - excluded = set( - [self.tokenizer.word_index[special_token] for special_token in self.special_tokens]) + excluded = set([self.tokenizer.word_index[special_token] for special_token in self.special_tokens]) all_keys = all_keys - excluded return all_keys def get_token_by_index(self, index): if index in self.tokenizer.index_word: return self.tokenizer.index_word[index] - raise RuntimeError(f'{index} is not a valid index in tokenizer') + raise RuntimeError(f"{index} is not a valid index in tokenizer") def get_first_token_index(self): return min(self.get_all_token_indexes()) diff --git a/src/cehrbert/evaluations/evaluation.py b/src/cehrbert/evaluations/evaluation.py index e407338c..1b932e10 100644 --- a/src/cehrbert/evaluations/evaluation.py +++ b/src/cehrbert/evaluations/evaluation.py @@ -1,19 +1,51 @@ import configparser -from ..config import output_names as p -from ..config.grid_search_config import LEARNING_RATE, LSTM_DIRECTION, LSTM_UNIT -from .evaluation_parameters import * -from .evaluation_parse_args import create_evaluation_args -from .model_evaluators.hierarchical_bert_evaluators import * -from .model_evaluators.bert_model_evaluators import * -from .model_evaluators.sequence_model_evaluators import * -from .model_evaluators.frequency_model_evaluators import * -from ..utils.model_utils import * -from ..utils.checkpoint_utils import find_tokenizer_path, find_visit_tokenizer_path +import logging +import os + +import pandas as pd +import tensorflow as tf + +tf.keras.utils.set_random_seed(0) + +from cehrbert.config import output_names as p +from cehrbert.config.grid_search_config import LEARNING_RATE, LSTM_DIRECTION, LSTM_UNIT +from cehrbert.evaluations.evaluation_parameters import ( + BASELINE_MODEL, + FULL, + HIERARCHICAL_BERT_LSTM, + HIERARCHICAL_BERT_POOLING, + LSTM, + RANDOM_HIERARCHICAL_BERT_LSTM, + RANDOM_VANILLA_BERT_LSTM, + SEQUENCE_MODEL, + SLIDING_BERT, + VANILLA_BERT_FEED_FORWARD, + VANILLA_BERT_LSTM, +) +from cehrbert.evaluations.evaluation_parse_args import create_evaluation_args +from cehrbert.evaluations.model_evaluators.bert_model_evaluators import ( + BertFeedForwardModelEvaluator, + BertLstmModelEvaluator, + RandomVanillaLstmBertModelEvaluator, + SlidingBertModelEvaluator, +) +from cehrbert.evaluations.model_evaluators.frequency_model_evaluators import ( + LogisticRegressionModelEvaluator, + XGBClassifierEvaluator, +) +from cehrbert.evaluations.model_evaluators.hierarchical_bert_evaluators import ( + HierarchicalBertEvaluator, + HierarchicalBertPoolingEvaluator, + RandomHierarchicalBertEvaluator, +) +from cehrbert.evaluations.model_evaluators.sequence_model_evaluators import BiLstmModelEvaluator, GridSearchConfig +from cehrbert.utils.checkpoint_utils import find_tokenizer_path, find_visit_tokenizer_path +from cehrbert.utils.model_utils import validate_folder def get_grid_search_config(grid_search_config) -> GridSearchConfig: """ - Read the grid search config file and load learning_rates, lstm_directions and lstm_units + Read the grid search config file and load learning_rates, lstm_directions and lstm_units. :param grid_search_config: :return: @@ -30,13 +62,13 @@ def get_grid_search_config(grid_search_config) -> GridSearchConfig: return GridSearchConfig( learning_rates=learning_rates, lstm_directions=lstm_directions, - lstm_units=lstm_units + lstm_units=lstm_units, ) except Exception as e: print(f'{grid_search_config} cannot be parsed. Error message" {e}') else: - print(f'grid_search_config is not provided, will use the default GridSearchConfig') + print(f"grid_search_config is not provided, will use the default GridSearchConfig") return GridSearchConfig() @@ -45,19 +77,14 @@ def evaluate_sequence_models(args): # Load the training data dataset = pd.read_parquet(args.sequence_model_data_path) logging.getLogger(__name__).info( - f'sequence_model_data_path: {args.sequence_model_data_path}\n' - f'args.grid_search_config: {args.grid_search_config}\n' - ) - grid_search_config = get_grid_search_config( - args.grid_search_config + f"sequence_model_data_path: {args.sequence_model_data_path}\n" + f"args.grid_search_config: {args.grid_search_config}\n" ) + grid_search_config = get_grid_search_config(args.grid_search_config) if LSTM in args.model_evaluators: validate_folder(args.time_attention_model_folder) time_attention_tokenizer_path = find_tokenizer_path(args.time_attention_model_folder) - time_aware_model_path = os.path.join( - args.time_attention_model_folder, - p.TIME_ATTENTION_MODEL_PATH - ) + time_aware_model_path = os.path.join(args.time_attention_model_folder, p.TIME_ATTENTION_MODEL_PATH) BiLstmModelEvaluator( dataset=dataset, evaluation_folder=args.evaluation_folder, @@ -76,14 +103,13 @@ def evaluate_sequence_models(args): grid_search_config=grid_search_config, is_chronological_test=args.is_chronological_test, k_fold_test=args.k_fold_test, - multiple_test_run=args.multiple_test_run + multiple_test_run=args.multiple_test_run, ).eval_model() if VANILLA_BERT_FEED_FORWARD in args.model_evaluators: validate_folder(args.vanilla_bert_model_folder) bert_tokenizer_path = find_tokenizer_path(args.vanilla_bert_model_folder) - bert_model_path = os.path.join(args.vanilla_bert_model_folder, - p.BERT_MODEL_VALIDATION_PATH) + bert_model_path = os.path.join(args.vanilla_bert_model_folder, p.BERT_MODEL_VALIDATION_PATH) BertFeedForwardModelEvaluator( dataset=dataset, evaluation_folder=args.evaluation_folder, @@ -101,14 +127,13 @@ def evaluate_sequence_models(args): cross_validation_test=args.cross_validation_test, grid_search_config=grid_search_config, k_fold_test=args.k_fold_test, - multiple_test_run=args.multiple_test_run + multiple_test_run=args.multiple_test_run, ).eval_model() if SLIDING_BERT in args.model_evaluators: validate_folder(args.vanilla_bert_model_folder) bert_tokenizer_path = find_tokenizer_path(args.vanilla_bert_model_folder) - bert_model_path = os.path.join(args.vanilla_bert_model_folder, - p.BERT_MODEL_VALIDATION_PATH) + bert_model_path = os.path.join(args.vanilla_bert_model_folder, p.BERT_MODEL_VALIDATION_PATH) SlidingBertModelEvaluator( dataset=dataset, evaluation_folder=args.evaluation_folder, @@ -127,14 +152,13 @@ def evaluate_sequence_models(args): cross_validation_test=args.cross_validation_test, grid_search_config=grid_search_config, k_fold_test=args.k_fold_test, - multiple_test_run=args.multiple_test_run + multiple_test_run=args.multiple_test_run, ).eval_model() if VANILLA_BERT_LSTM in args.model_evaluators: validate_folder(args.vanilla_bert_model_folder) bert_tokenizer_path = find_tokenizer_path(args.vanilla_bert_model_folder) - bert_model_path = os.path.join(args.vanilla_bert_model_folder, - p.BERT_MODEL_VALIDATION_PATH) + bert_model_path = os.path.join(args.vanilla_bert_model_folder, p.BERT_MODEL_VALIDATION_PATH) BertLstmModelEvaluator( dataset=dataset, evaluation_folder=args.evaluation_folder, @@ -154,13 +178,12 @@ def evaluate_sequence_models(args): is_chronological_test=args.is_chronological_test, freeze_pretrained_model=args.freeze_pretrained_model, k_fold_test=args.k_fold_test, - multiple_test_run=args.multiple_test_run + multiple_test_run=args.multiple_test_run, ).eval_model() if RANDOM_VANILLA_BERT_LSTM in args.model_evaluators: validate_folder(args.vanilla_bert_model_folder) - bert_model_path = os.path.join(args.vanilla_bert_model_folder, - p.BERT_MODEL_VALIDATION_PATH) + bert_model_path = os.path.join(args.vanilla_bert_model_folder, p.BERT_MODEL_VALIDATION_PATH) bert_tokenizer_path = find_tokenizer_path(args.vanilla_bert_model_folder) visit_tokenizer_path = find_visit_tokenizer_path(args.vanilla_bert_model_folder) @@ -189,13 +212,12 @@ def evaluate_sequence_models(args): is_chronological_test=args.is_chronological_test, freeze_pretrained_model=args.freeze_pretrained_model, k_fold_test=args.k_fold_test, - multiple_test_run=args.multiple_test_run + multiple_test_run=args.multiple_test_run, ).eval_model() if HIERARCHICAL_BERT_LSTM in args.model_evaluators: validate_folder(args.vanilla_bert_model_folder) - bert_model_path = os.path.join(args.vanilla_bert_model_folder, - p.BERT_MODEL_VALIDATION_PATH) + bert_model_path = os.path.join(args.vanilla_bert_model_folder, p.BERT_MODEL_VALIDATION_PATH) bert_tokenizer_path = find_tokenizer_path(args.vanilla_bert_model_folder) bert_visit_tokenizer_path = find_visit_tokenizer_path(args.vanilla_bert_model_folder) @@ -221,13 +243,12 @@ def evaluate_sequence_models(args): is_chronological_test=args.is_chronological_test, freeze_pretrained_model=args.freeze_pretrained_model, k_fold_test=args.k_fold_test, - multiple_test_run=args.multiple_test_run + multiple_test_run=args.multiple_test_run, ).eval_model() if HIERARCHICAL_BERT_POOLING in args.model_evaluators: validate_folder(args.vanilla_bert_model_folder) - bert_model_path = os.path.join(args.vanilla_bert_model_folder, - p.BERT_MODEL_VALIDATION_PATH) + bert_model_path = os.path.join(args.vanilla_bert_model_folder, p.BERT_MODEL_VALIDATION_PATH) bert_tokenizer_path = find_tokenizer_path(args.vanilla_bert_model_folder) bert_visit_tokenizer_path = find_visit_tokenizer_path(args.vanilla_bert_model_folder) @@ -253,13 +274,12 @@ def evaluate_sequence_models(args): is_chronological_test=args.is_chronological_test, freeze_pretrained_model=args.freeze_pretrained_model, k_fold_test=args.k_fold_test, - multiple_test_run=args.multiple_test_run + multiple_test_run=args.multiple_test_run, ).eval_model() if RANDOM_HIERARCHICAL_BERT_LSTM in args.model_evaluators: validate_folder(args.vanilla_bert_model_folder) - bert_model_path = os.path.join(args.vanilla_bert_model_folder, - p.BERT_MODEL_VALIDATION_PATH) + bert_model_path = os.path.join(args.vanilla_bert_model_folder, p.BERT_MODEL_VALIDATION_PATH) bert_tokenizer_path = find_tokenizer_path(args.vanilla_bert_model_folder) bert_visit_tokenizer_path = find_visit_tokenizer_path(args.vanilla_bert_model_folder) @@ -290,7 +310,7 @@ def evaluate_sequence_models(args): include_att_tokens=args.include_att_tokens, is_chronological_test=args.is_chronological_test, k_fold_test=args.k_fold_test, - multiple_test_run=args.multiple_test_run + multiple_test_run=args.multiple_test_run, ).eval_model() @@ -308,7 +328,7 @@ def evaluate_baseline_models(args): is_transfer_learning=args.is_transfer_learning, training_percentage=args.training_percentage, k_fold_test=args.k_fold_test, - test_person_ids=test_person_ids + test_person_ids=test_person_ids, ).eval_model() XGBClassifierEvaluator( @@ -318,13 +338,11 @@ def evaluate_baseline_models(args): is_transfer_learning=args.is_transfer_learning, training_percentage=args.training_percentage, k_fold_test=args.k_fold_test, - test_person_ids=test_person_ids + test_person_ids=test_person_ids, ).eval_model() def main(args): - tf.keras.utils.set_random_seed(0) - if args.action == BASELINE_MODEL or args.action == FULL: evaluate_baseline_models(args) diff --git a/src/cehrbert/evaluations/evaluation_parameters.py b/src/cehrbert/evaluations/evaluation_parameters.py index 0c510908..ef582e5f 100644 --- a/src/cehrbert/evaluations/evaluation_parameters.py +++ b/src/cehrbert/evaluations/evaluation_parameters.py @@ -1,19 +1,28 @@ -FULL = 'full' -SEQUENCE_MODEL = 'sequence_model' -BASELINE_MODEL = 'baseline_model' +FULL = "full" +SEQUENCE_MODEL = "sequence_model" +BASELINE_MODEL = "baseline_model" EVALUATION_CHOICES = [FULL, SEQUENCE_MODEL, BASELINE_MODEL] -LSTM = 'lstm' -VANILLA_BERT_LSTM = 'vanilla_bert_lstm' -PROBABILISTIC_BERT_LSTM = 'probabilistic_bert_lstm' -PROBABILISTIC_PHENOTYPE_LSTM = 'probabilistic_phenotype_lstm' -VANILLA_BERT_FEED_FORWARD = 'vanilla_bert_feed_forward' -SLIDING_BERT = 'sliding_bert' -TEMPORAL_BERT_LSTM = 'temporal_bert_lstm' -RANDOM_VANILLA_BERT_LSTM = 'random_vanilla_bert_lstm' -HIERARCHICAL_BERT_LSTM = 'hierarchical_bert_lstm' -HIERARCHICAL_BERT_POOLING = 'hierarchical_bert_pooling' -RANDOM_HIERARCHICAL_BERT_LSTM = 'random_hierarchical_bert_lstm' -SEQUENCE_MODEL_EVALUATORS = [LSTM, VANILLA_BERT_LSTM, VANILLA_BERT_FEED_FORWARD, TEMPORAL_BERT_LSTM, - SLIDING_BERT, RANDOM_VANILLA_BERT_LSTM, HIERARCHICAL_BERT_LSTM, - RANDOM_HIERARCHICAL_BERT_LSTM, PROBABILISTIC_BERT_LSTM, - PROBABILISTIC_PHENOTYPE_LSTM, HIERARCHICAL_BERT_POOLING] +LSTM = "lstm" +VANILLA_BERT_LSTM = "vanilla_bert_lstm" +PROBABILISTIC_BERT_LSTM = "probabilistic_bert_lstm" +PROBABILISTIC_PHENOTYPE_LSTM = "probabilistic_phenotype_lstm" +VANILLA_BERT_FEED_FORWARD = "vanilla_bert_feed_forward" +SLIDING_BERT = "sliding_bert" +TEMPORAL_BERT_LSTM = "temporal_bert_lstm" +RANDOM_VANILLA_BERT_LSTM = "random_vanilla_bert_lstm" +HIERARCHICAL_BERT_LSTM = "hierarchical_bert_lstm" +HIERARCHICAL_BERT_POOLING = "hierarchical_bert_pooling" +RANDOM_HIERARCHICAL_BERT_LSTM = "random_hierarchical_bert_lstm" +SEQUENCE_MODEL_EVALUATORS = [ + LSTM, + VANILLA_BERT_LSTM, + VANILLA_BERT_FEED_FORWARD, + TEMPORAL_BERT_LSTM, + SLIDING_BERT, + RANDOM_VANILLA_BERT_LSTM, + HIERARCHICAL_BERT_LSTM, + RANDOM_HIERARCHICAL_BERT_LSTM, + PROBABILISTIC_BERT_LSTM, + PROBABILISTIC_PHENOTYPE_LSTM, + HIERARCHICAL_BERT_POOLING, +] diff --git a/src/cehrbert/evaluations/evaluation_parse_args.py b/src/cehrbert/evaluations/evaluation_parse_args.py index ec6aeb40..2eff4095 100644 --- a/src/cehrbert/evaluations/evaluation_parse_args.py +++ b/src/cehrbert/evaluations/evaluation_parse_args.py @@ -1,14 +1,21 @@ import argparse from sys import argv -from .evaluation_parameters import SEQUENCE_MODEL, BASELINE_MODEL, EVALUATION_CHOICES, \ - LSTM, VANILLA_BERT_LSTM, SLIDING_BERT, TEMPORAL_BERT_LSTM, HIERARCHICAL_BERT_LSTM, \ - SEQUENCE_MODEL_EVALUATORS +from cehrbert.evaluations.evaluation_parameters import ( + BASELINE_MODEL, + EVALUATION_CHOICES, + HIERARCHICAL_BERT_LSTM, + LSTM, + SEQUENCE_MODEL, + SEQUENCE_MODEL_EVALUATORS, + SLIDING_BERT, + TEMPORAL_BERT_LSTM, + VANILLA_BERT_LSTM, +) def create_evaluation_args(): - main_parser = argparse.ArgumentParser( - description='Arguments for evaluating the models') + main_parser = argparse.ArgumentParser(description="Arguments for evaluating the models") sequence_model_required = BASELINE_MODEL not in argv baseline_model_required = SEQUENCE_MODEL not in argv @@ -18,185 +25,231 @@ def create_evaluation_args(): sliding_bert = SLIDING_BERT in argv hierarchical_bert = HIERARCHICAL_BERT_LSTM in argv - main_parser.add_argument('-a', - '--action', - dest='action', - action='store', - choices=EVALUATION_CHOICES, - help='The action that determines the evaluation process', - required=True) - main_parser.add_argument('-d', - '--data_path', - dest='data_path', - action='store', - help='The training data path', - required=baseline_model_required) - main_parser.add_argument('--test_person_ids_path', - dest='test_person_ids_path', - action='store', - help='The test person_ids data', - required=False) - main_parser.add_argument('-ef', - '--evaluation_folder', - dest='evaluation_folder', - action='store', - required=True) - main_parser.add_argument('-n', - '--num_of_folds', - dest='num_of_folds', - action='store', - required=False, - type=int, - default=4) - main_parser.add_argument('--is_transfer_learning', - dest='is_transfer_learning', - action='store_true') - main_parser.add_argument('--training_percentage', - dest='training_percentage', - required=False, - action='store', - type=float, - default=1.0) - main_parser.add_argument('--learning_rate', - dest='learning_rate', - required=False, - action='store', - type=float, - default=1e-4) + main_parser.add_argument( + "-a", + "--action", + dest="action", + action="store", + choices=EVALUATION_CHOICES, + help="The action that determines the evaluation process", + required=True, + ) + main_parser.add_argument( + "-d", + "--data_path", + dest="data_path", + action="store", + help="The training data path", + required=baseline_model_required, + ) + main_parser.add_argument( + "--test_person_ids_path", + dest="test_person_ids_path", + action="store", + help="The test person_ids data", + required=False, + ) + main_parser.add_argument( + "-ef", + "--evaluation_folder", + dest="evaluation_folder", + action="store", + required=True, + ) + main_parser.add_argument( + "-n", + "--num_of_folds", + dest="num_of_folds", + action="store", + required=False, + type=int, + default=4, + ) + main_parser.add_argument("--is_transfer_learning", dest="is_transfer_learning", action="store_true") + main_parser.add_argument( + "--training_percentage", + dest="training_percentage", + required=False, + action="store", + type=float, + default=1.0, + ) + main_parser.add_argument( + "--learning_rate", + dest="learning_rate", + required=False, + action="store", + type=float, + default=1e-4, + ) - group = main_parser.add_argument_group('sequence model') - group.add_argument('-me', - '--model_evaluators', - dest='model_evaluators', - action='store', - nargs='+', - choices=SEQUENCE_MODEL_EVALUATORS, - required=sequence_model_required) - group.add_argument('-sd', - '--sequence_model_data_path', - dest='sequence_model_data_path', - action='store', - required=sequence_model_required) - group.add_argument('-smn', - '--sequence_model_name', - dest='sequence_model_name', - action='store', - default=None) - group.add_argument('-m', - '--max_seq_length', - dest='max_seq_length', - type=int, - action='store', - required=sequence_model_required) - group.add_argument('-b', - '--batch_size', - dest='batch_size', - action='store', - type=int, - required=sequence_model_required) - group.add_argument('-p', - '--epochs', - dest='epochs', - action='store', - type=int, - required=sequence_model_required) - group.add_argument('-ti', - '--time_attention_model_folder', - dest='time_attention_model_folder', - action='store', - required=lstm_model_required) - group.add_argument('-vb', - '--vanilla_bert_model_folder', - dest='vanilla_bert_model_folder', - action='store', - required=vanilla_bert_lstm) - group.add_argument('-tb', - '--temporal_bert_model_folder', - dest='temporal_bert_model_folder', - action='store', - required=temporal_bert_lstm) - group.add_argument('--stride', - dest='stride', - action='store', - type=int, - required=sliding_bert) - group.add_argument('--context_window', - dest='context_window', - action='store', - type=int, - required=sliding_bert) - group.add_argument('--max_num_of_visits', - dest='max_num_of_visits', - action='store', - type=int, - required=hierarchical_bert) - group.add_argument('--max_num_of_concepts', - dest='max_num_of_concepts', - action='store', - type=int, - required=hierarchical_bert) - group.add_argument('--depth', - dest='depth', - action='store', - type=int, - default=5, - required=False) - group.add_argument('-nh', - '--num_heads', - dest='num_heads', - action='store', - type=int, - default=8, - required=False) - group.add_argument('-iv', - '--include_visit', - dest='include_visit_prediction', - action='store_true', - required=False) - group.add_argument('-ut', - '--use_time_embedding', - dest='use_time_embedding', - action='store_true', - required=False) - group.add_argument('--time_embeddings_size', - dest='time_embeddings_size', - action='store', - type=int, - default=16, - required=False) - group.add_argument('--embedding_size', - dest='embedding_size', - action='store', - type=int, - default=128, - required=False) - group.add_argument('--cross_validation_test', - dest='cross_validation_test', - action='store_true', - required=False) - group.add_argument('--k_fold_test', - dest='k_fold_test', - action='store_true', - required=False) - group.add_argument('--grid_search_config', - dest='grid_search_config', - action='store', - help='The path storing the grid search configuration', - required='cross_validation_test' in argv) - group.add_argument('--include_att_tokens', - dest='include_att_tokens', - action='store_true', - required=False) - group.add_argument('--is_chronological_test', - dest='is_chronological_test', - action='store_true', - required=False) - group.add_argument('--freeze_pretrained_model', - dest='freeze_pretrained_model', - action='store_true', - required=False) - group.add_argument('--multiple_test_run', - dest='multiple_test_run', - action='store_true', - required=False) + group = main_parser.add_argument_group("sequence model") + group.add_argument( + "-me", + "--model_evaluators", + dest="model_evaluators", + action="store", + nargs="+", + choices=SEQUENCE_MODEL_EVALUATORS, + required=sequence_model_required, + ) + group.add_argument( + "-sd", + "--sequence_model_data_path", + dest="sequence_model_data_path", + action="store", + required=sequence_model_required, + ) + group.add_argument( + "-smn", + "--sequence_model_name", + dest="sequence_model_name", + action="store", + default=None, + ) + group.add_argument( + "-m", + "--max_seq_length", + dest="max_seq_length", + type=int, + action="store", + required=sequence_model_required, + ) + group.add_argument( + "-b", + "--batch_size", + dest="batch_size", + action="store", + type=int, + required=sequence_model_required, + ) + group.add_argument( + "-p", + "--epochs", + dest="epochs", + action="store", + type=int, + required=sequence_model_required, + ) + group.add_argument( + "-ti", + "--time_attention_model_folder", + dest="time_attention_model_folder", + action="store", + required=lstm_model_required, + ) + group.add_argument( + "-vb", + "--vanilla_bert_model_folder", + dest="vanilla_bert_model_folder", + action="store", + required=vanilla_bert_lstm, + ) + group.add_argument( + "-tb", + "--temporal_bert_model_folder", + dest="temporal_bert_model_folder", + action="store", + required=temporal_bert_lstm, + ) + group.add_argument("--stride", dest="stride", action="store", type=int, required=sliding_bert) + group.add_argument( + "--context_window", + dest="context_window", + action="store", + type=int, + required=sliding_bert, + ) + group.add_argument( + "--max_num_of_visits", + dest="max_num_of_visits", + action="store", + type=int, + required=hierarchical_bert, + ) + group.add_argument( + "--max_num_of_concepts", + dest="max_num_of_concepts", + action="store", + type=int, + required=hierarchical_bert, + ) + group.add_argument("--depth", dest="depth", action="store", type=int, default=5, required=False) + group.add_argument( + "-nh", + "--num_heads", + dest="num_heads", + action="store", + type=int, + default=8, + required=False, + ) + group.add_argument( + "-iv", + "--include_visit", + dest="include_visit_prediction", + action="store_true", + required=False, + ) + group.add_argument( + "-ut", + "--use_time_embedding", + dest="use_time_embedding", + action="store_true", + required=False, + ) + group.add_argument( + "--time_embeddings_size", + dest="time_embeddings_size", + action="store", + type=int, + default=16, + required=False, + ) + group.add_argument( + "--embedding_size", + dest="embedding_size", + action="store", + type=int, + default=128, + required=False, + ) + group.add_argument( + "--cross_validation_test", + dest="cross_validation_test", + action="store_true", + required=False, + ) + group.add_argument("--k_fold_test", dest="k_fold_test", action="store_true", required=False) + group.add_argument( + "--grid_search_config", + dest="grid_search_config", + action="store", + help="The path storing the grid search configuration", + required="cross_validation_test" in argv, + ) + group.add_argument( + "--include_att_tokens", + dest="include_att_tokens", + action="store_true", + required=False, + ) + group.add_argument( + "--is_chronological_test", + dest="is_chronological_test", + action="store_true", + required=False, + ) + group.add_argument( + "--freeze_pretrained_model", + dest="freeze_pretrained_model", + action="store_true", + required=False, + ) + group.add_argument( + "--multiple_test_run", + dest="multiple_test_run", + action="store_true", + required=False, + ) return main_parser diff --git a/src/cehrbert/evaluations/model_evaluators/bert_model_evaluators.py b/src/cehrbert/evaluations/model_evaluators/bert_model_evaluators.py index 4554c0b5..8a39a186 100644 --- a/src/cehrbert/evaluations/model_evaluators/bert_model_evaluators.py +++ b/src/cehrbert/evaluations/model_evaluators/bert_model_evaluators.py @@ -1,64 +1,68 @@ import pickle import numpy as np +import tensorflow as tf -from ...data_generators.learning_objective import post_pad_pre_truncate -from ..model_evaluators.model_evaluators import get_metrics -from ..model_evaluators.sequence_model_evaluators import SequenceModelEvaluator -from ...models.evaluation_models import * +from cehrbert.data_generators.learning_objective import post_pad_pre_truncate +from cehrbert.evaluations.model_evaluators.model_evaluators import get_metrics +from cehrbert.evaluations.model_evaluators.sequence_model_evaluators import SequenceModelEvaluator +from cehrbert.models.evaluation_models import ( + create_probabilistic_bert_bi_lstm_model, + create_random_vanilla_bert_bi_lstm_model, + create_sliding_bert_model, + create_temporal_bert_bi_lstm_model, + create_vanilla_bert_bi_lstm_model, + create_vanilla_feed_forward_model, +) class BertLstmModelEvaluator(SequenceModelEvaluator): - def __init__(self, - max_seq_length: str, - bert_model_path: str, - tokenizer_path: str, - is_temporal: bool = True, - *args, **kwargs): + def __init__( + self, + max_seq_length: str, + bert_model_path: str, + tokenizer_path: str, + is_temporal: bool = True, + *args, + **kwargs, + ): self._max_seq_length = max_seq_length self._bert_model_path = bert_model_path - self._tokenizer = pickle.load(open(tokenizer_path, 'rb')) + self._tokenizer = pickle.load(open(tokenizer_path, "rb")) self._is_temporal = is_temporal - self.get_logger().info(f'max_seq_length: {max_seq_length}\n' - f'vanilla_bert_model_path: {bert_model_path}\n' - f'tokenizer_path: {tokenizer_path}\n' - f'is_temporal: {is_temporal}\n') + self.get_logger().info( + f"max_seq_length: {max_seq_length}\n" + f"vanilla_bert_model_path: {bert_model_path}\n" + f"tokenizer_path: {tokenizer_path}\n" + f"is_temporal: {is_temporal}\n" + ) super(BertLstmModelEvaluator, self).__init__(*args, **kwargs) - def _create_model( - self, - **kwargs - ): + def _create_model(self, **kwargs): strategy = tf.distribute.MirroredStrategy() - self.get_logger().info('Number of devices: {}'.format(strategy.num_replicas_in_sync)) + self.get_logger().info("Number of devices: {}".format(strategy.num_replicas_in_sync)) with strategy.scope(): - create_model_fn = (create_temporal_bert_bi_lstm_model if self._is_temporal - else create_vanilla_bert_bi_lstm_model) + create_model_fn = ( + create_temporal_bert_bi_lstm_model if self._is_temporal else create_vanilla_bert_bi_lstm_model + ) try: - model = create_model_fn( - self._max_seq_length, - self._bert_model_path, - **kwargs - ) + model = create_model_fn(self._max_seq_length, self._bert_model_path, **kwargs) except ValueError as e: self.get_logger().exception(e) - model = create_model_fn( - self._max_seq_length, - self._bert_model_path, - **kwargs - ) + model = create_model_fn(self._max_seq_length, self._bert_model_path, **kwargs) - model.compile(loss='binary_crossentropy', - optimizer=tf.keras.optimizers.Adam(1e-4), - metrics=get_metrics()) + model.compile( + loss="binary_crossentropy", + optimizer=tf.keras.optimizers.Adam(1e-4), + metrics=get_metrics(), + ) return model def extract_model_inputs(self): - token_ids = self._tokenizer.encode( - self._dataset.concept_ids.apply(lambda concept_ids: concept_ids.tolist())) + token_ids = self._tokenizer.encode(self._dataset.concept_ids.apply(lambda concept_ids: concept_ids.tolist())) visit_segments = self._dataset.visit_segments time_stamps = self._dataset.dates ages = self._dataset.ages @@ -67,154 +71,151 @@ def extract_model_inputs(self): # ((self._dataset['age'] - self._dataset['age'].mean()) / self._dataset[ # 'age'].std()).astype(float).apply(lambda c: [c]).tolist()) labels = self._dataset.label.to_numpy() - padded_token_ides = post_pad_pre_truncate(token_ids, self._tokenizer.get_unused_token_id(), - self._max_seq_length) + padded_token_ides = post_pad_pre_truncate( + token_ids, self._tokenizer.get_unused_token_id(), self._max_seq_length + ) padded_visit_segments = post_pad_pre_truncate(visit_segments, 0, self._max_seq_length) mask = (padded_token_ides == self._tokenizer.get_unused_token_id()).astype(int) padded_time_stamps = post_pad_pre_truncate(time_stamps, 0, self._max_seq_length) padded_ages = post_pad_pre_truncate(ages, 0, self._max_seq_length) - padded_visit_concept_orders = post_pad_pre_truncate(visit_concept_orders, - self._max_seq_length, - self._max_seq_length) + padded_visit_concept_orders = post_pad_pre_truncate( + visit_concept_orders, self._max_seq_length, self._max_seq_length + ) # Retrieve the values associated with the concepts, this is mostly for measurements padded_concept_values = post_pad_pre_truncate( - self._dataset.concept_values, - -1.0, - self._max_seq_length, - d_type='float32' + self._dataset.concept_values, -1.0, self._max_seq_length, d_type="float32" ) - padded_concept_value_masks = post_pad_pre_truncate( - self._dataset.concept_value_masks, - 0, - self._max_seq_length - ) + padded_concept_value_masks = post_pad_pre_truncate(self._dataset.concept_value_masks, 0, self._max_seq_length) inputs = { - 'age': np.expand_dims(self._dataset.age, axis=-1), - 'concept_ids': padded_token_ides, - 'masked_concept_ids': padded_token_ides, - 'mask': mask, - 'visit_segments': padded_visit_segments, - 'time_stamps': padded_time_stamps, - 'ages': padded_ages, - 'visit_concept_orders': padded_visit_concept_orders, - 'concept_values': padded_concept_values, - 'concept_value_masks': padded_concept_value_masks, + "age": np.expand_dims(self._dataset.age, axis=-1), + "concept_ids": padded_token_ides, + "masked_concept_ids": padded_token_ides, + "mask": mask, + "visit_segments": padded_visit_segments, + "time_stamps": padded_time_stamps, + "ages": padded_ages, + "visit_concept_orders": padded_visit_concept_orders, + "concept_values": padded_concept_values, + "concept_value_masks": padded_concept_value_masks, } return inputs, labels class ProbabilisticBertModelEvaluator(BertLstmModelEvaluator): - def __init__(self, - *args, **kwargs): + def __init__(self, *args, **kwargs): super(ProbabilisticBertModelEvaluator, self).__init__(*args, **kwargs) def _create_model(self): strategy = tf.distribute.MirroredStrategy() - self.get_logger().info('Number of devices: {}'.format(strategy.num_replicas_in_sync)) + self.get_logger().info("Number of devices: {}".format(strategy.num_replicas_in_sync)) with strategy.scope(): try: - model = create_probabilistic_bert_bi_lstm_model( - self._max_seq_length, - self._bert_model_path - ) + model = create_probabilistic_bert_bi_lstm_model(self._max_seq_length, self._bert_model_path) except ValueError as e: self.get_logger().exception(e) - model = create_probabilistic_bert_bi_lstm_model( - self._max_seq_length, - self._bert_model_path - ) - model.compile(loss='binary_crossentropy', - optimizer=tf.keras.optimizers.Adam(1e-4), - metrics=get_metrics()) + model = create_probabilistic_bert_bi_lstm_model(self._max_seq_length, self._bert_model_path) + model.compile( + loss="binary_crossentropy", + optimizer=tf.keras.optimizers.Adam(1e-4), + metrics=get_metrics(), + ) return model class BertFeedForwardModelEvaluator(BertLstmModelEvaluator): - def __init__(self, - *args, **kwargs): + def __init__(self, *args, **kwargs): super(BertFeedForwardModelEvaluator, self).__init__(*args, **kwargs) def _create_model(self): strategy = tf.distribute.MirroredStrategy() - self.get_logger().info('Number of devices: {}'.format(strategy.num_replicas_in_sync)) + self.get_logger().info("Number of devices: {}".format(strategy.num_replicas_in_sync)) with strategy.scope(): try: model = create_vanilla_feed_forward_model((self._bert_model_path)) except ValueError as e: self.get_logger().exception(e) model = create_vanilla_feed_forward_model((self._bert_model_path)) - model.compile(loss='binary_crossentropy', - optimizer=tf.keras.optimizers.Adam(1e-4), - metrics=get_metrics()) + model.compile( + loss="binary_crossentropy", + optimizer=tf.keras.optimizers.Adam(1e-4), + metrics=get_metrics(), + ) return model class SlidingBertModelEvaluator(BertLstmModelEvaluator): - def __init__(self, - context_window: int, - stride: int, *args, **kwargs): + def __init__(self, context_window: int, stride: int, *args, **kwargs): self._context_window = context_window self._stride = stride super(SlidingBertModelEvaluator, self).__init__(*args, **kwargs) def _create_model(self): strategy = tf.distribute.MirroredStrategy() - self.get_logger().info('Number of devices: {}'.format(strategy.num_replicas_in_sync)) + self.get_logger().info("Number of devices: {}".format(strategy.num_replicas_in_sync)) with strategy.scope(): try: model = create_sliding_bert_model( model_path=self._bert_model_path, max_seq_length=self._max_seq_length, context_window=self._context_window, - stride=self._stride) + stride=self._stride, + ) except ValueError as e: self.get_logger().exception(e) model = create_sliding_bert_model( model_path=self._bert_model_path, max_seq_length=self._max_seq_length, context_window=self._context_window, - stride=self._stride) - model.compile(loss='binary_crossentropy', - optimizer=tf.keras.optimizers.Adam(1e-4), - metrics=get_metrics()) + stride=self._stride, + ) + model.compile( + loss="binary_crossentropy", + optimizer=tf.keras.optimizers.Adam(1e-4), + metrics=get_metrics(), + ) return model class RandomVanillaLstmBertModelEvaluator(BertLstmModelEvaluator): - def __init__(self, - embedding_size, - depth, - num_heads, - use_time_embedding, - time_embeddings_size, - visit_tokenizer_path, - *args, **kwargs): + def __init__( + self, + embedding_size, + depth, + num_heads, + use_time_embedding, + time_embeddings_size, + visit_tokenizer_path, + *args, + **kwargs, + ): self._embedding_size = embedding_size self._depth = depth self._num_heads = num_heads self._use_time_embedding = use_time_embedding self._time_embeddings_size = time_embeddings_size - self._visit_tokenizer = pickle.load(open(visit_tokenizer_path, 'rb')) + self._visit_tokenizer = pickle.load(open(visit_tokenizer_path, "rb")) super(RandomVanillaLstmBertModelEvaluator, self).__init__(*args, **kwargs) - self.get_logger().info(f'embedding_size: {embedding_size}\n' - f'depth: {depth}\n' - f'num_heads: {num_heads}\n' - f'use_time_embedding: {use_time_embedding}\n' - f'time_embeddings_size: {time_embeddings_size}\n' - f'visit_tokenizer_path: {visit_tokenizer_path}\n') + self.get_logger().info( + f"embedding_size: {embedding_size}\n" + f"depth: {depth}\n" + f"num_heads: {num_heads}\n" + f"use_time_embedding: {use_time_embedding}\n" + f"time_embeddings_size: {time_embeddings_size}\n" + f"visit_tokenizer_path: {visit_tokenizer_path}\n" + ) def _create_model(self): strategy = tf.distribute.MirroredStrategy() - self.get_logger().info('Number of devices: {}'.format(strategy.num_replicas_in_sync)) + self.get_logger().info("Number of devices: {}".format(strategy.num_replicas_in_sync)) with strategy.scope(): try: @@ -226,7 +227,8 @@ def _create_model(self): visit_tokenizer=self._visit_tokenizer, num_heads=self._num_heads, use_time_embedding=self._use_time_embedding, - time_embeddings_size=self._time_embeddings_size) + time_embeddings_size=self._time_embeddings_size, + ) except ValueError as e: self.get_logger().exception(e) model = create_random_vanilla_bert_bi_lstm_model( @@ -237,8 +239,11 @@ def _create_model(self): visit_tokenizer=self._visit_tokenizer, num_heads=self._num_heads, use_time_embedding=self._use_time_embedding, - time_embeddings_size=self._time_embeddings_size) - model.compile(loss='binary_crossentropy', - optimizer=tf.keras.optimizers.Adam(1e-4), - metrics=get_metrics()) + time_embeddings_size=self._time_embeddings_size, + ) + model.compile( + loss="binary_crossentropy", + optimizer=tf.keras.optimizers.Adam(1e-4), + metrics=get_metrics(), + ) return model diff --git a/src/cehrbert/evaluations/model_evaluators/frequency_model_evaluators.py b/src/cehrbert/evaluations/model_evaluators/frequency_model_evaluators.py index d85f315c..9bec07e2 100644 --- a/src/cehrbert/evaluations/model_evaluators/frequency_model_evaluators.py +++ b/src/cehrbert/evaluations/model_evaluators/frequency_model_evaluators.py @@ -2,17 +2,16 @@ from itertools import chain import numpy as np - from scipy.sparse import csr_matrix, hstack from sklearn.linear_model import LogisticRegression from sklearn.model_selection import GridSearchCV, StratifiedKFold, StratifiedShuffleSplit from sklearn.pipeline import Pipeline -from sklearn.preprocessing import normalize, StandardScaler +from sklearn.preprocessing import StandardScaler, normalize from tensorflow.keras.preprocessing.text import Tokenizer from xgboost import XGBClassifier -from ..model_evaluators.model_evaluators import AbstractModelEvaluator -from ...utils.model_utils import compute_binary_metrics +from cehrbert.evaluations.model_evaluators.model_evaluators import AbstractModelEvaluator +from cehrbert.utils.model_utils import compute_binary_metrics class BaselineModelEvaluator(AbstractModelEvaluator, ABC): @@ -32,22 +31,18 @@ def eval_model(self): train = np.where(~test_mask)[0] val_test = np.where(test_mask)[0] x, y = csr_matrix(hstack([inputs[train], age[train]])), labels[train] - test_data = (csr_matrix(hstack([inputs[val_test], age[val_test]])), labels[val_test]) + test_data = ( + csr_matrix(hstack([inputs[val_test], age[val_test]])), + labels[val_test], + ) self._model = self._create_model() if isinstance(self._model, GridSearchCV): self._model = self._model.fit(x, y) else: self._model.fit(x, y) - compute_binary_metrics( - self._model, - test_data, - self.get_model_metrics_folder() - ) + compute_binary_metrics(self._model, test_data, self.get_model_metrics_folder()) else: - for train, test in self.k_fold( - features=(inputs, age, person_ids), - labels=labels - ): + for train, test in self.k_fold(features=(inputs, age, person_ids), labels=labels): x, y = train self._model = self._create_model() if isinstance(self._model, GridSearchCV): @@ -55,11 +50,7 @@ def eval_model(self): else: self._model.fit(x, y) - compute_binary_metrics( - self._model, - test, - self.get_model_metrics_folder() - ) + compute_binary_metrics(self._model, test, self.get_model_metrics_folder()) def get_model_name(self): return type(self._model).__name__ @@ -72,27 +63,23 @@ def k_fold(self, features, labels): (inputs, age, person_ids) = features if self._k_fold_test: - stratified_splitter = StratifiedKFold( - n_splits=self._num_of_folds, - random_state=10 - ) + stratified_splitter = StratifiedKFold(n_splits=self._num_of_folds, random_state=10) else: - stratified_splitter = StratifiedShuffleSplit( - n_splits=1, - test_size=0.15, - random_state=10 - ) + stratified_splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.15, random_state=10) - for train, val_test in stratified_splitter.split( - X=labels, - y=labels - ): + for train, val_test in stratified_splitter.split(X=labels, y=labels): # further split val_test using a 2:3 ratio between val and test if self._is_transfer_learning: size = int(len(train) * self._training_percentage) train = np.random.choice(train, size, replace=False) - train_data = (csr_matrix(hstack([inputs[train], age[train]])), labels[train]) - test_data = (csr_matrix(hstack([inputs[val_test], age[val_test]])), labels[val_test]) + train_data = ( + csr_matrix(hstack([inputs[train], age[train]])), + labels[train], + ) + test_data = ( + csr_matrix(hstack([inputs[val_test], age[val_test]])), + labels[val_test], + ) yield train_data, test_data def extract_model_inputs(self): @@ -102,49 +89,52 @@ def extract_model_inputs(self): self._dataset.gender_concept_id = self._dataset.gender_concept_id.astype(str) # Tokenize the concepts - tokenizer = Tokenizer(filters='', lower=False) - tokenizer.fit_on_texts(self._dataset['concept_ids']) - self._dataset['token_ids'] = tokenizer.texts_to_sequences(self._dataset['concept_ids']) + tokenizer = Tokenizer(filters="", lower=False) + tokenizer.fit_on_texts(self._dataset["concept_ids"]) + self._dataset["token_ids"] = tokenizer.texts_to_sequences(self._dataset["concept_ids"]) # Create the row index dataset = self._dataset.reset_index().reset_index() - dataset['row_index'] = dataset[['token_ids', 'level_0']].apply( - lambda tup: [tup[1]] * len(tup[0]), axis=1) + dataset["row_index"] = dataset[["token_ids", "level_0"]].apply(lambda tup: [tup[1]] * len(tup[0]), axis=1) - row_index = list(chain(*dataset['row_index'].tolist())) - col_index = list(chain(*dataset['token_ids'].tolist())) - values = list(chain(*dataset['frequencies'].tolist())) + row_index = list(chain(*dataset["row_index"].tolist())) + col_index = list(chain(*dataset["token_ids"].tolist())) + values = list(chain(*dataset["frequencies"].tolist())) data_size = len(dataset) vocab_size = len(tokenizer.word_index) + 1 - row_index, col_index, values = zip( - *sorted(zip(row_index, col_index, values), key=lambda tup: (tup[0], tup[1]))) + row_index, col_index, values = zip(*sorted(zip(row_index, col_index, values), key=lambda tup: (tup[0], tup[1]))) - concept_freq_count = csr_matrix((values, (row_index, col_index)), - shape=(data_size, vocab_size)) + concept_freq_count = csr_matrix((values, (row_index, col_index)), shape=(data_size, vocab_size)) normalized_concept_freq_count = normalize(concept_freq_count) # one_hot_gender_race = OneHotEncoder(handle_unknown='ignore') \ # .fit_transform(dataset[['gender_concept_id', 'race_concept_id']].to_numpy()) - scaled_age = StandardScaler().fit_transform(dataset[['age']].to_numpy()) + scaled_age = StandardScaler().fit_transform(dataset[["age"]].to_numpy()) - y = dataset['label'].to_numpy() + y = dataset["label"].to_numpy() - return normalized_concept_freq_count, scaled_age, y, self._dataset.person_id.to_numpy() + return ( + normalized_concept_freq_count, + scaled_age, + y, + self._dataset.person_id.to_numpy(), + ) class LogisticRegressionModelEvaluator(BaselineModelEvaluator): def _create_model(self, *args, **kwargs): - pipe = Pipeline([('classifier', LogisticRegression())]) + pipe = Pipeline([("classifier", LogisticRegression())]) # Create param grid. param_grid = [ - {'classifier': [LogisticRegression()], - 'classifier__penalty': ['l1', 'l2'], - 'classifier__C': np.logspace(-4, 4, 20), - 'classifier__solver': ['liblinear'], - 'classifier__max_iter': [500] - } + { + "classifier": [LogisticRegression()], + "classifier__penalty": ["l1", "l2"], + "classifier__C": np.logspace(-4, 4, 20), + "classifier__solver": ["liblinear"], + "classifier__max_iter": [500], + } ] # Create grid search object clf = GridSearchCV(pipe, param_grid=param_grid, cv=5, verbose=True, n_jobs=-1) diff --git a/src/cehrbert/evaluations/model_evaluators/hierarchical_bert_evaluators.py b/src/cehrbert/evaluations/model_evaluators/hierarchical_bert_evaluators.py index 88e30202..b97e2886 100644 --- a/src/cehrbert/evaluations/model_evaluators/hierarchical_bert_evaluators.py +++ b/src/cehrbert/evaluations/model_evaluators/hierarchical_bert_evaluators.py @@ -1,59 +1,58 @@ from tensorflow.keras.utils import pad_sequences -from ...data_generators.learning_objective import post_pad_pre_truncate -from ..model_evaluators.model_evaluators import get_metrics -from ..model_evaluators.sequence_model_evaluators import SequenceModelEvaluator -from ...models.evaluation_models import create_hierarchical_bert_bi_lstm_model, \ - create_hierarchical_bert_bi_lstm_model_with_model, \ - create_hierarchical_bert_model_with_pooling -from ...models.hierachical_bert_model_v2 import transformer_hierarchical_bert_model -from ...utils.model_utils import * +from cehrbert.data_generators.learning_objective import post_pad_pre_truncate +from cehrbert.evaluations.model_evaluators.model_evaluators import get_metrics +from cehrbert.evaluations.model_evaluators.sequence_model_evaluators import SequenceModelEvaluator +from cehrbert.models.evaluation_models import ( + create_hierarchical_bert_bi_lstm_model, + create_hierarchical_bert_bi_lstm_model_with_model, + create_hierarchical_bert_model_with_pooling, +) +from cehrbert.models.hierachical_bert_model_v2 import transformer_hierarchical_bert_model +from cehrbert.utils.model_utils import convert_to_list_of_lists, np, pickle, tf class HierarchicalBertEvaluator(SequenceModelEvaluator): def __init__( - self, - bert_model_path: str, - tokenizer_path: str, - visit_tokenizer_path: str, - max_num_of_visits: int, - max_num_of_concepts: int, - include_att_tokens: bool = False, - *args, - **kwargs + self, + bert_model_path: str, + tokenizer_path: str, + visit_tokenizer_path: str, + max_num_of_visits: int, + max_num_of_concepts: int, + include_att_tokens: bool = False, + *args, + **kwargs, ): self._max_num_of_visits = max_num_of_visits self._max_num_of_concepts = max_num_of_concepts self._bert_model_path = bert_model_path - self._tokenizer = pickle.load(open(tokenizer_path, 'rb')) - self._visit_tokenizer = pickle.load(open(visit_tokenizer_path, 'rb')) + self._tokenizer = pickle.load(open(tokenizer_path, "rb")) + self._visit_tokenizer = pickle.load(open(visit_tokenizer_path, "rb")) self._include_att_tokens = include_att_tokens self.get_logger().info( - f'max_num_of_visits: {max_num_of_visits}\n' - f'max_num_of_concepts: {max_num_of_concepts}\n' - f'vanilla_bert_model_path: {bert_model_path}\n' - f'tokenizer_path: {tokenizer_path}\n' - f'include_att_token: {include_att_tokens}\n' - f'visit_tokenizer_path: {visit_tokenizer_path}\n' + f"max_num_of_visits: {max_num_of_visits}\n" + f"max_num_of_concepts: {max_num_of_concepts}\n" + f"vanilla_bert_model_path: {bert_model_path}\n" + f"tokenizer_path: {tokenizer_path}\n" + f"include_att_token: {include_att_tokens}\n" + f"visit_tokenizer_path: {visit_tokenizer_path}\n" ) super(HierarchicalBertEvaluator, self).__init__(*args, **kwargs) - def _create_model( - self, - **kwargs - ): + def _create_model(self, **kwargs): strategy = tf.distribute.MirroredStrategy() - self.get_logger().info('Number of devices: {}'.format(strategy.num_replicas_in_sync)) + self.get_logger().info("Number of devices: {}".format(strategy.num_replicas_in_sync)) with strategy.scope(): try: model = create_hierarchical_bert_bi_lstm_model( self._bert_model_path, include_att_tokens=self._include_att_tokens, freeze_pretrained_model=self._freeze_pretrained_model, - **kwargs + **kwargs, ) except ValueError as e: self.get_logger().exception(e) @@ -61,115 +60,110 @@ def _create_model( self._bert_model_path, include_att_tokens=self._include_att_tokens, freeze_pretrained_model=self._freeze_pretrained_model, - **kwargs + **kwargs, ) - model.compile(loss='binary_crossentropy', - optimizer=tf.keras.optimizers.Adam(1e-4), - metrics=get_metrics()) + model.compile( + loss="binary_crossentropy", + optimizer=tf.keras.optimizers.Adam(1e-4), + metrics=get_metrics(), + ) return model def _concept_mask(self, concept_ids): return list( - map(lambda c: (c == self._tokenizer.get_unused_token_id()).astype(int), concept_ids)) + map( + lambda c: (c == self._tokenizer.get_unused_token_id()).astype(int), + concept_ids, + ) + ) - def _pad(self, x, padded_token, dtype='int32'): + def _pad(self, x, padded_token, dtype="int32"): return pad_sequences( np.asarray(x), maxlen=self._max_num_of_concepts, - padding='post', - truncating='post', + padding="post", + truncating="post", value=padded_token, - dtype=dtype) + dtype=dtype, + ) def extract_model_inputs(self): max_seq_len = self._max_num_of_concepts * self._max_num_of_visits unused_token_id = self._tokenizer.get_unused_token_id() # Process concept ids - token_ids = self._dataset.concept_ids \ - .apply(convert_to_list_of_lists) \ - .apply(self._tokenizer.encode) \ + token_ids = ( + self._dataset.concept_ids.apply(convert_to_list_of_lists) + .apply(self._tokenizer.encode) .apply(lambda tokens: self._pad(tokens, padded_token=unused_token_id)) + ) padded_token_ids = np.reshape( - post_pad_pre_truncate( - token_ids.apply(lambda d: d.flatten()), - unused_token_id, - max_seq_len - ), - (-1, self._max_num_of_visits, self._max_num_of_concepts) + post_pad_pre_truncate(token_ids.apply(lambda d: d.flatten()), unused_token_id, max_seq_len), + (-1, self._max_num_of_visits, self._max_num_of_concepts), ) # Generate the concept mask pat_mask = (padded_token_ids == unused_token_id).astype(int) # Process age sequence - ages = self._dataset.ages.apply( - convert_to_list_of_lists - ).apply( + ages = self._dataset.ages.apply(convert_to_list_of_lists).apply( lambda tokens: self._pad(tokens, padded_token=0) ) padded_ages = np.reshape( post_pad_pre_truncate(ages.apply(lambda d: d.flatten()), 0, max_seq_len), - (-1, self._max_num_of_visits, self._max_num_of_concepts) + (-1, self._max_num_of_visits, self._max_num_of_concepts), ) # Process time sequence - dates = self._dataset.dates \ - .apply(convert_to_list_of_lists) \ - .apply(lambda tokens: self._pad(tokens, padded_token=0)) + dates = self._dataset.dates.apply(convert_to_list_of_lists).apply( + lambda tokens: self._pad(tokens, padded_token=0) + ) padded_dates = np.reshape( post_pad_pre_truncate(dates.apply(lambda d: d.flatten()), 0, max_seq_len), - (-1, self._max_num_of_visits, self._max_num_of_concepts) + (-1, self._max_num_of_visits, self._max_num_of_concepts), ) # Process concept ids # Retrieve the values associated with the concepts, this is mostly for measurements - concept_values = self._dataset.concept_values \ - .apply(convert_to_list_of_lists) \ - .apply(lambda tokens: self._pad(tokens, padded_token=-1.0, dtype='float32')) + concept_values = self._dataset.concept_values.apply(convert_to_list_of_lists).apply( + lambda tokens: self._pad(tokens, padded_token=-1.0, dtype="float32") + ) padded_concept_values = np.reshape( post_pad_pre_truncate(concept_values.apply(lambda d: d.flatten()), 0, max_seq_len), - (-1, self._max_num_of_visits, self._max_num_of_concepts) + (-1, self._max_num_of_visits, self._max_num_of_concepts), ) - concept_value_masks = self._dataset.concept_value_masks \ - .apply(convert_to_list_of_lists) \ - .apply(lambda tokens: self._pad(tokens, padded_token=0)) + concept_value_masks = self._dataset.concept_value_masks.apply(convert_to_list_of_lists).apply( + lambda tokens: self._pad(tokens, padded_token=0) + ) padded_concept_value_masks = np.reshape( post_pad_pre_truncate(concept_value_masks.apply(lambda d: d.flatten()), 0, max_seq_len), - (-1, self._max_num_of_visits, self._max_num_of_concepts) + (-1, self._max_num_of_visits, self._max_num_of_concepts), ) # Process att tokens - att_tokens = self._tokenizer.encode( - self._dataset.time_interval_atts.apply(lambda t: t.tolist()).tolist()) - padded_att_tokens = post_pad_pre_truncate( - att_tokens, - unused_token_id, - self._max_num_of_visits - )[:, 1:] + att_tokens = self._tokenizer.encode(self._dataset.time_interval_atts.apply(lambda t: t.tolist()).tolist()) + padded_att_tokens = post_pad_pre_truncate(att_tokens, unused_token_id, self._max_num_of_visits)[:, 1:] # Process visit segments padded_visit_segments = post_pad_pre_truncate( self._dataset.visit_segments, pad_value=0, - max_seq_len=self._max_num_of_visits + max_seq_len=self._max_num_of_visits, ) # Process visit_rank_orders padded_visit_rank_orders = post_pad_pre_truncate( self._dataset.visit_rank_orders, pad_value=0, - max_seq_len=self._max_num_of_visits + max_seq_len=self._max_num_of_visits, ) padded_visit_mask = post_pad_pre_truncate( - self._dataset.visit_masks, - pad_value=1, - max_seq_len=self._max_num_of_visits + self._dataset.visit_masks, pad_value=1, max_seq_len=self._max_num_of_visits ) visit_token_ids = self._visit_tokenizer.encode( @@ -179,22 +173,22 @@ def extract_model_inputs(self): padded_masked_visit_type = post_pad_pre_truncate( visit_token_ids, pad_value=self._visit_tokenizer.get_unused_token_id(), - max_seq_len=self._max_num_of_visits + max_seq_len=self._max_num_of_visits, ) inputs = { - 'pat_seq': padded_token_ids, - 'pat_mask': pat_mask, - 'pat_seq_time': padded_dates, - 'pat_seq_age': padded_ages, - 'visit_segment': padded_visit_segments, - 'visit_rank_order': padded_visit_rank_orders, - 'visit_time_delta_att': padded_att_tokens, - 'visit_mask': padded_visit_mask, - 'concept_values': padded_concept_values, - 'concept_value_masks': padded_concept_value_masks, - 'masked_visit_type': padded_masked_visit_type, - 'age': np.expand_dims(self._dataset.age, axis=-1) + "pat_seq": padded_token_ids, + "pat_mask": pat_mask, + "pat_seq_time": padded_dates, + "pat_seq_age": padded_ages, + "visit_segment": padded_visit_segments, + "visit_rank_order": padded_visit_rank_orders, + "visit_time_delta_att": padded_att_tokens, + "visit_mask": padded_visit_mask, + "concept_values": padded_concept_values, + "concept_value_masks": padded_concept_value_masks, + "masked_visit_type": padded_masked_visit_type, + "age": np.expand_dims(self._dataset.age, axis=-1), } labels = self._dataset.label.to_numpy() @@ -202,51 +196,47 @@ def extract_model_inputs(self): class HierarchicalBertPoolingEvaluator(HierarchicalBertEvaluator): - def __init__( - self, - *args, - **kwargs - ): + def __init__(self, *args, **kwargs): super(HierarchicalBertPoolingEvaluator, self).__init__(*args, **kwargs) - def _create_model( - self, - **kwargs - ): + def _create_model(self, **kwargs): strategy = tf.distribute.MirroredStrategy() - self.get_logger().info('Number of devices: {}'.format(strategy.num_replicas_in_sync)) + self.get_logger().info("Number of devices: {}".format(strategy.num_replicas_in_sync)) with strategy.scope(): try: model = create_hierarchical_bert_model_with_pooling( self._bert_model_path, freeze_pretrained_model=self._freeze_pretrained_model, - **kwargs + **kwargs, ) except ValueError as e: self.get_logger().exception(e) model = create_hierarchical_bert_model_with_pooling( self._bert_model_path, freeze_pretrained_model=self._freeze_pretrained_model, - **kwargs + **kwargs, ) - model.compile(loss='binary_crossentropy', - optimizer=tf.keras.optimizers.Adam(1e-4), - metrics=get_metrics()) + model.compile( + loss="binary_crossentropy", + optimizer=tf.keras.optimizers.Adam(1e-4), + metrics=get_metrics(), + ) return model class RandomHierarchicalBertEvaluator(HierarchicalBertEvaluator): def __init__( - self, - num_of_exchanges, - embedding_size, - depth, - num_heads, - use_time_embedding, - time_embeddings_size, - *args, **kwargs + self, + num_of_exchanges, + embedding_size, + depth, + num_heads, + use_time_embedding, + time_embeddings_size, + *args, + **kwargs, ): self._num_of_exchanges = num_of_exchanges self._embedding_size = embedding_size @@ -256,16 +246,18 @@ def __init__( self._time_embeddings_size = time_embeddings_size super(RandomHierarchicalBertEvaluator, self).__init__(*args, **kwargs) - self.get_logger().info(f'num_of_exchanges: {num_of_exchanges}\n' - f'embedding_size: {embedding_size}\n' - f'depth: {depth}\n' - f'num_heads: {num_heads}\n' - f'use_time_embedding: {use_time_embedding}\n' - f'time_embeddings_size: {time_embeddings_size}\n') + self.get_logger().info( + f"num_of_exchanges: {num_of_exchanges}\n" + f"embedding_size: {embedding_size}\n" + f"depth: {depth}\n" + f"num_heads: {num_heads}\n" + f"use_time_embedding: {use_time_embedding}\n" + f"time_embeddings_size: {time_embeddings_size}\n" + ) def _create_model(self): strategy = tf.distribute.MirroredStrategy() - self.get_logger().info('Number of devices: {}'.format(strategy.num_replicas_in_sync)) + self.get_logger().info("Number of devices: {}".format(strategy.num_replicas_in_sync)) with strategy.scope(): try: @@ -277,17 +269,15 @@ def _create_model(self): depth=self._depth, num_heads=self._num_heads, num_of_exchanges=self._num_of_exchanges, - time_embeddings_size=self._time_embeddings_size - ) - model = create_hierarchical_bert_bi_lstm_model_with_model( - cherbert_model + time_embeddings_size=self._time_embeddings_size, ) + model = create_hierarchical_bert_bi_lstm_model_with_model(cherbert_model) except ValueError as e: self.get_logger().exception(e) - model = create_hierarchical_bert_bi_lstm_model_with_model( - cherbert_model - ) - model.compile(loss='binary_crossentropy', - optimizer=tf.keras.optimizers.Adam(1e-4), - metrics=get_metrics()) + model = create_hierarchical_bert_bi_lstm_model_with_model(cherbert_model) + model.compile( + loss="binary_crossentropy", + optimizer=tf.keras.optimizers.Adam(1e-4), + metrics=get_metrics(), + ) return model diff --git a/src/cehrbert/evaluations/model_evaluators/model_evaluators.py b/src/cehrbert/evaluations/model_evaluators/model_evaluators.py index cea9c4d2..e9260c39 100644 --- a/src/cehrbert/evaluations/model_evaluators/model_evaluators.py +++ b/src/cehrbert/evaluations/model_evaluators/model_evaluators.py @@ -2,35 +2,39 @@ from abc import abstractmethod from ...trainers.model_trainer import AbstractModel -from ...utils.model_utils import * +from ...utils.model_utils import os, pathlib, tf def get_metrics(): """ - Standard metrics used for compiling the models + Standard metrics used for compiling the models. + :return: """ - return ['binary_accuracy', - tf.keras.metrics.Recall(name='recall'), - tf.keras.metrics.Precision(name='precision'), - tf.keras.metrics.AUC(curve='PR', name='pr_auc'), - tf.keras.metrics.AUC(name='auc')] + return [ + "binary_accuracy", + tf.keras.metrics.Recall(name="recall"), + tf.keras.metrics.Precision(name="precision"), + tf.keras.metrics.AUC(curve="PR", name="pr_auc"), + tf.keras.metrics.AUC(name="auc"), + ] class AbstractModelEvaluator(AbstractModel): def __init__( - self, - dataset, - evaluation_folder, - num_of_folds, - is_transfer_learning: bool = False, - training_percentage: float = 1.0, - learning_rate: float = 1e-4, - is_chronological_test: bool = False, - k_fold_test: bool = False, - test_person_ids=None, - *args, **kwargs + self, + dataset, + evaluation_folder, + num_of_folds, + is_transfer_learning: bool = False, + training_percentage: float = 1.0, + learning_rate: float = 1e-4, + is_chronological_test: bool = False, + k_fold_test: bool = False, + test_person_ids=None, + *args, + **kwargs, ): self._dataset = copy.copy(dataset) self._evaluation_folder = evaluation_folder @@ -43,24 +47,23 @@ def __init__( self._test_person_ids = test_person_ids if is_transfer_learning: - extension = 'transfer_learning_{:.2f}'.format(self._training_percentage).replace('.', - '_') + extension = "transfer_learning_{:.2f}".format(self._training_percentage).replace(".", "_") self._evaluation_folder = os.path.join(self._evaluation_folder, extension) self.get_logger().info( - f'evaluation_folder: {self._evaluation_folder}\n' - f'num_of_folds: {self._num_of_folds}\n' - f'is_transfer_learning {self._is_transfer_learning}\n' - f'training_percentage: {self._training_percentage}\n' - f'learning_rate: {self._learning_rate}\n' - f'is_chronological_test: {is_chronological_test}\n' - f'k_fold_test: {k_fold_test}\n' + f"evaluation_folder: {self._evaluation_folder}\n" + f"num_of_folds: {self._num_of_folds}\n" + f"is_transfer_learning {self._is_transfer_learning}\n" + f"training_percentage: {self._training_percentage}\n" + f"learning_rate: {self._learning_rate}\n" + f"is_chronological_test: {is_chronological_test}\n" + f"k_fold_test: {k_fold_test}\n" ) if self._is_chronological_test: - self.get_logger().info(f'Start sorting dataset chronologically using index date') - self._dataset = self._dataset.sort_values('index_date').reset_index() - self.get_logger().info(f'Finish sorting dataset chronologically') + self.get_logger().info(f"Start sorting dataset chronologically using index date") + self._dataset = self._dataset.sort_values("index_date").reset_index() + self.get_logger().info(f"Finish sorting dataset chronologically") super().__init__(*args, **kwargs) @@ -71,13 +74,13 @@ def get_model_name(self): def get_model_folder(self): model_folder = os.path.join(self._evaluation_folder, self.get_model_name()) if not os.path.exists(model_folder): - self.get_logger().info(f'Create the model folder at {model_folder}') + self.get_logger().info(f"Create the model folder at {model_folder}") pathlib.Path(model_folder).mkdir(parents=True, exist_ok=True) return model_folder def get_model_path(self): model_folder = self.get_model_folder() - return os.path.join(model_folder, f'{self.get_model_name()}.h5') + return os.path.join(model_folder, f"{self.get_model_name()}.h5") @abstractmethod def k_fold(self, features, labels): diff --git a/src/cehrbert/evaluations/model_evaluators/sequence_model_evaluators.py b/src/cehrbert/evaluations/model_evaluators/sequence_model_evaluators.py index e9d1319f..c6aa9967 100644 --- a/src/cehrbert/evaluations/model_evaluators/sequence_model_evaluators.py +++ b/src/cehrbert/evaluations/model_evaluators/sequence_model_evaluators.py @@ -1,18 +1,17 @@ +import math from abc import ABC, abstractmethod +from itertools import product -import math from scipy import stats -from itertools import product -from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold, \ - train_test_split +from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit, train_test_split from tensorflow.python.keras.utils.generic_utils import get_custom_objects -from .model_evaluators import AbstractModelEvaluator, get_metrics from ...config.grid_search_config import GridSearchConfig from ...data_generators.learning_objective import post_pad_pre_truncate from ...models.evaluation_models import create_bi_lstm_model from ...models.loss_schedulers import CosineLRSchedule -from ...utils.model_utils import * +from ...utils.model_utils import compute_binary_metrics, multimode, np, os, pd, pickle, save_training_history, tf +from .model_evaluators import AbstractModelEvaluator, get_metrics # Define a list of learning rates to fine-tune the model with LEARNING_RATES = [0.5e-4, 0.8e-4, 1.0e-4, 1.2e-4] @@ -25,24 +24,25 @@ class SequenceModelEvaluator(AbstractModelEvaluator, ABC): def __init__( - self, - epochs, - batch_size, - sequence_model_name: bool = None, - cross_validation_test: bool = False, - grid_search_config: GridSearchConfig = None, - freeze_pretrained_model=False, - multiple_test_run=False, - *args, **kwargs + self, + epochs, + batch_size, + sequence_model_name: bool = None, + cross_validation_test: bool = False, + grid_search_config: GridSearchConfig = None, + freeze_pretrained_model=False, + multiple_test_run=False, + *args, + **kwargs, ): self.get_logger().info( - f'epochs: {epochs}\n' - f'batch_size: {batch_size}\n' - f'sequence_model_name: {sequence_model_name}\n' - f'cross_validation_test: {cross_validation_test}\n' - f'grid_search_config: {grid_search_config}\n' - f'freeze_pretrained_model: {freeze_pretrained_model}\n' - f'multiple_test_run: {multiple_test_run}\n' + f"epochs: {epochs}\n" + f"batch_size: {batch_size}\n" + f"sequence_model_name: {sequence_model_name}\n" + f"cross_validation_test: {cross_validation_test}\n" + f"grid_search_config: {grid_search_config}\n" + f"freeze_pretrained_model: {freeze_pretrained_model}\n" + f"multiple_test_run: {multiple_test_run}\n" ) self._epochs = epochs self._batch_size = batch_size @@ -55,14 +55,15 @@ def __init__( self._grid_search_config = grid_search_config else: self._grid_search_config = GridSearchConfig() - self.get_logger().info(f'grid_search_config is None and initializing default ' - f'GridSearchConfig') + self.get_logger().info(f"grid_search_config is None and initializing default " f"GridSearchConfig") # Set the GPU to memory growth to true to prevent the entire GPU memory from being # allocated try: - [tf.config.experimental.set_memory_growth(device, True) - for device in tf.config.list_physical_devices('GPU')] + [ + tf.config.experimental.set_memory_growth(device, True) + for device in tf.config.list_physical_devices("GPU") + ] except (ValueError, RuntimeError) as error: # Invalid device or cannot modify virtual devices once initialized. tf.print(error) @@ -70,14 +71,15 @@ def __init__( super(SequenceModelEvaluator, self).__init__(*args, **kwargs) def train_model( - self, - training_data: tf.data.Dataset, - val_data: tf.data.Dataset, - model_name, - **kwargs + self, + training_data: tf.data.Dataset, + val_data: tf.data.Dataset, + model_name, + **kwargs, ): """ - Training the model for the keras based sequence models + Training the model for the keras based sequence models. + :param training_data: :param val_data: :param model_name: @@ -88,14 +90,10 @@ def train_model( epochs=self._epochs, validation_data=val_data, callbacks=self._get_callbacks(), - **kwargs + **kwargs, ) - save_training_history( - history, - self.get_model_history_folder(), - model_name - ) + save_training_history(history, self.get_model_history_folder(), model_name) return history def eval_model(self): @@ -111,18 +109,20 @@ def eval_model(self): self.train_model( training_data=train, val_data=val, - model_name=f'{self._sequence_model_name}_{i}') + model_name=f"{self._sequence_model_name}_{i}", + ) compute_binary_metrics( self._model, test, self.get_model_metrics_folder(), - model_name=f'{self._sequence_model_name}_{i}' + model_name=f"{self._sequence_model_name}_{i}", ) def eval_model_cross_validation_test(self): """ + The data is split into train_val and test partitions. - The data is split into train_val and test partitions. It carries out a k-fold cross + It carries out a k-fold cross validation on the train_val partition first, then :return: @@ -132,61 +132,42 @@ def eval_model_cross_validation_test(self): # Hold out 15% of the data for testing if self._is_chronological_test: training_stop = math.ceil(self._dataset.index.stop * 0.85) - training_val_test_set_idx = np.asarray( - range(training_stop) - ) - held_out_set_idx = np.asarray( - range(training_stop, len(self._dataset)) - ) + training_val_test_set_idx = np.asarray(range(training_stop)) + held_out_set_idx = np.asarray(range(training_stop, len(self._dataset))) else: - stratified_splitter = StratifiedShuffleSplit( - n_splits=1, - test_size=0.15, - random_state=10 - ) - training_val_test_set_idx, held_out_set_idx = next( - stratified_splitter.split( - X=labels, - y=labels - ) - ) + stratified_splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.15, random_state=10) + training_val_test_set_idx, held_out_set_idx = next(stratified_splitter.split(X=labels, y=labels)) # Use the remaining 85% of the training data for optimizing - training_val_test_set_inputs = { - k: v[training_val_test_set_idx] - for k, v in features.items() - } + training_val_test_set_inputs = {k: v[training_val_test_set_idx] for k, v in features.items()} training_val_test_set_labels = labels[training_val_test_set_idx] # Conduct a grid search to find the best combination of hyperparameters all_param_configs_pd = self.grid_search_cross_validation( - features=training_val_test_set_inputs, - labels=training_val_test_set_labels + features=training_val_test_set_inputs, labels=training_val_test_set_labels ) # Now that we know the most optimal configurations. Let's retrain the model with the full # set using the most frequent number of epochs in k-fold validation. In case of multiple # modes, we always take the smallest mode - optimal_hyperparam_combination = all_param_configs_pd.sort_values( - 'roc_auc', - ascending=False - ).iloc[0] + optimal_hyperparam_combination = all_param_configs_pd.sort_values("roc_auc", ascending=False).iloc[0] self._epochs = optimal_hyperparam_combination.epoch self._learning_rate = optimal_hyperparam_combination.learning_rate - with tf.device('/CPU:0'): + with tf.device("/CPU:0"): # Train using the full training set - full_training_set = tf.data.Dataset.from_tensor_slices( - (training_val_test_set_inputs, - training_val_test_set_labels) - ).cache().batch(self._batch_size) + full_training_set = ( + tf.data.Dataset.from_tensor_slices((training_val_test_set_inputs, training_val_test_set_labels)) + .cache() + .batch(self._batch_size) + ) for _ in range(self._num_of_folds): # Recreate the model self._model = self._create_model( is_bi_directional=optimal_hyperparam_combination.is_bi_directional, - lstm_unit=optimal_hyperparam_combination.lstm_unit + lstm_unit=optimal_hyperparam_combination.lstm_unit, ) if self._is_transfer_learning: @@ -200,41 +181,36 @@ def eval_model_cross_validation_test(self): self.train_model( training_data=training_data, val_data=training_data.take(10), - model_name=f'{self._sequence_model_name}_final' + model_name=f"{self._sequence_model_name}_final", ) # Construct the held-out tensorflow dataset to calculate the metrics - held_out_set_inputs = { - k: v[held_out_set_idx] - for k, v in features.items() - } + held_out_set_inputs = {k: v[held_out_set_idx] for k, v in features.items()} held_out_set_labels = labels[held_out_set_idx] - with tf.device('/CPU:0'): - hold_out_set = tf.data.Dataset.from_tensor_slices( - (held_out_set_inputs, - held_out_set_labels) - ).cache().batch(self._batch_size) + with tf.device("/CPU:0"): + hold_out_set = ( + tf.data.Dataset.from_tensor_slices((held_out_set_inputs, held_out_set_labels)) + .cache() + .batch(self._batch_size) + ) compute_binary_metrics( self._model, hold_out_set, self.get_model_test_metrics_folder(), evaluation_model_folder=self.get_model_test_prediction_folder(), - model_name=f'{self._sequence_model_name}_final', - calculate_ci=not self._multiple_test_run + model_name=f"{self._sequence_model_name}_final", + calculate_ci=not self._multiple_test_run, ) # If multiple test run is not enabled, we break out of the loop if not self._multiple_test_run: break - def grid_search_cross_validation( - self, - features, - labels - ): + def grid_search_cross_validation(self, features, labels): """ - This method conducts a grid search via cross validation to determine the best combination + This method conducts a grid search via cross validation to determine the best combination. + of hyperparameters :param features: @@ -243,22 +219,22 @@ def grid_search_cross_validation( """ all_param_configs = [] for idx, (lr, is_bi_directional, lstm_unit) in enumerate( - product( - self._grid_search_config.learning_rates, - self._grid_search_config.lstm_directions, - self._grid_search_config.lstm_units - ) + product( + self._grid_search_config.learning_rates, + self._grid_search_config.lstm_directions, + self._grid_search_config.lstm_units, + ) ): # Print out the model hyperparameters - tf.print(f'learning_rate: {lr}') - tf.print(f'is_bi_directional: {is_bi_directional}') - tf.print(f'lstm_unit: {lstm_unit}') + tf.print(f"learning_rate: {lr}") + tf.print(f"is_bi_directional: {is_bi_directional}") + tf.print(f"lstm_unit: {lstm_unit}") # Remember this configuration in a dict param_config = { - 'learning_rate': lr, - 'is_bi_directional': is_bi_directional, - 'lstm_unit': lstm_unit + "learning_rate": lr, + "is_bi_directional": is_bi_directional, + "lstm_unit": lstm_unit, } # Update the learning rate self._learning_rate = lr @@ -269,105 +245,77 @@ def grid_search_cross_validation( # Run the k-fold 10 times until we discover a single mode max_iter = 10 while max_iter > 0: - for i, (train, val, test) in enumerate( - self.k_fold( - features=features, - labels=labels - ) - ): - self._model = self._create_model( - is_bi_directional=is_bi_directional, - lstm_unit=lstm_unit - ) + for i, (train, val, test) in enumerate(self.k_fold(features=features, labels=labels)): + self._model = self._create_model(is_bi_directional=is_bi_directional, lstm_unit=lstm_unit) history = self.train_model( training_data=train, val_data=val, - model_name=f'{self._sequence_model_name}_param_{idx}_iter_{i}' + model_name=f"{self._sequence_model_name}_param_{idx}_iter_{i}", ) # This captures the number of epochs each fold trained - num_of_epochs.append(len(history.history['loss']) - 1) + num_of_epochs.append(len(history.history["loss"]) - 1) fold_metrics = compute_binary_metrics( self._model, test, self.get_model_metrics_folder(), - model_name=f'{self._sequence_model_name}_param_{idx}_iter_{i}', + model_name=f"{self._sequence_model_name}_param_{idx}_iter_{i}", extra_info=param_config, - calculate_ci=False + calculate_ci=False, ) - roc_auc_scores.append(fold_metrics['roc_auc']) + roc_auc_scores.append(fold_metrics["roc_auc"]) max_iter = max_iter - 1 # If we find a single mode, we exit the loop if len(multimode(num_of_epochs)) == 1: self.get_logger().info( - f'Found the best epoch for lr={lr},' - f' is_bi_directional={is_bi_directional}, lstm_unit={lstm_unit}' + f"Found the best epoch for lr={lr}," + f" is_bi_directional={is_bi_directional}, lstm_unit={lstm_unit}" ) break if max_iter == 0: raise RuntimeError( - f'Failed to find the best epoch for lr={lr},' - f' is_bi_directional={is_bi_directional}, lstm_unit={lstm_unit}' + f"Failed to find the best epoch for lr={lr}," + f" is_bi_directional={is_bi_directional}, lstm_unit={lstm_unit}" ) # Add the number of epochs and average roc_auc to this combination - param_config.update({ - 'epoch': stats.mode(num_of_epochs).mode[0], - 'roc_auc': np.mean(roc_auc_scores) - }) - - all_param_configs.append( - param_config + param_config.update( + { + "epoch": stats.mode(num_of_epochs).mode[0], + "roc_auc": np.mean(roc_auc_scores), + } ) + + all_param_configs.append(param_config) # Save all the parameter combinations to the model folder all_param_configs_pd = pd.DataFrame(all_param_configs) all_param_configs_pd.to_parquet( os.path.join( self.get_model_folder(), - f'{self._sequence_model_name}_parameter_combinations.parquet' + f"{self._sequence_model_name}_parameter_combinations.parquet", ) ) return all_param_configs_pd - def k_fold( - self, - features, - labels - ): + def k_fold(self, features, labels): """ - :param features: - :param labels: + :param labels: """ # This preserves the percentage of samples for each class (0 and 1 for binary # classification) if self._k_fold_test: - stratified_splitter = StratifiedKFold( - n_splits=self._num_of_folds, - random_state=10 - ) + stratified_splitter = StratifiedKFold(n_splits=self._num_of_folds, random_state=10) else: - stratified_splitter = StratifiedShuffleSplit( - n_splits=self._num_of_folds, - test_size=0.15, - random_state=10 - ) + stratified_splitter = StratifiedShuffleSplit(n_splits=self._num_of_folds, test_size=0.15, random_state=10) - for train, val_test in stratified_splitter.split( - X=labels, - y=labels - ): + for train, val_test in stratified_splitter.split(X=labels, y=labels): if self._k_fold_test: # further split val_test using a 1:1 ratio between val and test - val, test = train_test_split( - val_test, - test_size=0.5, - random_state=10, - stratify=labels[val_test] - ) + val, test = train_test_split(val_test, test_size=0.5, random_state=10, stratify=labels[val_test]) else: test = val_test val = val_test @@ -380,17 +328,18 @@ def k_fold( val_input = {k: v[val] for k, v in features.items()} test_input = {k: v[test] for k, v in features.items()} - tf.print(f'{self}: The train size is {len(train)}') - tf.print(f'{self}: The val size is {len(val)}') - tf.print(f'{self}: The test size is {len(test)}') + tf.print(f"{self}: The train size is {len(train)}") + tf.print(f"{self}: The val size is {len(val)}") + tf.print(f"{self}: The test size is {len(test)}") - with tf.device('/CPU:0'): - training_set = tf.data.Dataset.from_tensor_slices( - (training_input, labels[train])).cache().batch(self._batch_size) - val_set = tf.data.Dataset.from_tensor_slices( - (val_input, labels[val])).cache().batch(self._batch_size) - test_set = tf.data.Dataset.from_tensor_slices( - (test_input, labels[test])).cache().batch(self._batch_size) + with tf.device("/CPU:0"): + training_set = ( + tf.data.Dataset.from_tensor_slices((training_input, labels[train])).cache().batch(self._batch_size) + ) + val_set = tf.data.Dataset.from_tensor_slices((val_input, labels[val])).cache().batch(self._batch_size) + test_set = ( + tf.data.Dataset.from_tensor_slices((test_input, labels[test])).cache().batch(self._batch_size) + ) yield training_set, val_set, test_set @@ -399,22 +348,22 @@ def get_model_name(self): def _get_callbacks(self): """ - Standard callbacks for the evaluations + Standard callbacks for the evaluations. + :return: """ learning_rate_scheduler = tf.keras.callbacks.LearningRateScheduler( CosineLRSchedule(lr_high=self._learning_rate, lr_low=1e-8, initial_period=10), - verbose=1) - - early_stopping = tf.keras.callbacks.EarlyStopping( - monitor='val_loss', - patience=1, - restore_best_weights=True + verbose=1, ) + + early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=1, restore_best_weights=True) model_checkpoint = tf.keras.callbacks.ModelCheckpoint( filepath=self.get_model_path(), - monitor='val_loss', mode='auto', - save_best_only=True, verbose=1 + monitor="val_loss", + mode="auto", + save_best_only=True, + verbose=1, ) return [learning_rate_scheduler, early_stopping, model_checkpoint] @@ -426,75 +375,70 @@ def extract_model_inputs(self): class BiLstmModelEvaluator(SequenceModelEvaluator): def __init__( - self, - max_seq_length: int, - time_aware_model_path: str, - tokenizer_path: str, - embedding_size: int, - *args, - **kwargs + self, + max_seq_length: int, + time_aware_model_path: str, + tokenizer_path: str, + embedding_size: int, + *args, + **kwargs, ): self._max_seq_length = max_seq_length self._embedding_size = embedding_size self._time_aware_model_path = time_aware_model_path - self._tokenizer = pickle.load(open(tokenizer_path, 'rb')) + self._tokenizer = pickle.load(open(tokenizer_path, "rb")) self.get_logger().info( - f'max_seq_length: {max_seq_length}\n' - f'embedding_size: {embedding_size}\n' - f'time_aware_model_path: {time_aware_model_path}\n' - f'tokenizer_path: {tokenizer_path}\n' + f"max_seq_length: {max_seq_length}\n" + f"embedding_size: {embedding_size}\n" + f"time_aware_model_path: {time_aware_model_path}\n" + f"tokenizer_path: {tokenizer_path}\n" ) super(BiLstmModelEvaluator, self).__init__(*args, **kwargs) - def _create_model( - self, - **kwargs - ): + def _create_model(self, **kwargs): def get_concept_embeddings(): try: another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0") with another_strategy.scope(): time_aware_model = tf.keras.models.load_model( self._time_aware_model_path, - custom_objects=dict(**get_custom_objects()) + custom_objects=dict(**get_custom_objects()), ) - embedding_layer = time_aware_model.get_layer('embedding_layer') + embedding_layer = time_aware_model.get_layer("embedding_layer") return embedding_layer.get_weights()[0] except (IOError, ImportError) as e: - self.get_logger().info( - f'Cannot load the time attention model, return None. Error: {e}' - ) + self.get_logger().info(f"Cannot load the time attention model, return None. Error: {e}") return None embeddings = get_concept_embeddings() strategy = tf.distribute.MirroredStrategy() - self.get_logger().info('Number of devices: {}'.format(strategy.num_replicas_in_sync)) + self.get_logger().info("Number of devices: {}".format(strategy.num_replicas_in_sync)) with strategy.scope(): model = create_bi_lstm_model( self._max_seq_length, self._tokenizer.get_vocab_size(), self._embedding_size, embeddings, - **kwargs + **kwargs, ) model.compile( - loss='binary_crossentropy', + loss="binary_crossentropy", optimizer=tf.keras.optimizers.Adam(1e-4), - metrics=get_metrics() + metrics=get_metrics(), ) return model def extract_model_inputs(self): - token_ids = self._tokenizer.encode( - self._dataset.concept_ids.apply(lambda concept_ids: concept_ids.tolist())) + token_ids = self._tokenizer.encode(self._dataset.concept_ids.apply(lambda concept_ids: concept_ids.tolist())) labels = self._dataset.label.to_numpy() - padded_token_ides = post_pad_pre_truncate(token_ids, self._tokenizer.get_unused_token_id(), - self._max_seq_length) + padded_token_ides = post_pad_pre_truncate( + token_ids, self._tokenizer.get_unused_token_id(), self._max_seq_length + ) inputs = { - 'age': np.expand_dims(self._dataset.age, axis=-1), - 'concept_ids': padded_token_ides + "age": np.expand_dims(self._dataset.age, axis=-1), + "concept_ids": padded_token_ides, } return inputs, labels diff --git a/src/cehrbert/evaluations/transfer_learning_evaluation.py b/src/cehrbert/evaluations/transfer_learning_evaluation.py index e77ed08e..b922275c 100644 --- a/src/cehrbert/evaluations/transfer_learning_evaluation.py +++ b/src/cehrbert/evaluations/transfer_learning_evaluation.py @@ -1,8 +1,8 @@ from .evaluation import main from .evaluation_parse_args import create_evaluation_args -TRAINING_PERCENTAGE = 'training_percentage' -IS_TRANSFER_LEARNING = 'is_transfer_learning' +TRAINING_PERCENTAGE = "training_percentage" +IS_TRANSFER_LEARNING = "is_transfer_learning" PERCENTAGES = [0.05, 0.1, 0.2, 0.4, 0.8] if __name__ == "__main__": diff --git a/src/cehrbert/keras_transformer/bert.py b/src/cehrbert/keras_transformer/bert.py index 874ede39..94e9e340 100644 --- a/src/cehrbert/keras_transformer/bert.py +++ b/src/cehrbert/keras_transformer/bert.py @@ -15,14 +15,15 @@ """ import tensorflow as tf -from tensorflow.keras.losses import binary_crossentropy from tensorflow.keras import backend as K +from tensorflow.keras.losses import binary_crossentropy from tensorflow.keras.utils import get_custom_objects def masked_perplexity(y_true, y_pred): """ - Masked version of popular metric for evaluating performance of + Masked version of popular metric for evaluating performance of. + language modelling architectures. It assumes that y_pred has shape (batch_size, sequence_length, 2), containing both - the original token ids @@ -34,17 +35,16 @@ def masked_perplexity(y_true, y_pred): More info: http://cs224d.stanford.edu/lecture_notes/LectureNotes4.pdf """ y_true_value = y_true[:, :, 0] - mask = K.cast(y_true[:, :, 1], dtype='float32') + mask = K.cast(y_true[:, :, 1], dtype="float32") cross_entropy = K.sparse_categorical_crossentropy(y_true_value, y_pred) - batch_perplexities = K.exp( - K.sum(mask * cross_entropy, axis=-1) / (K.sum(mask, axis=-1) + 1e-6)) + batch_perplexities = K.exp(K.sum(mask * cross_entropy, axis=-1) / (K.sum(mask, axis=-1) + 1e-6)) return K.mean(batch_perplexities) class MaskedMeanSquaredError(object): def __call__(self, y_true, y_pred): - y_true_val = K.cast(y_true[:, :, 0], dtype='float32') - mask = K.cast(y_true[:, :, 1], dtype='float32') + y_true_val = K.cast(y_true[:, :, 0], dtype="float32") + mask = K.cast(y_true[:, :, 1], dtype="float32") num_items_masked = tf.reduce_sum(mask, axis=-1) + 1e-6 masked_mse = tf.reduce_sum(tf.square(y_true_val - y_pred) * mask, axis=-1) / num_items_masked @@ -54,7 +54,8 @@ def __call__(self, y_true, y_pred): class MaskedPenalizedSparseCategoricalCrossentropy(object): """ - Masked cross-entropy (see `masked_perplexity` for more details) + Masked cross-entropy (see `masked_perplexity` for more details). + loss function with penalized confidence. Combines two loss functions: cross-entropy and negative entropy (weighted by `penalty_weight` parameter), following paper @@ -68,47 +69,41 @@ class MaskedPenalizedSparseCategoricalCrossentropy(object): """ def __init__(self, penalty_weight: float): - self.__name__ = 'MaskedPenalizedSparseCategoricalCrossentropy' + self.__name__ = "MaskedPenalizedSparseCategoricalCrossentropy" self.penalty_weight = penalty_weight def __call__(self, y_true, y_pred): - y_true_val = K.cast(y_true[:, :, 0], dtype='float32') - mask = K.cast(y_true[:, :, 1], dtype='float32') + y_true_val = K.cast(y_true[:, :, 0], dtype="float32") + mask = K.cast(y_true[:, :, 1], dtype="float32") # masked per-sample means of each loss num_items_masked = K.sum(mask, axis=-1) + 1e-6 masked_cross_entropy = ( - K.sum(mask * K.sparse_categorical_crossentropy(y_true_val, y_pred), - axis=-1) - / num_items_masked) - masked_entropy = ( - K.sum(mask * -K.sum(y_pred * K.log(y_pred), axis=-1), axis=-1) - / num_items_masked) + K.sum(mask * K.sparse_categorical_crossentropy(y_true_val, y_pred), axis=-1) / num_items_masked + ) + masked_entropy = K.sum(mask * -K.sum(y_pred * K.log(y_pred), axis=-1), axis=-1) / num_items_masked return masked_cross_entropy - self.penalty_weight * masked_entropy def get_config(self): - return { - 'penalty_weight': self.penalty_weight - } + return {"penalty_weight": self.penalty_weight} class SequenceCrossentropy(object): def __init__(self): - self.__name__ = 'SequenceCrossentropy' + self.__name__ = "SequenceCrossentropy" def __call__(self, y_true, y_pred): - y_true_val = K.cast(y_true[:, :, 0], dtype='float32') - mask = K.cast(y_true[:, :, 1], dtype='float32') + y_true_val = K.cast(y_true[:, :, 0], dtype="float32") + mask = K.cast(y_true[:, :, 1], dtype="float32") num_items_masked = K.sum(mask, axis=-1) + 1e-6 - loss = K.sum( - binary_crossentropy(y_true_val[:, :, tf.newaxis], y_pred) * mask, - axis=-1) + loss = K.sum(binary_crossentropy(y_true_val[:, :, tf.newaxis], y_pred) * mask, axis=-1) return loss / num_items_masked -get_custom_objects().update({ - 'MaskedPenalizedSparseCategoricalCrossentropy': - MaskedPenalizedSparseCategoricalCrossentropy, - 'masked_perplexity': masked_perplexity, - 'SequenceCrossentropy': SequenceCrossentropy -}) +get_custom_objects().update( + { + "MaskedPenalizedSparseCategoricalCrossentropy": MaskedPenalizedSparseCategoricalCrossentropy, + "masked_perplexity": masked_perplexity, + "SequenceCrossentropy": SequenceCrossentropy, + } +) diff --git a/src/cehrbert/keras_transformer/extras.py b/src/cehrbert/keras_transformer/extras.py index b42176ba..a037619f 100644 --- a/src/cehrbert/keras_transformer/extras.py +++ b/src/cehrbert/keras_transformer/extras.py @@ -1,30 +1,33 @@ """ -Tools that are not necessary for the Transformer by itself, but might be +Tools that are not necessary for the Transformer by itself, but might be. + useful in building models with it. """ + import math -import tensorflow as tf -from tensorflow.keras import activations, regularizers +import tensorflow as tf +from tensorflow.keras import activations from tensorflow.keras import backend as K - +from tensorflow.keras import regularizers from tensorflow.keras.utils import get_custom_objects class ReusableEmbedding(tf.keras.layers.Embedding): """ - A "reusable" form of the Embedding layer, which returns its + A "reusable" form of the Embedding layer, which returns its. + full embedding matrix as one of the outputs. This is necessary to guarantee correct work of Keras when the matrix is being re-used again in TiedOutputEmbedding layer. """ + def call(self, inputs): result = super().call(inputs) return [result, self.embeddings] def compute_output_shape(self, input_shape): - return [super().compute_output_shape(input_shape), - K.int_shape(self.embeddings)] + return [super().compute_output_shape(input_shape), K.int_shape(self.embeddings)] def compute_mask(self, inputs, mask=None): return [super().compute_mask(inputs, mask), None] @@ -32,7 +35,8 @@ def compute_mask(self, inputs, mask=None): class TiedOutputEmbedding(tf.keras.layers.Layer): """ - Allows to reuse the same word embedding matrix both for the input and + Allows to reuse the same word embedding matrix both for the input and. + the output layers of the network. This is called Weight Tying and is proven to improve performance of neural network language models, as well as decrease their number @@ -50,11 +54,16 @@ class TiedOutputEmbedding(tf.keras.layers.Layer): https://arxiv.org/abs/1611.01462 https://blog.openai.com/language-unsupervised/ """ - def __init__(self, activation=None, - add_biases=False, projection_regularizer=None, - projection_dropout: float = 0.0, - scaled_attention=False, - **kwargs): + + def __init__( + self, + activation=None, + add_biases=False, + projection_regularizer=None, + projection_dropout: float = 0.0, + scaled_attention=False, + **kwargs, + ): self.activation = activations.get(activation) self.add_biases = add_biases self.projection_regularizer = regularizers.get(projection_regularizer) @@ -68,10 +77,10 @@ def get_config(self): config, activation=activations.serialize(self.activation), add_biases=self.add_biases, - projection_regularizer=regularizers.serialize( - self.projection_regularizer), + projection_regularizer=regularizers.serialize(self.projection_regularizer), projection_dropout=self.projection_dropout, - scaled_attention=self.scaled_attention) + scaled_attention=self.scaled_attention, + ) # noinspection PyAttributeOutsideInit def build(self, input_shape): @@ -79,17 +88,19 @@ def build(self, input_shape): emb_input_dim, emb_output_dim = embedding_matrix_shape assert len(main_input_shape) == 3 self.projection = self.add_weight( - name='kernel', + name="kernel", shape=(main_input_shape[-1], emb_output_dim), - initializer='glorot_uniform', + initializer="glorot_uniform", regularizer=self.projection_regularizer, - trainable=True) + trainable=True, + ) if self.add_biases: self.biases = self.add_weight( - name='biases', + name="biases", shape=(emb_output_dim,), - initializer='zeros', - trainable=True) + initializer="zeros", + trainable=True, + ) return super().build(input_shape) def call(self, inputs, **kwargs): @@ -97,16 +108,15 @@ def call(self, inputs, **kwargs): input_shape_tensor = K.shape(main_input) last_input_dim = K.int_shape(main_input)[-1] emb_input_dim, emb_output_dim = K.int_shape(embedding_matrix) - projected = K.dot(K.reshape(main_input, (-1, last_input_dim)), - self.projection) + projected = K.dot(K.reshape(main_input, (-1, last_input_dim)), self.projection) if self.add_biases: - projected = K.bias_add(projected, self.biases, - data_format='channels_last') + projected = K.bias_add(projected, self.biases, data_format="channels_last") if 0 < self.projection_dropout < 1: projected = K.in_train_phase( lambda: K.dropout(projected, self.projection_dropout), projected, - training=kwargs.get('training')) + training=kwargs.get("training"), + ) attention = K.dot(projected, K.transpose(embedding_matrix)) if self.scaled_attention: # scaled dot-product attention, described in @@ -115,9 +125,8 @@ def call(self, inputs, **kwargs): attention = attention / sqrt_d result = K.reshape( self.activation(attention), - (input_shape_tensor[0], - input_shape_tensor[1], - emb_input_dim)) + (input_shape_tensor[0], input_shape_tensor[1], emb_input_dim), + ) return result def compute_output_shape(self, input_shape): @@ -126,7 +135,9 @@ def compute_output_shape(self, input_shape): return main_input_shape[0], main_input_shape[1], emb_input_dim -get_custom_objects().update({ - 'ReusableEmbedding': ReusableEmbedding, - 'TiedOutputEmbedding': TiedOutputEmbedding, -}) +get_custom_objects().update( + { + "ReusableEmbedding": ReusableEmbedding, + "TiedOutputEmbedding": TiedOutputEmbedding, + } +) diff --git a/src/cehrbert/keras_transformer/position.py b/src/cehrbert/keras_transformer/position.py index 13c1e5e0..bb628f24 100644 --- a/src/cehrbert/keras_transformer/position.py +++ b/src/cehrbert/keras_transformer/position.py @@ -5,6 +5,7 @@ class TransformerCoordinateEmbedding(tf.keras.layers.Layer): """ Represents trainable positional embeddings for the Transformer model: + 1. word position embeddings - one for each position in the sequence. 2. depth embeddings - one for each block of the model Calling the layer with the Transformer's input will return a new input @@ -17,7 +18,7 @@ def __init__(self, max_transformer_depth: int, **kwargs): def get_config(self): config = super().get_config() - config['max_transformer_depth'] = self.max_depth + config["max_transformer_depth"] = self.max_depth return config # noinspection PyAttributeOutsideInit @@ -25,27 +26,30 @@ def build(self, input_shape): sequence_length, d_model = input_shape[-2:] self.word_position_embeddings = self.add_weight( shape=(sequence_length, d_model), - initializer='uniform', - name='word_position_embeddings', - trainable=True) + initializer="uniform", + name="word_position_embeddings", + trainable=True, + ) self.depth_embeddings = self.add_weight( shape=(self.max_depth, d_model), - initializer='uniform', - name='depth_position_embeddings', - trainable=True) + initializer="uniform", + name="depth_position_embeddings", + trainable=True, + ) super().build(input_shape) def call(self, inputs, **kwargs): - depth = kwargs.get('step') + depth = kwargs.get("step") if depth is None: - raise ValueError("Please, provide current Transformer's step" - "using 'step' keyword argument.") + raise ValueError("Please, provide current Transformer's step" "using 'step' keyword argument.") result = inputs + self.word_position_embeddings if depth is not None: result = result + self.depth_embeddings[depth] return result -get_custom_objects().update({ - 'TransformerCoordinateEmbedding': TransformerCoordinateEmbedding, -}) +get_custom_objects().update( + { + "TransformerCoordinateEmbedding": TransformerCoordinateEmbedding, + } +) diff --git a/src/cehrbert/med_extension/schema_extension.py b/src/cehrbert/med_extension/schema_extension.py index 5ac47eb7..13c20c73 100644 --- a/src/cehrbert/med_extension/schema_extension.py +++ b/src/cehrbert/med_extension/schema_extension.py @@ -1,15 +1,19 @@ -from typing import TypedDict, List, Mapping, Any, Union, Optional -from typing_extensions import NotRequired import datetime +from typing import Any, List, Mapping, Optional, TypedDict, Union + +from typing_extensions import NotRequired -Event = TypedDict('Event', { - 'time': NotRequired[datetime.datetime], - 'code': str, - 'text_value': NotRequired[Optional[str]], - 'numeric_value': NotRequired[Optional[float]], - 'datetime_value': NotRequired[datetime.datetime], - 'properties': NotRequired[Optional[Mapping[str, Any]]], -}) +Event = TypedDict( + "Event", + { + "time": NotRequired[datetime.datetime], + "code": str, + "text_value": NotRequired[Optional[str]], + "numeric_value": NotRequired[Optional[float]], + "datetime_value": NotRequired[datetime.datetime], + "properties": NotRequired[Optional[Mapping[str, Any]]], + }, +) class Visit(TypedDict): diff --git a/src/cehrbert/models/bert_models.py b/src/cehrbert/models/bert_models.py index 01a7f540..2df4e126 100644 --- a/src/cehrbert/models/bert_models.py +++ b/src/cehrbert/models/bert_models.py @@ -1,31 +1,33 @@ import tensorflow as tf from ..keras_transformer.extras import ReusableEmbedding, TiedOutputEmbedding - +from ..utils.model_utils import create_concept_mask from .layers.custom_layers import ( - VisitEmbeddingLayer, Encoder, - PositionalEncodingLayer, TimeEmbeddingLayer, - ConceptValueTransformationLayer + ConceptValueTransformationLayer, + Encoder, + PositionalEncodingLayer, + TimeEmbeddingLayer, + VisitEmbeddingLayer, ) -from ..utils.model_utils import create_concept_mask def transformer_bert_model( - max_seq_length: int, - vocab_size: int, - embedding_size: int, - depth: int, - num_heads: int, - transformer_dropout: float = 0.1, - embedding_dropout: float = 0.6, - l2_reg_penalty: float = 1e-4, - use_time_embedding: bool = False, - time_embeddings_size: int = 16, - use_behrt: bool = False, - include_prolonged_length_stay: bool = False + max_seq_length: int, + vocab_size: int, + embedding_size: int, + depth: int, + num_heads: int, + transformer_dropout: float = 0.1, + embedding_dropout: float = 0.6, + l2_reg_penalty: float = 1e-4, + use_time_embedding: bool = False, + time_embeddings_size: int = 16, + use_behrt: bool = False, + include_prolonged_length_stay: bool = False, ): """ - Builds a BERT-based model (Bidirectional Encoder Representations + Builds a BERT-based model (Bidirectional Encoder Representations. + from Transformers) following paper "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" (https://arxiv.org/abs/1810.04805) @@ -35,156 +37,108 @@ def transformer_bert_model( or a vanilla Transformer (2017) to do the job (the original paper uses vanilla Transformer). """ - masked_concept_ids = tf.keras.layers.Input( - shape=(max_seq_length,), - dtype='int32', - name='masked_concept_ids' - ) + masked_concept_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="masked_concept_ids") - visit_segments = tf.keras.layers.Input( - shape=(max_seq_length,), - dtype='int32', - name='visit_segments' - ) + visit_segments = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="visit_segments") - visit_concept_orders = tf.keras.layers.Input( - shape=(max_seq_length,), - dtype='int32', - name='visit_concept_orders' - ) + visit_concept_orders = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="visit_concept_orders") - mask = tf.keras.layers.Input( - shape=(max_seq_length,), - dtype='int32', - name='mask' - ) + mask = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="mask") - concept_value_masks = tf.keras.layers.Input( - shape=(max_seq_length,), - dtype='int32', - name='concept_value_masks' - ) - concept_values = tf.keras.layers.Input( - shape=(max_seq_length,), - dtype='float32', - name='concept_values' - ) + concept_value_masks = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="concept_value_masks") + concept_values = tf.keras.layers.Input(shape=(max_seq_length,), dtype="float32", name="concept_values") concept_mask = create_concept_mask(mask, max_seq_length) default_inputs = [ - masked_concept_ids, visit_segments, - visit_concept_orders, mask, - concept_value_masks, concept_values + masked_concept_ids, + visit_segments, + visit_concept_orders, + mask, + concept_value_masks, + concept_values, ] - l2_regularizer = (tf.keras.regularizers.l2(l2_reg_penalty) if l2_reg_penalty else None) + l2_regularizer = tf.keras.regularizers.l2(l2_reg_penalty) if l2_reg_penalty else None embedding_layer = ReusableEmbedding( - vocab_size, embedding_size, + vocab_size, + embedding_size, input_length=max_seq_length, - name='concept_embeddings', - embeddings_regularizer=l2_regularizer + name="concept_embeddings", + embeddings_regularizer=l2_regularizer, ) - visit_segment_layer = VisitEmbeddingLayer( - visit_order_size=3, - embedding_size=embedding_size, - name='visit_segment' - ) + visit_segment_layer = VisitEmbeddingLayer(visit_order_size=3, embedding_size=embedding_size, name="visit_segment") concept_value_transformation_layer = ConceptValueTransformationLayer( - embedding_size=embedding_size, - name='concept_value_transformation_layer' + embedding_size=embedding_size, name="concept_value_transformation_layer" ) encoder = Encoder( - name='encoder', + name="encoder", num_layers=depth, d_model=embedding_size, num_heads=num_heads, - dropout_rate=transformer_dropout) + dropout_rate=transformer_dropout, + ) output_layer = TiedOutputEmbedding( projection_regularizer=l2_regularizer, projection_dropout=embedding_dropout, add_biases=use_time_embedding, - name='concept_prediction_logits') - - softmax_layer = tf.keras.layers.Softmax( - name='concept_predictions' + name="concept_prediction_logits", ) - next_step_input, embedding_matrix = embedding_layer( - masked_concept_ids - ) + softmax_layer = tf.keras.layers.Softmax(name="concept_predictions") + + next_step_input, embedding_matrix = embedding_layer(masked_concept_ids) # Transform the concept embeddings by combining their concept embeddings with the # corresponding val next_step_input = concept_value_transformation_layer( concept_embeddings=next_step_input, concept_values=concept_values, - concept_value_masks=concept_value_masks + concept_value_masks=concept_value_masks, ) if use_behrt: - ages = tf.keras.layers.Input(shape=(max_seq_length,), dtype='int32', - name='ages') + ages = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="ages") default_inputs.extend([ages]) age_embedding_layer = TimeEmbeddingLayer(embedding_size=embedding_size) next_step_input = next_step_input + age_embedding_layer(ages) - positional_encoding_layer = PositionalEncodingLayer( - embedding_size=embedding_size - ) + positional_encoding_layer = PositionalEncodingLayer(embedding_size=embedding_size) next_step_input += positional_encoding_layer(visit_concept_orders) elif use_time_embedding: - time_stamps = tf.keras.layers.Input( - shape=(max_seq_length,), - dtype='int32', - name='time_stamps') - ages = tf.keras.layers.Input( - shape=(max_seq_length,), - dtype='int32', - name='ages') + time_stamps = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="time_stamps") + ages = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="ages") default_inputs.extend([time_stamps, ages]) # # define the time embedding layer for absolute time stamps (since 1970) - time_embedding_layer = TimeEmbeddingLayer( - embedding_size=time_embeddings_size, - name='time_embedding_layer') + time_embedding_layer = TimeEmbeddingLayer(embedding_size=time_embeddings_size, name="time_embedding_layer") # define the age embedding layer for the age w.r.t the medical record - age_embedding_layer = TimeEmbeddingLayer( - embedding_size=time_embeddings_size, - name='age_embedding_layer') + age_embedding_layer = TimeEmbeddingLayer(embedding_size=time_embeddings_size, name="age_embedding_layer") positional_encoding_layer = PositionalEncodingLayer( - embedding_size=time_embeddings_size, - name='positional_encoding_layer' + embedding_size=time_embeddings_size, name="positional_encoding_layer" ) - scale_back_concat_layer = tf.keras.layers.Dense( - embedding_size, - activation='tanh', - name='scale_pat_seq_layer' - ) - time_embeddings = time_embedding_layer( - time_stamps - ) - age_embeddings = age_embedding_layer( - ages - ) - positional_encodings = positional_encoding_layer( - visit_concept_orders - ) + scale_back_concat_layer = tf.keras.layers.Dense(embedding_size, activation="tanh", name="scale_pat_seq_layer") + time_embeddings = time_embedding_layer(time_stamps) + age_embeddings = age_embedding_layer(ages) + positional_encodings = positional_encoding_layer(visit_concept_orders) next_step_input = scale_back_concat_layer( tf.concat( - [next_step_input, time_embeddings, age_embeddings, positional_encodings], - axis=-1 + [ + next_step_input, + time_embeddings, + age_embeddings, + positional_encodings, + ], + axis=-1, ) ) else: - positional_encoding_layer = PositionalEncodingLayer( - embedding_size=embedding_size - ) + positional_encoding_layer = PositionalEncodingLayer(embedding_size=embedding_size) next_step_input += positional_encoding_layer(visit_concept_orders) # Building a Vanilla Transformer (described in @@ -193,26 +147,23 @@ def transformer_bert_model( next_step_input, _ = encoder(next_step_input, concept_mask) - concept_predictions = softmax_layer( - output_layer([next_step_input, embedding_matrix])) + concept_predictions = softmax_layer(output_layer([next_step_input, embedding_matrix])) outputs = [concept_predictions] if include_prolonged_length_stay: - mask_embeddings = tf.tile(tf.expand_dims(mask == 0, -1, name='bert_expand_prolonged'), - [1, 1, embedding_size], name='bert_tile_prolonged') - mask_embeddings = tf.cast(mask_embeddings, dtype=tf.float32, name='bert_cast_prolonged') - contextualized_embeddings = tf.math.multiply(next_step_input, mask_embeddings, - name='bert_multiply_prolonged') + mask_embeddings = tf.tile( + tf.expand_dims(mask == 0, -1, name="bert_expand_prolonged"), + [1, 1, embedding_size], + name="bert_tile_prolonged", + ) + mask_embeddings = tf.cast(mask_embeddings, dtype=tf.float32, name="bert_cast_prolonged") + contextualized_embeddings = tf.math.multiply(next_step_input, mask_embeddings, name="bert_multiply_prolonged") summed_contextualized_embeddings = tf.reduce_sum(contextualized_embeddings, axis=-1) - prolonged_length_stay_prediction = tf.keras.layers.Dense(1, - name='prolonged_length_stay', - activation='sigmoid') + prolonged_length_stay_prediction = tf.keras.layers.Dense(1, name="prolonged_length_stay", activation="sigmoid") outputs.append(prolonged_length_stay_prediction(summed_contextualized_embeddings)) - model = tf.keras.Model( - inputs=default_inputs, - outputs=outputs) + model = tf.keras.Model(inputs=default_inputs, outputs=outputs) return model diff --git a/src/cehrbert/models/bert_models_visit_prediction.py b/src/cehrbert/models/bert_models_visit_prediction.py index 341f1628..be252ad8 100644 --- a/src/cehrbert/models/bert_models_visit_prediction.py +++ b/src/cehrbert/models/bert_models_visit_prediction.py @@ -1,30 +1,33 @@ import tensorflow as tf from ..keras_transformer.extras import ReusableEmbedding, TiedOutputEmbedding - from .layers.custom_layers import ( - VisitEmbeddingLayer, Encoder, DecoderLayer, - PositionalEncodingLayer, TimeEmbeddingLayer, - ConceptValueTransformationLayer + ConceptValueTransformationLayer, + DecoderLayer, + Encoder, + PositionalEncodingLayer, + TimeEmbeddingLayer, + VisitEmbeddingLayer, ) def transformer_bert_model_visit_prediction( - max_seq_length: int, - concept_vocab_size: int, - visit_vocab_size: int, - embedding_size: int, - depth: int, - num_heads: int, - transformer_dropout: float = 0.1, - embedding_dropout: float = 0.6, - l2_reg_penalty: float = 1e-4, - use_time_embedding: bool = False, - use_behrt: bool = False, - time_embeddings_size: int = 16 + max_seq_length: int, + concept_vocab_size: int, + visit_vocab_size: int, + embedding_size: int, + depth: int, + num_heads: int, + transformer_dropout: float = 0.1, + embedding_dropout: float = 0.6, + l2_reg_penalty: float = 1e-4, + use_time_embedding: bool = False, + use_behrt: bool = False, + time_embeddings_size: int = 16, ): """ - Builds a BERT-based model (Bidirectional Encoder Representations + Builds a BERT-based model (Bidirectional Encoder Representations. + from Transformers) following paper "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" (https://arxiv.org/abs/1810.04805) @@ -34,111 +37,80 @@ def transformer_bert_model_visit_prediction( or a vanilla Transformer (2017) to do the job (the original paper uses vanilla Transformer). """ - masked_concept_ids = tf.keras.Input( - shape=[max_seq_length], - dtype='int32', - name='masked_concept_ids' - ) - visit_segments = tf.keras.Input( - shape=[max_seq_length], - dtype='int32', - name='visit_segments' - ) - visit_concept_orders = tf.keras.Input( - shape=[max_seq_length], - dtype='int32', - name='visit_concept_orders' - ) - mask = tf.keras.Input( - shape=[max_seq_length], - dtype='int32', - name='mask' - ) - masked_visit_concepts = tf.keras.Input( - shape=[max_seq_length], - dtype='int32', - name='masked_visit_concepts' - ) - mask_visit = tf.keras.Input( - shape=[max_seq_length], - dtype='int32', - name='mask_visit' - ) - concept_value_masks = tf.keras.Input( - shape=[max_seq_length], - dtype='int32', - name='concept_value_masks' - ) - concept_values = tf.keras.Input( - shape=[max_seq_length], - dtype='float32', - name='concept_values' - ) + masked_concept_ids = tf.keras.Input(shape=[max_seq_length], dtype="int32", name="masked_concept_ids") + visit_segments = tf.keras.Input(shape=[max_seq_length], dtype="int32", name="visit_segments") + visit_concept_orders = tf.keras.Input(shape=[max_seq_length], dtype="int32", name="visit_concept_orders") + mask = tf.keras.Input(shape=[max_seq_length], dtype="int32", name="mask") + masked_visit_concepts = tf.keras.Input(shape=[max_seq_length], dtype="int32", name="masked_visit_concepts") + mask_visit = tf.keras.Input(shape=[max_seq_length], dtype="int32", name="mask_visit") + concept_value_masks = tf.keras.Input(shape=[max_seq_length], dtype="int32", name="concept_value_masks") + concept_values = tf.keras.Input(shape=[max_seq_length], dtype="float32", name="concept_values") default_inputs = [ - masked_concept_ids, visit_segments, - visit_concept_orders, mask, - masked_visit_concepts, mask_visit, - concept_value_masks, concept_values + masked_concept_ids, + visit_segments, + visit_concept_orders, + mask, + masked_visit_concepts, + mask_visit, + concept_value_masks, + concept_values, ] - l2_regularizer = (tf.keras.regularizers.l2(l2_reg_penalty) if l2_reg_penalty else None) + l2_regularizer = tf.keras.regularizers.l2(l2_reg_penalty) if l2_reg_penalty else None concept_embedding_layer = ReusableEmbedding( - concept_vocab_size, embedding_size, + concept_vocab_size, + embedding_size, input_length=max_seq_length, - name='concept_embeddings', + name="concept_embeddings", # Regularization is based on paper "A Comparative Study on # Regularization Strategies for Embedding-based Neural Networks" # https://arxiv.org/pdf/1508.03721.pdf - embeddings_regularizer=l2_regularizer + embeddings_regularizer=l2_regularizer, ) visit_embedding_layer = ReusableEmbedding( - visit_vocab_size, embedding_size, + visit_vocab_size, + embedding_size, input_length=max_seq_length, - name='visit_embeddings', - embeddings_regularizer=l2_regularizer + name="visit_embeddings", + embeddings_regularizer=l2_regularizer, ) visit_segment_layer = VisitEmbeddingLayer( - visit_order_size=3, - embedding_size=embedding_size, - name='visit_segment_layer' + visit_order_size=3, embedding_size=embedding_size, name="visit_segment_layer" ) concept_value_transformation_layer = ConceptValueTransformationLayer( - embedding_size=embedding_size, - name='concept_value_transformation_layer' + embedding_size=embedding_size, name="concept_value_transformation_layer" ) encoder = Encoder( - name='encoder', + name="encoder", num_layers=depth, d_model=embedding_size, num_heads=num_heads, - dropout_rate=transformer_dropout + dropout_rate=transformer_dropout, ) - decoder_layer = DecoderLayer( - d_model=embedding_size, - num_heads=num_heads, - dff=512 - ) + decoder_layer = DecoderLayer(d_model=embedding_size, num_heads=num_heads, dff=512) output_layer_1 = TiedOutputEmbedding( projection_regularizer=l2_regularizer, projection_dropout=embedding_dropout, - name='concept_prediction_logits') + name="concept_prediction_logits", + ) output_layer_2 = TiedOutputEmbedding( projection_regularizer=l2_regularizer, projection_dropout=embedding_dropout, - name='visit_prediction_logits') + name="visit_prediction_logits", + ) - concept_softmax_layer = tf.keras.layers.Softmax(name='concept_predictions') + concept_softmax_layer = tf.keras.layers.Softmax(name="concept_predictions") - visit_softmax_layer = tf.keras.layers.Softmax(name='visit_predictions') + visit_softmax_layer = tf.keras.layers.Softmax(name="visit_predictions") # embeddings for encoder input input_for_encoder, concept_embedding_matrix = concept_embedding_layer(masked_concept_ids) @@ -148,7 +120,7 @@ def transformer_bert_model_visit_prediction( input_for_encoder = concept_value_transformation_layer( concept_embeddings=input_for_encoder, concept_values=concept_values, - concept_value_masks=concept_value_masks + concept_value_masks=concept_value_masks, ) # embeddings for decoder input @@ -159,47 +131,35 @@ def transformer_bert_model_visit_prediction( input_for_encoder = visit_segment_layer([visit_segments, input_for_encoder]) if use_behrt: - ages = tf.keras.layers.Input(shape=(max_seq_length,), dtype='int32', - name='ages') + ages = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="ages") default_inputs.extend([ages]) age_embedding_layer = TimeEmbeddingLayer(embedding_size=embedding_size) input_for_encoder = input_for_encoder + age_embedding_layer(ages) - positional_encoding_layer = PositionalEncodingLayer( - embedding_size=embedding_size - ) + positional_encoding_layer = PositionalEncodingLayer(embedding_size=embedding_size) input_for_encoder += positional_encoding_layer(visit_concept_orders) elif use_time_embedding: # additional inputs with time embeddings - time_stamps = tf.keras.layers.Input(shape=(max_seq_length,), - dtype='int32', - name='time_stamps') - ages = tf.keras.layers.Input(shape=(max_seq_length,), - dtype='int32', - name='ages') + time_stamps = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="time_stamps") + ages = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="ages") default_inputs.extend([time_stamps, ages]) # # define the time embedding layer for absolute time stamps (since 1970) - time_embedding_layer = TimeEmbeddingLayer(embedding_size=time_embeddings_size, - name='time_embedding_layer') + time_embedding_layer = TimeEmbeddingLayer(embedding_size=time_embeddings_size, name="time_embedding_layer") # define the age embedding layer for the age w.r.t the medical record - age_embedding_layer = TimeEmbeddingLayer(embedding_size=time_embeddings_size, - name='age_embedding_layer') + age_embedding_layer = TimeEmbeddingLayer(embedding_size=time_embeddings_size, name="age_embedding_layer") positional_encoding_layer = PositionalEncodingLayer( - embedding_size=embedding_size, - name='positional_encoding_layer' + embedding_size=embedding_size, name="positional_encoding_layer" ) # dense layer for rescale the patient sequence embeddings back to the original size scale_back_patient_seq_concat_layer = tf.keras.layers.Dense( - embedding_size, - activation='tanh', - name='scale_pat_seq_layer') + embedding_size, activation="tanh", name="scale_pat_seq_layer" + ) # dense layer for rescale the visit sequence embeddings back to the original size scale_back_visit_seq_concat_layer = tf.keras.layers.Dense( - embedding_size, - activation='tanh', - name='scale_visit_seq_layer') + embedding_size, activation="tanh", name="scale_visit_seq_layer" + ) time_embeddings = time_embedding_layer(time_stamps) age_embeddings = age_embedding_layer(ages) @@ -207,20 +167,25 @@ def transformer_bert_model_visit_prediction( input_for_encoder = scale_back_patient_seq_concat_layer( tf.concat( - [input_for_encoder, time_embeddings, age_embeddings, positional_encodings], + [ + input_for_encoder, + time_embeddings, + age_embeddings, + positional_encodings, + ], axis=-1, - name='concat_for_encoder') + name="concat_for_encoder", + ) ) input_for_decoder = scale_back_visit_seq_concat_layer( tf.concat( [input_for_decoder, time_embeddings, age_embeddings], axis=-1, - name='concat_for_decoder') + name="concat_for_decoder", + ) ) else: - positional_encoding_layer = PositionalEncodingLayer( - embedding_size=embedding_size - ) + positional_encoding_layer = PositionalEncodingLayer(embedding_size=embedding_size) input_for_encoder += positional_encoding_layer(visit_concept_orders) input_for_encoder, att_weights = encoder( @@ -228,21 +193,16 @@ def transformer_bert_model_visit_prediction( mask[:, tf.newaxis, :], ) - concept_predictions = concept_softmax_layer( - output_layer_1([input_for_encoder, concept_embedding_matrix])) + concept_predictions = concept_softmax_layer(output_layer_1([input_for_encoder, concept_embedding_matrix])) decoder_output, _, _ = decoder_layer( input_for_decoder, input_for_encoder, mask[:, tf.newaxis, :], - mask_visit[:, tf.newaxis, :] - ) - visit_predictions = visit_softmax_layer( - output_layer_2([decoder_output, visit_embedding_matrix]) + mask_visit[:, tf.newaxis, :], ) + visit_predictions = visit_softmax_layer(output_layer_2([decoder_output, visit_embedding_matrix])) - model = tf.keras.Model( - inputs=default_inputs, - outputs=[concept_predictions, visit_predictions]) + model = tf.keras.Model(inputs=default_inputs, outputs=[concept_predictions, visit_predictions]) return model diff --git a/src/cehrbert/models/evaluation_models.py b/src/cehrbert/models/evaluation_models.py index 707cf477..897aac97 100644 --- a/src/cehrbert/models/evaluation_models.py +++ b/src/cehrbert/models/evaluation_models.py @@ -1,44 +1,38 @@ import tensorflow as tf - from tensorflow.keras.initializers import Constant from tensorflow.keras.models import Model -from .layers.custom_layers import get_custom_objects -from .layers.custom_layers import ConvolutionBertLayer from .bert_models_visit_prediction import transformer_bert_model_visit_prediction +from .layers.custom_layers import ConvolutionBertLayer, get_custom_objects def create_bi_lstm_model( - max_seq_length, - vocab_size, - embedding_size, - concept_embeddings, - dropout_rate=0.2, - lstm_unit=128, - activation='relu', - is_bi_directional=True + max_seq_length, + vocab_size, + embedding_size, + concept_embeddings, + dropout_rate=0.2, + lstm_unit=128, + activation="relu", + is_bi_directional=True, ): - age_of_visit_input = tf.keras.layers.Input(name='age', shape=(1,)) + age_of_visit_input = tf.keras.layers.Input(name="age", shape=(1,)) - age_batch_norm_layer = tf.keras.layers.BatchNormalization(name='age_batch_norm_layer') + age_batch_norm_layer = tf.keras.layers.BatchNormalization(name="age_batch_norm_layer") normalized_index_age = age_batch_norm_layer(age_of_visit_input) - concept_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype='int32', name='concept_ids') + concept_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="concept_ids") if concept_embeddings is not None: embedding_layer = tf.keras.layers.Embedding( vocab_size, embedding_size, embeddings_initializer=Constant(concept_embeddings), - mask_zero=True + mask_zero=True, ) else: - embedding_layer = tf.keras.layers.Embedding( - vocab_size, - embedding_size, - mask_zero=True - ) + embedding_layer = tf.keras.layers.Embedding(vocab_size, embedding_size, mask_zero=True) bi_lstm_layer = tf.keras.layers.LSTM(lstm_unit) @@ -51,7 +45,7 @@ def create_bi_lstm_model( dropout_dense_layer = tf.keras.layers.Dropout(dropout_rate) - output_layer = tf.keras.layers.Dense(1, activation='sigmoid') + output_layer = tf.keras.layers.Dense(1, activation="sigmoid") next_input = embedding_layer(concept_ids) @@ -63,141 +57,147 @@ def create_bi_lstm_model( output = output_layer(next_input) - model = Model(inputs=[concept_ids, age_of_visit_input], outputs=output, name='Vanilla_BI_LSTM') + model = Model(inputs=[concept_ids, age_of_visit_input], outputs=output, name="Vanilla_BI_LSTM") return model def create_vanilla_feed_forward_model(vanilla_bert_model_path): """ - BERT + Feedforward model for binary prediction + BERT + Feedforward model for binary prediction. + :param vanilla_bert_model_path: :return: """ - age_at_index_date = tf.keras.layers.Input(name='age', shape=(1,)) + age_at_index_date = tf.keras.layers.Input(name="age", shape=(1,)) - vanilla_bert_model = tf.keras.models.load_model(vanilla_bert_model_path, - custom_objects=dict(**get_custom_objects())) - bert_inputs = [i for i in vanilla_bert_model.inputs if - 'visit' not in i.name or ('visit_segment' in i.name - or 'visit_concept_order' in i.name)] + vanilla_bert_model = tf.keras.models.load_model( + vanilla_bert_model_path, custom_objects=dict(**get_custom_objects()) + ) + bert_inputs = [ + i + for i in vanilla_bert_model.inputs + if "visit" not in i.name or ("visit_segment" in i.name or "visit_concept_order" in i.name) + ] - contextualized_embeddings, _ = vanilla_bert_model.get_layer('encoder').output + contextualized_embeddings, _ = vanilla_bert_model.get_layer("encoder").output _, _, embedding_size = contextualized_embeddings.get_shape().as_list() - mask_input = [i for i in bert_inputs if - 'mask' in i.name and 'concept' not in i.name][0] - mask_embeddings = tf.tile(tf.expand_dims(mask_input == 0, -1, name='bert_expand_ff'), - [1, 1, embedding_size], name='bert_tile_ff') - contextualized_embeddings = tf.math.multiply(contextualized_embeddings, - tf.cast(mask_embeddings, dtype=tf.float32, - name='bert_cast_ff'), - name='bert_multiply_ff') + mask_input = [i for i in bert_inputs if "mask" in i.name and "concept" not in i.name][0] + mask_embeddings = tf.tile( + tf.expand_dims(mask_input == 0, -1, name="bert_expand_ff"), + [1, 1, embedding_size], + name="bert_tile_ff", + ) + contextualized_embeddings = tf.math.multiply( + contextualized_embeddings, + tf.cast(mask_embeddings, dtype=tf.float32, name="bert_cast_ff"), + name="bert_multiply_ff", + ) - output_layer = tf.keras.layers.Dense(1, name='prediction', activation='sigmoid') + output_layer = tf.keras.layers.Dense(1, name="prediction", activation="sigmoid") output = output_layer(tf.reduce_sum(contextualized_embeddings, axis=-2)) - lstm_with_vanilla_bert = Model(inputs=bert_inputs + [age_at_index_date], - outputs=output, name='Vanilla_BERT_PLUS_BI_LSTM') + lstm_with_vanilla_bert = Model( + inputs=bert_inputs + [age_at_index_date], + outputs=output, + name="Vanilla_BERT_PLUS_BI_LSTM", + ) return lstm_with_vanilla_bert def create_sliding_bert_model(model_path, max_seq_length, context_window, stride): - age_at_index_date = tf.keras.layers.Input(name='age', shape=(1,)) + age_at_index_date = tf.keras.layers.Input(name="age", shape=(1,)) - age_batch_norm_layer = tf.keras.layers.BatchNormalization(name='age_batch_norm_layer') + age_batch_norm_layer = tf.keras.layers.BatchNormalization(name="age_batch_norm_layer") normalized_index_age = age_batch_norm_layer(age_at_index_date) - concept_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype='int32', - name='concept_ids') - visit_segments = tf.keras.layers.Input(shape=(max_seq_length,), dtype='int32', - name='visit_segments') - time_stamps = tf.keras.layers.Input(shape=(max_seq_length,), - dtype='int32', - name='time_stamps') - visit_concept_orders = tf.keras.layers.Input( - shape=(max_seq_length,), - dtype='int32', - name='visit_concept_orders') - ages = tf.keras.layers.Input(shape=(max_seq_length,), - dtype='int32', - name='ages') - mask = tf.keras.layers.Input(shape=(max_seq_length,), dtype='int32', name='mask') - - convolution_bert_layer = ConvolutionBertLayer(model_path=model_path, - seq_len=max_seq_length, - context_window=context_window, - stride=stride) - - conv_bert_output = convolution_bert_layer([concept_ids, - visit_segments, - visit_concept_orders, - time_stamps, - ages, - mask]) + concept_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="concept_ids") + visit_segments = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="visit_segments") + time_stamps = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="time_stamps") + visit_concept_orders = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="visit_concept_orders") + ages = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="ages") + mask = tf.keras.layers.Input(shape=(max_seq_length,), dtype="int32", name="mask") + + convolution_bert_layer = ConvolutionBertLayer( + model_path=model_path, + seq_len=max_seq_length, + context_window=context_window, + stride=stride, + ) + + conv_bert_output = convolution_bert_layer( + [concept_ids, visit_segments, visit_concept_orders, time_stamps, ages, mask] + ) next_input = tf.keras.layers.concatenate([conv_bert_output, normalized_index_age]) dropout_conv_layer = tf.keras.layers.Dropout(0.2) - dense_layer = tf.keras.layers.Dense(64, activation='tanh') + dense_layer = tf.keras.layers.Dense(64, activation="tanh") dropout_dense_layer = tf.keras.layers.Dropout(0.2) - output_layer = tf.keras.layers.Dense(1, name='prediction', activation='sigmoid') + output_layer = tf.keras.layers.Dense(1, name="prediction", activation="sigmoid") output = output_layer(dropout_dense_layer(dense_layer(dropout_conv_layer(next_input)))) - model_inputs = [concept_ids, visit_segments, visit_concept_orders, time_stamps, ages, mask] - ffd_bert_model = tf.keras.models.Model(inputs=model_inputs + [age_at_index_date], - outputs=output) + model_inputs = [ + concept_ids, + visit_segments, + visit_concept_orders, + time_stamps, + ages, + mask, + ] + ffd_bert_model = tf.keras.models.Model(inputs=model_inputs + [age_at_index_date], outputs=output) return ffd_bert_model def create_vanilla_bert_bi_lstm_model( - max_seq_length, - vanilla_bert_model_path, - dropout_rate=0.2, - lstm_unit=128, - activation='relu', - is_bi_directional=True + max_seq_length, + vanilla_bert_model_path, + dropout_rate=0.2, + lstm_unit=128, + activation="relu", + is_bi_directional=True, ): - age_of_visit_input = tf.keras.layers.Input(name='age', shape=(1,)) + age_of_visit_input = tf.keras.layers.Input(name="age", shape=(1,)) - age_batch_norm_layer = tf.keras.layers.BatchNormalization(name='age_batch_norm_layer') + age_batch_norm_layer = tf.keras.layers.BatchNormalization(name="age_batch_norm_layer") normalized_index_age = age_batch_norm_layer(age_of_visit_input) vanilla_bert_model = tf.keras.models.load_model( - vanilla_bert_model_path, - custom_objects=dict(**get_custom_objects()) + vanilla_bert_model_path, custom_objects=dict(**get_custom_objects()) ) - bert_inputs = [i for i in vanilla_bert_model.inputs if - 'visit' not in i.name or ('visit_segment' in i.name - or 'visit_concept_order' in i.name)] + bert_inputs = [ + i + for i in vanilla_bert_model.inputs + if "visit" not in i.name or ("visit_segment" in i.name or "visit_concept_order" in i.name) + ] # bert_inputs = vanilla_bert_model.inputs - contextualized_embeddings, _ = vanilla_bert_model.get_layer('encoder').output + contextualized_embeddings, _ = vanilla_bert_model.get_layer("encoder").output _, _, embedding_size = contextualized_embeddings.get_shape().as_list() # mask_input = bert_inputs[-1] - mask_input = [i for i in bert_inputs if - 'mask' in i.name and 'concept' not in i.name][0] + mask_input = [i for i in bert_inputs if "mask" in i.name and "concept" not in i.name][0] mask_embeddings = tf.tile( - tf.expand_dims(mask_input == 0, -1, name='expand_mask'), - [1, 1, embedding_size], name='tile_mask' + tf.expand_dims(mask_input == 0, -1, name="expand_mask"), + [1, 1, embedding_size], + name="tile_mask", ) contextualized_embeddings = tf.math.multiply( contextualized_embeddings, - tf.cast(mask_embeddings, dtype=tf.float32, name='cast_mask') + tf.cast(mask_embeddings, dtype=tf.float32, name="cast_mask"), ) - masking_layer = tf.keras.layers.Masking(mask_value=0., - input_shape=(max_seq_length, embedding_size)) + masking_layer = tf.keras.layers.Masking(mask_value=0.0, input_shape=(max_seq_length, embedding_size)) bi_lstm_layer = tf.keras.layers.LSTM(lstm_unit) @@ -210,13 +210,11 @@ def create_vanilla_bert_bi_lstm_model( dropout_dense_layer = tf.keras.layers.Dropout(dropout_rate) - output_layer = tf.keras.layers.Dense(1, activation='sigmoid') + output_layer = tf.keras.layers.Dense(1, activation="sigmoid") # attach a property to the concept embeddings to indicate where are the masks, send the flag # to the downstream layer - next_input = masking_layer( - contextualized_embeddings - ) + next_input = masking_layer(contextualized_embeddings) next_input = dropout_lstm_layer(bi_lstm_layer(next_input)) @@ -226,25 +224,28 @@ def create_vanilla_bert_bi_lstm_model( output = output_layer(next_input) - lstm_with_vanilla_bert = Model(inputs=bert_inputs + [age_of_visit_input], - outputs=output, name='Vanilla_BERT_PLUS_BI_LSTM') + lstm_with_vanilla_bert = Model( + inputs=bert_inputs + [age_of_visit_input], + outputs=output, + name="Vanilla_BERT_PLUS_BI_LSTM", + ) return lstm_with_vanilla_bert def create_random_vanilla_bert_bi_lstm_model( - max_seq_length, - embedding_size, - depth, - tokenizer, - visit_tokenizer, - num_heads, - use_time_embedding, - time_embeddings_size + max_seq_length, + embedding_size, + depth, + tokenizer, + visit_tokenizer, + num_heads, + use_time_embedding, + time_embeddings_size, ): - age_of_visit_input = tf.keras.layers.Input(name='age', shape=(1,)) + age_of_visit_input = tf.keras.layers.Input(name="age", shape=(1,)) - age_batch_norm_layer = tf.keras.layers.BatchNormalization(name='age_batch_norm_layer') + age_batch_norm_layer = tf.keras.layers.BatchNormalization(name="age_batch_norm_layer") normalized_index_age = age_batch_norm_layer(age_of_visit_input) @@ -256,39 +257,44 @@ def create_random_vanilla_bert_bi_lstm_model( depth=depth, num_heads=num_heads, use_time_embedding=use_time_embedding, - time_embeddings_size=time_embeddings_size + time_embeddings_size=time_embeddings_size, ) - bert_inputs = [i for i in vanilla_bert_model.inputs if - 'visit' not in i.name or ('visit_segment' in i.name - or 'visit_concept_order' in i.name)] + bert_inputs = [ + i + for i in vanilla_bert_model.inputs + if "visit" not in i.name or ("visit_segment" in i.name or "visit_concept_order" in i.name) + ] # bert_inputs = vanilla_bert_model.inputs - contextualized_embeddings, _ = vanilla_bert_model.get_layer('encoder').output + contextualized_embeddings, _ = vanilla_bert_model.get_layer("encoder").output _, _, embedding_size = contextualized_embeddings.get_shape().as_list() # mask_input = bert_inputs[-1] - mask_input = [i for i in bert_inputs if - 'mask' in i.name and 'concept' not in i.name][0] - mask_embeddings = tf.tile(tf.expand_dims(mask_input == 0, -1, name='expand_mask'), - [1, 1, embedding_size], name='tile_mask') - contextualized_embeddings = tf.math.multiply(contextualized_embeddings, - tf.cast(mask_embeddings, dtype=tf.float32, - name='cast_mask')) + mask_input = [i for i in bert_inputs if "mask" in i.name and "concept" not in i.name][0] + mask_embeddings = tf.tile( + tf.expand_dims(mask_input == 0, -1, name="expand_mask"), + [1, 1, embedding_size], + name="tile_mask", + ) + contextualized_embeddings = tf.math.multiply( + contextualized_embeddings, + tf.cast(mask_embeddings, dtype=tf.float32, name="cast_mask"), + ) - masking_layer = tf.keras.layers.Masking(mask_value=0., - input_shape=(max_seq_length, embedding_size)) + masking_layer = tf.keras.layers.Masking(mask_value=0.0, input_shape=(max_seq_length, embedding_size)) bi_lstm_layer = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128)) dropout_lstm_layer = tf.keras.layers.Dropout(0.2) - dense_layer = tf.keras.layers.Dense(64, activation='relu') + dense_layer = tf.keras.layers.Dense(64, activation="relu") dropout_dense_layer = tf.keras.layers.Dropout(0.2) - output_layer = tf.keras.layers.Dense(1, activation='sigmoid') + output_layer = tf.keras.layers.Dense(1, activation="sigmoid") next_input = masking_layer( - contextualized_embeddings) # attach a property to the concept embeddings to indicate where are the masks, send the flag to the downstream layer + contextualized_embeddings + ) # attach a property to the concept embeddings to indicate where are the masks, send the flag to the downstream layer next_input = dropout_lstm_layer(bi_lstm_layer(next_input)) @@ -298,98 +304,77 @@ def create_random_vanilla_bert_bi_lstm_model( output = output_layer(next_input) - lstm_with_vanilla_bert = Model(inputs=bert_inputs + [age_of_visit_input], - outputs=output, name='Vanilla_BERT_PLUS_BI_LSTM') + lstm_with_vanilla_bert = Model( + inputs=bert_inputs + [age_of_visit_input], + outputs=output, + name="Vanilla_BERT_PLUS_BI_LSTM", + ) return lstm_with_vanilla_bert -def create_hierarchical_bert_bi_lstm_model( - bert_model_path, - **kwargs -): +def create_hierarchical_bert_bi_lstm_model(bert_model_path, **kwargs): model = tf.keras.models.load_model(bert_model_path, custom_objects=get_custom_objects()) - return create_hierarchical_bert_bi_lstm_model_with_model( - model, - **kwargs - ) + return create_hierarchical_bert_bi_lstm_model_with_model(model, **kwargs) def create_hierarchical_bert_bi_lstm_model_with_model( - hierarchical_bert_model, - dropout_rate=0.2, - lstm_unit=128, - activation='relu', - is_bi_directional=True, - include_att_tokens=False, - freeze_pretrained_model=False + hierarchical_bert_model, + dropout_rate=0.2, + lstm_unit=128, + activation="relu", + is_bi_directional=True, + include_att_tokens=False, + freeze_pretrained_model=False, ): - index_age_input = tf.keras.layers.Input(name='age', shape=(1,)) + index_age_input = tf.keras.layers.Input(name="age", shape=(1,)) - age_batch_norm_layer = tf.keras.layers.BatchNormalization(name='age_batch_norm_layer') + age_batch_norm_layer = tf.keras.layers.BatchNormalization(name="age_batch_norm_layer") normalized_index_age = age_batch_norm_layer(index_age_input) _, num_of_visits, num_of_concepts, embedding_size = hierarchical_bert_model.get_layer( - 'temporal_transformation_layer' + "temporal_transformation_layer" ).output.shape - is_phenotype_enabled = ( - 'hidden_visit_embeddings' in [layer.name for layer in hierarchical_bert_model.layers] - ) + is_phenotype_enabled = "hidden_visit_embeddings" in [layer.name for layer in hierarchical_bert_model.layers] # Freeze the weight of the pretrained model if enabled if freeze_pretrained_model: hierarchical_bert_model.trainable = False if is_phenotype_enabled: - contextualized_visit_embeddings, _ = hierarchical_bert_model.get_layer( - 'hidden_visit_embeddings' - ).output + contextualized_visit_embeddings, _ = hierarchical_bert_model.get_layer("hidden_visit_embeddings").output else: - contextualized_visit_embeddings, _ = hierarchical_bert_model.get_layer( - 'visit_encoder' - ).output + contextualized_visit_embeddings, _ = hierarchical_bert_model.get_layer("visit_encoder").output if not include_att_tokens: # Pad contextualized_visit_embeddings on axis 1 with one extra visit so we can extract the # visit embeddings using the reshape trick expanded_contextualized_visit_embeddings = tf.concat( - [contextualized_visit_embeddings, - contextualized_visit_embeddings[:, 0:1, :]], - axis=1 + [ + contextualized_visit_embeddings, + contextualized_visit_embeddings[:, 0:1, :], + ], + axis=1, ) # Extract the visit embeddings elements contextualized_visit_embeddings = tf.reshape( - expanded_contextualized_visit_embeddings, (-1, num_of_visits, 3 * embedding_size) - )[:, :, embedding_size: embedding_size * 2] + expanded_contextualized_visit_embeddings, + (-1, num_of_visits, 3 * embedding_size), + )[:, :, embedding_size : embedding_size * 2] - visit_mask = hierarchical_bert_model.get_layer('visit_mask').output + visit_mask = hierarchical_bert_model.get_layer("visit_mask").output if include_att_tokens: - visit_mask = tf.reshape( - tf.tile(visit_mask[:, :, tf.newaxis], [1, 1, 3]), - (-1, num_of_visits * 3) - )[:, 1:] + visit_mask = tf.reshape(tf.tile(visit_mask[:, :, tf.newaxis], [1, 1, 3]), (-1, num_of_visits * 3))[:, 1:] - mask_embeddings = tf.cast( - tf.math.logical_not( - tf.cast( - visit_mask, - dtype=tf.bool - ) - ), - dtype=tf.float32 - )[:, :, tf.newaxis] + mask_embeddings = tf.cast(tf.math.logical_not(tf.cast(visit_mask, dtype=tf.bool)), dtype=tf.float32)[ + :, :, tf.newaxis + ] - contextualized_embeddings = tf.math.multiply( - contextualized_visit_embeddings, - mask_embeddings - ) + contextualized_embeddings = tf.math.multiply(contextualized_visit_embeddings, mask_embeddings) - masking_layer = tf.keras.layers.Masking( - mask_value=0., - input_shape=(num_of_visits, embedding_size) - ) + masking_layer = tf.keras.layers.Masking(mask_value=0.0, input_shape=(num_of_visits, embedding_size)) bi_lstm_layer = tf.keras.layers.LSTM(lstm_unit) @@ -402,15 +387,13 @@ def create_hierarchical_bert_bi_lstm_model_with_model( dropout_dense_layer = tf.keras.layers.Dropout(dropout_rate) - output_layer = tf.keras.layers.Dense(1, activation='sigmoid', name='label') + output_layer = tf.keras.layers.Dense(1, activation="sigmoid", name="label") next_input = masking_layer(contextualized_embeddings) next_input = dropout_lstm_layer(bi_lstm_layer(next_input)) - next_input = tf.keras.layers.concatenate( - [next_input, tf.reshape(normalized_index_age, (-1, 1))] - ) + next_input = tf.keras.layers.concatenate([next_input, tf.reshape(normalized_index_age, (-1, 1))]) next_input = dropout_dense_layer(dense_layer(next_input)) @@ -419,65 +402,54 @@ def create_hierarchical_bert_bi_lstm_model_with_model( lstm_with_hierarchical_bert = tf.keras.models.Model( inputs=hierarchical_bert_model.inputs + [index_age_input], outputs=output, - name='HIERARCHICAL_BERT_PLUS_BI_LSTM' + name="HIERARCHICAL_BERT_PLUS_BI_LSTM", ) return lstm_with_hierarchical_bert def create_hierarchical_bert_model_with_pooling( - bert_model_path, - dropout_rate=0.2, - activation='tanh', - freeze_pretrained_model=False, - **kwargs + bert_model_path, + dropout_rate=0.2, + activation="tanh", + freeze_pretrained_model=False, + **kwargs, ): - hierarchical_bert_model = tf.keras.models.load_model( - bert_model_path, - custom_objects=get_custom_objects() - ) + hierarchical_bert_model = tf.keras.models.load_model(bert_model_path, custom_objects=get_custom_objects()) - age_of_visit_input = tf.keras.layers.Input(name='age', shape=(1,)) + age_of_visit_input = tf.keras.layers.Input(name="age", shape=(1,)) _, num_of_visits, num_of_concepts, embedding_size = hierarchical_bert_model.get_layer( - 'temporal_transformation_layer' + "temporal_transformation_layer" ).output.shape - is_phenotype_enabled = ( - 'hidden_visit_embeddings' in [layer.name for layer in hierarchical_bert_model.layers] - ) + is_phenotype_enabled = "hidden_visit_embeddings" in [layer.name for layer in hierarchical_bert_model.layers] # Freeze the weight of the pretrained model if enabled if freeze_pretrained_model: hierarchical_bert_model.trainable = False if is_phenotype_enabled: - contextualized_visit_embeddings, _ = hierarchical_bert_model.get_layer( - 'hidden_visit_embeddings' - ).output + contextualized_visit_embeddings, _ = hierarchical_bert_model.get_layer("hidden_visit_embeddings").output else: - contextualized_visit_embeddings, _ = hierarchical_bert_model.get_layer( - 'visit_encoder' - ).output + contextualized_visit_embeddings, _ = hierarchical_bert_model.get_layer("visit_encoder").output - visit_masks = hierarchical_bert_model.get_layer('visit_mask').output + visit_masks = hierarchical_bert_model.get_layer("visit_mask").output # Get the first embedding from the visit embedding sequence # [batch_size, embedding_size] visit_embedding_pooling = tf.gather_nd( contextualized_visit_embeddings, indices=tf.argmax(visit_masks, axis=1)[:, tf.newaxis], - batch_dims=1 + batch_dims=1, ) dense_layer_1 = tf.keras.layers.Dense(128, activation=activation) dropout_dense_layer_1 = tf.keras.layers.Dropout(dropout_rate) dense_layer_2 = tf.keras.layers.Dense(64, activation=activation) dropout_dense_layer_2 = tf.keras.layers.Dropout(dropout_rate) - output_layer = tf.keras.layers.Dense(1, activation='sigmoid', name='label') + output_layer = tf.keras.layers.Dense(1, activation="sigmoid", name="label") - next_input = tf.keras.layers.concatenate( - [visit_embedding_pooling, tf.reshape(age_of_visit_input, (-1, 1))] - ) + next_input = tf.keras.layers.concatenate([visit_embedding_pooling, tf.reshape(age_of_visit_input, (-1, 1))]) next_input = dropout_dense_layer_1(dense_layer_1(next_input)) next_input = dropout_dense_layer_2(dense_layer_2(next_input)) @@ -487,7 +459,7 @@ def create_hierarchical_bert_model_with_pooling( hierarchical_bert_with_pooling = tf.keras.models.Model( inputs=hierarchical_bert_model.inputs + [age_of_visit_input], outputs=output, - name='HIERARCHICAL_BERT_POOLING' + name="HIERARCHICAL_BERT_POOLING", ) return hierarchical_bert_with_pooling @@ -495,52 +467,40 @@ def create_hierarchical_bert_model_with_pooling( def create_prob_phenotype_bi_lstm_model_with_model(bert_model_path): """ - :param bert_model_path: + :return: """ model = tf.keras.models.load_model(bert_model_path, custom_objects=get_custom_objects()) - age_of_visit_input = tf.keras.layers.Input(name='age', shape=(1,)) + age_of_visit_input = tf.keras.layers.Input(name="age", shape=(1,)) - _, _, contextual_visit_embeddings, _ = model.get_layer( - 'visit_phenotype_layer' - ).output + _, _, contextual_visit_embeddings, _ = model.get_layer("visit_phenotype_layer").output embedding_size = contextual_visit_embeddings.shape[-1] - visit_mask = [i for i in model.inputs if i.name == 'visit_mask'][0] + visit_mask = [i for i in model.inputs if i.name == "visit_mask"][0] num_of_visits = visit_mask.shape[1] # Expand dimension for masking MultiHeadAttention in Visit Encoder - visit_mask_with_att = (tf.reshape( - tf.stack([visit_mask, visit_mask], axis=2), - shape=(-1, num_of_visits * 2) - )[:, 1:]) + visit_mask_with_att = tf.reshape(tf.stack([visit_mask, visit_mask], axis=2), shape=(-1, num_of_visits * 2))[:, 1:] mask_embeddings = tf.cast( - tf.math.logical_not( - tf.cast( - visit_mask_with_att, - dtype=tf.bool - ) - ), - dtype=tf.float32 + tf.math.logical_not(tf.cast(visit_mask_with_att, dtype=tf.bool)), + dtype=tf.float32, )[:, :, tf.newaxis] contextual_visit_embeddings = tf.math.multiply(contextual_visit_embeddings, mask_embeddings) - masking_layer = tf.keras.layers.Masking( - mask_value=0., - input_shape=(num_of_visits, embedding_size)) + masking_layer = tf.keras.layers.Masking(mask_value=0.0, input_shape=(num_of_visits, embedding_size)) bi_lstm_layer = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128)) dropout_lstm_layer = tf.keras.layers.Dropout(0.2) - dense_layer = tf.keras.layers.Dense(64, activation='tanh') + dense_layer = tf.keras.layers.Dense(64, activation="tanh") dropout_dense_layer = tf.keras.layers.Dropout(0.2) - output_layer = tf.keras.layers.Dense(1, activation='sigmoid', name='label') + output_layer = tf.keras.layers.Dense(1, activation="sigmoid", name="label") next_input = masking_layer(contextual_visit_embeddings) @@ -552,40 +512,48 @@ def create_prob_phenotype_bi_lstm_model_with_model(bert_model_path): output = output_layer(next_input) - lstm_with_cher_bert = tf.keras.models.Model(inputs=model.inputs + [age_of_visit_input], - outputs=output, name='PROB_PHENOTYPE_PLUS_BI_LSTM') + lstm_with_cher_bert = tf.keras.models.Model( + inputs=model.inputs + [age_of_visit_input], + outputs=output, + name="PROB_PHENOTYPE_PLUS_BI_LSTM", + ) return lstm_with_cher_bert def create_temporal_bert_bi_lstm_model(max_seq_length, temporal_bert_model_path): - temporal_bert_model = tf.keras.models.load_model(temporal_bert_model_path, - custom_objects=dict(**get_custom_objects())) + temporal_bert_model = tf.keras.models.load_model( + temporal_bert_model_path, custom_objects=dict(**get_custom_objects()) + ) bert_inputs = temporal_bert_model.inputs[0:5] - _, _, embedding_size = temporal_bert_model.get_layer('temporal_encoder').output[ - 0].get_shape().as_list() - contextualized_embeddings, _, _ = temporal_bert_model.get_layer('temporal_encoder').output + _, _, embedding_size = temporal_bert_model.get_layer("temporal_encoder").output[0].get_shape().as_list() + contextualized_embeddings, _, _ = temporal_bert_model.get_layer("temporal_encoder").output - age_of_visit_input = tf.keras.layers.Input(name='age', shape=(1,)) + age_of_visit_input = tf.keras.layers.Input(name="age", shape=(1,)) mask_input = bert_inputs[-1] mask_embeddings = tf.cast( - tf.tile(tf.expand_dims(mask_input == 0, -1, name='expand_mask'), [1, 1, embedding_size], - name='tile_mask'), tf.float32, name='cast_mask_embeddings') + tf.tile( + tf.expand_dims(mask_input == 0, -1, name="expand_mask"), + [1, 1, embedding_size], + name="tile_mask", + ), + tf.float32, + name="cast_mask_embeddings", + ) contextualized_embeddings = tf.math.multiply(contextualized_embeddings, mask_embeddings) - masking_layer = tf.keras.layers.Masking(mask_value=0., - input_shape=(max_seq_length, embedding_size)) + masking_layer = tf.keras.layers.Masking(mask_value=0.0, input_shape=(max_seq_length, embedding_size)) bi_lstm_layer = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128)) dropout_lstm_layer = tf.keras.layers.Dropout(0.1) - dense_layer = tf.keras.layers.Dense(64, activation='tanh') + dense_layer = tf.keras.layers.Dense(64, activation="tanh") dropout_dense_layer = tf.keras.layers.Dropout(0.1) - output_layer = tf.keras.layers.Dense(1, activation='sigmoid') + output_layer = tf.keras.layers.Dense(1, activation="sigmoid") next_input = masking_layer(contextualized_embeddings) @@ -597,89 +565,69 @@ def create_temporal_bert_bi_lstm_model(max_seq_length, temporal_bert_model_path) output = output_layer(next_input) - return Model(inputs=bert_inputs + [age_of_visit_input], outputs=output, - name='TEMPORAL_BERT_PLUS_BI_LSTM') + return Model( + inputs=bert_inputs + [age_of_visit_input], + outputs=output, + name="TEMPORAL_BERT_PLUS_BI_LSTM", + ) def create_probabilistic_bert_bi_lstm_model(max_seq_length, vanilla_bert_model_path): - age_of_visit_input = tf.keras.layers.Input( - name='age', - shape=(1,) - ) + age_of_visit_input = tf.keras.layers.Input(name="age", shape=(1,)) - bert_model = tf.keras.models.load_model( - vanilla_bert_model_path, - custom_objects=dict(**get_custom_objects()) - ) + bert_model = tf.keras.models.load_model(vanilla_bert_model_path, custom_objects=dict(**get_custom_objects())) bert_inputs = bert_model.inputs # Get phenotype embeddings and probability distribution - _, phenotype_probability_dist = bert_model.get_layer('hidden_phenotype_layer').output + _, phenotype_probability_dist = bert_model.get_layer("hidden_phenotype_layer").output - num_hidden_state = bert_model.get_layer('hidden_phenotype_layer').hidden_unit + num_hidden_state = bert_model.get_layer("hidden_phenotype_layer").hidden_unit # (batch, max_sequence, embedding_size) - concept_embeddings, _ = bert_model.get_layer('concept_embeddings').output + concept_embeddings, _ = bert_model.get_layer("concept_embeddings").output _, _, embedding_size = concept_embeddings.get_shape().as_list() # (batch * num_hidden_state, max_sequence, embedding_size) - contextualized_embeddings = bert_model.get_layer('tf.reshape').output + contextualized_embeddings = bert_model.get_layer("tf.reshape").output # mask_input = bert_inputs[-1] - mask_input = [i for i in bert_inputs if - 'mask' in i.name and 'concept' not in i.name][0] + mask_input = [i for i in bert_inputs if "mask" in i.name and "concept" not in i.name][0] # (batch * num_hidden_state, max_sequence, 1) mask_embeddings = tf.reshape( - tf.tile( - (mask_input == 0)[:, tf.newaxis, :], - [1, num_hidden_state, 1] - ), - (-1, max_seq_length) + tf.tile((mask_input == 0)[:, tf.newaxis, :], [1, num_hidden_state, 1]), + (-1, max_seq_length), )[:, :, tf.newaxis] # (batch * num_hidden_state, max_sequence, embedding_size) - contextualized_embeddings = tf.math.multiply( - contextualized_embeddings, - tf.cast(mask_embeddings, dtype=tf.float32) - ) + contextualized_embeddings = tf.math.multiply(contextualized_embeddings, tf.cast(mask_embeddings, dtype=tf.float32)) # Masking layer for LSTM - masking_layer = tf.keras.layers.Masking( - mask_value=0., - input_shape=(max_seq_length, embedding_size) - ) + masking_layer = tf.keras.layers.Masking(mask_value=0.0, input_shape=(max_seq_length, embedding_size)) bi_lstm_layer = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128)) dropout_lstm_layer = tf.keras.layers.Dropout(0.2) - dense_layer = tf.keras.layers.Dense(64, activation='relu') + dense_layer = tf.keras.layers.Dense(64, activation="relu") dropout_dense_layer = tf.keras.layers.Dropout(0.2) # (batch * num_hidden_state, 1) - output_layer = tf.keras.layers.Dense(1, activation='sigmoid') + output_layer = tf.keras.layers.Dense(1, activation="sigmoid") # attach a property to the concept embeddings to indicate where are the masks, send the flag # to the downstream layer: (batch * num_hidden_state, max_sequence, embedding_size) - next_input = masking_layer( - contextualized_embeddings) + next_input = masking_layer(contextualized_embeddings) # (batch * num_hidden_state, 256) next_input = dropout_lstm_layer(bi_lstm_layer(next_input)) # (batch * num_hidden_state, 1) duplicate_age_of_visit_input = tf.reshape( - tf.tile( - age_of_visit_input[:, tf.newaxis, :], - [1, num_hidden_state, 1] - ), - (-1, 1) + tf.tile(age_of_visit_input[:, tf.newaxis, :], [1, num_hidden_state, 1]), (-1, 1) ) # (batch * num_hidden_state, 256 + 1) - next_input = tf.keras.layers.concatenate( - [next_input, duplicate_age_of_visit_input] - ) + next_input = tf.keras.layers.concatenate([next_input, duplicate_age_of_visit_input]) # (batch * num_hidden_state, 64) next_input = dropout_dense_layer(dense_layer(next_input)) @@ -688,21 +636,14 @@ def create_probabilistic_bert_bi_lstm_model(max_seq_length, vanilla_bert_model_p output = output_layer(next_input) # (batch, num_hidden_state, 1) - reshaped_output = tf.reshape( - output, - (-1, num_hidden_state, 1) - ) + reshaped_output = tf.reshape(output, (-1, num_hidden_state, 1)) # (batch, 1) - weighted_output = tf.squeeze( - tf.reduce_sum( - phenotype_probability_dist[:, :, tf.newaxis] * reshaped_output, - axis=1 - ) - ) + weighted_output = tf.squeeze(tf.reduce_sum(phenotype_probability_dist[:, :, tf.newaxis] * reshaped_output, axis=1)) lstm_with_prob_bert = Model( inputs=bert_inputs + [age_of_visit_input], - outputs=weighted_output, name='Probabilistic_BERT_PLUS_BI_LSTM' + outputs=weighted_output, + name="Probabilistic_BERT_PLUS_BI_LSTM", ) return lstm_with_prob_bert diff --git a/src/cehrbert/models/hf_models/config.py b/src/cehrbert/models/hf_models/config.py index 83317a29..6ce0bd9e 100644 --- a/src/cehrbert/models/hf_models/config.py +++ b/src/cehrbert/models/hf_models/config.py @@ -1,10 +1,12 @@ -from typing import Dict, Any, List +from typing import Dict, List + from transformers import PretrainedConfig class CEHRGPTConfig(PretrainedConfig): """ Args: + vocab_size (`int`, *optional*, defaults to 50257): Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. @@ -87,40 +89,40 @@ class CEHRGPTConfig(PretrainedConfig): } def __init__( - self, - vocab_size=50257, - time_token_vocab_size=50257, - n_positions=1024, - n_embd=768, - n_layer=12, - n_head=12, - n_inner=None, - activation_function="gelu_new", - resid_pdrop=0.1, - embd_pdrop=0.1, - attn_pdrop=0.1, - layer_norm_epsilon=1e-5, - initializer_range=0.02, - summary_type="cls_index", - summary_use_proj=True, - summary_activation=None, - summary_proj_to_labels=True, - summary_first_dropout=0.1, - scale_attn_weights=True, - use_cache=True, - bos_token_id=50256, - eos_token_id=50256, - lab_token_ids=None, - scale_attn_by_inverse_layer_idx=False, - reorder_and_upcast_attn=False, - exclude_position_ids=False, - include_values=False, - include_ttv_prediction=False, - use_sub_time_tokenization=True, - time_token_loss_weight=1.0, - time_to_visit_loss_weight=1.0, - token_to_time_token_mapping: Dict[int, List] = None, - **kwargs, + self, + vocab_size=50257, + time_token_vocab_size=50257, + n_positions=1024, + n_embd=768, + n_layer=12, + n_head=12, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + lab_token_ids=None, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + exclude_position_ids=False, + include_values=False, + include_ttv_prediction=False, + use_sub_time_tokenization=True, + time_token_loss_weight=1.0, + time_to_visit_loss_weight=1.0, + token_to_time_token_mapping: Dict[int, List] = None, + **kwargs, ): if token_to_time_token_mapping is None: token_to_time_token_mapping = {} @@ -165,8 +167,7 @@ def __init__( def token_to_time_token_mapping(self) -> Dict[int, List[int]]: # The saved _token_to_time_token_mapping converts the key to string, so we need to convert it back to int return { - int(token): list(map(int, sub_tokens)) - for token, sub_tokens in self._token_to_time_token_mapping.items() + int(token): list(map(int, sub_tokens)) for token, sub_tokens in self._token_to_time_token_mapping.items() } @@ -174,32 +175,32 @@ class CehrBertConfig(PretrainedConfig): model_type = "cehrbert" def __init__( - self, - vocab_size=20000, - n_visit_segments=3, - hidden_size=128, - n_time_embd=16, - num_hidden_layers=12, - num_attention_heads=8, - intermediate_size=2048, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - layer_norm_eps=1e-12, - pad_token_id=0, - lab_token_ids=None, - tie_word_embeddings=True, - num_labels=2, - classifier_dropout=0.1, - bidirectional=True, - include_value_prediction=False, - mlm_probability=0.15, - time_embedding_scaling_factor: float = 1000, - age_embedding_scaling_factor: float = 100, - **kwargs, + self, + vocab_size=20000, + n_visit_segments=3, + hidden_size=128, + n_time_embd=16, + num_hidden_layers=12, + num_attention_heads=8, + intermediate_size=2048, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + lab_token_ids=None, + tie_word_embeddings=True, + num_labels=2, + classifier_dropout=0.1, + bidirectional=True, + include_value_prediction=False, + mlm_probability=0.15, + time_embedding_scaling_factor: float = 1000, + age_embedding_scaling_factor: float = 100, + **kwargs, ): self.vocab_size = vocab_size self.hidden_size = hidden_size diff --git a/src/cehrbert/models/hf_models/hf_cehrbert.py b/src/cehrbert/models/hf_models/hf_cehrbert.py index 0d5bd7df..4d07361e 100644 --- a/src/cehrbert/models/hf_models/hf_cehrbert.py +++ b/src/cehrbert/models/hf_models/hf_cehrbert.py @@ -1,26 +1,23 @@ import math from typing import Optional + import torch from torch import nn from torch.nn.utils.rnn import pack_padded_sequence - -from transformers.models.bert.modeling_bert import BertEncoder, BertPooler, BertOnlyMLMHead from transformers import PreTrainedModel from transformers.activations import gelu_new -from ...models.hf_models.config import CehrBertConfig -from ...models.hf_models.hf_modeling_outputs import CehrBertModelOutput, CehrBertSequenceClassifierOutput +from transformers.models.bert.modeling_bert import BertEncoder, BertOnlyMLMHead, BertPooler from transformers.utils import logging +from cehrbert.models.hf_models.config import CehrBertConfig +from cehrbert.models.hf_models.hf_modeling_outputs import CehrBertModelOutput, CehrBertSequenceClassifierOutput + logger = logging.get_logger("transformers") LARGE_POSITION_VALUE = 1000000 class PositionalEncodingLayer(nn.Module): - def __init__( - self, - embedding_size: int, - max_sequence_length: int - ): + def __init__(self, embedding_size: int, max_sequence_length: int): super(PositionalEncodingLayer, self).__init__() self.max_sequence_length = max_sequence_length @@ -29,26 +26,19 @@ def __init__( pe = torch.zeros(max_sequence_length, embedding_size) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) - self.register_buffer('pe', pe) + self.register_buffer("pe", pe) - def forward( - self, - visit_concept_orders: torch.IntTensor - ) -> torch.Tensor: + def forward(self, visit_concept_orders: torch.IntTensor) -> torch.Tensor: # Normalize the visit_orders using the smallest visit_concept_orders masked_visit_concept_orders = torch.where( visit_concept_orders > 0, visit_concept_orders, - torch.tensor(LARGE_POSITION_VALUE) + torch.tensor(LARGE_POSITION_VALUE), ) first_vals = torch.min(masked_visit_concept_orders, dim=1).values.unsqueeze(dim=-1) visit_concept_orders = visit_concept_orders - first_vals - visit_concept_orders = torch.maximum( - visit_concept_orders, torch.zeros_like(visit_concept_orders) - ) - visit_concept_orders = torch.minimum( - visit_concept_orders, torch.tensor(self.max_sequence_length) - 1 - ) + visit_concept_orders = torch.maximum(visit_concept_orders, torch.zeros_like(visit_concept_orders)) + visit_concept_orders = torch.minimum(visit_concept_orders, torch.tensor(self.max_sequence_length) - 1) # Get the same positional encodings for the concepts with the same visit_order positional_embeddings = self.pe[visit_concept_orders] return positional_embeddings @@ -56,10 +46,10 @@ def forward( class TimeEmbeddingLayer(nn.Module): def __init__( - self, - embedding_size: int, - is_time_delta: bool = False, - scaling_factor: float = 1.0 + self, + embedding_size: int, + is_time_delta: bool = False, + scaling_factor: float = 1.0, ): super(TimeEmbeddingLayer, self).__init__() self.embedding_size = embedding_size @@ -68,17 +58,13 @@ def __init__( self.w = nn.Parameter(torch.randn(1, self.embedding_size)) self.phi = nn.Parameter(torch.randn(1, self.embedding_size)) - def forward( - self, - dates: torch.Tensor - ) -> torch.Tensor: + def forward(self, dates: torch.Tensor) -> torch.Tensor: dates = dates.to(torch.float) dates = dates / self.scaling_factor if self.is_time_delta: dates = torch.cat( - [torch.zeros(dates[..., 0:1].shape), - dates[..., 1:] - dates[..., :-1]], - dim=-1 + [torch.zeros(dates[..., 0:1].shape), dates[..., 1:] - dates[..., :-1]], + dim=-1, ) next_input = dates.unsqueeze(-1) * self.w + self.phi return torch.sin(next_input) @@ -87,19 +73,16 @@ def forward( class ConceptValueTransformationLayer(nn.Module): def __init__(self, embedding_size): super(ConceptValueTransformationLayer, self).__init__() - self.merge_value_transformation_layer = nn.Linear( - embedding_size + 1, - embedding_size - ) + self.merge_value_transformation_layer = nn.Linear(embedding_size + 1, embedding_size) def forward( - self, - concept_embeddings: torch.Tensor, - concept_values: Optional[torch.FloatTensor] = None, - concept_value_masks: Optional[torch.FloatTensor] = None, + self, + concept_embeddings: torch.Tensor, + concept_values: Optional[torch.FloatTensor] = None, + concept_value_masks: Optional[torch.FloatTensor] = None, ): if concept_values is None or concept_value_masks is None: - logger.warning('concept_values and concept_value_masks are ignored') + logger.warning("concept_values and concept_value_masks are ignored") return concept_embeddings # (batch_size, seq_length, 1) @@ -107,18 +90,14 @@ def forward( # (batch_size, seq_length, 1) concept_value_masks = concept_value_masks.unsqueeze(-1) # (batch_size, seq_length, 1 + embedding_size) - concept_embeddings_with_val = torch.cat( - [concept_embeddings, concept_values], dim=-1 - ) + concept_embeddings_with_val = torch.cat([concept_embeddings, concept_values], dim=-1) # Run through a dense layer to bring the dimension back to embedding_size - concept_embeddings_with_val = self.merge_value_transformation_layer( - concept_embeddings_with_val - ) + concept_embeddings_with_val = self.merge_value_transformation_layer(concept_embeddings_with_val) merged = torch.where( concept_value_masks.to(torch.bool), concept_embeddings_with_val, - concept_embeddings + concept_embeddings, ) return merged @@ -134,10 +113,7 @@ def __init__(self, embedding_size): nn.Linear(embedding_size // 2, 1), ) - def forward( - self, - hidden_states: Optional[torch.FloatTensor] - ): + def forward(self, hidden_states: Optional[torch.FloatTensor]): # (batch_size, context_window, 1) concept_vals = self.concept_value_decoder_layer(hidden_states) return concept_vals @@ -159,29 +135,23 @@ def __init__(self, config: CehrBertConfig): self.linear_proj = nn.Linear(config.hidden_size + 3 * config.n_time_embd, config.hidden_size) def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - ages: Optional[torch.LongTensor] = None, - dates: Optional[torch.LongTensor] = None, - visit_concept_orders: Optional[torch.LongTensor] = None, - concept_values: Optional[torch.FloatTensor] = None, - concept_value_masks: Optional[torch.FloatTensor] = None, - visit_segments: Optional[torch.LongTensor] = None + self, + input_ids: Optional[torch.LongTensor] = None, + ages: Optional[torch.LongTensor] = None, + dates: Optional[torch.LongTensor] = None, + visit_concept_orders: Optional[torch.LongTensor] = None, + concept_values: Optional[torch.FloatTensor] = None, + concept_value_masks: Optional[torch.FloatTensor] = None, + visit_segments: Optional[torch.LongTensor] = None, ) -> torch.Tensor: # Get the concept embeddings x = self.concept_embeddings(input_ids) # Combine values with the concept embeddings - x = self.concept_value_transformation_layer( - x, - concept_values, - concept_value_masks - ) + x = self.concept_value_transformation_layer(x, concept_values, concept_value_masks) age_embeddings = self.age_embedding_layer(ages) time_embeddings = self.age_embedding_layer(dates) positional_embeddings = self.positional_embedding_layer(visit_concept_orders) - x = self.linear_proj( - torch.cat([x, time_embeddings, age_embeddings, positional_embeddings], dim=-1) - ) + x = self.linear_proj(torch.cat([x, time_embeddings, age_embeddings, positional_embeddings], dim=-1)) x = gelu_new(x) x += self.visit_segment_embeddings(visit_segments) return x @@ -189,8 +159,9 @@ def forward( class CehrBertPreTrainedModel(PreTrainedModel): """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. + An abstract class to handle weights initialization and a simple interface for downloading. + + and loading pretrained models. """ config_class = CehrBertConfig @@ -200,7 +171,7 @@ class CehrBertPreTrainedModel(PreTrainedModel): _no_split_modules = ["BertLayer"] def _init_weights(self, module): - """Initialize the weights""" + """Initialize the weights.""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 @@ -217,10 +188,6 @@ def _init_weights(self, module): class CehrBert(CehrBertPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ def __init__(self, config: CehrBertConfig): super().__init__(config) @@ -233,17 +200,17 @@ def __init__(self, config: CehrBertConfig): self.post_init() def forward( - self, - input_ids: torch.LongTensor, - attention_mask: torch.Tensor, - ages: Optional[torch.LongTensor] = None, - dates: Optional[torch.LongTensor] = None, - visit_concept_orders: Optional[torch.LongTensor] = None, - concept_values: Optional[torch.FloatTensor] = None, - concept_value_masks: Optional[torch.FloatTensor] = None, - visit_segments: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None + self, + input_ids: torch.LongTensor, + attention_mask: torch.Tensor, + ages: Optional[torch.LongTensor] = None, + dates: Optional[torch.LongTensor] = None, + visit_concept_orders: Optional[torch.LongTensor] = None, + concept_values: Optional[torch.FloatTensor] = None, + concept_value_masks: Optional[torch.FloatTensor] = None, + visit_segments: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, ) -> CehrBertModelOutput: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -253,7 +220,8 @@ def forward( self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # We can provide a self-attention mask of dimensions + # [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) @@ -264,14 +232,14 @@ def forward( visit_concept_orders=visit_concept_orders, concept_values=concept_values, concept_value_masks=concept_value_masks, - visit_segments=visit_segments + visit_segments=visit_segments, ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=True + return_dict=True, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) @@ -279,7 +247,7 @@ def forward( return CehrBertModelOutput( last_hidden_state=encoder_outputs.last_hidden_state, attentions=encoder_outputs.attentions, - pooler_output=pooled_output + pooler_output=pooled_output, ) @@ -310,19 +278,19 @@ def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings def forward( - self, - input_ids: torch.LongTensor, - attention_mask: torch.Tensor, - ages: Optional[torch.LongTensor] = None, - dates: Optional[torch.LongTensor] = None, - visit_concept_orders: Optional[torch.LongTensor] = None, - concept_values: Optional[torch.FloatTensor] = None, - concept_value_masks: Optional[torch.FloatTensor] = None, - visit_segments: Optional[torch.LongTensor] = None, - mlm_skip_values: Optional[torch.BoolTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - labels: Optional[torch.LongTensor] = None + self, + input_ids: torch.LongTensor, + attention_mask: torch.Tensor, + ages: Optional[torch.LongTensor] = None, + dates: Optional[torch.LongTensor] = None, + visit_concept_orders: Optional[torch.LongTensor] = None, + concept_values: Optional[torch.FloatTensor] = None, + concept_value_masks: Optional[torch.FloatTensor] = None, + visit_segments: Optional[torch.LongTensor] = None, + mlm_skip_values: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, ) -> CehrBertModelOutput: cehrbert_output = self.bert( input_ids, @@ -334,7 +302,7 @@ def forward( concept_value_masks, visit_segments, output_attentions, - output_hidden_states + output_hidden_states, ) prediction_scores = self.cls(cehrbert_output.last_hidden_state) @@ -354,9 +322,7 @@ def forward( predicted_values = self.concept_value_decoder_layer(cehrbert_output.last_hidden_state) num_items = torch.sum(concept_value_masks.to(torch.float32), dim=-1) + 1e-6 values_ = (predicted_values.squeeze(-1) - concept_values) ** 2 - masked_mse = torch.sum( - values_ * concept_value_masks * mlm_masks, dim=-1 - ) / num_items + masked_mse = torch.sum(values_ * concept_value_masks * mlm_masks, dim=-1) / num_items total_loss += torch.mean(masked_mse) return CehrBertModelOutput( @@ -364,7 +330,7 @@ def forward( prediction_logits=prediction_scores, last_hidden_state=cehrbert_output.last_hidden_state, attentions=cehrbert_output.attentions, - pooler_output=cehrbert_output.pooler_output + pooler_output=cehrbert_output.pooler_output, ) @@ -388,19 +354,19 @@ def __init__(self, config: CehrBertConfig): self.post_init() def forward( - self, - input_ids: torch.LongTensor, - attention_mask: torch.Tensor, - age_at_index: torch.FloatTensor, - ages: Optional[torch.LongTensor] = None, - dates: Optional[torch.LongTensor] = None, - visit_concept_orders: Optional[torch.LongTensor] = None, - concept_values: Optional[torch.FloatTensor] = None, - concept_value_masks: Optional[torch.FloatTensor] = None, - visit_segments: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - classifier_label: Optional[torch.FloatTensor] = None + self, + input_ids: torch.LongTensor, + attention_mask: torch.Tensor, + age_at_index: torch.FloatTensor, + ages: Optional[torch.LongTensor] = None, + dates: Optional[torch.LongTensor] = None, + visit_concept_orders: Optional[torch.LongTensor] = None, + concept_values: Optional[torch.FloatTensor] = None, + concept_value_masks: Optional[torch.FloatTensor] = None, + visit_segments: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + classifier_label: Optional[torch.FloatTensor] = None, ) -> CehrBertSequenceClassifierOutput: normalized_age = self.age_batch_norm(age_at_index) @@ -414,7 +380,7 @@ def forward( concept_value_masks, visit_segments, output_attentions, - output_hidden_states + output_hidden_states, ) next_input = self.dropout(cehrbert_output.pooler_output) @@ -433,7 +399,7 @@ def forward( loss=loss, logits=logits, hidden_states=cehrbert_output.last_hidden_state, - attentions=cehrbert_output.attentions + attentions=cehrbert_output.attentions, ) @@ -449,7 +415,7 @@ def __init__(self, config: CehrBertConfig): input_size=config.hidden_size, hidden_size=config.hidden_size, batch_first=True, - bidirectional=config.bidirectional + bidirectional=config.bidirectional, ) classifier_dropout = ( @@ -464,19 +430,19 @@ def __init__(self, config: CehrBertConfig): self.post_init() def forward( - self, - input_ids: torch.LongTensor, - attention_mask: torch.Tensor, - age_at_index: torch.FloatTensor, - ages: Optional[torch.LongTensor] = None, - dates: Optional[torch.LongTensor] = None, - visit_concept_orders: Optional[torch.LongTensor] = None, - concept_values: Optional[torch.FloatTensor] = None, - concept_value_masks: Optional[torch.FloatTensor] = None, - visit_segments: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - classifier_label: Optional[torch.FloatTensor] = None + self, + input_ids: torch.LongTensor, + attention_mask: torch.Tensor, + age_at_index: torch.FloatTensor, + ages: Optional[torch.LongTensor] = None, + dates: Optional[torch.LongTensor] = None, + visit_concept_orders: Optional[torch.LongTensor] = None, + concept_values: Optional[torch.FloatTensor] = None, + concept_value_masks: Optional[torch.FloatTensor] = None, + visit_segments: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + classifier_label: Optional[torch.FloatTensor] = None, ) -> CehrBertSequenceClassifierOutput: normalized_age = self.age_batch_norm(age_at_index) @@ -490,13 +456,14 @@ def forward( concept_value_masks, visit_segments, output_attentions, - output_hidden_states + output_hidden_states, ) lengths = torch.sum(attention_mask, dim=-1) packed_input = pack_padded_sequence( - cehrbert_output.last_hidden_state, lengths.cpu(), + cehrbert_output.last_hidden_state, + lengths.cpu(), batch_first=True, - enforce_sorted=False + enforce_sorted=False, ) _, (h_n, c_n) = self.lstm(packed_input) next_input = self.dropout(h_n.transpose(1, 0).reshape([h_n.shape[1], -1])) @@ -515,5 +482,5 @@ def forward( loss=loss, logits=logits, hidden_states=cehrbert_output.last_hidden_state, - attentions=cehrbert_output.attentions + attentions=cehrbert_output.attentions, ) diff --git a/src/cehrbert/models/hf_models/hf_modeling_outputs.py b/src/cehrbert/models/hf_models/hf_modeling_outputs.py index 895c26bf..24853914 100644 --- a/src/cehrbert/models/hf_models/hf_modeling_outputs.py +++ b/src/cehrbert/models/hf_models/hf_modeling_outputs.py @@ -1,6 +1,7 @@ +from dataclasses import dataclass from typing import Optional, Tuple + import torch -from dataclasses import dataclass from transformers.modeling_outputs import ModelOutput @@ -152,6 +153,7 @@ class CehrBertModelOutput(ModelOutput): the classification token after processing through a linear layer and a tanh activation function. The linear layer weights are trained from the next sentence prediction (classification) objective during pretraining. """ + loss: Optional[torch.FloatTensor] = None prediction_logits: Optional[torch.FloatTensor] = None pooler_output: torch.FloatTensor = None diff --git a/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py b/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py index 04fdce34..65c4ce61 100644 --- a/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py +++ b/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py @@ -1,24 +1,25 @@ -import os import json +import os import pickle from functools import partial -from typing import Sequence, Union, List, Dict, Any from itertools import islice +from typing import Any, Dict, List, Sequence, Union import transformers from datasets import Dataset, DatasetDict from tokenizers import Tokenizer from tokenizers.models import WordLevel -from tokenizers.trainers import WordLevelTrainer from tokenizers.pre_tokenizers import WhitespaceSplit +from tokenizers.trainers import WordLevelTrainer from transformers.tokenization_utils_base import PushToHubMixin -from .tokenization_utils import agg_statistics, map_statistics, _agg_helper -from ...runners.hf_runner_argument_dataclass import DataTrainingArguments + +from cehrbert.models.hf_models.tokenization_utils import agg_helper, agg_statistics, map_statistics +from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments PAD_TOKEN = "[PAD]" CLS_TOKEN = "[CLS]" MASK_TOKEN = "[MASK]" -UNUSED_TOKEN = '[UNUSED]' +UNUSED_TOKEN = "[UNUSED]" OUT_OF_VOCABULARY_TOKEN = "[OOV]" TOKENIZER_FILE_NAME = "tokenizer.json" @@ -26,34 +27,45 @@ LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.json" -def load_json_file( - json_file -): +def load_json_file(json_file) -> Union[List[Dict[str, Any]], Dict[str, Any]]: + """ + Loads a JSON file and returns the parsed JSON object. + + Args: + json_file (str): The path to the JSON file. + + Returns: + dict: The parsed JSON object. + + Raises: + RuntimeError: If the JSON file cannot be read. + """ try: with open(json_file, "r", encoding="utf-8") as reader: file_contents = reader.read() parsed_json = json.loads(file_contents) return parsed_json - except Exception as e: - raise RuntimeError(f"Can't load the json file at {json_file} due to {e}") + except RuntimeError as e: + raise RuntimeError(f"Can't load the json file at {json_file}") from e class CehrBertTokenizer(PushToHubMixin): def __init__( - self, - tokenizer: Tokenizer, - lab_stats: List[Dict[str, Any]], - concept_name_mapping: Dict[str, str] + self, + tokenizer: Tokenizer, + lab_stats: List[Dict[str, Any]], + concept_name_mapping: Dict[str, str], ): self._tokenizer = tokenizer self._lab_stats = lab_stats self._lab_stat_mapping = { lab_stat["concept_id"]: { - 'unit': lab_stat["unit"], - 'mean': lab_stat['mean'], - 'std': lab_stat['std'] - } for lab_stat in lab_stats + "unit": lab_stat["unit"], + "mean": lab_stat["mean"], + "std": lab_stat["std"], + } + for lab_stat in lab_stats } self._concept_name_mapping = concept_name_mapping self._oov_token_index = self._tokenizer.token_to_id(OUT_OF_VOCABULARY_TOKEN) @@ -68,6 +80,10 @@ def __init__( def vocab_size(self): return self._tokenizer.get_vocab_size() + @property + def oov_token_index(self): + return self._oov_token_index + @property def mask_token_index(self): return self._mask_token_index @@ -87,25 +103,26 @@ def cls_token_index(self): @property def lab_token_ids(self): reserved_tokens = [ - OUT_OF_VOCABULARY_TOKEN, PAD_TOKEN, UNUSED_TOKEN, OUT_OF_VOCABULARY_TOKEN + OUT_OF_VOCABULARY_TOKEN, + PAD_TOKEN, + UNUSED_TOKEN, + OUT_OF_VOCABULARY_TOKEN, ] - return self.encode( - [_['concept_id'] for _ in self._lab_stats if _['concept_id'] not in reserved_tokens] - ) + return self.encode([_["concept_id"] for _ in self._lab_stats if _["concept_id"] not in reserved_tokens]) def encode(self, concept_ids: Sequence[str]) -> Sequence[int]: encoded = self._tokenizer.encode(concept_ids, is_pretokenized=True) return encoded.ids def decode(self, concept_token_ids: List[int]) -> List[str]: - return self._tokenizer.decode(concept_token_ids).split(' ') + return self._tokenizer.decode(concept_token_ids).split(" ") - def _convert_token_to_id(self, token): + def convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" token_id = self._tokenizer.token_to_id(token) return token_id if token_id else self._oov_token_index - def _convert_id_to_token(self, index): + def convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" token = self._tokenizer.id_to_token(index) return token if token else OUT_OF_VOCABULARY_TOKEN @@ -115,18 +132,22 @@ def convert_tokens_to_string(self, tokens): out_string = " ".join([self._concept_name_mapping[t] for t in tokens]) return out_string - def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + push_to_hub: bool = False, + **kwargs, + ): """ Save the Cehrbert tokenizer. - This method make sure the batch processor can then be re-loaded using the .from_pretrained class method. Args: save_directory (`str` or `os.PathLike`): The path to a directory where the tokenizer will be saved. push_to_hub (`bool`, *optional*, defaults to `False`): - Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + Whether to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). kwargs (`Dict[str, Any]`, *optional*): @@ -161,9 +182,9 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: @classmethod def from_pretrained( - cls, - pretrained_model_name_or_path: Union[str, os.PathLike], - **kwargs, + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + **kwargs, ): """ Load the CehrBert tokenizer. @@ -187,7 +208,7 @@ def from_pretrained( ) if not tokenizer_file: - return None + raise RuntimeError(f"tokenizer_file does not exist: {tokenizer_file}") tokenizer = Tokenizer.from_file(tokenizer_file) @@ -195,13 +216,13 @@ def from_pretrained( pretrained_model_name_or_path, LAB_STATS_FILE_NAME, **kwargs ) if not lab_stats_file: - return None + raise RuntimeError(f"lab_stats_file does not exist: {lab_stats_file}") concept_name_mapping_file = transformers.utils.hub.cached_file( pretrained_model_name_or_path, CONCEPT_MAPPING_FILE_NAME, **kwargs ) if not concept_name_mapping_file: - return None + raise RuntimeError(f"concept_name_mapping_file does not exist: {concept_name_mapping_file}") lab_stats = load_json_file(lab_stats_file) @@ -211,29 +232,37 @@ def from_pretrained( @classmethod def train_tokenizer( - cls, - dataset: Union[Dataset, DatasetDict], - feature_names: List[str], - concept_name_mapping: Dict[str, str], - data_args: DataTrainingArguments + cls, + dataset: Union[Dataset, DatasetDict], + feature_names: List[str], + concept_name_mapping: Dict[str, str], + data_args: DataTrainingArguments, ): """ - Train a huggingface word level tokenizer. To use their tokenizer, we need to concatenate all the concepts + Train a huggingface word level tokenizer. + + To use their tokenizer, we need to concatenate all the concepts together and treat it as a sequence. """ if isinstance(dataset, DatasetDict): - dataset = dataset['train'] + dataset = dataset["train"] # Use the Fast Tokenizer from the Huggingface tokenizers Rust implementation. # https://github.com/huggingface/tokenizers tokenizer = Tokenizer(WordLevel(unk_token=OUT_OF_VOCABULARY_TOKEN, vocab=dict())) tokenizer.pre_tokenizer = WhitespaceSplit() trainer = WordLevelTrainer( - special_tokens=[PAD_TOKEN, MASK_TOKEN, OUT_OF_VOCABULARY_TOKEN, CLS_TOKEN, UNUSED_TOKEN], + special_tokens=[ + PAD_TOKEN, + MASK_TOKEN, + OUT_OF_VOCABULARY_TOKEN, + CLS_TOKEN, + UNUSED_TOKEN, + ], vocab_size=data_args.vocab_size, min_frequency=data_args.min_frequency, - show_progress=True + show_progress=True, ) for feature_name in feature_names: batch_concat_concepts_partial_func = partial(cls.batch_concat_concepts, feature_name=feature_name) @@ -241,7 +270,7 @@ def train_tokenizer( concatenated_features = dataset.map( batch_concat_concepts_partial_func, batched=True, - batch_size=data_args.preprocessing_batch_size + batch_size=data_args.preprocessing_batch_size, ) def batched_generator(): @@ -250,9 +279,7 @@ def batched_generator(): batch = list(islice(iterator, data_args.preprocessing_batch_size)) if not batch: break - yield [ - example[feature_name] for example in batch - ] + yield [example[feature_name] for example in batch] # We pass a generator of list of texts (concatenated concept_ids) to train_from_iterator # for efficient training @@ -263,7 +290,7 @@ def batched_generator(): num_proc=data_args.preprocessing_num_workers, batched=True, batch_size=data_args.preprocessing_batch_size, - remove_columns=dataset.column_names + remove_columns=dataset.column_names, ) generator = concatenated_features[feature_name] @@ -271,14 +298,14 @@ def batched_generator(): if data_args.streaming: parts = dataset.map( - partial(_agg_helper, map_func=map_statistics), + partial(agg_helper, map_func=map_statistics), batched=True, batch_size=data_args.preprocessing_batch_size, - remove_columns=dataset.column_names + remove_columns=dataset.column_names, ) else: parts = dataset.map( - partial(_agg_helper, map_func=map_statistics), + partial(agg_helper, map_func=map_statistics), batched=True, batch_size=data_args.preprocessing_batch_size, remove_columns=dataset.column_names, @@ -295,13 +322,13 @@ def batched_generator(): current = agg_statistics(current, fixed_stat) lab_stats = [ { - 'concept_id': concept_id, - 'unit': unit, - 'mean': online_stats.mean(), - 'std': online_stats.standard_deviation(), - 'count': online_stats.count + "concept_id": concept_id, + "unit": unit, + "mean": online_stats.mean(), + "std": online_stats.standard_deviation(), + "count": online_stats.count, } - for (concept_id, unit), online_stats in current['numeric_stats_by_lab'].items() + for (concept_id, unit), online_stats in current["numeric_stats_by_lab"].items() ] return CehrBertTokenizer(tokenizer, lab_stats, concept_name_mapping) @@ -312,10 +339,10 @@ def batch_concat_concepts(cls, records: Dict[str, List], feature_name) -> Dict[s def normalize(self, concept_id, concept_value) -> float: if concept_id in self._lab_stat_mapping: - mean_ = (concept_value - self._lab_stat_mapping[concept_id]['mean']) - std = self._lab_stat_mapping[concept_id]['std'] + mean_ = concept_value - self._lab_stat_mapping[concept_id]["mean"] + std = self._lab_stat_mapping[concept_id]["std"] if std > 0: - normalized_value = mean_ / self._lab_stat_mapping[concept_id]['std'] + normalized_value = mean_ / self._lab_stat_mapping[concept_id]["std"] else: normalized_value = mean_ return normalized_value diff --git a/src/cehrbert/models/hf_models/tokenization_utils.py b/src/cehrbert/models/hf_models/tokenization_utils.py index c7d38c76..f43d8f4c 100644 --- a/src/cehrbert/models/hf_models/tokenization_utils.py +++ b/src/cehrbert/models/hf_models/tokenization_utils.py @@ -1,14 +1,12 @@ import collections import json import pickle -from typing import Dict, Any +from typing import Any, Dict from femr.stat_utils import OnlineStatistics -def load_json_file( - json_file -): +def load_json_file(json_file): try: with open(json_file, "r", encoding="utf-8") as reader: file_contents = reader.read() @@ -18,32 +16,30 @@ def load_json_file( raise RuntimeError(f"Can't load the json file at {json_file} due to {e}") -def _agg_helper(*args, map_func): +def agg_helper(*args, map_func): result = map_func(*args) return {"data": [pickle.dumps(result)]} def map_statistics(batch: Dict[str, Any]) -> Dict[str, Any]: - if 'units' in batch: - concept_value_units = batch['units'] + if "units" in batch: + concept_value_units = batch["units"] else: - concept_value_units = [['default_unit' for _ in cons] for cons in batch['concept_ids']] + concept_value_units = [["default_unit" for _ in cons] for cons in batch["concept_ids"]] numeric_stats_by_lab = collections.defaultdict(OnlineStatistics) for concept_ids, concept_values, concept_value_indicators, units in zip( - batch['concept_ids'], - batch['concept_values'], - batch['concept_value_masks'], - concept_value_units + batch["concept_ids"], + batch["concept_values"], + batch["concept_value_masks"], + concept_value_units, ): for concept_id, concept_value, concept_value_indicator, unit in zip( - concept_ids, concept_values, concept_value_indicators, units + concept_ids, concept_values, concept_value_indicators, units ): if concept_value_indicator == 1: numeric_stats_by_lab[(concept_id, unit)].add(1, concept_value) - return { - 'numeric_stats_by_lab': numeric_stats_by_lab - } + return {"numeric_stats_by_lab": numeric_stats_by_lab} def agg_statistics(stats1, stats2): diff --git a/src/cehrbert/models/hierachical_bert_model_v2.py b/src/cehrbert/models/hierachical_bert_model_v2.py index 5ddfbff6..9fd87f3a 100644 --- a/src/cehrbert/models/hierachical_bert_model_v2.py +++ b/src/cehrbert/models/hierachical_bert_model_v2.py @@ -1,25 +1,33 @@ -from .layers.custom_layers import * +from .layers.custom_layers import ( + ConceptValueTransformationLayer, + Encoder, + ReusableEmbedding, + SimpleDecoderLayer, + TemporalTransformationLayer, + TiedOutputEmbedding, + tf, +) def transformer_hierarchical_bert_model( - num_of_visits, - num_of_concepts, - concept_vocab_size, - embedding_size, - depth: int, - num_heads: int, - transformer_dropout: float = 0.1, - embedding_dropout: float = 0.6, - l2_reg_penalty: float = 1e-4, - time_embeddings_size: int = 16, - include_att_prediction: bool = False, - include_visit_type_prediction: bool = False, - include_readmission: bool = False, - include_prolonged_length_stay: bool = False, - visit_vocab_size: int = None + num_of_visits, + num_of_concepts, + concept_vocab_size, + embedding_size, + depth: int, + num_heads: int, + transformer_dropout: float = 0.1, + embedding_dropout: float = 0.6, + l2_reg_penalty: float = 1e-4, + time_embeddings_size: int = 16, + include_att_prediction: bool = False, + include_visit_type_prediction: bool = False, + include_readmission: bool = False, + include_prolonged_length_stay: bool = False, + visit_vocab_size: int = None, ): """ - Create a hierarchical bert model + Create a hierarchical bert model. :param num_of_visits: :param num_of_concepts: @@ -40,100 +48,103 @@ def transformer_hierarchical_bert_model( """ # If the second tiered learning objectives are enabled, visit_vocab_size needs to be provided if include_visit_type_prediction and not visit_vocab_size: - raise RuntimeError(f'visit_vocab_size can not be null ' - f'when the second learning objectives are enabled') + raise RuntimeError(f"visit_vocab_size can not be null " f"when the second learning objectives are enabled") pat_seq = tf.keras.layers.Input( - shape=(num_of_visits, num_of_concepts,), - dtype='int32', - name='pat_seq' + shape=( + num_of_visits, + num_of_concepts, + ), + dtype="int32", + name="pat_seq", ) pat_seq_age = tf.keras.layers.Input( - shape=(num_of_visits, num_of_concepts,), - dtype='int32', - name='pat_seq_age' + shape=( + num_of_visits, + num_of_concepts, + ), + dtype="int32", + name="pat_seq_age", ) pat_seq_time = tf.keras.layers.Input( - shape=(num_of_visits, num_of_concepts,), - dtype='int32', - name='pat_seq_time' + shape=( + num_of_visits, + num_of_concepts, + ), + dtype="int32", + name="pat_seq_time", ) pat_mask = tf.keras.layers.Input( - shape=(num_of_visits, num_of_concepts,), - dtype='int32', - name='pat_mask' + shape=( + num_of_visits, + num_of_concepts, + ), + dtype="int32", + name="pat_mask", ) concept_values = tf.keras.layers.Input( - shape=(num_of_visits, num_of_concepts,), - dtype='float32', - name='concept_values' + shape=( + num_of_visits, + num_of_concepts, + ), + dtype="float32", + name="concept_values", ) concept_value_masks = tf.keras.layers.Input( - shape=(num_of_visits, num_of_concepts,), - dtype='int32', - name='concept_value_masks' + shape=( + num_of_visits, + num_of_concepts, + ), + dtype="int32", + name="concept_value_masks", ) - visit_mask = tf.keras.layers.Input( - shape=(num_of_visits,), - dtype='int32', - name='visit_mask') + visit_mask = tf.keras.layers.Input(shape=(num_of_visits,), dtype="int32", name="visit_mask") - visit_time_delta_att = tf.keras.layers.Input( - shape=(num_of_visits - 1,), - dtype='int32', - name='visit_time_delta_att' - ) + visit_time_delta_att = tf.keras.layers.Input(shape=(num_of_visits - 1,), dtype="int32", name="visit_time_delta_att") - visit_rank_order = tf.keras.layers.Input( - shape=(num_of_visits,), - dtype='int32', - name='visit_rank_order' - ) + visit_rank_order = tf.keras.layers.Input(shape=(num_of_visits,), dtype="int32", name="visit_rank_order") - visit_visit_type = tf.keras.layers.Input( - shape=(num_of_visits,), - dtype='int32', - name='masked_visit_type' - ) + visit_visit_type = tf.keras.layers.Input(shape=(num_of_visits,), dtype="int32", name="masked_visit_type") # Create a list of inputs so the model could reference these later default_inputs = [ - pat_seq, pat_seq_age, pat_seq_time, pat_mask, - concept_values, concept_value_masks, visit_mask, - visit_time_delta_att, visit_rank_order, visit_visit_type + pat_seq, + pat_seq_age, + pat_seq_time, + pat_mask, + concept_values, + concept_value_masks, + visit_mask, + visit_time_delta_att, + visit_rank_order, + visit_visit_type, ] # Expand dimensions for masking MultiHeadAttention in Concept Encoder - pat_concept_mask = tf.reshape( - pat_mask, - shape=(-1, num_of_concepts) - )[:, tf.newaxis, tf.newaxis, :] + pat_concept_mask = tf.reshape(pat_mask, shape=(-1, num_of_concepts))[:, tf.newaxis, tf.newaxis, :] # output the embedding_matrix: - l2_regularizer = (tf.keras.regularizers.l2(l2_reg_penalty) if l2_reg_penalty else None) + l2_regularizer = tf.keras.regularizers.l2(l2_reg_penalty) if l2_reg_penalty else None concept_embedding_layer = ReusableEmbedding( concept_vocab_size, embedding_size, - name='concept_embedding_layer', - embeddings_regularizer=l2_regularizer + name="concept_embedding_layer", + embeddings_regularizer=l2_regularizer, ) visit_type_embedding_layer = ReusableEmbedding( concept_vocab_size, embedding_size, - name='visit_type_embedding_layer', - embeddings_regularizer=l2_regularizer + name="visit_type_embedding_layer", + embeddings_regularizer=l2_regularizer, ) # Look up the embeddings for the concepts - concept_embeddings, embedding_matrix = concept_embedding_layer( - pat_seq - ) + concept_embeddings, embedding_matrix = concept_embedding_layer(pat_seq) concept_value_transformation_layer = ConceptValueTransformationLayer( - embedding_size=embedding_size, - name='concept_value_transformation_layer' + embedding_size=embedding_size, name="concept_value_transformation_layer" ) # Transform the concept embeddings by combining their concept embeddings with the @@ -141,84 +152,53 @@ def transformer_hierarchical_bert_model( concept_embeddings = concept_value_transformation_layer( concept_embeddings=concept_embeddings, concept_values=concept_values, - concept_value_masks=concept_value_masks + concept_value_masks=concept_value_masks, ) # Look up the embeddings for the att tokens - att_embeddings, _ = concept_embedding_layer( - visit_time_delta_att - ) + att_embeddings, _ = concept_embedding_layer(visit_time_delta_att) # Re-purpose token id 0 as the visit start embedding - visit_start_embeddings, _ = concept_embedding_layer( - tf.zeros_like( - visit_mask, - dtype=tf.int32 - ) - ) + visit_start_embeddings, _ = concept_embedding_layer(tf.zeros_like(visit_mask, dtype=tf.int32)) temporal_transformation_layer = TemporalTransformationLayer( time_embeddings_size=time_embeddings_size, embedding_size=embedding_size, - name='temporal_transformation_layer' + name="temporal_transformation_layer", ) # (batch, num_of_visits, num_of_concepts, embedding_size) - concept_embeddings = temporal_transformation_layer( - concept_embeddings, - pat_seq_age, - pat_seq_time, - visit_rank_order - ) + concept_embeddings = temporal_transformation_layer(concept_embeddings, pat_seq_age, pat_seq_time, visit_rank_order) # (batch, num_of_visits, embedding_size) # The first bert applied at the visit level concept_encoder = Encoder( - name='concept_encoder', + name="concept_encoder", num_layers=depth, d_model=embedding_size, num_heads=num_heads, - dropout_rate=transformer_dropout + dropout_rate=transformer_dropout, ) - concept_embeddings = tf.reshape( - concept_embeddings, - shape=(-1, num_of_concepts, embedding_size) - ) + concept_embeddings = tf.reshape(concept_embeddings, shape=(-1, num_of_concepts, embedding_size)) - concept_embeddings, _ = concept_encoder( - concept_embeddings, # be reused - pat_concept_mask # not change - ) + concept_embeddings, _ = concept_encoder(concept_embeddings, pat_concept_mask) # be reused # not change # (batch_size, num_of_visits, num_of_concepts, embedding_size) - concept_embeddings = tf.reshape( - concept_embeddings, - shape=(-1, num_of_visits, num_of_concepts, embedding_size) - ) + concept_embeddings = tf.reshape(concept_embeddings, shape=(-1, num_of_visits, num_of_concepts, embedding_size)) # Step 2 generate visit embeddings # Slice out the first contextualized embedding of each visit # (batch_size, num_of_visits, embedding_size) visit_embeddings = concept_embeddings[:, :, 0] - visit_type_embedding_dense_layer = tf.keras.layers.Dense( - embedding_size, - name='visit_type_embedding_dense_layer' - ) + visit_type_embedding_dense_layer = tf.keras.layers.Dense(embedding_size, name="visit_type_embedding_dense_layer") # (batch_size, num_of_visits, embedding_size) - visit_type_embeddings, visit_type_embedding_matrix = visit_type_embedding_layer( - visit_visit_type - ) + visit_type_embeddings, visit_type_embedding_matrix = visit_type_embedding_layer(visit_visit_type) # Combine visit_type_embeddings with visit_embeddings - visit_embeddings = visit_type_embedding_dense_layer( - tf.concat([ - visit_embeddings, - visit_type_embeddings - ], axis=-1) - ) + visit_embeddings = visit_type_embedding_dense_layer(tf.concat([visit_embeddings, visit_type_embeddings], axis=-1)) # (batch_size, num_of_visits, embedding_size) expanded_att_embeddings = tf.concat([att_embeddings, att_embeddings[:, 0:1, :]], axis=1) @@ -226,89 +206,68 @@ def transformer_hierarchical_bert_model( # Insert the att embeddings between visit embeddings # (batch_size, num_of_visits + num_of_visits + num_of_visits - 1, embedding_size) contextualized_visit_embeddings = tf.reshape( - tf.concat( - [visit_start_embeddings, - visit_embeddings, - expanded_att_embeddings], - axis=-1 - ), - (-1, 3 * num_of_visits, embedding_size) + tf.concat([visit_start_embeddings, visit_embeddings, expanded_att_embeddings], axis=-1), + (-1, 3 * num_of_visits, embedding_size), )[:, :-1, :] # Expand dimension for masking MultiHeadAttention in Visit Encoder - visit_mask_with_att = tf.reshape( - tf.tile(visit_mask[:, :, tf.newaxis], [1, 1, 3]), - (-1, num_of_visits * 3) - )[:, tf.newaxis, tf.newaxis, 1:] + visit_mask_with_att = tf.reshape(tf.tile(visit_mask[:, :, tf.newaxis], [1, 1, 3]), (-1, num_of_visits * 3))[ + :, tf.newaxis, tf.newaxis, 1: + ] # (num_of_visits_with_att, num_of_visits_with_att) look_ahead_mask_base = tf.cast( 1 - tf.linalg.band_part(tf.ones((num_of_visits, num_of_visits)), -1, 0), - dtype=tf.int32 + dtype=tf.int32, ) look_ahead_visit_mask_with_att = tf.reshape( - tf.tile( - look_ahead_mask_base[:, tf.newaxis, :, tf.newaxis], - [1, 3, 1, 3] - ), - shape=(num_of_visits * 3, num_of_visits * 3) + tf.tile(look_ahead_mask_base[:, tf.newaxis, :, tf.newaxis], [1, 3, 1, 3]), + shape=(num_of_visits * 3, num_of_visits * 3), )[:-1, :-1] look_ahead_concept_mask = tf.reshape( tf.tile( look_ahead_mask_base[:, tf.newaxis, :, tf.newaxis], - [1, num_of_concepts, 1, 1] + [1, num_of_concepts, 1, 1], ), - (num_of_concepts * num_of_visits, -1) + (num_of_concepts * num_of_visits, -1), ) # (batch_size, 1, num_of_visits_with_att, num_of_visits_with_att) - look_ahead_visit_mask_with_att = tf.maximum( - visit_mask_with_att, - look_ahead_visit_mask_with_att - ) + look_ahead_visit_mask_with_att = tf.maximum(visit_mask_with_att, look_ahead_visit_mask_with_att) # (batch_size, 1, num_of_visits * num_of_concepts, num_of_visits) - look_ahead_concept_mask = tf.maximum( - visit_mask[:, tf.newaxis, tf.newaxis, :], - look_ahead_concept_mask - ) + look_ahead_concept_mask = tf.maximum(visit_mask[:, tf.newaxis, tf.newaxis, :], look_ahead_concept_mask) # Second bert applied at the patient level to the visit embeddings visit_encoder = Encoder( - name='visit_encoder', + name="visit_encoder", num_layers=depth, d_model=embedding_size, num_heads=num_heads, - dropout_rate=transformer_dropout + dropout_rate=transformer_dropout, ) # Feed augmented visit embeddings into encoders to get contextualized visit embeddings - contextualized_visit_embeddings, _ = visit_encoder( - contextualized_visit_embeddings, - look_ahead_visit_mask_with_att - ) + contextualized_visit_embeddings, _ = visit_encoder(contextualized_visit_embeddings, look_ahead_visit_mask_with_att) # Pad contextualized_visit_embeddings on axis 1 with one extra visit so we can extract the # visit embeddings using the reshape trick expanded_contextualized_visit_embeddings = tf.concat( - [contextualized_visit_embeddings, - contextualized_visit_embeddings[:, 0:1, :]], - axis=1 + [contextualized_visit_embeddings, contextualized_visit_embeddings[:, 0:1, :]], + axis=1, ) # Extract the visit embeddings elements visit_embeddings_without_att = tf.reshape( - expanded_contextualized_visit_embeddings, (-1, num_of_visits, 3 * embedding_size) - )[:, :, embedding_size: embedding_size * 2] + expanded_contextualized_visit_embeddings, + (-1, num_of_visits, 3 * embedding_size), + )[:, :, embedding_size : embedding_size * 2] # # Step 3 decoder applied to patient level # Reshape the data in visit view back to patient view: # (batch, num_of_visits * num_of_concepts, embedding_size) - concept_embeddings = tf.reshape( - concept_embeddings, - shape=(-1, num_of_visits * num_of_concepts, embedding_size) - ) + concept_embeddings = tf.reshape(concept_embeddings, shape=(-1, num_of_visits * num_of_concepts, embedding_size)) # Let local concept embeddings access the global representatives of each visit global_concept_embeddings_layer = SimpleDecoderLayer( @@ -316,44 +275,35 @@ def transformer_hierarchical_bert_model( num_heads=num_heads, rate=transformer_dropout, dff=512, - name='global_concept_embeddings_layer' + name="global_concept_embeddings_layer", ) global_concept_embeddings, _ = global_concept_embeddings_layer( - concept_embeddings, - visit_embeddings_without_att, - look_ahead_concept_mask + concept_embeddings, visit_embeddings_without_att, look_ahead_concept_mask ) concept_output_layer = TiedOutputEmbedding( projection_regularizer=l2_regularizer, projection_dropout=embedding_dropout, - name='concept_prediction_logits' + name="concept_prediction_logits", ) - concept_softmax_layer = tf.keras.layers.Softmax( - name='concept_predictions' - ) + concept_softmax_layer = tf.keras.layers.Softmax(name="concept_predictions") - concept_predictions = concept_softmax_layer( - concept_output_layer([global_concept_embeddings, embedding_matrix]) - ) + concept_predictions = concept_softmax_layer(concept_output_layer([global_concept_embeddings, embedding_matrix])) outputs = [concept_predictions] if include_att_prediction: # Extract the ATT embeddings contextualized_att_embeddings = tf.reshape( - expanded_contextualized_visit_embeddings, (-1, num_of_visits, 3 * embedding_size) - )[:, :-1, embedding_size * 2:] + expanded_contextualized_visit_embeddings, + (-1, num_of_visits, 3 * embedding_size), + )[:, :-1, embedding_size * 2 :] # Create the att to concept mask ATT tokens only attend to the concepts in the # neighboring visits - att_concept_mask = create_att_concept_mask( - num_of_concepts, - num_of_visits, - visit_mask - ) + att_concept_mask = create_att_concept_mask(num_of_concepts, num_of_visits, visit_mask) # Use the simple decoder layer to decode att embeddings using the neighboring concept # embeddings @@ -362,22 +312,18 @@ def transformer_hierarchical_bert_model( num_heads=num_heads, rate=transformer_dropout, dff=512, - name='global_att_embeddings_layer' + name="global_att_embeddings_layer", ) contextualized_att_embeddings, _ = global_att_embeddings_layer( - contextualized_att_embeddings, - concept_embeddings, - att_concept_mask + contextualized_att_embeddings, concept_embeddings, att_concept_mask ) att_prediction_layer = tf.keras.layers.Softmax( - name='att_predictions', + name="att_predictions", ) - att_predictions = att_prediction_layer( - concept_output_layer([contextualized_att_embeddings, embedding_matrix]) - ) + att_predictions = att_prediction_layer(concept_output_layer([contextualized_att_embeddings, embedding_matrix])) outputs.append(att_predictions) if include_visit_type_prediction: @@ -386,85 +332,49 @@ def transformer_hierarchical_bert_model( visit_type_prediction_output_layer = TiedOutputEmbedding( projection_regularizer=l2_regularizer, projection_dropout=embedding_dropout, - name='visit_type_prediction_logits' + name="visit_type_prediction_logits", ) - visit_softmax_layer = tf.keras.layers.Softmax( - name='visit_predictions' - ) + visit_softmax_layer = tf.keras.layers.Softmax(name="visit_predictions") visit_predictions = visit_softmax_layer( - visit_type_prediction_output_layer( - [visit_embeddings_without_att, visit_type_embedding_matrix] - ) + visit_type_prediction_output_layer([visit_embeddings_without_att, visit_type_embedding_matrix]) ) outputs.append(visit_predictions) if include_readmission: - is_readmission_layer = tf.keras.layers.Dense( - 1, - activation='sigmoid', - name='is_readmission' - ) + is_readmission_layer = tf.keras.layers.Dense(1, activation="sigmoid", name="is_readmission") - is_readmission_output = is_readmission_layer( - visit_embeddings_without_att - ) + is_readmission_output = is_readmission_layer(visit_embeddings_without_att) outputs.append(is_readmission_output) if include_prolonged_length_stay: - visit_prolonged_stay_layer = tf.keras.layers.Dense( - 1, - activation='sigmoid', - name='visit_prolonged_stay' - ) + visit_prolonged_stay_layer = tf.keras.layers.Dense(1, activation="sigmoid", name="visit_prolonged_stay") - visit_prolonged_stay_output = visit_prolonged_stay_layer( - visit_embeddings_without_att - ) + visit_prolonged_stay_output = visit_prolonged_stay_layer(visit_embeddings_without_att) outputs.append(visit_prolonged_stay_output) - hierarchical_bert = tf.keras.Model( - inputs=default_inputs, - outputs=outputs - ) + hierarchical_bert = tf.keras.Model(inputs=default_inputs, outputs=outputs) return hierarchical_bert -def create_att_concept_mask( - num_of_concepts, - num_of_visits, - visit_mask -): +def create_att_concept_mask(num_of_concepts, num_of_visits, visit_mask): """ - :param num_of_concepts: + :param num_of_visits: :param visit_mask: :return: """ - att_concept_mask = tf.eye( - num_of_visits - 1, - num_of_visits, - dtype=tf.int32 - ) - att_concept_mask = 1 - att_concept_mask - tf.roll( - att_concept_mask, - axis=-1, - shift=1 - )[tf.newaxis, :, :] - att_concept_mask = tf.maximum( - att_concept_mask, - visit_mask[:, 1:, tf.newaxis] - ) + att_concept_mask = tf.eye(num_of_visits - 1, num_of_visits, dtype=tf.int32) + att_concept_mask = 1 - att_concept_mask - tf.roll(att_concept_mask, axis=-1, shift=1)[tf.newaxis, :, :] + att_concept_mask = tf.maximum(att_concept_mask, visit_mask[:, 1:, tf.newaxis]) att_concept_mask = tf.reshape( - tf.tile( - att_concept_mask[:, :, :, tf.newaxis], - [1, 1, 1, num_of_concepts] - ), (-1, 1, num_of_visits - 1, num_of_concepts * num_of_visits) + tf.tile(att_concept_mask[:, :, :, tf.newaxis], [1, 1, 1, num_of_concepts]), + (-1, 1, num_of_visits - 1, num_of_concepts * num_of_visits), ) return att_concept_mask diff --git a/src/cehrbert/models/hierachical_phenotype_model_new.py b/src/cehrbert/models/hierachical_phenotype_model_new.py index a9191d61..df73b197 100644 --- a/src/cehrbert/models/hierachical_phenotype_model_new.py +++ b/src/cehrbert/models/hierachical_phenotype_model_new.py @@ -1,89 +1,81 @@ -from .layers.custom_layers import * from .hierachical_bert_model_v2 import create_att_concept_mask - - -def create_visit_masks( - visit_mask, - num_of_visits -): +from .layers.custom_layers import ( + ConceptValueTransformationLayer, + Encoder, + ReusableEmbedding, + SimpleDecoderLayer, + TemporalTransformationLayer, + TiedOutputEmbedding, + VisitPhenotypeLayer, + tf, +) + + +def create_visit_masks(visit_mask, num_of_visits): # Expand dimension for masking MultiHeadAttention in Visit Encoder - visit_mask_with_att = tf.reshape( - tf.tile(visit_mask[:, :, tf.newaxis], [1, 1, 3]), - (-1, num_of_visits * 3) - )[:, tf.newaxis, tf.newaxis, 1:] + visit_mask_with_att = tf.reshape(tf.tile(visit_mask[:, :, tf.newaxis], [1, 1, 3]), (-1, num_of_visits * 3))[ + :, tf.newaxis, tf.newaxis, 1: + ] # (num_of_visits_with_att, num_of_visits_with_att) look_ahead_mask_base = tf.cast( 1 - tf.linalg.band_part(tf.ones((num_of_visits, num_of_visits)), -1, 0), - dtype=tf.int32 + dtype=tf.int32, ) look_ahead_visit_mask_with_att = tf.reshape( - tf.tile( - look_ahead_mask_base[:, tf.newaxis, :, tf.newaxis], - [1, 3, 1, 3] - ), - shape=(num_of_visits * 3, num_of_visits * 3) + tf.tile(look_ahead_mask_base[:, tf.newaxis, :, tf.newaxis], [1, 3, 1, 3]), + shape=(num_of_visits * 3, num_of_visits * 3), )[:-1, :-1] # (batch_size, 1, num_of_visits_with_att, num_of_visits_with_att) - look_ahead_visit_mask_with_att = tf.maximum( - visit_mask_with_att, - look_ahead_visit_mask_with_att - ) + look_ahead_visit_mask_with_att = tf.maximum(visit_mask_with_att, look_ahead_visit_mask_with_att) return look_ahead_visit_mask_with_att -def create_concept_masks( - visit_mask, - num_of_visits, - num_of_concepts -): +def create_concept_masks(visit_mask, num_of_visits, num_of_concepts): # Expand dimension for masking MultiHeadAttention in Visit Encoder # (num_of_visits_with_att, num_of_visits_with_att) look_ahead_mask_base = tf.cast( 1 - tf.linalg.band_part(tf.ones((num_of_visits, num_of_visits)), -1, 0), - dtype=tf.int32 + dtype=tf.int32, ) look_ahead_concept_mask = tf.reshape( tf.tile( look_ahead_mask_base[:, tf.newaxis, :, tf.newaxis], - [1, num_of_concepts, 1, 1] + [1, num_of_concepts, 1, 1], ), - (num_of_concepts * num_of_visits, -1) + (num_of_concepts * num_of_visits, -1), ) # (batch_size, 1, num_of_visits * num_of_concepts, num_of_visits) - look_ahead_concept_mask = tf.maximum( - visit_mask[:, tf.newaxis, tf.newaxis, :], - look_ahead_concept_mask - ) + look_ahead_concept_mask = tf.maximum(visit_mask[:, tf.newaxis, tf.newaxis, :], look_ahead_concept_mask) return look_ahead_concept_mask class HierarchicalBertModel(tf.keras.Model): def __init__( - self, - num_of_visits: int, - num_of_concepts: int, - concept_vocab_size: int, - visit_vocab_size: int, - embedding_size: int, - depth: int, - num_heads: int, - transformer_dropout: float = 0.1, - embedding_dropout: float = 0.6, - l2_reg_penalty: float = 1e-4, - time_embeddings_size: int = 16, - num_of_phenotypes: int = 20, - num_of_phenotype_neighbors: int = 3, - num_of_concept_neighbors: int = 10, - phenotype_entropy_weight: float = 2e-05, - phenotype_euclidean_weight: float = 2e-05, - phenotype_concept_distance_weight: float = 1e-04, - include_att_prediction: bool = True, - **kwargs + self, + num_of_visits: int, + num_of_concepts: int, + concept_vocab_size: int, + visit_vocab_size: int, + embedding_size: int, + depth: int, + num_heads: int, + transformer_dropout: float = 0.1, + embedding_dropout: float = 0.6, + l2_reg_penalty: float = 1e-4, + time_embeddings_size: int = 16, + num_of_phenotypes: int = 20, + num_of_phenotype_neighbors: int = 3, + num_of_concept_neighbors: int = 10, + phenotype_entropy_weight: float = 2e-05, + phenotype_euclidean_weight: float = 2e-05, + phenotype_concept_distance_weight: float = 1e-04, + include_att_prediction: bool = True, + **kwargs, ): super(HierarchicalBertModel, self).__init__(**kwargs) @@ -107,59 +99,55 @@ def __init__( self.include_att_prediction = include_att_prediction # output the embedding_matrix: - self.l2_regularizer = ( - tf.keras.regularizers.l2(l2_reg_penalty) if l2_reg_penalty else None - ) + self.l2_regularizer = tf.keras.regularizers.l2(l2_reg_penalty) if l2_reg_penalty else None self.concept_embedding_layer = ReusableEmbedding( concept_vocab_size, embedding_size, - name='concept_embedding_layer', - embeddings_regularizer=self.l2_regularizer + name="concept_embedding_layer", + embeddings_regularizer=self.l2_regularizer, ) self.visit_type_embedding_layer = ReusableEmbedding( visit_vocab_size, embedding_size, - name='visit_type_embedding_layer', - embeddings_regularizer=self.l2_regularizer + name="visit_type_embedding_layer", + embeddings_regularizer=self.l2_regularizer, ) # a layer for combining concept values with concept embeddings self.concept_value_transformation_layer = ConceptValueTransformationLayer( - embedding_size=embedding_size, - name='concept_value_transformation_layer' + embedding_size=embedding_size, name="concept_value_transformation_layer" ) # a layer for combining time/age embeddings with the concept embeddings self.temporal_transformation_layer = TemporalTransformationLayer( time_embeddings_size=time_embeddings_size, embedding_size=embedding_size, - name='temporal_transformation_layer' + name="temporal_transformation_layer", ) # (batch, num_of_visits, embedding_size) # The first bert applied at the visit level self.concept_encoder = Encoder( - name='concept_encoder', + name="concept_encoder", num_layers=depth, d_model=embedding_size, num_heads=num_heads, - dropout_rate=transformer_dropout + dropout_rate=transformer_dropout, ) # Second bert applied at the patient level to the visit embeddings self.visit_encoder = Encoder( - name='visit_encoder', + name="visit_encoder", num_layers=depth, d_model=embedding_size, num_heads=num_heads, - dropout_rate=transformer_dropout + dropout_rate=transformer_dropout, ) self.visit_type_embedding_dense_layer = tf.keras.layers.Dense( - embedding_size, - name='visit_type_embedding_dense_layer' + embedding_size, name="visit_type_embedding_dense_layer" ) # A hidden phenotype layer that each visit embedding needs to go through @@ -172,7 +160,7 @@ def __init__( phenotype_entropy_weight=phenotype_entropy_weight, phenotype_euclidean_weight=phenotype_euclidean_weight, phenotype_concept_distance_weight=phenotype_concept_distance_weight, - name='hidden_visit_embeddings' + name="hidden_visit_embeddings", ) # Let local concept embeddings access the global representatives of each visit @@ -181,18 +169,16 @@ def __init__( num_heads=num_heads, rate=transformer_dropout, dff=512, - name='global_concept_embeddings_layer' + name="global_concept_embeddings_layer", ) self.concept_output_layer = TiedOutputEmbedding( projection_regularizer=self.l2_regularizer, projection_dropout=embedding_dropout, - name='concept_prediction_logits' + name="concept_prediction_logits", ) - self.concept_softmax_layer = tf.keras.layers.Softmax( - name='concept_predictions' - ) + self.concept_softmax_layer = tf.keras.layers.Softmax(name="concept_predictions") if include_att_prediction: self.global_att_embeddings_layer = SimpleDecoderLayer( @@ -200,59 +186,43 @@ def __init__( num_heads=num_heads, rate=transformer_dropout, dff=512, - name='global_att_embeddings_layer' + name="global_att_embeddings_layer", ) self.att_prediction_layer = tf.keras.layers.Softmax( - name='att_predictions', + name="att_predictions", ) - def call( - self, - inputs, - **kwargs - ): + def call(self, inputs, **kwargs): - pat_seq = inputs['pat_seq'] - pat_seq_age = inputs['pat_seq_age'] - pat_seq_time = inputs['pat_seq_time'] - pat_mask = inputs['pat_mask'] - concept_values = inputs['concept_values'] - concept_value_masks = inputs['concept_value_masks'] - visit_mask = inputs['visit_mask'] - visit_time_delta_att = inputs['visit_time_delta_att'] - visit_rank_order = inputs['visit_rank_order'] - visit_visit_type = inputs['masked_visit_type'] + pat_seq = inputs["pat_seq"] + pat_seq_age = inputs["pat_seq_age"] + pat_seq_time = inputs["pat_seq_time"] + pat_mask = inputs["pat_mask"] + concept_values = inputs["concept_values"] + concept_value_masks = inputs["concept_value_masks"] + visit_mask = inputs["visit_mask"] + visit_time_delta_att = inputs["visit_time_delta_att"] + visit_rank_order = inputs["visit_rank_order"] + visit_visit_type = inputs["masked_visit_type"] # Expand dimensions for masking MultiHeadAttention in Concept Encoder - pat_concept_mask = tf.reshape( - pat_mask, - shape=(-1, self.num_of_concepts) - )[:, tf.newaxis, tf.newaxis, :] + pat_concept_mask = tf.reshape(pat_mask, shape=(-1, self.num_of_concepts))[:, tf.newaxis, tf.newaxis, :] # Look up the embeddings for the concepts - concept_embeddings, embedding_matrix = self.concept_embedding_layer( - pat_seq - ) + concept_embeddings, embedding_matrix = self.concept_embedding_layer(pat_seq) # Look up the embeddings for the att tokens - att_embeddings, _ = self.concept_embedding_layer( - visit_time_delta_att - ) + att_embeddings, _ = self.concept_embedding_layer(visit_time_delta_att) # Re-purpose token id 0 as the visit start embedding - visit_start_embeddings, _ = self.concept_embedding_layer( - tf.zeros_like( - visit_mask, - dtype=tf.int32 - ) - ) + visit_start_embeddings, _ = self.concept_embedding_layer(tf.zeros_like(visit_mask, dtype=tf.int32)) # Transform the concept embeddings by combining their concept embeddings with the # corresponding val concept_embeddings = self.concept_value_transformation_layer( concept_embeddings=concept_embeddings, concept_values=concept_values, - concept_value_masks=concept_value_masks + concept_value_masks=concept_value_masks, ) # (batch, num_of_visits, num_of_concepts, embedding_size) @@ -261,26 +231,21 @@ def call( concept_embeddings, pat_seq_age, pat_seq_time, - visit_rank_order + visit_rank_order, ) # (batch * num_of_visits, num_of_concepts, embedding_size) - concept_embeddings = tf.reshape( - concept_embeddings, - shape=(-1, self.num_of_concepts, self.embedding_size) - ) + concept_embeddings = tf.reshape(concept_embeddings, shape=(-1, self.num_of_concepts, self.embedding_size)) # Step 1: apply the first bert to the concept embeddings for each visit concept_embeddings, _ = self.concept_encoder( - concept_embeddings, # be reused - pat_concept_mask, # not change, - **kwargs + concept_embeddings, pat_concept_mask, **kwargs # be reused # not change, ) # (batch, num_of_visits, num_of_concepts, embedding_size) concept_embeddings = tf.reshape( concept_embeddings, - shape=(-1, self.num_of_visits, self.num_of_concepts, self.embedding_size) + shape=(-1, self.num_of_visits, self.num_of_concepts, self.embedding_size), ) # Step 2: generate visit embeddings @@ -289,87 +254,61 @@ def call( visit_embeddings = concept_embeddings[:, :, 0] # (batch_size, num_of_visits, embedding_size) - visit_type_embeddings, visit_type_embedding_matrix = self.visit_type_embedding_layer( - visit_visit_type, - **kwargs - ) + visit_type_embeddings, visit_type_embedding_matrix = self.visit_type_embedding_layer(visit_visit_type, **kwargs) # Combine visit_type_embeddings with visit_embeddings visit_embeddings = self.visit_type_embedding_dense_layer( - tf.concat([ - visit_embeddings, - visit_type_embeddings - ], axis=-1), - **kwargs + tf.concat([visit_embeddings, visit_type_embeddings], axis=-1), **kwargs ) # (batch_size, num_of_visits, embedding_size) - expanded_att_embeddings = tf.concat( - [att_embeddings, att_embeddings[:, 0:1, :]], - axis=1 - ) + expanded_att_embeddings = tf.concat([att_embeddings, att_embeddings[:, 0:1, :]], axis=1) # Insert the att embeddings between visit embeddings # (batch_size, num_of_visits + num_of_visits + num_of_visits - 1, embedding_size) visit_embeddings = tf.reshape( tf.concat( - [visit_start_embeddings, - visit_embeddings, - expanded_att_embeddings], - axis=-1 + [visit_start_embeddings, visit_embeddings, expanded_att_embeddings], + axis=-1, ), - (-1, 3 * self.num_of_visits, self.embedding_size) + (-1, 3 * self.num_of_visits, self.embedding_size), )[:, :-1, :] - look_ahead_visit_mask_with_att = create_visit_masks( - visit_mask, self.num_of_visits - ) + look_ahead_visit_mask_with_att = create_visit_masks(visit_mask, self.num_of_visits) # Feed augmented visit embeddings into encoders to get contextualized visit embeddings - visit_embeddings, _ = self.visit_encoder( - visit_embeddings, - look_ahead_visit_mask_with_att, - **kwargs - ) + visit_embeddings, _ = self.visit_encoder(visit_embeddings, look_ahead_visit_mask_with_att, **kwargs) # Pad contextualized_visit_embeddings on axis 1 with one extra visit, we can extract the # visit embeddings using reshape trick - padded_visit_embeddings = tf.concat( - [visit_embeddings, - visit_embeddings[:, 0:1, :]], - axis=1 - ) + padded_visit_embeddings = tf.concat([visit_embeddings, visit_embeddings[:, 0:1, :]], axis=1) # Extract the visit embeddings elements visit_embeddings_without_att = tf.reshape( padded_visit_embeddings, (-1, self.num_of_visits, 3 * self.embedding_size) - )[:, :, self.embedding_size: self.embedding_size * 2] + )[:, :, self.embedding_size : self.embedding_size * 2] # (batch_size, num_of_visits, vocab_size) - visit_embeddings_without_att, _, = self.visit_phenotype_layer( - [visit_embeddings_without_att, - visit_mask, - embedding_matrix], - **kwargs - ) + ( + visit_embeddings_without_att, + _, + ) = self.visit_phenotype_layer([visit_embeddings_without_att, visit_mask, embedding_matrix], **kwargs) # # Step 3 decoder applied to patient level # Reshape the data in visit view back to patient view: # (batch, num_of_visits * num_of_concepts, embedding_size) concept_embeddings = tf.reshape( concept_embeddings, - shape=(-1, self.num_of_visits * self.num_of_concepts, self.embedding_size) + shape=(-1, self.num_of_visits * self.num_of_concepts, self.embedding_size), ) - look_ahead_concept_mask = create_concept_masks( - visit_mask, self.num_of_visits, self.num_of_concepts - ) + look_ahead_concept_mask = create_concept_masks(visit_mask, self.num_of_visits, self.num_of_concepts) global_concept_embeddings, _ = self.global_concept_embeddings_layer( concept_embeddings, visit_embeddings_without_att, look_ahead_concept_mask, - **kwargs + **kwargs, ) concept_predictions = self.concept_softmax_layer( @@ -377,7 +316,7 @@ def call( ) outputs = { - 'concept_predictions': concept_predictions, + "concept_predictions": concept_predictions, # 'padded_visit_embeddings': padded_visit_embeddings, # 'visit_embeddings_without_att': visit_embeddings_without_att } @@ -385,16 +324,13 @@ def call( if self.include_att_prediction: # Extract the ATT embeddings contextualized_att_embeddings = tf.reshape( - padded_visit_embeddings, (-1, self.num_of_visits, 3 * self.embedding_size) - )[:, :-1, self.embedding_size * 2:] + padded_visit_embeddings, + (-1, self.num_of_visits, 3 * self.embedding_size), + )[:, :-1, self.embedding_size * 2 :] # Create the att to concept mask ATT tokens only attend to the concepts in the # neighboring visits - att_concept_mask = create_att_concept_mask( - self.num_of_concepts, - self.num_of_visits, - visit_mask - ) + att_concept_mask = create_att_concept_mask(self.num_of_concepts, self.num_of_visits, visit_mask) # Use the simple decoder layer to decode att embeddings using the neighboring concept # embeddings @@ -402,45 +338,45 @@ def call( contextualized_att_embeddings, concept_embeddings, att_concept_mask, - **kwargs + **kwargs, ) att_predictions = self.att_prediction_layer( self.concept_output_layer([contextualized_att_embeddings, embedding_matrix]) ) - outputs['att_predictions'] = att_predictions + outputs["att_predictions"] = att_predictions return outputs def get_config(self): config = super().get_config() - config['concept_vocab_size'] = self.concept_vocab_size - config['visit_vocab_size'] = self.visit_vocab_size - config['embedding_size'] = self.embedding_size - config['time_embeddings_size'] = self.time_embeddings_size - config['depth'] = self.depth - config['num_heads'] = self.num_heads - config['transformer_dropout'] = self.transformer_dropout - config['embedding_dropout'] = self.embedding_dropout - config['l2_reg_penalty'] = self.l2_reg_penalty - config['num_of_phenotypes'] = self.num_of_phenotypes - config['num_of_phenotype_neighbors'] = self.num_of_phenotype_neighbors - config['num_of_concept_neighbors'] = self.num_of_concept_neighbors - config['phenotype_entropy_weight'] = self.phenotype_entropy_weight - config['phenotype_euclidean_weight'] = self.phenotype_euclidean_weight - config['phenotype_concept_distance_weight'] = self.phenotype_concept_distance_weight - config['include_att_prediction'] = self.include_att_prediction + config["concept_vocab_size"] = self.concept_vocab_size + config["visit_vocab_size"] = self.visit_vocab_size + config["embedding_size"] = self.embedding_size + config["time_embeddings_size"] = self.time_embeddings_size + config["depth"] = self.depth + config["num_heads"] = self.num_heads + config["transformer_dropout"] = self.transformer_dropout + config["embedding_dropout"] = self.embedding_dropout + config["l2_reg_penalty"] = self.l2_reg_penalty + config["num_of_phenotypes"] = self.num_of_phenotypes + config["num_of_phenotype_neighbors"] = self.num_of_phenotype_neighbors + config["num_of_concept_neighbors"] = self.num_of_concept_neighbors + config["phenotype_entropy_weight"] = self.phenotype_entropy_weight + config["phenotype_euclidean_weight"] = self.phenotype_euclidean_weight + config["phenotype_concept_distance_weight"] = self.phenotype_concept_distance_weight + config["include_att_prediction"] = self.include_att_prediction return config class MultiTaskHierarchicalBertModel(HierarchicalBertModel): def __init__( - self, - include_visit_prediction: bool, - include_readmission: bool, - include_prolonged_length_stay: bool, - *args, - **kwargs + self, + include_visit_prediction: bool, + include_readmission: bool, + include_prolonged_length_stay: bool, + *args, + **kwargs, ): super(MultiTaskHierarchicalBertModel, self).__init__(*args, **kwargs) @@ -452,64 +388,46 @@ def __init__( self.visit_prediction_output_layer = TiedOutputEmbedding( projection_regularizer=self.l2_regularizer, projection_dropout=self.embedding_dropout, - name='visit_type_prediction_logits' + name="visit_type_prediction_logits", ) - self.visit_softmax_layer = tf.keras.layers.Softmax( - name='visit_predictions' - ) + self.visit_softmax_layer = tf.keras.layers.Softmax(name="visit_predictions") if include_readmission: - self.is_readmission_layer = tf.keras.layers.Dense( - 1, - activation='sigmoid', - name='is_readmission' - ) + self.is_readmission_layer = tf.keras.layers.Dense(1, activation="sigmoid", name="is_readmission") if include_prolonged_length_stay: self.visit_prolonged_stay_layer = tf.keras.layers.Dense( - 1, - activation='sigmoid', - name='visit_prolonged_stay' + 1, activation="sigmoid", name="visit_prolonged_stay" ) def get_config(self): config = super().get_config() - config['include_visit_prediction'] = self.include_visit_prediction - config['include_readmission'] = self.include_readmission - config['include_prolonged_length_stay'] = self.include_prolonged_length_stay + config["include_visit_prediction"] = self.include_visit_prediction + config["include_readmission"] = self.include_readmission + config["include_prolonged_length_stay"] = self.include_prolonged_length_stay return config - def call( - self, - inputs, - **kwargs - ): + def call(self, inputs, **kwargs): # Get the outputs from the super class outputs = super(MultiTaskHierarchicalBertModel, self).call(inputs, **kwargs) - visit_embeddings_without_att = outputs['visit_embeddings_without_att'] + visit_embeddings_without_att = outputs["visit_embeddings_without_att"] if self.include_visit_prediction: # Slice out the visit embeddings (CLS tokens) visit_type_embedding_matrix = self.visit_type_embedding_layer.embeddings visit_predictions = self.visit_softmax_layer( - self.visit_prediction_output_layer( - [visit_embeddings_without_att, visit_type_embedding_matrix] - ) + self.visit_prediction_output_layer([visit_embeddings_without_att, visit_type_embedding_matrix]) ) - outputs['visit_predictions'] = visit_predictions + outputs["visit_predictions"] = visit_predictions if self.include_readmission: - is_readmission_output = self.is_readmission_layer( - visit_embeddings_without_att - ) - outputs['is_readmission'] = is_readmission_output + is_readmission_output = self.is_readmission_layer(visit_embeddings_without_att) + outputs["is_readmission"] = is_readmission_output if self.include_prolonged_length_stay: - visit_prolonged_stay_output = self.visit_prolonged_stay_layer( - visit_embeddings_without_att - ) - outputs['visit_prolonged_stay'] = visit_prolonged_stay_output + visit_prolonged_stay_output = self.visit_prolonged_stay_layer(visit_embeddings_without_att) + outputs["visit_prolonged_stay"] = visit_prolonged_stay_output return outputs diff --git a/src/cehrbert/models/layers/custom_layers.py b/src/cehrbert/models/layers/custom_layers.py index b3ded591..e916d4fd 100644 --- a/src/cehrbert/models/layers/custom_layers.py +++ b/src/cehrbert/models/layers/custom_layers.py @@ -1,4 +1,3 @@ -import platform import numpy as np import tensorflow as tf from tensorflow.keras.utils import get_custom_objects @@ -14,9 +13,7 @@ def get_angles(pos, i, d_model): def positional_encoding(position, d_model): - angle_rads = get_angles(np.arange(position)[:, np.newaxis], - np.arange(d_model)[np.newaxis, :], - d_model) + angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model) # apply sin to even indices in the array; 2i angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) @@ -30,10 +27,12 @@ def positional_encoding(position, d_model): def point_wise_feed_forward_network(d_model, dff): - return tf.keras.Sequential([ - tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff) - tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model) - ]) + return tf.keras.Sequential( + [ + tf.keras.layers.Dense(dff, activation="relu"), # (batch_size, seq_len, dff) + tf.keras.layers.Dense(d_model), # (batch_size, seq_len, d_model) + ] + ) class EncoderLayer(tf.keras.layers.Layer): @@ -51,7 +50,7 @@ def __init__(self, d_model, num_heads, dff, rate=0.1, *args, **kwargs): num_heads=num_heads, key_dim=d_model // num_heads, output_shape=d_model, - attention_axes=1 + attention_axes=1, ) self.ffn = point_wise_feed_forward_network(d_model, dff) @@ -63,10 +62,10 @@ def __init__(self, d_model, num_heads, dff, rate=0.1, *args, **kwargs): def get_config(self): config = super().get_config() - config['d_model'] = self.d_model - config['num_heads'] = self.num_heads - config['dff'] = self.dff - config['rate'] = self.rate + config["d_model"] = self.d_model + config["num_heads"] = self.num_heads + config["dff"] = self.dff + config["rate"] = self.rate return config def call(self, x, mask, **kwargs): @@ -82,7 +81,7 @@ def call(self, x, mask, **kwargs): value=x, attention_mask=mask, return_attention_scores=True, - **kwargs + **kwargs, ) attn_output = self.dropout1(attn_output, **kwargs) out1 = self.layernorm1(x + attn_output) # (batch_size, input_seq_len, d_model) @@ -94,7 +93,16 @@ def call(self, x, mask, **kwargs): class Encoder(tf.keras.layers.Layer): - def __init__(self, num_layers, d_model, num_heads, dff=2148, dropout_rate=0.1, *args, **kwargs): + def __init__( + self, + num_layers, + d_model, + num_heads, + dff=2148, + dropout_rate=0.1, + *args, + **kwargs, + ): super(Encoder, self).__init__(*args, **kwargs) self.d_model = d_model @@ -103,17 +111,17 @@ def __init__(self, num_layers, d_model, num_heads, dff=2148, dropout_rate=0.1, * self.dff = dff self.dropout_rate = dropout_rate self.enc_layers = [ - EncoderLayer(d_model, num_heads, dff, dropout_rate, name='transformer' + str(i)) - for i in range(num_layers)] + EncoderLayer(d_model, num_heads, dff, dropout_rate, name="transformer" + str(i)) for i in range(num_layers) + ] self.dropout = tf.keras.layers.Dropout(dropout_rate) def get_config(self): config = super().get_config() - config['num_layers'] = self.num_layers - config['d_model'] = self.d_model - config['num_heads'] = self.num_heads - config['dff'] = self.dff - config['dropout_rate'] = self.dropout_rate + config["num_layers"] = self.num_layers + config["d_model"] = self.d_model + config["num_heads"] = self.num_heads + config["dff"] = self.dff + config["dropout_rate"] = self.dropout_rate return config def call(self, x, mask, **kwargs): @@ -126,14 +134,14 @@ def call(self, x, mask, **kwargs): class GptDecoder(tf.keras.layers.Layer): def __init__( - self, - num_layers, - d_model, - num_heads, - dff=2148, - dropout_rate=0.1, - *args, - **kwargs + self, + num_layers, + d_model, + num_heads, + dff=2148, + dropout_rate=0.1, + *args, + **kwargs, ): super(GptDecoder, self).__init__(*args, **kwargs) @@ -143,18 +151,18 @@ def __init__( self.dff = dff self.dropout_rate = dropout_rate self.decoder_layers = [ - GptDecoderLayer(d_model, num_heads, dff, dropout_rate, name='transformer' + str(i)) + GptDecoderLayer(d_model, num_heads, dff, dropout_rate, name="transformer" + str(i)) for i in range(num_layers) ] self.dropout = tf.keras.layers.Dropout(dropout_rate) def get_config(self): config = super().get_config() - config['num_layers'] = self.num_layers - config['d_model'] = self.d_model - config['num_heads'] = self.num_heads - config['dff'] = self.dff - config['dropout_rate'] = self.dropout_rate + config["num_layers"] = self.num_layers + config["d_model"] = self.d_model + config["num_heads"] = self.num_heads + config["dff"] = self.dff + config["dropout_rate"] = self.dropout_rate return config def call(self, x, **kwargs): @@ -182,13 +190,13 @@ def __init__(self, d_model, num_heads, dff, rate=0.1, *args, **kwargs): num_heads=num_heads, key_dim=d_model // num_heads, output_shape=d_model, - attention_axes=1 + attention_axes=1, ) self.mha2 = tf.keras.layers.MultiHeadAttention( num_heads=num_heads, key_dim=d_model // num_heads, output_shape=d_model, - attention_axes=1 + attention_axes=1, ) self.ffn = point_wise_feed_forward_network(d_model, dff) @@ -203,10 +211,10 @@ def __init__(self, d_model, num_heads, dff, rate=0.1, *args, **kwargs): def get_config(self): config = super().get_config() - config['d_model'] = self.d_model - config['num_heads'] = self.num_heads - config['dff'] = self.dff - config['rate'] = self.rate + config["d_model"] = self.d_model + config["num_heads"] = self.num_heads + config["dff"] = self.dff + config["rate"] = self.rate return config def call(self, x, enc_output, decoder_mask, encoder_mask, **kwargs): @@ -224,7 +232,7 @@ def call(self, x, enc_output, decoder_mask, encoder_mask, **kwargs): value=x, attention_mask=decoder_mask, return_attention_scores=True, - **kwargs + **kwargs, ) # (batch_size, target_seq_len, d_model) attn1 = self.dropout1(attn1, **kwargs) @@ -236,7 +244,7 @@ def call(self, x, enc_output, decoder_mask, encoder_mask, **kwargs): query=out1, attention_mask=encoder_mask, return_attention_scores=True, - **kwargs + **kwargs, ) # (batch_size, target_seq_len, d_model) attn2 = self.dropout2(attn2, **kwargs) @@ -259,10 +267,7 @@ def __init__(self, d_model, num_heads, dff, rate=0.1, *args, **kwargs): self.dff = dff self.rate = rate - self.mha = tf.keras.layers.MultiHeadAttention( - num_heads=num_heads, - key_dim=d_model // num_heads - ) + self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads) self.ffn = point_wise_feed_forward_network(d_model, dff) self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6) self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6) @@ -271,16 +276,16 @@ def __init__(self, d_model, num_heads, dff, rate=0.1, *args, **kwargs): def get_config(self): config = super().get_config() - config['d_model'] = self.d_model - config['num_heads'] = self.num_heads - config['dff'] = self.dff - config['rate'] = self.rate + config["d_model"] = self.d_model + config["num_heads"] = self.num_heads + config["dff"] = self.dff + config["rate"] = self.rate return config def call(self, query, key, value, decoder_mask=None, **kwargs): # Supports backward compatibility - if 'mask' in kwargs: - kwargs.pop('mask') + if "mask" in kwargs: + kwargs.pop("mask") # (batch_size, target_seq_len, d_model) attn, attn_weights_block = self.mha( @@ -290,7 +295,7 @@ def call(self, query, key, value, decoder_mask=None, **kwargs): attention_mask=decoder_mask, use_causal_mask=decoder_mask is None, return_attention_scores=True, - **kwargs + **kwargs, ) attn = self.dropout1(attn, **kwargs) @@ -310,13 +315,7 @@ def call(self, query, key, value, decoder_mask=None, **kwargs): class NonTrainablePositionEmbedding(tf.keras.layers.Layer): - def __init__( - self, - maxlen, - embed_dim, - *args, - **kwargs - ): + def __init__(self, maxlen, embed_dim, *args, **kwargs): super().__init__(*args, **kwargs) self._maxlen = maxlen self._embed_dim = embed_dim @@ -324,8 +323,8 @@ def __init__( def get_config(self): config = super().get_config() - config['maxlen'] = self._maxlen - config['embed_dim'] = self._embed_dim + config["maxlen"] = self._maxlen + config["embed_dim"] = self._embed_dim return config def call(self, x, **kwargs): @@ -337,13 +336,7 @@ def call(self, x, **kwargs): class TrainablePositionEmbedding(tf.keras.layers.Layer): - def __init__( - self, - maxlen, - embed_dim, - *args, - **kwargs - ): + def __init__(self, maxlen, embed_dim, *args, **kwargs): super().__init__(*args, **kwargs) self.pos_emb = tf.keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim) self._maxlen = maxlen @@ -351,8 +344,8 @@ def __init__( def get_config(self): config = super().get_config() - config['maxlen'] = self._maxlen - config['embed_dim'] = self._embed_dim + config["maxlen"] = self._maxlen + config["embed_dim"] = self._embed_dim return config def call(self, x, **kwargs): @@ -364,27 +357,15 @@ def call(self, x, **kwargs): class SimpleDecoderLayer(tf.keras.layers.Layer): - def __init__( - self, - d_model, - num_heads, - dff=512, - rate=0.1, - *args, - **kwargs - ): - super(SimpleDecoderLayer, self).__init__( - *args, - **kwargs - ) + def __init__(self, d_model, num_heads, dff=512, rate=0.1, *args, **kwargs): + super(SimpleDecoderLayer, self).__init__(*args, **kwargs) assert d_model % num_heads == 0 self.d_model = d_model self.num_heads = num_heads self.dff = dff self.rate = rate self.multi_head_attention_layer = tf.keras.layers.MultiHeadAttention( - num_heads=num_heads, - key_dim=d_model // num_heads + num_heads=num_heads, key_dim=d_model // num_heads ) self.ffn = point_wise_feed_forward_network(d_model, dff) self.mha_layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6) @@ -394,15 +375,13 @@ def __init__( def get_config(self): config = super().get_config() - config['d_model'] = self.d_model - config['num_heads'] = self.num_heads - config['dff'] = self.dff - config['rate'] = self.rate + config["d_model"] = self.d_model + config["num_heads"] = self.num_heads + config["dff"] = self.dff + config["rate"] = self.rate return config - def call( - self, decoder_input, enc_output, encoder_mask, **kwargs - ): + def call(self, decoder_input, enc_output, encoder_mask, **kwargs): # The reason we are doing this is that tensorflow on Mac doesn't seem to recognize the rank correctly # if platform.system() == 'Darwin': batch, enc_length = tf.shape(enc_output)[0], tf.shape(enc_output)[1] @@ -416,7 +395,7 @@ def call( query=decoder_input, attention_mask=encoder_mask, return_attention_scores=True, - **kwargs + **kwargs, ) # (batch_size, target_seq_len, d_model) attn = self.mha_dropout_layer(attn, **kwargs) out2 = self.mha_layernorm(attn + decoder_input) # (batch_size, target_seq_len, d_model) @@ -429,13 +408,7 @@ def call( class PositionalEncodingLayer(tf.keras.layers.Layer): - def __init__( - self, - embedding_size, - max_sequence_length=512, - *args, - **kwargs - ): + def __init__(self, embedding_size, max_sequence_length=512, *args, **kwargs): super(PositionalEncodingLayer, self).__init__(*args, **kwargs) self.embedding_size = embedding_size self.max_sequence_length = max_sequence_length @@ -444,8 +417,8 @@ def __init__( def get_config(self): config = super().get_config() - config['max_sequence_length'] = self.max_sequence_length - config['embedding_size'] = self.embedding_size + config["max_sequence_length"] = self.max_sequence_length + config["embedding_size"] = self.embedding_size return config def call(self, visit_concept_orders): @@ -465,45 +438,49 @@ def __init__(self, embedding_size, is_time_delta=False, *args, **kwargs): super(TimeEmbeddingLayer, self).__init__(*args, **kwargs) self.embedding_size = embedding_size self.is_time_delta = is_time_delta - self.w = self.add_weight(shape=(1, self.embedding_size), - trainable=True, - initializer=tf.keras.initializers.GlorotNormal(), - name=f'time_embedding_weight_{self.name}') - self.phi = self.add_weight(shape=(1, self.embedding_size), - trainable=True, - initializer=tf.keras.initializers.GlorotNormal(), - name=f'time_embedding_phi_{self.name}') + self.w = self.add_weight( + shape=(1, self.embedding_size), + trainable=True, + initializer=tf.keras.initializers.GlorotNormal(), + name=f"time_embedding_weight_{self.name}", + ) + self.phi = self.add_weight( + shape=(1, self.embedding_size), + trainable=True, + initializer=tf.keras.initializers.GlorotNormal(), + name=f"time_embedding_phi_{self.name}", + ) def get_config(self): config = super().get_config() - config['embedding_size'] = self.embedding_size - config['is_time_delta'] = self.is_time_delta + config["embedding_size"] = self.embedding_size + config["is_time_delta"] = self.is_time_delta return config def call(self, time_stamps): time_stamps = tf.cast(time_stamps, tf.float32) if self.is_time_delta: time_stamps = tf.concat( - [time_stamps[:, 0:1] * 0, time_stamps[:, 1:] - time_stamps[:, :-1]], axis=-1) + [time_stamps[:, 0:1] * 0, time_stamps[:, 1:] - time_stamps[:, :-1]], + axis=-1, + ) next_input = tf.expand_dims(time_stamps, axis=-1) * self.w + self.phi return tf.sin(next_input) class VisitEmbeddingLayer(tf.keras.layers.Layer): - def __init__(self, visit_order_size: int, - embedding_size: int, *args, **kwargs): + def __init__(self, visit_order_size: int, embedding_size: int, *args, **kwargs): super(VisitEmbeddingLayer, self).__init__(*args, **kwargs) self.visit_order_size = visit_order_size self.embedding_size = embedding_size - self.visit_embedding_layer = tf.keras.layers.Embedding(self.visit_order_size, - self.embedding_size) + self.visit_embedding_layer = tf.keras.layers.Embedding(self.visit_order_size, self.embedding_size) def get_config(self): config = super().get_config() - config['visit_order_size'] = self.visit_order_size - config['embedding_size'] = self.embedding_size + config["visit_order_size"] = self.visit_order_size + config["embedding_size"] = self.embedding_size return config def call(self, inputs, **kwargs): @@ -515,23 +492,18 @@ class ConceptValuePredictionLayer(tf.keras.layers.Layer): def __init__(self, embedding_size, *args, **kwargs): super(ConceptValuePredictionLayer, self).__init__(*args, **kwargs) self.embedding_size = embedding_size - self.concept_value_decoder_layer = tf.keras.Sequential(layers=[ - tf.keras.layers.Dense( - self.embedding_size, - activation='tanh' - ), - tf.keras.layers.Dense( - self.embedding_size, - activation='tanh' - ), - tf.keras.layers.Dense( - 1 - ) - ], name='value_decoder_layer') + self.concept_value_decoder_layer = tf.keras.Sequential( + layers=[ + tf.keras.layers.Dense(self.embedding_size, activation="tanh"), + tf.keras.layers.Dense(self.embedding_size, activation="tanh"), + tf.keras.layers.Dense(1), + ], + name="value_decoder_layer", + ) def get_config(self): config = super().get_config() - config['embedding_size'] = self.embedding_size + config["embedding_size"] = self.embedding_size return config def call(self, original_concept_embeddings, concept_val_embeddings, concept_value_masks): @@ -541,15 +513,9 @@ def call(self, original_concept_embeddings, concept_val_embeddings, concept_valu concept_vals = self.concept_value_decoder_layer(context) # (batch_size, context_window, 1) - concept_value_masks = tf.expand_dims( - concept_value_masks, - axis=-1 - ) + concept_value_masks = tf.expand_dims(concept_value_masks, axis=-1) # Zero out the positions without a val - concept_vals = tf.multiply( - concept_vals, - tf.cast(concept_value_masks, dtype=tf.float32) - ) + concept_vals = tf.multiply(concept_vals, tf.cast(concept_value_masks, dtype=tf.float32)) return concept_vals @@ -558,13 +524,12 @@ def __init__(self, embedding_size, *args, **kwargs): super(ConceptValueTransformationLayer, self).__init__(*args, **kwargs) self.embedding_size = embedding_size self.merge_value_transformation_layer = tf.keras.layers.Dense( - embedding_size, - name='merge_value_transformation_layer' + embedding_size, name="merge_value_transformation_layer" ) def get_config(self): config = super().get_config() - config['embedding_size'] = self.embedding_size + config["embedding_size"] = self.embedding_size return config def call(self, concept_embeddings, concept_values, concept_value_masks): @@ -572,42 +537,25 @@ def call(self, concept_embeddings, concept_values, concept_value_masks): # Combine the concept embeddings with concept_values # (batch_size, num_of_visits, num_of_concepts, 1) - concept_values = tf.expand_dims( - concept_values, - axis=-1 - ) + concept_values = tf.expand_dims(concept_values, axis=-1) # (batch_size, num_of_visits, num_of_concepts, 1) - concept_value_masks = tf.expand_dims( - concept_value_masks, - axis=-1 - ) + concept_value_masks = tf.expand_dims(concept_value_masks, axis=-1) # (batch_size, num_of_visits, num_of_concepts, 1 + embedding_size) - concept_embeddings_with_val = tf.concat( - [concept_embeddings, concept_values], - axis=-1 - ) + concept_embeddings_with_val = tf.concat([concept_embeddings, concept_values], axis=-1) # Run through a dense layer to bring the dimension back to embedding_size - concept_embeddings_with_val = self.merge_value_transformation_layer( - concept_embeddings_with_val - ) + concept_embeddings_with_val = self.merge_value_transformation_layer(concept_embeddings_with_val) # Zero out the positions without a val concept_embeddings_with_val = tf.multiply( - concept_embeddings_with_val, - tf.cast(concept_value_masks, dtype=tf.float32) + concept_embeddings_with_val, tf.cast(concept_value_masks, dtype=tf.float32) ) # Derive the inverse concept value masks for zeroing out the embeddings without a val inverse_concept_value_masks = tf.cast( - tf.logical_not( - tf.cast(concept_value_masks, dtype=tf.bool) - ), - dtype=tf.float32 + tf.logical_not(tf.cast(concept_value_masks, dtype=tf.bool)), + dtype=tf.float32, ) # Zero out the position of concept embeddings with a val - concept_embeddings_without_val = tf.multiply( - inverse_concept_value_masks, - concept_embeddings - ) + concept_embeddings_without_val = tf.multiply(inverse_concept_value_masks, concept_embeddings) # Merge two sets of concept embeddings concept_embeddings = concept_embeddings_without_val + concept_embeddings_with_val @@ -623,62 +571,45 @@ def __init__(self, time_embeddings_size, embedding_size, *args, **kwargs): self.embedding_size = embedding_size # define the time embedding layer for absolute time stamps (since 1970) - self.time_embedding_layer = TimeEmbeddingLayer( - embedding_size=time_embeddings_size, - name='time_embedding_layer' - ) + self.time_embedding_layer = TimeEmbeddingLayer(embedding_size=time_embeddings_size, name="time_embedding_layer") # define the age embedding layer for the age w.r.t the medical record - self.age_embedding_layer = TimeEmbeddingLayer( - embedding_size=time_embeddings_size, - name='age_embedding_layer' - ) + self.age_embedding_layer = TimeEmbeddingLayer(embedding_size=time_embeddings_size, name="age_embedding_layer") # define positional encoding layer for visit numbers, the visit numbers are normalized # by subtracting visit numbers off the first visit number self.positional_encoding_layer = PositionalEncodingLayer( - embedding_size=time_embeddings_size, - name='positional_encoding_layer' + embedding_size=time_embeddings_size, name="positional_encoding_layer" ) # Temporal transformation self.temporal_transformation_layer = tf.keras.layers.Dense( - embedding_size, - activation='tanh', - name='temporal_transformation' + embedding_size, activation="tanh", name="temporal_transformation" ) def get_config(self): config = super().get_config() - config['time_embeddings_size'] = self.time_embeddings_size - config['embedding_size'] = self.embedding_size + config["time_embeddings_size"] = self.time_embeddings_size + config["embedding_size"] = self.embedding_size return config def call(self, concept_embeddings, pat_seq_age, pat_seq_time, visit_rank_order, **kwargs): _, _, num_of_concepts = pat_seq_age.shape - pt_seq_age_embeddings = self.age_embedding_layer( - pat_seq_age, - **kwargs - ) - pt_seq_time_embeddings = self.time_embedding_layer( - pat_seq_time, - **kwargs - ) - visit_positional_encoding = self.positional_encoding_layer( - visit_rank_order, - **kwargs - ) + pt_seq_age_embeddings = self.age_embedding_layer(pat_seq_age, **kwargs) + pt_seq_time_embeddings = self.time_embedding_layer(pat_seq_time, **kwargs) + visit_positional_encoding = self.positional_encoding_layer(visit_rank_order, **kwargs) - visit_positional_encoding = tf.tile( - visit_positional_encoding[:, :, tf.newaxis, :], [1, 1, num_of_concepts, 1]) + visit_positional_encoding = tf.tile(visit_positional_encoding[:, :, tf.newaxis, :], [1, 1, num_of_concepts, 1]) # (batch, num_of_visits, num_of_concepts, embedding_size) temporal_concept_embeddings = self.temporal_transformation_layer( tf.concat( - [concept_embeddings, - pt_seq_age_embeddings, - pt_seq_time_embeddings, - visit_positional_encoding], - axis=-1 + [ + concept_embeddings, + pt_seq_age_embeddings, + pt_seq_time_embeddings, + visit_positional_encoding, + ], + axis=-1, ) ) @@ -692,28 +623,33 @@ def __init__(self, model_path: str, *args, **kwargs): bert_model = tf.keras.models.load_model(model_path, custom_objects=get_custom_objects()) self.model_path = model_path - self.concept_embedding_layer = bert_model.get_layer('concept_embeddings') - self.visit_segment_layer = [layer for layer in bert_model.layers if - layer.name in ['visit_embedding_layer', - 'visit_segment_layer']][0] - self.positional_encoding_layer = bert_model.get_layer('positional_encoding_layer') - self.time_embedding_layer = bert_model.get_layer('time_embedding_layer') - self.age_embedding_layer = bert_model.get_layer('age_embedding_layer') - self.scale_pat_seq_layer = bert_model.get_layer('scale_pat_seq_layer') - self.encoder_layer = bert_model.get_layer('encoder') + self.concept_embedding_layer = bert_model.get_layer("concept_embeddings") + self.visit_segment_layer = [ + layer for layer in bert_model.layers if layer.name in ["visit_embedding_layer", "visit_segment_layer"] + ][0] + self.positional_encoding_layer = bert_model.get_layer("positional_encoding_layer") + self.time_embedding_layer = bert_model.get_layer("time_embedding_layer") + self.age_embedding_layer = bert_model.get_layer("age_embedding_layer") + self.scale_pat_seq_layer = bert_model.get_layer("scale_pat_seq_layer") + self.encoder_layer = bert_model.get_layer("encoder") # self.conv_1d = tf.keras.layers.Conv1D(1, 1) - self.attention_dense = tf.keras.layers.Dense(self.scale_pat_seq_layer.units, - activation='tanh') - self.dense = tf.keras.layers.Dense(self.scale_pat_seq_layer.units, activation='tanh') + self.attention_dense = tf.keras.layers.Dense(self.scale_pat_seq_layer.units, activation="tanh") + self.dense = tf.keras.layers.Dense(self.scale_pat_seq_layer.units, activation="tanh") def get_config(self): config = super().get_config() - config['model_path'] = self.model_path + config["model_path"] = self.model_path return config def call(self, inputs, **kwargs): - (local_concept_ids, local_visit_segments, local_visit_concept_orders, - local_time_stamps, local_ages, local_mask) = inputs + ( + local_concept_ids, + local_visit_segments, + local_visit_concept_orders, + local_time_stamps, + local_ages, + local_mask, + ) = inputs batch_size, max_seq_length = local_mask.get_shape().as_list() @@ -724,19 +660,30 @@ def call(self, inputs, **kwargs): concept_mask = create_concept_mask(local_mask, max_seq_length) input_for_encoder = self.scale_pat_seq_layer( - tf.concat([concept_embeddings, time_embeddings, age_embeddings, positional_encoddings], - axis=-1)) + tf.concat( + [ + concept_embeddings, + time_embeddings, + age_embeddings, + positional_encoddings, + ], + axis=-1, + ) + ) input_for_encoder = self.visit_segment_layer([local_visit_segments, input_for_encoder]) contextualized_embeddings, _ = self.encoder_layer(input_for_encoder, concept_mask) _, _, embedding_size = contextualized_embeddings.get_shape().as_list() mask_embeddings = tf.tile(tf.expand_dims(local_mask == 0, -1), [1, 1, embedding_size]) - contextualized_embeddings = tf.math.multiply(contextualized_embeddings, - tf.cast(mask_embeddings, dtype=tf.float32)) + contextualized_embeddings = tf.math.multiply( + contextualized_embeddings, tf.cast(mask_embeddings, dtype=tf.float32) + ) # (batch, seq_len, embeddings_size) - multi_dim_att = tf.nn.softmax(self.attention_dense(contextualized_embeddings) - + (tf.cast(tf.expand_dims(local_mask, axis=-1), - dtype='float32') * -1e9), axis=1) + multi_dim_att = tf.nn.softmax( + self.attention_dense(contextualized_embeddings) + + (tf.cast(tf.expand_dims(local_mask, axis=-1), dtype="float32") * -1e9), + axis=1, + ) context_representation = tf.reduce_sum(multi_dim_att * contextualized_embeddings, axis=1) # conv_output = self.conv_1d(contextualized_embeddings) @@ -750,11 +697,15 @@ def call(self, inputs, **kwargs): class ConvolutionBertLayer(tf.keras.layers.Layer): - def __init__(self, - model_path: str, - seq_len: int, - context_window: int, - stride: int, *args, **kwargs): + def __init__( + self, + model_path: str, + seq_len: int, + context_window: int, + stride: int, + *args, + **kwargs, + ): super(ConvolutionBertLayer, self).__init__(*args, **kwargs) self.model_path = model_path self.seq_len = seq_len @@ -763,17 +714,16 @@ def __init__(self, self.step = (seq_len - context_window) // stride + 1 self.bert_layer = BertLayer(model_path=model_path) # self.conv_1d = tf.keras.layers.Conv1D(1, 1) - self.attention_dense = tf.keras.layers.Dense(self.bert_layer.scale_pat_seq_layer.units, - activation='tanh') + self.attention_dense = tf.keras.layers.Dense(self.bert_layer.scale_pat_seq_layer.units, activation="tanh") assert (self.step - 1) * self.stride + self.context_window == self.seq_len def get_config(self): config = super().get_config() - config['model_path'] = self.model_path - config['seq_len'] = self.seq_len - config['context_window'] = self.context_window - config['stride'] = self.stride + config["model_path"] = self.model_path + config["seq_len"] = self.seq_len + config["context_window"] = self.context_window + config["stride"] = self.stride return config def call(self, inputs, **kwargs): @@ -792,12 +742,14 @@ def call(self, inputs, **kwargs): visit_concept_orders_step = visit_concept_orders[:, start_index:end_index] mask_step = mask[:, start_index:end_index] - inputs_step = [concept_ids_step, - visit_segments_step, - visit_concept_orders_step, - time_stamps_step, - ages_step, - mask_step] + inputs_step = [ + concept_ids_step, + visit_segments_step, + visit_concept_orders_step, + time_stamps_step, + ages_step, + mask_step, + ] output_masking = tf.cast(tf.reduce_all(mask_step == 1, axis=-1), dtype=tf.int32) @@ -814,13 +766,11 @@ def call(self, inputs, **kwargs): attn = self.attention_dense(bert_output_tensor) - attn += (tf.cast(tf.expand_dims(bert_output_masking_tensor, axis=-1), - dtype='float32') * -1e9) + attn += tf.cast(tf.expand_dims(bert_output_masking_tensor, axis=-1), dtype="float32") * -1e9 _, _, embedding_size = bert_output_tensor.get_shape().as_list() - context_representation = tf.reduce_sum(tf.nn.softmax(attn, axis=1) * bert_output_tensor, - axis=1) + context_representation = tf.reduce_sum(tf.nn.softmax(attn, axis=1) * bert_output_tensor, axis=1) # context_representation = tf.reshape( # tf.transpose(tf.nn.softmax(conv_output, axis=1), [0, 2, 1]) @ bert_output_tensor, @@ -832,15 +782,15 @@ def call(self, inputs, **kwargs): class VisitPhenotypeLayer(tf.keras.layers.Layer): def __init__( - self, - num_of_phenotypes: int, - num_of_phenotype_neighbors: int, - num_of_concept_neighbors: int, - embedding_size: int, - transformer_dropout: float, - dff: int = 2148, - *args, - **kwargs + self, + num_of_phenotypes: int, + num_of_phenotype_neighbors: int, + num_of_concept_neighbors: int, + embedding_size: int, + transformer_dropout: float, + dff: int = 2148, + *args, + **kwargs, ): super(VisitPhenotypeLayer, self).__init__(*args, **kwargs) self.num_of_phenotypes = num_of_phenotypes @@ -856,13 +806,10 @@ def __init__( shape=(num_of_phenotypes, embedding_size), initializer=tf.keras.initializers.GlorotUniform(seed=0), trainable=True, - name='phenotype_embeddings_matrix' + name="phenotype_embeddings_matrix", ) - self.ffn = point_wise_feed_forward_network( - embedding_size, - dff - ) + self.ffn = point_wise_feed_forward_network(embedding_size, dff) self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6) self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6) @@ -871,89 +818,63 @@ def __init__( def get_config(self): config = super().get_config() - config['num_of_phenotypes'] = self.num_of_phenotypes - config['embedding_size'] = self.embedding_size - config['transformer_dropout'] = self.transformer_dropout - config['dff'] = self.dff - config['num_of_concept_neighbors'] = self.num_of_concept_neighbors - config['num_of_phenotype_neighbors'] = self.num_of_phenotype_neighbors + config["num_of_phenotypes"] = self.num_of_phenotypes + config["embedding_size"] = self.embedding_size + config["transformer_dropout"] = self.transformer_dropout + config["dff"] = self.dff + config["num_of_concept_neighbors"] = self.num_of_concept_neighbors + config["num_of_phenotype_neighbors"] = self.num_of_phenotype_neighbors return config def call(self, inputs, **kwargs): visit_embeddings, visit_mask, embedding_matrix = inputs # Do not compute the entropy for the masked visits - converted_visit_mask = tf.cast( - tf.logical_not( - tf.cast( - visit_mask, - dtype=tf.bool - ) - ), - dtype=tf.float32 - )[:, :, tf.newaxis] + converted_visit_mask = tf.cast(tf.logical_not(tf.cast(visit_mask, dtype=tf.bool)), dtype=tf.float32)[ + :, :, tf.newaxis + ] # (batch_size, num_of_visits, num_of_phenotypes) visit_phenotype_probs = tf.nn.softmax( - visit_embeddings @ tf.transpose( - self.phenotype_embeddings, - [1, 0] - ) * converted_visit_mask + visit_embeddings @ tf.transpose(self.phenotype_embeddings, [1, 0]) * converted_visit_mask ) # calculate phenotype concept distance matrix (num_of_phenotypes, top_k) phenotype_concept_dist = tf.reduce_mean( -tf.math.top_k( - -distance_matrix( - self.phenotype_embeddings, - embedding_matrix - ), - k=self.num_of_concept_neighbors + -distance_matrix(self.phenotype_embeddings, embedding_matrix), + k=self.num_of_concept_neighbors, ).values ) - self.add_metric( - phenotype_concept_dist, - name='phenotype_concept_dist' - ) + self.add_metric(phenotype_concept_dist, name="phenotype_concept_dist") # Calculate the probability distribution entropy phenotype_prob_entropy = -tf.reduce_sum( visit_phenotype_probs * tf.math.log(visit_phenotype_probs) * converted_visit_mask, - axis=-1 + axis=-1, ) # Add the entropy to the model metrics - self.add_metric( - phenotype_prob_entropy, - name='phenotype_probability_entropy' - ) + self.add_metric(phenotype_prob_entropy, name="phenotype_probability_entropy") # Get phenotype pairwise distance metrics phe_inv_loss, phe_dist_metric, phe_dist_var = self.get_inverse_phenotype_dist_loss_metric() - self.add_metric( - phe_dist_metric, - name='phenotype_euclidean_distance' - ) + self.add_metric(phe_dist_metric, name="phenotype_euclidean_distance") - self.add_metric( - phe_dist_var, - name='phenotype_euclidean_variance' - ) + self.add_metric(phe_dist_var, name="phenotype_euclidean_variance") # Calculate the contextualized visit embeddings using the pre-defined phenotype embeddings # (batch_size, num_of_visits, embedding_size) contextualized_phenotype_embeddings = self.dropout1( visit_phenotype_probs @ self.phenotype_embeddings, - training=kwargs.get('training') + training=kwargs.get("training"), ) - out1 = self.layernorm1( - visit_embeddings + contextualized_phenotype_embeddings - ) + out1 = self.layernorm1(visit_embeddings + contextualized_phenotype_embeddings) ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model) - ffn_output = self.dropout2(ffn_output, training=kwargs.get('training')) + ffn_output = self.dropout2(ffn_output, training=kwargs.get("training")) out2 = self.layernorm2(out1 + ffn_output) # (batch_size, input_seq_len, d_model) return out2, visit_phenotype_probs @@ -962,25 +883,17 @@ def get_inverse_phenotype_dist_loss_metric(self): r = tf.reduce_sum(self.phenotype_embeddings * self.phenotype_embeddings, 1) # turn r into column vector r = tf.reshape(r, [-1, 1]) - euclidean_distances_full = r - 2 * tf.matmul(self.phenotype_embeddings, tf.transpose( - self.phenotype_embeddings)) + tf.transpose(r) + euclidean_distances_full = ( + r - 2 * tf.matmul(self.phenotype_embeddings, tf.transpose(self.phenotype_embeddings)) + tf.transpose(r) + ) - euclidean_distances = -tf.math.top_k( - -euclidean_distances_full, - k=self.num_of_phenotype_neighbors - ).values + euclidean_distances = -tf.math.top_k(-euclidean_distances_full, k=self.num_of_phenotype_neighbors).values - inv_loss = tf.reduce_mean( - tf.math.exp(-euclidean_distances) - ) + inv_loss = tf.reduce_mean(tf.math.exp(-euclidean_distances)) - var_loss = tf.math.reduce_variance( - euclidean_distances - ) + var_loss = tf.math.reduce_variance(euclidean_distances) - dist_metric = tf.reduce_mean( - euclidean_distances - ) + dist_metric = tf.reduce_mean(euclidean_distances) return inv_loss, dist_metric, var_loss @@ -989,7 +902,9 @@ def distance_matrix(matrix_1, matrix_2): m = matrix_1.shape[0] n = matrix_2.shape[0] - assert matrix_1.shape[1] == matrix_2.shape[1], f"The number of components for vectors in A \ + assert ( + matrix_1.shape[1] == matrix_2.shape[1] + ), f"The number of components for vectors in A \ {matrix_1.shape[1]} does not match that of B {matrix_2.shape[1]}!" matrix_1_dots = tf.reshape(tf.reduce_sum(matrix_1 * matrix_1, axis=1), (m, 1)) * tf.ones((1, n)) @@ -1000,25 +915,27 @@ def distance_matrix(matrix_1, matrix_2): return tf.sqrt(matrix_distance_squared) -get_custom_objects().update({ - 'Encoder': Encoder, - 'GptDecoder': GptDecoder, - 'GptDecoderLayer': GptDecoderLayer, - 'TrainablePositionEmbedding': TrainablePositionEmbedding, - 'EncoderLayer': EncoderLayer, - 'DecoderLayer': DecoderLayer, - 'SimpleDecoderLayer': SimpleDecoderLayer, - 'VisitEmbeddingLayer': VisitEmbeddingLayer, - 'PositionalEncodingLayer': PositionalEncodingLayer, - 'NonTrainablePositionEmbedding': NonTrainablePositionEmbedding, - 'TimeEmbeddingLayer': TimeEmbeddingLayer, - 'TemporalTransformationLayer': TemporalTransformationLayer, - 'ConceptValueTransformationLayer': ConceptValueTransformationLayer, - 'ReusableEmbedding': ReusableEmbedding, - 'TiedOutputEmbedding': TiedOutputEmbedding, - 'MaskedPenalizedSparseCategoricalCrossentropy': MaskedPenalizedSparseCategoricalCrossentropy, - 'BertLayer': BertLayer, - 'ConvolutionBertLayer': ConvolutionBertLayer, - 'VisitPhenotypeLayer': VisitPhenotypeLayer, - 'ConceptValuePredictionLayer': ConceptValuePredictionLayer -}) +get_custom_objects().update( + { + "Encoder": Encoder, + "GptDecoder": GptDecoder, + "GptDecoderLayer": GptDecoderLayer, + "TrainablePositionEmbedding": TrainablePositionEmbedding, + "EncoderLayer": EncoderLayer, + "DecoderLayer": DecoderLayer, + "SimpleDecoderLayer": SimpleDecoderLayer, + "VisitEmbeddingLayer": VisitEmbeddingLayer, + "PositionalEncodingLayer": PositionalEncodingLayer, + "NonTrainablePositionEmbedding": NonTrainablePositionEmbedding, + "TimeEmbeddingLayer": TimeEmbeddingLayer, + "TemporalTransformationLayer": TemporalTransformationLayer, + "ConceptValueTransformationLayer": ConceptValueTransformationLayer, + "ReusableEmbedding": ReusableEmbedding, + "TiedOutputEmbedding": TiedOutputEmbedding, + "MaskedPenalizedSparseCategoricalCrossentropy": MaskedPenalizedSparseCategoricalCrossentropy, + "BertLayer": BertLayer, + "ConvolutionBertLayer": ConvolutionBertLayer, + "VisitPhenotypeLayer": VisitPhenotypeLayer, + "ConceptValuePredictionLayer": ConceptValuePredictionLayer, + } +) diff --git a/src/cehrbert/models/layers/hierarchical_custom_layers.py b/src/cehrbert/models/layers/hierarchical_custom_layers.py index 57123a62..fb964d02 100644 --- a/src/cehrbert/models/layers/hierarchical_custom_layers.py +++ b/src/cehrbert/models/layers/hierarchical_custom_layers.py @@ -6,16 +6,18 @@ class HierarchicalBertLayer(tf.keras.layers.Layer): - def __init__(self, - num_of_exchanges, - num_of_visits, - num_of_concepts, - depth, - embedding_size, - num_heads, - dropout_rate=0.1, - *args, - **kwargs): + def __init__( + self, + num_of_exchanges, + num_of_visits, + num_of_concepts, + depth, + embedding_size, + num_heads, + dropout_rate=0.1, + *args, + **kwargs, + ): super(HierarchicalBertLayer, self).__init__(*args, **kwargs) assert embedding_size % num_heads == 0 @@ -28,23 +30,21 @@ def __init__(self, self.num_heads = num_heads self.dropout_rate = dropout_rate self.concept_encoder_layer = Encoder( - name='concept_encoder', + name="concept_encoder", num_layers=depth, d_model=embedding_size, num_heads=num_heads, - dropout_rate=dropout_rate + dropout_rate=dropout_rate, ) self.visit_encoder_layer = Encoder( - name='visit_encoder', + name="visit_encoder", num_layers=depth, d_model=embedding_size, num_heads=num_heads, - dropout_rate=dropout_rate + dropout_rate=dropout_rate, ) self.mha_layer = tf.keras.layers.MultiHeadAttention( - num_heads=num_heads, - key_dim=embedding_size // num_heads, - name='mha' + num_heads=num_heads, key_dim=embedding_size // num_heads, name="mha" ) # Insert the att embeddings between the visit embeddings using the following trick @@ -53,9 +53,9 @@ def __init__(self, np.identity(self.num_of_visits), obj=range(1, self.num_of_visits), values=0, - axis=1 + axis=1, ), - dtype=tf.float32 + dtype=tf.float32, ) # Create the inverse "identity" matrix for inserting att embeddings @@ -64,44 +64,43 @@ def __init__(self, np.identity(self.num_of_visits - 1), obj=range(0, self.num_of_visits), values=0, - axis=1), - dtype=tf.float32) + axis=1, + ), + dtype=tf.float32, + ) - self.merge_matrix = tf.constant( - [1] + [0] * (self.num_of_concepts - 1), - dtype=tf.float32 - )[tf.newaxis, tf.newaxis, :, tf.newaxis] + self.merge_matrix = tf.constant([1] + [0] * (self.num_of_concepts - 1), dtype=tf.float32)[ + tf.newaxis, tf.newaxis, :, tf.newaxis + ] - self.merge_matrix_inverse = tf.constant( - [0] + [1] * (self.num_of_concepts - 1), - dtype=tf.float32 - )[tf.newaxis, tf.newaxis, :, tf.newaxis] + self.merge_matrix_inverse = tf.constant([0] + [1] * (self.num_of_concepts - 1), dtype=tf.float32)[ + tf.newaxis, tf.newaxis, :, tf.newaxis + ] self.global_embedding_dropout_layer = tf.keras.layers.Dropout(dropout_rate) self.global_concept_embeddings_normalization = tf.keras.layers.LayerNormalization( - name='global_concept_embeddings_normalization', - epsilon=1e-6 + name="global_concept_embeddings_normalization", epsilon=1e-6 ) def get_config(self): config = super().get_config() - config['num_of_visits'] = self.num_of_visits - config['num_of_concepts'] = self.num_of_concepts - config['num_of_exchanges'] = self.num_of_exchanges - config['embedding_size'] = self.embedding_size - config['depth'] = self.depth - config['num_heads'] = self.num_heads - config['dropout_rate'] = self.dropout_rate + config["num_of_visits"] = self.num_of_visits + config["num_of_concepts"] = self.num_of_concepts + config["num_of_exchanges"] = self.num_of_exchanges + config["embedding_size"] = self.embedding_size + config["depth"] = self.depth + config["num_heads"] = self.num_heads + config["dropout_rate"] = self.dropout_rate return config def call( - self, - temporal_concept_embeddings, - att_embeddings, - pat_concept_mask, - visit_concept_mask, - **kwargs + self, + temporal_concept_embeddings, + att_embeddings, + pat_concept_mask, + visit_concept_mask, + **kwargs, ): for i in range(self.num_of_exchanges): # Step 1 @@ -109,13 +108,18 @@ def call( contextualized_concept_embeddings, _ = self.concept_encoder_layer( temporal_concept_embeddings, # be reused pat_concept_mask, # not change - **kwargs + **kwargs, ) # (batch_size, num_of_visits, num_of_concepts, embedding_size) contextualized_concept_embeddings = tf.reshape( contextualized_concept_embeddings, - shape=(-1, self.num_of_visits, self.num_of_concepts, self.embedding_size) + shape=( + -1, + self.num_of_visits, + self.num_of_concepts, + self.embedding_size, + ), ) # Step 2 generate augmented visit embeddings # Slice out the first contextualized embedding of each visit @@ -126,19 +130,23 @@ def call( # (batch, num_of_visits * num_of_concepts, embedding_size) contextualized_concept_embeddings = tf.reshape( contextualized_concept_embeddings, - shape=(-1, self.num_of_visits * self.num_of_concepts, self.embedding_size) + shape=( + -1, + self.num_of_visits * self.num_of_concepts, + self.embedding_size, + ), ) # (batch_size, num_of_visits + num_of_visits - 1, embedding_size) expanded_visit_embeddings = tf.transpose( tf.transpose(visit_embeddings, perm=[0, 2, 1]) @ self.identity, - perm=[0, 2, 1] + perm=[0, 2, 1], ) # (batch_size, num_of_visits + num_of_visits - 1, embedding_size) expanded_att_embeddings = tf.transpose( tf.transpose(att_embeddings, perm=[0, 2, 1]) @ self.identity_inverse, - perm=[0, 2, 1] + perm=[0, 2, 1], ) # Insert the att embeddings between visit embedidngs @@ -147,23 +155,17 @@ def call( # Step 3 encoder applied to patient level # Feed augmented visit embeddings into encoders to get contextualized visit embeddings - visit_embeddings, _ = self.visit_encoder_layer( - augmented_visit_embeddings, - visit_concept_mask, - **kwargs - ) + visit_embeddings, _ = self.visit_encoder_layer(augmented_visit_embeddings, visit_concept_mask, **kwargs) # v, k, q global_concept_embeddings = self.mha_layer( value=visit_embeddings, key=visit_embeddings, query=contextualized_concept_embeddings, attention_mask=visit_concept_mask, - return_attention_scores=False + return_attention_scores=False, ) - global_concept_embeddings = self.global_embedding_dropout_layer( - global_concept_embeddings - ) + global_concept_embeddings = self.global_embedding_dropout_layer(global_concept_embeddings) global_concept_embeddings = self.global_concept_embeddings_normalization( global_concept_embeddings + contextualized_concept_embeddings @@ -175,30 +177,25 @@ def call( global_concept_embeddings = tf.reshape( global_concept_embeddings, - (-1, self.num_of_visits, self.num_of_concepts, self.embedding_size) + (-1, self.num_of_visits, self.num_of_concepts, self.embedding_size), ) global_concept_embeddings += ( - global_concept_embeddings * self.merge_matrix_inverse + - tf.expand_dims( - visit_embeddings_without_att, - axis=-2 - ) * self.merge_matrix + global_concept_embeddings * self.merge_matrix_inverse + + tf.expand_dims(visit_embeddings_without_att, axis=-2) * self.merge_matrix ) temporal_concept_embeddings = tf.reshape( global_concept_embeddings, - (-1, self.num_of_concepts, self.embedding_size) + (-1, self.num_of_concepts, self.embedding_size), ) global_concept_embeddings = tf.reshape( global_concept_embeddings, - (-1, self.num_of_visits * self.num_of_concepts, self.embedding_size) + (-1, self.num_of_visits * self.num_of_concepts, self.embedding_size), ) return global_concept_embeddings, self.identity @ visit_embeddings -get_custom_objects().update({ - 'HierarchicalBertLayer': HierarchicalBertLayer -}) +get_custom_objects().update({"HierarchicalBertLayer": HierarchicalBertLayer}) diff --git a/src/cehrbert/models/loss_schedulers.py b/src/cehrbert/models/loss_schedulers.py index af659bf1..db71ee50 100644 --- a/src/cehrbert/models/loss_schedulers.py +++ b/src/cehrbert/models/loss_schedulers.py @@ -2,8 +2,8 @@ class CosineLRSchedule: - """ - Cosine annealing with warm restarts, described in paper + """Cosine annealing with warm restarts, described in paper. + "SGDR: stochastic gradient descent with warm restarts" https://arxiv.org/abs/1608.03983 @@ -20,8 +20,14 @@ class CosineLRSchedule: `keras.callbacks.LearningRateScheduler`. """ - def __init__(self, lr_high: float, lr_low: float, initial_period: int = 50, - period_mult: float = 2, high_lr_mult: float = 0.97): + def __init__( + self, + lr_high: float, + lr_low: float, + initial_period: int = 50, + period_mult: float = 2, + high_lr_mult: float = 0.97, + ): self._lr_high = lr_high self._lr_low = lr_low self._initial_period = initial_period @@ -39,9 +45,7 @@ def get_lr_for_epoch(self, epoch): result = lr_max for i in range(epoch + 1): if i == epoch: # last iteration - result = (self._lr_low + - 0.5 * (lr_max - self._lr_low) * - (1 + math.cos(math.pi * t_cur / period))) + result = self._lr_low + 0.5 * (lr_max - self._lr_low) * (1 + math.cos(math.pi * t_cur / period)) else: if t_cur == period: period *= self._period_mult diff --git a/src/cehrbert/models/parse_args.py b/src/cehrbert/models/parse_args.py index ab4baacc..64ad85d7 100644 --- a/src/cehrbert/models/parse_args.py +++ b/src/cehrbert/models/parse_args.py @@ -5,94 +5,91 @@ def create_parse_args(): - parser = argparse.ArgumentParser( - description='Arguments for concept embedding model' - ) + parser = argparse.ArgumentParser(description="Arguments for concept embedding model") parser.add_argument( - '--training_data_parquet_path', - dest='training_data_parquet_path', - action='store', - help='The path for your input_folder where the raw data is', - required=True + "--training_data_parquet_path", + dest="training_data_parquet_path", + action="store", + help="The path for your input_folder where the raw data is", + required=True, ) parser.add_argument( - '-o', - '--output_folder', - dest='output_folder', - action='store', - help= - 'The output folder that stores the domain tables download destination', - required=True + "-o", + "--output_folder", + dest="output_folder", + action="store", + help="The output folder that stores the domain tables download destination", + required=True, ) parser.add_argument( - '--checkpoint_name', - dest='checkpoint_name', - action='store', - help='This refers to the model name in the model output folder, which will be used as the checkpoint to ' - 'restore the training from', - required=False + "--checkpoint_name", + dest="checkpoint_name", + action="store", + help="This refers to the model name in the model output folder, which will be used as the checkpoint to " + "restore the training from", + required=False, ) parser.add_argument( - '-m', - '--max_seq_length', - dest='max_seq_length', - action='store', + "-m", + "--max_seq_length", + dest="max_seq_length", + action="store", type=int, default=100, - required=False + required=False, ) parser.add_argument( - '-t', - '--time_window_size', - dest='time_window_size', - action='store', + "-t", + "--time_window_size", + dest="time_window_size", + action="store", type=int, default=100, - required=False + required=False, ) parser.add_argument( - '-c', - '--embedding_size', - dest='embedding_size', - action='store', + "-c", + "--embedding_size", + dest="embedding_size", + action="store", type=int, default=128, - required=False + required=False, ) parser.add_argument( - '-e', - '--epochs', - dest='epochs', - action='store', + "-e", + "--epochs", + dest="epochs", + action="store", type=int, default=50, - required=False + required=False, ) parser.add_argument( - '-b', - '--batch_size', - dest='batch_size', - action='store', + "-b", + "--batch_size", + dest="batch_size", + action="store", type=int, default=128, - required=False + required=False, ) parser.add_argument( - '-lr', - '--learning_rate', - dest='learning_rate', - action='store', + "-lr", + "--learning_rate", + dest="learning_rate", + action="store", type=float, default=2e-4, - required=False + required=False, ) parser.add_argument( - '-bl', - '--tf_board_log_path', - dest='tf_board_log_path', - action='store', - default='./logs', - required=False + "-bl", + "--tf_board_log_path", + dest="tf_board_log_path", + action="store", + default="./logs", + required=False, ) return parser @@ -100,65 +97,47 @@ def create_parse_args(): def create_parse_args_base_bert(): parser = create_parse_args() parser.add_argument( - '--min_num_of_concepts', - dest='min_num_of_concepts', - action='store', + "--min_num_of_concepts", + dest="min_num_of_concepts", + action="store", type=int, default=5, - required=False + required=False, ) parser.add_argument( - '-d', - '--depth', - dest='depth', - action='store', + "-d", + "--depth", + dest="depth", + action="store", type=int, default=5, - required=False + required=False, ) parser.add_argument( - '-nh', - '--num_heads', - dest='num_heads', - action='store', + "-nh", + "--num_heads", + dest="num_heads", + action="store", type=int, default=8, - required=False - ) - parser.add_argument( - '-iv', - '--include_visit', - dest='include_visit_prediction', - action='store_true' - ) - parser.add_argument( - '--include_prolonged_length_stay', - dest='include_prolonged_length_stay', - action='store_true' - ) - parser.add_argument( - '-ut', - '--use_time_embedding', - dest='use_time_embedding', - action='store_true' - ) - parser.add_argument( - '--use_behrt', - dest='use_behrt', - action='store_true' + required=False, ) + parser.add_argument("-iv", "--include_visit", dest="include_visit_prediction", action="store_true") parser.add_argument( - '--use_dask', - dest='use_dask', - action='store_true' + "--include_prolonged_length_stay", + dest="include_prolonged_length_stay", + action="store_true", ) + parser.add_argument("-ut", "--use_time_embedding", dest="use_time_embedding", action="store_true") + parser.add_argument("--use_behrt", dest="use_behrt", action="store_true") + parser.add_argument("--use_dask", dest="use_dask", action="store_true") parser.add_argument( - '--time_embeddings_size', - dest='time_embeddings_size', - action='store', + "--time_embeddings_size", + dest="time_embeddings_size", + action="store", type=int, default=16, - required=False + required=False, ) return parser @@ -167,233 +146,184 @@ def create_parse_args_gpt(): from numpy import infty def valid_mask_rate(mask_rate: str) -> float: - """Custom argparse type for validating the mask rate given from the command line""" + """Custom argparse type for validating the mask rate given from the command line.""" try: rate = float(mask_rate) if rate < 0: - raise RuntimeError(f'{mask_rate} cannot be less than 0') + raise RuntimeError(f"{mask_rate} cannot be less than 0") if rate > 1: - raise RuntimeError(f'{mask_rate} cannot be grater than 1') + raise RuntimeError(f"{mask_rate} cannot be grater than 1") return rate except ValueError as e: raise argparse.ArgumentTypeError(e) parser = create_parse_args() parser.add_argument( - '--min_num_of_concepts', - dest='min_num_of_concepts', - action='store', + "--min_num_of_concepts", + dest="min_num_of_concepts", + action="store", type=int, default=5, - required=False + required=False, ) parser.add_argument( - '-d', - '--depth', - dest='depth', - action='store', + "-d", + "--depth", + dest="depth", + action="store", type=int, default=5, - required=False + required=False, ) parser.add_argument( - '-nh', - '--num_heads', - dest='num_heads', - action='store', + "-nh", + "--num_heads", + dest="num_heads", + action="store", type=int, default=8, - required=False - ) - parser.add_argument( - '--concept_path', - dest='concept_path', - action='store', - help='The path for the concept', - required=True + required=False, ) parser.add_argument( - '--use_dask', - dest='use_dask', - action='store_true' + "--concept_path", + dest="concept_path", + action="store", + help="The path for the concept", + required=True, ) + parser.add_argument("--use_dask", dest="use_dask", action="store_true") parser.add_argument( - '--min_num_of_visits', - dest='min_num_of_visits', - action='store', + "--min_num_of_visits", + dest="min_num_of_visits", + action="store", type=int, default=2, - required=False + required=False, ) parser.add_argument( - '--max_num_of_visits', - dest='max_num_of_visits', - action='store', + "--max_num_of_visits", + dest="max_num_of_visits", + action="store", type=int, default=20, - required=False + required=False, ) parser.add_argument( - '--print_every', - dest='print_every', - action='store', + "--print_every", + dest="print_every", + action="store", type=int, default=500, - required=False + required=False, ) parser.add_argument( - '--num_of_patients', - dest='num_of_patients', - action='store', + "--num_of_patients", + dest="num_of_patients", + action="store", type=int, default=1024, - required=False + required=False, ) parser.add_argument( - '--sampling_batch_size', - dest='sampling_batch_size', - action='store', + "--sampling_batch_size", + dest="sampling_batch_size", + action="store", type=int, default=256, - required=False + required=False, ) parser.add_argument( - '--low_rate', - dest='low_rate', - action='store', + "--low_rate", + dest="low_rate", + action="store", type=valid_mask_rate, default=0.5, - required=False + required=False, ) parser.add_argument( - '--high_rate', - dest='high_rate', - action='store', + "--high_rate", + dest="high_rate", + action="store", type=valid_mask_rate, default=1.0, - required=False + required=False, ) parser.add_argument( - '--period', - dest='period', - action='store', + "--period", + dest="period", + action="store", type=int, default=1000, - required=False - ) - parser.add_argument( - '--total', - dest='total', - action='store', - type=int, - default=infty, - required=False - ) - parser.add_argument( - '--including_long_sequence', - dest='including_long_sequence', - action='store_true' - ) - parser.add_argument( - '--save_checkpoint', - dest='save_checkpoint', - action='store_true' + required=False, ) + parser.add_argument("--total", dest="total", action="store", type=int, default=infty, required=False) + parser.add_argument("--including_long_sequence", dest="including_long_sequence", action="store_true") + parser.add_argument("--save_checkpoint", dest="save_checkpoint", action="store_true") parser.add_argument( - '--save_freq', - dest='save_freq', - action='store', + "--save_freq", + dest="save_freq", + action="store", type=int, default=0, - required='--save_checkpoint' in argv - ) - parser.add_argument( - '--sampling_dataset_enabled', - dest='sampling_dataset_enabled', - action='store_true' - ) - parser.add_argument( - '--is_random_cursor_long_sequence', - dest='is_random_cursor_long_sequence', - action='store_true' + required="--save_checkpoint" in argv, ) parser.add_argument( - '--efficient_training', - dest='efficient_training', - action='store_true' + "--sampling_dataset_enabled", + dest="sampling_dataset_enabled", + action="store_true", ) parser.add_argument( - '--include_numeric_value', - dest='include_numeric_value', - action='store_true' + "--is_random_cursor_long_sequence", + dest="is_random_cursor_long_sequence", + action="store_true", ) + parser.add_argument("--efficient_training", dest="efficient_training", action="store_true") + parser.add_argument("--include_numeric_value", dest="include_numeric_value", action="store_true") + parser.add_argument("--shuffle_records", dest="shuffle_records", action="store_true") parser.add_argument( - '--shuffle_records', - dest='shuffle_records', - action='store_true' - ) - parser.add_argument( - '--val_data_parquet_path', - dest='val_data_parquet_path', - action='store', - required=False - ) - parser.add_argument( - '--is_weighted_sample', - dest='is_weighted_sample', - action='store_true' + "--val_data_parquet_path", + dest="val_data_parquet_path", + action="store", + required=False, ) + parser.add_argument("--is_weighted_sample", dest="is_weighted_sample", action="store_true") parser.add_argument( - '--weighted_sample_scaling_factor', - dest='weighted_sample_scaling_factor', - action='store', + "--weighted_sample_scaling_factor", + dest="weighted_sample_scaling_factor", + action="store", required=False, type=float, - default=0.5 + default=0.5, ) parser.add_argument( - '--weighted_sample_bin_width', - dest='weighted_sample_bin_width', - action='store', + "--weighted_sample_bin_width", + dest="weighted_sample_bin_width", + action="store", required=False, type=int, - default=20 - ) - parser.add_argument( - '--num_steps', - dest='num_steps', - action='store', - required=False, - type=int - ) - parser.add_argument( - '--include_penalty', - dest='include_penalty', - action='store_true' - ) - parser.add_argument( - '--include_positional_encoding', - dest='include_positional_encoding', - action='store_true' + default=20, ) + parser.add_argument("--num_steps", dest="num_steps", action="store", required=False, type=int) + parser.add_argument("--include_penalty", dest="include_penalty", action="store_true") parser.add_argument( - '--sort_sequence_by_length', - dest='sort_sequence_by_length', - action='store_true' + "--include_positional_encoding", + dest="include_positional_encoding", + action="store_true", ) + parser.add_argument("--sort_sequence_by_length", dest="sort_sequence_by_length", action="store_true") return parser def create_parse_args_temporal_bert(): parser = create_parse_args_base_bert() parser.add_argument( - '-ti', - '--time_attention_folder', - dest='time_attention_folder', - action='store', - help= - 'The path for your time attention input_folder where the raw data is', - required=True) + "-ti", + "--time_attention_folder", + dest="time_attention_folder", + action="store", + help="The path for your time attention input_folder where the raw data is", + required=True, + ) return parser @@ -406,74 +336,64 @@ def check_prob(value): parser = create_parse_args_base_bert() parser.add_argument( - '--max_num_visits', - dest='max_num_visits', - action='store', + "--max_num_visits", + dest="max_num_visits", + action="store", type=int, - help='Max no.of visits per patient', - required=True + help="Max no.of visits per patient", + required=True, ) parser.add_argument( - '--max_num_concepts', - dest='max_num_concepts', - action='store', + "--max_num_concepts", + dest="max_num_concepts", + action="store", type=int, - help='Max no.of concepts per visit per patient', - required=True + help="Max no.of concepts per visit per patient", + required=True, ) parser.add_argument( - '--min_num_of_visits', - dest='min_num_of_visits', - action='store', + "--min_num_of_visits", + dest="min_num_of_visits", + action="store", type=int, default=1, - required=False - ) - parser.add_argument( - '--include_att_prediction', - dest='include_att_prediction', - action='store_true' - ) - parser.add_argument( - '--include_readmission', - dest='include_readmission', - action='store_true' + required=False, ) + parser.add_argument("--include_att_prediction", dest="include_att_prediction", action="store_true") + parser.add_argument("--include_readmission", dest="include_readmission", action="store_true") parser.add_argument( - '--random_mask_prob', - dest='random_mask_prob', + "--random_mask_prob", + dest="random_mask_prob", type=check_prob, - required='include_readmission' in argv or 'include_prolonged_length_stay' in argv, + required="include_readmission" in argv or "include_prolonged_length_stay" in argv, default=1.0, - help='The probability the secondary learning objective uses. The value 0.2 ' - 'indicates there is a 20% chance of masking in pre-training ' - 'for secondary learning objectives' + help="The probability the secondary learning objective uses. The value 0.2 " + "indicates there is a 20% chance of masking in pre-training " + "for secondary learning objectives", ) parser.add_argument( - '--concept_similarity_path', - dest='concept_similarity_path', - action='store', - required=False + "--concept_similarity_path", + dest="concept_similarity_path", + action="store", + required=False, ) parser.add_argument( - '--concept_similarity_type', - dest='concept_similarity_type', - action='store', - choices=[ - member.value for member in SimilarityType - ], - help='The concept similarity measures to use for masking', + "--concept_similarity_type", + dest="concept_similarity_type", + action="store", + choices=[member.value for member in SimilarityType], + help="The concept similarity measures to use for masking", default=SimilarityType.NONE.value, - required=False + required=False, ) parser.add_argument( - '--secondary_learning_warmup_step', - dest='warmup_step', - action='store', + "--secondary_learning_warmup_step", + dest="warmup_step", + action="store", type=int, - help='The number steps before secondary learning objectives start', + help="The number steps before secondary learning objectives start", default=-1, - required=False + required=False, ) return parser @@ -481,31 +401,30 @@ def check_prob(value): def create_parse_args_hierarchical_bert_phenotype(): parser = create_parse_args_hierarchical_bert() parser.add_argument( - '--num_of_phenotypes', - dest='num_of_phenotypes', - action='store', + "--num_of_phenotypes", + dest="num_of_phenotypes", + action="store", type=int, - help='Num of phenotypes', + help="Num of phenotypes", default=20, - required=False + required=False, ) parser.add_argument( - '--num_of_phenotype_neighbors', - dest='num_of_phenotype_neighbors', - action='store', + "--num_of_phenotype_neighbors", + dest="num_of_phenotype_neighbors", + action="store", type=int, - help='Num of phenotype neighbors to consider when driving the phenotypes apart from each ' - 'other', + help="Num of phenotype neighbors to consider when driving the phenotypes apart from each " "other", default=3, - required=False + required=False, ) parser.add_argument( - '--num_of_concept_neighbors', - dest='num_of_concept_neighbors', - action='store', + "--num_of_concept_neighbors", + dest="num_of_concept_neighbors", + action="store", type=int, - help='Num of concept neighbors to consider when minimizing the phenotype-concept distances', + help="Num of concept neighbors to consider when minimizing the phenotype-concept distances", default=10, - required=False + required=False, ) return parser diff --git a/src/cehrbert/runners/hf_cehrbert_finetune_runner.py b/src/cehrbert/runners/hf_cehrbert_finetune_runner.py index 91af5c76..05ecd860 100644 --- a/src/cehrbert/runners/hf_cehrbert_finetune_runner.py +++ b/src/cehrbert/runners/hf_cehrbert_finetune_runner.py @@ -1,31 +1,33 @@ -import os import json - +import os from typing import Tuple import numpy as np import pandas as pd -from sklearn.metrics import accuracy_score, roc_auc_score, precision_recall_curve, auc +from datasets import DatasetDict, load_from_disk +from peft import LoraConfig, get_peft_model from scipy.special import expit as sigmoid - -from datasets import load_from_disk, DatasetDict +from sklearn.metrics import accuracy_score, auc, precision_recall_curve, roc_auc_score +from transformers import EarlyStoppingCallback, Trainer, set_seed from transformers.utils import logging -from transformers import Trainer, set_seed -from transformers import EarlyStoppingCallback -from peft import LoraConfig, get_peft_model -from ..data_generators.hf_data_generator.meds_utils import create_dataset_from_meds_reader -from ..data_generators.hf_data_generator.hf_dataset_collator import CehrBertDataCollator -from ..data_generators.hf_data_generator.hf_dataset import create_cehrbert_finetuning_dataset -from ..models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer -from ..models.hf_models.config import CehrBertConfig -from ..models.hf_models.hf_cehrbert import ( - CehrBertPreTrainedModel, CehrBertForClassification, CehrBertLstmForClassification +from cehrbert.data_generators.hf_data_generator.hf_dataset import create_cehrbert_finetuning_dataset +from cehrbert.data_generators.hf_data_generator.hf_dataset_collator import CehrBertDataCollator +from cehrbert.data_generators.hf_data_generator.meds_utils import create_dataset_from_meds_reader +from cehrbert.models.hf_models.config import CehrBertConfig +from cehrbert.models.hf_models.hf_cehrbert import ( + CehrBertForClassification, + CehrBertLstmForClassification, + CehrBertPreTrainedModel, ) -from .hf_runner_argument_dataclass import FineTuneModelType -from .runner_util import ( - get_last_hf_checkpoint, load_parquet_as_dataset, - generate_prepared_ds_path, parse_runner_args, get_meds_extension_path +from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer +from cehrbert.runners.hf_runner_argument_dataclass import FineTuneModelType +from cehrbert.runners.runner_util import ( + generate_prepared_ds_path, + get_last_hf_checkpoint, + get_meds_extension_path, + load_parquet_as_dataset, + parse_runner_args, ) LOG = logging.get_logger("transformers") @@ -56,22 +58,18 @@ def compute_metrics(eval_pred): precision, recall, _ = precision_recall_curve(labels, positive_probs) pr_auc = auc(recall, precision) - return { - "accuracy": accuracy, - "roc_auc": roc_auc, - "pr_auc": pr_auc - } + return {"accuracy": accuracy, "roc_auc": roc_auc, "pr_auc": pr_auc} -def load_pretrained_model_and_tokenizer(model_args) -> Tuple[CehrBertPreTrainedModel, CehrBertTokenizer]: +def load_pretrained_model_and_tokenizer( + model_args, +) -> Tuple[CehrBertPreTrainedModel, CehrBertTokenizer]: # Try to load the pretrained tokenizer try: tokenizer_abspath = os.path.abspath(model_args.tokenizer_name_or_path) tokenizer = CehrBertTokenizer.from_pretrained(tokenizer_abspath) - except Exception as e: - raise ValueError( - f'Can not load the pretrained tokenizer from {model_args.tokenizer_name_or_path}' - ) + except Exception: + raise ValueError(f"Can not load the pretrained tokenizer from {model_args.tokenizer_name_or_path}") if model_args.finetune_model_type == FineTuneModelType.POOLING.value: finetune_model_cls = CehrBertForClassification @@ -79,7 +77,7 @@ def load_pretrained_model_and_tokenizer(model_args) -> Tuple[CehrBertPreTrainedM finetune_model_cls = CehrBertLstmForClassification else: raise ValueError( - f'finetune_model_type can be one of the following types {[e.value for e in FineTuneModelType]}' + f"finetune_model_type can be one of the following types {[e.value for e in FineTuneModelType]}" ) # Try to load the pretrained model @@ -91,7 +89,7 @@ def load_pretrained_model_and_tokenizer(model_args) -> Tuple[CehrBertPreTrainedM model_config = CehrBertConfig( vocab_size=tokenizer.vocab_size, lab_token_ids=tokenizer.lab_token_ids, - **model_args.as_dict() + **model_args.as_dict(), ) model = finetune_model_cls(model_config) @@ -122,11 +120,11 @@ def main(): target_modules=model_args.target_modules, lora_dropout=model_args.lora_dropout, bias="none", - modules_to_save=["classifier", "age_batch_norm", "dense_layer"] + modules_to_save=["classifier", "age_batch_norm", "dense_layer"], ) model = get_peft_model(model, config) else: - raise ValueError(f'The LORA adapter is not supported for {model_args.finetune_model_type}') + raise ValueError(f"The LORA adapter is not supported for {model_args.finetune_model_type}") if any(prepared_ds_path.glob("*")): LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") @@ -139,7 +137,7 @@ def main(): if data_args.is_data_in_med: meds_extension_path = get_meds_extension_path( data_folder=data_args.cohort_folder, - dataset_prepared_path=data_args.dataset_prepared_path + dataset_prepared_path=data_args.dataset_prepared_path, ) try: LOG.info(f"Trying to load the MEDS extension from disk at {meds_extension_path}...") @@ -161,7 +159,7 @@ def main(): test_set = load_parquet_as_dataset(data_args.test_data_folder) if data_args.chronological_split: - dataset = dataset.sort('index_date') + dataset = dataset.sort("index_date") # Determine the split index total_size = len(dataset) train_end = int((1 - data_args.validation_split_percentage) * total_size) @@ -173,12 +171,12 @@ def main(): test_valid = validation_set.train_test_split( test_size=data_args.test_eval_ratio, seed=training_args.seed ) - validation_set = test_valid['train'] - test_set = test_valid['test'] + validation_set = test_valid["train"] + test_set = test_valid["test"] elif data_args.split_by_patient: LOG.info(f"Using the split_by_patient strategy") - unique_patient_ids = np.unique(dataset['person_id']) + unique_patient_ids = np.unique(dataset["person_id"]) LOG.info(f"There are {len(unique_patient_ids)} num of patients in total") np.random.seed(training_args.seed) np.random.shuffle(unique_patient_ids) @@ -187,96 +185,89 @@ def main(): train_patient_ids = set(unique_patient_ids[:train_end]) if not test_set: # Calculate split indices - validation_end = int( - len(unique_patient_ids) - * data_args.validation_split_percentage - * data_args.test_eval_ratio - ) + train_end + validation_end = ( + int(len(unique_patient_ids) * data_args.validation_split_percentage * data_args.test_eval_ratio) + + train_end + ) # Split patient IDs val_patient_ids = set(unique_patient_ids[train_end:validation_end]) test_patient_ids = set(unique_patient_ids[validation_end:]) def assign_split(example): - pid = example['person_id'] + pid = example["person_id"] if pid in train_patient_ids: - return 'train' + return "train" elif pid in val_patient_ids: - return 'validation' + return "validation" elif pid in test_patient_ids: - return 'test' + return "test" else: raise ValueError(f"Unknown patient {pid}") # Apply the function to assign splits dataset = dataset.map( - lambda example: {'split': assign_split(example)}, - num_proc=data_args.preprocessing_num_workers + lambda example: {"split": assign_split(example)}, + num_proc=data_args.preprocessing_num_workers, ) train_set = dataset.filter( - lambda example: example['split'] == 'train', - num_proc=data_args.preprocessing_num_workers + lambda example: example["split"] == "train", + num_proc=data_args.preprocessing_num_workers, ) validation_set = dataset.filter( - lambda example: example['split'] == 'validation', - num_proc=data_args.preprocessing_num_workers + lambda example: example["split"] == "validation", + num_proc=data_args.preprocessing_num_workers, ) test_set = dataset.filter( - lambda example: example['split'] == 'test', - num_proc=data_args.preprocessing_num_workers + lambda example: example["split"] == "test", + num_proc=data_args.preprocessing_num_workers, ) else: # Split patient IDs val_patient_ids = set(unique_patient_ids[train_end:]) def assign_split(example): - pid = example['person_id'] + pid = example["person_id"] if pid in train_patient_ids: - return 'train' + return "train" elif pid in val_patient_ids: - return 'validation' + return "validation" else: raise ValueError(f"Unknown patient {pid}") # Apply the function to assign splits dataset = dataset.map( - lambda example: {'split': assign_split(example)}, - num_proc=data_args.preprocessing_num_workers + lambda example: {"split": assign_split(example)}, + num_proc=data_args.preprocessing_num_workers, ) train_set = dataset.filter( - lambda example: example['split'] == 'train', - num_proc=data_args.preprocessing_num_workers + lambda example: example["split"] == "train", + num_proc=data_args.preprocessing_num_workers, ) validation_set = dataset.filter( - lambda example: example['split'] == 'validation', - num_proc=data_args.preprocessing_num_workers + lambda example: example["split"] == "validation", + num_proc=data_args.preprocessing_num_workers, ) else: # Split the dataset into train/val train_val = dataset.train_test_split( test_size=data_args.validation_split_percentage, - seed=training_args.seed + seed=training_args.seed, ) - train_set = train_val['train'] - validation_set = train_val['test'] + train_set = train_val["train"] + validation_set = train_val["test"] if not test_set: test_valid = validation_set.train_test_split( test_size=data_args.test_eval_ratio, seed=training_args.seed ) - validation_set = test_valid['train'] - test_set = test_valid['test'] + validation_set = test_valid["train"] + test_set = test_valid["test"] # Organize them into a single DatasetDict - final_splits = DatasetDict({ - 'train': train_set, - 'validation': validation_set, - 'test': test_set - }) + final_splits = DatasetDict({"train": train_set, "validation": validation_set, "test": test_set}) processed_dataset = create_cehrbert_finetuning_dataset( - dataset=final_splits, - concept_tokenizer=tokenizer, - data_args=data_args + dataset=final_splits, concept_tokenizer=tokenizer, data_args=data_args ) if not data_args.streaming: @@ -288,16 +279,16 @@ def assign_split(example): set_seed(training_args.seed) if not data_args.streaming: - processed_dataset.set_format('pt') + processed_dataset.set_format("pt") trainer = Trainer( model=model, data_collator=collator, - train_dataset=processed_dataset['train'], - eval_dataset=processed_dataset['validation'], + train_dataset=processed_dataset["train"], + eval_dataset=processed_dataset["validation"], compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=model_args.early_stopping_patience)], - args=training_args + args=training_args, ) checkpoint = get_last_hf_checkpoint(training_args) @@ -313,18 +304,21 @@ def assign_split(example): if training_args.do_predict: # If do_train is set to False, we need to load the model from the checkpoint. if not training_args.do_train: - LOG.info(f"The do_train flag is set to False. Loading the weights form {training_args.output_dir}") + LOG.info( + "The do_train flag is set to False. Loading the weights from %s", + training_args.output_dir, + ) trainer._load_from_checkpoint(training_args.output_dir) - test_results = trainer.predict(processed_dataset['test']) + test_results = trainer.predict(processed_dataset["test"]) # Save results to JSON - test_results_path = os.path.join(training_args.output_dir, 'test_results.json') - with open(test_results_path, 'w') as f: + test_results_path = os.path.join(training_args.output_dir, "test_results.json") + with open(test_results_path, "w") as f: json.dump(test_results.metrics, f, indent=4) - LOG.info(f'Test results: {test_results.metrics}') + LOG.info(f"Test results: {test_results.metrics}") - person_ids = [row['person_id'] for row in processed_dataset['test']] + person_ids = [row["person_id"] for row in processed_dataset["test"]] if isinstance(test_results.predictions, np.ndarray): predictions = np.squeeze(test_results.predictions).tolist() @@ -335,17 +329,8 @@ def assign_split(example): else: labels = np.squeeze(test_results.label_ids[0]).tolist() - prediction_pd = pd.DataFrame( - { - 'person_id ': person_ids, - 'prediction': predictions, - 'label': labels - } - ) - prediction_pd.to_csv( - os.path.join(training_args.output_dir, 'test_predictions.csv'), - index=False - ) + prediction_pd = pd.DataFrame({"person_id ": person_ids, "prediction": predictions, "label": labels}) + prediction_pd.to_csv(os.path.join(training_args.output_dir, "test_predictions.csv"), index=False) if __name__ == "__main__": diff --git a/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py b/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py index dc913a92..6d0b026c 100644 --- a/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py +++ b/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py @@ -1,131 +1,210 @@ +import json import os +from typing import Optional, Union -from typing import Union, Optional - -from datasets import DatasetDict, IterableDatasetDict, Dataset, load_from_disk -from transformers.utils import logging +from datasets import Dataset, DatasetDict, IterableDatasetDict, load_from_disk from transformers import AutoConfig, Trainer, set_seed +from transformers.utils import logging + +from cehrbert.data_generators.hf_data_generator.hf_dataset import create_cehrbert_pretraining_dataset +from cehrbert.data_generators.hf_data_generator.hf_dataset_collator import CehrBertDataCollator +from cehrbert.data_generators.hf_data_generator.meds_utils import create_dataset_from_meds_reader +from cehrbert.models.hf_models.config import CehrBertConfig +from cehrbert.models.hf_models.hf_cehrbert import CehrBertForPreTraining +from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer -from ..data_generators.hf_data_generator.meds_utils import create_dataset_from_meds_reader -from ..data_generators.hf_data_generator.hf_dataset_collator import CehrBertDataCollator -from ..data_generators.hf_data_generator.hf_dataset import ( - create_cehrbert_pretraining_dataset -) -from ..models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer -from ..models.hf_models.config import CehrBertConfig -from ..models.hf_models.hf_cehrbert import CehrBertForPreTraining -from .runner_util import generate_prepared_ds_path, load_parquet_as_dataset, get_last_hf_checkpoint, \ - parse_runner_args, get_meds_extension_path from .hf_runner_argument_dataclass import DataTrainingArguments, ModelArguments +from .runner_util import ( + generate_prepared_ds_path, + get_last_hf_checkpoint, + get_meds_extension_path, + load_parquet_as_dataset, + parse_runner_args, +) LOG = logging.get_logger("transformers") def load_and_create_tokenizer( - data_args: DataTrainingArguments, - model_args: ModelArguments, - dataset: Optional[Union[Dataset, DatasetDict]] = None + data_args: DataTrainingArguments, + model_args: ModelArguments, + dataset: Optional[Union[Dataset, DatasetDict]] = None, ) -> CehrBertTokenizer: + """ + Loads a pretrained tokenizer or creates a new one if it cannot be loaded. + + Args: + data_args (DataTrainingArguments): Data-related arguments used for training the tokenizer. + model_args (ModelArguments): Model-related arguments including the tokenizer's path or name. + dataset (Optional[Union[Dataset, DatasetDict]]): A dataset used to train the tokenizer if it cannot be loaded. + + Returns: + CehrBertTokenizer: The loaded or newly created and trained tokenizer. + + Raises: + RuntimeError: If the tokenizer cannot be loaded and no dataset is provided to create a new tokenizer. + + Behavior: + - Attempts to load the tokenizer from the specified path in `model_args.tokenizer_name_or_path`. + - If loading fails and no dataset is provided, it raises the original exception. + - If a dataset is provided, it trains a new tokenizer on the dataset using the `concept_ids` feature. + - Saves the newly created tokenizer at the specified path. + + Example: + tokenizer = load_and_create_tokenizer(data_args, model_args, dataset) + """ # Try to load the pretrained tokenizer tokenizer_abspath = os.path.abspath(model_args.tokenizer_name_or_path) try: tokenizer = CehrBertTokenizer.from_pretrained(tokenizer_abspath) - except Exception as e: - LOG.warning(e) + except (OSError, RuntimeError, FileNotFoundError, json.JSONDecodeError) as e: + LOG.warning( + "Failed to load the tokenizer from %s with the error " + "\n%s\nTried to create the tokenizer, however the dataset is not provided.", + tokenizer_abspath, + e, + ) if dataset is None: - raise RuntimeError( - f"Failed to load the tokenizer from {tokenizer_abspath} with the error \n{e}\n" - f"Tried to create the tokenizer, however the dataset is not provided." - ) - + raise e tokenizer = CehrBertTokenizer.train_tokenizer( - dataset, ['concept_ids'], {}, data_args + dataset, feature_names=["concept_ids"], concept_name_mapping={}, data_args=data_args ) tokenizer.save_pretrained(tokenizer_abspath) return tokenizer -def load_and_create_model( - model_args: ModelArguments, - tokenizer: CehrBertTokenizer -) -> CehrBertForPreTraining: +def load_and_create_model(model_args: ModelArguments, tokenizer: CehrBertTokenizer) -> CehrBertForPreTraining: + """ + Loads a pretrained model or creates a new model configuration if the pretrained model cannot be loaded. + + Args: + model_args (ModelArguments): Model-related arguments including the model's path or configuration details. + tokenizer (CehrBertTokenizer): The tokenizer to be used with the model, providing vocab and token information. + + Returns: + CehrBertForPreTraining: The loaded or newly configured model for pretraining. + + Behavior: + - Attempts to load the model's configuration from the specified path in `model_args.model_name_or_path`. + - If loading fails, it logs the error and creates a new model configuration using the tokenizer's vocab size + and lab token IDs. + - Returns a `CehrBertForPreTraining` model initialized with the loaded or newly created configuration. + + Example: + model = load_and_create_model(model_args, tokenizer) + """ try: model_abspath = os.path.abspath(model_args.model_name_or_path) model_config = AutoConfig.from_pretrained(model_abspath) - except Exception as e: + except (OSError, ValueError, FileNotFoundError, json.JSONDecodeError) as e: LOG.warning(e) model_config = CehrBertConfig( vocab_size=tokenizer.vocab_size, lab_token_ids=tokenizer.lab_token_ids, - **model_args.as_dict() + **model_args.as_dict(), ) - return CehrBertForPreTraining(model_config) def main(): + """ + Main function for preparing, loading, and training a CEHR-BERT model for pretraining. + + This function handles: + - Parsing input arguments for data, model, and training configurations. + - Loading or creating a dataset, either from a previously saved state or raw data (e.g., MEDS). + - Creating or loading a CEHR-BERT tokenizer, depending on whether a tokenizer exists. + - Creating and configuring the CEHR-BERT model for pretraining. + - Setting up a data collator and trainer for pretraining using Hugging Face's `Trainer` class. + - Handling dataset splitting for training and validation. + - Optionally resuming training from the last checkpoint. + + Key Steps: + 1. Check for streaming data support and adjust settings accordingly. + 2. Load the dataset from disk if available, or create it from raw data. + 3. Tokenize the dataset using the CEHR-BERT tokenizer. + 4. Train the model, resume from a checkpoint if specified, and save the final model and metrics. + + Raises: + RuntimeError: Raised if required arguments (e.g., validation split details) are missing. + + Example Usage: + Run this function in a script with appropriate arguments: + ``` + python hf_cehrbert_pretrain_runner.py --data_args --model_args \ + --training_args + ``` + + Dependencies: + - Hugging Face Transformers (Trainer, Dataset, DatasetDict, etc.) + - CEHR-BERT modules such as `CehrBertTokenizer`, `CehrBertForPreTraining`, + and `CehrBertDataCollator`. + + Notes: + - Assumes the data is in the CEHR-BERT format or needs conversion from the MEDS format. + - Supports both disk-based and streaming datasets, depending on the argument configuration. + - The tokenizer and model are saved to disk after the training process completes. + """ data_args, model_args, training_args = parse_runner_args() if data_args.streaming: - # This is for disabling the warning message https://github.com/huggingface/transformers/issues/5486 - # This happens only when streaming is enabled + # This happens only when streaming is enabled. This is for disabling the warning message + # https://github.com/huggingface/transformers/issues/5486 os.environ["TOKENIZERS_PARALLELISM"] = "false" - # The iterable dataset doesn't have sharding implemented, so the number of works has to be set to 0 - # Otherwise the trainer will throw an error + # The iterable dataset doesn't have sharding implemented, so the number of works has to + # be set to 0. Otherwise the trainer will throw an error training_args.dataloader_num_workers = 0 prepared_ds_path = generate_prepared_ds_path(data_args, model_args) if any(prepared_ds_path.glob("*")): - LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") + LOG.info("Loading prepared dataset from disk at %s...", prepared_ds_path) processed_dataset = load_from_disk(str(prepared_ds_path)) if data_args.streaming: processed_dataset = processed_dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers) LOG.info("Prepared dataset loaded from disk...") - # If the data has been processed in the past, it's assume the tokenizer has been created before. - # we load the CEHR-BERT tokenizer from the output folder. - tokenizer = load_and_create_tokenizer( - data_args=data_args, - model_args=model_args, - dataset=processed_dataset - ) + # If the data has been processed in the past, it's assume the tokenizer has been created + # before. We load the CEHR-BERT tokenizer from the output folder. + tokenizer = load_and_create_tokenizer(data_args=data_args, model_args=model_args, dataset=processed_dataset) else: # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format if data_args.is_data_in_med: meds_extension_path = get_meds_extension_path( data_folder=data_args.data_folder, - dataset_prepared_path=data_args.dataset_prepared_path + dataset_prepared_path=data_args.dataset_prepared_path, ) try: - LOG.info(f"Trying to load the MEDS extension from disk at {meds_extension_path}...") + LOG.info( + "Trying to load the MEDS extension from disk at %s...", + meds_extension_path, + ) dataset = load_from_disk(meds_extension_path) if data_args.streaming: dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers) - except Exception as e: + except RuntimeError as e: LOG.exception(e) dataset = create_dataset_from_meds_reader(data_args, is_pretraining=True) if not data_args.streaming: dataset.save_to_disk(meds_extension_path) else: # Load the dataset from the parquet files - dataset = load_parquet_as_dataset(data_args.data_folder, split='train', streaming=data_args.streaming) + dataset = load_parquet_as_dataset(data_args.data_folder, split="train", streaming=data_args.streaming) # If streaming is enabled, we need to manually split the data into train/val if data_args.streaming and data_args.validation_split_num: dataset = dataset.shuffle(buffer_size=10_000, seed=training_args.seed) train_set = dataset.skip(data_args.validation_split_num) val_set = dataset.take(data_args.validation_split_num) - dataset = DatasetDict({ - 'train': train_set, - 'validation': val_set - }) + dataset = DatasetDict({"train": train_set, "validation": val_set}) elif data_args.validation_split_percentage: - dataset = dataset.train_test_split(test_size=data_args.validation_split_percentage, - seed=training_args.seed) + dataset = dataset.train_test_split( + test_size=data_args.validation_split_percentage, + seed=training_args.seed, + ) else: raise RuntimeError( - f"Can not split the data. If streaming is enabled, validation_split_num needs to be " - f"defined, otherwise validation_split_percentage needs to be provided. " + f"Can not split the data. If streaming is enabled, validation_split_num needs " + f"to be defined, otherwise validation_split_percentage needs to be provided. " f"The current values are:\n" f"validation_split_percentage: {data_args.validation_split_percentage}\n" f"validation_split_num: {data_args.validation_split_num}\n" @@ -133,16 +212,10 @@ def main(): ) # Create the CEHR-BERT tokenizer if it's not available in the output folder - tokenizer = load_and_create_tokenizer( - data_args=data_args, - model_args=model_args, - dataset=dataset - ) + tokenizer = load_and_create_tokenizer(data_args=data_args, model_args=model_args, dataset=dataset) # sort the patient features chronologically and tokenize the data processed_dataset = create_cehrbert_pretraining_dataset( - dataset=dataset, - concept_tokenizer=tokenizer, - data_args=data_args + dataset=dataset, concept_tokenizer=tokenizer, data_args=data_args ) # only save the data to the disk if it is not streaming if not data_args.streaming: @@ -154,7 +227,7 @@ def main(): tokenizer=tokenizer, max_length=model_args.max_position_embeddings, is_pretraining=True, - mlm_probability=model.config.mlm_probability + mlm_probability=model.config.mlm_probability, ) # Detecting last checkpoint. @@ -164,13 +237,13 @@ def main(): set_seed(training_args.seed) if not data_args.streaming: - processed_dataset.set_format('pt') + processed_dataset.set_format("pt") eval_dataset = None if isinstance(processed_dataset, DatasetDict) or isinstance(processed_dataset, IterableDatasetDict): - train_dataset = processed_dataset['train'] - if 'validation' in processed_dataset: - eval_dataset = processed_dataset['validation'] + train_dataset = processed_dataset["train"] + if "validation" in processed_dataset: + eval_dataset = processed_dataset["validation"] else: train_dataset = processed_dataset @@ -180,7 +253,7 @@ def main(): train_dataset=train_dataset, eval_dataset=eval_dataset, # compute_metrics=compute_metrics, - args=training_args + args=training_args, ) checkpoint = None diff --git a/src/cehrbert/runners/hf_runner_argument_dataclass.py b/src/cehrbert/runners/hf_runner_argument_dataclass.py index b1ae890b..b78c6da9 100644 --- a/src/cehrbert/runners/hf_runner_argument_dataclass.py +++ b/src/cehrbert/runners/hf_runner_argument_dataclass.py @@ -1,15 +1,17 @@ -from dataclasses import dataclass, field, asdict +import dataclasses from enum import Enum -from typing import Optional, Dict, Any, Literal, List +from typing import Any, Dict, List, Literal, Optional -from ..data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import MedsToCehrBertConversion -from ..data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import MedsToBertMimic4 +from ..data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import ( + MedsToBertMimic4, + MedsToCehrBertConversion, +) from ..spark_apps.decorators.patient_event_decorator import AttType # Create an enum dynamically from the list MedsToCehrBertConversionType = Enum( - 'MedsToCehrBertConversionType', - [cls.__name__ for cls in MedsToCehrBertConversion.__subclasses__()] + "MedsToCehrBertConversionType", + [cls.__name__ for cls in MedsToCehrBertConversion.__subclasses__()], ) @@ -18,65 +20,58 @@ class FineTuneModelType(Enum): LSTM = "lstm" -@dataclass +@dataclasses.dataclass class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - """ - data_folder: Optional[str] = field( + """Arguments pertaining to what data we are going to input our model for training and eval.""" + + data_folder: Optional[str] = dataclasses.field( metadata={"help": "The name of the dataset to use (via the datasets library)."} ) - dataset_prepared_path: Optional[str] = field( + dataset_prepared_path: Optional[str] = dataclasses.field( metadata={"help": "The folder in which the prepared dataset is cached"} ) - test_data_folder: Optional[str] = field( + test_data_folder: Optional[str] = dataclasses.field( default=None, - metadata={"help": "The name of the test dataset to use (via the datasets library)."} + metadata={"help": "The name of the test dataset to use (via the datasets library)."}, ) - cohort_folder: Optional[str] = field( + cohort_folder: Optional[str] = dataclasses.field( default=None, - metadata={"help": "The name of the cohort generated by ACES or OHDSI cohort builder."} + metadata={"help": "The name of the cohort generated by ACES or OHDSI cohort builder."}, ) - chronological_split: Optional[bool] = field( + chronological_split: Optional[bool] = dataclasses.field( default=False, metadata={ "help": "A flag to indicate whether the data will be split chronologically, " - "where the historical data is used for training " - "and the future data is used for validation adn testing" - } + "where the historical data is used for training " + "and the future data is used for validation adn testing" + }, ) - split_by_patient: Optional[bool] = field( + split_by_patient: Optional[bool] = dataclasses.field( default=False, metadata={ "help": "A flag to indicate whether the records associated with the same person_id " - "should end up in the same split" - } + "should end up in the same split" + }, ) - validation_split_percentage: Optional[float] = field( + validation_split_percentage: Optional[float] = dataclasses.field( default=0.05, - metadata={ - "help": "The percentage of the train set used as validation set in case there's no validation split" - } + metadata={"help": "The percentage of the train set used as validation set in case there's no validation split"}, ) - validation_split_num: Optional[int] = field( + validation_split_num: Optional[int] = dataclasses.field( default=1000, - metadata={ - "help": "The number of the train set used as validation set in case there's no validation split" - }, + metadata={"help": "The number of the train set used as validation set in case there's no validation split"}, ) - test_eval_ratio: Optional[float] = field( + test_eval_ratio: Optional[float] = dataclasses.field( default=0.5, - metadata={ - "help": "The percentage of the train set used as validation set in case there's no validation split" - }, + metadata={"help": "The percentage of the train set used as validation set in case there's no validation split"}, ) - preprocessing_num_workers: Optional[int] = field( + preprocessing_num_workers: Optional[int] = dataclasses.field( default=4, metadata={"help": "The number of processes to use for the preprocessing."}, ) - preprocessing_batch_size: Optional[int] = field( + preprocessing_batch_size: Optional[int] = dataclasses.field( default=10000, - metadata={"help": "The batch size to use for preprocessing a streaming dataset"} + metadata={"help": "The batch size to use for preprocessing a streaming dataset"}, ) att_function_type: Literal[ AttType.CEHR_BERT.value, @@ -84,18 +79,18 @@ class DataTrainingArguments: AttType.WEEK.value, AttType.MONTH.value, AttType.MIX.value, - AttType.NONE.value - ] = field( + AttType.NONE.value, + ] = dataclasses.field( default=AttType.CEHR_BERT.value, metadata={ "help": "The ATT type to choose the level of granularity to use for creating the " - "artificial time tokens between visits", - "choices": f"choices={[e.value for e in AttType]}" - } + "artificial time tokens between visits", + "choices": f"choices={[e.value for e in AttType]}", + }, ) - is_data_in_med: Optional[bool] = field( + is_data_in_med: Optional[bool] = dataclasses.field( default=False, - metadata={"help": "The boolean indicator to indicate whether the data is in the MED format"} + metadata={"help": "The boolean indicator to indicate whether the data is in the MED format"}, ) inpatient_att_function_type: Literal[ AttType.CEHR_BERT.value, @@ -103,95 +98,99 @@ class DataTrainingArguments: AttType.WEEK.value, AttType.MONTH.value, AttType.MIX.value, - AttType.NONE.value - ] = field( + AttType.NONE.value, + ] = dataclasses.field( default=AttType.NONE, metadata={ "help": "The ATT type to choose the level of granularity to use for creating the " - "artificial time tokens between neighboring events within inpatient visits." - "Default to None, meaning the inpatient artificial time tokens are not created.", - "choices": f"choices={[e.value for e in AttType]}" - } + "artificial time tokens between neighboring events within inpatient visits." + "Default to None, meaning the inpatient artificial time tokens are not created.", + "choices": f"choices={[e.value for e in AttType]}", + }, ) - # TODO: Python 3.9/10 do not support dynamic unpacking, we have to manually provide the entire list right now. - meds_to_cehrbert_conversion_type: Literal[MedsToBertMimic4.__name__] = field( + # TODO: Python 3.9/10 do not support dynamic unpacking, we have to manually provide the entire + # list right now. + meds_to_cehrbert_conversion_type: Literal[MedsToBertMimic4.__name__] = dataclasses.field( default=MedsToBertMimic4, metadata={ "help": "The MEDS to CEHRBERT conversion type e.g. MedsToBertMimic4", - "choices": f"choices={[e for e in MedsToCehrBertConversionType.__members__]}" - } + "choices": f"choices={[e for e in MedsToCehrBertConversionType.__members__]}", + }, ) - include_auxiliary_token: Optional[bool] = field( + include_auxiliary_token: Optional[bool] = dataclasses.field( default=False, metadata={ "help": "The boolean indicator to indicate whether visit type should be included " - "at the beginning of the visit and discharge facility should be included at the end of the visit" - } + "at the beginning of the visit and discharge facility should be included at " + "the end of the visit" + }, ) - include_demographic_prompt: Optional[bool] = field( + include_demographic_prompt: Optional[bool] = dataclasses.field( default=False, metadata={ - "help": "The boolean indicator to indicate whether the demographic tokens should be added " - "at the beginning of the sequence including start_year, start_age, gender, race" - } + "help": "The boolean indicator to indicate whether the demographic tokens should be " + "added at the beginning of the sequence " + "including start_year, start_age, gender, race." + }, ) - streaming: Optional[bool] = field( + streaming: Optional[bool] = dataclasses.field( default=False, - metadata={ - "help": "The boolean indicator to indicate whether the data should be streamed" - } + metadata={"help": "The boolean indicator to indicate whether the data should be streamed"}, ) - vocab_size: Optional[int] = field( + vocab_size: Optional[int] = dataclasses.field( default=50_000, - metadata={"help": "The maximum vocab size allowed for the tokenizer trainer to use"} + metadata={"help": "The maximum vocab size allowed for the tokenizer trainer to use"}, ) - min_frequency: Optional[int] = field( + min_frequency: Optional[int] = dataclasses.field( default=0, - metadata={"help": "The minimum frequency for concepts to be kept by the tokenizer"} + metadata={"help": "The minimum frequency for concepts to be kept by the tokenizer"}, ) - min_num_tokens: Optional[int] = field( + min_num_tokens: Optional[int] = dataclasses.field( default=20, - metadata={"help": "The minimum num of tokens required in a sequences to be included for training"} + metadata={"help": "The minimum num of tokens required in a sequences to be included for training"}, ) - shuffle_records: Optional[bool] = field( + shuffle_records: Optional[bool] = dataclasses.field( default=False, - metadata={"help": "Indicates whether to randomly shuffle the records that have the same rank"} + metadata={"help": "Indicates whether to randomly shuffle the records that have the same rank"}, ) -@dataclass +@dataclasses.dataclass class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """Arguments pertaining to which model/config/tokenizer we are going to fine-tune,. + + or train from scratch. """ - model_name_or_path: Optional[str] = field( + model_name_or_path: Optional[str] = dataclasses.field( metadata={ "help": ( - "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." + "The model checkpoint for weights initialization. " + "Don't set if you want to train a model from scratch." ) }, ) - tokenizer_name_or_path: Optional[str] = field( + tokenizer_name_or_path: Optional[str] = dataclasses.field( metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} ) - early_stopping_patience: Optional[int] = field( + early_stopping_patience: Optional[int] = dataclasses.field( default=1, metadata={ - "help": "stop training when the specified metric worsens for `early_stopping_patience` evaluation calls." - } + "help": "stop training when the specified metric worsens " "for `early_stopping_patience` evaluation calls." + }, ) - cache_dir: Optional[str] = field( + cache_dir: Optional[str] = dataclasses.field( default=None, - metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + metadata={"help": "Where do you want to store the pretrained models downloaded " "from huggingface.co"}, ) - use_auth_token: bool = field( + use_auth_token: bool = dataclasses.field( default=None, metadata={ - "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead." - } + "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. " + "Please use `token` instead." + }, ) - trust_remote_code: bool = field( + trust_remote_code: bool = dataclasses.field( default=False, metadata={ "help": ( @@ -201,7 +200,7 @@ class ModelArguments: ) }, ) - torch_dtype: Optional[str] = field( + torch_dtype: Optional[str] = dataclasses.field( default=None, metadata={ "help": ( @@ -209,87 +208,81 @@ class ModelArguments: "dtype will be automatically derived from the model's weights." ), "choices": ["auto", "bfloat16", "float16", "float32"], - } + }, ) - hidden_size: Optional[int] = field( + hidden_size: Optional[int] = dataclasses.field( default=128, - metadata={"help": "The embedding and hidden size for the transformer block"} + metadata={"help": "The embedding and hidden size for the transformer block"}, ) - num_hidden_layers: Optional[int] = field( + num_hidden_layers: Optional[int] = dataclasses.field( default=6, - metadata={"help": "The number of layers used in the transformer model"} + metadata={"help": "The number of layers used in the transformer model"}, ) - n_head: Optional[int] = field( - default=8, - metadata={"help": "The number of heads in Multi-Head Attention"} + n_head: Optional[int] = dataclasses.field( + default=8, metadata={"help": "The number of heads in Multi-Head Attention"} ) - max_position_embeddings: Optional[int] = field( + max_position_embeddings: Optional[int] = dataclasses.field( default=512, - metadata={"help": "The maximum length of the sequence allowed for the transformer model"} + metadata={"help": "The maximum length of the sequence allowed for the transformer model"}, ) - finetune_model_type: Literal[FineTuneModelType.POOLING.value, FineTuneModelType.LSTM.value] = field( + finetune_model_type: Literal[FineTuneModelType.POOLING.value, FineTuneModelType.LSTM.value] = dataclasses.field( default=FineTuneModelType.POOLING.value, metadata={ "help": "The finetune model type to choose from", - "choices": f"choices={[e.value for e in FineTuneModelType]}" - } + "choices": f"choices={[e.value for e in FineTuneModelType]}", + }, ) - use_lora: Optional[bool] = field( + use_lora: Optional[bool] = dataclasses.field( default=False, - metadata={"help": "The flag to indicate whether or not to use the Lora adapter for finetuning"} + metadata={"help": "The flag to indicate whether or not to use the Lora adapter for finetuning"}, ) - lora_rank: Optional[int] = field( - default=16, - metadata={"help": "Lora attention dimension (the “rank”)."} + lora_rank: Optional[int] = dataclasses.field( + default=16, metadata={"help": "Lora attention dimension (the “rank”)."} ) - lora_alpha: Optional[int] = field( - default=16, - metadata={"help": "The alpha parameter for Lora scaling."} + lora_alpha: Optional[int] = dataclasses.field( + default=16, metadata={"help": "The alpha parameter for Lora scaling."} ) - target_modules: Optional[List[str]] = field( + target_modules: Optional[List[str]] = dataclasses.field( default_factory=lambda: ["query", "value"], metadata={ - "help": - "The names of the modules to apply the adapter to. If this is specified, only the modules with the " - "specified names will be replaced. When passing a string, a regex match will be performed. When " - "passing a list of strings, either an exact match will be performed or it is checked if the name " - "of the module ends with any of the passed strings. If this is specified as ‘all-linear’, " - "then all linear/Conv1D modules are chosen, excluding the output layer. If this is not specified, " - "modules will be chosen according to the model architecture. If the architecture is not known, " - "an error will be raised — in this case, you should specify the target modules manually."} - ) - lora_dropout: Optional[float] = field( - default=0.1, - metadata={"help": "The dropout probability for Lora layers"} - ) - exclude_position_ids: Optional[bool] = field( + "help": "The names of the modules to apply the adapter to. If this is specified, only the modules with the " + "specified names will be replaced. When passing a string, a regex match will be performed. When " + "passing a list of strings, either an exact match will be performed or it is checked if the name " + "of the module ends with any of the passed strings. If this is specified as ‘all-linear’, " + "then all linear/Conv1D modules are chosen, excluding the output layer. If this is not specified, " + "modules will be chosen according to the model architecture. If the architecture is not known, " + "an error will be raised — in this case, you should specify the target modules manually." + }, + ) + lora_dropout: Optional[float] = dataclasses.field( + default=0.1, metadata={"help": "The dropout probability for Lora layers"} + ) + exclude_position_ids: Optional[bool] = dataclasses.field( default=False, - metadata={"help": "Whether or not to exclude position ids from the transformer model"} + metadata={"help": "Whether or not to exclude position ids from the transformer model"}, ) - include_values: Optional[bool] = field( + include_values: Optional[bool] = dataclasses.field( default=False, - metadata={"help": "Whether or not to include values into the model"} + metadata={"help": "Whether or not to include values into the model"}, ) - use_sub_time_tokenization: Optional[bool] = field( + use_sub_time_tokenization: Optional[bool] = dataclasses.field( default=True, - metadata={"help": "Whether or not to decompose the time interval into year/month/day"} + metadata={"help": "Whether or not to decompose the time interval into year/month/day"}, ) - include_value_prediction: Optional[bool] = field( + include_value_prediction: Optional[bool] = dataclasses.field( default=True, - metadata={"help": "Whether or not to include value prediction head for cehrbert"} + metadata={"help": "Whether or not to include value prediction head for cehrbert"}, ) - include_ttv_prediction: Optional[bool] = field( + include_ttv_prediction: Optional[bool] = dataclasses.field( default=True, - metadata={"help": "Whether or not to include the time to visit prediction"} + metadata={"help": "Whether or not to include the time to visit prediction"}, ) - time_token_loss_weight: Optional[float] = field( - default=1.0, - metadata={"help": "The weight of the time token loss"} + time_token_loss_weight: Optional[float] = dataclasses.field( + default=1.0, metadata={"help": "The weight of the time token loss"} ) - time_to_visit_loss_weight: Optional[float] = field( - default=1.0, - metadata={"help": "The weight of the time to visit loss"} + time_to_visit_loss_weight: Optional[float] = dataclasses.field( + default=1.0, metadata={"help": "The weight of the time to visit loss"} ) def as_dict(self) -> Dict[str, Any]: - return asdict(self) + return dataclasses.asdict(self) diff --git a/src/cehrbert/runners/runner_util.py b/src/cehrbert/runners/runner_util.py index fe20ba78..63876ca9 100644 --- a/src/cehrbert/runners/runner_util.py +++ b/src/cehrbert/runners/runner_util.py @@ -1,19 +1,19 @@ +import glob import hashlib import os import re -import glob import sys -from typing import Tuple, Union from pathlib import Path +from typing import Tuple, Union import torch -from datasets import load_dataset, Dataset, IterableDataset +from datasets import Dataset, IterableDataset, load_dataset from torch.nn import functional as F -from transformers import HfArgumentParser, TrainingArguments, EvalPrediction -from transformers.utils import logging +from transformers import EvalPrediction, HfArgumentParser, TrainingArguments from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import logging -from .hf_runner_argument_dataclass import ModelArguments, DataTrainingArguments +from .hf_runner_argument_dataclass import DataTrainingArguments, ModelArguments LOG = logging.get_logger("transformers") @@ -50,13 +50,14 @@ def load_parquet_as_dataset(data_folder, split="train", streaming=False) -> Unio """ data_abspath = os.path.abspath(data_folder) data_files = glob.glob(os.path.join(data_abspath, "*.parquet")) - dataset = load_dataset('parquet', data_files=data_files, split=split, streaming=streaming) + dataset = load_dataset("parquet", data_files=data_files, split=split, streaming=streaming) return dataset def get_last_hf_checkpoint(training_args): """ - Retrieves the path to the last saved checkpoint from the specified output directory, + Retrieves the path to the last saved checkpoint from the specified output directory,. + if it exists and conditions permit resuming training from that checkpoint. This function checks if an output directory contains any previously saved checkpoints and @@ -106,17 +107,36 @@ def get_last_hf_checkpoint(training_args): ) elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: LOG.info( - f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " - "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + "Checkpoint detected, resuming training at %s. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch.", + last_checkpoint, ) return last_checkpoint def md5(to_hash: str, encoding: str = "utf-8") -> str: - try: - return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest() - except TypeError: - return hashlib.md5(to_hash.encode(encoding)).hexdigest() + """ + Computes the MD5 hash of a given string. + + Args: + to_hash (str): The string to be hashed. + encoding (str, optional): The character encoding to use for the string. + Defaults to "utf-8". + + Returns: + str: The resulting MD5 hash as a hexadecimal string. + + Notes: + - The `usedforsecurity=False` flag is used to signal that the MD5 hash + is not being used for security purposes. + - If the Python environment does not support the `usedforsecurity=False` flag, + the function will fall back to a standard MD5 hash calculation. + + Example: + >>> md5("hello") + '5d41402abc4b2a76b9719d911017c592' + """ + return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest() def generate_prepared_ds_path(data_args, model_args, data_folder=None) -> Path: @@ -167,60 +187,66 @@ def generate_prepared_ds_path(data_args, model_args, data_folder=None) -> Path: """ data_folder = data_folder if data_folder else data_args.data_folder concatenated_str = ( - str(model_args.max_position_embeddings) + - "|" + os.path.abspath(data_folder) + - "|" + os.path.abspath(model_args.tokenizer_name_or_path) + - "|" + (str(data_args.validation_split_percentage) if data_args.validation_split_percentage else "") + - "|" + f"test_eval_ratio={str(data_args.test_eval_ratio)}" + - "|" + f"split_by_patient={str(data_args.split_by_patient)}" + - "|" + f"chronological_split={str(data_args.chronological_split)}" + str(model_args.max_position_embeddings) + + "|" + + os.path.abspath(data_folder) + + "|" + + os.path.abspath(model_args.tokenizer_name_or_path) + + "|" + + (str(data_args.validation_split_percentage) if data_args.validation_split_percentage else "") + + "|" + + f"test_eval_ratio={str(data_args.test_eval_ratio)}" + + "|" + + f"split_by_patient={str(data_args.split_by_patient)}" + + "|" + + f"chronological_split={str(data_args.chronological_split)}" ) basename = os.path.basename(data_folder) - cleaned_basename = re.sub(r'[^a-zA-Z0-9_]', '', basename) + cleaned_basename = re.sub(r"[^a-zA-Z0-9_]", "", basename) LOG.info(f"concatenated_str: {concatenated_str}") ds_hash = f"{cleaned_basename}_{str(md5(concatenated_str))}" LOG.info(f"ds_hash: {ds_hash}") - prepared_ds_path = ( - Path(os.path.abspath(data_args.dataset_prepared_path)) / ds_hash - ) + prepared_ds_path = Path(os.path.abspath(data_args.dataset_prepared_path)) / ds_hash return prepared_ds_path def parse_runner_args() -> Tuple[DataTrainingArguments, ModelArguments, TrainingArguments]: """ - Parses command line arguments provided to a script for training a model using the Hugging Face library. - - This function uses HfArgumentParser to parse arguments from either command line directly or from configuration files - in JSON or YAML format. The arguments are expected to belong to three categories: ModelArguments, DataTrainingArguments, - and TrainingArguments, each corresponding to specific configurations required for the model training. - - The function checks the system's command line arguments: - - If there is exactly one argument and it is a JSON file, it parses the JSON file to extract the arguments. - - If there is exactly one argument and it is a YAML file, it parses the YAML file instead. - - Otherwise, it assumes arguments are provided directly through the command line and parses them accordingly. - - Returns: - tuple: A tuple containing three elements: - - data_args (DataTrainingArguments): Arguments related to data processing and dataset handling. - - model_args (ModelArguments): Arguments related to model configuration and specifics. - - training_args (TrainingArguments): Arguments related to the training process, such as learning rate and - training epochs. - - Raises: - FileNotFoundError: If the specified JSON or YAML file does not exist. - json.JSONDecodeError: If there is an error parsing a JSON file. - yaml.YAMLError: If there is an error parsing a YAML file. - Exception: For other issues that occur during argument parsing. - - Examples: - Command line usage might look like this: - $ python training_script.py --model_name_or_path bert-base-uncased --do_train - - Or using a JSON configuration file: - $ python training_script.py config.json - - Or using a YAML configuration file: - $ python training_script.py config.yaml + Parses command line arguments provided to a script for training a model using the Hugging Face. + + library. + + This function uses HfArgumentParser to parse arguments from either command line directly or from configuration files + in JSON or YAML format. The arguments are expected to belong to three categories: ModelArguments, DataTrainingArguments, + and TrainingArguments, each corresponding to specific configurations required for the model training. + + The function checks the system's command line arguments: + - If there is exactly one argument and it is a JSON file, it parses the JSON file to extract the arguments. + - If there is exactly one argument and it is a YAML file, it parses the YAML file instead. + - Otherwise, it assumes arguments are provided directly through the command line and parses them accordingly. + + Returns: + tuple: A tuple containing three elements: + - data_args (DataTrainingArguments): Arguments related to data processing and dataset handling. + - model_args (ModelArguments): Arguments related to model configuration and specifics. + - training_args (TrainingArguments): Arguments related to the training process, such as learning rate and + training epochs. + + Raises: + FileNotFoundError: If the specified JSON or YAML file does not exist. + json.JSONDecodeError: If there is an error parsing a JSON file. + yaml.YAMLError: If there is an error parsing a YAML file. + Exception: For other issues that occur during argument parsing. + + Examples: + Command line usage might look like this: + $ python training_script.py --model_name_or_path bert-base-uncased --do_train + + Or using a JSON configuration file: + $ python training_script.py config.json + + Or using a YAML configuration file: + $ python training_script.py config.yaml """ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): @@ -286,13 +312,28 @@ def compute_metrics(eval_pred: EvalPrediction): # Calculate perplexity perplexity = torch.exp(torch.mean(cross_entropy_loss)) - return {"perplexity": perplexity.item()} # Use .item() to extract the scalar value from the tensor + return {"perplexity": perplexity.item()} def get_meds_extension_path(data_folder: str, dataset_prepared_path: str): - data_folder = data_folder + """ + Generates the file path for the 'meds_extension' by appending the base name of the data folder. + + to the dataset prepared path. + + Args: + data_folder (str): The path to the data folder. The trailing backslash will be removed. + dataset_prepared_path (str): The directory where the dataset is prepared. + + Returns: + str: The constructed file path for the meds extension. + + Example: + If data_folder is "C:\\data\\" and dataset_prepared_path is "C:\\prepared_data", + the function will return "C:\\prepared_data\\data_meds_extension". + """ if data_folder.endswith("\\"): - data_folder.rstrip("\\") + data_folder = data_folder.rstrip("\\") basename = os.path.basename(data_folder) meds_extension_path = os.path.join(dataset_prepared_path, f"{basename}_meds_extension") return meds_extension_path diff --git a/src/cehrbert/spark_apps/cohorts/atrial_fibrillation.py b/src/cehrbert/spark_apps/cohorts/atrial_fibrillation.py index ebbcec0b..76395f2c 100644 --- a/src/cehrbert/spark_apps/cohorts/atrial_fibrillation.py +++ b/src/cehrbert/spark_apps/cohorts/atrial_fibrillation.py @@ -1,13 +1,13 @@ -from ..cohorts.query_builder import QueryBuilder, QuerySpec, AncestorTableSpec +from ..cohorts.query_builder import AncestorTableSpec, QueryBuilder, QuerySpec COHORT_QUERY_TEMPLATE = """ SELECT co.person_id, - FIRST(DATE(vo.visit_start_date)) OVER (PARTITION BY co.person_id + FIRST(DATE(vo.visit_start_date)) OVER (PARTITION BY co.person_id ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS index_date, - FIRST(vo.visit_occurrence_id) OVER (PARTITION BY co.person_id + FIRST(vo.visit_occurrence_id) OVER (PARTITION BY co.person_id ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS visit_occurrence_id -FROM global_temp.condition_occurrence AS co +FROM global_temp.condition_occurrence AS co JOIN global_temp.visit_occurrence AS vo ON co.visit_occurrence_id = vo.visit_occurrence_id JOIN global_temp.{atrial_fibrillation_concepts} AS c @@ -16,21 +16,29 @@ ATRIAL_FIBRILLATION_CONCEPT_ID = [313217] -DEPENDENCY_LIST = ['person', 'condition_occurrence', 'visit_occurrence'] +DEPENDENCY_LIST = ["person", "condition_occurrence", "visit_occurrence"] -DEFAULT_COHORT_NAME = 'atrial_fibrillation' -ATRIAL_FIBRILLATION_CONCEPTS = 'atrial_fibrillation_concepts' +DEFAULT_COHORT_NAME = "atrial_fibrillation" +ATRIAL_FIBRILLATION_CONCEPTS = "atrial_fibrillation_concepts" def query_builder(): - query = QuerySpec(table_name=DEFAULT_COHORT_NAME, - query_template=COHORT_QUERY_TEMPLATE, - parameters={'atrial_fibrillation_concepts': ATRIAL_FIBRILLATION_CONCEPTS}) + query = QuerySpec( + table_name=DEFAULT_COHORT_NAME, + query_template=COHORT_QUERY_TEMPLATE, + parameters={"atrial_fibrillation_concepts": ATRIAL_FIBRILLATION_CONCEPTS}, + ) - ancestor_table_specs = [AncestorTableSpec(table_name=ATRIAL_FIBRILLATION_CONCEPTS, - ancestor_concept_ids=ATRIAL_FIBRILLATION_CONCEPT_ID, - is_standard=True)] - return QueryBuilder(cohort_name=DEFAULT_COHORT_NAME, - dependency_list=DEPENDENCY_LIST, - query=query, - ancestor_table_specs=ancestor_table_specs) + ancestor_table_specs = [ + AncestorTableSpec( + table_name=ATRIAL_FIBRILLATION_CONCEPTS, + ancestor_concept_ids=ATRIAL_FIBRILLATION_CONCEPT_ID, + is_standard=True, + ) + ] + return QueryBuilder( + cohort_name=DEFAULT_COHORT_NAME, + dependency_list=DEPENDENCY_LIST, + query=query, + ancestor_table_specs=ancestor_table_specs, + ) diff --git a/src/cehrbert/spark_apps/cohorts/cabg.py b/src/cehrbert/spark_apps/cohorts/cabg.py index e1666d7a..8f4c7051 100644 --- a/src/cehrbert/spark_apps/cohorts/cabg.py +++ b/src/cehrbert/spark_apps/cohorts/cabg.py @@ -1,4 +1,4 @@ -from ..cohorts.query_builder import QueryBuilder, QuerySpec, AncestorTableSpec +from ..cohorts.query_builder import AncestorTableSpec, QueryBuilder, QuerySpec COHORT_QUERY_TEMPLATE = """ SELECT DISTINCT @@ -9,15 +9,15 @@ ( SELECT DISTINCT vo.person_id, - FIRST(DATE(vo.visit_start_date)) OVER (PARTITION BY po.person_id + FIRST(DATE(vo.visit_start_date)) OVER (PARTITION BY po.person_id ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS index_date, - FIRST(vo.visit_occurrence_id) OVER (PARTITION BY po.person_id + FIRST(vo.visit_occurrence_id) OVER (PARTITION BY po.person_id ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS visit_occurrence_id FROM global_temp.procedure_occurrence AS po JOIN global_temp.visit_occurrence AS vo ON po.visit_occurrence_id = vo.visit_occurrence_id WHERE EXISTS ( - SELECT 1 + SELECT 1 FROM global_temp.{cabg_concept_table} AS ie WHERE po.procedure_concept_id = ie.concept_id ) @@ -25,12 +25,24 @@ WHERE c.index_date >= '{date_lower_bound}' """ -DEFAULT_COHORT_NAME = 'cabg' -DEPENDENCY_LIST = ['person', 'procedure_occurrence', 'visit_occurrence'] -CABG_INCLUSION_TABLE = 'CABG' +DEFAULT_COHORT_NAME = "cabg" +DEPENDENCY_LIST = ["person", "procedure_occurrence", "visit_occurrence"] +CABG_INCLUSION_TABLE = "CABG" CABG_CONCEPTS = [ - 43528001, 43528003, 43528004, 43528002, 4305852, 4168831, 2107250, - 2107216, 2107222, 2107231, 4336464, 4231998, 4284104, 2100873 + 43528001, + 43528003, + 43528004, + 43528002, + 4305852, + 4168831, + 2107250, + 2107216, + 2107222, + 2107231, + 4336464, + 4231998, + 4284104, + 2100873, ] @@ -39,21 +51,21 @@ def query_builder(spark_args): table_name=DEFAULT_COHORT_NAME, query_template=COHORT_QUERY_TEMPLATE, parameters={ - 'cabg_concept_table': CABG_INCLUSION_TABLE, - 'date_lower_bound': spark_args.date_lower_bound - } + "cabg_concept_table": CABG_INCLUSION_TABLE, + "date_lower_bound": spark_args.date_lower_bound, + }, ) ancestor_table_specs = [ AncestorTableSpec( table_name=CABG_INCLUSION_TABLE, ancestor_concept_ids=CABG_CONCEPTS, - is_standard=True + is_standard=True, ) ] return QueryBuilder( cohort_name=DEFAULT_COHORT_NAME, dependency_list=DEPENDENCY_LIST, query=query, - ancestor_table_specs=ancestor_table_specs + ancestor_table_specs=ancestor_table_specs, ) diff --git a/src/cehrbert/spark_apps/cohorts/coronary_artery_disease.py b/src/cehrbert/spark_apps/cohorts/coronary_artery_disease.py index f51b4096..a8f7539f 100644 --- a/src/cehrbert/spark_apps/cohorts/coronary_artery_disease.py +++ b/src/cehrbert/spark_apps/cohorts/coronary_artery_disease.py @@ -1,4 +1,4 @@ -from ..cohorts.query_builder import QueryBuilder, QuerySpec, AncestorTableSpec +from ..cohorts.query_builder import AncestorTableSpec, QueryBuilder, QuerySpec COHORT_QUERY_TEMPLATE = """ WITH prior_graft_stent AS ( @@ -20,35 +20,40 @@ ( SELECT DISTINCT vo.person_id, - FIRST(DATE(vo.visit_start_date)) OVER (PARTITION BY co.person_id + FIRST(DATE(vo.visit_start_date)) OVER (PARTITION BY co.person_id ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS index_date, - FIRST(vo.visit_occurrence_id) OVER (PARTITION BY co.person_id + FIRST(vo.visit_occurrence_id) OVER (PARTITION BY co.person_id ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS visit_occurrence_id FROM global_temp.condition_occurrence AS co JOIN global_temp.visit_occurrence AS vo ON co.visit_occurrence_id = vo.visit_occurrence_id WHERE EXISTS ( - SELECT 1 + SELECT 1 FROM global_temp.{cad_concept_table} AS ie WHERE co.condition_concept_id = ie.concept_id ) ) c WHERE NOT EXISTS ( - -- The patients who had a graft or stent procedures before the index date + -- The patients who had a graft or stent procedures before the index date -- need to be removed from the cohort SELECT 1 FROM prior_graft_stent AS exclusion - WHERE exclusion.person_id = c.person_id + WHERE exclusion.person_id = c.person_id AND c.index_date > exclusion.procedure_date -) AND c.index_date >= '{date_lower_bound}' +) AND c.index_date >= '{date_lower_bound}' """ -DEFAULT_COHORT_NAME = 'coronary_artery_disease' -DEPENDENCY_LIST = ['person', 'condition_occurrence', 'procedure_occurrence', 'visit_occurrence'] -CAD_INCLUSION_TABLE = 'CAD' +DEFAULT_COHORT_NAME = "coronary_artery_disease" +DEPENDENCY_LIST = [ + "person", + "condition_occurrence", + "procedure_occurrence", + "visit_occurrence", +] +CAD_INCLUSION_TABLE = "CAD" CAD_CONCEPTS = [317576] -PRIOR_PROCEDURE_TABLE = 'graft_stent' +PRIOR_PROCEDURE_TABLE = "graft_stent" PRIOR_PROCEDURES = [4296227, 42537730, 762043, 44782770, 42537729] @@ -57,27 +62,27 @@ def query_builder(spark_args): table_name=DEFAULT_COHORT_NAME, query_template=COHORT_QUERY_TEMPLATE, parameters={ - 'cad_concept_table': CAD_INCLUSION_TABLE, - 'graft_stent_table': PRIOR_PROCEDURE_TABLE, - 'date_lower_bound': spark_args.date_lower_bound - } + "cad_concept_table": CAD_INCLUSION_TABLE, + "graft_stent_table": PRIOR_PROCEDURE_TABLE, + "date_lower_bound": spark_args.date_lower_bound, + }, ) ancestor_table_specs = [ AncestorTableSpec( table_name=CAD_INCLUSION_TABLE, ancestor_concept_ids=CAD_CONCEPTS, - is_standard=True + is_standard=True, ), AncestorTableSpec( table_name=PRIOR_PROCEDURE_TABLE, ancestor_concept_ids=PRIOR_PROCEDURES, - is_standard=True - ) + is_standard=True, + ), ] return QueryBuilder( cohort_name=DEFAULT_COHORT_NAME, dependency_list=DEPENDENCY_LIST, query=query, - ancestor_table_specs=ancestor_table_specs + ancestor_table_specs=ancestor_table_specs, ) diff --git a/src/cehrbert/spark_apps/cohorts/covid.py b/src/cehrbert/spark_apps/cohorts/covid.py index 0728a76e..89c7c8be 100644 --- a/src/cehrbert/spark_apps/cohorts/covid.py +++ b/src/cehrbert/spark_apps/cohorts/covid.py @@ -21,7 +21,7 @@ UNION - SELECT + SELECT co.person_id, FIRST(visit_start_date) OVER (PARTITION BY v.person_id ORDER BY visit_start_date, v.visit_occurrence_id) AS index_date, FIRST(v.visit_occurrence_id) OVER (PARTITION BY v.person_id ORDER BY visit_start_date, v.visit_occurrence_id) AS visit_occurrence_id @@ -32,15 +32,11 @@ ) c """ -DEFAULT_COHORT_NAME = 'covid19' -DEPENDENCY_LIST = ['person', 'visit_occurrence', 'measurement', 'condition_occurrence'] +DEFAULT_COHORT_NAME = "covid19" +DEPENDENCY_LIST = ["person", "visit_occurrence", "measurement", "condition_occurrence"] def query_builder(): - query = QuerySpec(table_name=DEFAULT_COHORT_NAME, - query_template=COVID_COHORT_QUERY, - parameters={}) + query = QuerySpec(table_name=DEFAULT_COHORT_NAME, query_template=COVID_COHORT_QUERY, parameters={}) - return QueryBuilder(cohort_name=DEFAULT_COHORT_NAME, - dependency_list=DEPENDENCY_LIST, - query=query) + return QueryBuilder(cohort_name=DEFAULT_COHORT_NAME, dependency_list=DEPENDENCY_LIST, query=query) diff --git a/src/cehrbert/spark_apps/cohorts/covid_inpatient.py b/src/cehrbert/spark_apps/cohorts/covid_inpatient.py index 3cb49e65..33ee1eca 100644 --- a/src/cehrbert/spark_apps/cohorts/covid_inpatient.py +++ b/src/cehrbert/spark_apps/cohorts/covid_inpatient.py @@ -1,7 +1,7 @@ from ..cohorts.query_builder import QueryBuilder, QuerySpec COVID_COHORT_QUERY = """ -WITH covid_positive AS +WITH covid_positive AS ( SELECT DISTINCT @@ -20,7 +20,7 @@ WHERE measurement_concept_id IN (723475,723479,706178,723473,723474,586515,706177,706163,706180,706181) AND value_source_value = 'Detected' - UNION + UNION SELECT DISTINCT co.person_id, @@ -66,22 +66,18 @@ FIRST_VALUE(vo.visit_occurrence_id) OVER(PARTITION BY vo.person_id ORDER BY vo.index_date) AS visit_occurrence_id FROM ( - SELECT + SELECT co.* FROM all_covid_tests AS co WHERE visit_concept_id IN (262, 9203, 9201) ) vo """ -DEFAULT_COHORT_NAME = 'covid19' -DEPENDENCY_LIST = ['person', 'visit_occurrence', 'measurement', 'condition_occurrence'] +DEFAULT_COHORT_NAME = "covid19" +DEPENDENCY_LIST = ["person", "visit_occurrence", "measurement", "condition_occurrence"] def query_builder(): - query = QuerySpec(table_name=DEFAULT_COHORT_NAME, - query_template=COVID_COHORT_QUERY, - parameters={}) + query = QuerySpec(table_name=DEFAULT_COHORT_NAME, query_template=COVID_COHORT_QUERY, parameters={}) - return QueryBuilder(cohort_name=DEFAULT_COHORT_NAME, - dependency_list=DEPENDENCY_LIST, - query=query) + return QueryBuilder(cohort_name=DEFAULT_COHORT_NAME, dependency_list=DEPENDENCY_LIST, query=query) diff --git a/src/cehrbert/spark_apps/cohorts/death.py b/src/cehrbert/spark_apps/cohorts/death.py index 48bd2f42..32ab339a 100644 --- a/src/cehrbert/spark_apps/cohorts/death.py +++ b/src/cehrbert/spark_apps/cohorts/death.py @@ -1,20 +1,20 @@ from ..cohorts.query_builder import QueryBuilder, QuerySpec, create_cohort_entry_query_spec DEATH_COHORT_QUERY = """ -WITH max_death_date_cte AS +WITH max_death_date_cte AS ( - SELECT + SELECT person_id, MAX(death_date) AS death_date FROM global_temp.death GROUP BY person_id ), -last_visit_start_date AS +last_visit_start_date AS ( SELECT person_id, MAX(visit_start_date) AS last_visit_start_date - FROM global_temp.visit_occurrence + FROM global_temp.visit_occurrence GROUP BY person_id ) @@ -28,20 +28,18 @@ AND v.last_visit_start_date <= d.death_date """ -DEFAULT_COHORT_NAME = 'mortality' -DEPENDENCY_LIST = ['person', 'death', 'visit_occurrence'] +DEFAULT_COHORT_NAME = "mortality" +DEPENDENCY_LIST = ["person", "death", "visit_occurrence"] def query_builder(): - query = QuerySpec(table_name=DEFAULT_COHORT_NAME, - query_template=DEATH_COHORT_QUERY, - parameters={}) + query = QuerySpec(table_name=DEFAULT_COHORT_NAME, query_template=DEATH_COHORT_QUERY, parameters={}) - entry_cohort_query = create_cohort_entry_query_spec( - entry_query_template=DEATH_COHORT_QUERY, - parameters={}) + entry_cohort_query = create_cohort_entry_query_spec(entry_query_template=DEATH_COHORT_QUERY, parameters={}) - return QueryBuilder(cohort_name=DEFAULT_COHORT_NAME, - dependency_list=DEPENDENCY_LIST, - query=query, - entry_cohort_query=entry_cohort_query) + return QueryBuilder( + cohort_name=DEFAULT_COHORT_NAME, + dependency_list=DEPENDENCY_LIST, + query=query, + entry_cohort_query=entry_cohort_query, + ) diff --git a/src/cehrbert/spark_apps/cohorts/heart_failure.py b/src/cehrbert/spark_apps/cohorts/heart_failure.py index 8938aee1..1362edef 100644 --- a/src/cehrbert/spark_apps/cohorts/heart_failure.py +++ b/src/cehrbert/spark_apps/cohorts/heart_failure.py @@ -1,5 +1,10 @@ -from ..cohorts.query_builder import QueryBuilder, AncestorTableSpec, QuerySpec, \ - create_cohort_entry_query_spec, create_negative_query_spec +from ..cohorts.query_builder import ( + AncestorTableSpec, + QueryBuilder, + QuerySpec, + create_cohort_entry_query_spec, + create_negative_query_spec, +) # 1. Incidens of Heart Failure HEART_FAILURE_CONCEPT = [316139] @@ -12,25 +17,79 @@ PHYSICAL_EXAM_CONCEPT = [433595, 200528, 4117930, 4329988, 4289004, 4285133] ## Lab result concept # https://labtestsonline.org/tests/bnp-and-nt-probnp -BNP_CONCEPT = [4307029, 3031569, 3011960, - 3052295] # High B-type Natriuretic Peptide (BNP) > 500 pg/mL +BNP_CONCEPT = [ + 4307029, + 3031569, + 3011960, + 3052295, +] # High B-type Natriuretic Peptide (BNP) > 500 pg/mL NT_PRO_BNP_CONCEPT = [3029187, 42529224, 3029435, 42529225] -PWP_CONCEPT = [1002721, 4040920, - 21490776] # Pulmonary artery wedge pressure >= 18 no patient in cumc -CVP_CONCEPT = [21490675, 4323687, 3000333, - 1003995] # Central venous pressure >= 12 no patient in cumc +PWP_CONCEPT = [ + 1002721, + 4040920, + 21490776, +] # Pulmonary artery wedge pressure >= 18 no patient in cumc +CVP_CONCEPT = [ + 21490675, + 4323687, + 3000333, + 1003995, +] # Central venous pressure >= 12 no patient in cumc CI_CONCEPT = 21490712 # Cardiac index < 2.2 no patient in cumc # 4. At least ONE of the treatments specifically for HF -DRUG_CONCEPT = [956874, 942350, 987406, 932745, 1309799, 970250, 992590, 907013, 1942960] -MECHANICAL_CIRCULATORY_SUPPORT_CONCEPT = [45888564, 4052536, 4337306, 2107514, 45889695, 2107500, - 45887675, 43527920, 2107501, 45890116, 40756954, 4338594, - 43527923, 40757060, 2100812] +DRUG_CONCEPT = [ + 956874, + 942350, + 987406, + 932745, + 1309799, + 970250, + 992590, + 907013, + 1942960, +] +MECHANICAL_CIRCULATORY_SUPPORT_CONCEPT = [ + 45888564, + 4052536, + 4337306, + 2107514, + 45889695, + 2107500, + 45887675, + 43527920, + 2107501, + 45890116, + 40756954, + 4338594, + 43527923, + 40757060, + 2100812, +] DIALYSIS_CONCEPT = [4032243, 45889365] -ARTIFICIAL_HEART_ASSOCIATED_PROCEDURE_CONCEPT = [4144390, 4150347, 4281764, 725038, 725037, 2100816, - 2100822, 725039, 2100828, 4337306, 4140024, - 4146121, 4060257, 4309033, 4222272, 4243758, - 4241906, 4080968, 4224193, 4052537, 4050864] +ARTIFICIAL_HEART_ASSOCIATED_PROCEDURE_CONCEPT = [ + 4144390, + 4150347, + 4281764, + 725038, + 725037, + 2100816, + 2100822, + 725039, + 2100828, + 4337306, + 4140024, + 4146121, + 4060257, + 4309033, + 4222272, + 4243758, + 4241906, + 4080968, + 4224193, + 4052537, + 4050864, +] DIURETIC_CONCEPT_ID = [4186999] @@ -66,9 +125,9 @@ SELECT DISTINCT v.person_id, v.visit_occurrence_id, - first(DATE(c.condition_start_date)) OVER (PARTITION BY v.person_id + first(DATE(c.condition_start_date)) OVER (PARTITION BY v.person_id ORDER BY DATE(c.condition_start_date)) AS earliest_condition_start_date, - first(DATE(v.visit_start_date)) OVER (PARTITION BY v.person_id + first(DATE(v.visit_start_date)) OVER (PARTITION BY v.person_id ORDER BY DATE(v.visit_start_date)) AS earliest_visit_start_date, first(v.visit_occurrence_id) OVER (PARTITION BY v.person_id ORDER BY DATE(v.visit_start_date)) AS earliest_visit_occurrence_id @@ -89,7 +148,7 @@ ), worsen_hf_diagnosis AS ( - SELECT DISTINCT person_id, visit_occurrence_id + SELECT DISTINCT person_id, visit_occurrence_id FROM global_temp.condition_occurrence AS co JOIN global_temp.{worsen_hf_dx_concepts} AS w_hf ON co.condition_concept_id = w_hf.concept_id @@ -108,7 +167,7 @@ JOIN global_temp.{bnp_concepts} AS bnp ON m.measurement_concept_id = bnp.concept_id AND m.value_source_value > 500 - UNION ALL + UNION ALL SELECT DISTINCT person_id, visit_occurrence_id FROM global_temp.measurement AS m JOIN global_temp.{nt_pro_bnp_concepts} AS nt_bnp @@ -117,15 +176,15 @@ ), drug_concepts AS ( - SELECT DISTINCT + SELECT DISTINCT * FROM ( - SELECT * - FROM global_temp.{drug_concepts} - - UNION - + SELECT * + FROM global_temp.{drug_concepts} + + UNION + SELECT * FROM global_temp.diuretics_concepts ) d @@ -180,9 +239,9 @@ SELECT DISTINCT v.person_id, v.visit_occurrence_id, - first(DATE(c.condition_start_date)) OVER (PARTITION BY v.person_id + first(DATE(c.condition_start_date)) OVER (PARTITION BY v.person_id ORDER BY DATE(c.condition_start_date)) AS earliest_condition_start_date, - first(DATE(v.visit_start_date)) OVER (PARTITION BY v.person_id + first(DATE(v.visit_start_date)) OVER (PARTITION BY v.person_id ORDER BY DATE(v.visit_start_date)) AS earliest_visit_start_date, first(v.visit_occurrence_id) OVER (PARTITION BY v.person_id ORDER BY DATE(v.visit_start_date)) AS earliest_visit_occurrence_id @@ -202,13 +261,13 @@ ) AS bnp ON c.person_id = bnp.person_id LEFT JOIN ( - SELECT DISTINCT + SELECT DISTINCT person_id FROM treatment_cohort ) AS tc ON c.person_id = tc.person_id LEFT JOIN ( - SELECT DISTINCT + SELECT DISTINCT hf.person_id FROM hf_conditions hf JOIN drug_cohort dc @@ -226,95 +285,138 @@ WHERE inclusion = {inclusion} """ -DEPENDENCY_LIST = ['person', 'condition_occurrence', 'visit_occurrence', 'drug_exposure', - 'measurement', - 'procedure_occurrence'] -HEART_FAILURE_CONCEPT_TABLE = 'hf_concept' -WORSEN_HF_DX_CONCEPT_TABLE = 'worsen_hf_dx_concepts' -PHYSICAL_EXAM_COHORT_TABLE = 'phy_exam_concepts' -BNP_CONCEPT_TABLE = 'bnp_concepts' -NT_PRO_BNP_CONCEPT_TABLE = 'nt_pro_bnp_concepts' -DRUG_CONCEPT_TABLE = 'drug_concepts' -MECHANICAL_SUPPORT_CONCEPT_TABLE = 'mechanical_support_concepts' -DIALYSIS_CONCEPT_TABLE = 'dialysis_concepts' -ARTIFICIAL_HEART_CONCEPT_TABLE = 'artificial_heart_concepts' +DEPENDENCY_LIST = [ + "person", + "condition_occurrence", + "visit_occurrence", + "drug_exposure", + "measurement", + "procedure_occurrence", +] +HEART_FAILURE_CONCEPT_TABLE = "hf_concept" +WORSEN_HF_DX_CONCEPT_TABLE = "worsen_hf_dx_concepts" +PHYSICAL_EXAM_COHORT_TABLE = "phy_exam_concepts" +BNP_CONCEPT_TABLE = "bnp_concepts" +NT_PRO_BNP_CONCEPT_TABLE = "nt_pro_bnp_concepts" +DRUG_CONCEPT_TABLE = "drug_concepts" +MECHANICAL_SUPPORT_CONCEPT_TABLE = "mechanical_support_concepts" +DIALYSIS_CONCEPT_TABLE = "dialysis_concepts" +ARTIFICIAL_HEART_CONCEPT_TABLE = "artificial_heart_concepts" -DIURETICS_ANCESTOR_TABLE = 'diuretics_ancestor_table' -DIURETICS_INGREDIENT_CONCEPTS = 'diuretics_concepts' +DIURETICS_ANCESTOR_TABLE = "diuretics_ancestor_table" +DIURETICS_INGREDIENT_CONCEPTS = "diuretics_concepts" -INTERMEDIATE_COHORT_NAME = 'intermediate_heart_failure' -DEFAULT_COHORT_NAME = 'heart_failure' -NEGATIVE_COHORT_NAME = 'negative_heart_failure' +INTERMEDIATE_COHORT_NAME = "intermediate_heart_failure" +DEFAULT_COHORT_NAME = "heart_failure" +NEGATIVE_COHORT_NAME = "negative_heart_failure" def query_builder(): - query = QuerySpec(table_name=DEFAULT_COHORT_NAME, - query_template=HEART_FAILURE_COHORT_QUERY, - parameters={'intermediate_heart_failure': INTERMEDIATE_COHORT_NAME, - 'inclusion': 1}) - - ancestor_table_specs = [AncestorTableSpec(table_name=HEART_FAILURE_CONCEPT_TABLE, - ancestor_concept_ids=HEART_FAILURE_CONCEPT, - is_standard=True), - AncestorTableSpec(table_name=WORSEN_HF_DX_CONCEPT_TABLE, - ancestor_concept_ids=WORSEN_HF_DIAGNOSIS_CONCEPT, - is_standard=True), - AncestorTableSpec(table_name=PHYSICAL_EXAM_COHORT_TABLE, - ancestor_concept_ids=PHYSICAL_EXAM_CONCEPT, - is_standard=True), - AncestorTableSpec(table_name=BNP_CONCEPT_TABLE, - ancestor_concept_ids=BNP_CONCEPT, - is_standard=True), - AncestorTableSpec(table_name=NT_PRO_BNP_CONCEPT_TABLE, - ancestor_concept_ids=NT_PRO_BNP_CONCEPT, - is_standard=True), - AncestorTableSpec(table_name=DRUG_CONCEPT_TABLE, - ancestor_concept_ids=DRUG_CONCEPT, - is_standard=True), - AncestorTableSpec(table_name=MECHANICAL_SUPPORT_CONCEPT_TABLE, - ancestor_concept_ids=MECHANICAL_CIRCULATORY_SUPPORT_CONCEPT, - is_standard=True), - AncestorTableSpec(table_name=DIALYSIS_CONCEPT_TABLE, - ancestor_concept_ids=DIALYSIS_CONCEPT, - is_standard=True), - AncestorTableSpec(table_name=ARTIFICIAL_HEART_CONCEPT_TABLE, - ancestor_concept_ids=ARTIFICIAL_HEART_ASSOCIATED_PROCEDURE_CONCEPT, - is_standard=True), - AncestorTableSpec(table_name=DIURETICS_ANCESTOR_TABLE, - ancestor_concept_ids=DIURETIC_CONCEPT_ID, - is_standard=False) - ] - - dependency_queries = [QuerySpec(table_name=DIURETICS_INGREDIENT_CONCEPTS, - query_template=ROLL_UP_DIURETICS_TO_INGREDIENT_TEMPLATE, - parameters={}), - QuerySpec(table_name=INTERMEDIATE_COHORT_NAME, - query_template=HEART_FAILURE_INTERMEDIATE_COHORT_QUERY, - parameters={'hf_concept': HEART_FAILURE_CONCEPT_TABLE, - 'worsen_hf_dx_concepts': WORSEN_HF_DX_CONCEPT_TABLE, - 'phy_exam_concepts': PHYSICAL_EXAM_COHORT_TABLE, - 'bnp_concepts': BNP_CONCEPT_TABLE, - 'nt_pro_bnp_concepts': NT_PRO_BNP_CONCEPT_TABLE, - 'drug_concepts': DRUG_CONCEPT_TABLE, - 'mechanical_support_concepts': MECHANICAL_SUPPORT_CONCEPT_TABLE, - 'dialysis_concepts': DIALYSIS_CONCEPT_TABLE, - 'artificial_heart_concepts': ARTIFICIAL_HEART_CONCEPT_TABLE - })] + query = QuerySpec( + table_name=DEFAULT_COHORT_NAME, + query_template=HEART_FAILURE_COHORT_QUERY, + parameters={ + "intermediate_heart_failure": INTERMEDIATE_COHORT_NAME, + "inclusion": 1, + }, + ) + + ancestor_table_specs = [ + AncestorTableSpec( + table_name=HEART_FAILURE_CONCEPT_TABLE, + ancestor_concept_ids=HEART_FAILURE_CONCEPT, + is_standard=True, + ), + AncestorTableSpec( + table_name=WORSEN_HF_DX_CONCEPT_TABLE, + ancestor_concept_ids=WORSEN_HF_DIAGNOSIS_CONCEPT, + is_standard=True, + ), + AncestorTableSpec( + table_name=PHYSICAL_EXAM_COHORT_TABLE, + ancestor_concept_ids=PHYSICAL_EXAM_CONCEPT, + is_standard=True, + ), + AncestorTableSpec( + table_name=BNP_CONCEPT_TABLE, + ancestor_concept_ids=BNP_CONCEPT, + is_standard=True, + ), + AncestorTableSpec( + table_name=NT_PRO_BNP_CONCEPT_TABLE, + ancestor_concept_ids=NT_PRO_BNP_CONCEPT, + is_standard=True, + ), + AncestorTableSpec( + table_name=DRUG_CONCEPT_TABLE, + ancestor_concept_ids=DRUG_CONCEPT, + is_standard=True, + ), + AncestorTableSpec( + table_name=MECHANICAL_SUPPORT_CONCEPT_TABLE, + ancestor_concept_ids=MECHANICAL_CIRCULATORY_SUPPORT_CONCEPT, + is_standard=True, + ), + AncestorTableSpec( + table_name=DIALYSIS_CONCEPT_TABLE, + ancestor_concept_ids=DIALYSIS_CONCEPT, + is_standard=True, + ), + AncestorTableSpec( + table_name=ARTIFICIAL_HEART_CONCEPT_TABLE, + ancestor_concept_ids=ARTIFICIAL_HEART_ASSOCIATED_PROCEDURE_CONCEPT, + is_standard=True, + ), + AncestorTableSpec( + table_name=DIURETICS_ANCESTOR_TABLE, + ancestor_concept_ids=DIURETIC_CONCEPT_ID, + is_standard=False, + ), + ] + + dependency_queries = [ + QuerySpec( + table_name=DIURETICS_INGREDIENT_CONCEPTS, + query_template=ROLL_UP_DIURETICS_TO_INGREDIENT_TEMPLATE, + parameters={}, + ), + QuerySpec( + table_name=INTERMEDIATE_COHORT_NAME, + query_template=HEART_FAILURE_INTERMEDIATE_COHORT_QUERY, + parameters={ + "hf_concept": HEART_FAILURE_CONCEPT_TABLE, + "worsen_hf_dx_concepts": WORSEN_HF_DX_CONCEPT_TABLE, + "phy_exam_concepts": PHYSICAL_EXAM_COHORT_TABLE, + "bnp_concepts": BNP_CONCEPT_TABLE, + "nt_pro_bnp_concepts": NT_PRO_BNP_CONCEPT_TABLE, + "drug_concepts": DRUG_CONCEPT_TABLE, + "mechanical_support_concepts": MECHANICAL_SUPPORT_CONCEPT_TABLE, + "dialysis_concepts": DIALYSIS_CONCEPT_TABLE, + "artificial_heart_concepts": ARTIFICIAL_HEART_CONCEPT_TABLE, + }, + ), + ] entry_cohort_query = create_cohort_entry_query_spec( entry_query_template=HEART_FAILURE_ENTRY_COHORT, - parameters={'hf_concept': HEART_FAILURE_CONCEPT_TABLE}) + parameters={"hf_concept": HEART_FAILURE_CONCEPT_TABLE}, + ) negative_query = create_negative_query_spec( entry_query_template=HEART_FAILURE_COHORT_QUERY, - parameters={'intermediate_heart_failure': INTERMEDIATE_COHORT_NAME, - 'inclusion': 0}) - - return QueryBuilder(cohort_name=DEFAULT_COHORT_NAME, - query=query, - negative_query=negative_query, - entry_cohort_query=entry_cohort_query, - dependency_list=DEPENDENCY_LIST, - dependency_queries=dependency_queries, - post_queries=[], - ancestor_table_specs=ancestor_table_specs) + parameters={ + "intermediate_heart_failure": INTERMEDIATE_COHORT_NAME, + "inclusion": 0, + }, + ) + + return QueryBuilder( + cohort_name=DEFAULT_COHORT_NAME, + query=query, + negative_query=negative_query, + entry_cohort_query=entry_cohort_query, + dependency_list=DEPENDENCY_LIST, + dependency_queries=dependency_queries, + post_queries=[], + ancestor_table_specs=ancestor_table_specs, + ) diff --git a/src/cehrbert/spark_apps/cohorts/ischemic_stroke.py b/src/cehrbert/spark_apps/cohorts/ischemic_stroke.py index cd9fc7c3..f24f3cfc 100644 --- a/src/cehrbert/spark_apps/cohorts/ischemic_stroke.py +++ b/src/cehrbert/spark_apps/cohorts/ischemic_stroke.py @@ -1,13 +1,13 @@ -from ..cohorts.query_builder import QueryBuilder, QuerySpec, AncestorTableSpec +from ..cohorts.query_builder import AncestorTableSpec, QueryBuilder, QuerySpec COHORT_QUERY_TEMPLATE = """ SELECT co.person_id, - FIRST(DATE(vo.visit_start_date)) OVER (PARTITION BY co.person_id + FIRST(DATE(vo.visit_start_date)) OVER (PARTITION BY co.person_id ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS index_date, - FIRST(vo.visit_occurrence_id) OVER (PARTITION BY co.person_id + FIRST(vo.visit_occurrence_id) OVER (PARTITION BY co.person_id ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS visit_occurrence_id -FROM global_temp.condition_occurrence AS co +FROM global_temp.condition_occurrence AS co JOIN global_temp.visit_occurrence AS vo ON co.visit_occurrence_id = vo.visit_occurrence_id JOIN global_temp.{ischemic_stroke_concepts} AS c @@ -16,21 +16,29 @@ ISCHEMIC_STROKE_CONCEPT_ID = [443454] -DEPENDENCY_LIST = ['person', 'condition_occurrence', 'visit_occurrence'] +DEPENDENCY_LIST = ["person", "condition_occurrence", "visit_occurrence"] -DEFAULT_COHORT_NAME = 'ischemic_stroke' -ISCHEMIC_STROKE_CONCEPTS = 'ischemic_stroke_concepts' +DEFAULT_COHORT_NAME = "ischemic_stroke" +ISCHEMIC_STROKE_CONCEPTS = "ischemic_stroke_concepts" def query_builder(): - query = QuerySpec(table_name=DEFAULT_COHORT_NAME, - query_template=COHORT_QUERY_TEMPLATE, - parameters={'ischemic_stroke_concepts': ISCHEMIC_STROKE_CONCEPTS}) + query = QuerySpec( + table_name=DEFAULT_COHORT_NAME, + query_template=COHORT_QUERY_TEMPLATE, + parameters={"ischemic_stroke_concepts": ISCHEMIC_STROKE_CONCEPTS}, + ) - ancestor_table_specs = [AncestorTableSpec(table_name=ISCHEMIC_STROKE_CONCEPTS, - ancestor_concept_ids=ISCHEMIC_STROKE_CONCEPT_ID, - is_standard=True)] - return QueryBuilder(cohort_name=DEFAULT_COHORT_NAME, - dependency_list=DEPENDENCY_LIST, - query=query, - ancestor_table_specs=ancestor_table_specs) + ancestor_table_specs = [ + AncestorTableSpec( + table_name=ISCHEMIC_STROKE_CONCEPTS, + ancestor_concept_ids=ISCHEMIC_STROKE_CONCEPT_ID, + is_standard=True, + ) + ] + return QueryBuilder( + cohort_name=DEFAULT_COHORT_NAME, + dependency_list=DEPENDENCY_LIST, + query=query, + ancestor_table_specs=ancestor_table_specs, + ) diff --git a/src/cehrbert/spark_apps/cohorts/last_visit_discharged_home.py b/src/cehrbert/spark_apps/cohorts/last_visit_discharged_home.py index c6047060..5d32112e 100644 --- a/src/cehrbert/spark_apps/cohorts/last_visit_discharged_home.py +++ b/src/cehrbert/spark_apps/cohorts/last_visit_discharged_home.py @@ -5,7 +5,7 @@ v.person_id, v.visit_occurrence_id, v.index_date -FROM +FROM ( SELECT v.person_id, @@ -21,18 +21,14 @@ WHERE v.rn = 1 AND v.index_date >= '{date_lower_bound}' """ -DEPENDENCY_LIST = ['person', 'visit_occurrence'] -DEFAULT_COHORT_NAME = 'last_visit_discharge_home' +DEPENDENCY_LIST = ["person", "visit_occurrence"] +DEFAULT_COHORT_NAME = "last_visit_discharge_home" def query_builder(spark_args): query = QuerySpec( table_name=DEFAULT_COHORT_NAME, query_template=COHORT_QUERY, - parameters={'date_lower_bound': spark_args.date_lower_bound} - ) - return QueryBuilder( - cohort_name=DEFAULT_COHORT_NAME, - dependency_list=DEPENDENCY_LIST, - query=query + parameters={"date_lower_bound": spark_args.date_lower_bound}, ) + return QueryBuilder(cohort_name=DEFAULT_COHORT_NAME, dependency_list=DEPENDENCY_LIST, query=query) diff --git a/src/cehrbert/spark_apps/cohorts/query_builder.py b/src/cehrbert/spark_apps/cohorts/query_builder.py index acaa3b56..ec125e0f 100644 --- a/src/cehrbert/spark_apps/cohorts/query_builder.py +++ b/src/cehrbert/spark_apps/cohorts/query_builder.py @@ -1,21 +1,25 @@ +import logging from abc import ABC from typing import List, NamedTuple -import logging -ENTRY_COHORT = 'entry_cohort' -NEGATIVE_COHORT = 'negative_cohort' +ENTRY_COHORT = "entry_cohort" +NEGATIVE_COHORT = "negative_cohort" def create_cohort_entry_query_spec(entry_query_template, parameters): - return QuerySpec(table_name=ENTRY_COHORT, - query_template=entry_query_template, - parameters=parameters) + return QuerySpec( + table_name=ENTRY_COHORT, + query_template=entry_query_template, + parameters=parameters, + ) def create_negative_query_spec(entry_query_template, parameters): - return QuerySpec(table_name=NEGATIVE_COHORT, - query_template=entry_query_template, - parameters=parameters) + return QuerySpec( + table_name=NEGATIVE_COHORT, + query_template=entry_query_template, + parameters=parameters, + ) class QuerySpec(NamedTuple): @@ -24,8 +28,7 @@ class QuerySpec(NamedTuple): table_name: str def __str__(self): - return (f'table={self.table_name}\n' - f'query={self.query_template.format(**self.parameters)}\n') + return f"table={self.table_name}\n" f"query={self.query_template.format(**self.parameters)}\n" class AncestorTableSpec(NamedTuple): @@ -34,25 +37,29 @@ class AncestorTableSpec(NamedTuple): is_standard: bool def __str__(self): - return (f'table_name={self.table_name}\n' - f'ancestor_concept_ids={self.ancestor_concept_ids}\n' - f'is_standard={self.is_standard}\n') + return ( + f"table_name={self.table_name}\n" + f"ancestor_concept_ids={self.ancestor_concept_ids}\n" + f"is_standard={self.is_standard}\n" + ) class QueryBuilder(ABC): - def __init__(self, - cohort_name: str, - dependency_list: List[str], - query: QuerySpec, - negative_query: QuerySpec = None, - entry_cohort_query: QuerySpec = None, - dependency_queries: List[QuerySpec] = None, - post_queries: List[QuerySpec] = None, - ancestor_table_specs: List[AncestorTableSpec] = None): + def __init__( + self, + cohort_name: str, + dependency_list: List[str], + query: QuerySpec, + negative_query: QuerySpec = None, + entry_cohort_query: QuerySpec = None, + dependency_queries: List[QuerySpec] = None, + post_queries: List[QuerySpec] = None, + ancestor_table_specs: List[AncestorTableSpec] = None, + ): """ - :param cohort_name: + :param query: :param dependency_queries: :param post_queries: @@ -68,53 +75,61 @@ def __init__(self, self._dependency_list = dependency_list self._ancestor_table_specs = ancestor_table_specs - self.get_logger().info(f'cohort_name: {cohort_name}\n' - f'post_queries: {post_queries}\n' - f'entry_cohort: {entry_cohort_query}\n' - f'dependency_queries: {dependency_queries}\n' - f'dependency_list: {dependency_list}\n' - f'ancestor_table_specs: {ancestor_table_specs}\n' - f'query: {query}\n' - f'negative_query: {negative_query}\n') + self.get_logger().info( + f"cohort_name: {cohort_name}\n" + f"post_queries: {post_queries}\n" + f"entry_cohort: {entry_cohort_query}\n" + f"dependency_queries: {dependency_queries}\n" + f"dependency_list: {dependency_list}\n" + f"ancestor_table_specs: {ancestor_table_specs}\n" + f"query: {query}\n" + f"negative_query: {negative_query}\n" + ) def get_dependency_queries(self): """ - Instantiate table dependencies in spark for + Instantiate table dependencies in spark for. + :return: """ return self._dependency_queries def get_entry_cohort_query(self): """ - Queryspec for Instantiating the entry cohort in spark context + Queryspec for Instantiating the entry cohort in spark context. + :return: """ return self._entry_cohort_query def get_query(self): """ - Create a query that can be executed by spark.sql + Create a query that can be executed by spark.sql. + :return: """ return self._query def get_negative_query(self): """ - Return the negative query that can be executed by spark.sql + Return the negative query that can be executed by spark.sql. + :return: """ return self._negative_query def get_post_process_queries(self): """ - Get a list of post process queries to process the cohort + Get a list of post process queries to process the cohort. + :return: """ return self._post_queries def get_dependency_list(self): """ - Get a list of tables that are required for this cohort + Get a list of tables that are required for this cohort. + :return: """ return self._dependency_list @@ -124,13 +139,14 @@ def get_cohort_name(self): def get_ancestor_table_specs(self): """ - Create the descendant table for the provided ancestor_table_specs + Create the descendant table for the provided ancestor_table_specs. + :return: """ return self._ancestor_table_specs def __str__(self): - return f'{str(self.__class__.__name__)} for {self.get_cohort_name()}' + return f"{str(self.__class__.__name__)} for {self.get_cohort_name()}" @classmethod def get_logger(cls): diff --git a/src/cehrbert/spark_apps/cohorts/spark_app_base.py b/src/cehrbert/spark_apps/cohorts/spark_app_base.py index 018247fe..9c920411 100644 --- a/src/cehrbert/spark_apps/cohorts/spark_app_base.py +++ b/src/cehrbert/spark_apps/cohorts/spark_app_base.py @@ -1,26 +1,47 @@ import os import re -from abc import ABC import shutil +from abc import ABC from pandas import to_datetime -from pyspark.sql import DataFrame -from pyspark.sql import SparkSession +from pyspark.sql import DataFrame, SparkSession from pyspark.sql.window import Window -from ..cohorts.query_builder import QueryBuilder, ENTRY_COHORT, NEGATIVE_COHORT -from ...utils.spark_utils import * - -COHORT_TABLE_NAME = 'cohort' -PERSON = 'person' -OBSERVATION_PERIOD = 'observation_period' -DEFAULT_DEPENDENCY = ['person', 'visit_occurrence', 'observation_period', 'concept', - 'concept_ancestor', 'concept_relationship'] +from ...utils.spark_utils import ( + VISIT_OCCURRENCE, + AttType, + F, + List, + W, + build_ancestry_table_for, + create_concept_frequency_data, + create_hierarchical_sequence_data, + create_sequence_data, + create_sequence_data_with_att, + extract_ehr_records, + get_descendant_concept_ids, + logging, + preprocess_domain_table, +) +from ..cohorts.query_builder import ENTRY_COHORT, NEGATIVE_COHORT, QueryBuilder + +COHORT_TABLE_NAME = "cohort" +PERSON = "person" +OBSERVATION_PERIOD = "observation_period" +DEFAULT_DEPENDENCY = [ + "person", + "visit_occurrence", + "observation_period", + "concept", + "concept_ancestor", + "concept_relationship", +] def cohort_validator(required_columns_attribute): """ - Decorator for validating the cohort dataframe returned by build function in + Decorator for validating the cohort dataframe returned by build function in. + AbstractCohortBuilderBase :param required_columns_attribute: attribute for storing cohort_required_columns in :class:`spark_apps.spark_app_base.AbstractCohortBuilderBase` @@ -33,7 +54,7 @@ def wrapper(self, *args, **kwargs): required_columns = getattr(self, required_columns_attribute) for required_column in required_columns: if required_column not in cohort.columns: - raise AssertionError(f'{required_column} is a required column in the cohort') + raise AssertionError(f"{required_column} is a required column in the cohort") return cohort return wrapper @@ -54,27 +75,29 @@ def validate_date_folder(input_folder, table_list): for domain_table_name in table_list: parquet_file_path = os.path.join(input_folder, domain_table_name) if not os.path.exists(parquet_file_path): - raise FileExistsError(f'{parquet_file_path} does not exist') + raise FileExistsError(f"{parquet_file_path} does not exist") def validate_folder(folder_path): if not os.path.exists(folder_path): - raise FileExistsError(f'{folder_path} does not exist') + raise FileExistsError(f"{folder_path} does not exist") class BaseCohortBuilder(ABC): - cohort_required_columns = ['person_id', 'index_date', 'visit_occurrence_id'] - - def __init__(self, - query_builder: QueryBuilder, - input_folder: str, - output_folder: str, - date_lower_bound: str, - date_upper_bound: str, - age_lower_bound: int, - age_upper_bound: int, - prior_observation_period: int, - post_observation_period: int): + cohort_required_columns = ["person_id", "index_date", "visit_occurrence_id"] + + def __init__( + self, + query_builder: QueryBuilder, + input_folder: str, + output_folder: str, + date_lower_bound: str, + date_upper_bound: str, + age_lower_bound: int, + age_upper_bound: int, + prior_observation_period: int, + post_observation_period: int, + ): self._query_builder = query_builder self._input_folder = input_folder @@ -85,18 +108,20 @@ def __init__(self, self._age_upper_bound = age_upper_bound self._prior_observation_period = prior_observation_period self._post_observation_period = post_observation_period - cohort_name = re.sub('[^a-z0-9]+', '_', self._query_builder.get_cohort_name().lower()) + cohort_name = re.sub("[^a-z0-9]+", "_", self._query_builder.get_cohort_name().lower()) self._output_data_folder = os.path.join(self._output_folder, cohort_name) - self.get_logger().info(f'query_builder: {query_builder}\n' - f'input_folder: {input_folder}\n' - f'output_folder: {output_folder}\n' - f'date_lower_bound: {date_lower_bound}\n' - f'date_upper_bound: {date_upper_bound}\n' - f'age_lower_bound: {age_lower_bound}\n' - f'age_upper_bound: {age_upper_bound}\n' - f'prior_observation_period: {prior_observation_period}\n' - f'post_observation_period: {post_observation_period}\n') + self.get_logger().info( + f"query_builder: {query_builder}\n" + f"input_folder: {input_folder}\n" + f"output_folder: {output_folder}\n" + f"date_lower_bound: {date_lower_bound}\n" + f"date_upper_bound: {date_upper_bound}\n" + f"age_lower_bound: {age_lower_bound}\n" + f"age_upper_bound: {age_upper_bound}\n" + f"prior_observation_period: {prior_observation_period}\n" + f"post_observation_period: {post_observation_period}\n" + ) # Validate the age range, observation_window and prediction_window self._validate_integer_inputs() @@ -106,16 +131,17 @@ def __init__(self, # Validate if the data folders exist validate_date_folder(self._input_folder, self._query_builder.get_dependency_list()) - self.spark = SparkSession.builder.appName( - f'Generate {self._query_builder.get_cohort_name()}').getOrCreate() + self.spark = SparkSession.builder.appName(f"Generate {self._query_builder.get_cohort_name()}").getOrCreate() - self._dependency_dict = instantiate_dependencies(self.spark, self._input_folder, - self._query_builder.get_dependency_list()) + self._dependency_dict = instantiate_dependencies( + self.spark, self._input_folder, self._query_builder.get_dependency_list() + ) - @cohort_validator('cohort_required_columns') + @cohort_validator("cohort_required_columns") def create_cohort(self): """ - Create cohort + Create cohort. + :return: """ # Build the ancestor tables for the main query to use if the ancestor_table_specs are @@ -136,16 +162,14 @@ def create_cohort(self): # Build the dependency for the entry cohort if exists if self._query_builder.get_entry_cohort_query(): entry_cohort_query = self._query_builder.get_entry_cohort_query() - query = entry_cohort_query.query_template.format( - **entry_cohort_query.parameters) + query = entry_cohort_query.query_template.format(**entry_cohort_query.parameters) dependency_table = self.spark.sql(query) dependency_table.createOrReplaceGlobalTempView(entry_cohort_query.table_name) # Build the negative cohort if exists if self._query_builder.get_negative_query(): negative_cohort_query = self._query_builder.get_negative_query() - query = negative_cohort_query.query_template.format( - **negative_cohort_query.parameters) + query = negative_cohort_query.query_template.format(**negative_cohort_query.parameters) dependency_table = self.spark.sql(query) dependency_table.createOrReplaceGlobalTempView(negative_cohort_query.table_name) @@ -162,54 +186,61 @@ def create_cohort(self): return cohort def build(self): - """ - Build the cohort and write the dataframe as parquet files to _output_data_folder - """ + """Build the cohort and write the dataframe as parquet files to _output_data_folder.""" cohort = self.create_cohort() cohort = self._apply_observation_period(cohort) cohort = self._add_demographics(cohort) - cohort = cohort.where(F.col('age').between(self._age_lower_bound, self._age_upper_bound)) \ - .where(F.col('index_date').between(to_datetime(self._date_lower_bound), - to_datetime(self._date_upper_bound))) + cohort = cohort.where(F.col("age").between(self._age_lower_bound, self._age_upper_bound)).where( + F.col("index_date").between(to_datetime(self._date_lower_bound), to_datetime(self._date_upper_bound)) + ) - cohort.write.mode('overwrite').parquet(self._output_data_folder) + cohort.write.mode("overwrite").parquet(self._output_data_folder) return self def load_cohort(self): return self.spark.read.parquet(self._output_data_folder) - @cohort_validator('cohort_required_columns') + @cohort_validator("cohort_required_columns") def _apply_observation_period(self, cohort: DataFrame): - cohort.createOrReplaceGlobalTempView('cohort') + cohort.createOrReplaceGlobalTempView("cohort") - qualified_cohort = self.spark.sql(""" + qualified_cohort = self.spark.sql( + """ SELECT c.* - FROM global_temp.cohort AS c - JOIN global_temp.observation_period AS p - ON c.person_id = p.person_id + FROM global_temp.cohort AS c + JOIN global_temp.observation_period AS p + ON c.person_id = p.person_id AND DATE_ADD(c.index_date, -{prior_observation_period}) >= p.observation_period_start_date AND DATE_ADD(c.index_date, {post_observation_period}) <= p.observation_period_end_date - """.format(prior_observation_period=self._prior_observation_period, - post_observation_period=self._post_observation_period)) + """.format( + prior_observation_period=self._prior_observation_period, + post_observation_period=self._post_observation_period, + ) + ) - self.spark.sql(f'DROP VIEW global_temp.cohort') + self.spark.sql(f"DROP VIEW global_temp.cohort") return qualified_cohort - @cohort_validator('cohort_required_columns') + @cohort_validator("cohort_required_columns") def _add_demographics(self, cohort: DataFrame): - return cohort.join(self._dependency_dict[PERSON], 'person_id') \ - .withColumn('age', F.year('index_date') - F.col('year_of_birth')) \ - .select(F.col('person_id'), - F.col('age'), - F.col('gender_concept_id'), - F.col('race_concept_id'), - F.col('index_date'), - F.col('visit_occurrence_id')).distinct() + return ( + cohort.join(self._dependency_dict[PERSON], "person_id") + .withColumn("age", F.year("index_date") - F.col("year_of_birth")) + .select( + F.col("person_id"), + F.col("age"), + F.col("gender_concept_id"), + F.col("race_concept_id"), + F.col("index_date"), + F.col("visit_occurrence_id"), + ) + .distinct() + ) def _validate_integer_inputs(self): assert self._age_lower_bound >= 0 @@ -225,41 +256,41 @@ def get_logger(cls): class NestedCohortBuilder: def __init__( - self, - cohort_name: str, - input_folder: str, - output_folder: str, - target_cohort: DataFrame, - outcome_cohort: DataFrame, - ehr_table_list: List[str], - observation_window: int, - hold_off_window: int, - prediction_start_days: int, - prediction_window: int, - num_of_visits: int, - num_of_concepts: int, - patient_splits_folder: str = None, - is_window_post_index: bool = False, - include_visit_type: bool = True, - allow_measurement_only: bool = False, - exclude_visit_tokens: bool = False, - is_feature_concept_frequency: bool = False, - is_roll_up_concept: bool = False, - include_concept_list: bool = True, - is_new_patient_representation: bool = False, - gpt_patient_sequence: bool = False, - is_hierarchical_bert: bool = False, - classic_bert_seq: bool = False, - is_first_time_outcome: bool = False, - is_questionable_outcome_existed: bool = False, - is_remove_index_prediction_starts: bool = False, - is_prediction_window_unbounded: bool = False, - is_observation_window_unbounded: bool = False, - is_population_estimation: bool = False, - att_type: AttType = AttType.CEHR_BERT, - exclude_demographic: bool = True, - use_age_group: bool = False, - single_contribution: bool = False + self, + cohort_name: str, + input_folder: str, + output_folder: str, + target_cohort: DataFrame, + outcome_cohort: DataFrame, + ehr_table_list: List[str], + observation_window: int, + hold_off_window: int, + prediction_start_days: int, + prediction_window: int, + num_of_visits: int, + num_of_concepts: int, + patient_splits_folder: str = None, + is_window_post_index: bool = False, + include_visit_type: bool = True, + allow_measurement_only: bool = False, + exclude_visit_tokens: bool = False, + is_feature_concept_frequency: bool = False, + is_roll_up_concept: bool = False, + include_concept_list: bool = True, + is_new_patient_representation: bool = False, + gpt_patient_sequence: bool = False, + is_hierarchical_bert: bool = False, + classic_bert_seq: bool = False, + is_first_time_outcome: bool = False, + is_questionable_outcome_existed: bool = False, + is_remove_index_prediction_starts: bool = False, + is_prediction_window_unbounded: bool = False, + is_observation_window_unbounded: bool = False, + is_population_estimation: bool = False, + att_type: AttType = AttType.CEHR_BERT, + exclude_demographic: bool = True, + use_age_group: bool = False, + single_contribution: bool = False, ): self._cohort_name = cohort_name self._input_folder = input_folder @@ -290,9 +321,9 @@ def __init__( self._is_prediction_window_unbounded = is_prediction_window_unbounded self._include_concept_list = include_concept_list self._allow_measurement_only = allow_measurement_only - self._output_data_folder = os.path.join(self._output_folder, - re.sub('[^a-z0-9]+', '_', - self._cohort_name.lower())) + self._output_data_folder = os.path.join( + self._output_folder, re.sub("[^a-z0-9]+", "_", self._cohort_name.lower()) + ) self._is_population_estimation = is_population_estimation self._att_type = att_type self._exclude_demographic = exclude_demographic @@ -300,41 +331,40 @@ def __init__( self._single_contribution = single_contribution self.get_logger().info( - f'cohort_name: {cohort_name}\n' - f'input_folder: {input_folder}\n' - f'output_folder: {output_folder}\n' - f'ehr_table_list: {ehr_table_list}\n' - f'observation_window: {observation_window}\n' - f'prediction_start_days: {prediction_start_days}\n' - f'prediction_window: {prediction_window}\n' - f'hold_off_window: {hold_off_window}\n' - f'num_of_visits: {num_of_visits}\n' - f'num_of_concepts: {num_of_concepts}\n' - f'is_window_post_index: {is_window_post_index}\n' - f'include_visit_type: {include_visit_type}\n' - f'exclude_visit_tokens: {exclude_visit_tokens}\n' - f'allow_measurement_only: {allow_measurement_only}\n' - f'is_feature_concept_frequency: {is_feature_concept_frequency}\n' - f'is_roll_up_concept: {is_roll_up_concept}\n' - f'is_new_patient_representation: {is_new_patient_representation}\n' - f'gpt_patient_sequence: {gpt_patient_sequence}\n' - f'is_hierarchical_bert: {is_hierarchical_bert}\n' - f'is_first_time_outcome: {is_first_time_outcome}\n' - f'is_questionable_outcome_existed: {is_questionable_outcome_existed}\n' - f'is_remove_index_prediction_starts: {is_remove_index_prediction_starts}\n' - f'is_prediction_window_unbounded: {is_prediction_window_unbounded}\n' - f'include_concept_list: {include_concept_list}\n' - f'is_observation_window_unbounded: {is_observation_window_unbounded}\n' - f'is_population_estimation: {is_population_estimation}\n' - f'att_type: {att_type}\n' - f'exclude_demographic: {exclude_demographic}\n' - f'use_age_group: {use_age_group}\n' - f'single_contribution: {single_contribution}\n' + f"cohort_name: {cohort_name}\n" + f"input_folder: {input_folder}\n" + f"output_folder: {output_folder}\n" + f"ehr_table_list: {ehr_table_list}\n" + f"observation_window: {observation_window}\n" + f"prediction_start_days: {prediction_start_days}\n" + f"prediction_window: {prediction_window}\n" + f"hold_off_window: {hold_off_window}\n" + f"num_of_visits: {num_of_visits}\n" + f"num_of_concepts: {num_of_concepts}\n" + f"is_window_post_index: {is_window_post_index}\n" + f"include_visit_type: {include_visit_type}\n" + f"exclude_visit_tokens: {exclude_visit_tokens}\n" + f"allow_measurement_only: {allow_measurement_only}\n" + f"is_feature_concept_frequency: {is_feature_concept_frequency}\n" + f"is_roll_up_concept: {is_roll_up_concept}\n" + f"is_new_patient_representation: {is_new_patient_representation}\n" + f"gpt_patient_sequence: {gpt_patient_sequence}\n" + f"is_hierarchical_bert: {is_hierarchical_bert}\n" + f"is_first_time_outcome: {is_first_time_outcome}\n" + f"is_questionable_outcome_existed: {is_questionable_outcome_existed}\n" + f"is_remove_index_prediction_starts: {is_remove_index_prediction_starts}\n" + f"is_prediction_window_unbounded: {is_prediction_window_unbounded}\n" + f"include_concept_list: {include_concept_list}\n" + f"is_observation_window_unbounded: {is_observation_window_unbounded}\n" + f"is_population_estimation: {is_population_estimation}\n" + f"att_type: {att_type}\n" + f"exclude_demographic: {exclude_demographic}\n" + f"use_age_group: {use_age_group}\n" + f"single_contribution: {single_contribution}\n" ) - self.spark = SparkSession.builder.appName(f'Generate {self._cohort_name}').getOrCreate() - self._dependency_dict = instantiate_dependencies(self.spark, self._input_folder, - DEFAULT_DEPENDENCY) + self.spark = SparkSession.builder.appName(f"Generate {self._cohort_name}").getOrCreate() + self._dependency_dict = instantiate_dependencies(self.spark, self._input_folder, DEFAULT_DEPENDENCY) # Validate the input and output folders validate_folder(self._input_folder) @@ -343,8 +373,8 @@ def __init__( validate_date_folder(self._input_folder, self._ehr_table_list) def build(self): - self._target_cohort.createOrReplaceGlobalTempView('target_cohort') - self._outcome_cohort.createOrReplaceGlobalTempView('outcome_cohort') + self._target_cohort.createOrReplaceGlobalTempView("target_cohort") + self._outcome_cohort.createOrReplaceGlobalTempView("outcome_cohort") prediction_start_days = self._prediction_start_days prediction_window = self._prediction_window @@ -354,7 +384,8 @@ def build(self): prediction_window += self._observation_window + self._hold_off_window if self._is_first_time_outcome: - target_cohort = self.spark.sql(""" + target_cohort = self.spark.sql( + """ SELECT t.person_id AS cohort_member_id, t.* @@ -362,35 +393,46 @@ def build(self): LEFT JOIN global_temp.{entry_cohort} AS o ON t.person_id = o.person_id AND DATE_ADD(t.index_date, {prediction_start_days}) > o.index_date - WHERE o.person_id IS NULL - """.format(entry_cohort=ENTRY_COHORT, - prediction_start_days=prediction_start_days)) - target_cohort.createOrReplaceGlobalTempView('target_cohort') + WHERE o.person_id IS NULL + """.format( + entry_cohort=ENTRY_COHORT, + prediction_start_days=prediction_start_days, + ) + ) + target_cohort.createOrReplaceGlobalTempView("target_cohort") if self._is_questionable_outcome_existed: - target_cohort = self.spark.sql(""" + target_cohort = self.spark.sql( + """ SELECT t.* FROM global_temp.target_cohort AS t LEFT JOIN global_temp.{questionnation_outcome_cohort} AS o ON t.person_id = o.person_id - WHERE o.person_id IS NULL - """.format(questionnation_outcome_cohort=NEGATIVE_COHORT)) - target_cohort.createOrReplaceGlobalTempView('target_cohort') + WHERE o.person_id IS NULL + """.format( + questionnation_outcome_cohort=NEGATIVE_COHORT + ) + ) + target_cohort.createOrReplaceGlobalTempView("target_cohort") if self._is_remove_index_prediction_starts: # Remove the patients whose outcome date lies between index_date and index_date + # prediction_start_days - target_cohort = self.spark.sql(""" + target_cohort = self.spark.sql( + """ SELECT DISTINCT t.* - FROM global_temp.target_cohort AS t + FROM global_temp.target_cohort AS t LEFT JOIN global_temp.outcome_cohort AS exclusion ON t.person_id = exclusion.person_id - AND exclusion.index_date BETWEEN t.index_date - AND DATE_ADD(t.index_date, {prediction_start_days}) + AND exclusion.index_date BETWEEN t.index_date + AND DATE_ADD(t.index_date, {prediction_start_days}) WHERE exclusion.person_id IS NULL - """.format(prediction_start_days=max(prediction_start_days - 1, 0))) - target_cohort.createOrReplaceGlobalTempView('target_cohort') + """.format( + prediction_start_days=max(prediction_start_days - 1, 0) + ) + ) + target_cohort.createOrReplaceGlobalTempView("target_cohort") if self._is_prediction_window_unbounded: query_template = """ @@ -398,10 +440,10 @@ def build(self): t.*, o.index_date as outcome_date, CAST(ISNOTNULL(o.person_id) AS INT) AS label - FROM global_temp.target_cohort AS t + FROM global_temp.target_cohort AS t LEFT JOIN global_temp.outcome_cohort AS o ON t.person_id = o.person_id - AND o.index_date >= DATE_ADD(t.index_date, {prediction_start_days}) + AND o.index_date >= DATE_ADD(t.index_date, {prediction_start_days}) """ else: query_template = """ @@ -409,68 +451,68 @@ def build(self): t.*, o.index_date as outcome_date, CAST(ISNOTNULL(o.person_id) AS INT) AS label - FROM global_temp.target_cohort AS t + FROM global_temp.target_cohort AS t LEFT JOIN global_temp.observation_period AS op - ON t.person_id = op.person_id + ON t.person_id = op.person_id AND DATE_ADD(t.index_date, {prediction_window}) <= op.observation_period_end_date LEFT JOIN global_temp.outcome_cohort AS o ON t.person_id = o.person_id - AND o.index_date BETWEEN DATE_ADD(t.index_date, {prediction_start_days}) + AND o.index_date BETWEEN DATE_ADD(t.index_date, {prediction_start_days}) AND DATE_ADD(t.index_date, {prediction_window}) WHERE op.person_id IS NOT NULL OR o.person_id IS NOT NULL """ - cohort_member_id_udf = F.dense_rank().over( - W.orderBy('person_id', 'index_date', 'visit_occurrence_id')) - cohort = self.spark.sql(query_template.format(prediction_start_days=prediction_start_days, - prediction_window=prediction_window)) \ - .withColumn('cohort_member_id', cohort_member_id_udf) + cohort_member_id_udf = F.dense_rank().over(W.orderBy("person_id", "index_date", "visit_occurrence_id")) + cohort = self.spark.sql( + query_template.format( + prediction_start_days=prediction_start_days, + prediction_window=prediction_window, + ) + ).withColumn("cohort_member_id", cohort_member_id_udf) # Keep one record in case that there are multiple samples generated for the same index_date. # This should not happen in theory, this is really just a safeguard row_rank = F.row_number().over( - Window.partitionBy('person_id', 'cohort_member_id', 'index_date').orderBy(F.desc('label'))) - cohort = cohort.withColumn('row_rank', row_rank) \ - .where('row_rank == 1') \ - .drop('row_rank') + Window.partitionBy("person_id", "cohort_member_id", "index_date").orderBy(F.desc("label")) + ) + cohort = cohort.withColumn("row_rank", row_rank).where("row_rank == 1").drop("row_rank") # We only allow the patient to contribute once to the dataset # If the patient has any positive outcomes, we will take the most recent positive outcome, # otherwise we will take the most recent negative outcome if self._single_contribution: record_rank = F.row_number().over( - Window.partitionBy('person_id').orderBy(F.desc('label'), F.desc('index_date'))) - cohort = cohort.withColumn('record_rank', record_rank) \ - .where('record_rank == 1') \ - .drop('record_rank') + Window.partitionBy("person_id").orderBy(F.desc("label"), F.desc("index_date")) + ) + cohort = cohort.withColumn("record_rank", record_rank).where("record_rank == 1").drop("record_rank") ehr_records_for_cohorts = self.extract_ehr_records_for_cohort(cohort) # ehr_records_for_cohorts.show() - cohort = cohort.join(ehr_records_for_cohorts, ['person_id', 'cohort_member_id']) \ - .where(F.col('num_of_visits') >= self._num_of_visits) \ - .where(F.col('num_of_concepts') >= self._num_of_concepts) + cohort = ( + cohort.join(ehr_records_for_cohorts, ["person_id", "cohort_member_id"]) + .where(F.col("num_of_visits") >= self._num_of_visits) + .where(F.col("num_of_concepts") >= self._num_of_concepts) + ) # if patient_splits is provided, we will if self._patient_splits_folder: patient_splits = self.spark.read.parquet(self._patient_splits_folder) - cohort.join(patient_splits, 'person_id').orderBy('person_id', 'cohort_member_id') \ - .write.mode('overwrite').parquet(os.path.join(self._output_data_folder, 'temp')) + cohort.join(patient_splits, "person_id").orderBy("person_id", "cohort_member_id").write.mode( + "overwrite" + ).parquet(os.path.join(self._output_data_folder, "temp")) # Reload the data from the disk - cohort = self.spark.read.parquet(os.path.join(self._output_data_folder, 'temp')) - cohort.where('split="train"').write.mode('overwrite').parquet( - os.path.join(self._output_data_folder, 'train') + cohort = self.spark.read.parquet(os.path.join(self._output_data_folder, "temp")) + cohort.where('split="train"').write.mode("overwrite").parquet( + os.path.join(self._output_data_folder, "train") ) - cohort.where('split="test"').write.mode('overwrite').parquet( - os.path.join(self._output_data_folder, 'test') - ) - shutil.rmtree(os.path.join(self._output_data_folder, 'temp')) + cohort.where('split="test"').write.mode("overwrite").parquet(os.path.join(self._output_data_folder, "test")) + shutil.rmtree(os.path.join(self._output_data_folder, "temp")) else: - cohort.orderBy('person_id', 'cohort_member_id') \ - .write.mode('overwrite').parquet(self._output_data_folder) + cohort.orderBy("person_id", "cohort_member_id").write.mode("overwrite").parquet(self._output_data_folder) def extract_ehr_records_for_cohort(self, cohort: DataFrame): """ - Create the patient sequence based on the observation window for the given cohort + Create the patient sequence based on the observation window for the given cohort. :param cohort: :return: @@ -482,89 +524,98 @@ def extract_ehr_records_for_cohort(self, cohort: DataFrame): self._ehr_table_list, self._include_visit_type, self._is_roll_up_concept, - self._include_concept_list + self._include_concept_list, ) # Duplicate the records for cohorts that allow multiple entries - ehr_records = ehr_records.join(cohort, 'person_id').select( - [ehr_records[field_name] for field_name in ehr_records.schema.fieldNames()] + ['cohort_member_id'] + ehr_records = ehr_records.join(cohort, "person_id").select( + [ehr_records[field_name] for field_name in ehr_records.schema.fieldNames()] + ["cohort_member_id"] ) # Only allow the data records that occurred between the index date and the prediction window if self._is_population_estimation: if self._is_prediction_window_unbounded: - record_window_filter = ehr_records['date'] <= F.current_date() + record_window_filter = ehr_records["date"] <= F.current_date() else: - record_window_filter = ehr_records['date'] <= F.date_add( - cohort['index_date'], - self._prediction_window - ) + record_window_filter = ehr_records["date"] <= F.date_add(cohort["index_date"], self._prediction_window) else: # For patient level prediction, we remove all records post index date if self._is_observation_post_index: - record_window_filter = ehr_records['date'].between( - cohort['index_date'], - F.date_add(cohort['index_date'], self._observation_window)) + record_window_filter = ehr_records["date"].between( + cohort["index_date"], + F.date_add(cohort["index_date"], self._observation_window), + ) else: if self._is_observation_window_unbounded: - record_window_filter = ehr_records['date'] <= F.date_sub( - cohort['index_date'], self._hold_off_window + record_window_filter = ehr_records["date"] <= F.date_sub( + cohort["index_date"], self._hold_off_window ) else: - record_window_filter = ehr_records['date'].between( - F.date_sub(cohort['index_date'], - self._observation_window + self._hold_off_window), - F.date_sub(cohort['index_date'], self._hold_off_window) + record_window_filter = ehr_records["date"].between( + F.date_sub( + cohort["index_date"], + self._observation_window + self._hold_off_window, + ), + F.date_sub(cohort["index_date"], self._hold_off_window), ) - cohort_ehr_records = ehr_records.join( - cohort, - (ehr_records.person_id == cohort.person_id) & - (ehr_records.cohort_member_id == cohort.cohort_member_id) - ).where(record_window_filter) \ + cohort_ehr_records = ( + ehr_records.join( + cohort, + (ehr_records.person_id == cohort.person_id) & (ehr_records.cohort_member_id == cohort.cohort_member_id), + ) + .where(record_window_filter) .select([ehr_records[field_name] for field_name in ehr_records.schema.fieldNames()]) + ) if self._is_hierarchical_bert: return create_hierarchical_sequence_data( person=self._dependency_dict[PERSON], visit_occurrence=self._dependency_dict[VISIT_OCCURRENCE], patient_events=cohort_ehr_records, - allow_measurement_only=self._allow_measurement_only + allow_measurement_only=self._allow_measurement_only, ) if self._is_feature_concept_frequency: return create_concept_frequency_data(cohort_ehr_records, None) if self._is_new_patient_representation: - birthdate_udf = F.coalesce('birth_datetime', F.concat('year_of_birth', F.lit('-01-01')).cast('timestamp')) + birthdate_udf = F.coalesce( + "birth_datetime", + F.concat("year_of_birth", F.lit("-01-01")).cast("timestamp"), + ) patient_demographic = self._dependency_dict[PERSON].select( - 'person_id', - birthdate_udf.alias('birth_datetime'), - 'race_concept_id', - 'gender_concept_id' + "person_id", + birthdate_udf.alias("birth_datetime"), + "race_concept_id", + "gender_concept_id", ) - age_udf = F.ceil(F.months_between(F.col('visit_start_date'), F.col('birth_datetime')) / F.lit(12)) - visit_occurrence_person = self._dependency_dict[VISIT_OCCURRENCE].join(patient_demographic, 'person_id') \ - .withColumn('age', age_udf) \ - .drop('birth_datetime') + age_udf = F.ceil(F.months_between(F.col("visit_start_date"), F.col("birth_datetime")) / F.lit(12)) + visit_occurrence_person = ( + self._dependency_dict[VISIT_OCCURRENCE] + .join(patient_demographic, "person_id") + .withColumn("age", age_udf) + .drop("birth_datetime") + ) return create_sequence_data_with_att( cohort_ehr_records, visit_occurrence=visit_occurrence_person, include_visit_type=self._include_visit_type, exclude_visit_tokens=self._exclude_visit_tokens, - patient_demographic=patient_demographic if self._gpt_patient_sequence else None, + patient_demographic=(patient_demographic if self._gpt_patient_sequence else None), att_type=self._att_type, exclude_demographic=self._exclude_demographic, - use_age_group=self._use_age_group + use_age_group=self._use_age_group, ) return create_sequence_data( cohort_ehr_records, date_filter=None, include_visit_type=self._include_visit_type, - classic_bert_seq=self._classic_bert_seq) + classic_bert_seq=self._classic_bert_seq, + ) @classmethod def get_logger(cls): @@ -572,13 +623,14 @@ def get_logger(cls): def create_prediction_cohort( - spark_args, - target_query_builder: QueryBuilder, - outcome_query_builder: QueryBuilder, - ehr_table_list + spark_args, + target_query_builder: QueryBuilder, + outcome_query_builder: QueryBuilder, + ehr_table_list, ): """ - TODO + TODO. + :param spark_args: :param target_query_builder: :param outcome_query_builder: @@ -622,30 +674,38 @@ def create_prediction_cohort( post_observation_period = observation_window + hold_off_window if is_window_post_index else 0 # Generate the target cohort - target_cohort = BaseCohortBuilder( - query_builder=target_query_builder, - input_folder=input_folder, - output_folder=output_folder, - date_lower_bound=date_lower_bound, - date_upper_bound=date_upper_bound, - age_lower_bound=age_lower_bound, - age_upper_bound=age_upper_bound, - prior_observation_period=prior_observation_period, - post_observation_period=post_observation_period - ).build().load_cohort() + target_cohort = ( + BaseCohortBuilder( + query_builder=target_query_builder, + input_folder=input_folder, + output_folder=output_folder, + date_lower_bound=date_lower_bound, + date_upper_bound=date_upper_bound, + age_lower_bound=age_lower_bound, + age_upper_bound=age_upper_bound, + prior_observation_period=prior_observation_period, + post_observation_period=post_observation_period, + ) + .build() + .load_cohort() + ) # Generate the outcome cohort - outcome_cohort = BaseCohortBuilder( - query_builder=outcome_query_builder, - input_folder=input_folder, - output_folder=output_folder, - date_lower_bound=date_lower_bound, - date_upper_bound=date_upper_bound, - age_lower_bound=age_lower_bound, - age_upper_bound=age_upper_bound, - prior_observation_period=0, - post_observation_period=0 - ).build().load_cohort() + outcome_cohort = ( + BaseCohortBuilder( + query_builder=outcome_query_builder, + input_folder=input_folder, + output_folder=output_folder, + date_lower_bound=date_lower_bound, + date_upper_bound=date_upper_bound, + age_lower_bound=age_lower_bound, + age_upper_bound=age_upper_bound, + prior_observation_period=0, + post_observation_period=0, + ) + .build() + .load_cohort() + ) NestedCohortBuilder( cohort_name=cohort_name, @@ -681,5 +741,5 @@ def create_prediction_cohort( att_type=AttType(spark_args.att_type), exclude_demographic=spark_args.exclude_demographic, use_age_group=spark_args.use_age_group, - single_contribution=spark_args.single_contribution + single_contribution=spark_args.single_contribution, ).build() diff --git a/src/cehrbert/spark_apps/cohorts/type_two_diabietes.py b/src/cehrbert/spark_apps/cohorts/type_two_diabietes.py index bbd56c0b..3da47290 100644 --- a/src/cehrbert/spark_apps/cohorts/type_two_diabietes.py +++ b/src/cehrbert/spark_apps/cohorts/type_two_diabietes.py @@ -1,4 +1,4 @@ -from ..cohorts.query_builder import QueryBuilder, QuerySpec, AncestorTableSpec +from ..cohorts.query_builder import AncestorTableSpec, QueryBuilder, QuerySpec COHORT_QUERY_TEMPLATE = """ WITH person_ids_to_include_drug AS @@ -9,9 +9,9 @@ JOIN global_temp.{drug_inclusion_concepts} AS e ON d.drug_concept_id = e.concept_id ), -person_ids_to_exclude_observation AS +person_ids_to_exclude_observation AS ( - + SELECT DISTINCT o.person_id, o.observation_date @@ -19,7 +19,7 @@ JOIN global_temp.{observation_exclusion_concepts} AS oec ON o.observation_concept_id = oec.concept_id ) -SELECT +SELECT distinct c.person_id, c.index_date, @@ -28,9 +28,9 @@ ( SELECT DISTINCT vo.person_id, - FIRST(DATE(vo.visit_start_date)) OVER (PARTITION BY co.person_id + FIRST(DATE(vo.visit_start_date)) OVER (PARTITION BY co.person_id ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS index_date, - FIRST(vo.visit_occurrence_id) OVER (PARTITION BY co.person_id + FIRST(vo.visit_occurrence_id) OVER (PARTITION BY co.person_id ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS visit_occurrence_id FROM global_temp.condition_occurrence AS co JOIN global_temp.{diabetes_inclusion_concepts} AS ie @@ -46,25 +46,85 @@ """ DIABETES_INCLUSION = [443238, 201820, 442793, 4016045] -DIABETES_EXCLUSION = [40484648, 201254, 435216, 4058243, 30968, 438476, 195771, 193323, 4019513, - 40484649] +DIABETES_EXCLUSION = [ + 40484648, + 201254, + 435216, + 4058243, + 30968, + 438476, + 195771, + 193323, + 4019513, + 40484649, +] DRUG_INCLUSION = [ - 1503297, 1594973, 1597756, 1559684, 1560171, 1502855, 1502809, 1525215, 1547504, 1580747, - 40166035, 43013884, 40239216, 1516766, 1502826, 1510202, 1529331, 35605670, 35602717, 1516976, - 1502905, 46221581, 1550023, 35198096, 42899447, 1544838, 1567198, 35884381, 1531601, 1588986, - 1513876, 19013951, 1590165, 1596977, 1586346, 19090204, 1513843, 1513849, 1562586, 19090226, - 19090221, 1586369, 19090244, 19090229, 19090247, 19090249, 19090180, 19013926, 19091621, - 19090187] + 1503297, + 1594973, + 1597756, + 1559684, + 1560171, + 1502855, + 1502809, + 1525215, + 1547504, + 1580747, + 40166035, + 43013884, + 40239216, + 1516766, + 1502826, + 1510202, + 1529331, + 35605670, + 35602717, + 1516976, + 1502905, + 46221581, + 1550023, + 35198096, + 42899447, + 1544838, + 1567198, + 35884381, + 1531601, + 1588986, + 1513876, + 19013951, + 1590165, + 1596977, + 1586346, + 19090204, + 1513843, + 1513849, + 1562586, + 19090226, + 19090221, + 1586369, + 19090244, + 19090229, + 19090247, + 19090249, + 19090180, + 19013926, + 19091621, + 19090187, +] OBSERVATION_EXCLUSION = [40769338, 43021173, 42539022, 46270562] -DEPENDENCY_LIST = ['person', 'condition_occurrence', 'visit_occurrence', 'drug_exposure', - 'observation'] +DEPENDENCY_LIST = [ + "person", + "condition_occurrence", + "visit_occurrence", + "drug_exposure", + "observation", +] -DIABETES_INCLUSION_TABLE = 'diabetes_inclusion_concepts' -DIABETES_EXCLUSION_TABLE = 'diabetes_exclusion_concepts' -DRUG_INCLUSION_TABLE = 'drug_inclusion_concepts' -OBSERVATION_EXCLUSION_TABLE = 'observation_exclusion_concepts' +DIABETES_INCLUSION_TABLE = "diabetes_inclusion_concepts" +DIABETES_EXCLUSION_TABLE = "diabetes_exclusion_concepts" +DRUG_INCLUSION_TABLE = "drug_inclusion_concepts" +OBSERVATION_EXCLUSION_TABLE = "observation_exclusion_concepts" -DEFAULT_COHORT_NAME = 'type_two_diabetes' +DEFAULT_COHORT_NAME = "type_two_diabetes" def query_builder(spark_args): @@ -72,31 +132,39 @@ def query_builder(spark_args): table_name=DEFAULT_COHORT_NAME, query_template=COHORT_QUERY_TEMPLATE, parameters={ - 'diabetes_exclusion_concepts': DIABETES_EXCLUSION_TABLE, - 'diabetes_inclusion_concepts': DIABETES_INCLUSION_TABLE, - 'drug_inclusion_concepts': DRUG_INCLUSION_TABLE, - 'observation_exclusion_concepts': OBSERVATION_EXCLUSION_TABLE, - 'date_lower_bound': spark_args.date_lower_bound - } + "diabetes_exclusion_concepts": DIABETES_EXCLUSION_TABLE, + "diabetes_inclusion_concepts": DIABETES_INCLUSION_TABLE, + "drug_inclusion_concepts": DRUG_INCLUSION_TABLE, + "observation_exclusion_concepts": OBSERVATION_EXCLUSION_TABLE, + "date_lower_bound": spark_args.date_lower_bound, + }, ) ancestor_table_specs = [ - AncestorTableSpec(table_name=DIABETES_INCLUSION_TABLE, - ancestor_concept_ids=DIABETES_INCLUSION, - is_standard=True), - AncestorTableSpec(table_name=DIABETES_EXCLUSION_TABLE, - ancestor_concept_ids=DIABETES_EXCLUSION, - is_standard=True), - AncestorTableSpec(table_name=OBSERVATION_EXCLUSION_TABLE, - ancestor_concept_ids=OBSERVATION_EXCLUSION, - is_standard=True), - AncestorTableSpec(table_name=DRUG_INCLUSION_TABLE, - ancestor_concept_ids=DRUG_INCLUSION, - is_standard=True) - + AncestorTableSpec( + table_name=DIABETES_INCLUSION_TABLE, + ancestor_concept_ids=DIABETES_INCLUSION, + is_standard=True, + ), + AncestorTableSpec( + table_name=DIABETES_EXCLUSION_TABLE, + ancestor_concept_ids=DIABETES_EXCLUSION, + is_standard=True, + ), + AncestorTableSpec( + table_name=OBSERVATION_EXCLUSION_TABLE, + ancestor_concept_ids=OBSERVATION_EXCLUSION, + is_standard=True, + ), + AncestorTableSpec( + table_name=DRUG_INCLUSION_TABLE, + ancestor_concept_ids=DRUG_INCLUSION, + is_standard=True, + ), ] return QueryBuilder( cohort_name=DEFAULT_COHORT_NAME, dependency_list=DEPENDENCY_LIST, query=query, - ancestor_table_specs=ancestor_table_specs) + ancestor_table_specs=ancestor_table_specs, + ) diff --git a/src/cehrbert/spark_apps/cohorts/ventilation.py b/src/cehrbert/spark_apps/cohorts/ventilation.py index aaa783a2..d3af84d0 100644 --- a/src/cehrbert/spark_apps/cohorts/ventilation.py +++ b/src/cehrbert/spark_apps/cohorts/ventilation.py @@ -8,15 +8,15 @@ FROM global_temp.vent AS vent """ -DEFAULT_COHORT_NAME = 'ventilation' -DEPENDENCY_LIST = ['vent'] +DEFAULT_COHORT_NAME = "ventilation" +DEPENDENCY_LIST = ["vent"] def query_builder(): - query = QuerySpec(table_name=DEFAULT_COHORT_NAME, - query_template=VENTILATION_COHORT_QUERY, - parameters={}) + query = QuerySpec( + table_name=DEFAULT_COHORT_NAME, + query_template=VENTILATION_COHORT_QUERY, + parameters={}, + ) - return QueryBuilder(cohort_name=DEFAULT_COHORT_NAME, - dependency_list=DEPENDENCY_LIST, - query=query) + return QueryBuilder(cohort_name=DEFAULT_COHORT_NAME, dependency_list=DEPENDENCY_LIST, query=query) diff --git a/src/cehrbert/spark_apps/decorators/patient_event_decorator.py b/src/cehrbert/spark_apps/decorators/patient_event_decorator.py index afc6d61a..7eaf2409 100644 --- a/src/cehrbert/spark_apps/decorators/patient_event_decorator.py +++ b/src/cehrbert/spark_apps/decorators/patient_event_decorator.py @@ -5,18 +5,20 @@ import numpy as np from pyspark.sql import DataFrame -from pyspark.sql import functions as F, Window as W, types as T +from pyspark.sql import Window as W +from pyspark.sql import functions as F +from pyspark.sql import types as T -from ...const.common import MEASUREMENT, CATEGORICAL_MEASUREMENT +from ...const.common import CATEGORICAL_MEASUREMENT, MEASUREMENT class AttType(Enum): - DAY = 'day' - WEEK = 'week' - MONTH = 'month' - CEHR_BERT = 'cehrbert' - MIX = 'mix' - NONE = 'none' + DAY = "day" + WEEK = "week" + MONTH = "month" + CEHR_BERT = "cehrbert" + MIX = "mix" + NONE = "none" class PatientEventDecorator(ABC): @@ -31,11 +33,30 @@ def decorate(self, patient_events): @classmethod def get_required_columns(cls): - return set(['cohort_member_id', 'person_id', 'standard_concept_id', 'date', - 'datetime', 'visit_occurrence_id', 'domain', 'concept_value', 'visit_rank_order', - 'visit_segment', 'priority', 'date_in_week', 'concept_value_mask', - 'mlm_skip_value', 'age', 'visit_concept_id', 'visit_start_date', - 'visit_start_datetime', 'visit_concept_order', 'concept_order']) + return set( + [ + "cohort_member_id", + "person_id", + "standard_concept_id", + "date", + "datetime", + "visit_occurrence_id", + "domain", + "concept_value", + "visit_rank_order", + "visit_segment", + "priority", + "date_in_week", + "concept_value_mask", + "mlm_skip_value", + "age", + "visit_concept_id", + "visit_start_date", + "visit_start_datetime", + "visit_concept_order", + "concept_order", + ] + ) def validate(self, patient_events: DataFrame): actual_column_set = set(patient_events.columns) @@ -44,15 +65,13 @@ def validate(self, patient_events: DataFrame): diff_left = actual_column_set - expected_column_set diff_right = expected_column_set - actual_column_set raise RuntimeError( - f'{self}\n' - f'actual_column_set - expected_column_set: {diff_left}\n' - f'expected_column_set - actual_column_set: {diff_right}' + f"{self}\n" + f"actual_column_set - expected_column_set: {diff_left}\n" + f"expected_column_set - actual_column_set: {diff_right}" ) -class PatientEventBaseDecorator( - PatientEventDecorator -): +class PatientEventBaseDecorator(PatientEventDecorator): # output_columns = [ # 'cohort_member_id', 'person_id', 'concept_ids', 'visit_segments', 'orders', # 'dates', 'ages', 'visit_concept_orders', 'num_of_visits', 'num_of_concepts', @@ -62,12 +81,10 @@ class PatientEventBaseDecorator( def __init__(self, visit_occurrence): self._visit_occurrence = visit_occurrence - def _decorate( - self, - patient_events: DataFrame - ): + def _decorate(self, patient_events: DataFrame): """ - patient_events contains the following columns (cohort_member_id, person_id, + Patient_events contains the following columns (cohort_member_id, person_id,. + standard_concept_id, date, visit_occurrence_id, domain, concept_value) :param patient_events: @@ -76,108 +93,111 @@ def _decorate( # todo: create an assertion the dataframe contains the above columns - valid_visit_ids = patient_events.select( - 'visit_occurrence_id', 'cohort_member_id' - ).distinct() + valid_visit_ids = patient_events.select("visit_occurrence_id", "cohort_member_id").distinct() # Add visit_start_date to the patient_events dataframe and create the visit rank visit_rank_udf = F.row_number().over( - W.partitionBy('person_id', 'cohort_member_id').orderBy( - 'visit_start_datetime', 'is_inpatient', 'expired', 'visit_occurrence_id' + W.partitionBy("person_id", "cohort_member_id").orderBy( + "visit_start_datetime", "is_inpatient", "expired", "visit_occurrence_id" ) ) - visit_segment_udf = F.col('visit_rank_order') % F.lit(2) + 1 + visit_segment_udf = F.col("visit_rank_order") % F.lit(2) + 1 # The visit records are joined to the cohort members (there could be multiple entries for the same patient) # if multiple entries are present, we duplicate the visit records for those. - visits = self._visit_occurrence.join( - valid_visit_ids, - 'visit_occurrence_id' - ).select( - 'person_id', - 'cohort_member_id', - 'visit_occurrence_id', - 'visit_end_date', - F.col('visit_start_date').cast(T.DateType()).alias('visit_start_date'), - F.to_timestamp('visit_start_datetime').alias('visit_start_datetime'), - F.col('visit_concept_id').cast('int').isin([9201, 262, 8971, 8920]).cast('int').alias('is_inpatient'), - F.when(F.col('discharged_to_concept_id').cast('int') == 4216643, F.lit(1)).otherwise(F.lit(0)).alias( - 'expired') - ).withColumn('visit_rank_order', visit_rank_udf) \ - .withColumn('visit_segment', visit_segment_udf) \ - .drop('person_id', 'expired') + visits = ( + self._visit_occurrence.join(valid_visit_ids, "visit_occurrence_id") + .select( + "person_id", + "cohort_member_id", + "visit_occurrence_id", + "visit_end_date", + F.col("visit_start_date").cast(T.DateType()).alias("visit_start_date"), + F.to_timestamp("visit_start_datetime").alias("visit_start_datetime"), + F.col("visit_concept_id").cast("int").isin([9201, 262, 8971, 8920]).cast("int").alias("is_inpatient"), + F.when(F.col("discharged_to_concept_id").cast("int") == 4216643, F.lit(1)) + .otherwise(F.lit(0)) + .alias("expired"), + ) + .withColumn("visit_rank_order", visit_rank_udf) + .withColumn("visit_segment", visit_segment_udf) + .drop("person_id", "expired") + ) # Determine the concept order depending on the visit type. For outpatient visits, we assume the concepts to # have the same order, whereas for inpatient visits, the concept order is determined by the time stamp. # the concept order needs to be generated per each cohort member because the same visit could be used # in multiple cohort's histories of the same patient concept_order_udf = F.when( - F.col('is_inpatient') == 1, - F.dense_rank().over(W.partitionBy('cohort_member_id', 'visit_occurrence_id').orderBy('datetime')) + F.col("is_inpatient") == 1, + F.dense_rank().over(W.partitionBy("cohort_member_id", "visit_occurrence_id").orderBy("datetime")), ).otherwise(F.lit(1)) # Determine the global visit concept order for each patient, this takes both visit_rank_order and concept_order # into account when assigning this new order. # e.g. visit_rank_order = [1, 1, 2, 2], concept_order = [1, 1, 1, 2] -> visit_concept_order = [1, 1, 2, 3] visit_concept_order_udf = F.dense_rank().over( - W.partitionBy('person_id', 'cohort_member_id').orderBy( - 'visit_rank_order', 'concept_order' - ) + W.partitionBy("person_id", "cohort_member_id").orderBy("visit_rank_order", "concept_order") ) # We need to set the visit_end_date as the visit_start_date for outpatient visits # For inpatient visits, we use the original visit_end_date if available, otherwise # we will infer the visit_end_date using the max(date) of the current visit visit_end_date_udf = F.when( - F.col('is_inpatient') == 1, F.coalesce( - F.col('visit_end_date'), - F.max('date').over(W.partitionBy( - 'cohort_member_id', - 'visit_occurrence_id') - ) - ) - ).otherwise(F.col('visit_start_date')) + F.col("is_inpatient") == 1, + F.coalesce( + F.col("visit_end_date"), + F.max("date").over(W.partitionBy("cohort_member_id", "visit_occurrence_id")), + ), + ).otherwise(F.col("visit_start_date")) # We need to bound the medical event dates between visit_start_date and visit_end_date bound_medical_event_date = F.when( - F.col('date') < F.col('visit_start_date'), F.col('visit_start_date')).otherwise( - F.when(F.col('date') > F.col('visit_end_date'), F.col('visit_end_date')).otherwise(F.col('date')) - ) + F.col("date") < F.col("visit_start_date"), F.col("visit_start_date") + ).otherwise(F.when(F.col("date") > F.col("visit_end_date"), F.col("visit_end_date")).otherwise(F.col("date"))) # We need to bound the medical event dates between visit_start_date and visit_end_date bound_medical_event_datetime = F.when( - F.col('datetime') < F.col('visit_start_datetime'), F.col('visit_start_datetime')).otherwise( - F.when(F.col('datetime') > F.col('visit_end_datetime'), F.col('visit_end_datetime')).otherwise( - F.col('datetime')) + F.col("datetime") < F.col("visit_start_datetime"), + F.col("visit_start_datetime"), + ).otherwise( + F.when( + F.col("datetime") > F.col("visit_end_datetime"), + F.col("visit_end_datetime"), + ).otherwise(F.col("datetime")) ) - patient_events = patient_events.join(visits, ['cohort_member_id', 'visit_occurrence_id']) \ - .withColumn('visit_end_date', visit_end_date_udf) \ - .withColumn('visit_end_datetime', F.date_add('visit_end_date', 1)) \ - .withColumn('visit_end_datetime', F.expr('visit_end_datetime - INTERVAL 1 MINUTE')) \ - .withColumn('date', bound_medical_event_date) \ - .withColumn('datetime', bound_medical_event_datetime) \ - .withColumn('concept_order', concept_order_udf) \ - .withColumn('visit_concept_order', visit_concept_order_udf) \ - .drop('is_inpatient', 'visit_end_date', 'visit_end_datetime') \ + patient_events = ( + patient_events.join(visits, ["cohort_member_id", "visit_occurrence_id"]) + .withColumn("visit_end_date", visit_end_date_udf) + .withColumn("visit_end_datetime", F.date_add("visit_end_date", 1)) + .withColumn("visit_end_datetime", F.expr("visit_end_datetime - INTERVAL 1 MINUTE")) + .withColumn("date", bound_medical_event_date) + .withColumn("datetime", bound_medical_event_datetime) + .withColumn("concept_order", concept_order_udf) + .withColumn("visit_concept_order", visit_concept_order_udf) + .drop("is_inpatient", "visit_end_date", "visit_end_datetime") .distinct() + ) # Set the priority for the events. # Create the week since epoch UDF - weeks_since_epoch_udf = (F.unix_timestamp('date') / F.lit(24 * 60 * 60 * 7)).cast('int') - patient_events = patient_events \ - .withColumn('priority', F.lit(0)) \ - .withColumn('date_in_week', weeks_since_epoch_udf) + weeks_since_epoch_udf = (F.unix_timestamp("date") / F.lit(24 * 60 * 60 * 7)).cast("int") + patient_events = patient_events.withColumn("priority", F.lit(0)).withColumn( + "date_in_week", weeks_since_epoch_udf + ) # Create the concept_value_mask field to indicate whether domain values should be skipped # As of now only measurement has values, so other domains would be skipped. - patient_events = patient_events \ - .withColumn('concept_value_mask', (F.col('domain') == MEASUREMENT).cast('int')) \ - .withColumn('mlm_skip_value', - (F.col('domain').isin([MEASUREMENT, CATEGORICAL_MEASUREMENT])).cast('int')) + patient_events = patient_events.withColumn( + "concept_value_mask", (F.col("domain") == MEASUREMENT).cast("int") + ).withColumn( + "mlm_skip_value", + (F.col("domain").isin([MEASUREMENT, CATEGORICAL_MEASUREMENT])).cast("int"), + ) - if 'concept_value' not in patient_events.schema.fieldNames(): - patient_events = patient_events.withColumn('concept_value', F.lit(0.0)) + if "concept_value" not in patient_events.schema.fieldNames(): + patient_events = patient_events.withColumn("concept_value", F.lit(0.0)) # (cohort_member_id, person_id, standard_concept_id, date, datetime, visit_occurrence_id, domain, # concept_value, visit_rank_order, visit_segment, priority, date_in_week, @@ -187,12 +207,12 @@ def _decorate( class PatientEventAttDecorator(PatientEventDecorator): def __init__( - self, - visit_occurrence, - include_visit_type, - exclude_visit_tokens, - att_type: AttType, - include_inpatient_hour_token: bool = False + self, + visit_occurrence, + include_visit_type, + exclude_visit_tokens, + att_type: AttType, + include_inpatient_hour_token: bool = False, ): self._visit_occurrence = visit_occurrence self._include_visit_type = include_visit_type @@ -200,89 +220,97 @@ def __init__( self._att_type = att_type self._include_inpatient_hour_token = include_inpatient_hour_token - def _decorate( - self, - patient_events: DataFrame - ): + def _decorate(self, patient_events: DataFrame): if self._att_type == AttType.NONE: return patient_events # visits should the following columns (person_id, # visit_concept_id, visit_start_date, visit_occurrence_id, domain, concept_value) - cohort_member_person_pair = patient_events.select('person_id', 'cohort_member_id').distinct() - valid_visit_ids = patient_events \ - .groupby('cohort_member_id', 'visit_occurrence_id', 'visit_segment', 'visit_rank_order') \ - .agg(F.min('visit_concept_order').alias('min_visit_concept_order'), - F.max('visit_concept_order').alias('max_visit_concept_order'), - F.min('concept_order').alias('min_concept_order'), - F.max('concept_order').alias('max_concept_order')) - - visit_occurrence = self._visit_occurrence.select( - 'person_id', - F.col('visit_start_date').cast(T.DateType()).alias('date'), - F.col('visit_start_date').cast(T.DateType()).alias('visit_start_date'), - F.col('visit_start_datetime').cast(T.TimestampType()).alias('visit_start_datetime'), - F.coalesce('visit_end_date', 'visit_start_date').cast(T.DateType()).alias('visit_end_date'), - 'visit_concept_id', - 'visit_occurrence_id', - F.lit('visit').alias('domain'), - F.lit(0.0).alias('concept_value'), - F.lit(0).alias('concept_value_mask'), - F.lit(0).alias('mlm_skip_value'), - 'age', - 'discharged_to_concept_id' - ).join(valid_visit_ids, 'visit_occurrence_id') \ - .join(cohort_member_person_pair, ['person_id', 'cohort_member_id']) + cohort_member_person_pair = patient_events.select("person_id", "cohort_member_id").distinct() + valid_visit_ids = patient_events.groupby( + "cohort_member_id", + "visit_occurrence_id", + "visit_segment", + "visit_rank_order", + ).agg( + F.min("visit_concept_order").alias("min_visit_concept_order"), + F.max("visit_concept_order").alias("max_visit_concept_order"), + F.min("concept_order").alias("min_concept_order"), + F.max("concept_order").alias("max_concept_order"), + ) + + visit_occurrence = ( + self._visit_occurrence.select( + "person_id", + F.col("visit_start_date").cast(T.DateType()).alias("date"), + F.col("visit_start_date").cast(T.DateType()).alias("visit_start_date"), + F.col("visit_start_datetime").cast(T.TimestampType()).alias("visit_start_datetime"), + F.coalesce("visit_end_date", "visit_start_date").cast(T.DateType()).alias("visit_end_date"), + "visit_concept_id", + "visit_occurrence_id", + F.lit("visit").alias("domain"), + F.lit(0.0).alias("concept_value"), + F.lit(0).alias("concept_value_mask"), + F.lit(0).alias("mlm_skip_value"), + "age", + "discharged_to_concept_id", + ) + .join(valid_visit_ids, "visit_occurrence_id") + .join(cohort_member_person_pair, ["person_id", "cohort_member_id"]) + ) # We assume outpatient visits end on the same day, therefore we start visit_end_date to visit_start_date due # to bad date - visit_occurrence = visit_occurrence \ - .withColumn('visit_end_date', - F.when(F.col('visit_concept_id').isin([9201, 262, 8971, 8920]), - F.col('visit_end_date')).otherwise(F.col('visit_start_date')) - ) + visit_occurrence = visit_occurrence.withColumn( + "visit_end_date", + F.when( + F.col("visit_concept_id").isin([9201, 262, 8971, 8920]), + F.col("visit_end_date"), + ).otherwise(F.col("visit_start_date")), + ) - weeks_since_epoch_udf = (F.unix_timestamp('date') / F.lit(24 * 60 * 60 * 7)).cast('int') - visit_occurrence = visit_occurrence \ - .withColumn('date_in_week', weeks_since_epoch_udf) + weeks_since_epoch_udf = (F.unix_timestamp("date") / F.lit(24 * 60 * 60 * 7)).cast("int") + visit_occurrence = visit_occurrence.withColumn("date_in_week", weeks_since_epoch_udf) # Cache visit for faster processing visit_occurrence.cache() - visits = visit_occurrence.drop('discharged_to_concept_id') + visits = visit_occurrence.drop("discharged_to_concept_id") # (cohort_member_id, person_id, standard_concept_id, date, visit_occurrence_id, domain, # concept_value, visit_rank_order, visit_segment, priority, date_in_week, # concept_value_mask, mlm_skip_value, visit_end_date) - visit_start_events = visits \ - .withColumn('date', F.col('visit_start_date')) \ - .withColumn('datetime', F.to_timestamp('visit_start_date')) \ - .withColumn('standard_concept_id', F.lit('VS')) \ - .withColumn('visit_concept_order', F.col('min_visit_concept_order')) \ - .withColumn('concept_order', F.col('min_concept_order') - 1) \ - .withColumn('priority', F.lit(-2)) \ - .drop('min_visit_concept_order', 'max_visit_concept_order') \ - .drop('min_concept_order', 'max_concept_order') - - visit_end_events = visits \ - .withColumn('date', F.col('visit_end_date')) \ - .withColumn('datetime', F.date_add(F.to_timestamp('visit_end_date'), 1)) \ - .withColumn('datetime', F.expr("datetime - INTERVAL 1 MINUTE")) \ - .withColumn('standard_concept_id', F.lit('VE')) \ - .withColumn('visit_concept_order', F.col('max_visit_concept_order')) \ - .withColumn('concept_order', F.col('max_concept_order') + 1) \ - .withColumn('priority', F.lit(200)) \ - .drop('min_visit_concept_order', 'max_visit_concept_order') \ - .drop('min_concept_order', 'max_concept_order') + visit_start_events = ( + visits.withColumn("date", F.col("visit_start_date")) + .withColumn("datetime", F.to_timestamp("visit_start_date")) + .withColumn("standard_concept_id", F.lit("VS")) + .withColumn("visit_concept_order", F.col("min_visit_concept_order")) + .withColumn("concept_order", F.col("min_concept_order") - 1) + .withColumn("priority", F.lit(-2)) + .drop("min_visit_concept_order", "max_visit_concept_order") + .drop("min_concept_order", "max_concept_order") + ) + + visit_end_events = ( + visits.withColumn("date", F.col("visit_end_date")) + .withColumn("datetime", F.date_add(F.to_timestamp("visit_end_date"), 1)) + .withColumn("datetime", F.expr("datetime - INTERVAL 1 MINUTE")) + .withColumn("standard_concept_id", F.lit("VE")) + .withColumn("visit_concept_order", F.col("max_visit_concept_order")) + .withColumn("concept_order", F.col("max_concept_order") + 1) + .withColumn("priority", F.lit(200)) + .drop("min_visit_concept_order", "max_visit_concept_order") + .drop("min_concept_order", "max_concept_order") + ) # Get the prev days_since_epoch - prev_visit_end_date_udf = F.lag('visit_end_date').over( - W.partitionBy('person_id', 'cohort_member_id').orderBy('visit_rank_order') + prev_visit_end_date_udf = F.lag("visit_end_date").over( + W.partitionBy("person_id", "cohort_member_id").orderBy("visit_rank_order") ) # Compute the time difference between the current record and the previous record - time_delta_udf = F.when(F.col('prev_visit_end_date').isNull(), 0).otherwise( - F.datediff('visit_start_date', 'prev_visit_end_date') + time_delta_udf = F.when(F.col("prev_visit_end_date").isNull(), 0).otherwise( + F.datediff("visit_start_date", "prev_visit_end_date") ) # Udf for calculating the time token @@ -299,20 +327,24 @@ def _decorate( time_token_udf = F.udf(att_func, T.StringType()) - att_tokens = visits \ - .withColumn('datetime', F.to_timestamp('date')) \ - .withColumn('prev_visit_end_date', prev_visit_end_date_udf) \ - .where(F.col('prev_visit_end_date').isNotNull()) \ - .withColumn('time_delta', time_delta_udf) \ - .withColumn('time_delta', F.when(F.col('time_delta') < 0, F.lit(0)).otherwise(F.col('time_delta'))) \ - .withColumn('standard_concept_id', time_token_udf('time_delta')) \ - .withColumn('priority', F.lit(-3)) \ - .withColumn('visit_rank_order', F.col('visit_rank_order')) \ - .withColumn('visit_concept_order', F.col('min_visit_concept_order')) \ - .withColumn('concept_order', F.lit(0)) \ - .drop('prev_visit_end_date', 'time_delta') \ - .drop('min_visit_concept_order', 'max_visit_concept_order') \ - .drop('min_concept_order', 'max_concept_order') + att_tokens = ( + visits.withColumn("datetime", F.to_timestamp("date")) + .withColumn("prev_visit_end_date", prev_visit_end_date_udf) + .where(F.col("prev_visit_end_date").isNotNull()) + .withColumn("time_delta", time_delta_udf) + .withColumn( + "time_delta", + F.when(F.col("time_delta") < 0, F.lit(0)).otherwise(F.col("time_delta")), + ) + .withColumn("standard_concept_id", time_token_udf("time_delta")) + .withColumn("priority", F.lit(-3)) + .withColumn("visit_rank_order", F.col("visit_rank_order")) + .withColumn("visit_concept_order", F.col("min_visit_concept_order")) + .withColumn("concept_order", F.lit(0)) + .drop("prev_visit_end_date", "time_delta") + .drop("min_visit_concept_order", "max_visit_concept_order") + .drop("min_concept_order", "max_concept_order") + ) if self._exclude_visit_tokens: artificial_tokens = att_tokens @@ -321,139 +353,147 @@ def _decorate( if self._include_visit_type: # insert visit type after the VS token - visit_type_tokens = visits \ - .withColumn('standard_concept_id', F.col('visit_concept_id')) \ - .withColumn('datetime', F.to_timestamp('date')) \ - .withColumn('visit_concept_order', F.col('min_visit_concept_order')) \ - .withColumn('concept_order', F.lit(0)) \ - .withColumn('priority', F.lit(-1)) \ - .drop('min_visit_concept_order', 'max_visit_concept_order') \ - .drop('min_concept_order', 'max_concept_order') + visit_type_tokens = ( + visits.withColumn("standard_concept_id", F.col("visit_concept_id")) + .withColumn("datetime", F.to_timestamp("date")) + .withColumn("visit_concept_order", F.col("min_visit_concept_order")) + .withColumn("concept_order", F.lit(0)) + .withColumn("priority", F.lit(-1)) + .drop("min_visit_concept_order", "max_visit_concept_order") + .drop("min_concept_order", "max_concept_order") + ) artificial_tokens = artificial_tokens.unionByName(visit_type_tokens) - artificial_tokens = artificial_tokens.drop('visit_end_date') + artificial_tokens = artificial_tokens.drop("visit_end_date") # Retrieving the events that are ONLY linked to inpatient visits - inpatient_visits = visit_occurrence.where(F.col('visit_concept_id').isin([9201, 262, 8971, 8920])).select( - 'visit_occurrence_id', - 'visit_end_date', - 'cohort_member_id' - ) - inpatient_events = patient_events.join( - inpatient_visits, ['visit_occurrence_id', 'cohort_member_id'] + inpatient_visits = visit_occurrence.where(F.col("visit_concept_id").isin([9201, 262, 8971, 8920])).select( + "visit_occurrence_id", "visit_end_date", "cohort_member_id" ) + inpatient_events = patient_events.join(inpatient_visits, ["visit_occurrence_id", "cohort_member_id"]) # Fill in the visit_end_date if null (because some visits are still ongoing at the time of data extraction) # Bound the event dates within visit_start_date and visit_end_date # Generate a span rank to indicate the position of the group of events # Update the priority for each span - inpatient_events = inpatient_events.withColumn( - 'visit_end_date', - F.coalesce('visit_end_date', F.max('date').over( - W.partitionBy('cohort_member_id', 'visit_occurrence_id'))) - ).withColumn( - 'date', - F.when(F.col('date') < F.col('visit_start_date'), F.col('visit_start_date')).otherwise( - F.when(F.col('date') > F.col('visit_end_date'), F.col('visit_end_date')).otherwise( - F.col('date') - ) + inpatient_events = ( + inpatient_events.withColumn( + "visit_end_date", + F.coalesce( + "visit_end_date", + F.max("date").over(W.partitionBy("cohort_member_id", "visit_occurrence_id")), + ), ) - ).withColumn( - 'priority', F.col('priority') + F.col('concept_order') * 0.1 - ).drop('visit_end_date') - - discharge_events = visit_occurrence \ - .where(F.col('visit_concept_id').isin([9201, 262, 8971, 8920])) \ - .withColumn('standard_concept_id', F.coalesce(F.col('discharged_to_concept_id'), F.lit(0))) \ - .withColumn('visit_concept_order', F.col('max_visit_concept_order')) \ - .withColumn('concept_order', F.col('max_concept_order') + 1) \ - .withColumn('date', F.col('visit_end_date')) \ - .withColumn('datetime', F.date_add(F.to_timestamp('visit_end_date'), 1)) \ - .withColumn('datetime', F.expr("datetime - INTERVAL 1 MINUTE")) \ - .withColumn('priority', F.lit(100)) \ - .drop('discharged_to_concept_id', 'visit_end_date') \ - .drop('min_visit_concept_order', 'max_visit_concept_order') \ - .drop('min_concept_order', 'max_concept_order') + .withColumn( + "date", + F.when(F.col("date") < F.col("visit_start_date"), F.col("visit_start_date")).otherwise( + F.when(F.col("date") > F.col("visit_end_date"), F.col("visit_end_date")).otherwise(F.col("date")) + ), + ) + .withColumn("priority", F.col("priority") + F.col("concept_order") * 0.1) + .drop("visit_end_date") + ) + + discharge_events = ( + visit_occurrence.where(F.col("visit_concept_id").isin([9201, 262, 8971, 8920])) + .withColumn( + "standard_concept_id", + F.coalesce(F.col("discharged_to_concept_id"), F.lit(0)), + ) + .withColumn("visit_concept_order", F.col("max_visit_concept_order")) + .withColumn("concept_order", F.col("max_concept_order") + 1) + .withColumn("date", F.col("visit_end_date")) + .withColumn("datetime", F.date_add(F.to_timestamp("visit_end_date"), 1)) + .withColumn("datetime", F.expr("datetime - INTERVAL 1 MINUTE")) + .withColumn("priority", F.lit(100)) + .drop("discharged_to_concept_id", "visit_end_date") + .drop("min_visit_concept_order", "max_visit_concept_order") + .drop("min_concept_order", "max_concept_order") + ) # Add discharge events to the inpatient visits inpatient_events = inpatient_events.unionByName(discharge_events) # Get the prev days_since_epoch - inpatient_prev_date_udf = F.lag('date').over( - W.partitionBy('cohort_member_id', 'visit_occurrence_id').orderBy('concept_order') + inpatient_prev_date_udf = F.lag("date").over( + W.partitionBy("cohort_member_id", "visit_occurrence_id").orderBy("concept_order") ) # Compute the time difference between the current record and the previous record - inpatient_time_delta_udf = F.when(F.col('prev_date').isNull(), 0).otherwise( - F.datediff('date', 'prev_date') - ) + inpatient_time_delta_udf = F.when(F.col("prev_date").isNull(), 0).otherwise(F.datediff("date", "prev_date")) if self._include_inpatient_hour_token: # Create ATT tokens within the inpatient visits - inpatient_prev_datetime_udf = F.lag('datetime').over( - W.partitionBy('cohort_member_id', 'visit_occurrence_id').orderBy('concept_order') + inpatient_prev_datetime_udf = F.lag("datetime").over( + W.partitionBy("cohort_member_id", "visit_occurrence_id").orderBy("concept_order") ) # Compute the time difference between the current record and the previous record - inpatient_hour_delta_udf = F.when(F.col('prev_datetime').isNull(), 0).otherwise( - F.floor((F.unix_timestamp('datetime') - F.unix_timestamp('prev_datetime')) / 3600) + inpatient_hour_delta_udf = F.when(F.col("prev_datetime").isNull(), 0).otherwise( + F.floor((F.unix_timestamp("datetime") - F.unix_timestamp("prev_datetime")) / 3600) ) inpatient_att_token = F.when( - F.col('hour_delta') < 24, - F.concat(F.lit('i-H'), F.col('hour_delta')) - ).otherwise( - F.concat(F.lit('i-'), time_token_udf('time_delta')) - ) + F.col("hour_delta") < 24, F.concat(F.lit("i-H"), F.col("hour_delta")) + ).otherwise(F.concat(F.lit("i-"), time_token_udf("time_delta"))) # Create ATT tokens within the inpatient visits - inpatient_att_events = inpatient_events \ - .withColumn('is_span_boundary', F.row_number().over( - W.partitionBy('cohort_member_id', 'visit_occurrence_id', 'concept_order').orderBy('priority'))) \ - .where(F.col('is_span_boundary') == 1) \ - .withColumn('prev_date', inpatient_prev_date_udf) \ - .withColumn('time_delta', inpatient_time_delta_udf) \ - .withColumn('prev_datetime', inpatient_prev_datetime_udf) \ - .withColumn('hour_delta', inpatient_hour_delta_udf) \ - .where(F.col('prev_date').isNotNull()) \ - .where(F.col('hour_delta') > 0) \ - .withColumn('standard_concept_id', inpatient_att_token) \ - .withColumn('visit_concept_order', F.col('visit_concept_order')) \ - .withColumn('priority', F.col('priority') - 0.01) \ - .withColumn('concept_value_mask', F.lit(0)) \ - .withColumn('concept_value', F.lit(0.0)) \ - .drop('prev_date', 'time_delta', 'is_span_boundary') \ - .drop('prev_datetime', 'hour_delta') + inpatient_att_events = ( + inpatient_events.withColumn( + "is_span_boundary", + F.row_number().over( + W.partitionBy("cohort_member_id", "visit_occurrence_id", "concept_order").orderBy("priority") + ), + ) + .where(F.col("is_span_boundary") == 1) + .withColumn("prev_date", inpatient_prev_date_udf) + .withColumn("time_delta", inpatient_time_delta_udf) + .withColumn("prev_datetime", inpatient_prev_datetime_udf) + .withColumn("hour_delta", inpatient_hour_delta_udf) + .where(F.col("prev_date").isNotNull()) + .where(F.col("hour_delta") > 0) + .withColumn("standard_concept_id", inpatient_att_token) + .withColumn("visit_concept_order", F.col("visit_concept_order")) + .withColumn("priority", F.col("priority") - 0.01) + .withColumn("concept_value_mask", F.lit(0)) + .withColumn("concept_value", F.lit(0.0)) + .drop("prev_date", "time_delta", "is_span_boundary") + .drop("prev_datetime", "hour_delta") + ) else: # Create ATT tokens within the inpatient visits - inpatient_att_events = inpatient_events \ - .withColumn('is_span_boundary', F.row_number().over( - W.partitionBy('cohort_member_id', 'visit_occurrence_id', 'concept_order').orderBy('priority'))) \ - .where(F.col('is_span_boundary') == 1) \ - .withColumn('prev_date', inpatient_prev_date_udf) \ - .withColumn('time_delta', inpatient_time_delta_udf) \ - .where(F.col('time_delta') != 0) \ - .where(F.col('prev_date').isNotNull()) \ - .withColumn('standard_concept_id', F.concat(F.lit('i-'), time_token_udf('time_delta'))) \ - .withColumn('visit_concept_order', F.col('visit_concept_order')) \ - .withColumn('priority', F.col('priority') - 0.01) \ - .withColumn('concept_value_mask', F.lit(0)) \ - .withColumn('concept_value', F.lit(0.0)) \ - .drop('prev_date', 'time_delta', 'is_span_boundary') + inpatient_att_events = ( + inpatient_events.withColumn( + "is_span_boundary", + F.row_number().over( + W.partitionBy("cohort_member_id", "visit_occurrence_id", "concept_order").orderBy("priority") + ), + ) + .where(F.col("is_span_boundary") == 1) + .withColumn("prev_date", inpatient_prev_date_udf) + .withColumn("time_delta", inpatient_time_delta_udf) + .where(F.col("time_delta") != 0) + .where(F.col("prev_date").isNotNull()) + .withColumn( + "standard_concept_id", + F.concat(F.lit("i-"), time_token_udf("time_delta")), + ) + .withColumn("visit_concept_order", F.col("visit_concept_order")) + .withColumn("priority", F.col("priority") - 0.01) + .withColumn("concept_value_mask", F.lit(0)) + .withColumn("concept_value", F.lit(0.0)) + .drop("prev_date", "time_delta", "is_span_boundary") + ) self.validate(inpatient_events) self.validate(inpatient_att_events) # Retrieving the events that are NOT linked to inpatient visits other_events = patient_events.join( - inpatient_visits.select('visit_occurrence_id', 'cohort_member_id'), - ['visit_occurrence_id', 'cohort_member_id'], - how='left_anti' + inpatient_visits.select("visit_occurrence_id", "cohort_member_id"), + ["visit_occurrence_id", "cohort_member_id"], + how="left_anti", ) - patient_events = inpatient_events.unionByName( - inpatient_att_events - ).unionByName( - other_events - ) + patient_events = inpatient_events.unionByName(inpatient_att_events).unionByName(other_events) self.validate(patient_events) self.validate(artificial_tokens) @@ -462,21 +502,12 @@ def _decorate( return patient_events.unionByName(artificial_tokens) -class DemographicPromptDecorator( - PatientEventDecorator -): - def __init__( - self, - patient_demographic, - use_age_group: bool = False - ): +class DemographicPromptDecorator(PatientEventDecorator): + def __init__(self, patient_demographic, use_age_group: bool = False): self._patient_demographic = patient_demographic self._use_age_group = use_age_group - def _decorate( - self, - patient_events: DataFrame - ): + def _decorate(self, patient_events: DataFrame): if self._patient_demographic is None: return patient_events @@ -487,86 +518,79 @@ def _decorate( # Get the first token of the patient history first_token_udf = F.row_number().over( - W.partitionBy('cohort_member_id', 'person_id').orderBy( - 'visit_start_datetime', - 'visit_occurrence_id', - 'priority', - 'standard_concept_id') + W.partitionBy("cohort_member_id", "person_id").orderBy( + "visit_start_datetime", + "visit_occurrence_id", + "priority", + "standard_concept_id", + ) ) # Identify the first token of each patient history - patient_first_token = patient_events \ - .withColumn('token_order', first_token_udf) \ - .withColumn('concept_value_mask', F.lit(0)) \ - .withColumn('concept_value', F.lit(0.0)) \ - .where('token_order = 1') \ - .drop('token_order') + patient_first_token = ( + patient_events.withColumn("token_order", first_token_udf) + .withColumn("concept_value_mask", F.lit(0)) + .withColumn("concept_value", F.lit(0.0)) + .where("token_order = 1") + .drop("token_order") + ) # Udf for identifying the earliest date associated with a visit_occurrence_id - sequence_start_year_token = patient_first_token \ - .withColumn('standard_concept_id', - F.concat(F.lit('year:'), F.year('date').cast(T.StringType()))) \ - .withColumn('priority', F.lit(-10)) \ - .withColumn('visit_segment', F.lit(0)) \ - .withColumn('date_in_week', F.lit(0)) \ - .withColumn('age', F.lit(-1)) \ - .withColumn('visit_rank_order', F.lit(0)) \ - .withColumn('visit_concept_order', F.lit(0)) \ - .withColumn('concept_order', F.lit(0)) + sequence_start_year_token = ( + patient_first_token.withColumn( + "standard_concept_id", + F.concat(F.lit("year:"), F.year("date").cast(T.StringType())), + ) + .withColumn("priority", F.lit(-10)) + .withColumn("visit_segment", F.lit(0)) + .withColumn("date_in_week", F.lit(0)) + .withColumn("age", F.lit(-1)) + .withColumn("visit_rank_order", F.lit(0)) + .withColumn("visit_concept_order", F.lit(0)) + .withColumn("concept_order", F.lit(0)) + ) sequence_start_year_token.cache() if self._use_age_group: calculate_age_group_at_first_visit_udf = F.ceil( - F.floor(F.months_between(F.col('date'), F.col('birth_datetime')) / F.lit(12) / 10) + F.floor(F.months_between(F.col("date"), F.col("birth_datetime")) / F.lit(12) / 10) ) age_at_first_visit_udf = F.concat( - F.lit('age:'), + F.lit("age:"), (calculate_age_group_at_first_visit_udf * 10).cast(T.StringType()), - F.lit('-'), - ((calculate_age_group_at_first_visit_udf + 1) * 10).cast(T.StringType()) + F.lit("-"), + ((calculate_age_group_at_first_visit_udf + 1) * 10).cast(T.StringType()), ) else: calculate_age_at_first_visit_udf = F.ceil( - F.months_between(F.col('date'), F.col('birth_datetime')) / F.lit(12) - ) - age_at_first_visit_udf = F.concat( - F.lit('age:'), - calculate_age_at_first_visit_udf.cast(T.StringType()) + F.months_between(F.col("date"), F.col("birth_datetime")) / F.lit(12) ) + age_at_first_visit_udf = F.concat(F.lit("age:"), calculate_age_at_first_visit_udf.cast(T.StringType())) + + sequence_age_token = ( + self._patient_demographic.select(F.col("person_id"), F.col("birth_datetime")) + .join(sequence_start_year_token, "person_id") + .withColumn("standard_concept_id", age_at_first_visit_udf) + .withColumn("priority", F.lit(-9)) + .drop("birth_datetime") + ) - sequence_age_token = self._patient_demographic.select( - F.col('person_id'), - F.col('birth_datetime') - ).join( - sequence_start_year_token, - 'person_id' - ).withColumn( - 'standard_concept_id', - age_at_first_visit_udf - ).withColumn('priority', F.lit(-9)).drop('birth_datetime') - - sequence_gender_token = self._patient_demographic.select( - F.col('person_id'), - F.col('gender_concept_id') - ).join( - sequence_start_year_token, - 'person_id' - ).withColumn( - 'standard_concept_id', - F.col('gender_concept_id').cast(T.StringType()) - ).withColumn('priority', F.lit(-8)).drop('gender_concept_id') - - sequence_race_token = self._patient_demographic.select( - F.col('person_id'), - F.col('race_concept_id') - ).join( - sequence_start_year_token, - 'person_id' - ).withColumn( - 'standard_concept_id', - F.col('race_concept_id').cast(T.StringType()) - ).withColumn('priority', F.lit(-7)).drop('race_concept_id') + sequence_gender_token = ( + self._patient_demographic.select(F.col("person_id"), F.col("gender_concept_id")) + .join(sequence_start_year_token, "person_id") + .withColumn("standard_concept_id", F.col("gender_concept_id").cast(T.StringType())) + .withColumn("priority", F.lit(-8)) + .drop("gender_concept_id") + ) + + sequence_race_token = ( + self._patient_demographic.select(F.col("person_id"), F.col("race_concept_id")) + .join(sequence_start_year_token, "person_id") + .withColumn("standard_concept_id", F.col("race_concept_id").cast(T.StringType())) + .withColumn("priority", F.lit(-7)) + .drop("race_concept_id") + ) patient_events = patient_events.unionByName(sequence_start_year_token) patient_events = patient_events.unionByName(sequence_age_token) @@ -577,31 +601,27 @@ def _decorate( class DeathEventDecorator(PatientEventDecorator): - def __init__( - self, - death, - att_type - ): + def __init__(self, death, att_type): self._death = death self._att_type = att_type - def _decorate( - self, - patient_events: DataFrame - ): + def _decorate(self, patient_events: DataFrame): if self._death is None: return patient_events - death_records = patient_events.join(self._death.select('person_id', 'death_date'), 'person_id') + death_records = patient_events.join(self._death.select("person_id", "death_date"), "person_id") - max_visit_occurrence_id = death_records.select( - F.max('visit_occurrence_id').alias('max_visit_occurrence_id') - ) + max_visit_occurrence_id = death_records.select(F.max("visit_occurrence_id").alias("max_visit_occurrence_id")) - last_ve_record = death_records.where(F.col('standard_concept_id') == 'VE').withColumn( - 'record_rank', - F.row_number().over(W.partitionBy('person_id', 'cohort_member_id').orderBy(F.desc('date'))) - ).where(F.col('record_rank') == 1).drop('record_rank') + last_ve_record = ( + death_records.where(F.col("standard_concept_id") == "VE") + .withColumn( + "record_rank", + F.row_number().over(W.partitionBy("person_id", "cohort_member_id").orderBy(F.desc("date"))), + ) + .where(F.col("record_rank") == 1) + .drop("record_rank") + ) last_ve_record.cache() last_ve_record.show() @@ -610,26 +630,22 @@ def _decorate( # 'visit_segment', 'priority', 'date_in_week', 'concept_value_mask', # 'mlm_skip_value', 'age', 'visit_concept_id']) - artificial_visit_id = ( - F.row_number().over(W.partitionBy(F.lit(0)).orderBy('person_id', 'cohort_member_id')) - + F.col('max_visit_occurrence_id') + artificial_visit_id = F.row_number().over( + W.partitionBy(F.lit(0)).orderBy("person_id", "cohort_member_id") + ) + F.col("max_visit_occurrence_id") + death_records = ( + last_ve_record.crossJoin(max_visit_occurrence_id) + .withColumn("visit_occurrence_id", artificial_visit_id) + .withColumn("standard_concept_id", F.lit("[DEATH]")) + .withColumn("domain", F.lit("death")) + .withColumn("visit_rank_order", F.lit(1) + F.col("visit_rank_order")) + .withColumn("priority", F.lit(20)) + .drop("max_visit_occurrence_id") ) - death_records = last_ve_record \ - .crossJoin(max_visit_occurrence_id) \ - .withColumn('visit_occurrence_id', artificial_visit_id) \ - .withColumn('standard_concept_id', F.lit('[DEATH]')) \ - .withColumn('domain', F.lit('death')) \ - .withColumn('visit_rank_order', F.lit(1) + F.col('visit_rank_order')) \ - .withColumn('priority', F.lit(20)) \ - .drop('max_visit_occurrence_id') - - vs_records = death_records \ - .withColumn('standard_concept_id', F.lit('VS')) \ - .withColumn('priority', F.lit(15)) - - ve_records = death_records \ - .withColumn('standard_concept_id', F.lit('VE')) \ - .withColumn('priority', F.lit(30)) + + vs_records = death_records.withColumn("standard_concept_id", F.lit("VS")).withColumn("priority", F.lit(15)) + + ve_records = death_records.withColumn("standard_concept_id", F.lit("VE")).withColumn("priority", F.lit(30)) # Udf for calculating the time token if self._att_type == AttType.DAY: @@ -646,16 +662,18 @@ def _decorate( time_token_udf = F.udf(att_func, T.StringType()) att_records = death_records.withColumn( - 'death_date', F.when(F.col('death_date') < F.col('date'), F.col('date')).otherwise(F.col('death_date')) + "death_date", + F.when(F.col("death_date") < F.col("date"), F.col("date")).otherwise(F.col("death_date")), + ) + att_records = ( + att_records.withColumn("time_delta", F.datediff("death_date", "date")) + .withColumn("standard_concept_id", time_token_udf("time_delta")) + .withColumn("priority", F.lit(10)) + .drop("time_delta") ) - att_records = att_records \ - .withColumn('time_delta', F.datediff('death_date', 'date')) \ - .withColumn('standard_concept_id', time_token_udf('time_delta')) \ - .withColumn('priority', F.lit(10)) \ - .drop('time_delta') new_tokens = att_records.unionByName(vs_records).unionByName(death_records).unionByName(ve_records) - new_tokens = new_tokens.drop('death_date') + new_tokens = new_tokens.drop("death_date") self.validate(new_tokens) return patient_events.unionByName(new_tokens) @@ -665,36 +683,36 @@ def time_token_func(time_delta) -> Optional[str]: if time_delta is None or np.isnan(time_delta): return None if time_delta < 0: - return 'W-1' + return "W-1" if time_delta < 28: - return f'W{str(math.floor(time_delta / 7))}' + return f"W{str(math.floor(time_delta / 7))}" if time_delta < 360: - return f'M{str(math.floor(time_delta / 30))}' - return 'LT' + return f"M{str(math.floor(time_delta / 30))}" + return "LT" def time_day_token(time_delta): if time_delta is None or np.isnan(time_delta): return None if time_delta < 1080: - return f'D{str(time_delta)}' - return 'LT' + return f"D{str(time_delta)}" + return "LT" def time_week_token(time_delta): if time_delta is None or np.isnan(time_delta): return None if time_delta < 1080: - return f'W{str(math.floor(time_delta / 7))}' - return 'LT' + return f"W{str(math.floor(time_delta / 7))}" + return "LT" def time_month_token(time_delta): if time_delta is None or np.isnan(time_delta): return None if time_delta < 1080: - return f'M{str(math.floor(time_delta / 30))}' - return 'LT' + return f"M{str(math.floor(time_delta / 30))}" + return "LT" def time_mix_token(time_delta): @@ -707,20 +725,20 @@ def time_mix_token(time_delta): if time_delta is None or np.isnan(time_delta): return None if time_delta <= 7: - return f'D{str(time_delta)}' + return f"D{str(time_delta)}" if time_delta <= 30: # e.g. 8 -> W2 - return f'W{str(math.ceil(time_delta / 7))}' + return f"W{str(math.ceil(time_delta / 7))}" if time_delta <= 360: # e.g. 31 -> M2 - return f'M{str(math.ceil(time_delta / 30))}' + return f"M{str(math.ceil(time_delta / 30))}" # if time_delta <= 720: # # e.g. 361 -> Q5 # return f'Q{str(math.ceil(time_delta / 90))}' # if time_delta <= 1080: # # e.g. 1081 -> Y2 # return f'Y{str(math.ceil(time_delta / 360))}' - return 'LT' + return "LT" def get_att_function(att_type: Union[AttType, str]): diff --git a/src/cehrbert/spark_apps/generate_concept_similarity_table.py b/src/cehrbert/spark_apps/generate_concept_similarity_table.py index dbe991e0..339d4880 100644 --- a/src/cehrbert/spark_apps/generate_concept_similarity_table.py +++ b/src/cehrbert/spark_apps/generate_concept_similarity_table.py @@ -1,25 +1,40 @@ -import os +"""This module provides functionality to extract patient event data from domain tables,. + +compute information content and semantic similarity for concepts, and calculate concept +similarity scores. + +Functions: extract_data: Extract data from specified domain tables. compute_information_content: +Compute the information content for concepts based on frequency. +compute_information_content_similarity: Compute the similarity between concepts based on +information content. compute_semantic_similarity: Compute the semantic similarity between concept +pairs. main: Main function to orchestrate the extraction, processing, and saving of concept +similarity data. +""" + import datetime -from pyspark.sql import SparkSession, DataFrame +import logging +import os +from typing import List + +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import functions as F -from ..utils.spark_utils import * -from ..config.output_names import * +from ..config.output_names import CONCEPT_SIMILARITY_PATH, QUALIFIED_CONCEPT_LIST_PATH from ..const.common import CONCEPT, CONCEPT_ANCESTOR +from ..utils.spark_utils import join_domain_tables, preprocess_domain_table, validate_table_names -def extract_data( - spark: SparkSession, - input_folder: str, - domain_table_list: List[str] -): +def extract_data(spark: SparkSession, input_folder: str, domain_table_list: List[str]): """ - Extract all data points from the specified domains + Extract patient event data from the specified domain tables. - :param spark: - :param input_folder: - :param domain_table_list: - :param include_concept_list: - :return: + Args: + spark (SparkSession): The Spark session to use for processing. + input_folder (str): Path to the input folder containing domain tables. + domain_table_list (List[str]): List of domain table names to extract data from. + + Returns: + DataFrame: A DataFrame containing extracted and processed patient event data. """ domain_tables = [] for domain_table_name in domain_table_list: @@ -31,12 +46,9 @@ def extract_data( return patient_event -def compute_information_content( - patient_event: DataFrame, - concept_ancestor: DataFrame -): +def compute_information_content(patient_event: DataFrame, concept_ancestor: DataFrame): """ - Calculate the information content using the frequency of each concept and the graph + Calculate the information content using the frequency of each concept and the graph. :param patient_event: :param concept_ancestor: @@ -45,163 +57,193 @@ def compute_information_content( # Get the total count total_count = patient_event.distinct().count() # Count the frequency of each concept - concept_frequency = patient_event.distinct().groupBy( - 'standard_concept_id').count() + concept_frequency = patient_event.distinct().groupBy("standard_concept_id").count() # left join b/w descendent_concept_id and the standard_concept_id in the concept freq table - freq_df = concept_frequency.join( - concept_ancestor, F.col('descendant_concept_id') == F.col('standard_concept_id')) \ - .groupBy('ancestor_concept_id').sum("count") \ - .withColumnRenamed('ancestor_concept_id', 'concept_id') \ - .withColumnRenamed('sum(count)', 'count') + freq_df = ( + concept_frequency.join( + concept_ancestor, + F.col("descendant_concept_id") == F.col("standard_concept_id"), + ) + .groupBy("ancestor_concept_id") + .sum("count") + .withColumnRenamed("ancestor_concept_id", "concept_id") + .withColumnRenamed("sum(count)", "count") + ) # Calculate information content for each concept - information_content = freq_df.withColumn( - 'information_content', (-F.log(F.col('count') / total_count))) \ - .withColumn('probability', F.col('count') / total_count) + information_content = freq_df.withColumn("information_content", (-F.log(F.col("count") / total_count))).withColumn( + "probability", F.col("count") / total_count + ) return information_content def compute_information_content_similarity( - concept_pair: DataFrame, - information_content: DataFrame, - concept_ancestor: DataFrame + concept_pair: DataFrame, information_content: DataFrame, concept_ancestor: DataFrame ): + """ + Compute the similarity between concept pairs based on their information content. + + Args: + concept_pair (DataFrame): A DataFrame containing pairs of concepts. + information_content (DataFrame): A DataFrame with information content for concepts. + concept_ancestor (DataFrame): A DataFrame containing concept ancestor relationships. + + Returns: + DataFrame: A DataFrame containing various similarity measures for concept pairs. + """ # Extract the pairs of concepts from the training data and join to the information content table - information_content_concept_pair = concept_pair.select("concept_id_1", "concept_id_2") \ - .join(information_content, F.col("concept_id_1") == F.col("concept_id"), "left_outer") \ - .select(F.col("concept_id_1"), - F.col("concept_id_2"), - F.col("information_content").alias("information_content_1")) \ - .join(information_content, F.col("concept_id_2") == F.col("concept_id"), "left_outer") \ - .select(F.col("concept_id_1"), - F.col("concept_id_2"), - F.col("information_content_1"), - F.col("information_content").alias("information_content_2")) + information_content_concept_pair = ( + concept_pair.select("concept_id_1", "concept_id_2") + .join( + information_content, + F.col("concept_id_1") == F.col("concept_id"), + "left_outer", + ) + .select( + F.col("concept_id_1"), + F.col("concept_id_2"), + F.col("information_content").alias("information_content_1"), + ) + .join( + information_content, + F.col("concept_id_2") == F.col("concept_id"), + "left_outer", + ) + .select( + F.col("concept_id_1"), + F.col("concept_id_2"), + F.col("information_content_1"), + F.col("information_content").alias("information_content_2"), + ) + ) # Join to get all the ancestors of concept_id_1 concept_id_1_ancestor = information_content_concept_pair.join( - concept_ancestor, F.col('concept_id_1') == F.col('descendant_concept_id')) \ - .select('concept_id_1', 'concept_id_2', 'ancestor_concept_id') + concept_ancestor, F.col("concept_id_1") == F.col("descendant_concept_id") + ).select("concept_id_1", "concept_id_2", "ancestor_concept_id") # Join to get all the ancestors of concept_id_2 concept_id_2_ancestor = concept_pair.join( - concept_ancestor, F.col('concept_id_2') == F.col('descendant_concept_id')) \ - .select('concept_id_1', 'concept_id_2', 'ancestor_concept_id') + concept_ancestor, F.col("concept_id_2") == F.col("descendant_concept_id") + ).select("concept_id_1", "concept_id_2", "ancestor_concept_id") # Compute the summed information content of all ancestors of concept_id_1 and concept_id_2 - union_sum = concept_id_1_ancestor.union(concept_id_2_ancestor).distinct() \ - .join(information_content, F.col("ancestor_concept_id") == F.col("concept_id")) \ - .groupBy('concept_id_1', 'concept_id_2').agg( - F.sum('information_content').alias('ancestor_union_ic')) + union_sum = ( + concept_id_1_ancestor.union(concept_id_2_ancestor) + .distinct() + .join(information_content, F.col("ancestor_concept_id") == F.col("concept_id")) + .groupBy("concept_id_1", "concept_id_2") + .agg(F.sum("information_content").alias("ancestor_union_ic")) + ) # Compute the summed information content of common ancestors of concept_id_1 and concept_id_2 - intersection_sum = concept_id_1_ancestor.intersect(concept_id_2_ancestor) \ - .join(information_content, F.col('ancestor_concept_id') == F.col('concept_id')) \ - .groupBy('concept_id_1', 'concept_id_2').agg( - F.sum('information_content').alias('ancestor_intersection_ic')) + intersection_sum = ( + concept_id_1_ancestor.intersect(concept_id_2_ancestor) + .join(information_content, F.col("ancestor_concept_id") == F.col("concept_id")) + .groupBy("concept_id_1", "concept_id_2") + .agg(F.sum("information_content").alias("ancestor_intersection_ic")) + ) # Compute the information content and probability of the most informative common ancestor (MICA) - mica_ancestor = concept_id_1_ancestor.intersect(concept_id_2_ancestor) \ - .join(information_content, F.col("ancestor_concept_id") == F.col("concept_id")) \ - .groupBy("concept_id_1", "concept_id_2").agg( - F.max('information_content').alias('mica_information_content'), - F.max('probability').alias('mica_probability') + mica_ancestor = ( + concept_id_1_ancestor.intersect(concept_id_2_ancestor) + .join(information_content, F.col("ancestor_concept_id") == F.col("concept_id")) + .groupBy("concept_id_1", "concept_id_2") + .agg( + F.max("information_content").alias("mica_information_content"), + F.max("probability").alias("mica_probability"), + ) ) # Join the MICA to pairs of concepts features = information_content_concept_pair.join( mica_ancestor, - (information_content_concept_pair['concept_id_1'] == mica_ancestor['concept_id_1']) - & (information_content_concept_pair['concept_id_2'] == mica_ancestor['concept_id_2']), - "left_outer") \ - .select([information_content_concept_pair[f] for f in - information_content_concept_pair.schema.fieldNames()] + [ - F.col("mica_information_content"), F.col("mica_probability")]) + (information_content_concept_pair["concept_id_1"] == mica_ancestor["concept_id_1"]) + & (information_content_concept_pair["concept_id_2"] == mica_ancestor["concept_id_2"]), + "left_outer", + ).select( + [information_content_concept_pair[f] for f in information_content_concept_pair.schema.fieldNames()] + + [F.col("mica_information_content"), F.col("mica_probability")] + ) # Compute the lin measure features = features.withColumn( "lin_measure", - 2 * F.col('mica_information_content') / ( - F.col('information_content_1') * F.col('information_content_2'))) + 2 * F.col("mica_information_content") / (F.col("information_content_1") * F.col("information_content_2")), + ) # Compute the jiang measure features = features.withColumn( - 'jiang_measure', - 1 - (F.col('information_content_1') + F.col('information_content_2') - 2 * F.col( - 'mica_information_content')) + "jiang_measure", + 1 - (F.col("information_content_1") + F.col("information_content_2") - 2 * F.col("mica_information_content")), ) # Compute the information coefficient features = features.withColumn( - 'information_coefficient', - F.col('lin_measure') * (1 - 1 / (1 + F.col('mica_information_content'))) + "information_coefficient", + F.col("lin_measure") * (1 - 1 / (1 + F.col("mica_information_content"))), ) # Compute the relevance_measure - features = features.withColumn( - 'relevance_measure', - F.col('lin_measure') * (1 - F.col('mica_probability')) - ) + features = features.withColumn("relevance_measure", F.col("lin_measure") * (1 - F.col("mica_probability"))) # Join to get the summed information content of the common ancestors of concept_id_1 and # concept_id_2 features = features.join( intersection_sum, (features["concept_id_1"] == intersection_sum["concept_id_1"]) - & (features["concept_id_2"] == intersection_sum["concept_id_2"]), "left_outer") \ - .select( - [features[f] for f in features.schema.fieldNames()] + [ - F.col("ancestor_intersection_ic")]) + & (features["concept_id_2"] == intersection_sum["concept_id_2"]), + "left_outer", + ).select([features[f] for f in features.schema.fieldNames()] + [F.col("ancestor_intersection_ic")]) # Join to get the summed information content of the common ancestors of concept_id_1 and # concept_id_2 features = features.join( union_sum, - (features['concept_id_1'] == union_sum['concept_id_1']) - & (features['concept_id_2'] == union_sum['concept_id_2']), 'left_outer') \ - .select([features[f] for f in features.schema.fieldNames()] + [ - F.col("ancestor_union_ic")]) + (features["concept_id_1"] == union_sum["concept_id_1"]) + & (features["concept_id_2"] == union_sum["concept_id_2"]), + "left_outer", + ).select([features[f] for f in features.schema.fieldNames()] + [F.col("ancestor_union_ic")]) # Compute the graph information content measure features = features.withColumn( - 'graph_ic_measure', - F.col('ancestor_intersection_ic') / F.col( - 'ancestor_union_ic') + "graph_ic_measure", + F.col("ancestor_intersection_ic") / F.col("ancestor_union_ic"), ) - return features.select([ - F.col('concept_id_1'), - F.col('concept_id_2'), - F.col("mica_information_content"), - F.col("lin_measure"), - F.col("jiang_measure"), - F.col("information_coefficient"), - F.col("relevance_measure"), - F.col("graph_ic_measure") - ]) - - -def compute_semantic_similarity( - spark, - patient_event, - concept, - concept_ancestor -): - required_concept = patient_event.distinct().select('standard_concept_id').join( - concept, - F.col('standard_concept_id') == F.col('concept_id') - ).select('standard_concept_id', 'domain_id') + return features.select( + [ + F.col("concept_id_1"), + F.col("concept_id_2"), + F.col("mica_information_content"), + F.col("lin_measure"), + F.col("jiang_measure"), + F.col("information_coefficient"), + F.col("relevance_measure"), + F.col("graph_ic_measure"), + ] + ) - concept_ancestor.createOrReplaceTempView('concept_ancestor') - required_concept.createOrReplaceTempView('required_concept') - concept_pair = spark.sql(''' +def compute_semantic_similarity(spark, patient_event, concept, concept_ancestor): + required_concept = ( + patient_event.distinct() + .select("standard_concept_id") + .join(concept, F.col("standard_concept_id") == F.col("concept_id")) + .select("standard_concept_id", "domain_id") + ) + + concept_ancestor.createOrReplaceTempView("concept_ancestor") + required_concept.createOrReplaceTempView("required_concept") + + concept_pair = spark.sql( + """ WITH concept_pair AS ( - SELECT + SELECT c1.standard_concept_id AS concept_id_1, c2.standard_concept_id AS concept_id_2, c1.domain_id - FROM required_concept AS c1 + FROM required_concept AS c1 JOIN required_concept AS c2 ON c1.domain_id = c2.domain_id WHERE c1.standard_concept_id <> c2.standard_concept_id @@ -218,185 +260,157 @@ def compute_semantic_similarity( JOIN concept_ancestor AS ca_2 ON cp.concept_id_2 = ca_2.descendant_concept_id WHERE ca_1.ancestor_concept_id = ca_2.ancestor_concept_id - ''') + """ + ) # Find the root concepts - root_concept = concept_ancestor \ - .groupBy('descendant_concept_id') \ - .count().where('count = 1') \ - .withColumnRenamed('descendant_concept_id', 'root_concept_id') + root_concept = ( + concept_ancestor.groupBy("descendant_concept_id") + .count() + .where("count = 1") + .withColumnRenamed("descendant_concept_id", "root_concept_id") + ) # Retrieve all ancestor descendant relationships for the root concepts - root_concept_relationship = root_concept.join( - concept_ancestor, - root_concept["root_concept_id"] == - concept_ancestor["ancestor_concept_id"]) \ - .select(concept_ancestor["ancestor_concept_id"], - concept_ancestor["descendant_concept_id"], - concept_ancestor["max_levels_of_separation"].alias("root_distance")) \ + root_concept_relationship = ( + root_concept.join( + concept_ancestor, + root_concept["root_concept_id"] == concept_ancestor["ancestor_concept_id"], + ) + .select( + concept_ancestor["ancestor_concept_id"], + concept_ancestor["descendant_concept_id"], + concept_ancestor["max_levels_of_separation"].alias("root_distance"), + ) .where("ancestor_concept_id <> descendant_concept_id") + ) # Join to get all root concepts and their corresponding root_distance concept_pair = concept_pair.join( root_concept_relationship, - F.col('common_ancestor_concept_id') == F.col('descendant_concept_id')) \ - .select('concept_id_1', 'concept_id_2', 'distance_1', 'distance_2', 'root_distance') + F.col("common_ancestor_concept_id") == F.col("descendant_concept_id"), + ).select("concept_id_1", "concept_id_2", "distance_1", "distance_2", "root_distance") # Compute the semantic similarity concept_pair_similarity = concept_pair.withColumn( - 'semantic_similarity', - 2 * F.col("root_distance") / ( - 2 * F.col("root_distance") + F.col("distance_1") + F.col("distance_2")) + "semantic_similarity", + 2 * F.col("root_distance") / (2 * F.col("root_distance") + F.col("distance_1") + F.col("distance_2")), ) # Find the maximum semantic similarity - concept_pair_similarity = concept_pair_similarity.groupBy("concept_id_1", "concept_id_2") \ - .agg(F.max("semantic_similarity").alias("semantic_similarity")) + concept_pair_similarity = concept_pair_similarity.groupBy("concept_id_1", "concept_id_2").agg( + F.max("semantic_similarity").alias("semantic_similarity") + ) return concept_pair_similarity def main( - input_folder: str, - output_folder: str, - domain_table_list: List[str], - date_filter: str, - include_concept_list: bool + input_folder: str, + output_folder: str, + domain_table_list: List[str], + date_filter: str, + include_concept_list: bool, ): """ - spark: SparkSession, - input_folder: str, - domain_table_list: List[str], - include_concept_list: bool - This function creates the information content table based on the given domain tables - - Keyword arguments: - domain_tables -- the array containing the OMOP domain tables except visit_occurrence - + Main function to generate the concept similarity table. + + Args: + input_folder (str): The path to the input folder containing raw data. + output_folder (str): The path to the output folder to store the results. + domain_table_list (List[str]): List of domain tables to process. + date_filter (str): Date filter to apply to the data. + include_concept_list (bool): Whether to include a filtered concept list. """ - spark = SparkSession.builder.appName('Generate the concept similarity table').getOrCreate() + spark = SparkSession.builder.appName("Generate the concept similarity table").getOrCreate() logger = logging.getLogger(__name__) logger.info( - f'input_folder: {input_folder}\n' - f'output_folder: {output_folder}\n' - f'domain_table_list: {domain_table_list}\n' - f'date_filter: {date_filter}\n' - f'include_concept_list: {include_concept_list}\n' + "input_folder: %s\noutput_folder: %s\ndomain_table_list: %s\ndate_filter: " "%s\ninclude_concept_list: %s", + input_folder, + output_folder, + domain_table_list, + date_filter, + include_concept_list, ) concept = preprocess_domain_table(spark, input_folder, CONCEPT) concept_ancestor = preprocess_domain_table(spark, input_folder, CONCEPT_ANCESTOR) # Extract all data points from specified domains - patient_event = extract_data( - spark, - input_folder, - domain_table_list - ) + patient_event = extract_data(spark, input_folder, domain_table_list) # Calculate information content using unfiltered the patient event dataframe - information_content = compute_information_content( - patient_event, - concept_ancestor - ) + information_content = compute_information_content(patient_event, concept_ancestor) # Filter out concepts that are not required in the required concept_list if include_concept_list and patient_event: # Filter out concepts - qualified_concepts = broadcast( - preprocess_domain_table( - spark, - input_folder, - QUALIFIED_CONCEPT_LIST_PATH - ) - ) + qualified_concepts = F.broadcast(preprocess_domain_table(spark, input_folder, QUALIFIED_CONCEPT_LIST_PATH)) - patient_event = patient_event.join( - qualified_concepts, - 'standard_concept_id' - ).select('standard_concept_id') + patient_event = patient_event.join(qualified_concepts, "standard_concept_id").select("standard_concept_id") - concept_pair_similarity = compute_semantic_similarity( - spark, - patient_event, - concept, - concept_ancestor - ) + concept_pair_similarity = compute_semantic_similarity(spark, patient_event, concept, concept_ancestor) # Compute the information content based similarity scores concept_pair_ic_similarity = compute_information_content_similarity( - concept_pair_similarity, - information_content, - concept_ancestor + concept_pair_similarity, information_content, concept_ancestor ) - concept_pair_similarity_columns = [ - concept_pair_similarity[f] for f in - concept_pair_similarity.schema.fieldNames() - ] + concept_pair_similarity_columns = [concept_pair_similarity[f] for f in concept_pair_similarity.schema.fieldNames()] concept_pair_ic_similarity_columns = [ - f for f in concept_pair_ic_similarity.schema.fieldNames() - if 'concept_id' not in f + f for f in concept_pair_ic_similarity.schema.fieldNames() if "concept_id" not in f ] # Join two dataframes to get the final result concept_pair_similarity = concept_pair_similarity.join( concept_pair_ic_similarity, - (concept_pair_similarity['concept_id_1'] == concept_pair_ic_similarity['concept_id_1']) & ( - concept_pair_similarity['concept_id_2'] == concept_pair_ic_similarity[ - 'concept_id_2']) + (concept_pair_similarity["concept_id_1"] == concept_pair_ic_similarity["concept_id_1"]) + & (concept_pair_similarity["concept_id_2"] == concept_pair_ic_similarity["concept_id_2"]), ).select(concept_pair_similarity_columns + concept_pair_ic_similarity_columns) - concept_pair_similarity.write.mode('overwrite').parquet( - os.path.join( - output_folder, - CONCEPT_SIMILARITY_PATH - ) - ) + concept_pair_similarity.write.mode("overwrite").parquet(os.path.join(output_folder, CONCEPT_SIMILARITY_PATH)) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Arguments for generate Concept Similarity Table') +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Arguments for generate Concept Similarity Table") parser.add_argument( - '-i', - '--input_folder', - dest='input_folder', - action='store', - help='The path for your input_folder where the raw data is', - required=True + "-i", + "--input_folder", + dest="input_folder", + action="store", + help="The path for your input_folder where the raw data is", + required=True, ) parser.add_argument( - '-o', - '--output_folder', - dest='output_folder', - action='store', - help='The path for your output_folder', - required=True + "-o", + "--output_folder", + dest="output_folder", + action="store", + help="The path for your output_folder", + required=True, ) parser.add_argument( - '-tc', - '--domain_table_list', - dest='domain_table_list', - nargs='+', - action='store', - help='The list of domain tables you want to download', + "-tc", + "--domain_table_list", + dest="domain_table_list", + nargs="+", + action="store", + help="The list of domain tables you want to download", type=validate_table_names, - required=True + required=True, ) parser.add_argument( - '-d', - '--date_filter', - dest='date_filter', - type=lambda s: datetime.datetime.strptime(s, '%Y-%m-%d'), - action='store', + "-d", + "--date_filter", + dest="date_filter", + type=lambda s: datetime.datetime.strptime(s, "%Y-%m-%d"), + action="store", required=False, - default='2018-01-01' - ) - parser.add_argument( - '--include_concept_list', - dest='include_concept_list', - action='store_true' + default="2018-01-01", ) + parser.add_argument("--include_concept_list", dest="include_concept_list", action="store_true") ARGS = parser.parse_args() @@ -405,5 +419,5 @@ def main( ARGS.output_folder, ARGS.domain_table_list, ARGS.date_filter, - ARGS.include_concept_list + ARGS.include_concept_list, ) diff --git a/src/cehrbert/spark_apps/generate_hierarchical_bert_training_data.py b/src/cehrbert/spark_apps/generate_hierarchical_bert_training_data.py index ba7c89b9..82c875bd 100644 --- a/src/cehrbert/spark_apps/generate_hierarchical_bert_training_data.py +++ b/src/cehrbert/spark_apps/generate_hierarchical_bert_training_data.py @@ -1,36 +1,91 @@ +""" +This module generates hierarchical BERT training data based on domain tables from OMOP EHR data. + +It processes patient event data, joins multiple domain tables, filters concepts based on a +minimum number of patients, and creates hierarchical sequence data for BERT training. + +Key Functions: + - preprocess_domain_table: Preprocesses domain tables for data extraction. + - process_measurement: Handles special processing for measurement data. + - join_domain_tables: Joins multiple domain tables into a unified DataFrame. + - create_hierarchical_sequence_data: Generates hierarchical sequence data for training. + +Command-line Arguments: + - input_folder: Path to the directory containing input data. + - output_folder: Path to the directory where the output will be saved. + - domain_table_list: List of domain tables to process. + - date_filter: Optional filter for processing the data based on date. + - max_num_of_visits_per_person: Maximum number of visits per patient to include. + - min_observation_period: Minimum observation period in days for patients to be included. + - include_concept_list: Whether to apply a filter to retain certain concepts. + - include_incomplete_visit: Whether to include incomplete visit records in the training data. +""" + import datetime +import logging import os from pyspark.sql import SparkSession +from pyspark.sql import functions as F -from ..config.output_names import * -from ..utils.spark_utils import * -from ..const.common import OBSERVATION_PERIOD, VISIT_OCCURRENCE, PERSON, MEASUREMENT, REQUIRED_MEASUREMENT -from ..utils.spark_utils import validate_table_names +from ..config.output_names import PARQUET_DATA_PATH, QUALIFIED_CONCEPT_LIST_PATH +from ..const.common import MEASUREMENT, OBSERVATION_PERIOD, PERSON, REQUIRED_MEASUREMENT, VISIT_OCCURRENCE +from ..utils.spark_utils import ( + create_hierarchical_sequence_data, + join_domain_tables, + preprocess_domain_table, + process_measurement, + validate_table_names, +) def main( + input_folder, + output_folder, + domain_table_list, + date_filter, + max_num_of_visits_per_person, + min_observation_period: int = 360, + include_concept_list: bool = True, + include_incomplete_visit: bool = True, +): + """ + Main function to generate hierarchical BERT training data from domain tables. + + Args: + input_folder (str): The path to the input folder containing raw data. + output_folder (str): The path to the output folder for storing the training data. + domain_table_list (list): A list of domain tables to process (e.g., condition_occurrence). + date_filter (str): Date filter for processing data, default is '2018-01-01'. + max_num_of_visits_per_person (int): The maximum number of visits to include per person. + min_observation_period (int, optional): Minimum observation period in days. Default is 360. + include_concept_list (bool, optional): Whether to filter by concept list. Default is True. + include_incomplete_visit (bool, optional): Whether to include incomplete visits. Default is + True. + + This function preprocesses domain tables, filters and processes measurement data, + and generates hierarchical sequence data for training BERT models on EHR records. + """ + spark = SparkSession.builder.appName("Generate Hierarchical Bert Training Data").getOrCreate() + + logger = logging.getLogger(__name__) + logger.info( + "input_folder: %s\n" + "output_folder: %s\n" + "domain_table_list: %s\n" + "date_filter: %s\n" + "max_num_of_visits_per_person: %s\n" + "min_observation_period: %s\n" + "include_concept_list: %s\n" + "include_incomplete_visit: %s", input_folder, output_folder, domain_table_list, date_filter, max_num_of_visits_per_person, - min_observation_period: int = 360, - include_concept_list: bool = True, - include_incomplete_visit: bool = True -): - spark = SparkSession.builder.appName('Generate Hierarchical Bert Training Data').getOrCreate() - - logger = logging.getLogger(__name__) - logger.info( - f'input_folder: {input_folder}\n' - f'output_folder: {output_folder}\n' - f'domain_table_list: {domain_table_list}\n' - f'date_filter: {date_filter}\n' - f'max_num_of_visits_per_person: {max_num_of_visits_per_person}\n' - f'min_observation_period: {min_observation_period}\n' - f'include_concept_list: {include_concept_list}\n' - f'include_incomplete_visit: {include_incomplete_visit}' + min_observation_period, + include_concept_list, + include_incomplete_visit, ) domain_tables = [] @@ -41,26 +96,28 @@ def main( domain_tables.append(preprocess_domain_table(spark, input_folder, domain_table_name)) observation_period = ( - preprocess_domain_table(spark, input_folder, OBSERVATION_PERIOD).withColumn( - 'observation_period_start_date', - F.col('observation_period_start_date').cast('date') - ).withColumn( - 'observation_period_end_date', - F.col('observation_period_end_date').cast('date') - ).withColumn( - 'period', - F.datediff('observation_period_end_date', 'observation_period_start_date') - ).where(F.col('period') >= min_observation_period).select('person_id') + preprocess_domain_table(spark, input_folder, OBSERVATION_PERIOD) + .withColumn( + "observation_period_start_date", + F.col("observation_period_start_date").cast("date"), + ) + .withColumn( + "observation_period_end_date", + F.col("observation_period_end_date").cast("date"), + ) + .withColumn( + "period", + F.datediff("observation_period_end_date", "observation_period_start_date"), + ) + .where(F.col("period") >= min_observation_period) + .select("person_id") ) visit_occurrence = preprocess_domain_table(spark, input_folder, VISIT_OCCURRENCE) person = preprocess_domain_table(spark, input_folder, PERSON) # Filter for the persons that have enough observation period - person = person.join( - observation_period, - 'person_id' - ).select([person[f] for f in person.schema.fieldNames()]) + person = person.join(observation_period, "person_id").select([person[f] for f in person.schema.fieldNames()]) # Union all domain table records patient_events = join_domain_tables(domain_tables) @@ -69,19 +126,10 @@ def main( if include_concept_list and patient_events: # Filter out concepts - qualified_concepts = broadcast( - preprocess_domain_table( - spark, - input_folder, - QUALIFIED_CONCEPT_LIST_PATH - ) - ) + qualified_concepts = F.broadcast(preprocess_domain_table(spark, input_folder, QUALIFIED_CONCEPT_LIST_PATH)) # The select is necessary to make sure the order of the columns is the same as the # original dataframe - patient_events = patient_events.join( - qualified_concepts, - 'standard_concept_id' - ).select(column_names) + patient_events = patient_events.join(qualified_concepts, "standard_concept_id").select(column_names) # Process the measurement table if exists if MEASUREMENT in domain_table_list: @@ -89,103 +137,91 @@ def main( required_measurement = preprocess_domain_table(spark, input_folder, REQUIRED_MEASUREMENT) # The select is necessary to make sure the order of the columns is the same as the # original dataframe, otherwise the union might use the wrong columns - scaled_measurement = process_measurement( - spark, - measurement, - required_measurement - ).select(column_names) + scaled_measurement = process_measurement(spark, measurement, required_measurement).select(column_names) if patient_events: # Union all measurement records together with other domain records - patient_events = patient_events.union( - scaled_measurement - ) + patient_events = patient_events.union(scaled_measurement) else: patient_events = scaled_measurement # cohort_member_id is the same as the person_id - patient_events = patient_events.withColumn('cohort_member_id', F.col('person_id')) + patient_events = patient_events.withColumn("cohort_member_id", F.col("person_id")) sequence_data = create_hierarchical_sequence_data( - person, visit_occurrence, patient_events, + person, + visit_occurrence, + patient_events, date_filter=date_filter, max_num_of_visits_per_person=max_num_of_visits_per_person, - include_incomplete_visit=include_incomplete_visit + include_incomplete_visit=include_incomplete_visit, ) - sequence_data.write.mode('overwrite').parquet( - os.path.join( - output_folder, - PARQUET_DATA_PATH - ) - ) + sequence_data.write.mode("overwrite").parquet(os.path.join(output_folder, PARQUET_DATA_PATH)) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Arguments for generate training ' - 'data for Hierarchical Bert') +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Arguments for generate training data for Hierarchical Bert") parser.add_argument( - '-i', - '--input_folder', - dest='input_folder', - action='store', - help='The path for your input_folder where the raw data is', - required=True + "-i", + "--input_folder", + dest="input_folder", + action="store", + help="The path for your input_folder where the raw data is", + required=True, ) parser.add_argument( - '-o', - '--output_folder', - dest='output_folder', - action='store', - help='The path for your output_folder', - required=True + "-o", + "--output_folder", + dest="output_folder", + action="store", + help="The path for your output_folder", + required=True, ) parser.add_argument( - '-tc', - '--domain_table_list', - dest='domain_table_list', - nargs='+', - action='store', - help='The list of domain tables you want to download', + "-tc", + "--domain_table_list", + dest="domain_table_list", + nargs="+", + action="store", + help="The list of domain tables you want to download", type=validate_table_names, - required=True + required=True, ) parser.add_argument( - '-d', - '--date_filter', - dest='date_filter', - type=lambda s: datetime.datetime.strptime(s, '%Y-%m-%d'), - action='store', + "-d", + "--date_filter", + dest="date_filter", + type=lambda s: datetime.datetime.strptime(s, "%Y-%m-%d"), + action="store", required=False, - default='2018-01-01' + default="2018-01-01", ) parser.add_argument( - '--max_num_of_visits', - dest='max_num_of_visits', - action='store', + "--max_num_of_visits", + dest="max_num_of_visits", + action="store", type=int, default=200, - help='Max no.of visits per patient to be included', - required=False + help="Max no.of visits per patient to be included", + required=False, ) parser.add_argument( - '--min_observation_period', - dest='min_observation_period', - action='store', + "--min_observation_period", + dest="min_observation_period", + action="store", type=int, default=1, - help='Minimum observation period in days', - required=False - ) - parser.add_argument( - '--include_concept_list', - dest='include_concept_list', - action='store_true' + help="Minimum observation period in days", + required=False, ) + parser.add_argument("--include_concept_list", dest="include_concept_list", action="store_true") parser.add_argument( - '--include_incomplete_visit', - dest='include_incomplete_visit', - action='store_true' + "--include_incomplete_visit", + dest="include_incomplete_visit", + action="store_true", ) ARGS = parser.parse_args() @@ -198,5 +234,5 @@ def main( max_num_of_visits_per_person=ARGS.max_num_of_visits, min_observation_period=ARGS.min_observation_period, include_concept_list=ARGS.include_concept_list, - include_incomplete_visit=ARGS.include_incomplete_visit + include_incomplete_visit=ARGS.include_incomplete_visit, ) diff --git a/src/cehrbert/spark_apps/generate_included_concept_list.py b/src/cehrbert/spark_apps/generate_included_concept_list.py index 295ec646..ac81dedd 100644 --- a/src/cehrbert/spark_apps/generate_included_concept_list.py +++ b/src/cehrbert/spark_apps/generate_included_concept_list.py @@ -1,22 +1,51 @@ +""" +This module generates a qualified concept list by processing patient event data across various. + +domain tables (e.g., condition_occurrence, procedure_occurrence, drug_exposure) and applying a +patient frequency filter to retain concepts linked to a minimum number of patients. + +Key Functions: + - preprocess_domain_table: Preprocesses domain tables to prepare for event extraction. + - join_domain_tables: Joins multiple domain tables into a unified DataFrame. + - main: Coordinates the entire process of reading domain tables, applying frequency filters, + and saving the qualified concept list. + +Command-line Arguments: + - input_folder: Directory containing the input data. + - output_folder: Directory where the qualified concept list will be saved. + - min_num_of_patients: Minimum number of patients linked to a concept for it to be included. + - with_drug_rollup: Boolean flag indicating whether drug concept rollups should be applied. +""" + import os from pyspark.sql import SparkSession -from pyspark.sql.functions import countDistinct +from pyspark.sql import functions as F -from ..utils.spark_utils import * -from ..const.common import MEASUREMENT from ..config.output_names import QUALIFIED_CONCEPT_LIST_PATH +from ..const.common import MEASUREMENT +from ..utils.spark_utils import join_domain_tables, preprocess_domain_table + +DOMAIN_TABLE_LIST = ["condition_occurrence", "procedure_occurrence", "drug_exposure"] -DOMAIN_TABLE_LIST = ['condition_occurrence', 'procedure_occurrence', 'drug_exposure'] +def main(input_folder, output_folder, min_num_of_patients, with_drug_rollup: bool = True): + """ + Main function to generate a qualified concept list based on patient event data from multiple. -def main( - input_folder, - output_folder, - min_num_of_patients, - with_drug_rollup: bool = True -): - spark = SparkSession.builder.appName('Generate concept list').getOrCreate() + domain tables. + + Args: + input_folder (str): The directory where the input data is stored. + output_folder (str): The directory where the output (qualified concept list) will be saved. + min_num_of_patients (int): Minimum number of patients that a concept must be linked to for + nclusion. + with_drug_rollup (bool): If True, applies drug rollup logic to the drug_exposure domain. + + The function processes patient event data across various domain tables, excludes low-frequency + concepts, and saves the filtered concepts to a specified output folder. + """ + spark = SparkSession.builder.appName("Generate concept list").getOrCreate() domain_tables = [] # Exclude measurement from domain_table_list if exists because we need to process measurement @@ -24,59 +53,58 @@ def main( for domain_table_name in DOMAIN_TABLE_LIST: if domain_table_name != MEASUREMENT: domain_tables.append( - preprocess_domain_table(spark, input_folder, domain_table_name, with_drug_rollup=with_drug_rollup) + preprocess_domain_table( + spark, + input_folder, + domain_table_name, + with_drug_rollup=with_drug_rollup, + ) ) # Union all domain table records patient_events = join_domain_tables(domain_tables) # Filter out concepts that are linked to less than 100 patients - qualified_concepts = patient_events.where('visit_occurrence_id IS NOT NULL') \ - .groupBy('standard_concept_id') \ - .agg(countDistinct('person_id').alias('freq')) \ - .where(F.col('freq') >= min_num_of_patients) - - qualified_concepts.write.mode('overwrite').parquet( - os.path.join( - output_folder, - QUALIFIED_CONCEPT_LIST_PATH - ) + qualified_concepts = ( + patient_events.where("visit_occurrence_id IS NOT NULL") + .groupBy("standard_concept_id") + .agg(F.countDistinct("person_id").alias("freq")) + .where(F.col("freq") >= min_num_of_patients) ) + qualified_concepts.write.mode("overwrite").parquet(os.path.join(output_folder, QUALIFIED_CONCEPT_LIST_PATH)) + -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Arguments for generate ' - 'concept list to be included') +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Arguments for generate concept list to be included") parser.add_argument( - '-i', - '--input_folder', - dest='input_folder', - action='store', - help='The path for your input_folder where the raw data is', - required=True + "-i", + "--input_folder", + dest="input_folder", + action="store", + help="The path for your input_folder where the raw data is", + required=True, ) parser.add_argument( - '-o', - '--output_folder', - dest='output_folder', - action='store', - help='The path for your output_folder', - required=True + "-o", + "--output_folder", + dest="output_folder", + action="store", + help="The path for your output_folder", + required=True, ) parser.add_argument( - '--min_num_of_patients', - dest='min_num_of_patients', - action='store', + "--min_num_of_patients", + dest="min_num_of_patients", + action="store", type=int, default=0, - help='Min no.of patients linked to concepts to be included', - required=False - ) - parser.add_argument( - '--with_drug_rollup', - dest='with_drug_rollup', - action='store_true' + help="Min no.of patients linked to concepts to be included", + required=False, ) + parser.add_argument("--with_drug_rollup", dest="with_drug_rollup", action="store_true") ARGS = parser.parse_args() @@ -84,5 +112,5 @@ def main( ARGS.input_folder, ARGS.output_folder, ARGS.min_num_of_patients, - ARGS.with_drug_rollup + ARGS.with_drug_rollup, ) diff --git a/src/cehrbert/spark_apps/generate_information_content.py b/src/cehrbert/spark_apps/generate_information_content.py index 801cdc02..eadbdd73 100644 --- a/src/cehrbert/spark_apps/generate_information_content.py +++ b/src/cehrbert/spark_apps/generate_information_content.py @@ -1,18 +1,36 @@ -import os +""" +This module generates an information content table based on a list of domain tables from OMOP data. + +It processes patient event data, calculates the frequency of each concept, and computes information +conten using the concept ancestor hierarchy. The results are written to a specified output path. + +Key Functions: + - preprocess_domain_table: Preprocess the domain tables for analysis. + - join_domain_tables: Join multiple domain tables to generate a unified patient event table. + - main: Orchestrates the process of reading input data, calculating concept frequencies, + and generating the information content table. + +Command-line Arguments: + - input_folder: The folder containing the raw OMOP domain data. + - output_folder: The folder where the results will be stored. + - domain_table_list: A list of OMOP domain tables to include in the analysis. + - date_filter: Optional date filter for processing the data. +""" + import datetime +import logging +import os + from pyspark.sql import SparkSession -from ..utils.spark_utils import * -from ..config.output_names import * +from pyspark.sql import functions as F + +from ..config.output_names import INFORMATION_CONTENT_DATA_PATH from ..const.common import CONCEPT_ANCESTOR +from ..utils.spark_utils import join_domain_tables, preprocess_domain_table, validate_table_names -def main( - input_folder, - output_folder, - domain_table_list, - date_filter -): - """Create the information content table +def main(input_folder, output_folder, domain_table_list, date_filter): + """Create the information content table. Keyword arguments: domain_tables -- the array containing the OMOP domain tables except visit_occurrence @@ -21,15 +39,17 @@ def main( This function creates the information content table based on the given domain tables """ - spark = SparkSession.builder.appName('Generate the information content table').getOrCreate() + spark = SparkSession.builder.appName("Generate the information content table").getOrCreate() logger = logging.getLogger(__name__) logger.info( - f'input_folder: {input_folder}\n' - f'output_folder: {output_folder}\n' - f'domain_table_list: {domain_table_list}\n' - f'date_filter: {date_filter}\n' + "input_folder: %s\noutput_folder: %s\ndomain_table_list: %s\ndate_filter: %s", + input_folder, + output_folder, + domain_table_list, + date_filter, ) + concept_ancestor = preprocess_domain_table(spark, input_folder, CONCEPT_ANCESTOR) domain_tables = [] for domain_table_name in domain_table_list: @@ -44,69 +64,68 @@ def main( total_count = patient_events.distinct().count() # Count the frequency of each concept - concept_frequency = patient_events.distinct().groupBy( - 'standard_concept_id').count() + concept_frequency = patient_events.distinct().groupBy("standard_concept_id").count() # left join b/w descendent_concept_id and the standard_concept_id in the concept freq table - freq_df = concept_frequency.join( - concept_ancestor, F.col('descendant_concept_id') == F.col('standard_concept_id')) \ - .groupBy('ancestor_concept_id').sum("count") \ - .withColumnRenamed('ancestor_concept_id', 'concept_id') \ - .withColumnRenamed('sum(count)', 'count') + freq_df = ( + concept_frequency.join( + concept_ancestor, + F.col("descendant_concept_id") == F.col("standard_concept_id"), + ) + .groupBy("ancestor_concept_id") + .sum("count") + .withColumnRenamed("ancestor_concept_id", "concept_id") + .withColumnRenamed("sum(count)", "count") + ) # Calculate information content for each concept - information_content = freq_df.withColumn( - 'information_content', (-F.log(F.col('count') / total_count))) \ - .withColumn('probability', F.col('count') / total_count) - - information_content.write.mode('overwrite').parquet( - os.path.join(output_folder, INFORMATION_CONTENT_DATA_PATH) + information_content = freq_df.withColumn("information_content", (-F.log(F.col("count") / total_count))).withColumn( + "probability", F.col("count") / total_count ) + information_content.write.mode("overwrite").parquet(os.path.join(output_folder, INFORMATION_CONTENT_DATA_PATH)) + -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Arguments for generate training data for Bert') +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Arguments for generate training data for Bert") parser.add_argument( - '-i', - '--input_folder', - dest='input_folder', - action='store', - help='The path for your input_folder where the raw data is', - required=True + "-i", + "--input_folder", + dest="input_folder", + action="store", + help="The path for your input_folder where the raw data is", + required=True, ) parser.add_argument( - '-o', - '--output_folder', - dest='output_folder', - action='store', - help='The path for your output_folder', - required=True + "-o", + "--output_folder", + dest="output_folder", + action="store", + help="The path for your output_folder", + required=True, ) parser.add_argument( - '-tc', - '--domain_table_list', - dest='domain_table_list', - nargs='+', - action='store', - help='The list of domain tables you want to download', + "-tc", + "--domain_table_list", + dest="domain_table_list", + nargs="+", + action="store", + help="The list of domain tables you want to download", type=validate_table_names, - required=True + required=True, ) parser.add_argument( - '-d', - '--date_filter', - dest='date_filter', - type=lambda s: datetime.datetime.strptime(s, '%Y-%m-%d'), - action='store', + "-d", + "--date_filter", + dest="date_filter", + type=lambda s: datetime.datetime.strptime(s, "%Y-%m-%d"), + action="store", required=False, - default='2018-01-01' + default="2018-01-01", ) ARGS = parser.parse_args() - main( - ARGS.input_folder, - ARGS.output_folder, - ARGS.domain_table_list, - ARGS.date_filter - ) + main(ARGS.input_folder, ARGS.output_folder, ARGS.domain_table_list, ARGS.date_filter) diff --git a/src/cehrbert/spark_apps/generate_required_labs.py b/src/cehrbert/spark_apps/generate_required_labs.py index 61441d43..a096e70a 100644 --- a/src/cehrbert/spark_apps/generate_required_labs.py +++ b/src/cehrbert/spark_apps/generate_required_labs.py @@ -2,32 +2,28 @@ from pyspark.sql import SparkSession -from ..utils.spark_utils import * -from ..const.common import MEASUREMENT, REQUIRED_MEASUREMENT, CONCEPT +from ..const.common import CONCEPT, MEASUREMENT, REQUIRED_MEASUREMENT +from ..utils.spark_utils import F, W, argparse, preprocess_domain_table -def main( - input_folder, - output_folder, - num_of_numeric_labs, - num_of_categorical_labs -): - spark = SparkSession.builder.appName('Generate required labs').getOrCreate() +def main(input_folder, output_folder, num_of_numeric_labs, num_of_categorical_labs): + spark = SparkSession.builder.appName("Generate required labs").getOrCreate() # Load measurement as a dataframe in pyspark measurement = preprocess_domain_table(spark, input_folder, MEASUREMENT) concept = preprocess_domain_table(spark, input_folder, CONCEPT) # Create the local measurement view - measurement.createOrReplaceTempView('measurement') + measurement.createOrReplaceTempView("measurement") # Create the local concept view - concept.createOrReplaceTempView('concept') + concept.createOrReplaceTempView("concept") - popular_labs = spark.sql(''' + popular_labs = spark.sql( + """ SELECT - m.measurement_concept_id, - c.concept_name, + m.measurement_concept_id, + c.concept_name, COUNT(*) AS freq, SUM(CASE WHEN m.value_as_number IS NOT NULL THEN 1 ELSE 0 END) / COUNT(*) AS numeric_percentage, SUM(CASE WHEN m.value_as_concept_id IS NOT NULL AND m.value_as_concept_id <> 0 THEN 1 ELSE 0 END) / COUNT(*) AS categorical_percentage @@ -37,70 +33,70 @@ def main( WHERE m.measurement_concept_id <> 0 GROUP BY m.measurement_concept_id, c.concept_name ORDER BY COUNT(*) DESC - ''') + """ + ) # Cache the dataframe for faster computation in the below transformations popular_labs.cache() - popular_numeric_labs = popular_labs \ - .withColumn('is_numeric', F.col('numeric_percentage') >= 0.5) \ - .where('is_numeric') \ - .withColumn('rn', F.row_number().over(W.orderBy(F.desc('freq')))) \ - .where(F.col('rn') <= num_of_numeric_labs) \ - .drop('rn') - - popular_categorical_labs = popular_labs \ - .withColumn('is_categorical', F.col('categorical_percentage') >= 0.5) \ - .where('is_categorical') \ - .withColumn('is_numeric', ~F.col('is_categorical')) \ - .withColumn('rn', F.row_number().over(W.orderBy(F.desc('freq')))) \ - .where(F.col('rn') <= num_of_categorical_labs) \ - .drop('is_categorical').drop('rn') - - popular_numeric_labs.unionAll(popular_categorical_labs).write.mode('overwrite').parquet( - os.path.join( - output_folder, - REQUIRED_MEASUREMENT - ) + popular_numeric_labs = ( + popular_labs.withColumn("is_numeric", F.col("numeric_percentage") >= 0.5) + .where("is_numeric") + .withColumn("rn", F.row_number().over(W.orderBy(F.desc("freq")))) + .where(F.col("rn") <= num_of_numeric_labs) + .drop("rn") + ) + + popular_categorical_labs = ( + popular_labs.withColumn("is_categorical", F.col("categorical_percentage") >= 0.5) + .where("is_categorical") + .withColumn("is_numeric", ~F.col("is_categorical")) + .withColumn("rn", F.row_number().over(W.orderBy(F.desc("freq")))) + .where(F.col("rn") <= num_of_categorical_labs) + .drop("is_categorical") + .drop("rn") + ) + + popular_numeric_labs.unionAll(popular_categorical_labs).write.mode("overwrite").parquet( + os.path.join(output_folder, REQUIRED_MEASUREMENT) ) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Arguments for generate ' - 'required labs to be included') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Arguments for generate " "required labs to be included") parser.add_argument( - '-i', - '--input_folder', - dest='input_folder', - action='store', - help='The path for your input_folder where the raw data is', - required=True + "-i", + "--input_folder", + dest="input_folder", + action="store", + help="The path for your input_folder where the raw data is", + required=True, ) parser.add_argument( - '-o', - '--output_folder', - dest='output_folder', - action='store', - help='The path for your output_folder', - required=True + "-o", + "--output_folder", + dest="output_folder", + action="store", + help="The path for your output_folder", + required=True, ) parser.add_argument( - '--num_of_numeric_labs', - dest='num_of_numeric_labs', - action='store', + "--num_of_numeric_labs", + dest="num_of_numeric_labs", + action="store", type=int, default=100, - help='The top most popular numeric labs to be included', - required=False + help="The top most popular numeric labs to be included", + required=False, ) parser.add_argument( - '--num_of_categorical_labs', - dest='num_of_categorical_labs', - action='store', + "--num_of_categorical_labs", + dest="num_of_categorical_labs", + action="store", type=int, default=100, - help='The top most popular categorical labs to be included', - required=False + help="The top most popular categorical labs to be included", + required=False, ) ARGS = parser.parse_args() @@ -109,5 +105,5 @@ def main( ARGS.input_folder, ARGS.output_folder, ARGS.num_of_numeric_labs, - ARGS.num_of_categorical_labs + ARGS.num_of_categorical_labs, ) diff --git a/src/cehrbert/spark_apps/generate_training_data.py b/src/cehrbert/spark_apps/generate_training_data.py index 99854f10..5f2d16ba 100644 --- a/src/cehrbert/spark_apps/generate_training_data.py +++ b/src/cehrbert/spark_apps/generate_training_data.py @@ -5,90 +5,108 @@ from pyspark.sql import SparkSession from pyspark.sql.window import Window -from ..utils.spark_utils import * from ..spark_apps.decorators.patient_event_decorator import AttType +from ..utils.spark_utils import ( + MEASUREMENT, + REQUIRED_MEASUREMENT, + F, + W, + argparse, + create_sequence_data, + create_sequence_data_with_att, + join_domain_tables, + logging, + preprocess_domain_table, + process_measurement, + validate_table_names, +) -VISIT_OCCURRENCE = 'visit_occurrence' -PERSON = 'person' -DEATH = 'death' +VISIT_OCCURRENCE = "visit_occurrence" +PERSON = "person" +DEATH = "death" def main( - input_folder, - output_folder, - domain_table_list, - date_filter, - include_visit_type, - is_new_patient_representation, - exclude_visit_tokens, - is_classic_bert, - include_prolonged_stay, - include_concept_list: bool, - gpt_patient_sequence: bool, - apply_age_filter: bool, - include_death: bool, - att_type: AttType, - include_sequence_information_content: bool = False, - exclude_demographic: bool = False, - use_age_group: bool = False, - with_drug_rollup: bool = True, - include_inpatient_hour_token: bool = False, - continue_from_events: bool = False + input_folder, + output_folder, + domain_table_list, + date_filter, + include_visit_type, + is_new_patient_representation, + exclude_visit_tokens, + is_classic_bert, + include_prolonged_stay, + include_concept_list: bool, + gpt_patient_sequence: bool, + apply_age_filter: bool, + include_death: bool, + att_type: AttType, + include_sequence_information_content: bool = False, + exclude_demographic: bool = False, + use_age_group: bool = False, + with_drug_rollup: bool = True, + include_inpatient_hour_token: bool = False, + continue_from_events: bool = False, ): - spark = SparkSession.builder.appName('Generate CEHR-BERT Training Data').getOrCreate() + spark = SparkSession.builder.appName("Generate CEHR-BERT Training Data").getOrCreate() logger = logging.getLogger(__name__) logger.info( - f'input_folder: {input_folder}\n' - f'output_folder: {output_folder}\n' - f'domain_table_list: {domain_table_list}\n' - f'date_filter: {date_filter}\n' - f'include_visit_type: {include_visit_type}\n' - f'is_new_patient_representation: {is_new_patient_representation}\n' - f'exclude_visit_tokens: {exclude_visit_tokens}\n' - f'is_classic_bert: {is_classic_bert}\n' - f'include_prolonged_stay: {include_prolonged_stay}\n' - f'include_concept_list: {include_concept_list}\n' - f'gpt_patient_sequence: {gpt_patient_sequence}\n' - f'apply_age_filter: {apply_age_filter}\n' - f'include_death: {include_death}\n' - f'att_type: {att_type}\n' - f'exclude_demographic: {exclude_demographic}\n' - f'use_age_group: {use_age_group}\n' - f'with_drug_rollup: {with_drug_rollup}\n' + f"input_folder: {input_folder}\n" + f"output_folder: {output_folder}\n" + f"domain_table_list: {domain_table_list}\n" + f"date_filter: {date_filter}\n" + f"include_visit_type: {include_visit_type}\n" + f"is_new_patient_representation: {is_new_patient_representation}\n" + f"exclude_visit_tokens: {exclude_visit_tokens}\n" + f"is_classic_bert: {is_classic_bert}\n" + f"include_prolonged_stay: {include_prolonged_stay}\n" + f"include_concept_list: {include_concept_list}\n" + f"gpt_patient_sequence: {gpt_patient_sequence}\n" + f"apply_age_filter: {apply_age_filter}\n" + f"include_death: {include_death}\n" + f"att_type: {att_type}\n" + f"exclude_demographic: {exclude_demographic}\n" + f"use_age_group: {use_age_group}\n" + f"with_drug_rollup: {with_drug_rollup}\n" ) domain_tables = [] for domain_table_name in domain_table_list: if domain_table_name != MEASUREMENT: domain_tables.append( - preprocess_domain_table(spark, input_folder, domain_table_name, with_drug_rollup=with_drug_rollup) + preprocess_domain_table( + spark, + input_folder, + domain_table_name, + with_drug_rollup=with_drug_rollup, + ) ) visit_occurrence = preprocess_domain_table(spark, input_folder, VISIT_OCCURRENCE) visit_occurrence = visit_occurrence.select( - 'visit_occurrence_id', - 'visit_start_date', - 'visit_start_datetime', - 'visit_end_date', - 'visit_concept_id', - 'person_id', - 'discharged_to_concept_id' + "visit_occurrence_id", + "visit_start_date", + "visit_start_datetime", + "visit_end_date", + "visit_concept_id", + "person_id", + "discharged_to_concept_id", ) person = preprocess_domain_table(spark, input_folder, PERSON) - birth_datetime_udf = F.coalesce('birth_datetime', - F.concat('year_of_birth', F.lit('-01-01')).cast('timestamp')) + birth_datetime_udf = F.coalesce("birth_datetime", F.concat("year_of_birth", F.lit("-01-01")).cast("timestamp")) person = person.select( - 'person_id', - birth_datetime_udf.alias('birth_datetime'), - 'race_concept_id', - 'gender_concept_id' + "person_id", + birth_datetime_udf.alias("birth_datetime"), + "race_concept_id", + "gender_concept_id", ) - visit_occurrence_person = visit_occurrence.join(person, 'person_id') \ - .withColumn('age', F.ceil( - F.months_between(F.col('visit_start_date'), F.col('birth_datetime')) / F.lit(12))) - visit_occurrence_person = visit_occurrence_person.drop('birth_datetime') + visit_occurrence_person = visit_occurrence.join(person, "person_id").withColumn( + "age", + F.ceil(F.months_between(F.col("visit_start_date"), F.col("birth_datetime")) / F.lit(12)), + ) + visit_occurrence_person = visit_occurrence_person.drop("birth_datetime") death = preprocess_domain_table(spark, input_folder, DEATH) if include_death else None @@ -97,16 +115,11 @@ def main( if include_concept_list and patient_events: column_names = patient_events.schema.fieldNames() # Filter out concepts - qualified_concepts = preprocess_domain_table( - spark, - input_folder, - 'qualified_concept_list' - ).select('standard_concept_id') + qualified_concepts = preprocess_domain_table(spark, input_folder, "qualified_concept_list").select( + "standard_concept_id" + ) - patient_events = patient_events.join( - qualified_concepts, - 'standard_concept_id' - ).select(column_names) + patient_events = patient_events.join(qualified_concepts, "standard_concept_id").select(column_names) # Process the measurement table if exists if MEASUREMENT in domain_table_list: @@ -114,41 +127,32 @@ def main( required_measurement = preprocess_domain_table(spark, input_folder, REQUIRED_MEASUREMENT) # The select is necessary to make sure the order of the columns is the same as the # original dataframe, otherwise the union might use the wrong columns - scaled_measurement = process_measurement( - spark, - measurement, - required_measurement, - output_folder - ) + scaled_measurement = process_measurement(spark, measurement, required_measurement, output_folder) if patient_events: # Union all measurement records together with other domain records - patient_events = patient_events.unionByName( - scaled_measurement - ) + patient_events = patient_events.unionByName(scaled_measurement) else: patient_events = scaled_measurement - patient_events = patient_events.join(visit_occurrence_person, 'visit_occurrence_id') \ - .select([patient_events[fieldName] for fieldName in patient_events.schema.fieldNames()] + - ['visit_concept_id', 'age']) \ - .withColumn('cohort_member_id', F.col('person_id')) + patient_events = ( + patient_events.join(visit_occurrence_person, "visit_occurrence_id") + .select( + [patient_events[fieldName] for fieldName in patient_events.schema.fieldNames()] + + ["visit_concept_id", "age"] + ) + .withColumn("cohort_member_id", F.col("person_id")) + ) # Apply the age security measure # We only keep the patient records, whose corresponding age is less than 90 if apply_age_filter: - patient_events = patient_events.where( - F.col('age') < 90 - ) + patient_events = patient_events.where(F.col("age") < 90) if not continue_from_events: - patient_events.write.mode("overwrite").parquet( - os.path.join(output_folder, 'all_patient_events') - ) + patient_events.write.mode("overwrite").parquet(os.path.join(output_folder, "all_patient_events")) - patient_events = spark.read.parquet( - os.path.join(output_folder, 'all_patient_events') - ) + patient_events = spark.read.parquet(os.path.join(output_folder, "all_patient_events")) if is_new_patient_representation: sequence_data = create_sequence_data_with_att( @@ -162,194 +166,171 @@ def main( att_type=att_type, exclude_demographic=exclude_demographic, use_age_group=use_age_group, - include_inpatient_hour_token=include_inpatient_hour_token + include_inpatient_hour_token=include_inpatient_hour_token, ) else: sequence_data = create_sequence_data( patient_events, date_filter=date_filter, include_visit_type=include_visit_type, - classic_bert_seq=is_classic_bert + classic_bert_seq=is_classic_bert, ) if include_prolonged_stay: - udf = F.when(F.col('visit_concept_id').isin([9201, 262, 9203]), - F.coalesce((F.datediff('visit_end_date', 'visit_start_date') > 7).cast('int'), - F.lit(0))).otherwise(F.lit(0)) + udf = F.when( + F.col("visit_concept_id").isin([9201, 262, 9203]), + F.coalesce( + (F.datediff("visit_end_date", "visit_start_date") > 7).cast("int"), + F.lit(0), + ), + ).otherwise(F.lit(0)) visit_occurrence = preprocess_domain_table(spark, input_folder, VISIT_OCCURRENCE) - visit_occurrence = visit_occurrence.withColumn('prolonged_length_stay', udf) \ - .select('person_id', 'prolonged_length_stay') \ - .withColumn('prolonged_length_stay', - F.max('prolonged_length_stay').over(W.partitionBy('person_id'))).distinct() - sequence_data = sequence_data.join(visit_occurrence, 'person_id') + visit_occurrence = ( + visit_occurrence.withColumn("prolonged_length_stay", udf) + .select("person_id", "prolonged_length_stay") + .withColumn( + "prolonged_length_stay", + F.max("prolonged_length_stay").over(W.partitionBy("person_id")), + ) + .distinct() + ) + sequence_data = sequence_data.join(visit_occurrence, "person_id") if include_sequence_information_content: - concept_df = patient_events.select('person_id', F.col('standard_concept_id').alias('concept_id')) - concept_freq = concept_df \ - .groupBy('concept_id').count() \ - .withColumn('prob', F.col('count') / F.sum('count').over(Window.partitionBy())) \ - .withColumn('ic', -F.log('prob')) + concept_df = patient_events.select("person_id", F.col("standard_concept_id").alias("concept_id")) + concept_freq = ( + concept_df.groupBy("concept_id") + .count() + .withColumn("prob", F.col("count") / F.sum("count").over(Window.partitionBy())) + .withColumn("ic", -F.log("prob")) + ) - patient_ic_df = concept_df.join(concept_freq, 'concept_id') \ - .groupby('person_id') \ - .agg(F.mean('ic').alias('ic')) + patient_ic_df = concept_df.join(concept_freq, "concept_id").groupby("person_id").agg(F.mean("ic").alias("ic")) - sequence_data = sequence_data.join(patient_ic_df, 'person_id') + sequence_data = sequence_data.join(patient_ic_df, "person_id") - patient_splits_folder = os.path.join(input_folder, 'patient_splits') + patient_splits_folder = os.path.join(input_folder, "patient_splits") if os.path.exists(patient_splits_folder): patient_splits = spark.read.parquet(patient_splits_folder) - sequence_data.join(patient_splits, 'person_id').write.mode('overwrite').parquet( - os.path.join(output_folder, 'patient_sequence', 'temp') + sequence_data.join(patient_splits, "person_id").write.mode("overwrite").parquet( + os.path.join(output_folder, "patient_sequence", "temp") ) - sequence_data = spark.read.parquet( - os.path.join(output_folder, 'patient_sequence', 'temp') + sequence_data = spark.read.parquet(os.path.join(output_folder, "patient_sequence", "temp")) + sequence_data.where('split="train"').write.mode("overwrite").parquet( + os.path.join(output_folder, "patient_sequence/train") ) - sequence_data.where('split="train"').write.mode('overwrite').parquet( - os.path.join(output_folder, 'patient_sequence/train') + sequence_data.where('split="test"').write.mode("overwrite").parquet( + os.path.join(output_folder, "patient_sequence/test") ) - sequence_data.where('split="test"').write.mode('overwrite').parquet( - os.path.join(output_folder, 'patient_sequence/test') - ) - shutil.rmtree(os.path.join(output_folder, 'patient_sequence', 'temp')) + shutil.rmtree(os.path.join(output_folder, "patient_sequence", "temp")) else: - sequence_data.write.mode('overwrite').parquet( - os.path.join(output_folder, 'patient_sequence') - ) + sequence_data.write.mode("overwrite").parquet(os.path.join(output_folder, "patient_sequence")) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Arguments for generate training data for Bert') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Arguments for generate training data for Bert") parser.add_argument( - '-i', - '--input_folder', - dest='input_folder', - action='store', - help='The path for your input_folder where the raw data is', - required=True + "-i", + "--input_folder", + dest="input_folder", + action="store", + help="The path for your input_folder where the raw data is", + required=True, ) parser.add_argument( - '-o', - '--output_folder', - dest='output_folder', - action='store', - help='The path for your output_folder', - required=True + "-o", + "--output_folder", + dest="output_folder", + action="store", + help="The path for your output_folder", + required=True, ) parser.add_argument( - '-tc', - '--domain_table_list', - dest='domain_table_list', - nargs='+', - action='store', - help='The list of domain tables you want to download', + "-tc", + "--domain_table_list", + dest="domain_table_list", + nargs="+", + action="store", + help="The list of domain tables you want to download", type=validate_table_names, - required=True + required=True, ) parser.add_argument( - '-d', - '--date_filter', - dest='date_filter', - type=lambda s: datetime.datetime.strptime(s, '%Y-%m-%d'), - action='store', + "-d", + "--date_filter", + dest="date_filter", + type=lambda s: datetime.datetime.strptime(s, "%Y-%m-%d"), + action="store", required=False, - default='2018-01-01' - ) - parser.add_argument( - '-iv', - '--include_visit_type', - dest='include_visit_type', - action='store_true', - help='Specify whether to include visit types for generating the training data' - ) - parser.add_argument( - '-ip', - '--is_new_patient_representation', - dest='is_new_patient_representation', - action='store_true', - help='Specify whether to generate the sequence of EHR records using the new patient ' - 'representation' - ) - parser.add_argument( - '-ib', - '--is_classic_bert_sequence', - dest='is_classic_bert_sequence', - action='store_true', - help='Specify whether to generate the sequence of EHR records using the classic BERT ' - 'sequence' - ) - parser.add_argument( - '-ev', - '--exclude_visit_tokens', - dest='exclude_visit_tokens', - action='store_true', - help='Specify whether or not to exclude the VS and VE tokens' - ) - parser.add_argument( - '--include_prolonged_length_stay', - dest='include_prolonged_stay', - action='store_true', - help='Specify whether or not to include the data for the second learning objective for ' - 'Med-BERT' - ) - parser.add_argument( - '--include_concept_list', - dest='include_concept_list', - action='store_true' - ) - parser.add_argument( - '--gpt_patient_sequence', - dest='gpt_patient_sequence', - action='store_true' - ) - parser.add_argument( - '--apply_age_filter', - dest='apply_age_filter', - action='store_true' + default="2018-01-01", ) parser.add_argument( - '--include_death', - dest='include_death', - action='store_true' + "-iv", + "--include_visit_type", + dest="include_visit_type", + action="store_true", + help="Specify whether to include visit types for generating the training data", ) parser.add_argument( - '--exclude_demographic', - dest='exclude_demographic', - action='store_true' + "-ip", + "--is_new_patient_representation", + dest="is_new_patient_representation", + action="store_true", + help="Specify whether to generate the sequence of EHR records using the new patient " "representation", ) parser.add_argument( - '--use_age_group', - dest='use_age_group', - action='store_true' + "-ib", + "--is_classic_bert_sequence", + dest="is_classic_bert_sequence", + action="store_true", + help="Specify whether to generate the sequence of EHR records using the classic BERT " "sequence", ) parser.add_argument( - '--with_drug_rollup', - dest='with_drug_rollup', - action='store_true' + "-ev", + "--exclude_visit_tokens", + dest="exclude_visit_tokens", + action="store_true", + help="Specify whether or not to exclude the VS and VE tokens", ) parser.add_argument( - '--include_inpatient_hour_token', - dest='include_inpatient_hour_token', - action='store_true' + "--include_prolonged_length_stay", + dest="include_prolonged_stay", + action="store_true", + help="Specify whether or not to include the data for the second learning objective for " "Med-BERT", ) + parser.add_argument("--include_concept_list", dest="include_concept_list", action="store_true") + parser.add_argument("--gpt_patient_sequence", dest="gpt_patient_sequence", action="store_true") + parser.add_argument("--apply_age_filter", dest="apply_age_filter", action="store_true") + parser.add_argument("--include_death", dest="include_death", action="store_true") + parser.add_argument("--exclude_demographic", dest="exclude_demographic", action="store_true") + parser.add_argument("--use_age_group", dest="use_age_group", action="store_true") + parser.add_argument("--with_drug_rollup", dest="with_drug_rollup", action="store_true") parser.add_argument( - '--continue_from_events', - dest='continue_from_events', - action='store_true' + "--include_inpatient_hour_token", + dest="include_inpatient_hour_token", + action="store_true", ) + parser.add_argument("--continue_from_events", dest="continue_from_events", action="store_true") parser.add_argument( - '--att_type', - dest='att_type', - action='store', + "--att_type", + dest="att_type", + action="store", choices=[e.value for e in AttType], ) ARGS = parser.parse_args() main( - ARGS.input_folder, ARGS.output_folder, ARGS.domain_table_list, ARGS.date_filter, - ARGS.include_visit_type, ARGS.is_new_patient_representation, ARGS.exclude_visit_tokens, - ARGS.is_classic_bert_sequence, ARGS.include_prolonged_stay, ARGS.include_concept_list, + ARGS.input_folder, + ARGS.output_folder, + ARGS.domain_table_list, + ARGS.date_filter, + ARGS.include_visit_type, + ARGS.is_new_patient_representation, + ARGS.exclude_visit_tokens, + ARGS.is_classic_bert_sequence, + ARGS.include_prolonged_stay, + ARGS.include_concept_list, ARGS.gpt_patient_sequence, ARGS.apply_age_filter, ARGS.include_death, @@ -358,5 +339,5 @@ def main( use_age_group=ARGS.use_age_group, with_drug_rollup=ARGS.with_drug_rollup, include_inpatient_hour_token=ARGS.include_inpatient_hour_token, - continue_from_events=ARGS.continue_from_events + continue_from_events=ARGS.continue_from_events, ) diff --git a/src/cehrbert/spark_apps/legacy/mortality.py b/src/cehrbert/spark_apps/legacy/mortality.py index 2c3b1470..19c0db39 100644 --- a/src/cehrbert/spark_apps/legacy/mortality.py +++ b/src/cehrbert/spark_apps/legacy/mortality.py @@ -5,9 +5,9 @@ from ..spark_parse_args import create_spark_args QUALIFIED_DEATH_DATE_QUERY = """ -WITH max_death_date_cte AS +WITH max_death_date_cte AS ( - SELECT + SELECT person_id, MAX(death_date) AS death_date FROM global_temp.death @@ -34,7 +34,7 @@ WITH last_visit_cte AS ( SELECT v.*, - COUNT(CASE WHEN DATE(v.visit_start_date) >= DATE_SUB(index_date, {observation_period}) + COUNT(CASE WHEN DATE(v.visit_start_date) >= DATE_SUB(index_date, {observation_period}) AND DATE(v.visit_start_date) < index_date THEN 1 ELSE NULL END) OVER (PARTITION BY v.person_id) AS num_of_visits FROM @@ -42,13 +42,13 @@ SELECT DISTINCT v.person_id, v.visit_start_date, - FIRST(v.visit_occurrence_id) OVER(PARTITION BY v.person_id + FIRST(v.visit_occurrence_id) OVER(PARTITION BY v.person_id ORDER BY DATE(v.visit_start_date) DESC) AS visit_occurrence_id, - FIRST(DATE(v.visit_start_date)) OVER(PARTITION BY v.person_id + FIRST(DATE(v.visit_start_date)) OVER(PARTITION BY v.person_id ORDER BY DATE(v.visit_start_date) DESC) AS index_date, - FIRST(v.discharge_to_concept_id) OVER(PARTITION BY v.person_id + FIRST(v.discharge_to_concept_id) OVER(PARTITION BY v.person_id ORDER BY DATE(v.visit_start_date) DESC) AS discharge_to_concept_id, - FIRST(DATE(v.visit_start_date)) OVER(PARTITION BY v.person_id + FIRST(DATE(v.visit_start_date)) OVER(PARTITION BY v.person_id ORDER BY DATE(v.visit_start_date)) AS earliest_visit_start_date FROM global_temp.visit_occurrence AS v -- Need to make sure the there is enough observation for the observation window. @@ -76,12 +76,12 @@ --AND v.num_of_visits >= {num_of_visits} """ -DOMAIN_TABLE_LIST = ['condition_occurrence', 'drug_exposure', 'procedure_occurrence'] +DOMAIN_TABLE_LIST = ["condition_occurrence", "drug_exposure", "procedure_occurrence"] -COHORT_TABLE = 'cohort' -DEATH = 'death' -PERSON = 'person' -VISIT_OCCURRENCE = 'visit_occurrence' +COHORT_TABLE = "cohort" +DEATH = "death" +PERSON = "person" +VISIT_OCCURRENCE = "visit_occurrence" DEPENDENCY_LIST = [DEATH, PERSON, VISIT_OCCURRENCE] @@ -90,12 +90,14 @@ class MortalityCohortBuilder(LastVisitCohortBuilderBase): def preprocess_dependencies(self): self.spark.sql(QUALIFIED_DEATH_DATE_QUERY).createOrReplaceGlobalTempView(DEATH) - num_of_visits = ((self._observation_window // 360) + 1) + num_of_visits = (self._observation_window // 360) + 1 - cohort_query = COHORT_QUERY_TEMPLATE.format(date_lower_bound=self._date_lower_bound, - date_upper_bound=self._date_upper_bound, - observation_period=self._observation_window, - num_of_visits=num_of_visits) + cohort_query = COHORT_QUERY_TEMPLATE.format( + date_lower_bound=self._date_lower_bound, + date_upper_bound=self._date_upper_bound, + observation_period=self._observation_window, + num_of_visits=num_of_visits, + ) cohort = self.spark.sql(cohort_query) cohort.createOrReplaceGlobalTempView(COHORT_TABLE) @@ -104,15 +106,16 @@ def preprocess_dependencies(self): def create_incident_cases(self): cohort = self._dependency_dict[COHORT_TABLE] - return cohort.where(f.col('label') == 1) + return cohort.where(f.col("label") == 1) def create_control_cases(self): cohort = self._dependency_dict[COHORT_TABLE] - return cohort.where(f.col('label') == 0) + return cohort.where(f.col("label") == 0) def create_matching_control_cases(self, incident_cases: DataFrame, control_cases: DataFrame): """ - Do not match for control and simply what's in the control cases + Do not match for control and simply what's in the control cases. + :param incident_cases: :param control_cases: :return: @@ -120,45 +123,61 @@ def create_matching_control_cases(self, incident_cases: DataFrame, control_cases return control_cases -def main(cohort_name, input_folder, output_folder, date_lower_bound, date_upper_bound, - age_lower_bound, age_upper_bound, observation_window, prediction_window, hold_off_window, - index_date_match_window, include_visit_type, is_feature_concept_frequency, - is_roll_up_concept): - cohort_builder = MortalityCohortBuilder(cohort_name, - input_folder, - output_folder, - date_lower_bound, - date_upper_bound, - age_lower_bound, - age_upper_bound, - observation_window, - prediction_window, - hold_off_window, - index_date_match_window, - DOMAIN_TABLE_LIST, - DEPENDENCY_LIST, - True, - include_visit_type, - is_feature_concept_frequency, - is_roll_up_concept) +def main( + cohort_name, + input_folder, + output_folder, + date_lower_bound, + date_upper_bound, + age_lower_bound, + age_upper_bound, + observation_window, + prediction_window, + hold_off_window, + index_date_match_window, + include_visit_type, + is_feature_concept_frequency, + is_roll_up_concept, +): + cohort_builder = MortalityCohortBuilder( + cohort_name, + input_folder, + output_folder, + date_lower_bound, + date_upper_bound, + age_lower_bound, + age_upper_bound, + observation_window, + prediction_window, + hold_off_window, + index_date_match_window, + DOMAIN_TABLE_LIST, + DEPENDENCY_LIST, + True, + include_visit_type, + is_feature_concept_frequency, + is_roll_up_concept, + ) cohort_builder.build() -if __name__ == '__main__': +if __name__ == "__main__": spark_args = create_spark_args() - main(spark_args.cohort_name, - spark_args.input_folder, - spark_args.output_folder, - spark_args.date_lower_bound, - spark_args.date_upper_bound, - spark_args.lower_bound, - spark_args.upper_bound, - spark_args.observation_window, - spark_args.prediction_window, - spark_args.hold_off_window, - spark_args.index_date_match_window, - spark_args.include_visit_type, - spark_args.is_feature_concept_frequency, - spark_args.is_roll_up_concept) + main( + spark_args.cohort_name, + spark_args.input_folder, + spark_args.output_folder, + spark_args.date_lower_bound, + spark_args.date_upper_bound, + spark_args.lower_bound, + spark_args.upper_bound, + spark_args.observation_window, + spark_args.prediction_window, + spark_args.hold_off_window, + spark_args.index_date_match_window, + spark_args.include_visit_type, + spark_args.is_feature_concept_frequency, + spark_args.is_roll_up_concept, + ) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/afib_ischemic_stroke.py b/src/cehrbert/spark_apps/prediction_cohorts/afib_ischemic_stroke.py index 76bef31b..2edd5a10 100644 --- a/src/cehrbert/spark_apps/prediction_cohorts/afib_ischemic_stroke.py +++ b/src/cehrbert/spark_apps/prediction_cohorts/afib_ischemic_stroke.py @@ -1,12 +1,13 @@ -from ..cohorts import ischemic_stroke, atrial_fibrillation +from ..cohorts import atrial_fibrillation, ischemic_stroke from ..cohorts.spark_app_base import create_prediction_cohort - from ..spark_parse_args import create_spark_args -DOMAIN_TABLE_LIST = ['condition_occurrence', 'drug_exposure', 'procedure_occurrence'] +DOMAIN_TABLE_LIST = ["condition_occurrence", "drug_exposure", "procedure_occurrence"] -if __name__ == '__main__': - create_prediction_cohort(create_spark_args(), - atrial_fibrillation.query_builder(), - ischemic_stroke.query_builder(), - DOMAIN_TABLE_LIST) +if __name__ == "__main__": + create_prediction_cohort( + create_spark_args(), + atrial_fibrillation.query_builder(), + ischemic_stroke.query_builder(), + DOMAIN_TABLE_LIST, + ) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/cad_cabg_cohort.py b/src/cehrbert/spark_apps/prediction_cohorts/cad_cabg_cohort.py index 34976534..93296cd1 100644 --- a/src/cehrbert/spark_apps/prediction_cohorts/cad_cabg_cohort.py +++ b/src/cehrbert/spark_apps/prediction_cohorts/cad_cabg_cohort.py @@ -1,11 +1,11 @@ -from ..spark_parse_args import create_spark_args -from ..cohorts import coronary_artery_disease as cad from ..cohorts import cabg +from ..cohorts import coronary_artery_disease as cad from ..cohorts.spark_app_base import create_prediction_cohort +from ..spark_parse_args import create_spark_args -DOMAIN_TABLE_LIST = ['condition_occurrence', 'drug_exposure', 'procedure_occurrence'] +DOMAIN_TABLE_LIST = ["condition_occurrence", "drug_exposure", "procedure_occurrence"] -if __name__ == '__main__': +if __name__ == "__main__": spark_args = create_spark_args() ehr_table_list = spark_args.ehr_table_list if spark_args.ehr_table_list else DOMAIN_TABLE_LIST @@ -14,5 +14,5 @@ spark_args, cad.query_builder(spark_args), cabg.query_builder(spark_args), - ehr_table_list + ehr_table_list, ) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/cad_hf_cohort.py b/src/cehrbert/spark_apps/prediction_cohorts/cad_hf_cohort.py index cf028a8e..767babbc 100644 --- a/src/cehrbert/spark_apps/prediction_cohorts/cad_hf_cohort.py +++ b/src/cehrbert/spark_apps/prediction_cohorts/cad_hf_cohort.py @@ -1,19 +1,18 @@ -from ..spark_parse_args import create_spark_args from ..cohorts import coronary_artery_disease as cad from ..cohorts import heart_failure as hf from ..cohorts.spark_app_base import create_prediction_cohort +from ..spark_parse_args import create_spark_args -DOMAIN_TABLE_LIST = ['condition_occurrence', 'drug_exposure', - 'procedure_occurrence', 'measurement'] +DOMAIN_TABLE_LIST = [ + "condition_occurrence", + "drug_exposure", + "procedure_occurrence", + "measurement", +] -if __name__ == '__main__': +if __name__ == "__main__": spark_args = create_spark_args() ehr_table_list = spark_args.ehr_table_list if spark_args.ehr_table_list else DOMAIN_TABLE_LIST - create_prediction_cohort( - spark_args, - cad.query_builder(spark_args), - hf.query_builder(), - ehr_table_list - ) + create_prediction_cohort(spark_args, cad.query_builder(spark_args), hf.query_builder(), ehr_table_list) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/copd_readmission.py b/src/cehrbert/spark_apps/prediction_cohorts/copd_readmission.py index e8ec130a..327ca843 100644 --- a/src/cehrbert/spark_apps/prediction_cohorts/copd_readmission.py +++ b/src/cehrbert/spark_apps/prediction_cohorts/copd_readmission.py @@ -1,12 +1,11 @@ -from ..spark_parse_args import create_spark_args +from ..cohorts.query_builder import AncestorTableSpec, QueryBuilder, QuerySpec from ..cohorts.spark_app_base import create_prediction_cohort -from ..cohorts.query_builder import QueryBuilder, AncestorTableSpec, QuerySpec - +from ..spark_parse_args import create_spark_args COPD_HOSPITALIZATION_QUERY = """ WITH copd_conditions AS ( - SELECT DISTINCT - descendant_concept_id AS concept_id + SELECT DISTINCT + descendant_concept_id AS concept_id FROM global_temp.concept_ancestor AS ca WHERE ca.ancestor_concept_id in (255573, 258780) ) @@ -34,32 +33,37 @@ WHERE v.visit_concept_id IN (9201, 262) --inpatient, er-inpatient """ -COPD_HOSPITALIZATION_COHORT = 'copd_readmission' -HOSPITALIZATION_COHORT = 'hospitalization' -DEPENDENCY_LIST = ['person', 'condition_occurrence', 'visit_occurrence'] -DOMAIN_TABLE_LIST = ['condition_occurrence'] +COPD_HOSPITALIZATION_COHORT = "copd_readmission" +HOSPITALIZATION_COHORT = "hospitalization" +DEPENDENCY_LIST = ["person", "condition_occurrence", "visit_occurrence"] +DOMAIN_TABLE_LIST = ["condition_occurrence"] def main(spark_args): - copd_inpatient_query = QuerySpec(table_name=COPD_HOSPITALIZATION_COHORT, - query_template=COPD_HOSPITALIZATION_QUERY, - parameters={}) - copd_inpatient = QueryBuilder(cohort_name=COPD_HOSPITALIZATION_COHORT, - dependency_list=DEPENDENCY_LIST, - query=copd_inpatient_query) + copd_inpatient_query = QuerySpec( + table_name=COPD_HOSPITALIZATION_COHORT, + query_template=COPD_HOSPITALIZATION_QUERY, + parameters={}, + ) + copd_inpatient = QueryBuilder( + cohort_name=COPD_HOSPITALIZATION_COHORT, + dependency_list=DEPENDENCY_LIST, + query=copd_inpatient_query, + ) - hospitalization_query = QuerySpec(table_name=HOSPITALIZATION_COHORT, - query_template=HOSPITALIZATION_QUERY, - parameters={}) - hospitalization = QueryBuilder(cohort_name=HOSPITALIZATION_COHORT, - dependency_list=DEPENDENCY_LIST, - query=hospitalization_query) + hospitalization_query = QuerySpec( + table_name=HOSPITALIZATION_COHORT, + query_template=HOSPITALIZATION_QUERY, + parameters={}, + ) + hospitalization = QueryBuilder( + cohort_name=HOSPITALIZATION_COHORT, + dependency_list=DEPENDENCY_LIST, + query=hospitalization_query, + ) - create_prediction_cohort(spark_args, - copd_inpatient, - hospitalization, - DOMAIN_TABLE_LIST) + create_prediction_cohort(spark_args, copd_inpatient, hospitalization, DOMAIN_TABLE_LIST) -if __name__ == '__main__': +if __name__ == "__main__": main(create_spark_args()) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/covid_death.py b/src/cehrbert/spark_apps/prediction_cohorts/covid_death.py index 0a50d479..74f6c549 100644 --- a/src/cehrbert/spark_apps/prediction_cohorts/covid_death.py +++ b/src/cehrbert/spark_apps/prediction_cohorts/covid_death.py @@ -1,12 +1,13 @@ from ..cohorts import covid_inpatient, death from ..cohorts.spark_app_base import create_prediction_cohort - from ..spark_parse_args import create_spark_args -DOMAIN_TABLE_LIST = ['condition_occurrence', 'drug_exposure', 'procedure_occurrence'] +DOMAIN_TABLE_LIST = ["condition_occurrence", "drug_exposure", "procedure_occurrence"] -if __name__ == '__main__': - create_prediction_cohort(create_spark_args(), - covid_inpatient.query_builder(), - death.query_builder(), - DOMAIN_TABLE_LIST) +if __name__ == "__main__": + create_prediction_cohort( + create_spark_args(), + covid_inpatient.query_builder(), + death.query_builder(), + DOMAIN_TABLE_LIST, + ) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/covid_ventilation.py b/src/cehrbert/spark_apps/prediction_cohorts/covid_ventilation.py index 299ac86b..1063fcb1 100644 --- a/src/cehrbert/spark_apps/prediction_cohorts/covid_ventilation.py +++ b/src/cehrbert/spark_apps/prediction_cohorts/covid_ventilation.py @@ -1,12 +1,13 @@ from ..cohorts import covid, ventilation from ..cohorts.spark_app_base import create_prediction_cohort - from ..spark_parse_args import create_spark_args -DOMAIN_TABLE_LIST = ['condition_occurrence', 'drug_exposure', 'procedure_occurrence'] +DOMAIN_TABLE_LIST = ["condition_occurrence", "drug_exposure", "procedure_occurrence"] -if __name__ == '__main__': - create_prediction_cohort(create_spark_args(), - covid.query_builder(), - ventilation.query_builder(), - DOMAIN_TABLE_LIST) +if __name__ == "__main__": + create_prediction_cohort( + create_spark_args(), + covid.query_builder(), + ventilation.query_builder(), + DOMAIN_TABLE_LIST, + ) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/discharge_home_death.py b/src/cehrbert/spark_apps/prediction_cohorts/discharge_home_death.py index 25b5bd39..58d80401 100644 --- a/src/cehrbert/spark_apps/prediction_cohorts/discharge_home_death.py +++ b/src/cehrbert/spark_apps/prediction_cohorts/discharge_home_death.py @@ -1,12 +1,16 @@ -from ..cohorts.spark_app_base import create_prediction_cohort from ..cohorts import death from ..cohorts import last_visit_discharged_home as last - +from ..cohorts.spark_app_base import create_prediction_cohort from ..spark_parse_args import create_spark_args -DOMAIN_TABLE_LIST = ['condition_occurrence', 'drug_exposure', 'procedure_occurrence', 'measurement'] +DOMAIN_TABLE_LIST = [ + "condition_occurrence", + "drug_exposure", + "procedure_occurrence", + "measurement", +] -if __name__ == '__main__': +if __name__ == "__main__": spark_args = create_spark_args() ehr_table_list = spark_args.ehr_table_list if spark_args.ehr_table_list else DOMAIN_TABLE_LIST @@ -14,5 +18,5 @@ spark_args, last.query_builder(spark_args), death.query_builder(), - ehr_table_list + ehr_table_list, ) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/hf_readmission.py b/src/cehrbert/spark_apps/prediction_cohorts/hf_readmission.py index 10653218..42b2baed 100644 --- a/src/cehrbert/spark_apps/prediction_cohorts/hf_readmission.py +++ b/src/cehrbert/spark_apps/prediction_cohorts/hf_readmission.py @@ -4,8 +4,8 @@ HEART_FAILURE_HOSPITALIZATION_QUERY = """ WITH hf_concepts AS ( - SELECT DISTINCT - descendant_concept_id AS concept_id + SELECT DISTINCT + descendant_concept_id AS concept_id FROM global_temp.concept_ancestor AS ca WHERE ca.ancestor_concept_id = 316139 ) @@ -35,45 +35,45 @@ WHERE v.visit_concept_id IN (9201, 262, 8971, 8920) --inpatient, er-inpatient """ -HF_HOSPITALIZATION_COHORT = 'hf_hospitalization' -HOSPITALIZATION_COHORT = 'hospitalization' -DEPENDENCY_LIST = ['person', 'condition_occurrence', 'visit_occurrence'] -DOMAIN_TABLE_LIST = ['condition_occurrence', 'drug_exposure', 'procedure_occurrence', 'measurement'] +HF_HOSPITALIZATION_COHORT = "hf_hospitalization" +HOSPITALIZATION_COHORT = "hospitalization" +DEPENDENCY_LIST = ["person", "condition_occurrence", "visit_occurrence"] +DOMAIN_TABLE_LIST = [ + "condition_occurrence", + "drug_exposure", + "procedure_occurrence", + "measurement", +] def main(spark_args): hf_inpatient_target_query = QuerySpec( table_name=HF_HOSPITALIZATION_COHORT, query_template=HEART_FAILURE_HOSPITALIZATION_QUERY, - parameters={'date_lower_bound': spark_args.date_lower_bound} + parameters={"date_lower_bound": spark_args.date_lower_bound}, ) hf_inpatient_target_querybuilder = QueryBuilder( cohort_name=HF_HOSPITALIZATION_COHORT, dependency_list=DEPENDENCY_LIST, - query=hf_inpatient_target_query + query=hf_inpatient_target_query, ) hospitalization_query = QuerySpec( table_name=HOSPITALIZATION_COHORT, query_template=HOSPITALIZATION_QUERY, - parameters={} + parameters={}, ) hospitalization = QueryBuilder( cohort_name=HOSPITALIZATION_COHORT, dependency_list=DEPENDENCY_LIST, - query=hospitalization_query + query=hospitalization_query, ) ehr_table_list = spark_args.ehr_table_list if spark_args.ehr_table_list else DOMAIN_TABLE_LIST - create_prediction_cohort( - spark_args, - hf_inpatient_target_querybuilder, - hospitalization, - ehr_table_list - ) + create_prediction_cohort(spark_args, hf_inpatient_target_querybuilder, hospitalization, ehr_table_list) -if __name__ == '__main__': +if __name__ == "__main__": main(create_spark_args()) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/hospitalization.py b/src/cehrbert/spark_apps/prediction_cohorts/hospitalization.py index 6e3679a9..918f2322 100644 --- a/src/cehrbert/spark_apps/prediction_cohorts/hospitalization.py +++ b/src/cehrbert/spark_apps/prediction_cohorts/hospitalization.py @@ -1,6 +1,6 @@ -from ..spark_parse_args import create_spark_args -from ..cohorts.spark_app_base import create_prediction_cohort from ..cohorts.query_builder import QueryBuilder, QuerySpec +from ..cohorts.spark_app_base import create_prediction_cohort +from ..spark_parse_args import create_spark_args HOSPITALIZATION_OUTCOME_QUERY = """ SELECT DISTINCT @@ -23,7 +23,7 @@ ), HOSPITAL_TARGET AS ( - SELECT DISTINCT + SELECT DISTINCT iv.person_id, iv.index_date, count(distinct case when v1.visit_concept_id IN (9201, 262) then v1.visit_occurrence_id end) as num_of_hospitalizations, @@ -37,7 +37,7 @@ GROUP BY iv.person_id, iv.index_date ) -SELECT +SELECT person_id, index_date, CAST(null AS INT) AS visit_occurrence_id @@ -46,10 +46,15 @@ AND index_date >= '{date_lower_bound}' """ -HOSPITALIZATION_TARGET_COHORT = 'hospitalization_target' -HOSPITALIZATION_OUTCOME_COHORT = 'hospitalization_outcome' -DEPENDENCY_LIST = ['person', 'condition_occurrence', 'visit_occurrence'] -DOMAIN_TABLE_LIST = ['condition_occurrence', 'drug_exposure', 'procedure_occurrence', 'measurement'] +HOSPITALIZATION_TARGET_COHORT = "hospitalization_target" +HOSPITALIZATION_OUTCOME_COHORT = "hospitalization_outcome" +DEPENDENCY_LIST = ["person", "condition_occurrence", "visit_occurrence"] +DOMAIN_TABLE_LIST = [ + "condition_occurrence", + "drug_exposure", + "procedure_occurrence", + "measurement", +] def main(spark_args): @@ -58,25 +63,25 @@ def main(spark_args): table_name=HOSPITALIZATION_TARGET_COHORT, query_template=HOSPITALIZATION_TARGET_QUERY, parameters={ - 'total_window': total_window, - 'date_lower_bound': spark_args.date_lower_bound - } + "total_window": total_window, + "date_lower_bound": spark_args.date_lower_bound, + }, ) hospitalization_querybuilder = QueryBuilder( cohort_name=HOSPITALIZATION_TARGET_COHORT, dependency_list=DEPENDENCY_LIST, - query=hospitalization_target_query + query=hospitalization_target_query, ) hospitalization_outcome_query = QuerySpec( table_name=HOSPITALIZATION_OUTCOME_COHORT, query_template=HOSPITALIZATION_OUTCOME_QUERY, - parameters={} + parameters={}, ) hospitalization_outcome_querybuilder = QueryBuilder( cohort_name=HOSPITALIZATION_OUTCOME_COHORT, dependency_list=DEPENDENCY_LIST, - query=hospitalization_outcome_query + query=hospitalization_outcome_query, ) ehr_table_list = spark_args.ehr_table_list if spark_args.ehr_table_list else DOMAIN_TABLE_LIST @@ -85,9 +90,9 @@ def main(spark_args): spark_args, hospitalization_querybuilder, hospitalization_outcome_querybuilder, - ehr_table_list + ehr_table_list, ) -if __name__ == '__main__': +if __name__ == "__main__": main(create_spark_args()) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/hospitalization_mortality.py b/src/cehrbert/spark_apps/prediction_cohorts/hospitalization_mortality.py index c9533b7d..6e2452b8 100644 --- a/src/cehrbert/spark_apps/prediction_cohorts/hospitalization_mortality.py +++ b/src/cehrbert/spark_apps/prediction_cohorts/hospitalization_mortality.py @@ -1,10 +1,14 @@ +from ..cohorts.query_builder import QueryBuilder, QuerySpec from ..cohorts.spark_app_base import create_prediction_cohort - from ..spark_parse_args import create_spark_args -from ..cohorts.query_builder import QueryBuilder, QuerySpec -DEPENDENCY_LIST = ['visit_occurrence'] -DOMAIN_TABLE_LIST = ['condition_occurrence', 'drug_exposure', 'procedure_occurrence', 'measurement'] +DEPENDENCY_LIST = ["visit_occurrence"] +DOMAIN_TABLE_LIST = [ + "condition_occurrence", + "drug_exposure", + "procedure_occurrence", + "measurement", +] HOSPITALIZATION_QUERY = """ SELECT DISTINCT @@ -12,7 +16,7 @@ v.visit_occurrence_id, v.index_date, v.expired -FROM +FROM ( SELECT v.person_id, @@ -39,41 +43,39 @@ WHERE expired = 1 """ -HOSPITALIZATION_TARGET_COHORT = 'hospitalization_target' -MORTALITY_COHORT = 'hospitalization_mortality' +HOSPITALIZATION_TARGET_COHORT = "hospitalization_target" +MORTALITY_COHORT = "hospitalization_mortality" -if __name__ == '__main__': +if __name__ == "__main__": spark_args = create_spark_args() ehr_table_list = spark_args.ehr_table_list if spark_args.ehr_table_list else DOMAIN_TABLE_LIST hospitalization_target_query = QuerySpec( table_name=HOSPITALIZATION_TARGET_COHORT, query_template=HOSPITALIZATION_QUERY, - parameters={ - 'date_lower_bound': spark_args.date_lower_bound - } + parameters={"date_lower_bound": spark_args.date_lower_bound}, ) hospitalization_querybuilder = QueryBuilder( cohort_name=HOSPITALIZATION_TARGET_COHORT, dependency_list=DEPENDENCY_LIST, - query=hospitalization_target_query + query=hospitalization_target_query, ) hospitalization_mortality_query = QuerySpec( table_name=MORTALITY_COHORT, query_template=MORTALITY_QUERY, - parameters={'target_table_name': HOSPITALIZATION_TARGET_COHORT} + parameters={"target_table_name": HOSPITALIZATION_TARGET_COHORT}, ) hospitalization_mortality_querybuilder = QueryBuilder( cohort_name=MORTALITY_COHORT, dependency_list=DEPENDENCY_LIST, - query=hospitalization_mortality_query + query=hospitalization_mortality_query, ) create_prediction_cohort( spark_args, hospitalization_querybuilder, hospitalization_mortality_querybuilder, - ehr_table_list + ehr_table_list, ) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/t2dm_hf_cohort.py b/src/cehrbert/spark_apps/prediction_cohorts/t2dm_hf_cohort.py index 9c9fa9bd..c737cfdb 100644 --- a/src/cehrbert/spark_apps/prediction_cohorts/t2dm_hf_cohort.py +++ b/src/cehrbert/spark_apps/prediction_cohorts/t2dm_hf_cohort.py @@ -1,19 +1,18 @@ -from ..spark_parse_args import create_spark_args -from ..cohorts import type_two_diabietes as t2dm from ..cohorts import heart_failure as hf +from ..cohorts import type_two_diabietes as t2dm from ..cohorts.spark_app_base import create_prediction_cohort +from ..spark_parse_args import create_spark_args -DOMAIN_TABLE_LIST = ['condition_occurrence', 'drug_exposure', - 'procedure_occurrence', 'measurement'] +DOMAIN_TABLE_LIST = [ + "condition_occurrence", + "drug_exposure", + "procedure_occurrence", + "measurement", +] -if __name__ == '__main__': +if __name__ == "__main__": spark_args = create_spark_args() ehr_table_list = spark_args.ehr_table_list if spark_args.ehr_table_list else DOMAIN_TABLE_LIST - create_prediction_cohort( - spark_args, - t2dm.query_builder(spark_args), - hf.query_builder(), - ehr_table_list - ) + create_prediction_cohort(spark_args, t2dm.query_builder(spark_args), hf.query_builder(), ehr_table_list) diff --git a/src/cehrbert/spark_apps/spark_parse_args.py b/src/cehrbert/spark_apps/spark_parse_args.py index 0532b883..518b67ad 100644 --- a/src/cehrbert/spark_apps/spark_parse_args.py +++ b/src/cehrbert/spark_apps/spark_parse_args.py @@ -1,219 +1,350 @@ +""" +This module defines functions for parsing command-line arguments for Spark applications. + +that generate cohort definitions. It includes argument parsing for cohort specifications, +date ranges, patient information, and EHR data extraction settings. + +Functions: + valid_date: Validates and converts a date string into a datetime object. + create_spark_args: Defines and parses command-line arguments for cohort generation and EHR + processing. +""" + import argparse import datetime -from ..spark_apps.decorators.patient_event_decorator import AttType +from .decorators.patient_event_decorator import AttType def valid_date(s): + """ + Validates and converts a date string into a datetime object. + + Args: + s (str): The date string in the format 'YYYY-MM-DD'. + Returns: + datetime.datetime: The parsed date. + Raises: + argparse.ArgumentTypeError: If the date string is not valid. + """ try: return datetime.datetime.strptime(s, "%Y-%m-%d") - except ValueError: - msg = "Not a valid date: '{0}'.".format(s) - raise argparse.ArgumentTypeError(msg) + except ValueError as e: + raise argparse.ArgumentTypeError(e) def create_spark_args(): - parser = argparse.ArgumentParser( - description='Arguments for spark applications for generating cohort definitions') - parser.add_argument('-c', - '--cohort_name', - dest='cohort_name', - action='store', - help='The cohort name', - required=True) - parser.add_argument('-i', - '--input_folder', - dest='input_folder', - action='store', - help='The path for your input_folder where the sequence data is', - required=True) - parser.add_argument('--patient_splits_folder', - dest='patient_splits_folder', - action='store', - help='The folder that contains the patient_splits data', - required=False) - parser.add_argument('-o', - '--output_folder', - dest='output_folder', - action='store', - help='The path for your output_folder', - required=True) - parser.add_argument('--ehr_table_list', - dest='ehr_table_list', - nargs='+', - action='store', - help='The list of domain tables you want to include for feature extraction', - required=False) - parser.add_argument('-dl', - '--date_lower_bound', - dest='date_lower_bound', - action='store', - help='The date filter lower bound for filtering training data', - required=True, - type=valid_date) - parser.add_argument('-du', - '--date_upper_bound', - dest='date_upper_bound', - action='store', - help='The date filter upper bound for filtering training data', - required=True, - type=valid_date) - parser.add_argument('-l', - '--age_lower_bound', - dest='age_lower_bound', - action='store', - help='The age lower bound', - required=False, - type=int, - default=0) - parser.add_argument('-u', - '--age_upper_bound', - dest='age_upper_bound', - action='store', - help='The age upper bound', - required=False, - type=int, - default=100) - parser.add_argument('-ow', - '--observation_window', - dest='observation_window', - action='store', - help='The observation window in days for extracting features', - required=False, - type=int, - default=365) - parser.add_argument('-pw', - '--prediction_window', - dest='prediction_window', - action='store', - help='The prediction window in which the prediction is made', - required=False, - type=int, - default=180) - parser.add_argument('-ps', - '--prediction_start_days', - dest='prediction_start_days', - action='store', - help='The prediction start days in which the prediction is made', - required=False, - type=int, - default=1) - parser.add_argument('-hw', - '--hold_off_window', - dest='hold_off_window', - action='store', - help='The hold off window for excluding the features', - required=False, - type=int, - default=0) - parser.add_argument('--num_of_visits', - dest='num_of_visits', - action='store', - help='The number of visits to qualify for the inclusion of the cohorts', - required=False, - type=int, - default=0) - parser.add_argument('--num_of_concepts', - dest='num_of_concepts', - action='store', - help='The number of concepts to qualify for the inclusion of the cohorts', - required=False, - type=int, - default=0) - parser.add_argument('-iw', - '--is_window_post_index', - dest='is_window_post_index', - action='store_true', - help='Indicate if the observation window is pre/post the index date') - parser.add_argument('-iv', - '--include_visit_type', - dest='include_visit_type', - action='store_true', - help='Specify whether to include visit types for ' - 'generating the training data') - parser.add_argument('-ev', - '--exclude_visit_tokens', - dest='exclude_visit_tokens', - action='store_true', - help='Specify whether or not to exclude the VS and VE tokens') - parser.add_argument('-f', - '--is_feature_concept_frequency', - dest='is_feature_concept_frequency', - action='store_true', - help='Specify whether the features are concept counts or not') - parser.add_argument('-ir', - '--is_roll_up_concept', - dest='is_roll_up_concept', - action='store_true', - help='Specify whether to roll up the concepts to their ancestors') - parser.add_argument('-ip', - '--is_new_patient_representation', - dest='is_new_patient_representation', - action='store_true', - help='Specify whether to generate the sequence of ' - 'EHR records using the new patient representation') - parser.add_argument('--gpt_patient_sequence', - dest='gpt_patient_sequence', - action='store_true', - help='Specify whether to generate the GPT sequence of ' - 'EHR records using the new patient representation') - parser.add_argument('-ih', - '--is_hierarchical_bert', - dest='is_hierarchical_bert', - action='store_true', - help='Specify whether to generate the sequence of ' - 'EHR records using the hierarchical patient representation') - parser.add_argument('-cbs', - '--classic_bert_seq', - dest='classic_bert_seq', - action='store_true', - help='Specify whether to generate the sequence of ' - 'EHR records using the classic BERT sequence representation where ' - 'visits are separated by a SEP token') - parser.add_argument('--is_first_time_outcome', - dest='is_first_time_outcome', - action='store_true', - help='is the outcome the first time occurrence?') - parser.add_argument('--is_remove_index_prediction_starts', - dest='is_remove_index_prediction_starts', - action='store_true', - help='is outcome between index_date and prediction start window removed?') - parser.add_argument('--is_prediction_window_unbounded', - dest='is_prediction_window_unbounded', - action='store_true', - help='is the end of the prediction window unbounded?') - parser.add_argument('--is_observation_window_unbounded', - dest='is_observation_window_unbounded', - action='store_true', - help='is the observation window unbounded?') - parser.add_argument('--include_concept_list', - dest='include_concept_list', - action='store_true', - help='Apply the filter to remove low-frequency concepts') - parser.add_argument('--allow_measurement_only', - dest='allow_measurement_only', - action='store_true', - help='Indicate whether we allow patients with measurements only') - parser.add_argument('--is_population_estimation', - dest='is_population_estimation', - action='store_true', - help='Indicate whether the cohort is constructed for population level ' - 'estimation') - parser.add_argument('--att_type', - dest='att_type', - action='store', - choices=[e.value for e in AttType]) - parser.add_argument('--exclude_demographic', - dest='exclude_demographic', - action='store_true', - help='Indicate whether we should exclude the demographic prompt of the patient sequence') - parser.add_argument('--use_age_group', - dest='use_age_group', - action='store_true', - help='Indicate whether we should age group to represent the age at the first event in the ' - 'patient sequence') - parser.add_argument('--single_contribution', - dest='single_contribution', - action='store_true', - help='Indicate whether we should contribute once to the training data') + """ + Defines and parses the command-line arguments for Spark applications. + + that generate cohort definitions based on EHR data. + + Returns: + argparse.Namespace: The parsed arguments as a namespace object containing the user + inputs. + + Command-line Arguments: + -c, --cohort_name: The name of the cohort being generated. + -i, --input_folder: The folder path containing the input data. + --patient_splits_folder: The folder containing patient splits data. + -o, --output_folder: The folder path to store the output data. + --ehr_table_list: List of EHR domain tables for feature extraction. + -dl, --date_lower_bound: The lower bound for date filtering. + -du, --date_upper_bound: The upper bound for date filtering. + -l, --age_lower_bound: The minimum age filter for cohort inclusion. + -u, --age_upper_bound: The maximum age filter for cohort inclusion. + -ow, --observation_window: The observation window duration in days. + -pw, --prediction_window: The prediction window duration in days. + -ps, --prediction_start_days: The start point of the prediction window in days. + -hw, --hold_off_window: The hold-off window for excluding certain features. + --num_of_visits: The minimum number of visits required for cohort inclusion. + --num_of_concepts: The minimum number of concepts required for cohort inclusion. + -iw, --is_window_post_index: Whether the observation window is post-index. + -iv, --include_visit_type: Whether to include visit types in feature generation. + -ev, --exclude_visit_tokens: Whether to exclude certain visit tokens (VS and VE). + -f, --is_feature_concept_frequency: Whether the features are based on concept counts. + -ir, --is_roll_up_concept: Whether to roll up concepts to their ancestors. + -ip, --is_new_patient_representation: Whether to use a new patient representation. + --gpt_patient_sequence: Whether to generate GPT sequences for EHR records. + -ih, --is_hierarchical_bert: Whether to use a hierarchical patient representation for BERT. + -cbs, --classic_bert_seq: Whether to use classic BERT sequence representation with SEP. + --is_first_time_outcome: Whether the outcome is the first-time occurrence. + --is_remove_index_prediction_starts: Whether to remove outcomes between index and prediction + start. + --is_prediction_window_unbounded: Whether the prediction window end is unbounded. + --is_observation_window_unbounded: Whether the observation window is unbounded. + --include_concept_list: Whether to apply filters for low-frequency concepts. + --allow_measurement_only: Whether patients with only measurements are allowed. + --is_population_estimation: Whether cohort is constructed for population-level estimation. + --att_type: The attribute type used for cohort definitions. + --exclude_demographic: Whether to exclude demographic prompts in patient sequences. + --use_age_group: Whether to represent age using age groups in patient sequences. + --single_contribution: Whether patients should contribute only once to the training data. + """ + parser = argparse.ArgumentParser(description="Arguments for spark applications for generating cohort definitions") + parser.add_argument( + "-c", + "--cohort_name", + dest="cohort_name", + action="store", + help="The cohort name", + required=True, + ) + parser.add_argument( + "-i", + "--input_folder", + dest="input_folder", + action="store", + help="The path for your input_folder where the sequence data is", + required=True, + ) + parser.add_argument( + "--patient_splits_folder", + dest="patient_splits_folder", + action="store", + help="The folder that contains the patient_splits data", + required=False, + ) + parser.add_argument( + "-o", + "--output_folder", + dest="output_folder", + action="store", + help="The path for your output_folder", + required=True, + ) + parser.add_argument( + "--ehr_table_list", + dest="ehr_table_list", + nargs="+", + action="store", + help="The list of domain tables you want to include for feature extraction", + required=False, + ) + parser.add_argument( + "-dl", + "--date_lower_bound", + dest="date_lower_bound", + action="store", + help="The date filter lower bound for filtering training data", + required=True, + type=valid_date, + ) + parser.add_argument( + "-du", + "--date_upper_bound", + dest="date_upper_bound", + action="store", + help="The date filter upper bound for filtering training data", + required=True, + type=valid_date, + ) + parser.add_argument( + "-l", + "--age_lower_bound", + dest="age_lower_bound", + action="store", + help="The age lower bound", + required=False, + type=int, + default=0, + ) + parser.add_argument( + "-u", + "--age_upper_bound", + dest="age_upper_bound", + action="store", + help="The age upper bound", + required=False, + type=int, + default=100, + ) + parser.add_argument( + "-ow", + "--observation_window", + dest="observation_window", + action="store", + help="The observation window in days for extracting features", + required=False, + type=int, + default=365, + ) + parser.add_argument( + "-pw", + "--prediction_window", + dest="prediction_window", + action="store", + help="The prediction window in which the prediction is made", + required=False, + type=int, + default=180, + ) + parser.add_argument( + "-ps", + "--prediction_start_days", + dest="prediction_start_days", + action="store", + help="The prediction start days in which the prediction is made", + required=False, + type=int, + default=1, + ) + parser.add_argument( + "-hw", + "--hold_off_window", + dest="hold_off_window", + action="store", + help="The hold off window for excluding the features", + required=False, + type=int, + default=0, + ) + parser.add_argument( + "--num_of_visits", + dest="num_of_visits", + action="store", + help="The number of visits to qualify for the inclusion of the cohorts", + required=False, + type=int, + default=0, + ) + parser.add_argument( + "--num_of_concepts", + dest="num_of_concepts", + action="store", + help="The number of concepts to qualify for the inclusion of the cohorts", + required=False, + type=int, + default=0, + ) + parser.add_argument( + "-iw", + "--is_window_post_index", + dest="is_window_post_index", + action="store_true", + help="Indicate if the observation window is pre/post the index date", + ) + parser.add_argument( + "-iv", + "--include_visit_type", + dest="include_visit_type", + action="store_true", + help="Specify whether to include visit types for " "generating the training data", + ) + parser.add_argument( + "-ev", + "--exclude_visit_tokens", + dest="exclude_visit_tokens", + action="store_true", + help="Specify whether or not to exclude the VS and VE tokens", + ) + parser.add_argument( + "-f", + "--is_feature_concept_frequency", + dest="is_feature_concept_frequency", + action="store_true", + help="Specify whether the features are concept counts or not", + ) + parser.add_argument( + "-ir", + "--is_roll_up_concept", + dest="is_roll_up_concept", + action="store_true", + help="Specify whether to roll up the concepts to their ancestors", + ) + parser.add_argument( + "-ip", + "--is_new_patient_representation", + dest="is_new_patient_representation", + action="store_true", + help="Specify whether to generate the sequence of " "EHR records using the new patient representation", + ) + parser.add_argument( + "--gpt_patient_sequence", + dest="gpt_patient_sequence", + action="store_true", + help="Specify whether to generate the GPT sequence of " "EHR records using the new patient representation", + ) + parser.add_argument( + "-ih", + "--is_hierarchical_bert", + dest="is_hierarchical_bert", + action="store_true", + help="Specify whether to generate the sequence of " "EHR records using the hierarchical patient representation", + ) + parser.add_argument( + "-cbs", + "--classic_bert_seq", + dest="classic_bert_seq", + action="store_true", + help="Specify whether to generate the sequence of " + "EHR records using the classic BERT sequence representation where " + "visits are separated by a SEP token", + ) + parser.add_argument( + "--is_first_time_outcome", + dest="is_first_time_outcome", + action="store_true", + help="is the outcome the first time occurrence?", + ) + parser.add_argument( + "--is_remove_index_prediction_starts", + dest="is_remove_index_prediction_starts", + action="store_true", + help="is outcome between index_date and prediction start window removed?", + ) + parser.add_argument( + "--is_prediction_window_unbounded", + dest="is_prediction_window_unbounded", + action="store_true", + help="is the end of the prediction window unbounded?", + ) + parser.add_argument( + "--is_observation_window_unbounded", + dest="is_observation_window_unbounded", + action="store_true", + help="is the observation window unbounded?", + ) + parser.add_argument( + "--include_concept_list", + dest="include_concept_list", + action="store_true", + help="Apply the filter to remove low-frequency concepts", + ) + parser.add_argument( + "--allow_measurement_only", + dest="allow_measurement_only", + action="store_true", + help="Indicate whether we allow patients with measurements only", + ) + parser.add_argument( + "--is_population_estimation", + dest="is_population_estimation", + action="store_true", + help="Indicate whether the cohort is constructed for population level " "estimation", + ) + parser.add_argument( + "--att_type", + dest="att_type", + action="store", + choices=[e.value for e in AttType], + ) + parser.add_argument( + "--exclude_demographic", + dest="exclude_demographic", + action="store_true", + help="Indicate whether we should exclude the demographic prompt of the patient sequence", + ) + parser.add_argument( + "--use_age_group", + dest="use_age_group", + action="store_true", + help="Indicate whether we should age group to represent the age at the first event in the " "patient sequence", + ) + parser.add_argument( + "--single_contribution", + dest="single_contribution", + action="store_true", + help="Indicate whether we should contribute once to the training data", + ) return parser.parse_args() diff --git a/src/cehrbert/spark_apps/sql_templates.py b/src/cehrbert/spark_apps/sql_templates.py index d82ec999..559bc5ca 100644 --- a/src/cehrbert/spark_apps/sql_templates.py +++ b/src/cehrbert/spark_apps/sql_templates.py @@ -1,4 +1,4 @@ -measurement_unit_stats_query = ''' +measurement_unit_stats_query = """ WITH measurement_percentile AS ( SELECT @@ -22,7 +22,7 @@ SELECT m.measurement_concept_id, - m.unit_concept_id, + m.unit_concept_id, MEAN(m.value_as_number) AS value_mean, STDDEV(m.value_as_number) AS value_stddev, COUNT(*) AS measurement_freq, @@ -32,11 +32,11 @@ JOIN measurement_percentile AS mp ON m.measurement_concept_id = mp.measurement_concept_id AND m.unit_concept_id = mp.unit_concept_id -WHERE +WHERE m.value_as_number BETWEEN mp.lower_bound AND mp.upper_bound - AND m.visit_occurrence_id IS NOT NULL - AND m.unit_concept_id <> 0 + AND m.visit_occurrence_id IS NOT NULL + AND m.unit_concept_id <> 0 AND m.measurement_concept_id <> 0 GROUP BY m.measurement_concept_id, m.unit_concept_id HAVING COUNT(*) >= 100 -''' +""" diff --git a/src/cehrbert/tools/download_omop_tables.py b/src/cehrbert/tools/download_omop_tables.py index d4f00f38..3fce836f 100644 --- a/src/cehrbert/tools/download_omop_tables.py +++ b/src/cehrbert/tools/download_omop_tables.py @@ -1,96 +1,100 @@ -import configparser import argparse +import configparser import os.path from pyspark.sql import SparkSession from pyspark.sql import functions as f -omop_table_dict = {'person': 'person_id', 'condition_occurrence': 'condition_occurrence_id', - 'measurement': 'measurement_id', - 'drug_exposure': 'drug_exposure_id', - 'procedure_occurrence': 'procedure_occurrence_id', - 'observation': 'observation_id', 'visit_occurrence': 'visit_occurrence_id'} - - -def find_num_of_records( - domain_table_name, - db_properties, - column_name, - spark_session -): - table_max_id = spark_session.read.format("jdbc") \ - .option("driver", db_properties['driver']) \ - .option("url", db_properties['base_url']) \ - .option("dbtable", "(SELECT MAX({}) AS {} FROM {}) as {}".format(column_name, column_name, - domain_table_name, - column_name)) \ - .option("user", db_properties['user']) \ - .option("password", db_properties['password']) \ - .load() \ - .select("{}".format(column_name)).collect()[0]['{}'.format(column_name)] +omop_table_dict = { + "person": "person_id", + "condition_occurrence": "condition_occurrence_id", + "measurement": "measurement_id", + "drug_exposure": "drug_exposure_id", + "procedure_occurrence": "procedure_occurrence_id", + "observation": "observation_id", + "visit_occurrence": "visit_occurrence_id", +} + + +def find_num_of_records(domain_table_name, db_properties, column_name, spark_session): + table_max_id = ( + spark_session.read.format("jdbc") + .option("driver", db_properties["driver"]) + .option("url", db_properties["base_url"]) + .option( + "dbtable", + "(SELECT MAX({}) AS {} FROM {}) as {}".format(column_name, column_name, domain_table_name, column_name), + ) + .option("user", db_properties["user"]) + .option("password", db_properties["password"]) + .load() + .select("{}".format(column_name)) + .collect()[0]["{}".format(column_name)] + ) return table_max_id -def download_omop_tables_with_partitions( - domain_table, - column_name, - db_properties, - output_folder, - spark_session -): - table = spark_session.read.format("jdbc") \ - .option("url", db_properties['base_url']) \ - .option("dbtable", "%s" % domain_table) \ - .option("user", db_properties['user']) \ - .option("password", db_properties['password']) \ - .option("numPartitions", 16) \ - .option("partitionColumn", column_name) \ - .option("lowerBound", 1) \ - .option("upperBound", - find_num_of_records(domain_table, db_properties, column_name, spark_session)) \ +def download_omop_tables_with_partitions(domain_table, column_name, db_properties, output_folder, spark_session): + table = ( + spark_session.read.format("jdbc") + .option("url", db_properties["base_url"]) + .option("dbtable", "%s" % domain_table) + .option("user", db_properties["user"]) + .option("password", db_properties["password"]) + .option("numPartitions", 16) + .option("partitionColumn", column_name) + .option("lowerBound", 1) + .option( + "upperBound", + find_num_of_records(domain_table, db_properties, column_name, spark_session), + ) .load() - table.write.mode('overwrite').parquet(output_folder + '/' + str(domain_table) + '/') - - -def download_omop_tables( - domain_table, - db_properties, - output_folder, - spark_session -): - table = spark_session.read.format("jdbc") \ - .option("url", db_properties['base_url']) \ - .option("dbtable", "%s" % domain_table) \ - .option("user", db_properties['user']) \ - .option("password", db_properties['password']) \ + ) + table.write.mode("overwrite").parquet(output_folder + "/" + str(domain_table) + "/") + + +def download_omop_tables(domain_table, db_properties, output_folder, spark_session): + table = ( + spark_session.read.format("jdbc") + .option("url", db_properties["base_url"]) + .option("dbtable", "%s" % domain_table) + .option("user", db_properties["user"]) + .option("password", db_properties["password"]) .load() - table.write.mode('overwrite').parquet(output_folder + '/' + str(domain_table) + '/') + ) + table.write.mode("overwrite").parquet(output_folder + "/" + str(domain_table) + "/") if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Arguments for downloading OMOP tables') - - parser.add_argument('-c', - '--credential_path', - dest='credential_path', - action='store', - help='The path for your database credentials', - required=True) - - parser.add_argument('-tc', - '--domain_table_list', - dest='domain_table_list', - nargs='+', - action='store', - help='The list of domain tables you want to download', - required=True) - - parser.add_argument('-o', - '--output_folder', - dest='output_folder', - action='store', - help='The output folder that stores the domain tables download destination', - required=True) + parser = argparse.ArgumentParser(description="Arguments for downloading OMOP tables") + + parser.add_argument( + "-c", + "--credential_path", + dest="credential_path", + action="store", + help="The path for your database credentials", + required=True, + ) + + parser.add_argument( + "-tc", + "--domain_table_list", + dest="domain_table_list", + nargs="+", + action="store", + help="The list of domain tables you want to download", + required=True, + ) + + parser.add_argument( + "-o", + "--output_folder", + dest="output_folder", + action="store", + help="The output folder that stores the domain tables download destination", + required=True, + ) ARGS = parser.parse_args() spark = SparkSession.builder.appName("Download OMOP tables").getOrCreate() @@ -105,21 +109,22 @@ def download_omop_tables( for item in domain_table_list: try: if item in omop_table_dict: - download_omop_tables_with_partitions(item, omop_table_dict.get(item), properties, - download_folder, spark) + download_omop_tables_with_partitions( + item, omop_table_dict.get(item), properties, download_folder, spark + ) else: download_omop_tables(item, properties, download_folder, spark) downloaded_tables.append(item) - print('table: ' + str(item) + ' is downloaded') + print("table: " + str(item) + " is downloaded") except Exception as e: print(str(e)) - print('The following tables were downloaded:' + str(downloaded_tables)) + print("The following tables were downloaded:" + str(downloaded_tables)) patient_splits_folder = os.path.join(download_folder, "patient_splits") if not os.path.exists(patient_splits_folder): - print('Creating the patient splits') + print("Creating the patient splits") person = spark.read.parquet(os.path.join(download_folder, "person")) - train_split, test_split = person.select('person_id').randomSplit([0.8, 0.2], seed=42) + train_split, test_split = person.select("person_id").randomSplit([0.8, 0.2], seed=42) train_split = train_split.withColumn("split", f.lit("train")) test_split = test_split.withColumn("split", f.lit("test")) patient_splits = train_split.unionByName(test_split) diff --git a/src/cehrbert/trainers/model_trainer.py b/src/cehrbert/trainers/model_trainer.py index ecb976cb..5ff74d27 100644 --- a/src/cehrbert/trainers/model_trainer.py +++ b/src/cehrbert/trainers/model_trainer.py @@ -1,6 +1,6 @@ -import os -import json import copy +import json +import os from abc import ABC, abstractmethod from pathlib import Path @@ -9,12 +9,11 @@ import tensorflow as tf from ..data_generators.data_generator_base import AbstractDataGeneratorBase -from ..models.loss_schedulers import CosineLRSchedule from ..models.layers.custom_layers import get_custom_objects -from ..utils.logging_utils import * -from ..utils.model_utils import log_function_decorator, create_folder_if_not_exist, \ - save_training_history -from ..utils.checkpoint_utils import get_checkpoint_epoch, MODEL_CONFIG_FILE +from ..models.loss_schedulers import CosineLRSchedule +from ..utils.checkpoint_utils import MODEL_CONFIG_FILE, get_checkpoint_epoch +from ..utils.logging_utils import logging +from ..utils.model_utils import create_folder_if_not_exist, log_function_decorator, save_training_history class AbstractModel(ABC): @@ -35,16 +34,16 @@ def get_model_folder(self): pass def get_model_metrics_folder(self): - return create_folder_if_not_exist(self.get_model_folder(), 'metrics') + return create_folder_if_not_exist(self.get_model_folder(), "metrics") def get_model_test_metrics_folder(self): - return create_folder_if_not_exist(self.get_model_folder(), 'test_metrics') + return create_folder_if_not_exist(self.get_model_folder(), "test_metrics") def get_model_test_prediction_folder(self): - return create_folder_if_not_exist(self.get_model_folder(), 'test_prediction') + return create_folder_if_not_exist(self.get_model_folder(), "test_prediction") def get_model_history_folder(self): - return create_folder_if_not_exist(self.get_model_folder(), 'history') + return create_folder_if_not_exist(self.get_model_folder(), "history") @classmethod def get_logger(cls): @@ -58,22 +57,23 @@ class AbstractConceptEmbeddingTrainer(AbstractModel): min_num_of_concepts = 5 def __init__( - self, - training_data_parquet_path: str, - model_folder: str, - batch_size: int, - epochs: int, - learning_rate: float, - checkpoint_name: str = None, - val_data_parquet_path: str = None, - tf_board_log_path: str = None, - shuffle_training_data: bool = True, - cache_dataset: bool = False, - use_dask: bool = False, - save_checkpoint: bool = False, - save_freq: int = 0, - shuffle_records: bool = False, - *args, **kwargs + self, + training_data_parquet_path: str, + model_folder: str, + batch_size: int, + epochs: int, + learning_rate: float, + checkpoint_name: str = None, + val_data_parquet_path: str = None, + tf_board_log_path: str = None, + shuffle_training_data: bool = True, + cache_dataset: bool = False, + use_dask: bool = False, + save_checkpoint: bool = False, + save_freq: int = 0, + shuffle_records: bool = False, + *args, + **kwargs, ): self._training_data_parquet_path = training_data_parquet_path @@ -107,23 +107,23 @@ def __init__( super(AbstractConceptEmbeddingTrainer, self).__init__(*args, **kwargs) self.get_logger().info( - f'training_data_parquet_path: {training_data_parquet_path}\n' - f'val_data_parquet_path: {val_data_parquet_path}\n' - f'batch_size: {batch_size}\n' - f'epochs: {epochs}\n' - f'learning_rate: {learning_rate}\n' - f'model_folder: {model_folder}\n' - f'checkpoint_name: {checkpoint_name}\n' - f'tf_board_log_path: {tf_board_log_path}\n' - f'shuffle_training_data: {shuffle_training_data}\n' - f'cache_dataset: {cache_dataset}\n' - f'use_dask: {use_dask}\n' - f'save_checkpoint: {save_checkpoint}\n' - f'save_freq: {save_freq}\n' - f'shuffle_records: {shuffle_records}\n' + f"training_data_parquet_path: {training_data_parquet_path}\n" + f"val_data_parquet_path: {val_data_parquet_path}\n" + f"batch_size: {batch_size}\n" + f"epochs: {epochs}\n" + f"learning_rate: {learning_rate}\n" + f"model_folder: {model_folder}\n" + f"checkpoint_name: {checkpoint_name}\n" + f"tf_board_log_path: {tf_board_log_path}\n" + f"shuffle_training_data: {shuffle_training_data}\n" + f"cache_dataset: {cache_dataset}\n" + f"use_dask: {use_dask}\n" + f"save_checkpoint: {save_checkpoint}\n" + f"save_freq: {save_freq}\n" + f"shuffle_records: {shuffle_records}\n" ) - self.get_logger().info('Saving the model configuration') + self.get_logger().info("Saving the model configuration") self.save_model_config() @abstractmethod @@ -133,7 +133,7 @@ def _load_dependencies(self): @log_function_decorator def _load_data(self, data_parquet_path): if not os.path.exists(data_parquet_path): - raise FileExistsError(f'{data_parquet_path} does not exist!') + raise FileExistsError(f"{data_parquet_path} does not exist!") if self._use_dask: return dd.read_parquet(data_parquet_path) @@ -144,27 +144,29 @@ def _load_data(self, data_parquet_path): def create_data_generator(self) -> AbstractDataGeneratorBase: """ Prepare _training_data for the model such as tokenize concepts. + :return: """ - pass def create_val_data_generator(self) -> AbstractDataGeneratorBase: """ Prepare _training_data for the model such as tokenize concepts. + :return: """ return None def train_model(self): """ - Train the model and save the history metrics into the model folder + Train the model and save the history metrics into the model folder. + :return: """ data_generator = self.create_data_generator() steps_per_epoch = data_generator.get_steps_per_epoch() dataset = tf.data.Dataset.from_generator( data_generator.create_batch_generator, - output_types=(data_generator.get_tf_dataset_schema()) + output_types=(data_generator.get_tf_dataset_schema()), ).prefetch(tf.data.experimental.AUTOTUNE) if self._cache_dataset: @@ -178,7 +180,7 @@ def train_model(self): val_steps_per_epoch = val_data_generator.get_steps_per_epoch() val_dataset = tf.data.Dataset.from_generator( val_data_generator.create_batch_generator, - output_types=(val_data_generator.get_tf_dataset_schema()) + output_types=(val_data_generator.get_tf_dataset_schema()), ).prefetch(tf.data.experimental.AUTOTUNE) history = self._model.fit( @@ -190,7 +192,7 @@ def train_model(self): callbacks=self._get_callbacks(), validation_freq=1 if val_dataset is not None else None, initial_epoch=self._current_epoch, - use_multiprocessing=True + use_multiprocessing=True, ) save_training_history(history, self.get_model_history_folder()) @@ -200,71 +202,61 @@ def restore_from_checkpoint(self): current_epoch = get_checkpoint_epoch(existing_model_path) self._current_epoch = current_epoch self._epochs += current_epoch - self.get_logger().info( - f'The {self} model will be loaded from {existing_model_path}') - model = tf.keras.models.load_model( - existing_model_path, custom_objects=get_custom_objects() - ) + self.get_logger().info(f"The {self} model will be loaded from {existing_model_path}") + model = tf.keras.models.load_model(existing_model_path, custom_objects=get_custom_objects()) return model def _get_callbacks(self): tensor_board_callback = tf.keras.callbacks.TensorBoard(log_dir=self._tf_board_log_path) model_checkpoint_args = { - 'filepath': self.get_model_path_epoch(), - 'save_best_only': True, - 'monitor': 'loss', - 'verbose': 1 + "filepath": self.get_model_path_epoch(), + "save_best_only": True, + "monitor": "loss", + "verbose": 1, } - model_checkpoint = tf.keras.callbacks.ModelCheckpoint( - **model_checkpoint_args - ) + model_checkpoint = tf.keras.callbacks.ModelCheckpoint(**model_checkpoint_args) learning_rate_scheduler = tf.keras.callbacks.LearningRateScheduler( CosineLRSchedule(lr_high=self._learning_rate, lr_low=1e-8, initial_period=10), - verbose=1) + verbose=1, + ) - callbacks = [ - tensor_board_callback, - model_checkpoint, - learning_rate_scheduler - ] + callbacks = [tensor_board_callback, model_checkpoint, learning_rate_scheduler] # Additional step-based checkpoint callback if self._save_checkpoint: + def on_epoch_begin(self, epoch, logs=None): self._current_epoch = epoch self._last_batch_seen = -1 self._batches_seen_since_last_saving = 0 frequency_checkpoint_args = copy.deepcopy(model_checkpoint_args) - frequency_checkpoint_args['filepath'] = self.get_model_path_step() - frequency_checkpoint_args['save_freq'] = self._save_freq - frequency_checkpoint_args['name'] = ' ' + frequency_checkpoint_args["filepath"] = self.get_model_path_step() + frequency_checkpoint_args["save_freq"] = self._save_freq + frequency_checkpoint_args["name"] = " " # Monkey patch the on_epoch_begin in ModelCheckpoint because we need to clear out _last_batch_seen and # _batches_seen_since_last_saving So the batch number in the model checkpoints created is a multiple of # save_freq frequencyModelCheckpoint = tf.keras.callbacks.ModelCheckpoint frequencyModelCheckpoint.on_epoch_begin = on_epoch_begin - callbacks.append( - frequencyModelCheckpoint( - **frequency_checkpoint_args - ) - ) + callbacks.append(frequencyModelCheckpoint(**frequency_checkpoint_args)) return callbacks def get_model_folder(self): """ - Infer the model folder from the property model_path + Infer the model folder from the property model_path. + :return: """ return str(Path(self._model_folder)) def get_model_path_epoch(self): - model_name = f"{self.get_model_name()}" + '_epoch_{epoch:02d}_batch_final.h5' + model_name = f"{self.get_model_name()}" + "_epoch_{epoch:02d}_batch_final.h5" return os.path.join(self.get_model_folder(), model_name) def get_model_path_step(self): - model_name = f"{self.get_model_name()}" + '_epoch_{epoch:02d}_batch_{batch:02d}.h5' + model_name = f"{self.get_model_name()}" + "_epoch_{epoch:02d}_batch_{batch:02d}.h5" return os.path.join(self.get_model_folder(), model_name) def get_tokenizer_name(self): @@ -287,25 +279,28 @@ def checkpoint_exists(self): def get_model_config(self): def remove_first_underscore(name): - if name[0] == '_': + if name[0] == "_": return name[1:] return name model_config = { - remove_first_underscore(k): v for k, v in self.__dict__.items() + remove_first_underscore(k): v + for k, v in self.__dict__.items() if type(v) in (int, float, str, bool, type(None)) } - model_config.update({ - 'model_name': self.get_model_name(), - 'tokenizer': self.get_tokenizer_name() - }) + model_config.update( + { + "model_name": self.get_model_name(), + "tokenizer": self.get_tokenizer_name(), + } + ) return model_config def save_model_config(self): model_config = self.get_model_config() model_config_path = os.path.join(self.get_model_folder(), MODEL_CONFIG_FILE) if not os.path.exists(model_config_path): - with open(model_config_path, 'w') as f: + with open(model_config_path, "w") as f: f.write(json.dumps(model_config)) @abstractmethod diff --git a/src/cehrbert/trainers/train_cehr_bert.py b/src/cehrbert/trainers/train_cehr_bert.py index e9f53abc..0ddd3018 100644 --- a/src/cehrbert/trainers/train_cehr_bert.py +++ b/src/cehrbert/trainers/train_cehr_bert.py @@ -1,34 +1,35 @@ import tensorflow as tf +from tensorflow.keras import optimizers + +from ..data_generators.data_generator_base import ( + BertDataGenerator, + BertVisitPredictionDataGenerator, + MedBertDataGenerator, +) +from ..keras_transformer.bert import MaskedPenalizedSparseCategoricalCrossentropy, masked_perplexity +from ..models.bert_models import transformer_bert_model +from ..models.bert_models_visit_prediction import transformer_bert_model_visit_prediction from ..models.parse_args import create_parse_args_base_bert from ..trainers.model_trainer import AbstractConceptEmbeddingTrainer from ..utils.model_utils import tokenize_one_field -from ..models.bert_models_visit_prediction import transformer_bert_model_visit_prediction -from ..models.bert_models import transformer_bert_model -from ..data_generators.data_generator_base import * - -from ..keras_transformer.bert import ( - masked_perplexity, - MaskedPenalizedSparseCategoricalCrossentropy -) - -from tensorflow.keras import optimizers class VanillaBertTrainer(AbstractConceptEmbeddingTrainer): confidence_penalty = 0.1 def __init__( - self, - embedding_size: int, - context_window_size: int, - depth: int, - num_heads: int, - include_visit_prediction: bool, - include_prolonged_length_stay: bool, - use_time_embedding: bool, - use_behrt: bool, - time_embeddings_size: int, - *args, **kwargs + self, + embedding_size: int, + context_window_size: int, + depth: int, + num_heads: int, + include_visit_prediction: bool, + include_prolonged_length_stay: bool, + use_time_embedding: bool, + use_behrt: bool, + time_embeddings_size: int, + *args, + **kwargs, ): self._embedding_size = embedding_size self._context_window_size = context_window_size @@ -43,53 +44,49 @@ def __init__( super(VanillaBertTrainer, self).__init__(*args, **kwargs) self.get_logger().info( - f'{self} will be trained with the following parameters:\n' - f'model_name: {self.get_model_name()}\n' - f'tokenizer_path: {self.get_tokenizer_path()}\n' - f'visit_tokenizer_path: {self.get_visit_tokenizer_path()}\n' - f'embedding_size: {embedding_size}\n' - f'context_window_size: {context_window_size}\n' - f'depth: {depth}\n' - f'num_heads: {num_heads}\n' - f'include_visit_prediction: {include_visit_prediction}\n' - f'include_prolonged_length_stay: {include_prolonged_length_stay}\n' - f'use_time_embeddings: {use_time_embedding}\n' - f'use_behrt: {use_behrt}\n' - f'time_embeddings_size: {time_embeddings_size}') + f"{self} will be trained with the following parameters:\n" + f"model_name: {self.get_model_name()}\n" + f"tokenizer_path: {self.get_tokenizer_path()}\n" + f"visit_tokenizer_path: {self.get_visit_tokenizer_path()}\n" + f"embedding_size: {embedding_size}\n" + f"context_window_size: {context_window_size}\n" + f"depth: {depth}\n" + f"num_heads: {num_heads}\n" + f"include_visit_prediction: {include_visit_prediction}\n" + f"include_prolonged_length_stay: {include_prolonged_length_stay}\n" + f"use_time_embeddings: {use_time_embedding}\n" + f"use_behrt: {use_behrt}\n" + f"time_embeddings_size: {time_embeddings_size}" + ) def _load_dependencies(self): - self._tokenizer = tokenize_one_field( - self._training_data, - 'concept_ids', - 'token_ids', - self.get_tokenizer_path() - ) + self._tokenizer = tokenize_one_field(self._training_data, "concept_ids", "token_ids", self.get_tokenizer_path()) if self._include_visit_prediction: self._visit_tokenizer = tokenize_one_field( self._training_data, - 'visit_concept_ids', - 'visit_token_ids', + "visit_concept_ids", + "visit_token_ids", self.get_visit_tokenizer_path(), - oov_token='-1' + oov_token="-1", ) def create_data_generator(self) -> BertDataGenerator: parameters = { - 'training_data': self._training_data, - 'batch_size': self._batch_size, - 'max_seq_len': self._context_window_size, - 'min_num_of_concepts': self.min_num_of_concepts, - 'concept_tokenizer': self._tokenizer, - 'is_random_cursor': True + "training_data": self._training_data, + "batch_size": self._batch_size, + "max_seq_len": self._context_window_size, + "min_num_of_concepts": self.min_num_of_concepts, + "concept_tokenizer": self._tokenizer, + "is_random_cursor": True, } data_generator_class = BertDataGenerator if self._include_visit_prediction: - parameters['visit_tokenizer'] = self._visit_tokenizer + parameters["visit_tokenizer"] = self._visit_tokenizer data_generator_class = BertVisitPredictionDataGenerator elif self._include_prolonged_length_stay: data_generator_class = MedBertDataGenerator @@ -98,13 +95,12 @@ def create_data_generator(self) -> BertDataGenerator: def _create_model(self): strategy = tf.distribute.MirroredStrategy() - self.get_logger().info('Number of devices: {}'.format(strategy.num_replicas_in_sync)) + self.get_logger().info("Number of devices: {}".format(strategy.num_replicas_in_sync)) with strategy.scope(): if self.checkpoint_exists(): model = self.restore_from_checkpoint() else: - optimizer = optimizers.Adam( - lr=self._learning_rate, beta_1=0.9, beta_2=0.999) + optimizer = optimizers.Adam(lr=self._learning_rate, beta_1=0.9, beta_2=0.999) if self._include_visit_prediction: model = transformer_bert_model_visit_prediction( @@ -115,14 +111,12 @@ def _create_model(self): depth=self._depth, num_heads=self._num_heads, use_time_embedding=self._use_time_embedding, - time_embeddings_size=self._time_embeddings_size + time_embeddings_size=self._time_embeddings_size, ) losses = { - 'concept_predictions': MaskedPenalizedSparseCategoricalCrossentropy( - self.confidence_penalty), - 'visit_predictions': MaskedPenalizedSparseCategoricalCrossentropy( - self.confidence_penalty) + "concept_predictions": MaskedPenalizedSparseCategoricalCrossentropy(self.confidence_penalty), + "visit_predictions": MaskedPenalizedSparseCategoricalCrossentropy(self.confidence_penalty), } else: model = transformer_bert_model( @@ -134,28 +128,30 @@ def _create_model(self): use_time_embedding=self._use_time_embedding, time_embeddings_size=self._time_embeddings_size, use_behrt=self._use_behrt, - include_prolonged_length_stay=self._include_prolonged_length_stay + include_prolonged_length_stay=self._include_prolonged_length_stay, ) losses = { - 'concept_predictions': MaskedPenalizedSparseCategoricalCrossentropy( - self.confidence_penalty) + "concept_predictions": MaskedPenalizedSparseCategoricalCrossentropy(self.confidence_penalty) } if self._include_prolonged_length_stay: - losses['prolonged_length_stay'] = tf.losses.BinaryCrossentropy() + losses["prolonged_length_stay"] = tf.losses.BinaryCrossentropy() - model.compile(optimizer, loss=losses, - metrics={'concept_predictions': masked_perplexity}) + model.compile( + optimizer, + loss=losses, + metrics={"concept_predictions": masked_perplexity}, + ) return model def get_model_name(self): - return 'CEHR_BERT' + return "CEHR_BERT" def get_model_config(self): model_config = super().get_model_config() if self._include_visit_prediction: - model_config['visit_tokenizer'] = self.get_visit_tokenizer_name() + model_config["visit_tokenizer"] = self.get_visit_tokenizer_name() return model_config @@ -177,7 +173,7 @@ def main(args): time_embeddings_size=args.time_embeddings_size, use_behrt=args.use_behrt, use_dask=args.use_dask, - tf_board_log_path=args.tf_board_log_path + tf_board_log_path=args.tf_board_log_path, ).train_model() diff --git a/src/cehrbert/utils/checkpoint_utils.py b/src/cehrbert/utils/checkpoint_utils.py index 5d081394..89546019 100644 --- a/src/cehrbert/utils/checkpoint_utils.py +++ b/src/cehrbert/utils/checkpoint_utils.py @@ -25,20 +25,18 @@ def get_checkpoint_epoch(checkpoint_path): return epoch raise RuntimeError( - f'The model checkpoint at {checkpoint_path} does not match any patterns below:\n' - f'{LEGACY_MODEL_CHECKPOINT_PATTERN.pattern}\n' - f'{EPOCH_CHECKPOINT_PATTERN.pattern}\n' - f'{BATCH_CHECKPOINT_PATTERN.pattern}\n' + f"The model checkpoint at {checkpoint_path} does not match any patterns below:\n" + f"{LEGACY_MODEL_CHECKPOINT_PATTERN.pattern}\n" + f"{EPOCH_CHECKPOINT_PATTERN.pattern}\n" + f"{BATCH_CHECKPOINT_PATTERN.pattern}\n" ) -def find_latest_checkpoint_path( - checkpoint_dir -): +def find_latest_checkpoint_path(checkpoint_dir): # Try to find the checkpoint with the legacy model naming convention legacy_checkpoint_path_dict = find_latest_checkpoint_legacy_model_path(checkpoint_dir) if legacy_checkpoint_path_dict: - return legacy_checkpoint_path_dict['checkpoint_path'] + return legacy_checkpoint_path_dict["checkpoint_path"] # Try to find the checkpoints associated with batch or epoch epoch_checkpoint_path_dict = find_latest_epoch_checkpoint_path(checkpoint_dir) @@ -46,22 +44,22 @@ def find_latest_checkpoint_path( # We always prefer the epoch checkpoint over the batch checkpoint if they have the same epoch if epoch_checkpoint_path_dict and batch_checkpoint_path_dict: - if epoch_checkpoint_path_dict['epoch'] >= batch_checkpoint_path_dict['epoch']: - return epoch_checkpoint_path_dict['checkpoint_path'] + if epoch_checkpoint_path_dict["epoch"] >= batch_checkpoint_path_dict["epoch"]: + return epoch_checkpoint_path_dict["checkpoint_path"] else: - return batch_checkpoint_path_dict['checkpoint_path'] + return batch_checkpoint_path_dict["checkpoint_path"] if epoch_checkpoint_path_dict: - return epoch_checkpoint_path_dict['checkpoint_path'] + return epoch_checkpoint_path_dict["checkpoint_path"] if batch_checkpoint_path_dict: - return batch_checkpoint_path_dict['checkpoint_path'] + return batch_checkpoint_path_dict["checkpoint_path"] raise RuntimeError( - f'Could not discover any model checkpoint in {checkpoint_dir} matching patterns\n' - f'{LEGACY_MODEL_CHECKPOINT_PATTERN.pattern}\n' - f'{EPOCH_CHECKPOINT_PATTERN.pattern}\n' - f'{BATCH_CHECKPOINT_PATTERN.pattern}\n' + f"Could not discover any model checkpoint in {checkpoint_dir} matching patterns\n" + f"{LEGACY_MODEL_CHECKPOINT_PATTERN.pattern}\n" + f"{EPOCH_CHECKPOINT_PATTERN.pattern}\n" + f"{BATCH_CHECKPOINT_PATTERN.pattern}\n" ) @@ -92,9 +90,9 @@ def find_latest_checkpoint_legacy_model_path(checkpoint_dir): checkpoints.sort(reverse=True, key=lambda x: (x[0], -x[1])) if checkpoints: return { - 'epoch': checkpoints[0][0], - 'loss': checkpoints[0][1], - 'checkpoint_path': checkpoints[0][2] + "epoch": checkpoints[0][0], + "loss": checkpoints[0][1], + "checkpoint_path": checkpoints[0][2], } return None @@ -107,9 +105,7 @@ def get_latest_epoch_checkpoint_epoch(filename): return None -def find_latest_epoch_checkpoint_path( - checkpoint_dir -): +def find_latest_epoch_checkpoint_path(checkpoint_dir): # List all files in the checkpoint directory files = os.listdir(checkpoint_dir) @@ -124,10 +120,7 @@ def find_latest_epoch_checkpoint_path( # Sort the checkpoints by epoch in descending order (to get the latest one first) checkpoints.sort(reverse=True, key=lambda x: x[0]) if checkpoints: - return { - 'epoch': checkpoints[0][0], - 'checkpoint_path': checkpoints[0][1] - } + return {"epoch": checkpoints[0][0], "checkpoint_path": checkpoints[0][1]} return None @@ -156,48 +149,50 @@ def find_latest_batch_checkpoint_path(checkpoint_dir): # Return the latest checkpoint filename, or None if no checkpoint was found if checkpoints: return { - 'epoch': checkpoints[0][0], - 'batch': checkpoints[0][1], - 'checkpoint_path': checkpoints[0][2] + "epoch": checkpoints[0][0], + "batch": checkpoints[0][1], + "checkpoint_path": checkpoints[0][2], } return None def find_tokenizer_path(model_folder: str): import glob + file_path = os.path.join(model_folder, MODEL_CONFIG_FILE) if os.path.exists(file_path): # Open the JSON file for reading - with open(file_path, 'r') as file: + with open(file_path, "r") as file: model_config = json.load(file) - tokenizer_name = model_config['tokenizer'] + tokenizer_name = model_config["tokenizer"] tokenizer_path = os.path.join(model_folder, tokenizer_name) return tokenizer_path else: - for candidate_name in glob.glob(os.path.join(model_folder, '*tokenizer.pickle')): - if 'visit_tokenizer.pickle' not in candidate_name: + for candidate_name in glob.glob(os.path.join(model_folder, "*tokenizer.pickle")): + if "visit_tokenizer.pickle" not in candidate_name: return os.path.join(model_folder, candidate_name) - raise RuntimeError(f'Could not discover any tokenizer in {model_folder} matching the pattern *tokenizer.pickle') + raise RuntimeError(f"Could not discover any tokenizer in {model_folder} matching the pattern *tokenizer.pickle") def find_visit_tokenizer_path(model_folder: str): import glob + file_path = os.path.join(model_folder, MODEL_CONFIG_FILE) if os.path.exists(file_path): # Open the JSON file for reading - with open(file_path, 'r') as file: + with open(file_path, "r") as file: model_config = json.load(file) - visit_tokenizer_name = model_config['visit_tokenizer'] + visit_tokenizer_name = model_config["visit_tokenizer"] visit_tokenizer_path = os.path.join(model_folder, visit_tokenizer_name) return visit_tokenizer_path else: - for candidate_name in glob.glob(os.path.join(model_folder, '*visit_tokenizer.pickle')): + for candidate_name in glob.glob(os.path.join(model_folder, "*visit_tokenizer.pickle")): return os.path.join(model_folder, candidate_name) raise RuntimeError( - f'Could not discover any tokenizer in {model_folder} matching the pattern *_visit_tokenizer.pickle' + f"Could not discover any tokenizer in {model_folder} matching the pattern *_visit_tokenizer.pickle" ) -MODEL_CONFIG_FILE = 'model_config.json' +MODEL_CONFIG_FILE = "model_config.json" diff --git a/src/cehrbert/utils/logging_utils.py b/src/cehrbert/utils/logging_utils.py index 07cf7ac9..ddacadb7 100644 --- a/src/cehrbert/utils/logging_utils.py +++ b/src/cehrbert/utils/logging_utils.py @@ -8,7 +8,7 @@ def add_console_logging(): handler = logging.StreamHandler(sys.stdout) handler.setLevel(logging.DEBUG) - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) root.addHandler(handler) diff --git a/src/cehrbert/utils/model_utils.py b/src/cehrbert/utils/model_utils.py index 049f4f1b..5f80c5a1 100644 --- a/src/cehrbert/utils/model_utils.py +++ b/src/cehrbert/utils/model_utils.py @@ -2,19 +2,19 @@ import inspect import logging import os -import pathlib import pickle import random import re from collections import Counter from itertools import chain -from typing import Dict, Union, Tuple, List +from pathlib import Path +from typing import Dict, List, Tuple, Union import numpy as np import pandas as pd import tensorflow as tf -from dask.dataframe import DataFrame as dd_dataframe -from pandas import DataFrame as pd_dataframe +from dask.dataframe import DataFrame as DaskDataFrame +from pandas import DataFrame as PandasDataFrame from sklearn import metrics from sklearn.linear_model import LogisticRegression from sklearn.model_selection import GridSearchCV @@ -22,25 +22,33 @@ from tensorflow.keras.models import Model from xgboost import XGBClassifier -from ..data_generators.data_classes import TokenizeFieldInfo -from ..data_generators.tokenizer import ConceptTokenizer +from cehrbert.data_generators.data_classes import TokenizeFieldInfo +from cehrbert.data_generators.tokenizer import ConceptTokenizer +DEFAULT_OOV_TOKEN = "-1" DECIMAL_PLACE = 4 LOGGER = logging.getLogger(__name__) -def create_folder_if_not_exist(folder, sub_folder_name): +def create_folder_if_not_exist(folder: str, sub_folder_name: str) -> Path: """ - Create the sub-folder if not exists. Will do not thing if the sub-folder already exists. + Creates a subfolder if it does not exist and returns the full Path object. - :param folder: - :param sub_folder_name: - :return: + Args: + folder (str): The parent folder where the subfolder will be created. + sub_folder_name (str): The name of the subfolder to be created. + + Returns: + Path: The full path to the created or existing subfolder. + + Example: + create_sub_folder("/home/user", "new_folder") + # Creates /home/user/new_folder if it doesn't exist and returns the path. """ - sub_folder = os.path.join(folder, sub_folder_name) - if not os.path.exists(sub_folder): - LOGGER.info(f'Create folder: {sub_folder}') - pathlib.Path(sub_folder).mkdir(parents=True, exist_ok=True) + sub_folder = Path(folder) / sub_folder_name + if not sub_folder.exists(): + LOGGER.info("Create folder: %s", sub_folder) + sub_folder.mkdir(parents=True, exist_ok=True) return sub_folder @@ -52,40 +60,49 @@ def wrapper(self, *args, **kwargs): beginning = datetime.datetime.now() logging.getLogger(function.__name__).info( - f'Started running {module_name}: {function_name} at line {line_no}') + "Started running %s: %s at line %s", module_name, function_name, line_no + ) output = function(self, *args, **kwargs) ending = datetime.datetime.now() logging.getLogger(function.__name__).info( - f'Took {ending - beginning} to run {module_name}: {function_name}.') + "Took %s to run %s: %s.", ending - beginning, module_name, function_name + ) return output return wrapper @log_function_decorator -def tokenize_one_field(training_data: Union[pd_dataframe, dd_dataframe], - column_name, tokenized_column_name, tokenizer_path, - oov_token='-1', encode=True, recreate=False): +def tokenize_one_field( + training_data: Union[PandasDataFrame, DaskDataFrame], + column_name, + tokenized_column_name, + tokenizer_path, + oov_token=DEFAULT_OOV_TOKEN, + encode=True, + recreate=False, +): """ - Tokenize the concept sequence and save the tokenizer as a pickle file + Tokenize the concept sequence and save the tokenizer as a pickle file. + :return: """ - tokenize_fields_info = [TokenizeFieldInfo(column_name=column_name, - tokenized_column_name=tokenized_column_name)] - return tokenize_multiple_fields(training_data, - tokenize_fields_info, - tokenizer_path, - oov_token, - encode, - recreate) + tokenize_fields_info = [TokenizeFieldInfo(column_name=column_name, tokenized_column_name=tokenized_column_name)] + return tokenize_multiple_fields(training_data, tokenize_fields_info, tokenizer_path, oov_token, encode, recreate) @log_function_decorator -def tokenize_multiple_fields(training_data: Union[pd_dataframe, dd_dataframe], - tokenize_fields_info: List[TokenizeFieldInfo], tokenizer_path, - oov_token='-1', encode=True, recreate=False): +def tokenize_multiple_fields( + training_data: Union[PandasDataFrame, DaskDataFrame], + tokenize_fields_info: List[TokenizeFieldInfo], + tokenizer_path, + oov_token=DEFAULT_OOV_TOKEN, + encode=True, + recreate=False, +): """ - Tokenize a list of fields + Tokenize a list of fields. + :param training_data: :param tokenize_fields_info: :param tokenizer_path: @@ -97,38 +114,45 @@ def tokenize_multiple_fields(training_data: Union[pd_dataframe, dd_dataframe], def tokenize_one_column(_tokenize_field_info: TokenizeFieldInfo): """ - Tokenize a field + Tokenize a field. + :param _tokenize_field_info: :return: """ - if isinstance(training_data, dd_dataframe): + if isinstance(training_data, DaskDataFrame): training_data[_tokenize_field_info.tokenized_column_name] = training_data[ - _tokenize_field_info.column_name].map_partitions( + _tokenize_field_info.column_name + ].map_partitions( lambda ds: pd.Series( - tokenizer.encode(map(lambda t: list(t[1]), ds.iteritems()), - is_generator=True), - name=_tokenize_field_info.tokenized_column_name), meta='iterable') + tokenizer.encode(map(lambda t: list(t[1]), ds.iteritems()), is_generator=True), + name=_tokenize_field_info.tokenized_column_name, + ), + meta="iterable", + ) else: - training_data[_tokenize_field_info.column_name] = training_data[ - _tokenize_field_info.column_name].apply(list) + training_data[_tokenize_field_info.column_name] = training_data[_tokenize_field_info.column_name].apply( + list + ) training_data[_tokenize_field_info.tokenized_column_name] = tokenizer.encode( - training_data[_tokenize_field_info.column_name]) + training_data[_tokenize_field_info.column_name] + ) if not os.path.exists(tokenizer_path) or recreate: tokenizer = ConceptTokenizer(oov_token=oov_token) for tokenize_field_info in tokenize_fields_info: tokenizer.fit_on_concept_sequences(training_data[tokenize_field_info.column_name]) else: - logging.getLogger(__name__).info( - f'Loading the existing tokenizer from {tokenizer_path}') - tokenizer = pickle.load(open(tokenizer_path, 'rb')) + logging.getLogger(__name__).info("Loading the existing tokenizer from %s", tokenizer_path) + with open(tokenizer_path, "rb") as f: + tokenizer = pickle.load(f) if encode: for tokenize_field_info in tokenize_fields_info: tokenize_one_column(tokenize_field_info) if not os.path.exists(tokenizer_path) or recreate: - pickle.dump(tokenizer, open(tokenizer_path, 'wb')) + with open(tokenizer_path, "wb") as f: + pickle.dump(tokenizer, f) return tokenizer @@ -138,8 +162,8 @@ def convert_to_list_of_lists(concept_lists): @log_function_decorator def run_model( - model, - dataset: Union[Dataset, Tuple[np.ndarray, np.ndarray]], + model, + dataset: Union[Dataset, Tuple[np.ndarray, np.ndarray]], ): if isinstance(dataset, Dataset): x = dataset.map(lambda _x, _y: _x) @@ -148,49 +172,43 @@ def run_model( elif len(dataset) == 2: x, y = dataset else: - raise TypeError('Only numpy array and tensorflow Dataset are supported types.') + raise TypeError("Only numpy array and tensorflow Dataset are supported types.") if isinstance(model, Model): prob = model.predict(x) elif isinstance(model, (LogisticRegression, XGBClassifier, GridSearchCV)): prob = model.predict_proba(x)[:, 1] else: - raise TypeError(f'Unknown type for the model {type(model)}') + raise TypeError(f"Unknown type for the model {type(model)}") return np.asarray(prob), y -def calculate_pr_auc( - labels, - probabilities -): +def calculate_pr_auc(labels, probabilities): """ - Calculate PR AUC given labels and probabilities + Calculate PR AUC given labels and probabilities. :param labels: :param probabilities: :return: """ # Calculate precision-recall auc - precisions, recalls, _ = metrics.precision_recall_curve( - labels, - np.asarray(probabilities) - ) + precisions, recalls, _ = metrics.precision_recall_curve(labels, np.asarray(probabilities)) return metrics.auc(recalls, precisions) @log_function_decorator def compute_binary_metrics( - model, - test_data: Union[Dataset, Tuple[np.ndarray, np.ndarray]], - metrics_folder, - evaluation_model_folder: str = None, - model_name: str = None, - extra_info: dict = None, - calculate_ci: bool = True + model, + test_data: Union[Dataset, Tuple[np.ndarray, np.ndarray]], + metrics_folder, + evaluation_model_folder: str = None, + model_name: str = None, + extra_info: dict = None, + calculate_ci: bool = True, ): """ - Compute Recall, Precision, F1-score and PR-AUC for the test data + Compute Recall, Precision, F1-score and PR-AUC for the test data. :param model: :param test_data: @@ -202,13 +220,10 @@ def compute_binary_metrics( :return: """ - def compute_confidence_interval( - x, - y, - metric_func - ): + def compute_confidence_interval(x, y, metric_func): """ - A helper function to calculate the 95% confidence interval for a given metric function + A helper function to calculate the 95% confidence interval for a given metric function. + :param x: :param y: :param metric_func: @@ -219,12 +234,7 @@ def compute_confidence_interval( bootstrap_metrics = [] total = len(y) for _ in range(1001): - x_sample, y_sample = zip( - *random.choices( - list(zip(x, y)), - k=total - ) - ) + x_sample, y_sample = zip(*random.choices(list(zip(x, y)), k=total)) bootstrap_metrics.append(metric_func(x_sample, y_sample)) bootstrap_metrics = sorted(bootstrap_metrics) @@ -233,21 +243,15 @@ def compute_confidence_interval( validate_folder(metrics_folder) - probabilities, labels = run_model( - model, - test_data - ) + probabilities, labels = run_model(model, test_data) predictions = (np.asarray(probabilities) > 0.5).astype(int) - recall = metrics.recall_score(labels, predictions, average='binary') - precision = metrics.precision_score(labels, predictions, average='binary') - f1_score = metrics.f1_score(labels, predictions, average='binary') + recall = metrics.recall_score(labels, predictions, average="binary") + precision = metrics.precision_score(labels, predictions, average="binary") + f1_score = metrics.f1_score(labels, predictions, average="binary") # Calculate precision-recall auc - precisions, recalls, pr_auc_thresholds = metrics.precision_recall_curve( - labels, - np.asarray(probabilities) - ) + precisions, recalls, pr_auc_thresholds = metrics.precision_recall_curve(labels, np.asarray(probabilities)) pr_auc = metrics.auc(recalls, precisions) # Calculate the best threshold for pr auc @@ -258,9 +262,7 @@ def compute_confidence_interval( # Calculate the 95% CI for pr_auc if calculate_ci: pr_auc_lower, pr_auc_upper = compute_confidence_interval( - x=labels, - y=probabilities, - metric_func=calculate_pr_auc + x=labels, y=probabilities, metric_func=calculate_pr_auc ) else: pr_auc_lower = pr_auc_upper = pr_auc @@ -269,9 +271,7 @@ def compute_confidence_interval( roc_auc = metrics.roc_auc_score(labels, probabilities) if calculate_ci: roc_auc_lower, roc_auc_upper = compute_confidence_interval( - x=labels, - y=probabilities, - metric_func=metrics.roc_auc_score + x=labels, y=probabilities, metric_func=metrics.roc_auc_score ) else: roc_auc_lower = roc_auc_upper = roc_auc @@ -284,40 +284,39 @@ def compute_confidence_interval( current_time = datetime.datetime.now().strftime("%m-%d-%Y-%H-%M-%S") data_metrics = { - 'model_name': model_name, - 'time_stamp': [current_time], - 'recall': [round(recall, DECIMAL_PLACE)], - 'precision': [round(precision, DECIMAL_PLACE)], - 'f1-score': [round(f1_score, DECIMAL_PLACE)], - 'pr_auc': [round(pr_auc, DECIMAL_PLACE)], - 'pr_auc_ci': f'({round(pr_auc_lower, DECIMAL_PLACE)}, {round(pr_auc_upper, DECIMAL_PLACE)})', - 'pr_auc_best_threshold': round(pr_auc_best_threshold, DECIMAL_PLACE), - 'roc_auc': [round(roc_auc, DECIMAL_PLACE)], - 'roc_auc_ci': f'({round(roc_auc_lower, DECIMAL_PLACE)}, {round(roc_auc_upper, DECIMAL_PLACE)})', - 'roc_auc_best_threshold': round(roc_auc_best_threshold, DECIMAL_PLACE) + "model_name": model_name, + "time_stamp": [current_time], + "recall": [round(recall, DECIMAL_PLACE)], + "precision": [round(precision, DECIMAL_PLACE)], + "f1-score": [round(f1_score, DECIMAL_PLACE)], + "pr_auc": [round(pr_auc, DECIMAL_PLACE)], + "pr_auc_ci": f"({round(pr_auc_lower, DECIMAL_PLACE)}, {round(pr_auc_upper, DECIMAL_PLACE)})", + "pr_auc_best_threshold": round(pr_auc_best_threshold, DECIMAL_PLACE), + "roc_auc": [round(roc_auc, DECIMAL_PLACE)], + "roc_auc_ci": f"({round(roc_auc_lower, DECIMAL_PLACE)}, {round(roc_auc_upper, DECIMAL_PLACE)})", + "roc_auc_best_threshold": round(roc_auc_best_threshold, DECIMAL_PLACE), } if extra_info: # Add the additional information to the metrics - tf.print(f'Adding extra_info to the metrics folder: {extra_info}') - data_metrics.update( - extra_info - ) + tf.print(f"Adding extra_info to the metrics folder: {extra_info}") + data_metrics.update(extra_info) data_metrics_pd = pd.DataFrame(data_metrics) - data_metrics_pd.to_parquet(os.path.join(metrics_folder, f'{current_time}.parquet')) + data_metrics_pd.to_parquet(os.path.join(metrics_folder, f"{current_time}.parquet")) if evaluation_model_folder: validate_folder(evaluation_model_folder) - prediction_pd = pd.DataFrame(zip(labels, probabilities), columns=['label', 'prediction']) - prediction_pd.to_parquet(os.path.join(evaluation_model_folder, f'{current_time}.parquet')) + prediction_pd = pd.DataFrame(zip(labels, probabilities), columns=["label", "prediction"]) + prediction_pd.to_parquet(os.path.join(evaluation_model_folder, f"{current_time}.parquet")) return data_metrics def save_training_history(history: Dict, history_folder, model_name: str = None): """ - Save the training metrics in the history dictionary as pandas dataframe to the file + Save the training metrics in the history dictionary as pandas dataframe to the file. + system in parquet format :param history: @@ -329,28 +328,29 @@ def save_training_history(history: Dict, history_folder, model_name: str = None) validate_folder(history_folder) current_time = datetime.datetime.now().strftime("%m-%d-%Y-%H-%M-%S") - history_parquet_file_path = f'{current_time}.parquet' + history_parquet_file_path = f"{current_time}.parquet" data_frame = pd.DataFrame(dict(sorted(history.history.items()))) - data_frame.insert(0, 'time_stamp', current_time) - data_frame.insert(0, 'model_name', model_name) + data_frame.insert(0, "time_stamp", current_time) + data_frame.insert(0, "model_name", model_name) data_frame.columns = data_frame.columns.astype(str) data_frame.to_parquet(os.path.join(history_folder, history_parquet_file_path)) def validate_folder(folder): if not os.path.exists(folder): - raise FileExistsError(f'{folder} does not exist!') + raise FileExistsError(f"{folder} does not exist!") def create_concept_mask(mask, max_seq_length): # mask the third dimension - concept_mask_1 = tf.tile(tf.expand_dims(tf.expand_dims(mask, axis=1), axis=-1), - [1, 1, 1, max_seq_length]) + concept_mask_1 = tf.tile(tf.expand_dims(tf.expand_dims(mask, axis=1), axis=-1), [1, 1, 1, max_seq_length]) # mask the fourth dimension concept_mask_2 = tf.expand_dims(tf.expand_dims(mask, axis=1), axis=1) concept_mask = tf.cast( - (concept_mask_1 + concept_mask_2) > 0, dtype=tf.int32, - name=f'{re.sub("[^0-9a-zA-Z]+", "", mask.name)}_mask') + (concept_mask_1 + concept_mask_2) > 0, + dtype=tf.int32, + name=f'{re.sub("[^0-9a-zA-Z]+", "", mask.name)}_mask', + ) return concept_mask diff --git a/src/cehrbert/utils/spark_utils.py b/src/cehrbert/utils/spark_utils.py index 7cdb60e8..3ad03c11 100644 --- a/src/cehrbert/utils/spark_utils.py +++ b/src/cehrbert/utils/spark_utils.py @@ -10,26 +10,57 @@ from pyspark.sql.pandas.functions import pandas_udf from ..config.output_names import QUALIFIED_CONCEPT_LIST_PATH -from ..const.common import PERSON, VISIT_OCCURRENCE, UNKNOWN_CONCEPT, MEASUREMENT, \ - CATEGORICAL_MEASUREMENT, REQUIRED_MEASUREMENT, CDM_TABLES +from ..const.common import ( + CATEGORICAL_MEASUREMENT, + CDM_TABLES, + MEASUREMENT, + PERSON, + REQUIRED_MEASUREMENT, + UNKNOWN_CONCEPT, + VISIT_OCCURRENCE, +) from ..spark_apps.decorators.patient_event_decorator import ( - DemographicPromptDecorator, PatientEventAttDecorator, PatientEventBaseDecorator, DeathEventDecorator, - time_token_func, AttType + AttType, + DeathEventDecorator, + DemographicPromptDecorator, + PatientEventAttDecorator, + PatientEventBaseDecorator, + time_token_func, ) from ..spark_apps.sql_templates import measurement_unit_stats_query -from ..utils.logging_utils import * +from ..utils.logging_utils import logging DOMAIN_KEY_FIELDS = { - 'condition_occurrence_id': [ - ('condition_concept_id', 'condition_start_date', 'condition_start_datetime', 'condition')], - 'procedure_occurrence_id': [('procedure_concept_id', 'procedure_date', 'procedure_datetime', 'procedure')], - 'drug_exposure_id': [('drug_concept_id', 'drug_exposure_start_date', 'drug_exposure_start_datetime', 'drug')], - 'measurement_id': [('measurement_concept_id', 'measurement_date', 'measurement_datetime', 'measurement')], - 'death_date': [('person_id', 'death_date', 'death_datetime', 'death')], - 'visit_concept_id': [ - ('visit_concept_id', 'visit_start_date', 'visit'), - ('discharged_to_concept_id', 'visit_end_date', 'visit') - ] + "condition_occurrence_id": [ + ( + "condition_concept_id", + "condition_start_date", + "condition_start_datetime", + "condition", + ) + ], + "procedure_occurrence_id": [("procedure_concept_id", "procedure_date", "procedure_datetime", "procedure")], + "drug_exposure_id": [ + ( + "drug_concept_id", + "drug_exposure_start_date", + "drug_exposure_start_datetime", + "drug", + ) + ], + "measurement_id": [ + ( + "measurement_concept_id", + "measurement_date", + "measurement_datetime", + "measurement", + ) + ], + "death_date": [("person_id", "death_date", "death_datetime", "death")], + "visit_concept_id": [ + ("visit_concept_id", "visit_start_date", "visit"), + ("discharged_to_concept_id", "visit_end_date", "visit"), + ], } LOGGER = logging.getLogger(__name__) @@ -40,39 +71,45 @@ def get_key_fields(domain_table) -> List[Tuple[str, str, str, str]]: for k, v in DOMAIN_KEY_FIELDS.items(): if k in field_names: return v - return [(get_concept_id_field(domain_table), get_domain_date_field(domain_table), - get_domain_datetime_field(domain_table), get_domain_field(domain_table))] + return [ + ( + get_concept_id_field(domain_table), + get_domain_date_field(domain_table), + get_domain_datetime_field(domain_table), + get_domain_field(domain_table), + ) + ] def get_domain_date_field(domain_table): # extract the domain start_date column - return [f for f in domain_table.schema.fieldNames() if 'date' in f][0] + return [f for f in domain_table.schema.fieldNames() if "date" in f][0] def get_domain_datetime_field(domain_table): # extract the domain start_date column - return [f for f in domain_table.schema.fieldNames() if 'datetime' in f][0] + return [f for f in domain_table.schema.fieldNames() if "datetime" in f][0] def get_concept_id_field(domain_table): - return [f for f in domain_table.schema.fieldNames() if 'concept_id' in f][0] + return [f for f in domain_table.schema.fieldNames() if "concept_id" in f][0] def get_domain_field(domain_table): - return get_concept_id_field(domain_table).replace('_concept_id', '') + return get_concept_id_field(domain_table).replace("_concept_id", "") def create_file_path(input_folder, table_name): - if input_folder[-1] == '/': + if input_folder[-1] == "/": file_path = input_folder + table_name else: - file_path = input_folder + '/' + table_name + file_path = input_folder + "/" + table_name return file_path def join_domain_tables(domain_tables): - """Standardize the format of OMOP domain tables using a time frame + """Standardize the format of OMOP domain tables using a time frame. Keyword arguments: domain_tables -- the array containing the OMOOP domain tabls except visit_occurrence @@ -82,39 +119,42 @@ def join_domain_tables(domain_tables): (person_id, standard_concept_id, date, lower_bound, upper_bound, domain). In this case, co-occurrence is defined as those concept ids that have co-occurred within the same time window of a patient. - """ patient_event = None for domain_table in domain_tables: # extract the domain concept_id from the table fields. E.g. condition_concept_id from # condition_occurrence extract the domain start_date column extract the name of the table - for concept_id_field, date_field, datetime_field, table_domain_field in get_key_fields(domain_table): + for ( + concept_id_field, + date_field, + datetime_field, + table_domain_field, + ) in get_key_fields(domain_table): # Remove records that don't have a date or standard_concept_id - sub_domain_table = domain_table \ - .where(F.col(date_field).isNotNull()) \ - .where(F.col(concept_id_field).isNotNull()) - datetime_field_udf = F.to_timestamp( - F.coalesce(datetime_field, date_field), - 'yyyy-MM-dd HH:mm:ss' + sub_domain_table = domain_table.where(F.col(date_field).isNotNull()).where( + F.col(concept_id_field).isNotNull() + ) + datetime_field_udf = F.to_timestamp(F.coalesce(datetime_field, date_field), "yyyy-MM-dd HH:mm:ss") + sub_domain_table = ( + sub_domain_table.where(F.col(concept_id_field).cast("string") != "0") + .withColumn("date", F.to_date(F.col(date_field))) + .withColumn("datetime", datetime_field_udf) ) - sub_domain_table = sub_domain_table.where(F.col(concept_id_field).cast('string') != '0') \ - .withColumn('date', F.to_date(F.col(date_field))) \ - .withColumn('datetime', datetime_field_udf) sub_domain_table = sub_domain_table.select( - sub_domain_table['person_id'], - sub_domain_table[concept_id_field].alias('standard_concept_id'), - sub_domain_table['date'].cast('date'), - sub_domain_table['datetime'], - sub_domain_table['visit_occurrence_id'], - F.lit(table_domain_field).alias('domain'), - F.lit(-1).alias('concept_value') + sub_domain_table["person_id"], + sub_domain_table[concept_id_field].alias("standard_concept_id"), + sub_domain_table["date"].cast("date"), + sub_domain_table["datetime"], + sub_domain_table["visit_occurrence_id"], + F.lit(table_domain_field).alias("domain"), + F.lit(-1).alias("concept_value"), ).distinct() # Remove "Patient Died" from condition_occurrence - if sub_domain_table == 'condition_occurrence': - sub_domain_table = sub_domain_table.where('condition_concept_id != 4216643') + if sub_domain_table == "condition_occurrence": + sub_domain_table = sub_domain_table.where("condition_concept_id != 4216643") if patient_event is None: patient_event = sub_domain_table @@ -125,58 +165,60 @@ def join_domain_tables(domain_tables): def preprocess_domain_table( - spark, - input_folder, - domain_table_name, - with_diagnosis_rollup=False, - with_drug_rollup=True + spark, + input_folder, + domain_table_name, + with_diagnosis_rollup=False, + with_drug_rollup=True, ): domain_table = spark.read.parquet(create_file_path(input_folder, domain_table_name)) - if 'concept' in domain_table_name.lower(): + if "concept" in domain_table_name.lower(): return domain_table # lowercase the schema fields - domain_table = domain_table.select( - [F.col(f_n).alias(f_n.lower()) for f_n in domain_table.schema.fieldNames()]) + domain_table = domain_table.select([F.col(f_n).alias(f_n.lower()) for f_n in domain_table.schema.fieldNames()]) for f_n in domain_table.schema.fieldNames(): - if 'date' in f_n and 'datetime' not in f_n: + if "date" in f_n and "datetime" not in f_n: # convert date columns to the date type domain_table = domain_table.withColumn(f_n, F.to_date(f_n)) - elif 'datetime' in f_n: + elif "datetime" in f_n: # convert date columns to the datetime type domain_table = domain_table.withColumn(f_n, F.to_timestamp(f_n)) - if domain_table_name == 'visit_occurrence': + if domain_table_name == "visit_occurrence": # This is CDM 5.2, we need to rename this column to be CDM 5.3 compatible - if 'discharge_to_concept_id' in domain_table.schema.fieldNames(): - domain_table = domain_table.withColumnRenamed('discharge_to_concept_id', 'discharged_to_concept_id') + if "discharge_to_concept_id" in domain_table.schema.fieldNames(): + domain_table = domain_table.withColumnRenamed("discharge_to_concept_id", "discharged_to_concept_id") if with_drug_rollup: - if domain_table_name == 'drug_exposure' \ - and path.exists(create_file_path(input_folder, 'concept')) \ - and path.exists(create_file_path(input_folder, 'concept_ancestor')): - concept = spark.read.parquet(create_file_path(input_folder, 'concept')) - concept_ancestor = spark.read.parquet( - create_file_path(input_folder, 'concept_ancestor')) + if ( + domain_table_name == "drug_exposure" + and path.exists(create_file_path(input_folder, "concept")) + and path.exists(create_file_path(input_folder, "concept_ancestor")) + ): + concept = spark.read.parquet(create_file_path(input_folder, "concept")) + concept_ancestor = spark.read.parquet(create_file_path(input_folder, "concept_ancestor")) domain_table = roll_up_to_drug_ingredients(domain_table, concept, concept_ancestor) if with_diagnosis_rollup: - if domain_table_name == 'condition_occurrence' \ - and path.exists(create_file_path(input_folder, 'concept')) \ - and path.exists(create_file_path(input_folder, 'concept_relationship')): - concept = spark.read.parquet(create_file_path(input_folder, 'concept')) - concept_relationship = spark.read.parquet( - create_file_path(input_folder, 'concept_relationship')) + if ( + domain_table_name == "condition_occurrence" + and path.exists(create_file_path(input_folder, "concept")) + and path.exists(create_file_path(input_folder, "concept_relationship")) + ): + concept = spark.read.parquet(create_file_path(input_folder, "concept")) + concept_relationship = spark.read.parquet(create_file_path(input_folder, "concept_relationship")) domain_table = roll_up_diagnosis(domain_table, concept, concept_relationship) - if domain_table_name == 'procedure_occurrence' \ - and path.exists(create_file_path(input_folder, 'concept')) \ - and path.exists(create_file_path(input_folder, 'concept_ancestor')): - concept = spark.read.parquet(create_file_path(input_folder, 'concept')) - concept_ancestor = spark.read.parquet( - create_file_path(input_folder, 'concept_ancestor')) + if ( + domain_table_name == "procedure_occurrence" + and path.exists(create_file_path(input_folder, "concept")) + and path.exists(create_file_path(input_folder, "concept_ancestor")) + ): + concept = spark.read.parquet(create_file_path(input_folder, "concept")) + concept_ancestor = spark.read.parquet(create_file_path(input_folder, "concept_ancestor")) domain_table = roll_up_procedure(domain_table, concept, concept_ancestor) return domain_table @@ -184,200 +226,280 @@ def preprocess_domain_table( def roll_up_to_drug_ingredients(drug_exposure, concept, concept_ancestor): # lowercase the schema fields - drug_exposure = drug_exposure.select( - [F.col(f_n).alias(f_n.lower()) for f_n in drug_exposure.schema.fieldNames()]) + drug_exposure = drug_exposure.select([F.col(f_n).alias(f_n.lower()) for f_n in drug_exposure.schema.fieldNames()]) - drug_ingredient = drug_exposure.select('drug_concept_id').distinct() \ - .join(concept_ancestor, F.col('drug_concept_id') == F.col('descendant_concept_id')) \ - .join(concept, F.col('ancestor_concept_id') == F.col('concept_id')) \ - .where(concept['concept_class_id'] == 'Ingredient') \ - .select(F.col('drug_concept_id'), F.col('concept_id').alias('ingredient_concept_id')) + drug_ingredient = ( + drug_exposure.select("drug_concept_id") + .distinct() + .join(concept_ancestor, F.col("drug_concept_id") == F.col("descendant_concept_id")) + .join(concept, F.col("ancestor_concept_id") == F.col("concept_id")) + .where(concept["concept_class_id"] == "Ingredient") + .select(F.col("drug_concept_id"), F.col("concept_id").alias("ingredient_concept_id")) + ) drug_ingredient_fields = [ - F.coalesce(F.col('ingredient_concept_id'), F.col('drug_concept_id')).alias( - 'drug_concept_id')] + F.coalesce(F.col("ingredient_concept_id"), F.col("drug_concept_id")).alias("drug_concept_id") + ] drug_ingredient_fields.extend( - [F.col(field_name) for field_name in drug_exposure.schema.fieldNames() if - field_name != 'drug_concept_id']) + [F.col(field_name) for field_name in drug_exposure.schema.fieldNames() if field_name != "drug_concept_id"] + ) - drug_exposure = drug_exposure.join(drug_ingredient, 'drug_concept_id', 'left_outer') \ - .select(drug_ingredient_fields) + drug_exposure = drug_exposure.join(drug_ingredient, "drug_concept_id", "left_outer").select(drug_ingredient_fields) return drug_exposure def roll_up_diagnosis(condition_occurrence, concept, concept_relationship): - list_3dig_code = ['3-char nonbill code', '3-dig nonbill code', '3-char billing code', - '3-dig billing code', - '3-dig billing E code', '3-dig billing V code', '3-dig nonbill E code', - '3-dig nonbill V code'] + list_3dig_code = [ + "3-char nonbill code", + "3-dig nonbill code", + "3-char billing code", + "3-dig billing code", + "3-dig billing E code", + "3-dig billing V code", + "3-dig nonbill E code", + "3-dig nonbill V code", + ] condition_occurrence = condition_occurrence.select( - [F.col(f_n).alias(f_n.lower()) for f_n in condition_occurrence.schema.fieldNames()]) - - condition_icd = condition_occurrence.select('condition_source_concept_id').distinct() \ - .join(concept, (F.col('condition_source_concept_id') == F.col('concept_id'))) \ - .where(concept['domain_id'] == 'Condition') \ - .where(concept['vocabulary_id'] != 'SNOMED') \ - .select(F.col('condition_source_concept_id'), - F.col('vocabulary_id').alias('child_vocabulary_id'), - F.col('concept_class_id').alias('child_concept_class_id')) - - condition_icd_hierarchy = condition_icd.join(concept_relationship, - F.col('condition_source_concept_id') == F.col( - 'concept_id_1')) \ - .join(concept, (F.col('concept_id_2') == F.col('concept_id')) & ( - F.col('concept_class_id').isin(list_3dig_code)), how='left') \ - .select(F.col('condition_source_concept_id').alias('source_concept_id'), - F.col('child_concept_class_id'), F.col('concept_id').alias('parent_concept_id'), - F.col('concept_name').alias('parent_concept_name'), - F.col('vocabulary_id').alias('parent_vocabulary_id'), - F.col('concept_class_id').alias('parent_concept_class_id')).distinct() - - condition_icd_hierarchy = condition_icd_hierarchy.withColumn('ancestor_concept_id', F.when( - F.col('child_concept_class_id').isin(list_3dig_code), F.col('source_concept_id')).otherwise( - F.col('parent_concept_id'))) \ - .dropna(subset='ancestor_concept_id') - - condition_occurrence_fields = [F.col(f_n).alias(f_n.lower()) for f_n in - condition_occurrence.schema.fieldNames() if - f_n != 'condition_source_concept_id'] - condition_occurrence_fields.append(F.coalesce(F.col('ancestor_concept_id'), - F.col('condition_source_concept_id')).alias( - 'condition_source_concept_id')) - - condition_occurrence = condition_occurrence.join(condition_icd_hierarchy, condition_occurrence[ - 'condition_source_concept_id'] == condition_icd_hierarchy['source_concept_id'], how='left') \ - .select(condition_occurrence_fields).withColumn('condition_concept_id', - F.col('condition_source_concept_id')) + [F.col(f_n).alias(f_n.lower()) for f_n in condition_occurrence.schema.fieldNames()] + ) + + condition_icd = ( + condition_occurrence.select("condition_source_concept_id") + .distinct() + .join(concept, (F.col("condition_source_concept_id") == F.col("concept_id"))) + .where(concept["domain_id"] == "Condition") + .where(concept["vocabulary_id"] != "SNOMED") + .select( + F.col("condition_source_concept_id"), + F.col("vocabulary_id").alias("child_vocabulary_id"), + F.col("concept_class_id").alias("child_concept_class_id"), + ) + ) + + condition_icd_hierarchy = ( + condition_icd.join( + concept_relationship, + F.col("condition_source_concept_id") == F.col("concept_id_1"), + ) + .join( + concept, + (F.col("concept_id_2") == F.col("concept_id")) & (F.col("concept_class_id").isin(list_3dig_code)), + how="left", + ) + .select( + F.col("condition_source_concept_id").alias("source_concept_id"), + F.col("child_concept_class_id"), + F.col("concept_id").alias("parent_concept_id"), + F.col("concept_name").alias("parent_concept_name"), + F.col("vocabulary_id").alias("parent_vocabulary_id"), + F.col("concept_class_id").alias("parent_concept_class_id"), + ) + .distinct() + ) + + condition_icd_hierarchy = condition_icd_hierarchy.withColumn( + "ancestor_concept_id", + F.when( + F.col("child_concept_class_id").isin(list_3dig_code), + F.col("source_concept_id"), + ).otherwise(F.col("parent_concept_id")), + ).dropna(subset="ancestor_concept_id") + + condition_occurrence_fields = [ + F.col(f_n).alias(f_n.lower()) + for f_n in condition_occurrence.schema.fieldNames() + if f_n != "condition_source_concept_id" + ] + condition_occurrence_fields.append( + F.coalesce(F.col("ancestor_concept_id"), F.col("condition_source_concept_id")).alias( + "condition_source_concept_id" + ) + ) + + condition_occurrence = ( + condition_occurrence.join( + condition_icd_hierarchy, + condition_occurrence["condition_source_concept_id"] == condition_icd_hierarchy["source_concept_id"], + how="left", + ) + .select(condition_occurrence_fields) + .withColumn("condition_concept_id", F.col("condition_source_concept_id")) + ) return condition_occurrence def roll_up_procedure(procedure_occurrence, concept, concept_ancestor): def extract_parent_code(concept_code): - return concept_code.split('.')[0] + return concept_code.split(".")[0] - parent_code_udf = F.udf(lambda code: extract_parent_code(code), T.StringType()) + parent_code_udf = F.udf(extract_parent_code, T.StringType()) - procedure_code = procedure_occurrence.select('procedure_source_concept_id').distinct() \ - .join(concept, F.col('procedure_source_concept_id') == F.col('concept_id')) \ - .where(concept['domain_id'] == 'Procedure') \ - .select(F.col('procedure_source_concept_id').alias('source_concept_id'), - F.col('vocabulary_id').alias('child_vocabulary_id'), - F.col('concept_class_id').alias('child_concept_class_id'), - F.col('concept_code').alias('child_concept_code')) + procedure_code = ( + procedure_occurrence.select("procedure_source_concept_id") + .distinct() + .join(concept, F.col("procedure_source_concept_id") == F.col("concept_id")) + .where(concept["domain_id"] == "Procedure") + .select( + F.col("procedure_source_concept_id").alias("source_concept_id"), + F.col("vocabulary_id").alias("child_vocabulary_id"), + F.col("concept_class_id").alias("child_concept_class_id"), + F.col("concept_code").alias("child_concept_code"), + ) + ) # cpt code rollup - cpt_code = procedure_code.where(F.col('child_vocabulary_id') == 'CPT4') - - cpt_hierarchy = cpt_code.join(concept_ancestor, - cpt_code['source_concept_id'] == concept_ancestor[ - 'descendant_concept_id']) \ - .join(concept, concept_ancestor['ancestor_concept_id'] == concept['concept_id']) \ - .where(concept['vocabulary_id'] == 'CPT4') \ - .select(F.col('source_concept_id'), F.col('child_concept_class_id'), - F.col('ancestor_concept_id').alias('parent_concept_id'), - F.col('min_levels_of_separation'), - F.col('concept_class_id').alias('parent_concept_class_id')) - - cpt_hierarchy_level_1 = cpt_hierarchy.where(F.col('min_levels_of_separation') == 1) \ - .where(F.col('child_concept_class_id') == 'CPT4') \ - .where(F.col('parent_concept_class_id') == 'CPT4 Hierarchy') \ - .select(F.col('source_concept_id'), F.col('parent_concept_id')) - - cpt_hierarchy_level_1 = cpt_hierarchy_level_1.join(concept_ancestor, ( - cpt_hierarchy_level_1['source_concept_id'] == concept_ancestor['descendant_concept_id']) - & (concept_ancestor[ - 'min_levels_of_separation'] == 1), - how='left') \ - .select(F.col('source_concept_id'), F.col('parent_concept_id'), - F.col('ancestor_concept_id').alias('root_concept_id')) - - cpt_hierarchy_level_1 = cpt_hierarchy_level_1.withColumn('isroot', F.when( - cpt_hierarchy_level_1['root_concept_id'] == 45889197, - cpt_hierarchy_level_1['source_concept_id']) \ - .otherwise( - cpt_hierarchy_level_1['parent_concept_id'])) \ - .select(F.col('source_concept_id'), F.col('isroot').alias('ancestor_concept_id')) - - cpt_hierarchy_level_0 = cpt_hierarchy.groupby('source_concept_id').max() \ - .where(F.col('max(min_levels_of_separation)') == 0) \ - .select(F.col('source_concept_id').alias('cpt_level_0_concept_id')) - - cpt_hierarchy_level_0 = cpt_hierarchy.join(cpt_hierarchy_level_0, - cpt_hierarchy['source_concept_id'] == - cpt_hierarchy_level_0['cpt_level_0_concept_id']) \ - .select(F.col('source_concept_id'), F.col('parent_concept_id').alias('ancestor_concept_id')) + cpt_code = procedure_code.where(F.col("child_vocabulary_id") == "CPT4") + + cpt_hierarchy = ( + cpt_code.join( + concept_ancestor, + cpt_code["source_concept_id"] == concept_ancestor["descendant_concept_id"], + ) + .join(concept, concept_ancestor["ancestor_concept_id"] == concept["concept_id"]) + .where(concept["vocabulary_id"] == "CPT4") + .select( + F.col("source_concept_id"), + F.col("child_concept_class_id"), + F.col("ancestor_concept_id").alias("parent_concept_id"), + F.col("min_levels_of_separation"), + F.col("concept_class_id").alias("parent_concept_class_id"), + ) + ) + + cpt_hierarchy_level_1 = ( + cpt_hierarchy.where(F.col("min_levels_of_separation") == 1) + .where(F.col("child_concept_class_id") == "CPT4") + .where(F.col("parent_concept_class_id") == "CPT4 Hierarchy") + .select(F.col("source_concept_id"), F.col("parent_concept_id")) + ) + + cpt_hierarchy_level_1 = cpt_hierarchy_level_1.join( + concept_ancestor, + (cpt_hierarchy_level_1["source_concept_id"] == concept_ancestor["descendant_concept_id"]) + & (concept_ancestor["min_levels_of_separation"] == 1), + how="left", + ).select( + F.col("source_concept_id"), + F.col("parent_concept_id"), + F.col("ancestor_concept_id").alias("root_concept_id"), + ) + + cpt_hierarchy_level_1 = cpt_hierarchy_level_1.withColumn( + "isroot", + F.when( + cpt_hierarchy_level_1["root_concept_id"] == 45889197, + cpt_hierarchy_level_1["source_concept_id"], + ).otherwise(cpt_hierarchy_level_1["parent_concept_id"]), + ).select(F.col("source_concept_id"), F.col("isroot").alias("ancestor_concept_id")) + + cpt_hierarchy_level_0 = ( + cpt_hierarchy.groupby("source_concept_id") + .max() + .where(F.col("max(min_levels_of_separation)") == 0) + .select(F.col("source_concept_id").alias("cpt_level_0_concept_id")) + ) + + cpt_hierarchy_level_0 = cpt_hierarchy.join( + cpt_hierarchy_level_0, + cpt_hierarchy["source_concept_id"] == cpt_hierarchy_level_0["cpt_level_0_concept_id"], + ).select( + F.col("source_concept_id"), + F.col("parent_concept_id").alias("ancestor_concept_id"), + ) cpt_hierarchy_rollup_all = cpt_hierarchy_level_1.union(cpt_hierarchy_level_0).drop_duplicates() # ICD code rollup - icd_list = ['ICD9CM', 'ICD9Proc', 'ICD10CM'] + icd_list = ["ICD9CM", "ICD9Proc", "ICD10CM"] - procedure_icd = procedure_code.where(F.col('vocabulary_id').isin(icd_list)) + procedure_icd = procedure_code.where(F.col("vocabulary_id").isin(icd_list)) - procedure_icd = procedure_icd.withColumn('parent_concept_code', - parent_code_udf(F.col('child_concept_code'))) \ - .withColumnRenamed('procedure_source_concept_id', 'source_concept_id') \ - .withColumnRenamed('concept_name', 'child_concept_name') \ - .withColumnRenamed('vocabulary_id', 'child_vocabulary_id') \ - .withColumnRenamed('concept_code', 'child_concept_code') \ - .withColumnRenamed('concept_class_id', 'child_concept_class_id') + procedure_icd = ( + procedure_icd.withColumn("parent_concept_code", parent_code_udf(F.col("child_concept_code"))) + .withColumnRenamed("procedure_source_concept_id", "source_concept_id") + .withColumnRenamed("concept_name", "child_concept_name") + .withColumnRenamed("vocabulary_id", "child_vocabulary_id") + .withColumnRenamed("concept_code", "child_concept_code") + .withColumnRenamed("concept_class_id", "child_concept_class_id") + ) - procedure_icd_map = procedure_icd.join(concept, ( - procedure_icd['parent_concept_code'] == concept['concept_code']) - & (procedure_icd['child_vocabulary_id'] == concept[ - 'vocabulary_id']), how='left') \ - .select('source_concept_id', F.col('concept_id').alias('ancestor_concept_id')).distinct() + procedure_icd_map = ( + procedure_icd.join( + concept, + (procedure_icd["parent_concept_code"] == concept["concept_code"]) + & (procedure_icd["child_vocabulary_id"] == concept["vocabulary_id"]), + how="left", + ) + .select("source_concept_id", F.col("concept_id").alias("ancestor_concept_id")) + .distinct() + ) # ICD10PCS rollup - procedure_10pcs = procedure_code.where(F.col('vocabulary_id') == 'ICD10PCS') - - procedure_10pcs = procedure_10pcs.withColumn('parent_concept_code', - F.substring(F.col('child_concept_code'), 1, 3)) \ - .withColumnRenamed('procedure_source_concept_id', 'source_concept_id') \ - .withColumnRenamed('concept_name', 'child_concept_name') \ - .withColumnRenamed('vocabulary_id', 'child_vocabulary_id') \ - .withColumnRenamed('concept_code', 'child_concept_code') \ - .withColumnRenamed('concept_class_id', 'child_concept_class_id') - - procedure_10pcs_map = procedure_10pcs.join(concept, ( - procedure_10pcs['parent_concept_code'] == concept['concept_code']) - & (procedure_10pcs['child_vocabulary_id'] == concept[ - 'vocabulary_id']), how='left') \ - .select('source_concept_id', F.col('concept_id').alias('ancestor_concept_id')).distinct() + procedure_10pcs = procedure_code.where(F.col("vocabulary_id") == "ICD10PCS") + + procedure_10pcs = ( + procedure_10pcs.withColumn("parent_concept_code", F.substring(F.col("child_concept_code"), 1, 3)) + .withColumnRenamed("procedure_source_concept_id", "source_concept_id") + .withColumnRenamed("concept_name", "child_concept_name") + .withColumnRenamed("vocabulary_id", "child_vocabulary_id") + .withColumnRenamed("concept_code", "child_concept_code") + .withColumnRenamed("concept_class_id", "child_concept_class_id") + ) + + procedure_10pcs_map = ( + procedure_10pcs.join( + concept, + (procedure_10pcs["parent_concept_code"] == concept["concept_code"]) + & (procedure_10pcs["child_vocabulary_id"] == concept["vocabulary_id"]), + how="left", + ) + .select("source_concept_id", F.col("concept_id").alias("ancestor_concept_id")) + .distinct() + ) # HCPCS rollup --- keep the concept_id itself - procedure_hcpcs = procedure_code.where(F.col('child_vocabulary_id') == 'HCPCS') - procedure_hcpcs_map = procedure_hcpcs.withColumn('ancestor_concept_id', - F.col('source_concept_id')) \ - .select('source_concept_id', 'ancestor_concept_id').distinct() - - procedure_hierarchy = cpt_hierarchy_rollup_all \ - .union(procedure_icd_map) \ - .union(procedure_10pcs_map) \ - .union(procedure_hcpcs_map) \ + procedure_hcpcs = procedure_code.where(F.col("child_vocabulary_id") == "HCPCS") + procedure_hcpcs_map = ( + procedure_hcpcs.withColumn("ancestor_concept_id", F.col("source_concept_id")) + .select("source_concept_id", "ancestor_concept_id") .distinct() - procedure_occurrence_fields = [F.col(f_n).alias(f_n.lower()) for f_n in - procedure_occurrence.schema.fieldNames() if - f_n != 'procedure_source_concept_id'] - procedure_occurrence_fields.append(F.coalesce(F.col('ancestor_concept_id'), - F.col('procedure_source_concept_id')).alias( - 'procedure_source_concept_id')) - - procedure_occurrence = procedure_occurrence.join(procedure_hierarchy, procedure_occurrence[ - 'procedure_source_concept_id'] == procedure_hierarchy['source_concept_id'], how='left') \ - .select(procedure_occurrence_fields) \ - .withColumn('procedure_concept_id', F.col('procedure_source_concept_id')) + ) + + procedure_hierarchy = ( + cpt_hierarchy_rollup_all.union(procedure_icd_map) + .union(procedure_10pcs_map) + .union(procedure_hcpcs_map) + .distinct() + ) + procedure_occurrence_fields = [ + F.col(f_n).alias(f_n.lower()) + for f_n in procedure_occurrence.schema.fieldNames() + if f_n != "procedure_source_concept_id" + ] + procedure_occurrence_fields.append( + F.coalesce(F.col("ancestor_concept_id"), F.col("procedure_source_concept_id")).alias( + "procedure_source_concept_id" + ) + ) + + procedure_occurrence = ( + procedure_occurrence.join( + procedure_hierarchy, + procedure_occurrence["procedure_source_concept_id"] == procedure_hierarchy["source_concept_id"], + how="left", + ) + .select(procedure_occurrence_fields) + .withColumn("procedure_concept_id", F.col("procedure_source_concept_id")) + ) return procedure_occurrence -def create_sequence_data(patient_event, - date_filter=None, - include_visit_type=False, - classic_bert_seq=False): +def create_sequence_data(patient_event, date_filter=None, include_visit_type=False, classic_bert_seq=False): """ - Create a sequence of the events associated with one patient in a chronological order + Create a sequence of the events associated with one patient in a chronological order. + :param patient_event: :param date_filter: :param include_visit_type: @@ -386,107 +508,138 @@ def create_sequence_data(patient_event, """ if date_filter: - patient_event = patient_event.where(F.col('date') >= date_filter) + patient_event = patient_event.where(F.col("date") >= date_filter) # Define a list of custom UDFs for creating custom columns - date_conversion_udf = (F.unix_timestamp('date') / F.lit(24 * 60 * 60 * 7)).cast('int') - earliest_visit_date_udf = F.min('date_in_week').over(W.partitionBy('visit_occurrence_id')) + date_conversion_udf = (F.unix_timestamp("date") / F.lit(24 * 60 * 60 * 7)).cast("int") + earliest_visit_date_udf = F.min("date_in_week").over(W.partitionBy("visit_occurrence_id")) - visit_rank_udf = F.dense_rank().over( - W.partitionBy('cohort_member_id', 'person_id').orderBy('earliest_visit_date')) - visit_segment_udf = F.col('visit_rank_order') % F.lit(2) + 1 + visit_rank_udf = F.dense_rank().over(W.partitionBy("cohort_member_id", "person_id").orderBy("earliest_visit_date")) + visit_segment_udf = F.col("visit_rank_order") % F.lit(2) + 1 # Derive columns - patient_event = patient_event.where('visit_occurrence_id IS NOT NULL') \ - .withColumn('date_in_week', date_conversion_udf) \ - .withColumn('earliest_visit_date', earliest_visit_date_udf) \ - .withColumn('visit_rank_order', visit_rank_udf) \ - .withColumn('visit_segment', visit_segment_udf) \ - .withColumn('priority', F.lit(0)) + patient_event = ( + patient_event.where("visit_occurrence_id IS NOT NULL") + .withColumn("date_in_week", date_conversion_udf) + .withColumn("earliest_visit_date", earliest_visit_date_udf) + .withColumn("visit_rank_order", visit_rank_udf) + .withColumn("visit_segment", visit_segment_udf) + .withColumn("priority", F.lit(0)) + ) if classic_bert_seq: # Udf for identifying the earliest date associated with a visit_occurrence_id - visit_start_date_udf = F.first('date').over( - W.partitionBy('cohort_member_id', 'person_id', 'visit_occurrence_id').orderBy('date')) + visit_start_date_udf = F.first("date").over( + W.partitionBy("cohort_member_id", "person_id", "visit_occurrence_id").orderBy("date") + ) # Udf for identifying the previous visit_occurrence_id - prev_visit_occurrence_id_udf = F.lag('visit_occurrence_id').over( - W.partitionBy('cohort_member_id', 'person_id').orderBy('visit_start_date', - 'visit_occurrence_id')) + prev_visit_occurrence_id_udf = F.lag("visit_occurrence_id").over( + W.partitionBy("cohort_member_id", "person_id").orderBy("visit_start_date", "visit_occurrence_id") + ) # We can achieve this by overwriting the record with the earliest time stamp - separator_events = patient_event.withColumn('visit_start_date', visit_start_date_udf) \ - .withColumn('prev_visit_occurrence_id', prev_visit_occurrence_id_udf) \ - .where('prev_visit_occurrence_id IS NOT NULL') \ - .where('visit_occurrence_id <> prev_visit_occurrence_id') \ - .withColumn('domain', F.lit('Separator')) \ - .withColumn('standard_concept_id', F.lit('SEP')) \ - .withColumn('priority', F.lit(-1)) \ - .withColumn('visit_segment', F.lit(0)) \ + separator_events = ( + patient_event.withColumn("visit_start_date", visit_start_date_udf) + .withColumn("prev_visit_occurrence_id", prev_visit_occurrence_id_udf) + .where("prev_visit_occurrence_id IS NOT NULL") + .where("visit_occurrence_id <> prev_visit_occurrence_id") + .withColumn("domain", F.lit("Separator")) + .withColumn("standard_concept_id", F.lit("SEP")) + .withColumn("priority", F.lit(-1)) + .withColumn("visit_segment", F.lit(0)) .select(patient_event.schema.fieldNames()) + ) # Combine this artificial token SEP with the original data patient_event = patient_event.union(separator_events) order_udf = F.row_number().over( - W.partitionBy('cohort_member_id', 'person_id').orderBy('earliest_visit_date', - 'visit_occurrence_id', - 'priority', 'date_in_week', - 'standard_concept_id')) + W.partitionBy("cohort_member_id", "person_id").orderBy( + "earliest_visit_date", + "visit_occurrence_id", + "priority", + "date_in_week", + "standard_concept_id", + ) + ) # Group the data into sequences - output_columns = ['order', 'date_in_week', 'standard_concept_id', - 'visit_segment', 'age', 'visit_rank_order'] + output_columns = [ + "order", + "date_in_week", + "standard_concept_id", + "visit_segment", + "age", + "visit_rank_order", + ] if include_visit_type: - output_columns.append('visit_concept_id') + output_columns.append("visit_concept_id") # Group by data by person_id and put all the events into a list # The order of the list is determined by the order column - patient_grouped_events = patient_event.withColumn('order', order_udf) \ - .withColumn('date_concept_id_period', F.struct(output_columns)) \ - .groupBy('person_id', 'cohort_member_id') \ - .agg(F.sort_array(F.collect_set('date_concept_id_period')).alias('date_concept_id_period'), - F.min('earliest_visit_date').alias('earliest_visit_date'), - F.max('date').alias('max_event_date'), - F.max('visit_rank_order').alias('num_of_visits'), - F.count('standard_concept_id').alias('num_of_concepts')) \ - .withColumn('orders', - F.col('date_concept_id_period.order').cast(T.ArrayType(T.IntegerType()))) \ - .withColumn('dates', F.col('date_concept_id_period.date_in_week')) \ - .withColumn('concept_ids', F.col('date_concept_id_period.standard_concept_id')) \ - .withColumn('visit_segments', F.col('date_concept_id_period.visit_segment')) \ - .withColumn('ages', F.col('date_concept_id_period.age')) \ - .withColumn('visit_concept_orders', F.col('date_concept_id_period.visit_rank_order')) + patient_grouped_events = ( + patient_event.withColumn("order", order_udf) + .withColumn("date_concept_id_period", F.struct(output_columns)) + .groupBy("person_id", "cohort_member_id") + .agg( + F.sort_array(F.collect_set("date_concept_id_period")).alias("date_concept_id_period"), + F.min("earliest_visit_date").alias("earliest_visit_date"), + F.max("date").alias("max_event_date"), + F.max("visit_rank_order").alias("num_of_visits"), + F.count("standard_concept_id").alias("num_of_concepts"), + ) + .withColumn( + "orders", + F.col("date_concept_id_period.order").cast(T.ArrayType(T.IntegerType())), + ) + .withColumn("dates", F.col("date_concept_id_period.date_in_week")) + .withColumn("concept_ids", F.col("date_concept_id_period.standard_concept_id")) + .withColumn("visit_segments", F.col("date_concept_id_period.visit_segment")) + .withColumn("ages", F.col("date_concept_id_period.age")) + .withColumn("visit_concept_orders", F.col("date_concept_id_period.visit_rank_order")) + ) # Default columns in the output dataframe - columns_for_output = ['cohort_member_id', 'person_id', 'earliest_visit_date', - 'max_event_date', 'orders', 'dates', 'ages', 'concept_ids', - 'visit_segments', 'visit_concept_orders', 'num_of_visits', - 'num_of_concepts'] + columns_for_output = [ + "cohort_member_id", + "person_id", + "earliest_visit_date", + "max_event_date", + "orders", + "dates", + "ages", + "concept_ids", + "visit_segments", + "visit_concept_orders", + "num_of_visits", + "num_of_concepts", + ] if include_visit_type: - patient_grouped_events = patient_grouped_events \ - .withColumn('visit_concept_ids', F.col('date_concept_id_period.visit_concept_id')) - columns_for_output.append('visit_concept_ids') + patient_grouped_events = patient_grouped_events.withColumn( + "visit_concept_ids", F.col("date_concept_id_period.visit_concept_id") + ) + columns_for_output.append("visit_concept_ids") return patient_grouped_events.select(columns_for_output) def create_sequence_data_with_att( - patient_events, - visit_occurrence, - date_filter=None, - include_visit_type=False, - exclude_visit_tokens=False, - patient_demographic=None, - death=None, - att_type: AttType = AttType.CEHR_BERT, - exclude_demographic: bool = True, - use_age_group: bool = False, - include_inpatient_hour_token: bool = False + patient_events, + visit_occurrence, + date_filter=None, + include_visit_type=False, + exclude_visit_tokens=False, + patient_demographic=None, + death=None, + att_type: AttType = AttType.CEHR_BERT, + exclude_demographic: bool = True, + use_age_group: bool = False, + include_inpatient_hour_token: bool = False, ): """ - Create a sequence of the events associated with one patient in a chronological order + Create a sequence of the events associated with one patient in a chronological order. :param patient_events: :param visit_occurrence: @@ -503,7 +656,7 @@ def create_sequence_data_with_att( :return: """ if date_filter: - patient_events = patient_events.where(F.col('date').cast('date') >= date_filter) + patient_events = patient_events.where(F.col("date").cast("date") >= date_filter) decorators = [ PatientEventBaseDecorator(visit_occurrence), @@ -512,10 +665,10 @@ def create_sequence_data_with_att( include_visit_type, exclude_visit_tokens, att_type, - include_inpatient_hour_token + include_inpatient_hour_token, ), # DemographicPromptDecorator(patient_demographic), - DeathEventDecorator(death, att_type) + DeathEventDecorator(death, att_type), ] if not exclude_demographic: @@ -526,90 +679,126 @@ def create_sequence_data_with_att( # add randomness to the order of the concepts that have the same time stamp order_udf = F.row_number().over( - W.partitionBy('cohort_member_id', 'person_id').orderBy( - 'visit_rank_order', - 'concept_order', - 'priority', - 'datetime', - 'standard_concept_id' + W.partitionBy("cohort_member_id", "person_id").orderBy( + "visit_rank_order", + "concept_order", + "priority", + "datetime", + "standard_concept_id", ) ) dense_rank_udf = F.dense_rank().over( - W.partitionBy('cohort_member_id', 'person_id').orderBy( - 'visit_rank_order', - 'concept_order', - 'priority', - 'datetime') + W.partitionBy("cohort_member_id", "person_id").orderBy( + "visit_rank_order", "concept_order", "priority", "datetime" + ) ) # Those columns are derived from the previous decorators struct_columns = [ - 'order', 'record_rank', 'date_in_week', 'standard_concept_id', 'visit_segment', 'age', - 'visit_rank_order', 'concept_value_mask', 'concept_value', 'mlm_skip_value', - 'visit_concept_id', 'visit_concept_order', 'concept_order', 'priority' + "order", + "record_rank", + "date_in_week", + "standard_concept_id", + "visit_segment", + "age", + "visit_rank_order", + "concept_value_mask", + "concept_value", + "mlm_skip_value", + "visit_concept_id", + "visit_concept_order", + "concept_order", + "priority", ] output_columns = [ - 'cohort_member_id', 'person_id', 'concept_ids', 'visit_segments', 'orders', - 'dates', 'ages', 'visit_concept_orders', 'num_of_visits', 'num_of_concepts', - 'concept_value_masks', 'concept_values', 'mlm_skip_values', 'priorities', - 'visit_concept_ids', 'visit_rank_orders', 'concept_orders', 'record_ranks' + "cohort_member_id", + "person_id", + "concept_ids", + "visit_segments", + "orders", + "dates", + "ages", + "visit_concept_orders", + "num_of_visits", + "num_of_concepts", + "concept_value_masks", + "concept_values", + "mlm_skip_values", + "priorities", + "visit_concept_ids", + "visit_rank_orders", + "concept_orders", + "record_ranks", ] - patient_grouped_events = patient_events \ - .withColumn('order', order_udf) \ - .withColumn('record_rank', dense_rank_udf) \ - .withColumn('data_for_sorting', F.struct(struct_columns)) \ - .groupBy('cohort_member_id', 'person_id') \ - .agg(F.sort_array(F.collect_set('data_for_sorting')).alias('data_for_sorting'), - F.max('visit_rank_order').alias('num_of_visits'), - F.count('standard_concept_id').alias('num_of_concepts')) \ - .withColumn('orders', F.col('data_for_sorting.order').cast(T.ArrayType(T.IntegerType()))) \ - .withColumn('record_ranks', F.col('data_for_sorting.record_rank').cast(T.ArrayType(T.IntegerType()))) \ - .withColumn('dates', F.col('data_for_sorting.date_in_week')) \ - .withColumn('concept_ids', F.col('data_for_sorting.standard_concept_id')) \ - .withColumn('visit_segments', F.col('data_for_sorting.visit_segment')) \ - .withColumn('ages', F.col('data_for_sorting.age')) \ - .withColumn('visit_rank_orders', F.col('data_for_sorting.visit_rank_order')) \ - .withColumn('visit_concept_orders', F.col('data_for_sorting.visit_concept_order')) \ - .withColumn('concept_orders', F.col('data_for_sorting.concept_order')) \ - .withColumn('priorities', F.col('data_for_sorting.priority')) \ - .withColumn('concept_value_masks', F.col('data_for_sorting.concept_value_mask')) \ - .withColumn('concept_values', F.col('data_for_sorting.concept_value')) \ - .withColumn('mlm_skip_values', F.col('data_for_sorting.mlm_skip_value')) \ - .withColumn('visit_concept_ids', F.col('data_for_sorting.visit_concept_id')) + patient_grouped_events = ( + patient_events.withColumn("order", order_udf) + .withColumn("record_rank", dense_rank_udf) + .withColumn("data_for_sorting", F.struct(struct_columns)) + .groupBy("cohort_member_id", "person_id") + .agg( + F.sort_array(F.collect_set("data_for_sorting")).alias("data_for_sorting"), + F.max("visit_rank_order").alias("num_of_visits"), + F.count("standard_concept_id").alias("num_of_concepts"), + ) + .withColumn("orders", F.col("data_for_sorting.order").cast(T.ArrayType(T.IntegerType()))) + .withColumn( + "record_ranks", + F.col("data_for_sorting.record_rank").cast(T.ArrayType(T.IntegerType())), + ) + .withColumn("dates", F.col("data_for_sorting.date_in_week")) + .withColumn("concept_ids", F.col("data_for_sorting.standard_concept_id")) + .withColumn("visit_segments", F.col("data_for_sorting.visit_segment")) + .withColumn("ages", F.col("data_for_sorting.age")) + .withColumn("visit_rank_orders", F.col("data_for_sorting.visit_rank_order")) + .withColumn("visit_concept_orders", F.col("data_for_sorting.visit_concept_order")) + .withColumn("concept_orders", F.col("data_for_sorting.concept_order")) + .withColumn("priorities", F.col("data_for_sorting.priority")) + .withColumn("concept_value_masks", F.col("data_for_sorting.concept_value_mask")) + .withColumn("concept_values", F.col("data_for_sorting.concept_value")) + .withColumn("mlm_skip_values", F.col("data_for_sorting.mlm_skip_value")) + .withColumn("visit_concept_ids", F.col("data_for_sorting.visit_concept_id")) + ) return patient_grouped_events.select(output_columns) def create_concept_frequency_data(patient_event, date_filter=None): if date_filter: - patient_event = patient_event.where(F.col('date') >= date_filter) + patient_event = patient_event.where(F.col("date") >= date_filter) take_concept_ids_udf = F.udf(lambda rows: [row[0] for row in rows], T.ArrayType(T.StringType())) take_freqs_udf = F.udf(lambda rows: [row[1] for row in rows], T.ArrayType(T.IntegerType())) - num_of_visits_concepts = patient_event.groupBy('cohort_member_id', 'person_id') \ - .agg( - F.countDistinct('visit_occurrence_id').alias('num_of_visits'), - F.count('standard_concept_id').alias('num_of_concepts') + num_of_visits_concepts = patient_event.groupBy("cohort_member_id", "person_id").agg( + F.countDistinct("visit_occurrence_id").alias("num_of_visits"), + F.count("standard_concept_id").alias("num_of_concepts"), ) - patient_event = patient_event.groupBy( - 'cohort_member_id', 'person_id', 'standard_concept_id').count() \ - .withColumn('concept_id_freq', F.struct('standard_concept_id', 'count')) \ - .groupBy('cohort_member_id', 'person_id').agg( - F.collect_list('concept_id_freq').alias('sequence')) \ - .withColumn('concept_ids', take_concept_ids_udf('sequence')) \ - .withColumn('frequencies', take_freqs_udf('sequence')) \ - .select('cohort_member_id', 'person_id', 'concept_ids', 'frequencies') \ - .join(num_of_visits_concepts, ['person_id', 'cohort_member_id']) + patient_event = ( + patient_event.groupBy("cohort_member_id", "person_id", "standard_concept_id") + .count() + .withColumn("concept_id_freq", F.struct("standard_concept_id", "count")) + .groupBy("cohort_member_id", "person_id") + .agg(F.collect_list("concept_id_freq").alias("sequence")) + .withColumn("concept_ids", take_concept_ids_udf("sequence")) + .withColumn("frequencies", take_freqs_udf("sequence")) + .select("cohort_member_id", "person_id", "concept_ids", "frequencies") + .join(num_of_visits_concepts, ["person_id", "cohort_member_id"]) + ) return patient_event -def extract_ehr_records(spark, input_folder, domain_table_list, include_visit_type=False, - with_rollup=False, include_concept_list=False): +def extract_ehr_records( + spark, + input_folder, + domain_table_list, + include_visit_type=False, + with_rollup=False, + include_concept_list=False, +): """ Extract the ehr records for domain_table_list from input_folder. @@ -624,64 +813,54 @@ def extract_ehr_records(spark, input_folder, domain_table_list, include_visit_ty domain_tables = [] for domain_table_name in domain_table_list: if domain_table_name != MEASUREMENT: - domain_tables.append( - preprocess_domain_table( - spark, - input_folder, - domain_table_name, - with_rollup - ) - ) + domain_tables.append(preprocess_domain_table(spark, input_folder, domain_table_name, with_rollup)) patient_ehr_records = join_domain_tables(domain_tables) if include_concept_list and patient_ehr_records: # Filter out concepts - qualified_concepts = preprocess_domain_table( - spark, - input_folder, - QUALIFIED_CONCEPT_LIST_PATH - ).select('standard_concept_id') - - patient_ehr_records = patient_ehr_records.join( - qualified_concepts, - 'standard_concept_id' + qualified_concepts = preprocess_domain_table(spark, input_folder, QUALIFIED_CONCEPT_LIST_PATH).select( + "standard_concept_id" ) + patient_ehr_records = patient_ehr_records.join(qualified_concepts, "standard_concept_id") + # Process the measurement table if exists if MEASUREMENT in domain_table_list: measurement = preprocess_domain_table(spark, input_folder, MEASUREMENT) required_measurement = preprocess_domain_table(spark, input_folder, REQUIRED_MEASUREMENT) - scaled_measurement = process_measurement( - spark, - measurement, - required_measurement - ) + scaled_measurement = process_measurement(spark, measurement, required_measurement) if patient_ehr_records: # Union all measurement records together with other domain records - patient_ehr_records = patient_ehr_records.union( - scaled_measurement - ) + patient_ehr_records = patient_ehr_records.union(scaled_measurement) else: patient_ehr_records = scaled_measurement - patient_ehr_records = patient_ehr_records.where('visit_occurrence_id IS NOT NULL').distinct() + patient_ehr_records = patient_ehr_records.where("visit_occurrence_id IS NOT NULL").distinct() person = preprocess_domain_table(spark, input_folder, PERSON) - person = person.withColumn('birth_datetime', - F.coalesce('birth_datetime', - F.concat('year_of_birth', F.lit('-01-01')).cast( - 'timestamp'))) - patient_ehr_records = patient_ehr_records.join(person, 'person_id') \ - .withColumn('age', - F.ceil(F.months_between(F.col('date'), F.col('birth_datetime')) / F.lit(12))) + person = person.withColumn( + "birth_datetime", + F.coalesce( + "birth_datetime", + F.concat("year_of_birth", F.lit("-01-01")).cast("timestamp"), + ), + ) + patient_ehr_records = patient_ehr_records.join(person, "person_id").withColumn( + "age", + F.ceil(F.months_between(F.col("date"), F.col("birth_datetime")) / F.lit(12)), + ) if include_visit_type: visit_occurrence = preprocess_domain_table(spark, input_folder, VISIT_OCCURRENCE) - patient_ehr_records = patient_ehr_records.join(visit_occurrence, 'visit_occurrence_id') \ - .select(patient_ehr_records['person_id'], patient_ehr_records['standard_concept_id'], - patient_ehr_records['date'], patient_ehr_records['visit_occurrence_id'], - patient_ehr_records['domain'], visit_occurrence['visit_concept_id'], - patient_ehr_records['age']) + patient_ehr_records = patient_ehr_records.join(visit_occurrence, "visit_occurrence_id").select( + patient_ehr_records["person_id"], + patient_ehr_records["standard_concept_id"], + patient_ehr_records["date"], + patient_ehr_records["visit_occurrence_id"], + patient_ehr_records["domain"], + visit_occurrence["visit_concept_id"], + patient_ehr_records["age"], + ) return patient_ehr_records @@ -691,7 +870,7 @@ def build_ancestry_table_for(spark, concept_ids): SELECT cr.concept_id_1 AS ancestor_concept_id, cr.concept_id_2 AS descendant_concept_id, - 1 AS distance + 1 AS distance FROM global_temp.concept_relationship AS cr WHERE cr.concept_id_1 in ({concept_ids}) AND cr.relationship_id = 'Subsumes' """ @@ -714,84 +893,98 @@ def build_ancestry_table_for(spark, concept_ids): * FROM global_temp.ancestry_table - UNION + UNION SELECT * FROM global_temp.candidate """ - ancestry_table = spark.sql( - initial_query.format(concept_ids=','.join([str(c) for c in concept_ids]))) - ancestry_table.createOrReplaceGlobalTempView('ancestry_table') + ancestry_table = spark.sql(initial_query.format(concept_ids=",".join([str(c) for c in concept_ids]))) + ancestry_table.createOrReplaceGlobalTempView("ancestry_table") candidate_set = spark.sql(recurring_query) - candidate_set.createOrReplaceGlobalTempView('candidate') + candidate_set.createOrReplaceGlobalTempView("candidate") while candidate_set.count() != 0: - spark.sql(union_query).createOrReplaceGlobalTempView('ancestry_table') + spark.sql(union_query).createOrReplaceGlobalTempView("ancestry_table") candidate_set = spark.sql(recurring_query) - candidate_set.createOrReplaceGlobalTempView('candidate') + candidate_set.createOrReplaceGlobalTempView("candidate") - ancestry_table = spark.sql(""" - SELECT + ancestry_table = spark.sql( + """ + SELECT * FROM global_temp.ancestry_table - """) + """ + ) - spark.sql(""" + spark.sql( + """ DROP VIEW global_temp.ancestry_table - """) + """ + ) return ancestry_table def get_descendant_concept_ids(spark, concept_ids): """ - Query concept_ancestor table to get all descendant_concept_ids for the given list of concept_ids + Query concept_ancestor table to get all descendant_concept_ids for the given list of concept_ids. + :param spark: :param concept_ids: :return: """ - descendant_concept_ids = spark.sql(""" + sanitized_concept_ids = [int(c) for c in concept_ids] + # Join the sanitized IDs into a string for the query + concept_ids_str = ",".join(map(str, sanitized_concept_ids)) + # Construct and execute the SQL query using the sanitized string + descendant_concept_ids = spark.sql( + f""" SELECT DISTINCT c.* FROM global_temp.concept_ancestor AS ca - JOIN global_temp.concept AS c + JOIN global_temp.concept AS c ON ca.descendant_concept_id = c.concept_id - WHERE ca.ancestor_concept_id IN ({concept_ids}) - """.format(concept_ids=','.join([str(c) for c in concept_ids]))) + WHERE ca.ancestor_concept_id IN ({concept_ids_str}) + """ + ) return descendant_concept_ids def get_standard_concept_ids(spark, concept_ids): - standard_concept_ids = spark.sql(""" + standard_concept_ids = spark.sql( + """ SELECT DISTINCT c.* FROM global_temp.concept_relationship AS cr - JOIN global_temp.concept AS c + JOIN global_temp.concept AS c ON ca.concept_id_2 = c.concept_id AND cr.relationship_id = 'Maps to' WHERE ca.concept_id_1 IN ({concept_ids}) - """.format(concept_ids=','.join([str(c) for c in concept_ids]))) + """.format( + concept_ids=",".join([str(c) for c in concept_ids]) + ) + ) return standard_concept_ids def get_table_column_refs(dataframe): - return [dataframe[fieldName] for fieldName in - dataframe.schema.fieldNames()] + return [dataframe[fieldName] for fieldName in dataframe.schema.fieldNames()] def create_hierarchical_sequence_data( - person, - visit_occurrence, - patient_events, - date_filter=None, - max_num_of_visits_per_person=None, - include_incomplete_visit=True, - allow_measurement_only=False + person, + visit_occurrence, + patient_events, + date_filter=None, + max_num_of_visits_per_person=None, + include_incomplete_visit=True, + allow_measurement_only=False, ): """ - This creates a hierarchical data frame for the hierarchical bert model + This creates a hierarchical data frame for the hierarchical bert model. + :param person: :param visit_occurrence: :param patient_events: @@ -803,188 +996,209 @@ def create_hierarchical_sequence_data( """ if date_filter: - visit_occurrence = visit_occurrence.where( - F.col('visit_start_date').cast('date') >= date_filter - ) + visit_occurrence = visit_occurrence.where(F.col("visit_start_date").cast("date") >= date_filter) # Construct visit information with the person demographic - visit_occurrence_person = create_visit_person_join( - person, - visit_occurrence, - include_incomplete_visit - ) + visit_occurrence_person = create_visit_person_join(person, visit_occurrence, include_incomplete_visit) # Retrieve all visit column references visit_column_refs = get_table_column_refs(visit_occurrence_person) # Construct the patient event column references pat_col_refs = [ - F.coalesce( - patient_events['cohort_member_id'], - visit_occurrence['person_id'] - ).alias('cohort_member_id'), - F.coalesce( - patient_events['standard_concept_id'], - F.lit(UNKNOWN_CONCEPT) - ).alias('standard_concept_id'), - F.coalesce( - patient_events['date'], - visit_occurrence['visit_start_date'] - ).alias('date'), - F.coalesce( - patient_events['domain'], - F.lit('unknown') - ).alias('domain'), - F.coalesce( - patient_events['concept_value'], - F.lit(-1.0) - ).alias('concept_value') + F.coalesce(patient_events["cohort_member_id"], visit_occurrence["person_id"]).alias("cohort_member_id"), + F.coalesce(patient_events["standard_concept_id"], F.lit(UNKNOWN_CONCEPT)).alias("standard_concept_id"), + F.coalesce(patient_events["date"], visit_occurrence["visit_start_date"]).alias("date"), + F.coalesce(patient_events["domain"], F.lit("unknown")).alias("domain"), + F.coalesce(patient_events["concept_value"], F.lit(-1.0)).alias("concept_value"), ] # Convert standard_concept_id to string type, this is needed for the tokenization # Calculate the age w.r.t to the event - patient_events = visit_occurrence_person.join( - patient_events, 'visit_occurrence_id', 'left_outer') \ - .select(visit_column_refs + pat_col_refs) \ - .withColumn('standard_concept_id', F.col('standard_concept_id').cast('string')) \ - .withColumn('age', F.ceil( - F.months_between(F.col('date'), F.col("birth_datetime")) / F.lit(12))) \ - .withColumn('concept_value_mask', (F.col('domain') == MEASUREMENT).cast('int')) \ - .withColumn('mlm_skip', - (F.col('domain').isin([MEASUREMENT, CATEGORICAL_MEASUREMENT])).cast('int')) \ - .withColumn('condition_mask', (F.col('domain') == 'condition').cast('int')) + patient_events = ( + visit_occurrence_person.join(patient_events, "visit_occurrence_id", "left_outer") + .select(visit_column_refs + pat_col_refs) + .withColumn("standard_concept_id", F.col("standard_concept_id").cast("string")) + .withColumn( + "age", + F.ceil(F.months_between(F.col("date"), F.col("birth_datetime")) / F.lit(12)), + ) + .withColumn("concept_value_mask", (F.col("domain") == MEASUREMENT).cast("int")) + .withColumn( + "mlm_skip", + (F.col("domain").isin([MEASUREMENT, CATEGORICAL_MEASUREMENT])).cast("int"), + ) + .withColumn("condition_mask", (F.col("domain") == "condition").cast("int")) + ) if not allow_measurement_only: # We only allow persons that have a non measurement record in the dataset - qualified_person_df = patient_events \ - .where(~F.col('domain').isin([MEASUREMENT, CATEGORICAL_MEASUREMENT])) \ - .where(F.col('standard_concept_id') != UNKNOWN_CONCEPT) \ - .select('person_id').distinct() + qualified_person_df = ( + patient_events.where(~F.col("domain").isin([MEASUREMENT, CATEGORICAL_MEASUREMENT])) + .where(F.col("standard_concept_id") != UNKNOWN_CONCEPT) + .select("person_id") + .distinct() + ) - patient_events = patient_events.join(qualified_person_df, 'person_id') + patient_events = patient_events.join(qualified_person_df, "person_id") # Create the udf for calculating the weeks since the epoch time 1970-01-01 - weeks_since_epoch_udf = ( - F.unix_timestamp('date') / F.lit(24 * 60 * 60 * 7) - ).cast('int') + weeks_since_epoch_udf = (F.unix_timestamp("date") / F.lit(24 * 60 * 60 * 7)).cast("int") # UDF for creating the concept orders within each visit visit_concept_order_udf = F.row_number().over( - W.partitionBy('cohort_member_id', - 'person_id', - 'visit_occurrence_id').orderBy('date', 'standard_concept_id') + W.partitionBy("cohort_member_id", "person_id", "visit_occurrence_id").orderBy("date", "standard_concept_id") ) - patient_events = patient_events \ - .withColumn('date', F.col('date').cast('date')) \ - .withColumn('date_in_week', weeks_since_epoch_udf) \ - .withColumn('visit_concept_order', visit_concept_order_udf) + patient_events = ( + patient_events.withColumn("date", F.col("date").cast("date")) + .withColumn("date_in_week", weeks_since_epoch_udf) + .withColumn("visit_concept_order", visit_concept_order_udf) + ) # Insert a CLS token at the beginning of each visit, this CLS token will be used as the visit # summary in pre-training / fine-tuning. We basically make a copy of the first concept of # each visit and change it to CLS, and set the concept order to 0 to make sure this is always # the first token of each visit - insert_cls_tokens = patient_events \ - .where('visit_concept_order == 1') \ - .withColumn('standard_concept_id', F.lit('CLS')) \ - .withColumn('domain', F.lit('CLS')) \ - .withColumn('visit_concept_order', F.lit(0)) \ - .withColumn('date', F.col('visit_start_date')) \ - .withColumn('concept_value_mask', F.lit(0)) \ - .withColumn('concept_value', F.lit(-1.0)) \ - .withColumn('mlm_skip', F.lit(1)) \ - .withColumn('condition_mask', F.lit(0)) + insert_cls_tokens = ( + patient_events.where("visit_concept_order == 1") + .withColumn("standard_concept_id", F.lit("CLS")) + .withColumn("domain", F.lit("CLS")) + .withColumn("visit_concept_order", F.lit(0)) + .withColumn("date", F.col("visit_start_date")) + .withColumn("concept_value_mask", F.lit(0)) + .withColumn("concept_value", F.lit(-1.0)) + .withColumn("mlm_skip", F.lit(1)) + .withColumn("condition_mask", F.lit(0)) + ) # Declare a list of columns that need to be collected per each visit - struct_columns = ['visit_concept_order', 'standard_concept_id', 'date_in_week', - 'age', 'concept_value_mask', 'concept_value', 'mlm_skip', 'condition_mask'] + struct_columns = [ + "visit_concept_order", + "standard_concept_id", + "date_in_week", + "age", + "concept_value_mask", + "concept_value", + "mlm_skip", + "condition_mask", + ] # Merge the first CLS tokens into patient sequence and collect events for each visit - patent_visit_sequence = patient_events.union(insert_cls_tokens) \ - .withColumn('visit_struct_data', F.struct(struct_columns)) \ - .groupBy('cohort_member_id', 'person_id', 'visit_occurrence_id') \ - .agg(F.sort_array(F.collect_set('visit_struct_data')).alias('visit_struct_data'), - F.first('visit_start_date').alias('visit_start_date'), - F.first('visit_rank_order').alias('visit_rank_order'), - F.first('visit_concept_id').alias('visit_concept_id'), - F.first('is_readmission').alias('is_readmission'), - F.first('is_inpatient').alias('is_inpatient'), - F.first('visit_segment').alias('visit_segment'), - F.first('time_interval_att').alias('time_interval_att'), - F.first('prolonged_stay').alias('prolonged_stay'), - F.count('standard_concept_id').alias('num_of_concepts')) \ - .orderBy(['person_id', 'visit_rank_order']) - - patient_visit_sequence = patent_visit_sequence \ - .withColumn('visit_concept_orders', F.col('visit_struct_data.visit_concept_order')) \ - .withColumn('visit_concept_ids', F.col('visit_struct_data.standard_concept_id')) \ - .withColumn('visit_concept_dates', F.col('visit_struct_data.date_in_week')) \ - .withColumn('visit_concept_ages', F.col('visit_struct_data.age')) \ - .withColumn('concept_value_masks', F.col('visit_struct_data.concept_value_mask')) \ - .withColumn('concept_values', F.col('visit_struct_data.concept_value')) \ - .withColumn('mlm_skip_values', F.col('visit_struct_data.mlm_skip')) \ - .withColumn('condition_masks', F.col('visit_struct_data.condition_mask')) \ - .withColumn('visit_mask', F.lit(0)) \ - .drop('visit_struct_data') - - visit_struct_data_columns = ['visit_rank_order', 'visit_occurrence_id', 'visit_start_date', - 'visit_concept_id', 'prolonged_stay', 'visit_mask', - 'visit_segment', 'num_of_concepts', 'is_readmission', - 'is_inpatient', 'time_interval_att', 'visit_concept_orders', - 'visit_concept_ids', 'visit_concept_dates', 'visit_concept_ages', - 'concept_values', 'concept_value_masks', 'mlm_skip_values', - 'condition_masks'] - - visit_weeks_since_epoch_udf = (F.unix_timestamp(F.col('visit_start_date').cast('date')) / F.lit( - 24 * 60 * 60 * 7)).cast('int') - - patient_sequence = patient_visit_sequence \ - .withColumn('visit_start_date', visit_weeks_since_epoch_udf) \ - .withColumn('visit_struct_data', - F.struct(visit_struct_data_columns).alias('visit_struct_data')) \ - .groupBy('cohort_member_id', 'person_id') \ - .agg(F.sort_array(F.collect_list('visit_struct_data')).alias('patient_list'), - F.sum(F.lit(1) - F.col('visit_mask')).alias('num_of_visits'), - F.sum('num_of_concepts').alias('num_of_concepts')) + patent_visit_sequence = ( + patient_events.union(insert_cls_tokens) + .withColumn("visit_struct_data", F.struct(struct_columns)) + .groupBy("cohort_member_id", "person_id", "visit_occurrence_id") + .agg( + F.sort_array(F.collect_set("visit_struct_data")).alias("visit_struct_data"), + F.first("visit_start_date").alias("visit_start_date"), + F.first("visit_rank_order").alias("visit_rank_order"), + F.first("visit_concept_id").alias("visit_concept_id"), + F.first("is_readmission").alias("is_readmission"), + F.first("is_inpatient").alias("is_inpatient"), + F.first("visit_segment").alias("visit_segment"), + F.first("time_interval_att").alias("time_interval_att"), + F.first("prolonged_stay").alias("prolonged_stay"), + F.count("standard_concept_id").alias("num_of_concepts"), + ) + .orderBy(["person_id", "visit_rank_order"]) + ) + + patient_visit_sequence = ( + patent_visit_sequence.withColumn("visit_concept_orders", F.col("visit_struct_data.visit_concept_order")) + .withColumn("visit_concept_ids", F.col("visit_struct_data.standard_concept_id")) + .withColumn("visit_concept_dates", F.col("visit_struct_data.date_in_week")) + .withColumn("visit_concept_ages", F.col("visit_struct_data.age")) + .withColumn("concept_value_masks", F.col("visit_struct_data.concept_value_mask")) + .withColumn("concept_values", F.col("visit_struct_data.concept_value")) + .withColumn("mlm_skip_values", F.col("visit_struct_data.mlm_skip")) + .withColumn("condition_masks", F.col("visit_struct_data.condition_mask")) + .withColumn("visit_mask", F.lit(0)) + .drop("visit_struct_data") + ) + + visit_struct_data_columns = [ + "visit_rank_order", + "visit_occurrence_id", + "visit_start_date", + "visit_concept_id", + "prolonged_stay", + "visit_mask", + "visit_segment", + "num_of_concepts", + "is_readmission", + "is_inpatient", + "time_interval_att", + "visit_concept_orders", + "visit_concept_ids", + "visit_concept_dates", + "visit_concept_ages", + "concept_values", + "concept_value_masks", + "mlm_skip_values", + "condition_masks", + ] + + visit_weeks_since_epoch_udf = ( + F.unix_timestamp(F.col("visit_start_date").cast("date")) / F.lit(24 * 60 * 60 * 7) + ).cast("int") + + patient_sequence = ( + patient_visit_sequence.withColumn("visit_start_date", visit_weeks_since_epoch_udf) + .withColumn( + "visit_struct_data", + F.struct(visit_struct_data_columns).alias("visit_struct_data"), + ) + .groupBy("cohort_member_id", "person_id") + .agg( + F.sort_array(F.collect_list("visit_struct_data")).alias("patient_list"), + F.sum(F.lit(1) - F.col("visit_mask")).alias("num_of_visits"), + F.sum("num_of_concepts").alias("num_of_concepts"), + ) + ) if max_num_of_visits_per_person: - patient_sequence = patient_sequence \ - .where(F.col('num_of_visits') <= max_num_of_visits_per_person) - - patient_sequence = patient_sequence \ - .withColumn('visit_rank_orders', F.col('patient_list.visit_rank_order')) \ - .withColumn('concept_orders', F.col('patient_list.visit_concept_orders')) \ - .withColumn('concept_ids', F.col('patient_list.visit_concept_ids')) \ - .withColumn('dates', F.col('patient_list.visit_concept_dates')) \ - .withColumn('ages', F.col('patient_list.visit_concept_ages')) \ - .withColumn('visit_dates', F.col('patient_list.visit_start_date')) \ - .withColumn('visit_segments', F.col('patient_list.visit_segment')) \ - .withColumn('visit_masks', F.col('patient_list.visit_mask')) \ - .withColumn('visit_concept_ids', - F.col('patient_list.visit_concept_id').cast(T.ArrayType(T.StringType()))) \ - .withColumn('time_interval_atts', F.col('patient_list.time_interval_att')) \ - .withColumn('concept_values', F.col('patient_list.concept_values')) \ - .withColumn('concept_value_masks', F.col('patient_list.concept_value_masks')) \ - .withColumn('mlm_skip_values', F.col('patient_list.mlm_skip_values')) \ - .withColumn('condition_masks', F.col('patient_list.condition_masks')) \ - .withColumn('is_readmissions', - F.col('patient_list.is_readmission').cast(T.ArrayType(T.IntegerType()))) \ - .withColumn('is_inpatients', - F.col('patient_list.is_inpatient').cast(T.ArrayType(T.IntegerType()))) \ - .withColumn('visit_prolonged_stays', - F.col('patient_list.prolonged_stay').cast(T.ArrayType(T.IntegerType()))) \ - .drop('patient_list') + patient_sequence = patient_sequence.where(F.col("num_of_visits") <= max_num_of_visits_per_person) + + patient_sequence = ( + patient_sequence.withColumn("visit_rank_orders", F.col("patient_list.visit_rank_order")) + .withColumn("concept_orders", F.col("patient_list.visit_concept_orders")) + .withColumn("concept_ids", F.col("patient_list.visit_concept_ids")) + .withColumn("dates", F.col("patient_list.visit_concept_dates")) + .withColumn("ages", F.col("patient_list.visit_concept_ages")) + .withColumn("visit_dates", F.col("patient_list.visit_start_date")) + .withColumn("visit_segments", F.col("patient_list.visit_segment")) + .withColumn("visit_masks", F.col("patient_list.visit_mask")) + .withColumn( + "visit_concept_ids", + F.col("patient_list.visit_concept_id").cast(T.ArrayType(T.StringType())), + ) + .withColumn("time_interval_atts", F.col("patient_list.time_interval_att")) + .withColumn("concept_values", F.col("patient_list.concept_values")) + .withColumn("concept_value_masks", F.col("patient_list.concept_value_masks")) + .withColumn("mlm_skip_values", F.col("patient_list.mlm_skip_values")) + .withColumn("condition_masks", F.col("patient_list.condition_masks")) + .withColumn( + "is_readmissions", + F.col("patient_list.is_readmission").cast(T.ArrayType(T.IntegerType())), + ) + .withColumn( + "is_inpatients", + F.col("patient_list.is_inpatient").cast(T.ArrayType(T.IntegerType())), + ) + .withColumn( + "visit_prolonged_stays", + F.col("patient_list.prolonged_stay").cast(T.ArrayType(T.IntegerType())), + ) + .drop("patient_list") + ) return patient_sequence -def create_visit_person_join( - person, - visit_occurrence, - include_incomplete_visit=True -): +def create_visit_person_join(person, visit_occurrence, include_incomplete_visit=True): """ - Create a new spark data frame based on person and visit_occurrence + Create a new spark data frame based on person and visit_occurrence. :param person: :param visit_occurrence: @@ -993,87 +1207,98 @@ def create_visit_person_join( """ # Create a pandas udf for generating the att token between two neighboring visits - @pandas_udf('string') + @pandas_udf("string") def pandas_udf_to_att(time_intervals: pd.Series) -> pd.Series: return time_intervals.apply(time_token_func) visit_rank_udf = F.row_number().over( - W.partitionBy('person_id').orderBy('visit_start_date', 'visit_end_date', - 'visit_occurrence_id')) - visit_segment_udf = F.col('visit_rank_order') % F.lit(2) + 1 - visit_windowing = W.partitionBy('person_id').orderBy('visit_start_date', - 'visit_end_date', - 'visit_occurrence_id') + W.partitionBy("person_id").orderBy("visit_start_date", "visit_end_date", "visit_occurrence_id") + ) + visit_segment_udf = F.col("visit_rank_order") % F.lit(2) + 1 + visit_windowing = W.partitionBy("person_id").orderBy("visit_start_date", "visit_end_date", "visit_occurrence_id") # Check whehter or not the visit is either an inpatient visit or E-I visit - is_inpatient_logic = F.col('visit_concept_id').isin([9201, 262]).cast('integer') + is_inpatient_logic = F.col("visit_concept_id").isin([9201, 262]).cast("integer") # Construct the logic for readmission, which is defined as inpatient visit occurred within 30 # days of the discharge readmission_logic = F.coalesce( - ((F.col('time_interval') <= 30) \ - & (F.col('visit_concept_id').isin([9201, 262])) \ - & (F.col('prev_visit_concept_id').isin([9201, 262]))).cast('integer'), F.lit(0) + ( + (F.col("time_interval") <= 30) + & (F.col("visit_concept_id").isin([9201, 262])) + & (F.col("prev_visit_concept_id").isin([9201, 262])) + ).cast("integer"), + F.lit(0), ) # Create prolonged inpatient stay # For the incomplete visit, we set prolonged_length_stay_logic to 0 prolonged_length_stay_logic = F.coalesce( - (F.datediff('visit_end_date', 'visit_start_date') >= 7).cast('integer'), F.lit(0) + (F.datediff("visit_end_date", "visit_start_date") >= 7).cast("integer"), + F.lit(0), ) - visit_filter = 'visit_start_date IS NOT NULL' + visit_filter = "visit_start_date IS NOT NULL" if not include_incomplete_visit: - visit_filter = f'{visit_filter} AND visit_end_date IS NOT NULL' + visit_filter = f"{visit_filter} AND visit_end_date IS NOT NULL" # Select the subset of columns and create derived columns using the UDF or spark sql # functions. In addition, we allow visits where visit_end_date IS NOT NULL, indicating the # visit is still on-going - visit_occurrence = visit_occurrence.select( - 'visit_occurrence_id', - 'person_id', - 'visit_concept_id', - 'visit_start_date', - 'visit_end_date' - ).where(visit_filter) \ - .withColumn('visit_rank_order', visit_rank_udf) \ - .withColumn('visit_segment', visit_segment_udf) \ - .withColumn('prev_visit_occurrence_id', F.lag('visit_occurrence_id').over(visit_windowing)) \ - .withColumn('prev_visit_concept_id', F.lag('visit_concept_id').over(visit_windowing)) \ - .withColumn('prev_visit_start_date', F.lag('visit_start_date').over(visit_windowing)) \ - .withColumn('prev_visit_end_date', F.lag('visit_end_date').over(visit_windowing)) \ - .withColumn('time_interval', F.datediff('visit_start_date', 'prev_visit_end_date')) \ - .withColumn('time_interval', - F.when(F.col('time_interval') < 0, F.lit(0)).otherwise(F.col('time_interval'))) \ - .withColumn('time_interval_att', pandas_udf_to_att('time_interval')) \ - .withColumn('is_inpatient', is_inpatient_logic) \ - .withColumn('is_readmission', readmission_logic) - - visit_occurrence = visit_occurrence \ - .withColumn('prolonged_stay', prolonged_length_stay_logic) \ - .select('visit_occurrence_id', - 'visit_concept_id', - 'person_id', - 'prolonged_stay', - 'is_readmission', - 'is_inpatient', - 'time_interval_att', - 'visit_rank_order', - 'visit_start_date', - 'visit_segment') + visit_occurrence = ( + visit_occurrence.select( + "visit_occurrence_id", + "person_id", + "visit_concept_id", + "visit_start_date", + "visit_end_date", + ) + .where(visit_filter) + .withColumn("visit_rank_order", visit_rank_udf) + .withColumn("visit_segment", visit_segment_udf) + .withColumn( + "prev_visit_occurrence_id", + F.lag("visit_occurrence_id").over(visit_windowing), + ) + .withColumn("prev_visit_concept_id", F.lag("visit_concept_id").over(visit_windowing)) + .withColumn("prev_visit_start_date", F.lag("visit_start_date").over(visit_windowing)) + .withColumn("prev_visit_end_date", F.lag("visit_end_date").over(visit_windowing)) + .withColumn("time_interval", F.datediff("visit_start_date", "prev_visit_end_date")) + .withColumn( + "time_interval", + F.when(F.col("time_interval") < 0, F.lit(0)).otherwise(F.col("time_interval")), + ) + .withColumn("time_interval_att", pandas_udf_to_att("time_interval")) + .withColumn("is_inpatient", is_inpatient_logic) + .withColumn("is_readmission", readmission_logic) + ) + + visit_occurrence = visit_occurrence.withColumn("prolonged_stay", prolonged_length_stay_logic).select( + "visit_occurrence_id", + "visit_concept_id", + "person_id", + "prolonged_stay", + "is_readmission", + "is_inpatient", + "time_interval_att", + "visit_rank_order", + "visit_start_date", + "visit_segment", + ) # Assume the birthday to be the first day of the birth year if birth_datetime is missing - person = person.select('person_id', F.coalesce('birth_datetime', - F.concat('year_of_birth', F.lit('-01-01')).cast( - 'timestamp')).alias('birth_datetime')) - return visit_occurrence.join(person, 'person_id') + person = person.select( + "person_id", + F.coalesce( + "birth_datetime", + F.concat("year_of_birth", F.lit("-01-01")).cast("timestamp"), + ).alias("birth_datetime"), + ) + return visit_occurrence.join(person, "person_id") -def process_measurement( - spark, - measurement, - required_measurement, - output_folder: str = None -): +def process_measurement(spark, measurement, required_measurement, output_folder: str = None): """ - Remove the measurement values that are outside the 0.01-0.99 quantiles. And scale the the + Remove the measurement values that are outside the 0.01-0.99 quantiles. + + And scale the the measurement value by substracting the mean and dividing by the standard deivation :param spark: :param @@ -1085,26 +1310,23 @@ def process_measurement( # Register the tables in spark context measurement.createOrReplaceTempView(MEASUREMENT) required_measurement.createOrReplaceTempView(REQUIRED_MEASUREMENT) - measurement_unit_stats_df = spark.sql( - measurement_unit_stats_query - ) + measurement_unit_stats_df = spark.sql(measurement_unit_stats_query) if output_folder: - measurement_unit_stats_df.repartition(10) \ - .write.mode('overwrite') \ - .parquet(path.join(output_folder, 'measurement_unit_stats')) - measurement_unit_stats_df = spark.read.parquet( - path.join(output_folder, 'measurement_unit_stats') + measurement_unit_stats_df.repartition(10).write.mode("overwrite").parquet( + path.join(output_folder, "measurement_unit_stats") ) + measurement_unit_stats_df = spark.read.parquet(path.join(output_folder, "measurement_unit_stats")) # Cache the stats in memory measurement_unit_stats_df.cache() # Broadcast df to local executors broadcast(measurement_unit_stats_df) # Create the temp view for this dataframe - measurement_unit_stats_df.createOrReplaceTempView('measurement_unit_stats') + measurement_unit_stats_df.createOrReplaceTempView("measurement_unit_stats") - scaled_numeric_lab = spark.sql(''' + scaled_numeric_lab = spark.sql( + """ SELECT m.person_id, m.measurement_concept_id AS standard_concept_id, @@ -1115,16 +1337,18 @@ def process_measurement( (m.value_as_number - s.value_mean) / value_stddev AS concept_value FROM measurement AS m JOIN measurement_unit_stats AS s - ON s.measurement_concept_id = m.measurement_concept_id + ON s.measurement_concept_id = m.measurement_concept_id AND s.unit_concept_id = m.unit_concept_id WHERE m.visit_occurrence_id IS NOT NULL AND m.value_as_number IS NOT NULL AND m.value_as_number BETWEEN s.lower_bound AND s.upper_bound - ''') + """ + ) # For categorical measurements in required_measurement, we concatenate measurement_concept_id # with value_as_concept_id to construct a new standard_concept_id - categorical_lab = spark.sql(''' + categorical_lab = spark.sql( + """ SELECT m.person_id, CASE @@ -1141,24 +1365,25 @@ def process_measurement( WHERE EXISTS ( SELECT 1 - FROM required_measurement AS r + FROM required_measurement AS r WHERE r.measurement_concept_id = m.measurement_concept_id AND r.is_numeric = false ) - ''') + """ + ) processed_measurement_df = scaled_numeric_lab.unionAll(categorical_lab) if output_folder: - processed_measurement_df.write.mode('overwrite').parquet(path.join(output_folder, 'processed_measurement')) - processed_measurement_df = spark.read.parquet(path.join(output_folder, 'processed_measurement')) + processed_measurement_df.write.mode("overwrite").parquet(path.join(output_folder, "processed_measurement")) + processed_measurement_df = spark.read.parquet(path.join(output_folder, "processed_measurement")) return processed_measurement_df def get_mlm_skip_domains(spark, input_folder, mlm_skip_table_list): """ - Translate the domain_table_name to the domain name + Translate the domain_table_name to the domain name. :param spark: :param input_folder: @@ -1166,15 +1391,14 @@ def get_mlm_skip_domains(spark, input_folder, mlm_skip_table_list): :return: """ domain_tables = [ - preprocess_domain_table(spark, input_folder, domain_table_name) - for domain_table_name in mlm_skip_table_list + preprocess_domain_table(spark, input_folder, domain_table_name) for domain_table_name in mlm_skip_table_list ] return list(map(get_domain_field, domain_tables)) def validate_table_names(domain_names): - for domain_name in domain_names.split(' '): + for domain_name in domain_names.split(" "): if domain_name not in CDM_TABLES: - raise argparse.ArgumentTypeError(f'{domain_name} is an invalid CDM table name') + raise argparse.ArgumentTypeError(f"{domain_name} is an invalid CDM table name") return domain_names diff --git a/tests/integration_tests/runners/hf_cehrbert_pretrain_runner_test.py b/tests/integration_tests/runners/hf_cehrbert_pretrain_runner_test.py index b60c32b7..19e5a8ab 100644 --- a/tests/integration_tests/runners/hf_cehrbert_pretrain_runner_test.py +++ b/tests/integration_tests/runners/hf_cehrbert_pretrain_runner_test.py @@ -1,10 +1,12 @@ -import unittest +import os +import shutil import sys import tempfile -import shutil -import os -from datasets import disable_caching +import unittest from pathlib import Path + +from datasets import disable_caching + from cehrbert.runners.hf_cehrbert_pretrain_runner import main disable_caching() @@ -17,27 +19,27 @@ class HfCehrBertRunnerIntegrationTest(unittest.TestCase): def setUp(self): # Get the root folder of the project root_folder = Path(os.path.abspath(__file__)).parent.parent.parent.parent - data_folder = os.path.join(root_folder, 'sample_data', 'pretrain') + data_folder = os.path.join(root_folder, "sample_data", "pretrain") # Create a temporary directory to store model and tokenizer self.temp_dir = tempfile.mkdtemp() - self.model_folder_path = os.path.join(self.temp_dir, 'model') + self.model_folder_path = os.path.join(self.temp_dir, "model") Path(self.model_folder_path).mkdir(parents=True, exist_ok=True) - self.dataset_prepared_path = os.path.join(self.temp_dir, 'dataset_prepared_path') + self.dataset_prepared_path = os.path.join(self.temp_dir, "dataset_prepared_path") Path(self.dataset_prepared_path).mkdir(parents=True, exist_ok=True) sys.argv = [ - 'hf_cehrbert_pretraining_runner.py', - '--model_name_or_path', + "hf_cehrbert_pretraining_runner.py", + "--model_name_or_path", self.model_folder_path, - '--tokenizer_name_or_path', + "--tokenizer_name_or_path", self.model_folder_path, - '--output_dir', + "--output_dir", self.model_folder_path, - '--data_folder', + "--data_folder", data_folder, - '--dataset_prepared_path', + "--dataset_prepared_path", self.dataset_prepared_path, - '--max_steps', - '10' + "--max_steps", + "10", ] def tearDown(self): @@ -48,5 +50,5 @@ def test_train_model(self): main() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/integration_tests/trainers/train_cehr_bert_test.py b/tests/integration_tests/trainers/train_cehr_bert_test.py index 4fe679aa..d5548743 100644 --- a/tests/integration_tests/trainers/train_cehr_bert_test.py +++ b/tests/integration_tests/trainers/train_cehr_bert_test.py @@ -1,10 +1,11 @@ -import unittest -import tempfile -import shutil import os +import shutil +import tempfile +import unittest os.environ["CUDA_VISIBLE_DEVICES"] = "-1" from pathlib import Path + from cehrbert.trainers.train_cehr_bert import VanillaBertTrainer @@ -14,11 +15,11 @@ def setUp(self): root_folder = Path(os.path.abspath(__file__)).parent.parent.parent.parent # Create a temporary directory to store model and tokenizer self.temp_dir = tempfile.mkdtemp() - self.model_folder_path = os.path.join(self.temp_dir, 'model') + self.model_folder_path = os.path.join(self.temp_dir, "model") Path(self.model_folder_path).mkdir(parents=True, exist_ok=True) - self.tf_board_log_path = os.path.join(self.model_folder_path, 'logs') - self.training_data_parquet_path = os.path.join(root_folder, 'sample_data/pretrain/patient_sequence.parquet') + self.tf_board_log_path = os.path.join(self.model_folder_path, "logs") + self.training_data_parquet_path = os.path.join(root_folder, "sample_data/pretrain/patient_sequence.parquet") self.embedding_size = 16 self.context_window_size = 10 @@ -48,7 +49,7 @@ def setUp(self): batch_size=self.batch_size, epochs=self.epochs, learning_rate=self.learning_rate, - tf_board_log_path=self.tf_board_log_path + tf_board_log_path=self.tf_board_log_path, ) def tearDown(self): @@ -60,5 +61,5 @@ def test_train_model(self): self.trainer.train_model() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit_tests/data_generators/bert_masked_language_modeling_learning_objective_test.py b/tests/unit_tests/data_generators/bert_masked_language_modeling_learning_objective_test.py index 35eef0d2..9f7a3c81 100644 --- a/tests/unit_tests/data_generators/bert_masked_language_modeling_learning_objective_test.py +++ b/tests/unit_tests/data_generators/bert_masked_language_modeling_learning_objective_test.py @@ -1,6 +1,8 @@ import unittest + import numpy as np import pandas as pd + from cehrbert.data_generators.data_classes import RowSlicer from cehrbert.data_generators.learning_objective import MaskedLanguageModelLearningObjective from cehrbert.data_generators.tokenizer import ConceptTokenizer @@ -10,33 +12,53 @@ class TestMaskedLanguageModelLearningObjective(unittest.TestCase): def setUp(self): # Setup code to run before each test, e.g., create a ConceptTokenizer instance - self.concept_tokenizer = ConceptTokenizer() # Initialize this with whatever parameters are appropriate for your implementation + self.concept_tokenizer = ( + ConceptTokenizer() + ) # Initialize this with whatever parameters are appropriate for your implementation self.max_seq_len = 6 self.is_pretraining = True self.learning_obj = MaskedLanguageModelLearningObjective( - self.concept_tokenizer, - self.max_seq_len, - self.is_pretraining + self.concept_tokenizer, self.max_seq_len, self.is_pretraining ) @staticmethod def create_mock_row(): # Create a mock row with 5 elements in each list return RowSlicer( - row=pd.Series({ - 'dates': [1, 2, 3, 4, 5], - 'token_ids': [101, 102, 103, 104, 105], # Example token IDs - 'visit_segments': [1, 1, 2, 2, 1], # Example visit segments - 'ages': [25, 26, 27, 28, 29], # Example ages - 'visit_concept_orders': [1, 2, 3, 4, 5], # Example visit concept orders - 'concept_values': [0.0, 0.0, 0.0, 0.0, 0.9], # Example concept values - 'concept_value_masks': [0, 0, 0, 0, 1], # Example concept value masks - 'mlm_skip_values': [0, 0, 0, 0, 1], # Example MLM skip values - 'orders': [1, 2, 3, 4, 5] # Example orders for sorting - }), + row=pd.Series( + { + "dates": [1, 2, 3, 4, 5], + "token_ids": [101, 102, 103, 104, 105], # Example token IDs + "visit_segments": [1, 1, 2, 2, 1], # Example visit segments + "ages": [25, 26, 27, 28, 29], # Example ages + "visit_concept_orders": [ + 1, + 2, + 3, + 4, + 5, + ], # Example visit concept orders + "concept_values": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.9, + ], # Example concept values + "concept_value_masks": [ + 0, + 0, + 0, + 0, + 1, + ], # Example concept value masks + "mlm_skip_values": [0, 0, 0, 0, 1], # Example MLM skip values + "orders": [1, 2, 3, 4, 5], # Example orders for sorting + } + ), start_index=0, end_index=5, # Updated to include all 5 elements - target_index=2 # Adjusted target index for demonstration + target_index=2, # Adjusted target index for demonstration ) def test_initialization(self): @@ -48,10 +70,10 @@ def test_initialization(self): def test_get_tf_dataset_schema(self): # Test the get_tf_dataset_schema method input_schema, output_schema = self.learning_obj.get_tf_dataset_schema() - self.assertIn('masked_concept_ids', input_schema) - self.assertIn('concept_ids', input_schema) - self.assertIn('mask', input_schema) - self.assertIn('concept_predictions', output_schema) + self.assertIn("masked_concept_ids", input_schema) + self.assertIn("concept_ids", input_schema) + self.assertIn("mask", input_schema) + self.assertIn("concept_predictions", output_schema) def test_process_batch(self): # Test the process_batch method with a mock input @@ -60,22 +82,31 @@ def test_process_batch(self): input_dict, output_dict = self.learning_obj.process_batch(mock_rows) # Assert that the input and output dictionaries have the correct structure and values - self.assertIn('masked_concept_ids', input_dict) - self.assertIn('concept_ids', input_dict) - self.assertIn('mask', input_dict) + self.assertIn("masked_concept_ids", input_dict) + self.assertIn("concept_ids", input_dict) + self.assertIn("mask", input_dict) # Continue for all expected keys in the input and output dictionaries... - self.assertIn('concept_predictions', output_dict) + self.assertIn("concept_predictions", output_dict) self.assertTrue( - (input_dict['concept_ids'][0] == np.asarray( - [101, 102, 103, 104, 105, self.concept_tokenizer.get_unused_token_id()])).all() + ( + input_dict["concept_ids"][0] + == np.asarray( + [ + 101, + 102, + 103, + 104, + 105, + self.concept_tokenizer.get_unused_token_id(), + ] + ) + ).all() ) # Test the concept mask, where 1 indicates attention and 0 indicates mask - self.assertTrue( - (input_dict['mask'][0] == np.asarray([1, 1, 1, 1, 1, 0])).all() - ) + self.assertTrue((input_dict["mask"][0] == np.asarray([1, 1, 1, 1, 1, 0])).all()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit_tests/data_generators/hf_data_generator/hf_generate_start_end_index_mapping_test.py b/tests/unit_tests/data_generators/hf_data_generator/hf_generate_start_end_index_mapping_test.py index 425d698e..fbc84ac4 100644 --- a/tests/unit_tests/data_generators/hf_data_generator/hf_generate_start_end_index_mapping_test.py +++ b/tests/unit_tests/data_generators/hf_data_generator/hf_generate_start_end_index_mapping_test.py @@ -1,8 +1,9 @@ -import unittest -from cehrbert.data_generators.hf_data_generator.hf_dataset_collator import CehrBertDataCollator import random +import unittest from unittest.mock import MagicMock +from cehrbert.data_generators.hf_data_generator.hf_dataset_collator import CehrBertDataCollator + # Seed the random number generator for reproducibility in tests random.seed(42) @@ -15,51 +16,46 @@ def setUp(self): self.mock_tokenizer.mask_token_index = 1 self.mock_tokenizer.unused_token_index = 99 self.mock_tokenizer.encode.return_value = [10, 20, 30] # Example token IDs - self.mock_tokenizer._convert_token_to_id.side_effect = [2, 3, 17, 18] - self.mock_tokenizer._convert_id_to_token.side_effect = ['year:2000', 'age:20-30'] - self.data_collator = CehrBertDataCollator( - tokenizer=self.mock_tokenizer, - max_length=10 - ) + self.mock_tokenizer.convert_token_to_id.side_effect = [2, 3, 17, 18] + self.mock_tokenizer.convert_id_to_token.side_effect = [ + "year:2000", + "age:20-30", + ] + self.data_collator = CehrBertDataCollator(tokenizer=self.mock_tokenizer, max_length=10) def test_long_sequence(self): # Test with a sequence longer than max_sequence_length - record = { - 'input_ids': [2, 4, 3, 5, 2, 6, 7, 8, 9, 10, 3, 11, 2, 12, 3, 13, 2, 14, 3] - } + record = {"input_ids": [2, 4, 3, 5, 2, 6, 7, 8, 9, 10, 3, 11, 2, 12, 3, 13, 2, 14, 3]} result = self.data_collator.generate_start_end_index(record) - self.assertListEqual(result['input_ids'], [2, 4, 3, 5, 2, 6, 7, 8, 9]) + self.assertListEqual(result["input_ids"], [2, 4, 3, 5, 2, 6, 7, 8, 9]) def test_short_sequence(self): # Test with a sequence shorter than max_sequence_length - record = { - 'input_ids': list(range(5)) # Shorter than max_sequence_length - } + record = {"input_ids": list(range(5))} # Shorter than max_sequence_length result = self.data_collator.generate_start_end_index(record) - self.assertListEqual(result['input_ids'], list(range(5))) + self.assertListEqual(result["input_ids"], list(range(5))) def test_edge_case_sequence_length_equal_to_max(self): # Test with a sequence exactly equal to max_sequence_length - record = { - 'input_ids': list(range(9)) # Exactly max_sequence_length - 1 - } + record = {"input_ids": list(range(9))} # Exactly max_sequence_length - 1 result = self.data_collator.generate_start_end_index(record) - self.assertListEqual(result['input_ids'], list(range(9))) + self.assertListEqual(result["input_ids"], list(range(9))) def test_tail_case_sequence_length_equal_to_max(self): from cehrbert.data_generators.hf_data_generator.hf_dataset_collator import TruncationType + # Test with a sequence exactly equal to max_sequence_length default_val = self.data_collator.truncate_type self.data_collator.truncate_type = TruncationType.TAIL record = { - 'input_ids': [13, 14, 15, 16] + list(range(2, 8)), # Exactly max_sequence_length - 1, - 'dates': [0, 0, 0, 0] + list(range(2052, 2058)) + "input_ids": [13, 14, 15, 16] + list(range(2, 8)), # Exactly max_sequence_length - 1, + "dates": [0, 0, 0, 0] + list(range(2052, 2058)), } result = self.data_collator.generate_start_end_index(record) - self.assertListEqual(result['input_ids'], [2, 3, 4, 5, 6, 7]) - self.assertListEqual(result['dates'], [2052, 2053, 2054, 2055, 2056, 2057]) + self.assertListEqual(result["input_ids"], [2, 3, 4, 5, 6, 7]) + self.assertListEqual(result["dates"], [2052, 2053, 2054, 2055, 2056, 2057]) self.data_collator.truncate_type = default_val -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit_tests/data_generators/hf_data_generator/hf_masked_language_modeling_mapping_test.py b/tests/unit_tests/data_generators/hf_data_generator/hf_masked_language_modeling_mapping_test.py index 1b91d30d..2d198c49 100644 --- a/tests/unit_tests/data_generators/hf_data_generator/hf_masked_language_modeling_mapping_test.py +++ b/tests/unit_tests/data_generators/hf_data_generator/hf_masked_language_modeling_mapping_test.py @@ -1,6 +1,7 @@ -import unittest import random +import unittest from unittest.mock import MagicMock + from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import HFTokenizationMapping @@ -19,10 +20,14 @@ def setUp(self): def test_transform_with_valid_indices(self): # Given a valid record with start and end indices record = { - 'concept_ids': ['c1', 'c2', 'c3'], - 'mlm_skip_values': [0, 0, 0, ], - 'concept_value_masks': [0, 0, 0], - 'concept_values': [0., 0., 0.], + "concept_ids": ["c1", "c2", "c3"], + "mlm_skip_values": [ + 0, + 0, + 0, + ], + "concept_value_masks": [0, 0, 0], + "concept_values": [0.0, 0.0, 0.0], } # Random seed for predictability in tests @@ -31,28 +36,32 @@ def test_transform_with_valid_indices(self): # Expected masked input ids might depend on random masking logic # Here we assume the second token gets masked with the mask token index (1) expected_masked_input_ids = [10, 20, 30] - expected_labels = [10, 20, 30] # Only non-masked tokens are labeled with original ids + expected_labels = [ + 10, + 20, + 30, + ] # Only non-masked tokens are labeled with original ids result = self.mapping.transform(record) # Check if the tokenizer's encode method was called correctly - self.mock_tokenizer.encode.assert_called_once_with(['c1', 'c2', 'c3']) + self.mock_tokenizer.encode.assert_called_once_with(["c1", "c2", "c3"]) # Validate the output - self.assertEqual(result['input_ids'], expected_masked_input_ids) - self.assertEqual(result['labels'], expected_labels) + self.assertEqual(result["input_ids"], expected_masked_input_ids) + self.assertEqual(result["labels"], expected_labels) def test_transform_assertion(self): # Given a valid record with start and end indices record = { - 'concept_ids': ['c1', 'c2', 'c3', 'c4'], - 'mlm_skip_values': [0, 0, 0, 1], - 'concept_value_masks': [0, 0, 0, 0], - 'concept_values': [0., 0., 0., 0.], + "concept_ids": ["c1", "c2", "c3", "c4"], + "mlm_skip_values": [0, 0, 0, 1], + "concept_value_masks": [0, 0, 0, 0], + "concept_values": [0.0, 0.0, 0.0, 0.0], } with self.assertRaises(AssertionError): self.mapping.transform(record) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit_tests/data_generators/hf_data_generator/hf_med_to_cehrbert_mapping_test.py b/tests/unit_tests/data_generators/hf_data_generator/hf_med_to_cehrbert_mapping_test.py index 9bda2f08..f0d76a7b 100644 --- a/tests/unit_tests/data_generators/hf_data_generator/hf_med_to_cehrbert_mapping_test.py +++ b/tests/unit_tests/data_generators/hf_data_generator/hf_med_to_cehrbert_mapping_test.py @@ -1,9 +1,9 @@ import unittest -from cehrbert.med_extension.schema_extension import CehrBertPatient, Visit, Event from datetime import datetime -from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping +from cehrbert.med_extension.schema_extension import CehrBertPatient, Event, Visit +from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments from cehrbert.spark_apps.decorators.patient_event_decorator import AttType @@ -14,30 +14,18 @@ def setUp(self): outpatient_visit = Visit( visit_type="9202", visit_start_datetime=datetime(2024, 4, 14, 0, 0), - events=[ - Event( - time=datetime(2024, 4, 14, 0, 0), - code='320128' - ) - ] + events=[Event(time=datetime(2024, 4, 14, 0, 0), code="320128")], ) inpatient_visit = Visit( visit_type="9201", visit_start_datetime=datetime(2024, 4, 21, 0, 0), visit_end_datetime=datetime(2024, 4, 22, 0, 0), - discharge_facility='8536', + discharge_facility="8536", events=[ - Event( - time=datetime(2024, 4, 21, 0, 0), - code='320128' - ), - Event( - time=datetime(2024, 4, 22, 0, 0), - code='4134120', - numeric_value=0.5 - ) - ] + Event(time=datetime(2024, 4, 21, 0, 0), code="320128"), + Event(time=datetime(2024, 4, 22, 0, 0), code="4134120", numeric_value=0.5), + ], ) # Intentionally perturb the chronological order of visits by putting outpatient_visit after inpatient_visit, @@ -45,11 +33,9 @@ def setUp(self): self.patient = CehrBertPatient( patient_id=0, birth_datetime=datetime(1980, 4, 14, 0, 0), - gender='Gender/F', - race='Race/unknown', - visits=[ - inpatient_visit, outpatient_visit - ] + gender="Gender/F", + race="Race/unknown", + visits=[inpatient_visit, outpatient_visit], ) def test_transform_cehrbert_with_auxiliary_token(self): @@ -58,130 +44,105 @@ def test_transform_cehrbert_with_auxiliary_token(self): data_folder=None, # required field set to None dataset_prepared_path=None, # required field set to None att_function_type=AttType.CEHR_BERT.value, - include_auxiliary_token=True + include_auxiliary_token=True, ) # Create an instance of the mapping class - mapper = MedToCehrBertDatasetMapping( - data_args - ) + mapper = MedToCehrBertDatasetMapping(data_args) transformed_record = mapper.transform(self.patient) # Assert - self.assertEqual(transformed_record['person_id'], 0) + self.assertEqual(transformed_record["person_id"], 0) # Test concept_ids self.assertListEqual( - transformed_record['concept_ids'], - ['[VS]', '9202', '320128', '[VE]', 'W1', '[VS]', '9201', '320128', '4134120', '8536', '[VE]'] + transformed_record["concept_ids"], + [ + "[VS]", + "9202", + "320128", + "[VE]", + "W1", + "[VS]", + "9201", + "320128", + "4134120", + "8536", + "[VE]", + ], ) # Test ages, age=-1 used for the ATT tokens - self.assertListEqual( - transformed_record['ages'], - [44, 44, 44, 44, -1, 44, 44, 44, 44, 44, 44] - ) + self.assertListEqual(transformed_record["ages"], [44, 44, 44, 44, -1, 44, 44, 44, 44, 44, 44]) # Test dates, dates=0 used for the ATT tokens self.assertListEqual( - transformed_record['dates'], - [2832, 2832, 2832, 2832, 0, 2833, 2833, 2833, 2833, 2833, 2833] + transformed_record["dates"], + [2832, 2832, 2832, 2832, 0, 2833, 2833, 2833, 2833, 2833, 2833], ) # Test visit_segments, visit_segment=0 used for the ATT tokens - self.assertListEqual( - transformed_record['visit_segments'], - [1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2] - ) + self.assertListEqual(transformed_record["visit_segments"], [1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2]) # Test visit_concept_orders, we visit_concept_order to be same as next visit for the ATT tokens self.assertListEqual( - transformed_record['visit_concept_orders'], - [1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2] + transformed_record["visit_concept_orders"], + [1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2], ) # Test concept_value_masks - self.assertListEqual( - transformed_record['concept_value_masks'], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0] - ) + self.assertListEqual(transformed_record["concept_value_masks"], [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]) # Test concept_values, concept_value=-1 is a default value associated with non-numeric measurements self.assertListEqual( - transformed_record['concept_values'], - [-1, -1, -1, -1, -1, -1, -1, -1, 0.5, -1, -1] + transformed_record["concept_values"], + [-1, -1, -1, -1, -1, -1, -1, -1, 0.5, -1, -1], ) # Test mlm_skip_values - self.assertListEqual( - transformed_record['mlm_skip_values'], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0] - ) + self.assertListEqual(transformed_record["mlm_skip_values"], [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]) def test_transform_basic(self): data_args = DataTrainingArguments( data_folder=None, # required field set to None dataset_prepared_path=None, # required field set to None att_function_type=AttType.CEHR_BERT.value, - include_auxiliary_token=False + include_auxiliary_token=False, ) # Create an instance of the mapping class - mapper = MedToCehrBertDatasetMapping( - data_args - ) + mapper = MedToCehrBertDatasetMapping(data_args) transformed_record = mapper.transform(self.patient) # Assert - self.assertEqual(transformed_record['person_id'], 0) + self.assertEqual(transformed_record["person_id"], 0) # Test concept_ids self.assertListEqual( - transformed_record['concept_ids'], - ['[VS]', '320128', '[VE]', 'W1', '[VS]', '320128', '4134120', '[VE]'] + transformed_record["concept_ids"], + ["[VS]", "320128", "[VE]", "W1", "[VS]", "320128", "4134120", "[VE]"], ) # Test ages, age=-1 used for the ATT tokens - self.assertListEqual( - transformed_record['ages'], - [44, 44, 44, -1, 44, 44, 44, 44] - ) + self.assertListEqual(transformed_record["ages"], [44, 44, 44, -1, 44, 44, 44, 44]) # Test dates, dates=0 used for the ATT tokens - self.assertListEqual( - transformed_record['dates'], - [2832, 2832, 2832, 0, 2833, 2833, 2833, 2833] - ) + self.assertListEqual(transformed_record["dates"], [2832, 2832, 2832, 0, 2833, 2833, 2833, 2833]) # Test visit_segments, visit_segment=0 used for the ATT tokens - self.assertListEqual( - transformed_record['visit_segments'], - [1, 1, 1, 0, 2, 2, 2, 2] - ) + self.assertListEqual(transformed_record["visit_segments"], [1, 1, 1, 0, 2, 2, 2, 2]) # Test visit_concept_orders, we visit_concept_order to be same as next visit for the ATT tokens - self.assertListEqual( - transformed_record['visit_concept_orders'], - [1, 1, 1, 2, 2, 2, 2, 2] - ) + self.assertListEqual(transformed_record["visit_concept_orders"], [1, 1, 1, 2, 2, 2, 2, 2]) # Test concept_value_masks - self.assertListEqual( - transformed_record['concept_value_masks'], - [0, 0, 0, 0, 0, 0, 1, 0] - ) + self.assertListEqual(transformed_record["concept_value_masks"], [0, 0, 0, 0, 0, 0, 1, 0]) # Test concept_values, concept_value=-1 is a default value associated with non-numeric measurements - self.assertListEqual( - transformed_record['concept_values'], - [-1, -1, -1, -1, -1, -1, 0.5, -1] - ) + self.assertListEqual(transformed_record["concept_values"], [-1, -1, -1, -1, -1, -1, 0.5, -1]) # Test mlm_skip_values - self.assertListEqual( - transformed_record['mlm_skip_values'], - [0, 0, 0, 0, 0, 0, 1, 0] - ) + self.assertListEqual(transformed_record["mlm_skip_values"], [0, 0, 0, 0, 0, 0, 1, 0]) def test_cehrgpt_transform(self): data_args = DataTrainingArguments( @@ -190,60 +151,91 @@ def test_cehrgpt_transform(self): att_function_type=AttType.DAY.value, inpatient_att_function_type=AttType.DAY.value, include_auxiliary_token=True, - include_demographic_prompt=True + include_demographic_prompt=True, ) # Create an instance of the mapping class - mapper = MedToCehrBertDatasetMapping( - data_args - ) + mapper = MedToCehrBertDatasetMapping(data_args) transformed_record = mapper.transform(self.patient) # Test concept_ids self.assertListEqual( - transformed_record['concept_ids'], - ['year:2024', 'age:44', 'Gender/F', 'Race/unknown', '[VS]', '9202', '320128', '[VE]', - 'D7', '[VS]', '9201', '320128', 'i-D1', '4134120', '8536', '[VE]'] + transformed_record["concept_ids"], + [ + "year:2024", + "age:44", + "Gender/F", + "Race/unknown", + "[VS]", + "9202", + "320128", + "[VE]", + "D7", + "[VS]", + "9201", + "320128", + "i-D1", + "4134120", + "8536", + "[VE]", + ], ) # Test ages, age=-1 used for the ATT tokens self.assertListEqual( - transformed_record['ages'], - [-1, -1, -1, -1, 44, 44, 44, 44, -1, 44, 44, 44, -1, 44, 44, 44] + transformed_record["ages"], + [-1, -1, -1, -1, 44, 44, 44, 44, -1, 44, 44, 44, -1, 44, 44, 44], ) # Test dates, dates=0 used for the ATT tokens self.assertListEqual( - transformed_record['dates'], - [0, 0, 0, 0, 2832, 2832, 2832, 2832, 0, 2833, 2833, 2833, 0, 2833, 2833, 2833] + transformed_record["dates"], + [ + 0, + 0, + 0, + 0, + 2832, + 2832, + 2832, + 2832, + 0, + 2833, + 2833, + 2833, + 0, + 2833, + 2833, + 2833, + ], ) # Test visit_segments, visit_segment=0 used for the ATT tokens self.assertListEqual( - transformed_record['visit_segments'], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2] + transformed_record["visit_segments"], + [0, 0, 0, 0, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2], ) # Test visit_concept_orders, we visit_concept_order to be same as next visit for the ATT tokens self.assertListEqual( - transformed_record['visit_concept_orders'], - [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2] + transformed_record["visit_concept_orders"], + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2], ) # Test concept_value_masks self.assertListEqual( - transformed_record['concept_value_masks'], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0] + transformed_record["concept_value_masks"], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], ) # Test concept_values, concept_value=-1 is a default value associated with non-numeric measurements self.assertListEqual( - transformed_record['concept_values'], - [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0.5, -1, -1] + transformed_record["concept_values"], + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0.5, -1, -1], ) # Test mlm_skip_values self.assertListEqual( - transformed_record['mlm_skip_values'], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0] + transformed_record["mlm_skip_values"], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], ) def test_inpatient_att_transform(self): @@ -253,67 +245,69 @@ def test_inpatient_att_transform(self): att_function_type=AttType.CEHR_BERT.value, inpatient_att_function_type=AttType.DAY.value, include_auxiliary_token=True, - include_demographic_prompt=False + include_demographic_prompt=False, ) # Create an instance of the mapping class - mapper = MedToCehrBertDatasetMapping( - data_args - ) + mapper = MedToCehrBertDatasetMapping(data_args) transformed_record = mapper.transform(self.patient) # Assert - self.assertEqual(transformed_record['person_id'], 0) + self.assertEqual(transformed_record["person_id"], 0) # Test concept_ids self.assertListEqual( - transformed_record['concept_ids'], - ['[VS]', '9202', '320128', '[VE]', 'W1', '[VS]', '9201', '320128', 'i-D1', '4134120', '8536', '[VE]'] + transformed_record["concept_ids"], + [ + "[VS]", + "9202", + "320128", + "[VE]", + "W1", + "[VS]", + "9201", + "320128", + "i-D1", + "4134120", + "8536", + "[VE]", + ], ) # Test ages, age=-1 used for the ATT tokens - self.assertListEqual( - transformed_record['ages'], - [44, 44, 44, 44, -1, 44, 44, 44, -1, 44, 44, 44] - ) + self.assertListEqual(transformed_record["ages"], [44, 44, 44, 44, -1, 44, 44, 44, -1, 44, 44, 44]) # Test dates, dates=0 used for the ATT tokens self.assertListEqual( - transformed_record['dates'], - [2832, 2832, 2832, 2832, 0, 2833, 2833, 2833, 0, 2833, 2833, 2833] + transformed_record["dates"], + [2832, 2832, 2832, 2832, 0, 2833, 2833, 2833, 0, 2833, 2833, 2833], ) # Test visit_segments, visit_segment=0 used for the ATT tokens - self.assertListEqual( - transformed_record['visit_segments'], - [1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2] - ) + self.assertListEqual(transformed_record["visit_segments"], [1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2]) # Test visit_concept_orders, we visit_concept_order to be same as next visit for the ATT tokens self.assertListEqual( - transformed_record['visit_concept_orders'], - [1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2] + transformed_record["visit_concept_orders"], + [1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2], ) # Test concept_value_masks self.assertListEqual( - transformed_record['concept_value_masks'], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0] + transformed_record["concept_value_masks"], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], ) # Test concept_values, concept_value=-1 is a default value associated with non-numeric measurements self.assertListEqual( - transformed_record['concept_values'], - [-1, -1, -1, -1, -1, -1, -1, -1, -1, 0.5, -1, -1] + transformed_record["concept_values"], + [-1, -1, -1, -1, -1, -1, -1, -1, -1, 0.5, -1, -1], ) # Test mlm_skip_values - self.assertListEqual( - transformed_record['mlm_skip_values'], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0] - ) + self.assertListEqual(transformed_record["mlm_skip_values"], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit_tests/data_generators/hf_data_generator/hf_sort_patient_sequence_dataset_mapping_test.py b/tests/unit_tests/data_generators/hf_data_generator/hf_sort_patient_sequence_dataset_mapping_test.py index d3cafdd8..0d8e36c9 100644 --- a/tests/unit_tests/data_generators/hf_data_generator/hf_sort_patient_sequence_dataset_mapping_test.py +++ b/tests/unit_tests/data_generators/hf_data_generator/hf_sort_patient_sequence_dataset_mapping_test.py @@ -1,4 +1,5 @@ import unittest + from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import SortPatientSequenceMapping @@ -9,20 +10,20 @@ def test_transform_with_orders(self): # Mock data with 'orders' column as integers record = { - 'orders': [2, 1, 3], - 'concept_ids': ['c', 'b', 'a'], - 'values': [30, 20, 10], - 'ages': [30, 25, 40], - 'visit_concept_orders': [5, 3, 9] + "orders": [2, 1, 3], + "concept_ids": ["c", "b", "a"], + "values": [30, 20, 10], + "ages": [30, 25, 40], + "visit_concept_orders": [5, 3, 9], } # Expected output after sorting expected = { - 'concept_ids': ['b', 'c', 'a'], - 'values': [20, 30, 10], - 'ages': [25, 30, 40], - 'visit_concept_orders': [3, 5, 9], - 'orders': [1, 2, 3] + "concept_ids": ["b", "c", "a"], + "values": [20, 30, 10], + "ages": [25, 30, 40], + "visit_concept_orders": [3, 5, 9], + "orders": [1, 2, 3], } # Perform transformation @@ -37,20 +38,20 @@ def test_transform_with_dates(self): # Mock data with 'dates' column as integers record = { - 'dates': [20210301, 20210101, 20210201], - 'concept_ids': ['c', 'b', 'a'], - 'values': [30, 20, 10], - 'ages': [40, 25, 30], - 'visit_concept_orders': [5, 3, 9] + "dates": [20210301, 20210101, 20210201], + "concept_ids": ["c", "b", "a"], + "values": [30, 20, 10], + "ages": [40, 25, 30], + "visit_concept_orders": [5, 3, 9], } # Expected output after sorting based on dates expected = { - 'concept_ids': ['b', 'a', 'c'], - 'values': [20, 10, 30], - 'ages': [25, 30, 40], - 'visit_concept_orders': [3, 9, 5], - 'dates': [20210101, 20210201, 20210301] + "concept_ids": ["b", "a", "c"], + "values": [20, 10, 30], + "ages": [25, 30, 40], + "visit_concept_orders": [3, 9, 5], + "dates": [20210101, 20210201, 20210301], } # Perform transformation @@ -65,10 +66,10 @@ def test_transform_without_sorting_columns(self): # Mock data without 'orders' or 'dates' record = { - 'concept_ids': ['c', 'b', 'a'], - 'values': [30, 20, 10], - 'ages': [30, 25, 40], - 'visit_concept_orders': [5, 3, 9] + "concept_ids": ["c", "b", "a"], + "values": [30, 20, 10], + "ages": [30, 25, 40], + "visit_concept_orders": [5, 3, 9], } # Expected output should be unchanged since no sorting column is provided @@ -81,5 +82,5 @@ def test_transform_without_sorting_columns(self): self.assertEqual(result, expected) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit_tests/data_generators/hf_data_generator/meds_utils_test.py b/tests/unit_tests/data_generators/hf_data_generator/meds_utils_test.py index 9620ebad..89379ef0 100644 --- a/tests/unit_tests/data_generators/hf_data_generator/meds_utils_test.py +++ b/tests/unit_tests/data_generators/hf_data_generator/meds_utils_test.py @@ -1,7 +1,8 @@ import unittest + from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import MedsToBertMimic4 from cehrbert.data_generators.hf_data_generator.meds_utils import get_meds_to_cehrbert_conversion_cls -from cehrbert.runners.hf_runner_argument_dataclass import MedsToCehrBertConversionType, AttType +from cehrbert.runners.hf_runner_argument_dataclass import AttType, MedsToCehrBertConversionType class TestGetMedsToCehrBertConversionCls(unittest.TestCase): @@ -18,5 +19,5 @@ def test_invalid_conversion(self): self.assertIn("is not a valid MedsToCehrBertConversionType", str(context.exception)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit_tests/data_generators/visit_prediction_learning_objective_test.py b/tests/unit_tests/data_generators/visit_prediction_learning_objective_test.py index 6c5c61c7..90372669 100644 --- a/tests/unit_tests/data_generators/visit_prediction_learning_objective_test.py +++ b/tests/unit_tests/data_generators/visit_prediction_learning_objective_test.py @@ -1,6 +1,8 @@ import unittest + import numpy as np import pandas as pd + from cehrbert.data_generators.data_classes import RowSlicer from cehrbert.data_generators.learning_objective import VisitPredictionLearningObjective from cehrbert.data_generators.tokenizer import ConceptTokenizer @@ -11,22 +13,21 @@ class TestVisitPredictionLearningObjective(unittest.TestCase): def setUp(self): self.visit_tokenizer = ConceptTokenizer() # Use a real or mock ConceptTokenizer as needed self.max_seq_len = 5 - self.learning_obj = VisitPredictionLearningObjective( - self.visit_tokenizer, - self.max_seq_len - ) + self.learning_obj = VisitPredictionLearningObjective(self.visit_tokenizer, self.max_seq_len) @staticmethod def create_mock_row(): # Create a mock row with 5 elements in each list return RowSlicer( - row=pd.Series({ - 'visit_token_ids': [101, 102, 103], # Example token IDs - 'visit_concept_orders': [1, 2, 3] # Example orders for sorting - }), + row=pd.Series( + { + "visit_token_ids": [101, 102, 103], # Example token IDs + "visit_concept_orders": [1, 2, 3], # Example orders for sorting + } + ), start_index=0, end_index=3, # Updated to include all 5 elements - target_index=2 # Adjusted target index for demonstration + target_index=2, # Adjusted target index for demonstration ) def test_initialization(self): @@ -35,24 +36,22 @@ def test_initialization(self): def test_get_tf_dataset_schema(self): input_schema, output_schema = self.learning_obj.get_tf_dataset_schema() - self.assertIn('masked_visit_concepts', input_schema) - self.assertIn('mask_visit', input_schema) - self.assertIn('visit_predictions', output_schema) + self.assertIn("masked_visit_concepts", input_schema) + self.assertIn("mask_visit", input_schema) + self.assertIn("visit_predictions", output_schema) def test_process_batch(self): # Test the process_batch method with a mock input mock_rows = [self.create_mock_row() for _ in range(5)] # Create a list of mock rows input_dict, output_dict = self.learning_obj.process_batch(mock_rows) - self.assertIn('masked_visit_concepts', input_dict) - self.assertIn('mask_visit', input_dict) - self.assertIn('visit_predictions', output_dict) + self.assertIn("masked_visit_concepts", input_dict) + self.assertIn("mask_visit", input_dict) + self.assertIn("visit_predictions", output_dict) # Test the concept mask, where 1 indicates attention and 0 indicates mask - self.assertTrue( - (input_dict['mask_visit'][0] == np.asarray([1, 1, 1, 0, 0])).all() - ) + self.assertTrue((input_dict["mask_visit"][0] == np.asarray([1, 1, 1, 0, 0])).all()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit_tests/models/concept_value_decoder_layer_test.py b/tests/unit_tests/models/concept_value_decoder_layer_test.py index 4e134d29..dddd14a9 100644 --- a/tests/unit_tests/models/concept_value_decoder_layer_test.py +++ b/tests/unit_tests/models/concept_value_decoder_layer_test.py @@ -1,25 +1,27 @@ +import unittest + import tensorflow as tf + from cehrbert.models.layers.custom_layers import ConceptValuePredictionLayer -import unittest class TestConceptValuePredictionLayer(unittest.TestCase): def test_layer_initialization(self): - """ Test if the layer initializes with the correct embedding size. """ + """Test if the layer initializes with the correct embedding size.""" embedding_size = 64 layer = ConceptValuePredictionLayer(embedding_size) self.assertEqual(layer.embedding_size, embedding_size) def test_get_config(self): - """ Test if the get_config method returns the correct configuration. """ + """Test if the get_config method returns the correct configuration.""" embedding_size = 64 layer = ConceptValuePredictionLayer(embedding_size) config = layer.get_config() - self.assertEqual(config['embedding_size'], embedding_size) + self.assertEqual(config["embedding_size"], embedding_size) def test_call(self): - """ Test the call method of the layer. """ + """Test the call method of the layer.""" embedding_size = 64 layer = ConceptValuePredictionLayer(embedding_size) @@ -37,5 +39,5 @@ def test_call(self): self.assertEqual(concept_vals.shape, (batch_size, context_window, 1)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit_tests/models/hf_models/hf_cehrbert_test.py b/tests/unit_tests/models/hf_models/hf_cehrbert_test.py index ab955288..8b3e49a9 100644 --- a/tests/unit_tests/models/hf_models/hf_cehrbert_test.py +++ b/tests/unit_tests/models/hf_models/hf_cehrbert_test.py @@ -1,5 +1,7 @@ import unittest + import torch + from cehrbert.models.hf_models.config import CehrBertConfig from cehrbert.models.hf_models.hf_cehrbert import CehrBertForPreTraining @@ -69,36 +71,24 @@ def test_model_output(self): concept_value_masks=concept_value_masks, visit_segments=visit_segments, labels=input_ids, - mlm_skip_values=mlm_skip_values + mlm_skip_values=mlm_skip_values, ) - self.assertTrue(hasattr(output, 'loss')) - self.assertTrue(hasattr(output, 'last_hidden_state')) - self.assertTrue(hasattr(output, 'attentions')) - self.assertTrue(hasattr(output, 'prediction_logits')) - self.assertTrue(hasattr(output, 'pooler_output')) + self.assertTrue(hasattr(output, "loss")) + self.assertTrue(hasattr(output, "last_hidden_state")) + self.assertTrue(hasattr(output, "attentions")) + self.assertTrue(hasattr(output, "prediction_logits")) + self.assertTrue(hasattr(output, "pooler_output")) - self.assertEqual( - output.prediction_logits.shape, - torch.Size([1, 10, self.config.vocab_size]) - ) - self.assertEqual( - output.pooler_output.shape, - torch.Size([1, 128]) - ) - self.assertEqual( - output.last_hidden_state.shape, - torch.Size([1, 10, self.config.hidden_size]) - ) - self.assertEqual( - len(output.attentions), - self.config.num_hidden_layers - ) + self.assertEqual(output.prediction_logits.shape, torch.Size([1, 10, self.config.vocab_size])) + self.assertEqual(output.pooler_output.shape, torch.Size([1, 128])) + self.assertEqual(output.last_hidden_state.shape, torch.Size([1, 10, self.config.hidden_size])) + self.assertEqual(len(output.attentions), self.config.num_hidden_layers) self.assertEqual( output.attentions[0].shape, - torch.Size([1, self.config.num_attention_heads, 10, 10]) + torch.Size([1, self.config.num_attention_heads, 10, 10]), ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit_tests/models/hf_models/tokenization_hf_cehrbert_test.py b/tests/unit_tests/models/hf_models/tokenization_hf_cehrbert_test.py index ad0d813a..e8badadb 100644 --- a/tests/unit_tests/models/hf_models/tokenization_hf_cehrbert_test.py +++ b/tests/unit_tests/models/hf_models/tokenization_hf_cehrbert_test.py @@ -1,9 +1,16 @@ import unittest + from tokenizers import Tokenizer from tokenizers.models import WordLevel from tokenizers.pre_tokenizers import Whitespace + from cehrbert.models.hf_models.tokenization_hf_cehrbert import ( - CehrBertTokenizer, PAD_TOKEN, MASK_TOKEN, OUT_OF_VOCABULARY_TOKEN, UNUSED_TOKEN, CLS_TOKEN + CLS_TOKEN, + MASK_TOKEN, + OUT_OF_VOCABULARY_TOKEN, + PAD_TOKEN, + UNUSED_TOKEN, + CehrBertTokenizer, ) @@ -19,15 +26,12 @@ def setUpClass(cls): UNUSED_TOKEN: 3, CLS_TOKEN: 4, "hello": 5, - "world": 6 + "world": 6, } tokenizer = Tokenizer(WordLevel(unk_token=OUT_OF_VOCABULARY_TOKEN, vocab=vocab)) tokenizer.pre_tokenizer = Whitespace() - concept_mapping = { - "hello": "Hello", - "world": "World" - } + concept_mapping = {"hello": "Hello", "world": "World"} cls.tokenizer = CehrBertTokenizer(tokenizer, lab_stats=[], concept_name_mapping=concept_mapping) def test_vocab_size(self): @@ -52,13 +56,13 @@ def test_convert_tokens_to_string(self): def test_oov_token(self): # Test the encoding of an out-of-vocabulary token encoded = self.tokenizer.encode(["nonexistent"]) - self.assertEqual(encoded, [self.tokenizer._oov_token_index]) + self.assertEqual(encoded, [self.tokenizer.oov_token_index]) def test_convert_id_to_token_oov(self): # Test decoding an out-of-vocabulary token ID - decoded = self.tokenizer._convert_id_to_token(99) # Assuming 99 is not in the index + decoded = self.tokenizer.convert_id_to_token(99) # Assuming 99 is not in the index self.assertEqual(decoded, OUT_OF_VOCABULARY_TOKEN) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/validation.py b/validation.py deleted file mode 100755 index 2fddc5e5..00000000 --- a/validation.py +++ /dev/null @@ -1,152 +0,0 @@ -import re - - -test_sequence = {'concept_ids': ['[START]', 'year:2005', 'age:20', '8532', '0', 'VS', '9203', '200219', '2007052', 'VE', 'W1', 'VS', '9203', '2006976', 'VE', 'M3', 'VS', '9203', '140214', '4063814', 'VE', 'M6', 'VS', '9203', '2007052', '4047791', '440029', 'VE', 'W0', 'VS', '9202', '140214', 'VE']},{'concept_ids': ['[START]', 'year:2010', 'age:77', '8507', '0', 'VS', '9202', '2414392', '4150062', 'VE', 'W0', 'VS', '9202', '2211481', '4150062', 'VE', 'W0', 'VS', '9202', '2211480', '2211481', '4150062', 'VE', 'LT', 'VS', '9202', '2211361', '77670', 'VE', 'LT', 'VS', '9202', '2211361', '77670', 'VE']},{'concept_ids': ['[START]', 'year:2019', 'age:55', '8507', '0', 'VS', '5083', '320128', '432867', '433736', '437827', '764123', '2414397', '320128', '437827', '764123', 'VE', 'LT', 'VS', '581477', '2313814', '2414398', '320128', '432867', '437827', '764123', 'VE', 'M2', 'VS', '581477', '2108115', '2514528', '320128', '4193704', '437827', '764123', 'VE', 'W0', 'VS', '581477', '1307046', '2313814', '2414398', '320128', '4193704', '437827', '764123', 'VE', 'M4', 'VS', '581477', '2313814', '2414398', '320128', '4193704', '437827', '764123', 'VE', 'W2', 'VS', '581477', '2414397', 'VE', 'W3', 'VS', '581477', '2108115', '2414397', '320128', '4193704', '437827', '77670', 'VE']},{'concept_ids': ['[START]', 'year:2008', 'age:79', '8507', '0', 'VS', '9202', '2313634', '2313654', '375545', '439297', 'VE', 'M4', 'VS', '9202', '2313636', '2313655', '375545', '439297', 'VE', 'LT', 'VS', '9202', '2313635', '2313655', '375545', '380103', '439297', 'VE']},{'concept_ids': ['[START]', 'year:1998', 'age:2', '8507', '0', 'VS', '9203', '372328', '42738971', 'VE', 'LT', 'VS', '9203', '25297', 'VE', 'LT', 'VS', '9203', '2514435', '25297', '440029', 'VE', 'LT', 'VS', '9203', '2514435', '2514437', '378253', 'VE', 'W0', 'VS', '9203', '2514435', '378253', '378253', '378253', 'VE', 'LT', 'VS', '9203', '2514434', '2514435', '28060', 'VE', 'W0', 'VS', '9203', '2514435', '25297', '372328', '380733', 'VE', 'LT', 'VS', '9203', '2514435', '254761', '257011', '440029', 'VE']},{'concept_ids': ['[START]', 'year:1991', 'age:46', '8507', '8552', 'VS', '9203', '77139', 'VE', 'LT', 'VS', '9203', '77670', 'VE', 'W0', 'VS', '9201', '321318', '4142645', '4195852', '4205879', '438791', '320128', '321318', '42537729', 'VE', 'W0', 'VS', '9203', '320128', 'VE', 'LT', 'VS', '9202', '2414393', '320128', 'VE', 'W3', 'VS', '9202', '2414397', '320128', '77670', 'VE']},{'concept_ids': ['[START]', 'year:2002', 'age:42', '8532', '0', 'VS', '9202', '4306780', 'VE', 'LT', 'VS', '9202', '4306780', 'VE', 'LT', 'VS', '9202', '45773385', 'VE']},{'concept_ids': ['[START]', 'year:1995', 'age:0', '8507', '8527', 'VS', '9201', '4014296', '4014296', 'VE', 'LT', 'VS', '9202', '2414392', '4088016', 'VE']},{'concept_ids': ['[START]', 'year:2003', 'age:45', '8532', '0', 'VS', '9202', '4149084', 'VE', 'M6', 'VS', '9202', '4306780', 'VE', 'M2', 'VS', '9202', '4172857', 'VE', 'LT', 'VS', '9202', '45773385', 'VE', 'M6', 'VS', '9202', '45773385', 'VE', 'LT', 'VS', '9202', '2213244', '45773385', 'VE', 'LT', 'VS', '9202', '2213244', '45773385', 'VE', 'M2', 'VS', '9202', '2211826', '4295261', '77646', 'VE', 'LT', 'VS', '9202', '2211826', '4295261', '80502', 'VE', 'W1', 'VS', '9202', '2211826', '80824', 'VE', 'LT', 'VS', '9202', '2211826', '2211828', '80502', 'VE', 'LT', 'VS', '9202', '2211809', '42627987', 'VE', 'M11', 'VS', '581477', '2313814', '2414392', '312437', '315078', '77670', 'VE']},{'concept_ids': ['[START]', 'year:2022', 'age:34', '8507', '8516', 'VS', '5083', '2414393', '442588', 'VE', 'M1', 'VS', '5083', '2414392', '442588', 'VE', 'W3', 'VS', '5083', '2414397', '442588', 'VE', 'M1', 'VS', '5083', '2414398', '313459', '442588', 'VE']},{'concept_ids': ['[START]', 'year:1988', 'age:2', '8532', '8516', 'VS', '0', '257011', 'VE', 'M4', 'VS', '0', '257011', 'VE']},{'concept_ids': ['[START]', 'year:1998', 'age:17', '8507', '0', 'VS', '9202', '200219', 'VE', 'LT', 'VS', '9202', '137275', 'VE', 'W0', 'VS', '9202', '137275', 'VE']},{'concept_ids': ['[START]', 'year:2015', 'age:62', '8532', '8527', 'VS', '9202', '1332418', '1501700', '2313814', '2414393', '320128', '77670', 'VE', 'W3', 'VS', '9202', '2313819', '2313869', '2414397', '77670', 'VE', 'W0', 'VS', '9202', '2313869', '320128', '77670', 'VE']},{'concept_ids': ['[START]', 'year:2006', 'age:0', '8532', '0', 'VS', '9201', '42739011', '2007893', 'VE', 'W0', 'VS', '9202', '4088016', 'VE', 'LT', 'VS', '9203', '140214', '2514434', '2514435', 'VE']},{'concept_ids': ['[START]', 'year:1988', 'age:1', '8507', '0', 'VS', '0', '257011', 'VE', 'M5', 'VS', '0', '317009', 'VE']},{'concept_ids': ['[START]', 'year:2019', 'age:16', '8507', '0', 'VS', '9202', '2414394', '4254485', '438409', '442077', '705944', '705944', 'VE', 'W2', 'VS', '9202', '2314103', '377091', 'VE', 'W0', 'VS', '9202', '2414398', '4149904', '438409', '442077', 'VE']},{'concept_ids': ['[START]', 'year:2020', 'age:17', '8507', '0', 'VS', '581477', '2108115', '2414398', '4267558', '440076', 'VE', 'M4', 'VS', '581477', '2108115', '2514399', '4267558', '440076', 'VE']},{'concept_ids': ['[START]', 'year:2002', 'age:33', '8532', '0', 'VS', '9203', '42738971', '44784105', '81151', 'VE', 'W3', 'VS', '9203', '42738972', '77139', 'VE']},{'concept_ids': ['[START]', 'year:2021', 'age:40', '8532', '0', 'VS', '9203', '2514435', 'VE', 'LT', 'VS', '581477', '2213244', '2514520', 'VE']},{'concept_ids': ['[START]', 'year:1996', 'age:7', '8532', '8516', 'VS', '9203', '257011', '42738972', 'VE', 'LT', 'VS', '9202', '42738972', '257011', 'VE', 'LT', 'VS', '9202', '2006977', '4036803', '4036803', 'VE', 'W0', 'VS', '9203', '2006977', '4043371', 'VE', 'W0', 'VS', '9202', '2006977', '4088016', 'VE', 'LT', 'VS', '9202', '4135174', '4210151', 'VE', 'W0', 'VS', '9202', '4135174', 'VE', 'LT', 'VS', '9203', '2514435', '372328', '378253', 'VE']}, {'concept_ids': ['[START]', 'year:1996', 'age:7', '8532', '8516', 'VS', '9203', '257011', '42738972', 'VE', 'LT', 'VS', '9202', '42738972', '257011', 'VE', 'LT', 'VS', '9202', '2006977', '4036803', '4036803', 'VE', 'W0', 'VS', '9203', '2006977', '4043371', 'VE', 'W0', 'VS', '1234', 'VE', 'LT', 'VS', '9202', '4135174', '4210151', 'VE', 'W0', 'VS', '9202', '4135174', 'VE', 'LT', 'VS', '9203', '2514435', '372328', '378253', 'VE']} - -GENDER = ['0', '8532', '8507'] - -RACE = ['0', -'38003577', -'38003595', -'38003604', -'38003607', -'38003610', -'38003609', -'38003602', -'38003576', -'38003583', -'38003605', -'38003615', -'38003606', -'38003574', -'38003596', -'38003599', -'8515', -'38003586', -'38003585', -'38003594', -'38003573', -'38003593', -'38003600', -'38003584', -'38003581', -'38003597', -'38003591', -'8657', -'38003580', -'38003579', -'38003613', -'8516', -'38003612', -'38003598', -'38003614', -'38003603', -'38003582', -'38003616', -'38003575', -'38003608', -'38003601', -'38003589', -'8557', -'38003587', -'38003588', -'38003592', -'38003611', -'8527', -'38003578', -'38003572', -'38003590' -] - -VISITS = ['0', '42898160', '38004283', '38004440', '38004220', '38004229', '38004213', '38004209', '38004367', '38004331', '38003820', '38004344', '38004254', '38004234', '38004266', '38004361', '38004330', '8782', '38004332', '38004204', '9202', '8883', '32760', '581475', '38004291', '8913', '38004277', '581383', '8947', '38004327', '38004307', '32037', '38004453', '8809', '38004239', '38004262', '38004293', '38004263', '38004193', '38004237', '38004218', '38004290', '38004342', '8827', '38004280', '38004362', '581385', '8966', '38004285', '38004334', '5084', '38004338', '38004360', '8957', '38004197', '38004356', '38004329', '32036', '38004326', '38004231', '32261', '38004351', '581478', '38004515', '38004246', '8676', '38004074', '8971', '38004324', '38004207', '38004444', '581381', '8546', '38004343', '8761', '38004364', '38004345', '38004311', '38004349', '38004352', '38004287', '38004284', '581477', '38004249', '38004205', '8964', '38004366', '38004354', '8941', '8974', '38004202', '38004335', '38004702', '38004210', '38004238', '8949', '38004442', '32693', '38004226', '38004697', '8870', '38004242', '8650', '38004521', '8672', '38004247', '8668', '38004689', '33004', '38004225', '38004346', '38004264', '38004691', '8968', '8905', '38004282', '581379', '32254', '38004323', '38004353', '32276', '32761', '38004677', '38004267', '581384', '38004314', '8951', '38004196', '38004215', '38004296', '38004328', '38004316', '38004250', '38004368', '38004227', '38004526', '38004222', '38004325', '38004275', '8960', '8716', '38004365', '38003619', '5083', '38004235', '38004278', '38004121', '38004274', '38004276', '38004703', '38004236', '38004228', '38004269', '38004217', '8976', '8858', '38004315', '38004256', '38004687', '8851', '33007', '32253', '38004303', '38004678', '38004336', '38004201', '581479', '38004248', '38004294', '38004340', '38004305', '38004523', '38004233', '38004317', '38004198', '38004519', '38004240', '38004363', '38003620', '38004321', '38004288', '38004348', '38004337', '262', '38004244', '38004223', '38004318', '38004698', '8920', '8602', '38004261', '38004268', '38004696', '38004680', '38004206', '38004700', '8850', '38004350', '38004322', '38004253', '38004259', '38004302', '8615', '38004333', '38004260', '38004306', '38004232', '8969', '38004682', '38003809', '38004194', '38004245', '38004693', '8977', '38003793', '38004339', '38004219', '8537', '38004683', '38004286', '38004358', '38004257', '9201', '38004690', '8584', '8756', '38004243', '9203', '38004310', '38004211', '38004270', '38004522', '38004279', '38004216', '581476', '8863', '8882', '581458', '38004341', '38004681', '38004443', '38004192', '38004241', '38004525', '38004281', '38004441', '38004203', '38004347', '38004359', '38004252', '38003821', '38004208', '38004295', '38004357', '32759', '38004258', '8717', '581380', '38004251', '38004199', '705159'] - -RESULTS = [] - -sequences = [x['concept_ids'] for x in test_sequence] - - -def check_ATT(token): - if token == 'LT': - return True - elif token.startswith('W'): - weeks = int(token[1:]) - if weeks < 0 or weeks > 3: - return False - else: - return True - - elif token.startswith('M'): - months = int(token[1:]) - if months < 1 or months > 11: - return False - else: - return True - - else: - return False - - - -def valid_visit(i, seq): - - if seq[i] not in VISITS or seq[i+1] == 'VE': - return -1 - - while seq[i] != 'VE': - i += 1 - - if seq[i] == 'VE' and i != len(seq) - 1: - i += 1 - time_token = check_ATT(seq[i]) - if time_token: - return i - else: - return -1 - - elif seq[i] == 'VE' and i == len(seq) - 1: - return i - - else: - return -1 - -def main(): - for seq in sequences: - valid = 1 - i = 0 - if seq[0] == '[START]': - i += 1 - if seq[1].startswith('year:') and int(seq[1][5:]) >= 0: - i += 1 - if seq[2].startswith('age:') and int(seq[2][4:]) >= 0: - i += 1 - if seq[3] in GENDER: - i += 1 - if seq[4] in RACE: - i += 1 - - while i < len(seq): - # iterate through the rest of the visit sequence - if seq[i] == 'VS': - i += 1 - valid_index = valid_visit(i, seq) - if valid_index != -1: - i = valid_index - - else: - valid = 0 - break - - i += 1 - - RESULTS.append(valid) - return RESULTS - - - -if __name__ == '__main__': - results = main() - print(results) - - - -