diff --git a/examples/medical_transcriptions_classification.py b/examples/medical_transcriptions_classification.py new file mode 100644 index 00000000..041c855c --- /dev/null +++ b/examples/medical_transcriptions_classification.py @@ -0,0 +1,49 @@ +import numpy as np +import torch + +from pyhealth.datasets import MedicalTranscriptionsDataset +from pyhealth.datasets import get_dataloader +from pyhealth.models import TransformersModel +from pyhealth.trainer import Trainer + +root = "/srv/local/data/zw12/raw_data/MedicalTranscriptions" +base_dataset = MedicalTranscriptionsDataset(root) + +sample_dataset = base_dataset.set_task() + +ratios = [0.7, 0.1, 0.2] +index = np.arange(len(sample_dataset)) +np.random.shuffle(index) +s1 = int(len(sample_dataset) * ratios[0]) +s2 = int(len(sample_dataset) * (ratios[0] + ratios[1])) +train_index = index[: s1] +val_index = index[s1: s2] +test_index = index[s2:] +train_dataset = torch.utils.data.Subset(sample_dataset, train_index) +val_dataset = torch.utils.data.Subset(sample_dataset, val_index) +test_dataset = torch.utils.data.Subset(sample_dataset, test_index) + +train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) +val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) +test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) + +model = TransformersModel( + model_name="emilyalsentzer/Bio_ClinicalBERT", + dataset=sample_dataset, + feature_keys=["transcription"], + label_key="label", + mode="multiclass", +) + +trainer = Trainer(model=model) + +print(trainer.evaluate(test_dataloader)) + +trainer.train( + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + epochs=1, + monitor="accuracy" +) + +print(trainer.evaluate(test_dataloader)) diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index bd5c530e..b4401ccd 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -1,13 +1,17 @@ +from .base_dataset_v2 import BaseDataset from .base_ehr_dataset import BaseEHRDataset from .base_signal_dataset import BaseSignalDataset +from .covid19_cxr import COVID19CXRDataset from .eicu import eICUDataset +from .isruc import ISRUCDataset +from .medical_transriptions import MedicalTranscriptionsDataset from .mimic3 import MIMIC3Dataset from .mimic4 import MIMIC4Dataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset -from .sleepedf import SleepEDFDataset -from .isruc import ISRUCDataset -from .shhs import SHHSDataset from .sample_dataset import SampleBaseDataset, SampleSignalDataset, SampleEHRDataset -from .splitter import split_by_patient, split_by_visit +from .sample_dataset_v2 import SampleDataset +from .shhs import SHHSDataset +from .sleepedf import SleepEDFDataset +from .splitter import split_by_patient, split_by_visit, split_by_sample from .utils import collate_fn_dict, get_dataloader, strptime diff --git a/pyhealth/datasets/base_dataset_v2.py b/pyhealth/datasets/base_dataset_v2.py new file mode 100644 index 00000000..5cbf007e --- /dev/null +++ b/pyhealth/datasets/base_dataset_v2.py @@ -0,0 +1,74 @@ +import logging +from abc import ABC, abstractmethod +from typing import Optional, Dict + +from tqdm import tqdm + +from pyhealth.datasets.sample_dataset_v2 import SampleDataset +from pyhealth.tasks.task_template import TaskTemplate + +logger = logging.getLogger(__name__) + + +class BaseDataset(ABC): + """Abstract base dataset class.""" + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + **kwargs, + ): + if dataset_name is None: + dataset_name = self.__class__.__name__ + self.root = root + self.dataset_name = dataset_name + logger.debug(f"Processing {self.dataset_name} base dataset...") + self.patients = self.process() + # TODO: cache + return + + def __str__(self): + return f"Base dataset {self.dataset_name}" + + def __len__(self): + return len(self.patients) + + @abstractmethod + def process(self) -> Dict: + raise NotImplementedError + + @abstractmethod + def stat(self): + print(f"Statistics of {self.dataset_name}:") + return + + @property + def default_task(self) -> Optional[TaskTemplate]: + return None + + def set_task(self, task: Optional[TaskTemplate] = None) -> SampleDataset: + """Processes the base dataset to generate the task-specific sample dataset. + """ + # TODO: cache? + if task is None: + # assert default tasks exist in attr + assert self.default_task is not None, "No default tasks found" + task = self.default_task + + # load from raw data + logger.debug(f"Setting task for {self.dataset_name} base dataset...") + + samples = [] + for patient_id, patient in tqdm( + self.patients.items(), desc=f"Generating samples for {task.task_name}" + ): + samples.extend(task(patient)) + sample_dataset = SampleDataset( + samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name=self.dataset_name, + task_name=task, + ) + return sample_dataset diff --git a/pyhealth/datasets/covid19_cxr.py b/pyhealth/datasets/covid19_cxr.py new file mode 100644 index 00000000..74e85d39 --- /dev/null +++ b/pyhealth/datasets/covid19_cxr.py @@ -0,0 +1,144 @@ +import os +from collections import Counter + +import pandas as pd + +from pyhealth.datasets.base_dataset_v2 import BaseDataset +from pyhealth.tasks.covid19_cxr_classification import COVID19CXRClassification + + +class COVID19CXRDataset(BaseDataset): + """Base image dataset for COVID-19 Radiography Database + + Dataset is available at https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database + + **COVID-19 data: + ----------------------- + COVID data are collected from different publicly accessible dataset, online sources and published papers. + -2473 CXR images are collected from padchest dataset[1]. + -183 CXR images from a Germany medical school[2]. + -559 CXR image from SIRM, Github, Kaggle & Tweeter[3,4,5,6] + -400 CXR images from another Github source[7]. + + ***Normal images: + ---------------------------------------- + 10192 Normal data are collected from from three different dataset. + -8851 RSNA [8] + -1341 Kaggle [9] + + ***Lung opacity images: + ---------------------------------------- + 6012 Lung opacity CXR images are collected from Radiological Society of North America (RSNA) CXR dataset [8] + + ***Viral Pneumonia images: + ---------------------------------------- + 1345 Viral Pneumonia data are collected from the Chest X-Ray Images (pneumonia) database [9] + + Please cite the follwoing two articles if you are using this dataset: + -M.E.H. Chowdhury, T. Rahman, A. Khandakar, R. Mazhar, M.A. Kadir, Z.B. Mahbub, K.R. Islam, M.S. Khan, A. Iqbal, N. Al-Emadi, M.B.I. Reaz, M. T. Islam, “Can AI help in screening Viral and COVID-19 pneumonia?” IEEE Access, Vol. 8, 2020, pp. 132665 - 132676. + -Rahman, T., Khandakar, A., Qiblawey, Y., Tahir, A., Kiranyaz, S., Kashem, S.B.A., Islam, M.T., Maadeed, S.A., Zughaier, S.M., Khan, M.S. and Chowdhury, M.E., 2020. Exploring the Effect of Image Enhancement Techniques on COVID-19 Detection using Chest X-ray Images. arXiv preprint arXiv:2012.02238. + + **Reference: + [1] https://bimcv.cipf.es/bimcv-projects/bimcv-covid19/#1590858128006-9e640421-6711 + [2] https://github.com/ml-workgroup/covid-19-image-repository/tree/master/png + [3] https://sirm.org/category/senza-categoria/covid-19/ + [4] https://eurorad.org + [5] https://github.com/ieee8023/covid-chestxray-dataset + [6] https://figshare.com/articles/COVID-19_Chest_X-Ray_Image_Repository/12580328 + [7] https://github.com/armiro/COVID-CXNet + [8] https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data + [9] https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia + + Args: + dataset_name: name of the dataset. + root: root directory of the raw data. *You can choose to use the path to Cassette portion or the Telemetry portion.* + dev: whether to enable dev mode (only use a small subset of the data). + Default is False. + refresh_cache: whether to refresh the cache; if true, the dataset will + be processed from scratch and the cache will be updated. Default is False. + + Attributes: + root: root directory of the raw data (should contain many csv files). + dataset_name: name of the dataset. Default is the name of the class. + dev: whether to enable dev mode (only use a small subset of the data). + Default is False. + refresh_cache: whether to refresh the cache; if true, the dataset will + be processed from scratch and the cache will be updated. Default is False. + + Examples: + >>> dataset = COVID19CXRDataset( + root="/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset", + ) + >>> print(dataset[0]) + >>> dataset.stat() + >>> dataset.info() + """ + + def process(self): + # process and merge raw xlsx files from the dataset + covid = pd.DataFrame( + pd.read_excel(f"{self.root}/COVID.metadata.xlsx") + ) + covid["FILE NAME"] = covid["FILE NAME"].apply( + lambda x: f"{self.root}/COVID/images/{x}.png" + ) + covid["label"] = "COVID" + lung_opacity = pd.DataFrame( + pd.read_excel(f"{self.root}/Lung_Opacity.metadata.xlsx") + ) + lung_opacity["FILE NAME"] = lung_opacity["FILE NAME"].apply( + lambda x: f"{self.root}/Lung_Opacity/images/{x}.png" + ) + lung_opacity["label"] = "Lung Opacity" + normal = pd.DataFrame( + pd.read_excel(f"{self.root}/Normal.metadata.xlsx") + ) + normal["FILE NAME"] = normal["FILE NAME"].apply( + lambda x: x.capitalize() + ) + normal["FILE NAME"] = normal["FILE NAME"].apply( + lambda x: f"{self.root}/Normal/images/{x}.png" + ) + normal["label"] = "Normal" + viral_pneumonia = pd.DataFrame( + pd.read_excel(f"{self.root}/Viral Pneumonia.metadata.xlsx") + ) + viral_pneumonia["FILE NAME"] = viral_pneumonia["FILE NAME"].apply( + lambda x: f"{self.root}/Viral Pneumonia/images/{x}.png" + ) + viral_pneumonia["label"] = "Viral Pneumonia" + df = pd.concat( + [covid, lung_opacity, normal, viral_pneumonia], + axis=0, + ignore_index=True + ) + df = df.drop(columns=["FORMAT", "SIZE"]) + df.columns = ["path", "url", "label"] + for path in df.path: + assert os.path.isfile(os.path.join(self.root, path)) + # create patient dict + patients = {} + for index, row in df.iterrows(): + patients[index] = row.to_dict() + return patients + + def stat(self): + super().stat() + print(f"Number of samples: {len(self.patients)}") + count = Counter([v['label'] for v in self.patients.values()]) + print(f"Number of classes: {len(count)}") + print(f"Class distribution: {count}") + + @property + def default_task(self): + return COVID19CXRClassification() + + +if __name__ == "__main__": + dataset = COVID19CXRDataset( + root="/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset", + ) + print(list(dataset.patients.items())[0]) + dataset.stat() + samples = dataset.set_task() + print(samples[0]) diff --git a/pyhealth/datasets/featurizers/__init__.py b/pyhealth/datasets/featurizers/__init__.py new file mode 100644 index 00000000..7ad6268b --- /dev/null +++ b/pyhealth/datasets/featurizers/__init__.py @@ -0,0 +1,2 @@ +from .image import ImageFeaturizer +from .value import ValueFeaturizer \ No newline at end of file diff --git a/pyhealth/datasets/featurizers/image.py b/pyhealth/datasets/featurizers/image.py new file mode 100644 index 00000000..d867f8ef --- /dev/null +++ b/pyhealth/datasets/featurizers/image.py @@ -0,0 +1,23 @@ + +import PIL.Image +import torchvision.transforms as transforms + + +class ImageFeaturizer: + + def __init__(self): + self.transform = transforms.Compose([transforms.ToTensor()]) + + def encode(self, value): + image = PIL.Image.open(value) + image.load() # to avoid "Too many open files" errors + image = self.transform(image) + return image + + +if __name__ == "__main__": + sample_image = "/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset/Normal/images/Normal-6335.png" + featurizer = ImageFeaturizer() + print(featurizer) + print(type(featurizer)) + print(featurizer.encode(sample_image)) diff --git a/pyhealth/datasets/featurizers/value.py b/pyhealth/datasets/featurizers/value.py new file mode 100644 index 00000000..2bd4c584 --- /dev/null +++ b/pyhealth/datasets/featurizers/value.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + + +@dataclass +class ValueFeaturizer: + + def encode(self, value): + return value + + +if __name__ == "__main__": + featurizer = ValueFeaturizer() + print(featurizer) + print(featurizer.encode(2)) diff --git a/pyhealth/datasets/medical_transriptions.py b/pyhealth/datasets/medical_transriptions.py new file mode 100644 index 00000000..c809533b --- /dev/null +++ b/pyhealth/datasets/medical_transriptions.py @@ -0,0 +1,68 @@ +import os +from collections import Counter + +import pandas as pd + +from pyhealth.datasets.base_dataset_v2 import BaseDataset +from pyhealth.tasks.medical_transcriptions_classification import MedicalTranscriptionsClassification + + +class MedicalTranscriptionsDataset(BaseDataset): + """Medical transcription data scraped from mtsamples.com + + Dataset is available at https://www.kaggle.com/datasets/tboyle10/medicaltranscriptions + + Args: + dataset_name: name of the dataset. + root: root directory of the raw data. *You can choose to use the path to Cassette portion or the Telemetry portion.* + dev: whether to enable dev mode (only use a small subset of the data). + Default is False. + refresh_cache: whether to refresh the cache; if true, the dataset will + be processed from scratch and the cache will be updated. Default is False. + + Attributes: + root: root directory of the raw data (should contain many csv files). + dataset_name: name of the dataset. Default is the name of the class. + dev: whether to enable dev mode (only use a small subset of the data). + Default is False. + refresh_cache: whether to refresh the cache; if true, the dataset will + be processed from scratch and the cache will be updated. Default is False. + + Examples: + >>> dataset = MedicalTranscriptionsDataset( + root="/srv/local/data/zw12/raw_data/MedicalTranscriptions", + ) + >>> print(dataset[0]) + >>> dataset.stat() + >>> dataset.info() + """ + + def process(self): + df = pd.read_csv(f"{self.root}/mtsamples.csv", index_col=0) + + # create patient dict + patients = {} + for index, row in df.iterrows(): + patients[index] = row.to_dict() + return patients + + def stat(self): + super().stat() + print(f"Number of samples: {len(self.patients)}") + count = Counter([v['medical_specialty'] for v in self.patients.values()]) + print(f"Number of classes: {len(count)}") + print(f"Class distribution: {count}") + + @property + def default_task(self): + return MedicalTranscriptionsClassification() + + +if __name__ == "__main__": + dataset = MedicalTranscriptionsDataset( + root="/srv/local/data/zw12/raw_data/MedicalTranscriptions", + ) + print(list(dataset.patients.items())[0]) + dataset.stat() + samples = dataset.set_task() + print(samples[0]) diff --git a/pyhealth/datasets/sample_dataset_v2.py b/pyhealth/datasets/sample_dataset_v2.py new file mode 100644 index 00000000..cd066682 --- /dev/null +++ b/pyhealth/datasets/sample_dataset_v2.py @@ -0,0 +1,175 @@ +from typing import Dict, List, Optional + +from torch.utils.data import Dataset + +from pyhealth.datasets.featurizers import ImageFeaturizer, ValueFeaturizer + + +class SampleDataset(Dataset): + """Sample dataset class. + """ + + def __init__( + self, + samples: List[Dict], + input_schema: Dict[str, str], + output_schema: Dict[str, str], + dataset_name: Optional[str] = None, + task_name: Optional[str] = None, + ): + if dataset_name is None: + dataset_name = "" + if task_name is None: + task_name = "" + self.samples = samples + self.input_schema = input_schema + self.output_schema = output_schema + self.dataset_name = dataset_name + self.task_name = task_name + self.transform = None + # TODO: get rid of input_info + self.input_info: Dict = self.validate() + self.build() + + def validate(self): + input_keys = set(self.input_schema.keys()) + output_keys = set(self.output_schema.keys()) + for s in self.samples: + assert input_keys.issubset(s.keys()), \ + "Input schema does not match samples." + assert output_keys.issubset(s.keys()), \ + "Output schema does not match samples." + input_info = {} + # get label signal info + input_info["label"] = {"type": str, "dim": 0} + return input_info + + def build(self): + for k, v in self.input_schema.items(): + if v == "image": + self.input_schema[k] = ImageFeaturizer() + else: + self.input_schema[k] = ValueFeaturizer() + for k, v in self.output_schema.items(): + if v == "image": + self.output_schema[k] = ImageFeaturizer() + else: + self.output_schema[k] = ValueFeaturizer() + return + + def __getitem__(self, index) -> Dict: + """Returns a sample by index. + + Returns: + Dict, a dict with patient_id, visit_id/record_id, and other task-specific + attributes as key. Conversion to index/tensor will be done + in the model. + """ + out = {} + for k, v in self.samples[index].items(): + if k in self.input_schema: + out[k] = self.input_schema[k].encode(v) + elif k in self.output_schema: + out[k] = self.output_schema[k].encode(v) + else: + out[k] = v + + if self.transform is not None: + out = self.transform(out) + + return out + + def set_transform(self, transform): + """Sets the transform for the dataset. + + Args: + transform: a callable transform function. + """ + self.transform = transform + return + + def get_all_tokens( + self, key: str, remove_duplicates: bool = True, sort: bool = True + ) -> List[str]: + """Gets all tokens with a specific key in the samples. + + Args: + key: the key of the tokens in the samples. + remove_duplicates: whether to remove duplicates. Default is True. + sort: whether to sort the tokens by alphabet order. Default is True. + + Returns: + tokens: a list of tokens. + """ + # TODO: get rid of this function + input_type = self.input_info[key]["type"] + input_dim = self.input_info[key]["dim"] + if input_type in [float, int]: + assert input_dim == 0, f"Cannot get tokens for vector with key {key}" + + tokens = [] + for sample in self.samples: + if input_dim == 0: + # a single value + tokens.append(sample[key]) + elif input_dim == 2: + # a list of codes + tokens.extend(sample[key]) + elif input_dim == 3: + # a list of list of codes + tokens.extend(flatten_list(sample[key])) + else: + raise NotImplementedError + if remove_duplicates: + tokens = list(set(tokens)) + if sort: + tokens.sort() + return tokens + + def __str__(self): + """Prints some information of the dataset.""" + return f"Sample dataset {self.dataset_name} {self.task_name}" + + def __len__(self): + """Returns the number of samples in the dataset.""" + return len(self.samples) + + +if __name__ == "__main__": + samples = [ + { + "id": "0", + "single_vector": [1, 2, 3], + "list_codes": ["505800458", "50580045810", "50580045811"], + "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], + "list_list_codes": [ + ["A05B", "A05C", "A06A"], + ["A11D", "A11E"] + ], + "list_list_vectors": [ + [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], + [[7.7, 8.5, 9.4]], + ], + "image": "/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset/Normal/images/Normal-6335.png", + "text": "This is a sample text", + "label": 1, + }, + ] + + dataset = SampleDataset( + samples=samples, + input_schema={ + "id": "str", + "single_vector": "vector", + "list_codes": "list", + "list_vectors": "list", + "list_list_codes": "list", + "list_list_vectors": "list", + "image": "image", + "text": "text", + }, + output_schema={ + "label": "label" + } + ) + print(dataset[0]) diff --git a/pyhealth/datasets/splitter.py b/pyhealth/datasets/splitter.py index 3ffc1df8..ccfb7b14 100644 --- a/pyhealth/datasets/splitter.py +++ b/pyhealth/datasets/splitter.py @@ -38,9 +38,10 @@ def split_by_visit( np.random.shuffle(index) train_index = index[: int(len(dataset) * ratios[0])] val_index = index[ - int(len(dataset) * ratios[0]) : int(len(dataset) * (ratios[0] + ratios[1])) - ] - test_index = index[int(len(dataset) * (ratios[0] + ratios[1])) :] + int(len(dataset) * ratios[0]): int( + len(dataset) * (ratios[0] + ratios[1])) + ] + test_index = index[int(len(dataset) * (ratios[0] + ratios[1])):] train_dataset = torch.utils.data.Subset(dataset, train_index) val_dataset = torch.utils.data.Subset(dataset, val_index) test_dataset = torch.utils.data.Subset(dataset, test_index) @@ -75,9 +76,10 @@ def split_by_patient( np.random.shuffle(patient_indx) train_patient_indx = patient_indx[: int(num_patients * ratios[0])] val_patient_indx = patient_indx[ - int(num_patients * ratios[0]) : int(num_patients * (ratios[0] + ratios[1])) - ] - test_patient_indx = patient_indx[int(num_patients * (ratios[0] + ratios[1])) :] + int(num_patients * ratios[0]): int( + num_patients * (ratios[0] + ratios[1])) + ] + test_patient_indx = patient_indx[int(num_patients * (ratios[0] + ratios[1])):] train_index = list( chain(*[dataset.patient_to_index[i] for i in train_patient_indx]) ) @@ -87,3 +89,40 @@ def split_by_patient( val_dataset = torch.utils.data.Subset(dataset, val_index) test_dataset = torch.utils.data.Subset(dataset, test_index) return train_dataset, val_dataset, test_dataset + + +def split_by_sample( + dataset: SampleBaseDataset, + ratios: Union[Tuple[float, float, float], List[float]], + seed: Optional[int] = None, +): + """Splits the dataset by sample + + Args: + dataset: a `SampleBaseDataset` object + ratios: a list/tuple of ratios for train / val / test + seed: random seed for shuffling the dataset + + Returns: + train_dataset, val_dataset, test_dataset: three subsets of the dataset of + type `torch.utils.data.Subset`. + + Note: + The original dataset can be accessed by `train_dataset.dataset`, + `val_dataset.dataset`, and `test_dataset.dataset`. + """ + if seed is not None: + np.random.seed(seed) + assert sum(ratios) == 1.0, "ratios must sum to 1.0" + index = np.arange(len(dataset)) + np.random.shuffle(index) + train_index = index[: int(len(dataset) * ratios[0])] + val_index = index[ + int(len(dataset) * ratios[0]): int( + len(dataset) * (ratios[0] + ratios[1])) + ] + test_index = index[int(len(dataset) * (ratios[0] + ratios[1])):] + train_dataset = torch.utils.data.Subset(dataset, train_index) + val_dataset = torch.utils.data.Subset(dataset, val_index) + test_dataset = torch.utils.data.Subset(dataset, test_index) + return train_dataset, val_dataset, test_dataset diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index b0467853..0ebbec30 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -23,3 +23,5 @@ from .stagenet import StageNet, StageNetLayer from .tcn import TCN, TCNLayer from .molerec import MoleRec, MoleRecLayer +from .torchvision_model import TorchvisionModel +from .transformers_model import TransformersModel diff --git a/pyhealth/models/torchvision_model.py b/pyhealth/models/torchvision_model.py new file mode 100644 index 00000000..a6d66e15 --- /dev/null +++ b/pyhealth/models/torchvision_model.py @@ -0,0 +1,182 @@ +from typing import List, Dict + +import torch +import torch.nn as nn +import torchvision + +from pyhealth.datasets.sample_dataset_v2 import SampleDataset +from pyhealth.models import BaseModel + +SUPPORTED_MODELS = [ + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "densenet121", + "densenet161", + "densenet169", + "densenet201", + "vit_b_16", + "vit_b_32", + "vit_l_16", + "vit_l_32", + "vit_h_14", + "swin_t", + "swin_s", + "swin_b", +] + +SUPPORTED_MODELS_FINAL_LAYER = {} +for model in SUPPORTED_MODELS: + if "resnet" in model: + SUPPORTED_MODELS_FINAL_LAYER[model] = "fc" + elif "densenet" in model: + SUPPORTED_MODELS_FINAL_LAYER[model] = "classifier" + elif "vit" in model: + SUPPORTED_MODELS_FINAL_LAYER[model] = "heads.head" + elif "swin" in model: + SUPPORTED_MODELS_FINAL_LAYER[model] = "head" + else: + raise NotImplementedError + + +class TorchvisionModel(BaseModel): + """Models from PyTorch's torchvision package. + + This class is a wrapper for models from torchvision. It will automatically load + the corresponding model and weights from torchvision. The final layer will be + replaced with a linear layer with the correct output size. + + -----------------------------------ResNet------------------------------------------ + Paper: Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. Deep Residual Learning + for Image Recognition. CVPR 2016. + -----------------------------------DenseNet---------------------------------------- + Paper: Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger. + Densely Connected Convolutional Networks. CVPR 2017. + ----------------------------Vision Transformer (ViT)------------------------------- + Paper: Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, + Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, + Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. An Image is Worth + 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021. + ----------------------------Swin Transformer (and V2)------------------------------ + Paper: Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, + Baining Guo. Swin Transformer: Hierarchical Vision Transformer Using Shifted + Windows. ICCV 2021. + + Paper: Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, + Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. Swin Transformer V2: Scaling + Up Capacity and Resolution. CVPR 2022. + ----------------------------------------------------------------------------------- + + Args: + dataset: the dataset to train the model. It is used to query certain + information such as the set of all tokens. + feature_keys: list of keys in samples to use as features, e.g., ["image"]. + Only one feature is supported. + label_key: key in samples to use as label, e.g., "drugs". + mode: one of "binary", "multiclass", or "multilabel". + model_name: str, name of the model to use, e.g., "resnet18". + See SUPPORTED_MODELS in the source code for the full list. + model_config: dict, kwargs to pass to the model constructor, + e.g., {"weights": "DEFAULT"}. See the torchvision documentation for the + set of supported kwargs for each model. + ----------------------------------------------------------------------------------- + """ + + def __init__( + self, + dataset: SampleDataset, + feature_keys: List[str], + label_key: str, + mode: str, + model_name: str, + model_config: dict, + ): + super(TorchvisionModel, self).__init__( + dataset=dataset, + feature_keys=feature_keys, + label_key=label_key, + mode=mode, + ) + + self.model_name = model_name + self.model_config = model_config + + assert len(feature_keys) == 1, "Only one feature is supported!" + assert model_name in SUPPORTED_MODELS_FINAL_LAYER.keys(), \ + f"PyHealth does not currently include {model_name} model!" + + self.model = torchvision.models.get_model(model_name, **model_config) + final_layer_name = SUPPORTED_MODELS_FINAL_LAYER[model_name] + final_layer = self.model + for name in final_layer_name.split("."): + final_layer = getattr(final_layer, name) + hidden_dim = final_layer.in_features + self.label_tokenizer = self.get_label_tokenizer() + output_size = self.get_output_size(self.label_tokenizer) + setattr(self.model, final_layer_name.split(".")[0], nn.Linear(hidden_dim, output_size)) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation.""" + # concat the info within one batch (batch, channel, length) + x = kwargs[self.feature_keys[0]] + x = torch.stack(x, dim=0).to(self.device) + if x.shape[1] == 1: + x = x.repeat((1, 3, 1, 1)) + logits = self.model(x) + y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + return { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + } + + +if __name__ == "__main__": + from pyhealth.datasets.utils import get_dataloader + from torchvision import transforms + from pyhealth.datasets import COVID19CXRDataset + + base_dataset = COVID19CXRDataset( + root="/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset", + ) + + sample_dataset = base_dataset.set_task() + + transform = transforms.Compose([ + transforms.Grayscale(), + transforms.Resize((224, 224)), + transforms.Normalize(mean=[0.5862785803043838], std=[0.27950088968644304]) + ]) + + + def encode(sample): + sample["path"] = transform(sample["path"]) + return sample + + + sample_dataset.set_transform(encode) + + train_loader = get_dataloader(sample_dataset, batch_size=16, shuffle=True) + + model = TorchvisionModel( + dataset=sample_dataset, + feature_keys=["path"], + label_key="label", + mode="multiclass", + model_name="vit_b_16", + model_config={"weights": "DEFAULT"}, + ) + + # data batch + data_batch = next(iter(train_loader)) + + # try the model + ret = model(**data_batch) + print(ret) + + # try loss backward + ret["loss"].backward() diff --git a/pyhealth/models/transformers_model.py b/pyhealth/models/transformers_model.py new file mode 100644 index 00000000..cee6f10e --- /dev/null +++ b/pyhealth/models/transformers_model.py @@ -0,0 +1,84 @@ +from typing import List, Dict + +import torch +import torch.nn as nn +from transformers import AutoModel, AutoTokenizer + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel + + +class TransformersModel(BaseModel): + """Transformers class for Huggingface models. + """ + + def __init__( + self, + dataset: SampleDataset, + feature_keys: List[str], + label_key: str, + mode: str, + model_name: str, + ): + super(TransformersModel, self).__init__( + dataset=dataset, + feature_keys=feature_keys, + label_key=label_key, + mode=mode, + ) + self.model_name = model_name + self.model = AutoModel.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.label_tokenizer = self.get_label_tokenizer() + output_size = self.get_output_size(self.label_tokenizer) + hidden_dim = self.model.config.hidden_size + self.fc = nn.Linear(hidden_dim, output_size) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation.""" + # concat the info within one batch (batch, channel, length) + x = kwargs[self.feature_keys[0]] + x = self.tokenizer( + x, return_tensors="pt", padding=True, truncation=True, max_length=256 + ) + x = x.to(self.device) + embeddings = self.model(**x).pooler_output + logits = self.fc(embeddings) + y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + return { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + } + + +if __name__ == "__main__": + from pyhealth.datasets import MedicalTranscriptionsDataset, get_dataloader + + base_dataset = MedicalTranscriptionsDataset( + root="/srv/local/data/zw12/raw_data/MedicalTranscriptions" + ) + + sample_dataset = base_dataset.set_task() + + train_loader = get_dataloader(sample_dataset, batch_size=16, shuffle=True) + + model = TransformersModel( + dataset=sample_dataset, + feature_keys=["transcription"], + label_key="label", + mode="multiclass", + model_name="emilyalsentzer/Bio_ClinicalBERT", + ) + + # data batch + data_batch = next(iter(train_loader)) + + # try the model + ret = model(**data_batch) + print(ret) + + # try loss backward + ret["loss"].backward() diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 70c2b848..d3c38e4e 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -1,3 +1,4 @@ +from .task_template import TaskTemplate from .drug_recommendation import ( drug_recommendation_eicu_fn, drug_recommendation_mimic3_fn, @@ -29,3 +30,5 @@ sleep_staging_isruc_fn, sleep_staging_shhs_fn, ) +from .covid19_cxr_classification import COVID19CXRClassification +from .medical_transcriptions_classification import MedicalTranscriptionsClassification diff --git a/pyhealth/tasks/covid19_cxr_classification.py b/pyhealth/tasks/covid19_cxr_classification.py new file mode 100644 index 00000000..bbb5b364 --- /dev/null +++ b/pyhealth/tasks/covid19_cxr_classification.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass, field +from typing import Dict + +from pyhealth.tasks.task_template import TaskTemplate + + +@dataclass(frozen=True) +class COVID19CXRClassification(TaskTemplate): + task_name: str = "COVID19CXRClassification" + input_schema: Dict[str, str] = field(default_factory=lambda: {"path": "image"}) + output_schema: Dict[str, str] = field(default_factory=lambda: {"label": "label"}) + + def __call__(self, patient): + return [patient] + + +if __name__ == "__main__": + task = COVID19CXRClassification() + print(task) + print(type(task)) diff --git a/pyhealth/tasks/medical_transcriptions_classification.py b/pyhealth/tasks/medical_transcriptions_classification.py new file mode 100644 index 00000000..1dbcff12 --- /dev/null +++ b/pyhealth/tasks/medical_transcriptions_classification.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass, field +from typing import Dict +import pandas as pd + +from pyhealth.tasks.task_template import TaskTemplate + + +@dataclass(frozen=True) +class MedicalTranscriptionsClassification(TaskTemplate): + task_name: str = "MedicalTranscriptionsClassification" + input_schema: Dict[str, str] = field(default_factory=lambda: {"transcription": "text"}) + output_schema: Dict[str, str] = field(default_factory=lambda: {"label": "label"}) + + def __call__(self, patient): + if patient["transcription"] is None or pd.isna(patient["transcription"]): + return [] + if patient["medical_specialty"] is None or pd.isna(patient["medical_specialty"]): + return [] + sample = { + "transcription": patient["transcription"], + "label": patient["medical_specialty"], + } + return [sample] + + +if __name__ == "__main__": + task = MedicalTranscriptionsClassification() + print(task) + print(type(task)) diff --git a/pyhealth/tasks/task_template.py b/pyhealth/tasks/task_template.py new file mode 100644 index 00000000..0415b5b0 --- /dev/null +++ b/pyhealth/tasks/task_template.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, List + + +@dataclass(frozen=True) +class TaskTemplate(ABC): + task_name: str + input_schema: Dict[str, str] + output_schema: Dict[str, str] + + @abstractmethod + def __call__(self, patient) -> List[Dict]: + raise NotImplementedError