Skip to content

Commit

Permalink
Merge pull request #208 from tianxin1860/develop
Browse files Browse the repository at this point in the history
add classify inference using infer_program
  • Loading branch information
Yibing Liu authored Jul 17, 2019
2 parents 39d4571 + 2f2fe7a commit 85cf2ee
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 48 deletions.
2 changes: 1 addition & 1 deletion ERNIE/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,4 +316,4 @@ python -u predict_classifier.py \

实际使用时,需要通过 `init_checkpoint` 指定预测用的模型,通过 `predict_set` 指定待预测的数据文件,通过 `num_labels` 配置分类的类别数目;

**Note**: predict_set 的数据格式与 dev_set 和 test_set 的数据格式完全一致,是由 text_a、text_b(可选) 、label 组成的2列/3列 tsv 文件,predict_set 中的 label 列起到占位符的作用,全部置 0 即可;
**Note**: predict_set 的数据格式是由 text_a、text_b(可选) 组成的1列/2列 tsv 文件;
2 changes: 1 addition & 1 deletion ERNIE/finetune/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False):
if is_prediction:
probs = fluid.layers.softmax(logits)
feed_targets_name = [
src_ids.name, pos_ids.name, sent_ids.name, input_mask.name
src_ids.name, sent_ids.name, pos_ids.name, input_mask.name
]
return pyreader, probs, feed_targets_name

Expand Down
59 changes: 39 additions & 20 deletions ERNIE/predict_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
model_g = ArgumentGroup(parser, "model", "options to init, resume and save model.")
model_g.add_arg("ernie_config_path", str, None, "Path to the json file for bert model config.")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("save_inference_model_path", str, "inference_model", "If set, save the inference model to this path.")
model_g.add_arg("use_fp16", bool, False, "Whether to resume parameters from fp16 checkpoint.")
model_g.add_arg("num_labels", int, 2, "num labels for classify")

Expand Down Expand Up @@ -65,7 +66,8 @@ def main(args):
label_map_config=args.label_map_config,
max_seq_len=args.max_seq_len,
do_lower_case=args.do_lower_case,
in_tokens=False)
in_tokens=False,
is_inference=True)

predict_prog = fluid.Program()
predict_startup = fluid.Program()
Expand Down Expand Up @@ -95,33 +97,50 @@ def main(args):
else:
raise ValueError("args 'init_checkpoint' should be set for prediction!")

predict_exe = fluid.Executor(place)
assert args.save_inference_model_path, "args save_inference_model_path should be set for prediction"
_, ckpt_dir = os.path.split(args.init_checkpoint.rstrip('/'))
dir_name = ckpt_dir + '_inference_model'
model_path = os.path.join(args.save_inference_model_path, dir_name)
print("save inference model to %s" % model_path)
fluid.io.save_inference_model(
model_path,
feed_target_names, [probs],
exe,
main_program=predict_prog)

print("load inference model from %s" % model_path)
infer_program, feed_target_names, probs = fluid.io.load_inference_model(
model_path, exe)

src_ids = feed_target_names[0]
sent_ids = feed_target_names[1]
pos_ids = feed_target_names[2]
input_mask = feed_target_names[3]

predict_data_generator = reader.data_generator(
input_file=args.predict_set,
batch_size=args.batch_size,
epoch=1,
shuffle=False)

predict_pyreader.decorate_tensor_provider(predict_data_generator)

predict_pyreader.start()
all_results = []
time_begin = time.time()
while True:
try:
results = predict_exe.run(program=predict_prog, fetch_list=[probs.name])
all_results.extend(results[0])
except fluid.core.EOFException:
predict_pyreader.reset()
break
time_end = time.time()

np.set_printoptions(precision=4, suppress=True)
print("-------------- prediction results --------------")
for index, result in enumerate(all_results):
print(str(index) + '\t{}'.format(result))

np.set_printoptions(precision=4, suppress=True)
index = 0
for sample in predict_data_generator():
src_ids_data = sample[0]
sent_ids_data = sample[1]
pos_ids_data = sample[2]
input_mask_data = sample[3]
output = exe.run(
infer_program,
feed={src_ids: src_ids_data,
sent_ids: sent_ids_data,
pos_ids: pos_ids_data,
input_mask: input_mask_data},
fetch_list=probs)
for single_result in output[0]:
print("example_index:{}\t{}".format(index, single_result))
index += 1

if __name__ == '__main__':
print_arguments(args)
Expand Down
68 changes: 42 additions & 26 deletions ERNIE/reader/task_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self,
max_seq_len=512,
do_lower_case=True,
in_tokens=False,
is_inference=False,
random_seed=None):
self.max_seq_len = max_seq_len
self.tokenizer = tokenization.FullTokenizer(
Expand All @@ -37,6 +38,7 @@ def __init__(self,
self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"]
self.in_tokens = in_tokens
self.is_inference = is_inference

np.random.seed(random_seed)

Expand Down Expand Up @@ -141,25 +143,33 @@ def _convert_example_to_record(self, example, max_seq_length, tokenizer):
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))

if self.label_map:
label_id = self.label_map[example.label]
if self.is_inference:
Record = namedtuple('Record',
['token_ids', 'text_type_ids', 'position_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids)
else:
label_id = example.label

Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid'])

qid = None
if "qid" in example._fields:
qid = example.qid

record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_id=label_id,
qid=qid)
if self.label_map:
label_id = self.label_map[example.label]
else:
label_id = example.label

Record = namedtuple('Record', [
'token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid'
])

qid = None
if "qid" in example._fields:
qid = example.qid

record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_id=label_id,
qid=qid)
return record

def _prepare_batch_data(self, examples, batch_size, phase=None):
Expand Down Expand Up @@ -235,14 +245,18 @@ def _pad_batch_records(self, batch_records):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
batch_labels = [record.label_id for record in batch_records]
batch_labels = np.array(batch_labels).astype("int64").reshape([-1, 1])

if batch_records[0].qid is not None:
batch_qids = [record.qid for record in batch_records]
batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1])
else:
batch_qids = np.array([]).astype("int64").reshape([-1, 1])
if not self.is_inference:
batch_labels = [record.label_id for record in batch_records]
batch_labels = np.array(batch_labels).astype("int64").reshape(
[-1, 1])

if batch_records[0].qid is not None:
batch_qids = [record.qid for record in batch_records]
batch_qids = np.array(batch_qids).astype("int64").reshape(
[-1, 1])
else:
batch_qids = np.array([]).astype("int64").reshape([-1, 1])

# padding
padded_token_ids, input_mask = pad_batch_data(
Expand All @@ -254,8 +268,10 @@ def _pad_batch_records(self, batch_records):

return_list = [
padded_token_ids, padded_text_type_ids, padded_position_ids,
input_mask, batch_labels, batch_qids
input_mask
]
if not self.is_inference:
return_list += [batch_labels, batch_qids]

return return_list

Expand Down

0 comments on commit 85cf2ee

Please sign in to comment.