Skip to content

Commit

Permalink
created integration tests for generating pretraining data (generate_t…
Browse files Browse the repository at this point in the history
…raining_data) and finetuning data (hf_admission)
  • Loading branch information
ChaoPang committed Sep 8, 2024
1 parent 1150a2a commit 48a293a
Show file tree
Hide file tree
Showing 61 changed files with 43 additions and 51 deletions.
38 changes: 19 additions & 19 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10.0
uses: actions/setup-python@v3
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install -e .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
PYTHONPATH=./: pytest
- uses: actions/checkout@v3
- name: Set up Python 3.10.0
uses: actions/setup-python@v3
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install -e .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
PYTHONPATH=./: pytest
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ build/
.eggs/
*.egg-info/
*__pycache__/
*venv*
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"numpy==1.24.3",
"packaging==23.2",
"pandas==2.2.0",
"pyspark==3.2.2"
"pyspark==3.1.2"
]

[tool.setuptools_scm]
Expand Down
Binary file added sample_data/omop_sample/concept/._SUCCESS.crc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file added sample_data/omop_sample/person/._SUCCESS.crc
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
1 change: 1 addition & 0 deletions src/cehrbert_data/utils/spark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,7 @@ def extract_ehr_records(
patient_ehr_records["person_id"],
patient_ehr_records["standard_concept_id"],
patient_ehr_records["date"],
patient_ehr_records["datetime"],
patient_ehr_records["visit_occurrence_id"],
patient_ehr_records["domain"],
visit_occurrence["visit_concept_id"],
Expand Down
36 changes: 14 additions & 22 deletions tests/integration_tests/test_generate_training_data.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,24 @@
import sys
import unittest
from ..pyspark_test_base import PySparkAbstract
from cehrbert_data.spark_parse_args import create_spark_args
from cehrbert_data.prediction_cohorts.hf_readmission import main
from cehrbert_data.decorators.patient_event_decorator import AttType
from cehrbert_data.apps.generate_training_data import main


class HfReadmissionTest(PySparkAbstract):

def test_run_pyspark_app(self):
sys.argv = [
"hf_readmission.py",
"--cohort_name", "hf_readmission",
"--input_folder", self.get_sample_data_folder(),
"--output_folder", self.get_output_folder(),
"--date_lower_bound", "1985-01-01",
"--date_upper_bound", "2023-12-31",
"--age_lower_bound", "18",
"--age_upper_bound", "100",
"--observation_window", "360",
"--prediction_start_days", "0",
"--prediction_window", "30",
"--include_visit_type",
"--is_new_patient_representation",
"--att_type", "cehr_bert",
"--ehr_table_list", "condition_occurrence", "procedure_occurrence", "drug_exposure"
]

main(create_spark_args())
main(
input_folder=self.get_sample_data_folder(),
output_folder=self.get_output_folder(),
domain_table_list=["condition_occurrence", "drug_exposure", "procedure_occurrence"],
date_filter="1985-01-01",
include_visit_type=True,
is_new_patient_representation=True,
include_concept_list=False,
gpt_patient_sequence=True,
apply_age_filter=True,
att_type=AttType.DAY
)


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests/test_hf_readmission.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import sys
import unittest
from ..pyspark_test import PySparkAbstract
from cehrbert_data.spark_parse_args import create_spark_args
from ..pyspark_test_base import PySparkAbstract
from cehrbert_data.utils.spark_parse_args import create_spark_args
from cehrbert_data.prediction_cohorts.hf_readmission import main


class HfReadmissionTest(PySparkAbstract):

def run_pyspark_app_test(self):
def test_run_pyspark_app(self):
sys.argv = [
"hf_readmission.py",
"--cohort_name", "hf_readmission",
Expand All @@ -23,7 +23,7 @@ def run_pyspark_app_test(self):
"--include_visit_type",
"--is_new_patient_representation",
"--att_type", "cehr_bert",
"--ehr_table_list", "condition_occurrence procedure_occurrence drug_exposure"
"--ehr_table_list", "condition_occurrence", "procedure_occurrence", "drug_exposure"
]

main(create_spark_args())
Expand Down
8 changes: 3 additions & 5 deletions tests/pyspark_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
import unittest
import tempfile
from pathlib import Path
from abc import abstractmethod
from cehrbert_data.spark_parse_args import create_spark_args
from cehrbert_data.prediction_cohorts.hf_readmission import main
from abc import abstractmethod, ABC


class PySparkAbstract(unittest.TestCase):
class PySparkAbstract(unittest.TestCase, ABC):

@classmethod
def setUpClass(cls):
Expand All @@ -34,7 +32,7 @@ def setUp(self):

@abstractmethod
def test_run_pyspark_app(self):
raise NotImplementedError("Not implemented yet")
pass

def get_sample_data_folder(self):
return self.data_folder
Expand Down

0 comments on commit 48a293a

Please sign in to comment.