-
Notifications
You must be signed in to change notification settings - Fork 6
/
app.py
95 lines (77 loc) · 2.57 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import gdown
import gradio as gr
import tensorflow as tf
from config import Parameters
from models.hybrid_model import GradientAccumulation
from utils.model_utils import *
from utils.viz_utils import make_gradcam_heatmap
from utils.viz_utils import save_and_display_gradcam
image_size = Parameters().image_size
str_labels = [
"daisy",
"dandelion",
"roses",
"sunflowers",
"tulips",
]
def get_model():
"""Get the model."""
model = GradientAccumulation(
n_gradients=params.num_grad_accumulation, model_name="HybridModel"
)
_ = model(tf.ones((1, params.image_size, params.image_size, 3)))[0].shape
return model
def get_model_weight(model_id):
"""Get the trained weights."""
if not os.path.exists("model.h5"):
model_weight = gdown.download(id=model_id, quiet=False)
else:
model_weight = "model.h5"
return model_weight
def load_model(model_id):
"""Load trained model."""
weight = get_model_weight(model_id)
model = get_model()
model.load_weights(weight)
return model
def image_process(image):
"""Image preprocess for model input."""
image = tf.cast(image, dtype=tf.float32)
original_shape = image.shape
image = tf.image.resize(image, [image_size, image_size])
image = image[tf.newaxis, ...]
return image, original_shape
def predict_fn(image):
"""A predict function that will be invoked by gradio."""
loaded_model = load_model(model_id="1y6tseN0194T6d-4iIh5wo7RL9ttQERe0")
loaded_image, original_shape = image_process(image)
heatmap_a, heatmap_b, preds = make_gradcam_heatmap(loaded_image, loaded_model)
int_label = tf.argmax(preds, axis=-1).numpy()[0]
str_label = str_labels[int_label]
overaly_a = save_and_display_gradcam(
loaded_image[0], heatmap_a, image_shape=original_shape[:2]
)
overlay_b = save_and_display_gradcam(
loaded_image[0], heatmap_b, image_shape=original_shape[:2]
)
return [f"Predicted: {str_label}", overaly_a, overlay_b]
iface = gr.Interface(
fn=predict_fn,
inputs=gr.inputs.Image(label="Input Image"),
outputs=[
gr.outputs.Label(label="Prediction"),
gr.inputs.Image(label="CNN GradCAM"),
gr.inputs.Image(label="Transformer GradCAM"),
],
title="Hybrid EfficientNet Swin Transformer Demo",
description="The model is trained on tf_flowers dataset.",
examples=[
["examples/dandelion.jpg"],
["examples/sunflower.jpg"],
["examples/tulip.jpg"],
["examples/daisy.jpg"],
["examples/rose.jpg"],
],
)
iface.launch()