-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_keras_to_pb.py
102 lines (75 loc) · 2.65 KB
/
convert_keras_to_pb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#https://williamjshipman.wordpress.com/2019/06/23/saving-and-loading-tensorflow-neural-networks-part-1-everythings-deprecated-what-now/
#https://www.tensorflow.org/guide/keras/save_and_serialize
import tensorflow as tf
import argparse
import shutil
import os
from os.path import expanduser, join
#Parse input parameters
parser = argparse.ArgumentParser(description='MobileNet Keras Model')
parser.add_argument('--modelPath', type=str, dest='MODEL_DIR', default="./export", help='location to store the model artifacts')
parser.add_argument('--version', type=str, dest='VERSION', default="1", help='model version')
args = parser.parse_args()
MODEL_DIR = args.MODEL_DIR
VERSION = args.VERSION
#tf.logging.set_verbosity(tf.logging.ERROR)
#print(tf.version)
#print(tf.keras.__version__)
tf.keras.layers.DepthwiseConv2D
with tf.keras.utils.CustomObjectScope(
{'relu6': tf.compat.v2.keras.layers.ReLU, 'DepthwiseConv2D': tf.keras.layers.DepthwiseConv2D}):
nsfw_model = tf.keras.models.load_model('./models/nsfw_mobilenet2.224x224.h5')
#nsfw_model.summary()
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
export_path = os.path.join(MODEL_DIR, VERSION)
print(export_path)
tf.saved_model.save(nsfw_model, export_path)
# tf.keras.models.save_model(
# model,
# filepath,
# overwrite=True,
# include_optimizer=True,
# save_format=None,
# signatures=None,
# options=None
# )
# #Save model
# export_path = os.path.join(MODEL_DIR, VERSION)
# print('export_path = {}\n'.format(export_path))
#
# tf.saved_model.save(
# tf.keras.backend.get_session(),
# export_path,
# inputs={'input_image': nsfw_model.input},
# outputs={t.name:t for t in nsfw_model.outputs})
#
# print('\nModel saved to ' + MODEL_DIR)
# else:
# print('\nExisting model found at ' + MODEL_DIR)
# print('\nDid not overwrite old model. Run the job again with a different location to store the model')
#
# print(os.listdir('./storage'))
#
# print(os.listdir(MODEL_DIR))
# shutil.rmtree('./storage/model3')
#
# shutil.rmtree(MODEL_DIR)
# print(os.listdir('./storage'))
#Save model
# if not os.path.exists(MODEL_DIR):
# os.makedirs(MODEL_DIR)
# export_path = os.path.join(MODEL_DIR, VERSION) + '.pb'
# print('export_path = {}\n'.format(export_path))
#
# shutil.copy('./storage/quant_nsfw_mobilenet.pb', export_path)
#
#
# print('\nModel saved to ' + MODEL_DIR)
# else:
# print('\nExisting model found at ' + MODEL_DIR)
# print('\nDid not overwrite old model. Run the job again with a different location to store the model')
#
#
# print(os.listdir(MODEL_DIR))
#/storage/quant_nsfw_mobilenet.pb