diff --git a/tools/paddle/prune_paddle_model.py b/tools/paddle/prune_paddle_model.py index 8578b75ad..36cbbbe89 100755 --- a/tools/paddle/prune_paddle_model.py +++ b/tools/paddle/prune_paddle_model.py @@ -6,15 +6,13 @@ import os -def new_prepend_feed_ops(inference_program, - feed_target_names, - feed_holder_name='feed'): +def prepend_feed_ops(program, feed_target_names): if len(feed_target_names) == 0: return - global_block = inference_program.global_block() + global_block = program.global_block() feed_var = global_block.create_var( - name=feed_holder_name, + name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True) @@ -33,13 +31,13 @@ def new_prepend_feed_ops(inference_program, attrs={'col': i}) -def append_fetch_ops(program, fetch_target_names, fetch_holder_name='fetch'): +def append_fetch_ops(program, fetch_target_names): """ In this palce, we will add the fetch op """ global_block = program.global_block() fetch_var = global_block.create_var( - name=fetch_holder_name, + name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True) print("the len of fetch_target_names:%d" % (len(fetch_target_names))) @@ -51,16 +49,20 @@ def append_fetch_ops(program, fetch_target_names, fetch_holder_name='fetch'): attrs={'col': i}) -def insert_fetch(program, fetchs, fetch_holder_name="fetch"): +def insert_by_op_type(program, op_names, op_type): global_block = program.global_block() need_to_remove_op_index = list() for i, op in enumerate(global_block.ops): - if op.type == 'fetch': + if op.type == op_type: need_to_remove_op_index.append(i) for index in need_to_remove_op_index[::-1]: global_block._remove_op(index) program.desc.flush() - append_fetch_ops(program, fetchs, fetch_holder_name) + + if op_type == "feed": + prepend_feed_ops(program, op_names) + else: + append_fetch_ops(program, op_names) def parse_arguments(): @@ -73,6 +75,10 @@ def parse_arguments(): '--model_filename', required=True, help='The input model file name.') parser.add_argument( '--params_filename', required=True, help='The parameters file name.') + parser.add_argument( + '--input_names', + nargs='+', + help='The inputs of pruned model.') parser.add_argument( '--output_names', required=True, @@ -94,7 +100,6 @@ def parse_arguments(): sys.exit(-1) paddle.enable_static() - paddle.static.io.prepend_feed_ops = new_prepend_feed_ops print("Start to load paddle model...") exe = static.Executor(paddle.CPUPlace()) [program, feed_target_names, fetch_targets] = static.io.load_inference_model( @@ -102,9 +107,18 @@ def parse_arguments(): exe, model_filename=args.model_filename, params_filename=args.params_filename) - insert_fetch(program, args.output_names) - feed_vars = [program.global_block().var(name) for name in feed_target_names] - fetch_vars = [program.global_block().var(out_name) for out_name in args.output_names] + + if args.input_names is not None: + insert_by_op_type(program, args.input_names, 'feed') + feed_vars = [program.global_block().var(name) for name in args.input_names] + else: + feed_vars = [program.global_block().var(name) for name in feed_target_names] + + if args.output_names is not None: + insert_by_op_type(program, args.output_names, 'fetch') + fetch_vars = [program.global_block().var(out_name) for out_name in args.output_names] + else: + fetch_vars = [out_var for out_var in fetch_targets] model_name = args.model_filename.split(".")[0] path_prefix = os.path.join(args.save_dir, model_name)