forked from rockyzhengwu/FoolNLTK
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
79 lines (60 loc) · 2.55 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env python
#-*-coding:utf-8-*-
import tensorflow as tf
import pickle
import numpy as np
from tensorflow.contrib.crf import viterbi_decode
def decode(logits, trans, sequence_lengths, tag_num):
viterbi_sequences = []
small = -1000.0
start = np.asarray([[small] * tag_num + [0]])
for logit, length in zip(logits, sequence_lengths):
score = logit[:length]
pad = small * np.ones([length, 1])
logits = np.concatenate([score, pad], axis=1)
logits = np.concatenate([start, logits], axis=0)
viterbi_seq, viterbi_score = viterbi_decode(logits, trans)
viterbi_sequences .append(viterbi_seq[1:])
return viterbi_sequences
def load_map(path):
with open(path, 'rb') as f:
char_to_id, tag_to_id, id_to_tag = pickle.load(f)
return char_to_id, id_to_tag
def load_graph(path):
with tf.gfile.GFile(path, mode='rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name="prefix")
return graph
class Model(object):
def __init__(self, map_file, model_file):
self.char_to_id, self.id_to_tag = load_map(map_file)
self.graph = load_graph(model_file)
self.input_x = self.graph.get_tensor_by_name("prefix/char_inputs:0")
self.lengths = self.graph.get_tensor_by_name("prefix/lengths:0")
self.dropout = self.graph.get_tensor_by_name("prefix/dropout:0")
self.logits = self.graph.get_tensor_by_name("prefix/project/logits:0")
self.trans = self.graph.get_tensor_by_name("prefix/crf_loss/transitions:0")
self.sess = tf.Session(graph=self.graph)
self.sess.as_default()
self.num_class = len(self.id_to_tag)
def predict(self, sents):
inputs = []
lengths = [len(text) for text in sents]
max_len = max(lengths)
for sent in sents:
sent_ids = [self.char_to_id.get(w) if w in self.char_to_id else self.char_to_id.get("<OOV>") for w in sent]
padding = [0] * (max_len - len(sent_ids))
sent_ids += padding
inputs.append(sent_ids)
inputs = np.array(inputs, dtype=np.int32)
feed_dict = {
self.input_x: inputs,
self.lengths: lengths,
self.dropout: 1.0
}
logits, trans = self.sess.run([self.logits, self.trans], feed_dict=feed_dict)
path = decode(logits, trans, lengths, self.num_class)
labels = [[self.id_to_tag.get(l) for l in p] for p in path]
return labels