Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add graph model #226

Merged
merged 4 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 243 additions & 0 deletions examples/graph_torchvision_model.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "3cc9dcc6-9d43-47e9-a4d3-0ba29beb425e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"PATH = '/home/namkyeong/PyHealth'\n",
"os.chdir(PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e5c4ec0-0887-48c7-8f97-1ccafb401f2d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from pyhealth.sampler import NeighborSampler\n",
"from pyhealth.models import Graph_TorchvisionModel\n",
"\n",
"from torchvision import transforms\n",
"from pyhealth.datasets import COVID19CXRDataset"
]
},
{
"cell_type": "markdown",
"id": "a80901ce-1077-4272-a087-f728683bcdad",
"metadata": {},
"source": [
"## Load Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1c54991-ed56-400f-a82e-3609ef0f953c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from pyhealth.datasets import COVID19CXRDataset\n",
"\n",
"root = \"./data/COVID-19_Radiography_Dataset\"\n",
"base_dataset = COVID19CXRDataset(root)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "46475e1d-e31a-48f1-848d-457a9c2eee48",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"base_dataset.default_task"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad7bc53d-66db-499d-bddd-97c3e6184ad7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"sample_dataset = base_dataset.set_task()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6caf164-9875-448a-938f-bae758b439d9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from torchvision import transforms\n",
"\n",
"\n",
"transform = transforms.Compose([\n",
" transforms.Resize((224, 224)),\n",
" transforms.Grayscale(),\n",
" transforms.Normalize(mean=[0.5862785803043838], std=[0.27950088968644304])\n",
"])\n",
"\n",
"\n",
"def encode(sample):\n",
" sample[\"path\"] = transform(sample[\"path\"])\n",
" return sample\n",
"\n",
"\n",
"sample_dataset.set_transform(encode)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c46515c-9ff3-4ff5-b942-cf43666d3a2d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from pyhealth.datasets import split_by_sample\n",
"\n",
"# Get Index of train, valid, test set\n",
"train_index, val_index, test_index = split_by_sample(\n",
" dataset=sample_dataset,\n",
" ratios=[0.7, 0.1, 0.2],\n",
" get_index = True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5c1e65c-ac81-421b-b265-7404fcab404c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model = Graph_TorchvisionModel(\n",
" dataset=sample_dataset,\n",
" feature_keys=[\"path\"],\n",
" label_key=\"label\",\n",
" mode=\"multiclass\",\n",
" model_name=\"vit_b_16\",\n",
" model_config={\"weights\": \"DEFAULT\"},\n",
" gnn_config={\"input_dim\": 256, \"hidden_dim\": 128},\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1640ebc-9fcb-49e1-9f2c-e7113ca85c62",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Build graph\n",
"# Set random = True will build random graph data\n",
"graph = model.build_graph(sample_dataset, random = True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "567e6091-87c7-4b37-a78d-402bc7c96939",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Define Sampler as Dataloader\n",
"train_dataloader = NeighborSampler(sample_dataset, graph[\"edge_index\"], node_idx=train_index, sizes=[15, 10], batch_size=64, shuffle=True, num_workers=12)\n",
"\n",
"# We sample all edges connected to target node for validation and test (Sizes = [-1, -1])\n",
"valid_dataloader = NeighborSampler(sample_dataset, graph[\"edge_index\"], node_idx=val_index, sizes=[-1, -1], batch_size=64, shuffle=False, num_workers=12)\n",
"test_dataloader = NeighborSampler(sample_dataset, graph[\"edge_index\"], node_idx=test_index, sizes=[-1, -1], batch_size=64, shuffle=False, num_workers=12)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c54b30ec-e7fc-402d-aba5-7c361e95cd60",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"from pyhealth.trainer import Trainer\n",
"\n",
"resnet_trainer = Trainer(model=model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1d2e71f-2a33-4187-b145-2161172821a9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"print(resnet_trainer.evaluate(test_dataloader))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ddf1c7c9-20f2-4d84-acb3-f24a5a99670e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"resnet_trainer.train(\n",
" train_dataloader=train_dataloader,\n",
" val_dataloader=valid_dataloader,\n",
" epochs=1,\n",
" monitor=\"accuracy\"\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pyhealth",
"language": "python",
"name": "pyhealth"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
3 changes: 2 additions & 1 deletion pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
from .tuab import TUABDataset
from .tuev import TUEVDataset
from .sample_dataset import SampleBaseDataset, SampleSignalDataset, SampleEHRDataset
from .splitter import split_by_patient, split_by_visit
from .splitter import split_by_patient, split_by_visit, split_by_sample
from .utils import collate_fn_dict, get_dataloader, strptime
from .covid19_cxr import COVID19CXRDataset
74 changes: 74 additions & 0 deletions pyhealth/datasets/base_dataset_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import logging
from abc import ABC, abstractmethod
from typing import Optional, Dict

from tqdm import tqdm

from pyhealth.datasets.sample_dataset_v2 import SampleDataset
from pyhealth.tasks.task_template import TaskTemplate

logger = logging.getLogger(__name__)


class BaseDataset(ABC):
"""Abstract base dataset class."""

def __init__(
self,
root: str,
dataset_name: Optional[str] = None,
**kwargs,
):
if dataset_name is None:
dataset_name = self.__class__.__name__
self.root = root
self.dataset_name = dataset_name
logger.debug(f"Processing {self.dataset_name} base dataset...")
self.patients = self.process()
# TODO: cache
return

def __str__(self):
return f"Base dataset {self.dataset_name}"

def __len__(self):
return len(self.patients)

@abstractmethod
def process(self) -> Dict:
raise NotImplementedError

@abstractmethod
def stat(self):
print(f"Statistics of {self.dataset_name}:")
return

@property
def default_task(self) -> Optional[TaskTemplate]:
return None

def set_task(self, task: Optional[TaskTemplate] = None) -> SampleDataset:
"""Processes the base dataset to generate the task-specific sample dataset.
"""
# TODO: cache?
if task is None:
# assert default tasks exist in attr
assert self.default_task is not None, "No default tasks found"
task = self.default_task

# load from raw data
logger.debug(f"Setting task for {self.dataset_name} base dataset...")

samples = []
for patient_id, patient in tqdm(
self.patients.items(), desc=f"Generating samples for {task.task_name}"
):
samples.extend(task(patient))
sample_dataset = SampleDataset(
samples,
input_schema=task.input_schema,
output_schema=task.output_schema,
dataset_name=self.dataset_name,
task_name=task,
)
return sample_dataset
Loading