From 334226aac5ab5c4393ecf60a1f4d9be265c3dda9 Mon Sep 17 00:00:00 2001 From: Zijian Wu Date: Mon, 7 Aug 2023 05:02:52 +0000 Subject: [PATCH] new dataset and new task --- pyhealth/datasets/base_dataset_v2.py | 8 +- pyhealth/datasets/chexpert_v1.py | 106 ++++++++++++++++++ pyhealth/datasets/covid19_cxr.py | 10 +- pyhealth/datasets/sample_dataset_v2.py | 2 +- pyhealth/tasks/chexpert_v1_classification.py | 20 ++++ pyhealth/tasks/covid19_cxr_classification.py | 2 +- .../medical_transcriptions_classification.py | 2 +- 7 files changed, 140 insertions(+), 10 deletions(-) create mode 100644 pyhealth/datasets/chexpert_v1.py create mode 100644 pyhealth/tasks/chexpert_v1_classification.py diff --git a/pyhealth/datasets/base_dataset_v2.py b/pyhealth/datasets/base_dataset_v2.py index 5cbf007e..8b550311 100644 --- a/pyhealth/datasets/base_dataset_v2.py +++ b/pyhealth/datasets/base_dataset_v2.py @@ -3,9 +3,11 @@ from typing import Optional, Dict from tqdm import tqdm - -from pyhealth.datasets.sample_dataset_v2 import SampleDataset -from pyhealth.tasks.task_template import TaskTemplate +import sys +sys.path.append('.') +from sample_dataset_v2 import SampleDataset# from pyhealth.datasets.sample_dataset_v2 import SampleDataset +sys.path.append('..') +from tasks.task_template import TaskTemplate #from pyhealth.tasks.task_template import TaskTemplate logger = logging.getLogger(__name__) diff --git a/pyhealth/datasets/chexpert_v1.py b/pyhealth/datasets/chexpert_v1.py new file mode 100644 index 00000000..d1917b62 --- /dev/null +++ b/pyhealth/datasets/chexpert_v1.py @@ -0,0 +1,106 @@ +import os +from collections import Counter +import pandas as pd +from tqdm import tqdm + +from base_dataset_v2 import BaseDataset# from pyhealth.datasets.base_dataset_v2 import BaseDataset +from tasks.chexpert_v1_classification import CheXpertV1Classification + +class CheXpertV1Dataset(BaseDataset): + """Base image dataset for CheXpert Database + + Dataset is available at https://stanfordaimi.azurewebsites.net/datasets/8cbd9ed4-2eb9-4565-affc-111cf4f7ebe2 + + **CheXpert v1 data: + ----------------------- + - Train: 223414 images from 64540 patients + - Validation: 902 images from 700 patients + + The CheXpert dataset consists of 14 labeled observations (pathology): + - No Finding, Enlarged Cardiomediastinum, Cardiomegaly, Lung Opacity, Lung Lesion, Edema, Consolidation, Pneumonia, + Atelectasis, Pneumothorax, Pleural Effusion, Pleural Other, Fracture, Support Devices + For each observation (pathology), there are 4 status: + - positive (1), negative (0), uncertain (-1), unmentioned (2) + + Please cite the follwoing articles if you are using this dataset: + - Irvin, J., Rajpurkar, P., Ko, M., Yu, Y., Ciurea-Ilcus, S., Chute, C., Marklund, H., Haghgoo, B., Ball, R., + Shpanskaya, K. and Seekins, J., 2019, July. Chexpert: A large chest radiograph dataset with uncertainty labels + and expert comparison. In Proceedings of the AAAI conference on artificial intelligence (Vol. 33, No. 01, pp. 590-597). + + Args: + dataset_name: name of the dataset. + root: root directory of the raw data (The parent directory of /CheXpert-v1.0). *You can choose to use the path to Cassette portion or the Telemetry portion.* + dev: whether to enable dev mode (only use a small subset of the data). + Default is False. + refresh_cache: whether to refresh the cache; if true, the dataset will + be processed from scratch and the cache will be updated. Default is False. + + Attributes: + root: root directory of the raw data (should contain many csv files). + dataset_name: name of the dataset. Default is the name of the class. + dev: whether to enable dev mode (only use a small subset of the data). + Default is False. + refresh_cache: whether to refresh the cache; if true, the dataset will + be processed from scratch and the cache will be updated. Default is False. + + Examples: + + >>> dataset = CheXpertV1Dataset( + root="/home/wuzijian1231/Datasets", + ) + >>> print(dataset.patients[0]) + >>> dataset.stat() + >>> samples = dataset.set_task() + >>> print(samples[0]) + """ + + def process(self): + # process and merge raw xlsx files from the dataset + df = pd.DataFrame( + pd.read_csv(f"{self.root}/CheXpert-v1.0/train.csv") + ) + df.fillna(value=2.0, inplace=True) # positive (1), negative (0), uncertain (-1), unmentioned (2) + df["Path"] = df["Path"].apply( + lambda x: f"{self.root}/{x}" + ) + df = df.drop(columns=["Sex", "Age", "Frontal/Lateral", "AP/PA"]) + self.pathology = [c for c in df] + del self.pathology[0] + df_list= [] + for p in self.pathology: + df_list.append(df[p]) + self.df_label = pd.concat(df_list, axis=1) + labels = self.df_label.values.tolist() + df.columns = [col for col in df] + for path in tqdm(df.Path): + assert os.path.isfile(path) + # create patient dict + patients = {} + for index, row in tqdm(df.iterrows()): + patients[index] = {'path':row['Path'], 'label':labels[index]} + return patients + + def stat(self): + super().stat() + print(f"Number of samples: {len(self.patients)}") + print(f"Number of Pathology: {len(self.pathology)}") + count = {} + for p in self.pathology: + cn = self.df_label[p] + count[p] = Counter(cn) + for p in self.pathology: + print(f"Class distribution - {p}: {count[p]}") + + @property + def default_task(self): + return CheXpertV1Classification() + +if __name__ == "__main__": + dataset = CheXpertV1Dataset( + root="/home/wuzijian1231/Datasets", + ) + print(dataset.patients[0]) + dataset.stat() + samples = dataset.set_task() + print(samples[0]) + \ No newline at end of file diff --git a/pyhealth/datasets/covid19_cxr.py b/pyhealth/datasets/covid19_cxr.py index 74e85d39..b6b05d7c 100644 --- a/pyhealth/datasets/covid19_cxr.py +++ b/pyhealth/datasets/covid19_cxr.py @@ -2,10 +2,11 @@ from collections import Counter import pandas as pd +import sys +sys.path.append('.') -from pyhealth.datasets.base_dataset_v2 import BaseDataset -from pyhealth.tasks.covid19_cxr_classification import COVID19CXRClassification - +from base_dataset_v2 import BaseDataset +from tasks.covid19_cxr_classification import COVID19CXRClassification# from pyhealth.tasks.covid19_cxr_classification import COVID19CXRClassification class COVID19CXRDataset(BaseDataset): """Base image dataset for COVID-19 Radiography Database @@ -120,6 +121,7 @@ def process(self): patients = {} for index, row in df.iterrows(): patients[index] = row.to_dict() + breakpoint() return patients def stat(self): @@ -136,7 +138,7 @@ def default_task(self): if __name__ == "__main__": dataset = COVID19CXRDataset( - root="/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset", + root="/home/wuzijian1231/Datasets/COVID-19_Radiography_Dataset"#"/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset", ) print(list(dataset.patients.items())[0]) dataset.stat() diff --git a/pyhealth/datasets/sample_dataset_v2.py b/pyhealth/datasets/sample_dataset_v2.py index cd066682..6aa14490 100644 --- a/pyhealth/datasets/sample_dataset_v2.py +++ b/pyhealth/datasets/sample_dataset_v2.py @@ -2,7 +2,7 @@ from torch.utils.data import Dataset -from pyhealth.datasets.featurizers import ImageFeaturizer, ValueFeaturizer +from featurizers import ImageFeaturizer, ValueFeaturizer# from pyhealth.datasets.featurizers import ImageFeaturizer, ValueFeaturizer class SampleDataset(Dataset): diff --git a/pyhealth/tasks/chexpert_v1_classification.py b/pyhealth/tasks/chexpert_v1_classification.py new file mode 100644 index 00000000..ff2eabd1 --- /dev/null +++ b/pyhealth/tasks/chexpert_v1_classification.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass, field +from typing import Dict + +from tasks.task_template import TaskTemplate # from pyhealth.tasks import TaskTemplate + + +@dataclass(frozen=True) +class CheXpertV1Classification(TaskTemplate): + task_name: str = "CheXpertV1Classification" + input_schema: Dict[str, str] = field(default_factory=lambda: {"path": "image"}) + output_schema: Dict[str, str] = field(default_factory=lambda: {"label": "label"}) + + def __call__(self, patient): + return [patient] + + +if __name__ == "__main__": + task = CheXpertV1Classification() + print(task) + print(type(task)) diff --git a/pyhealth/tasks/covid19_cxr_classification.py b/pyhealth/tasks/covid19_cxr_classification.py index 25c494f7..5d7e05eb 100644 --- a/pyhealth/tasks/covid19_cxr_classification.py +++ b/pyhealth/tasks/covid19_cxr_classification.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from typing import Dict -from pyhealth.tasks import TaskTemplate +from tasks.task_template import TaskTemplate # from pyhealth.tasks import TaskTemplate @dataclass(frozen=True) diff --git a/pyhealth/tasks/medical_transcriptions_classification.py b/pyhealth/tasks/medical_transcriptions_classification.py index 2f7fa5b3..5b1c4b9b 100644 --- a/pyhealth/tasks/medical_transcriptions_classification.py +++ b/pyhealth/tasks/medical_transcriptions_classification.py @@ -2,7 +2,7 @@ from typing import Dict import pandas as pd -from pyhealth.tasks import TaskTemplate +from tasks.task_template import TaskTemplate# from pyhealth.tasks import TaskTemplate @dataclass(frozen=True)