-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathIVUS_prediction.py
72 lines (63 loc) · 2.38 KB
/
IVUS_prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from PyQt5.QtWidgets import QProgressDialog, QMessageBox
from PyQt5.QtCore import Qt
import numpy as np
import os
import tensorflow as tf
IMG_MEAN = tf.constant([60.3486, 60.3486, 60.3486], dtype=tf.float32)
IMG_MEAN = tf.constant([60.3486], dtype=tf.float32)
num_classes = 4
num_phenotypes = 5
model_path = 'model/' # change this to relative filepath
model_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model')
try:
model = tf.saved_model.load(model_path)
except:
warning = ("Warning: No saved weights have been found, segmentation will be unsuccessful, check that weights are saved in {}".format(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model')))
print(warning)
def cast_and_center(image):
image = tf.cast(image, dtype=tf.float32)
image = image - IMG_MEAN
return image
def set_input_channels(images, channels=3):
image_dim = images.get_shape()
if len(image_dim) < 4:
images = tf.expand_dims(images, axis=3)
if image_dim[-1] != channels:
images = tf.tile(images, [1, 1, 1, channels])
return images
def predict(images):
"""Runs Convolutional Neural Network to predict image pixel class"""
batch_size = 64
dataset = tf.data.Dataset.from_tensor_slices((images))
dataset = dataset.map(cast_and_center)
dataset = dataset.batch(batch_size)
num_batches = int(np.ceil(images.shape[0]/batch_size))
"""
progress = QProgressDialog()
progress.setWindowFlags(Qt.Dialog)
progress.setModal(True)
progress.setMinimum(0)
progress.setMaximum(num_batches - 1)
progress.resize(500,100)
progress.setValue(0)
progress.setValue(1)
progress.setValue(0) # trick to make progress bar appear
progress.setWindowTitle("Segmenting images")
progress.show()
"""
pred = []
pheno_pred_list = []
for i, batch in enumerate(dataset):
batch = set_input_channels(batch)
logits = model(batch, training=False)
logits = tf.image.resize(logits, (tf.shape(batch)[1], tf.shape(batch)[2]))
pred.append(tf.argmax(logits, axis=-1, output_type=tf.dtypes.int32))
print('Batch {} of {} completed'.format(i+1, num_batches))
#progress.setValue(i)
#if progress.wasCanceled():
# break
#if progress.wasCanceled():
# return None
#progress.close()
pred = np.concatenate(pred)
return pred