Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core structure updates: image and text featurizer #172

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
Draft
49 changes: 49 additions & 0 deletions examples/medical_transcriptions_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import torch

from pyhealth.datasets import MedicalTranscriptionsDataset
from pyhealth.datasets import get_dataloader
from pyhealth.models import TransformersModel
from pyhealth.trainer import Trainer

root = "/srv/local/data/zw12/raw_data/MedicalTranscriptions"
base_dataset = MedicalTranscriptionsDataset(root)

sample_dataset = base_dataset.set_task()

ratios = [0.7, 0.1, 0.2]
index = np.arange(len(sample_dataset))
np.random.shuffle(index)
s1 = int(len(sample_dataset) * ratios[0])
s2 = int(len(sample_dataset) * (ratios[0] + ratios[1]))
train_index = index[: s1]
val_index = index[s1: s2]
test_index = index[s2:]
train_dataset = torch.utils.data.Subset(sample_dataset, train_index)
val_dataset = torch.utils.data.Subset(sample_dataset, val_index)
test_dataset = torch.utils.data.Subset(sample_dataset, test_index)

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

model = TransformersModel(
model_name="emilyalsentzer/Bio_ClinicalBERT",
dataset=sample_dataset,
feature_keys=["transcription"],
label_key="label",
mode="multiclass",
)

trainer = Trainer(model=model)

print(trainer.evaluate(test_dataloader))

trainer.train(
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
epochs=1,
monitor="accuracy"
)

print(trainer.evaluate(test_dataloader))
12 changes: 8 additions & 4 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from .base_dataset_v2 import BaseDataset
from .base_ehr_dataset import BaseEHRDataset
from .base_signal_dataset import BaseSignalDataset
from .covid19_cxr import COVID19CXRDataset
from .eicu import eICUDataset
from .isruc import ISRUCDataset
from .medical_transriptions import MedicalTranscriptionsDataset
from .mimic3 import MIMIC3Dataset
from .mimic4 import MIMIC4Dataset
from .mimicextract import MIMICExtractDataset
from .omop import OMOPDataset
from .sleepedf import SleepEDFDataset
from .isruc import ISRUCDataset
from .shhs import SHHSDataset
from .sample_dataset import SampleBaseDataset, SampleSignalDataset, SampleEHRDataset
from .splitter import split_by_patient, split_by_visit
from .sample_dataset_v2 import SampleDataset
from .shhs import SHHSDataset
from .sleepedf import SleepEDFDataset
from .splitter import split_by_patient, split_by_visit, split_by_sample
from .utils import collate_fn_dict, get_dataloader, strptime
74 changes: 74 additions & 0 deletions pyhealth/datasets/base_dataset_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import logging
from abc import ABC, abstractmethod
from typing import Optional, Dict

from tqdm import tqdm

from pyhealth.datasets.sample_dataset_v2 import SampleDataset
from pyhealth.tasks.task_template import TaskTemplate

logger = logging.getLogger(__name__)


class BaseDataset(ABC):
"""Abstract base dataset class."""

def __init__(
self,
root: str,
dataset_name: Optional[str] = None,
**kwargs,
):
if dataset_name is None:
dataset_name = self.__class__.__name__
self.root = root
self.dataset_name = dataset_name
logger.debug(f"Processing {self.dataset_name} base dataset...")
self.patients = self.process()
# TODO: cache
return

def __str__(self):
return f"Base dataset {self.dataset_name}"

def __len__(self):
return len(self.patients)

@abstractmethod
def process(self) -> Dict:
raise NotImplementedError

@abstractmethod
def stat(self):
print(f"Statistics of {self.dataset_name}:")
return

@property
def default_task(self) -> Optional[TaskTemplate]:
return None

def set_task(self, task: Optional[TaskTemplate] = None) -> SampleDataset:
"""Processes the base dataset to generate the task-specific sample dataset.
"""
# TODO: cache?
if task is None:
# assert default tasks exist in attr
assert self.default_task is not None, "No default tasks found"
task = self.default_task

# load from raw data
logger.debug(f"Setting task for {self.dataset_name} base dataset...")

samples = []
for patient_id, patient in tqdm(
self.patients.items(), desc=f"Generating samples for {task.task_name}"
):
samples.extend(task(patient))
sample_dataset = SampleDataset(
samples,
input_schema=task.input_schema,
output_schema=task.output_schema,
dataset_name=self.dataset_name,
task_name=task,
)
return sample_dataset
144 changes: 144 additions & 0 deletions pyhealth/datasets/covid19_cxr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import os
from collections import Counter

import pandas as pd

from pyhealth.datasets.base_dataset_v2 import BaseDataset
from pyhealth.tasks.covid19_cxr_classification import COVID19CXRClassification


