Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/xray report generation #144

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"editor.rulers": [
79
]
}
178 changes: 178 additions & 0 deletions examples/xray_report_generation_sat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
import argparse
import pickle
import torch
from collections import OrderedDict
from torchvision import transforms
from pyhealth.datasets import BaseImageCaptionDataset
from pyhealth.tasks.xray_report_generation import biview_multisent_fn
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.tokenizer import Tokenizer
from pyhealth.models import WordSAT, SentSAT
from pyhealth.trainer import Trainer

def get_args():
parser = argparse.ArgumentParser()

parser.add_argument('--root', type=str, default=".")
parser.add_argument('--encoder-chkpt-fname', type=str, default=None)
parser.add_argument('--tokenizer-fname', type=str, default=None)
parser.add_argument('--num-epochs', type=int, default=1)
parser.add_argument('--model-type', type=str, default="wordsat")

args = parser.parse_args()
return args

def seed_everything(seed: int):
import random, os
import numpy as np
import torch

random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True


# STEP 1: load data
def load_data(root):
base_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay')
return base_dataset

# STEP 2: set task
def set_task(base_dataset):
sample_dataset = base_dataset.set_task(biview_multisent_fn)
transform = transforms.Compose([
transforms.RandomAffine(degrees=30),
transforms.Resize((512,512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225]),
])
sample_dataset.set_transform(transform)
return sample_dataset

# STEP 3: get dataloaders
def get_dataloaders(sample_dataset):
train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset,[0.8,0.1,0.1])

train_dataloader = get_dataloader(train_dataset,batch_size=8,shuffle=True)
val_dataloader = get_dataloader(val_dataset,batch_size=1,shuffle=False)
test_dataloader = get_dataloader(test_dataset,batch_size=1,shuffle=False)

return train_dataloader,val_dataloader,test_dataloader

# STEP 4: get tokenizer
def get_tokenizer(root,sample_dataset=None,tokenizer_fname=None):
if tokenizer_fname:
with open(os.path.join(root,tokenizer_fname), 'wb') as f:
tokenizer = pickle.load(f)
else:
# <pad> should always be first element in the list of special tokens
special_tokens = ['<pad>','<start>','<end>','<unk>']
tokenizer = Tokenizer(
sample_dataset.get_all_tokens(key='caption'),
special_tokens=special_tokens,
)
return tokenizer

# STEP 5: get encoder pretrained state dictionary
def extract_encoder_state_dict(root,chkpt_fname):
checkpoint = torch.load(os.path.join(root,chkpt_fname) )
state_dict = OrderedDict()
for k,v in checkpoint['state_dict'].items():
if 'classifier' in k: continue
if k[:7] == 'module.' :
name = k[7:]
else:
name = k

name = name.replace('classifier.0.','classifier.')
state_dict[name] = v
return state_dict

# STEP 6: define model
def define_model(
sample_dataset,
tokenizer,
encoder_weights,
model_type='wordsat'):

if model_type == 'wordsat':
model=WordSAT(
dataset = sample_dataset,
n_input_images = 2,
label_key = 'caption',
tokenizer = tokenizer,
encoder_pretrained_weights = encoder_weights,
encoder_freeze_weights = True,
save_generated_caption = True
)
else:
model=SentSAT(
dataset = sample_dataset,
n_input_images = 2,
label_key = 'caption',
tokenizer = tokenizer,
encoder_pretrained_weights = encoder_weights,
encoder_freeze_weights = True,
save_generated_caption = True
)

return model

# STEP 7: run trainer
def run_trainer(output_path, train_dataloader, val_dataloader, model,n_epochs):
trainer = Trainer(
model=model,
output_path=output_path
)

trainer.train(
train_dataloader = train_dataloader,
val_dataloader = val_dataloader,
optimizer_params = {"lr": 1e-4},
weight_decay = 1e-5,
epochs = n_epochs,
monitor = 'Bleu_1'
)
return trainer

# STEP 8: evaluate
def evaluate(trainer,test_dataloader):
print(trainer.evaluate(test_dataloader))
return None

if __name__ == '__main__':
args = get_args()
seed_everything(42)
base_dataset = load_data(args.root)
sample_dataset = set_task(base_dataset)
train_dataloader,val_dataloader,test_dataloader = get_dataloaders(
sample_dataset)
tokenizer = get_tokenizer(args.root,sample_dataset)
encoder_weights = extract_encoder_state_dict(args.root,
args.encoder_chkpt_fname)

model = define_model(
sample_dataset,
tokenizer,
encoder_weights,
args.model_type)

trainer = run_trainer(
args.root,
train_dataloader,
val_dataloader,
model,
args.num_epochs)

print("\n===== Evaluating Test Data ======\n")
evaluate(trainer,test_dataloader)



