-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcd3_inference.py
127 lines (114 loc) · 7.07 KB
/
cd3_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from object_detection_fastai.helper.object_detection_helper import *
from object_detection_fastai.models.RetinaNet import RetinaNet
from torchvision.models.resnet import resnet18
from fastai.vision.learner import create_body
from data.cd3_dataset import CD3Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd
import openslide
import argparse
import torch
import cv2
import os
def visualize_detections(slide, detections, ds=1):
# Visualize detections on the slide image
slide_image = np.array(slide.get_thumbnail((slide.dimensions[0]//ds, slide.dimensions[1]//ds)))[:,:,:3]
slide_image = cv2.cvtColor(slide_image, cv2.COLOR_RGB2BGR)
colors = {
'IMMUNE CELL': (0, 255, 0), # Green
'NON-TUMOR CELL': (255, 0, 0), # Blue
'TUMOR CELL': (0, 165, 255), # Orange
}
for detection in detections:
x1, y1, x2, y2, _, class_label = detection
color = colors[class_label]
cv2.rectangle(slide_image, (x1//ds, y1//ds), (x2//ds, y2//ds), color, 2)
return slide_image
def load_model_and_anchors(checkpoint):
print('Loading model')
anchors = create_anchors(sizes=[(32, 32), (16, 16), (8, 8), (4, 4)], ratios=[0.5, 1, 2],scales=[0.5, 0.75, 1, 1.25, 1.5])
encoder = create_body(resnet18(), pretrained=False, cut=-2)
model = RetinaNet(encoder, n_classes=4, n_anchors=15, sizes=[32, 16, 8, 4], chs=128, final_bias=-4., n_conv=3)
model.load_state_dict(torch.load(os.path.join(os.getcwd(), 'ckpts', '{}.pth'.format(checkpoint)), map_location=torch.device('cpu'))['model'])
return model, anchors
def inference(dataloader, model, anchors, device, patch_size, down_factor, detect_thresh=0.5, nms_thresh=0.4):
classes = ['IMMUNE CELL', 'NON-TUMOR CELL', 'TUMOR CELL']
class_pred_batch, bbox_pred_batch, x_coordinates, y_coordinates = [], [], [], []
patch_counter = 0
detections = []
with torch.inference_mode():
for patches,x,y in tqdm(dataloader):
class_pred, bbox_pred, _ = model(patches.to(device))
class_pred_batch.extend (class_pred.cpu())
bbox_pred_batch.extend(bbox_pred.cpu())
x_coordinates.extend(x)
y_coordinates.extend(y)
patch_counter += len(patches)
print(f'Ran inference for {patch_counter} patches.')
final_bbox_pred, final_scores, final_class_pred = [], [], []
print(f'Postprocessing predictions.')
for clas_pred, bbox_pred, x, y in zip(class_pred_batch, bbox_pred_batch, x_coordinates, y_coordinates):
# anchor matching and filter predictions with detection threshold
bbox_pred, scores, preds = process_output(clas_pred.cpu(), bbox_pred.cpu(), anchors, detect_thresh=detect_thresh)
if bbox_pred is not None:
# rescale detection boxes with patch_size
t_sz = torch.Tensor([patch_size, patch_size])[None].float()
bbox_pred = rescale_box(bbox_pred, t_sz)
bbox_pred += torch.Tensor([y//down_factor, x//down_factor, 0, 0]).long()
# apply non-maximum-supression per patch
to_keep = nms(bbox_pred, scores, return_ids = True, nms_thresh = nms_thresh)
bbox_pred, preds, scores = bbox_pred[to_keep].cpu(), preds[to_keep].cpu(), scores[to_keep].cpu()
final_bbox_pred.extend(bbox_pred)
final_class_pred.extend(preds)
final_scores.extend(scores)
# Global non-maximum-supression
keep_global = nms(torch.Tensor(np.array(final_bbox_pred)), torch.Tensor(final_scores), return_ids = True, nms_thresh = nms_thresh)
final_bbox_pred = torch.Tensor(np.array(final_bbox_pred))[keep_global]
final_class_pred = torch.Tensor(np.array(final_class_pred))[keep_global]
final_scores = torch.Tensor(np.array(final_scores))[keep_global]
# convert bbox (x_local, y_local, h, w) to (x1_global, y1_global, x2_global, y2_global)
for box, pred, score in zip(final_bbox_pred, final_class_pred, final_scores):
y_box, x_box = box[:2]
h, w = box[2:4]
x1 = int(x_box) * down_factor
y1 = int(y_box) * down_factor
x2 = x1 + int(w) * down_factor
y2 = y1 + int(h) * down_factor
detections.append([int(x1), int(y1), int(x2), int(y2), float(score), classes[int(pred)]])
return detections
def process(slide_dir, level, patch_size, checkpoint, detect_thresh, visualize):
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
model, anchors = load_model_and_anchors(checkpoint)
model = model.eval().to(device)
os.makedirs(os.path.join(slide_dir, 'processing'), exist_ok=True)
# Iterate through slides in folder.
print(f'Running inference on all slides in folder {slide_dir} using a patch size of {patch_size}x{patch_size} at resolution level {level}')
for file in [slide for slide in os.listdir(slide_dir) if not os.path.isdir(os.path.join(slide_dir, slide))]:
print(f'Slide {file}')
slide = openslide.open_slide(os.path.join(slide_dir, file))
ds = CD3Dataset(slide, level, patch_size)
dl = DataLoader(ds, num_workers=0, batch_size=8)
detections = inference(dl, model, anchors, device, patch_size, slide.level_downsamples[level], detect_thresh)
detection_df = pd.DataFrame(detections, columns=['x1','y1','x2','y2','score','class'])
result_path = os.path.join(slide_dir, 'processing', f"{file.split('.')[0]}.csv")
detection_df.to_csv(result_path, index=False)
print('Stored results at', result_path)
# Visualize results
if visualize:
slide_image_with_detections = visualize_detections(slide, detections)
cv2.imwrite(result_path.replace('.csv', '.png'), slide_image_with_detections)
print('Stored visualization of predictions at', result_path.replace('.csv', '.png'))
def main():
parser = argparse.ArgumentParser(description='Inference for T-lyphocyte detection on CD3-stained IHC samples')
parser.add_argument('--slide_dir', type=str, help='Slide directory')
parser.add_argument('--level', type=int, help='Resolution level (models were trained on level 0, i.e. 0.25 um/pixel)', default=0)
parser.add_argument('--patch_size', type=int, help='Patch size (models were trained on 256 x 256 pixels)', default=256)
parser.add_argument('--checkpoint', type=str, help='Model checkpoint', default='all', choices=['all', 'HNSCC', 'NSCLC', 'TNBC', 'GC'])
parser.add_argument('--detect_thresh', type=int, help='Confidence threshold for detections. Lower threshold increases recall, higher threshold increases specificity.', default=0.5)
parser.add_argument('--visualize', type=bool, help='Flag for exporting detection results as png.', default=True)
args = parser.parse_args()
# Call inference function with parsed arguments
process(slide_dir=args.slide_dir, level=args.level, patch_size=args.patch_size, checkpoint=args.checkpoint, detect_thresh=args.detect_thresh, visualize=args.visualize)
if __name__ == "__main__":
main()