class COVID19CXRDataset(BaseDataset):
"""Base image dataset for COVID-19 Radiography Database

Dataset is available at https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database

**COVID-19 data:
-----------------------
COVID data are collected from different publicly accessible dataset, online sources and published papers.
-2473 CXR images are collected from padchest dataset[1].
-183 CXR images from a Germany medical school[2].
-559 CXR image from SIRM, Github, Kaggle & Tweeter[3,4,5,6]
-400 CXR images from another Github source[7].

***Normal images:
----------------------------------------
10192 Normal data are collected from from three different dataset.
-8851 RSNA [8]
-1341 Kaggle [9]

***Lung opacity images:
----------------------------------------
6012 Lung opacity CXR images are collected from Radiological Society of North America (RSNA) CXR dataset [8]

***Viral Pneumonia images:
----------------------------------------
1345 Viral Pneumonia data are collected from the Chest X-Ray Images (pneumonia) database [9]

Please cite the follwoing two articles if you are using this dataset:
-M.E.H. Chowdhury, T. Rahman, A. Khandakar, R. Mazhar, M.A. Kadir, Z.B. Mahbub, K.R. Islam, M.S. Khan, A. Iqbal, N. Al-Emadi, M.B.I. Reaz, M. T. Islam, “Can AI help in screening Viral and COVID-19 pneumonia?” IEEE Access, Vol. 8, 2020, pp. 132665 - 132676.
-Rahman, T., Khandakar, A., Qiblawey, Y., Tahir, A., Kiranyaz, S., Kashem, S.B.A., Islam, M.T., Maadeed, S.A., Zughaier, S.M., Khan, M.S. and Chowdhury, M.E., 2020. Exploring the Effect of Image Enhancement Techniques on COVID-19 Detection using Chest X-ray Images. arXiv preprint arXiv:2012.02238.

**Reference:
[1] https://bimcv.cipf.es/bimcv-projects/bimcv-covid19/#1590858128006-9e640421-6711
[2] https://github.com/ml-workgroup/covid-19-image-repository/tree/master/png
[3] https://sirm.org/category/senza-categoria/covid-19/
[4] https://eurorad.org
[5] https://github.com/ieee8023/covid-chestxray-dataset
[6] https://figshare.com/articles/COVID-19_Chest_X-Ray_Image_Repository/12580328
[7] https://github.com/armiro/COVID-CXNet
[8] https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data
[9] https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia

Args:
dataset_name: name of the dataset.
root: root directory of the raw data. *You can choose to use the path to Cassette portion or the Telemetry portion.*
dev: whether to enable dev mode (only use a small subset of the data).
Default is False.
refresh_cache: whether to refresh the cache; if true, the dataset will
be processed from scratch and the cache will be updated. Default is False.

Attributes:
root: root directory of the raw data (should contain many csv files).
dataset_name: name of the dataset. Default is the name of the class.
dev: whether to enable dev mode (only use a small subset of the data).
Default is False.
refresh_cache: whether to refresh the cache; if true, the dataset will
be processed from scratch and the cache will be updated. Default is False.

Examples:
>>> dataset = COVID19CXRDataset(
root="/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset",
)
>>> print(dataset[0])
>>> dataset.stat()
>>> dataset.info()
"""

def process(self):
# process and merge raw xlsx files from the dataset
covid = pd.DataFrame(
pd.read_excel(f"{self.root}/COVID.metadata.xlsx")
)
covid["FILE NAME"] = covid["FILE NAME"].apply(
lambda x: f"{self.root}/COVID/images/{x}.png"
)
covid["label"] = "COVID"
lung_opacity = pd.DataFrame(
pd.read_excel(f"{self.root}/Lung_Opacity.metadata.xlsx")
)
lung_opacity["FILE NAME"] = lung_opacity["FILE NAME"].apply(
lambda x: f"{self.root}/Lung_Opacity/images/{x}.png"
)
lung_opacity["label"] = "Lung Opacity"
normal = pd.DataFrame(
pd.read_excel(f"{self.root}/Normal.metadata.xlsx")
)
normal["FILE NAME"] = normal["FILE NAME"].apply(
lambda x: x.capitalize()
)
normal["FILE NAME"] = normal["FILE NAME"].apply(
lambda x: f"{self.root}/Normal/images/{x}.png"
)
normal["label"] = "Normal"
viral_pneumonia = pd.DataFrame(
pd.read_excel(f"{self.root}/Viral Pneumonia.metadata.xlsx")
)
viral_pneumonia["FILE NAME"] = viral_pneumonia["FILE NAME"].apply(
lambda x: f"{self.root}/Viral Pneumonia/images/{x}.png"
)
viral_pneumonia["label"] = "Viral Pneumonia"
df = pd.concat(
[covid, lung_opacity, normal, viral_pneumonia],
axis=0,
ignore_index=True
)
df = df.drop(columns=["FORMAT", "SIZE"])
df.columns = ["path", "url", "label"]
for path in df.path:
assert os.path.isfile(os.path.join(self.root, path))
# create patient dict
patients = {}
for index, row in df.iterrows():
patients[index] = row.to_dict()
return patients

def stat(self):
super().stat()
print(f"Number of samples: {len(self.patients)}")
count = Counter([v['label'] for v in self.patients.values()])
print(f"Number of classes: {len(count)}")
print(f"Class distribution: {count}")

@property
def default_task(self):
return COVID19CXRClassification()


if __name__ == "__main__":
dataset = COVID19CXRDataset(
root="/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset",
)
print(list(dataset.patients.items())[0])
dataset.stat()
samples = dataset.set_task()
print(samples[0])
2 changes: 2 additions & 0 deletions pyhealth/datasets/featurizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .image import ImageFeaturizer
from .value import ValueFeaturizer
23 changes: 23 additions & 0 deletions pyhealth/datasets/featurizers/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

import PIL.Image
import torchvision.transforms as transforms


class ImageFeaturizer:

def __init__(self):
self.transform = transforms.Compose([transforms.ToTensor()])

def encode(self, value):
image = PIL.Image.open(value)
image.load() # to avoid "Too many open files" errors
image = self.transform(image)
return image


if __name__ == "__main__":
sample_image = "/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset/Normal/images/Normal-6335.png"
featurizer = ImageFeaturizer()
print(featurizer)
print(type(featurizer))
print(featurizer.encode(sample_image))
14 changes: 14 additions & 0 deletions pyhealth/datasets/featurizers/value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dataclasses import dataclass


@dataclass
class ValueFeaturizer:

def encode(self, value):
return value


if __name__ == "__main__":
featurizer = ValueFeaturizer()
print(featurizer)
print(featurizer.encode(2))
Loading