4 changes: 3 additions & 1 deletion pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_ehr_dataset import BaseEHRDataset
from .base_image_caption_dataset import BaseImageCaptionDataset
from .base_signal_dataset import BaseSignalDataset
from .eicu import eICUDataset
from .mimic3 import MIMIC3Dataset
Expand All @@ -7,6 +8,7 @@
from .sleepedf import SleepEDFDataset
from .isruc import ISRUCDataset
from .shhs import SHHSDataset
from .sample_dataset import SampleBaseDataset, SampleSignalDataset, SampleEHRDataset
from .sample_dataset import SampleBaseDataset, SampleSignalDataset, SampleEHRDataset,\
SampleImageCaptionDataset
from .splitter import split_by_patient, split_by_visit
from .utils import collate_fn_dict, get_dataloader, strptime
152 changes: 152 additions & 0 deletions pyhealth/datasets/base_image_caption_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import logging
import os
from abc import ABC
from typing import Optional, Callable

import pandas as pd
from tqdm import tqdm

from pyhealth.datasets.sample_dataset import SampleImageCaptionDataset

logger = logging.getLogger(__name__)

INFO_MSG = """
dataset.patients:
- key: patient id
- value: a dict of image paths, caption, and other information
"""


class BaseImageCaptionDataset(ABC):
"""Abstract base Image Caption Generation dataset class.

This abstract class defines a uniform interface for all
image caption generation datasets.

Each specific dataset will be a subclass of this abstract class, which can
then be converted to samples dataset for different tasks by calling
`self.set_task()`.

Args:
root: root directory of the raw data (should contain many csv files).
dataset_name: name of the dataset. Default is the name of the class.
dev: whether to enable dev mode (only use a small subset of the data).
Default is False.
refresh_cache: whether to refresh the cache; if true, the dataset will
be processed from scratch and the cache will be updated.
Default is False.
"""

def __init__(
self,
root: str,
dataset_name: Optional[str] = None,
dev: bool = False,
refresh_cache: bool = False,
):
# base attributes
self.dataset_name = (
self.__class__.__name__ if dataset_name is None else dataset_name
)
self.root = root
# TODO: dev seems unnecessary for image and signal?
self.dev = dev
if dev:
logger.warning("WARNING: dev has no effect \
for image caption generation datasets.")
# TODO: refresh_cache seems unnecessary for image and signal?
self.refresh_cache = refresh_cache
if refresh_cache:
logger.warning("WARNING: refresh_cache has no effect \
for image caption generation datasets.")

self.metadata = pd.read_json(os.path.join(root,
"metadata.jsonl"), lines=True)
if "patient_id" not in self.metadata.columns:
# no patient_id in metadata, sequentially assign patient_id
self.metadata["patient_id"] = self.metadata.index

# group by patient_id
self.patients = dict()
for patient_id, group in self.metadata.groupby("patient_id"):
self.patients[patient_id] = group.to_dict(orient="records")

return

def __len__(self):
return len(self.patients)

def __str__(self):
"""Prints some information of the dataset."""
return f"Base dataset {self.dataset_name}"

def stat(self) -> str:
"""Returns some statistics of the base dataset."""
lines = list()
lines.append("")
lines.append(f"Statistics of base dataset (dev={self.dev}):")
lines.append(f"\t- Dataset: {self.dataset_name}")
lines.append(f"\t- Number of images: {len(self)}")
lines.append("")
print("\n".join(lines))
return "\n".join(lines)

@staticmethod
def info():
"""Prints the output format."""
print(INFO_MSG)

def set_task(
self,
task_fn: Callable,
task_name: Optional[str] = None,
) -> SampleImageCaptionDataset:
"""Processes the base dataset to generate the task-specific
sample dataset.

This function should be called by the user after the base dataset is
initialized. It will iterate through all patients in the base dataset
and call `task_fn` which should be implemented by the specific task.

Args:
task_fn: a function that takes a single patient and returns a
list of samples (each sample is a dict with patient_id,
image_path_list, caption and other task-specific attributes
as key). The samples will be concatenated to form the
sample dataset.
task_name: the name of the task. If None, the name of the task
function will be used.

Returns:
sample_dataset: the task-specific sample (Base) dataset.

Note:
In `task_fn`, a patient may have one or multiple images associated
to a caption, for e.g. a patient can have single report
for xrays taken from diffrent views that may be combined to
have a single sample such as
(
{'patient_id': 1,
'image_path_list': [frontal_img_path,lateral_img_path],
'caption': 'report_text'}
)
Patients can also be excluded from the task dataset by
returning an empty list.
"""
if task_name is None:
task_name = task_fn.__name__

# load from raw data
logger.debug(f"Processing {self.dataset_name} base dataset...")

samples = []
for patient_id, patient in tqdm(
self.patients.items(), desc=f"Generating samples for {task_name}"):
samples.extend(task_fn(patient))

sample_dataset = SampleImageCaptionDataset(
samples,
dataset_name=self.dataset_name,
task_name=task_name,
)
return sample_dataset
Loading