Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
0x3878f committed Jan 22, 2025
1 parent 6f88ae0 commit b971758
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
21 changes: 12 additions & 9 deletions paddle2onnx/mapper/exporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ int32_t ModelExporter::GetCfBlockMinOpsetVersion(
pir_parser.sub_blocks_ops.push_back(op);
}
}
// Must generate All sub_block's op output names must be generated here
// because it's may used in OPMapper.GetMinOpsetVersion function.
pir_parser.GetAllSubBlockOpOutputName(pir_parser.sub_blocks_ops);
auto max_opset = GetMinOpsetVersion(pir_parser, &block, true);
pir_parser.sub_blocks_ops.clear();
pir_parser.sub_blocks_ops = sub_blocks_ops_copy;
Expand All @@ -223,18 +226,17 @@ int32_t ModelExporter::GetMinOpsetVersion(const PaddlePirParser& pir_parser,
std::set<std::string> verbose_log;
OnnxHelper helper;
std::vector<pir::Operation*> block_ops;
for (auto& op :
block->ops()) { // it's necessary to be same with global/sub_blocks_ops
if (op->name() == "builtin.parameter") {
continue;
// it's necessary to be same with global/sub_blocks_ops
for (auto& op : block->ops()) {
if (op->name() != "builtin.parameter") {
block_ops.push_back(op);
}
block_ops.push_back(op);
}
for (auto i = 0; i < block_ops.size(); ++i) {
auto op = block_ops[i];
std::string op_name = op->name();
if (op_name == "pd_op.data" || op_name == "pd_op.fetch" ||
op_name == "pd_op.yeild") {
op_name == "cf.yield") {
continue;
}
int current_opset = 7;
Expand Down Expand Up @@ -470,9 +472,10 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportIfBlock(
pir_parser.sub_blocks_ops.push_back(op);
}
}
pir_parser.GetAllSubBlockOpOutputName(pir_parser.sub_blocks_ops);
// generate sub-block op outputs names in GetMinOpSetVersion() function.
// pir_parser.GetAllSubBlockOpOutputName(pir_parser.sub_blocks_ops);
if (!pir_parser.sub_blocks_ops.empty()) {
// get cf.yeild op input
// get cf.yield op input
pir::Operation* cf_yield_op = pir_parser.sub_blocks_ops.back();
// std::vector<std::string> sub_block_outpus;
for (int32_t idx = 0; idx < cf_yield_op->num_operands(); ++idx) {
Expand Down Expand Up @@ -590,7 +593,7 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
for (int32_t idx = 0; idx < outputs.size(); ++idx) {
auto output_item = outputs[idx];
if (output_item->name() == input_item->name()) {
output_item->set_name(pir_parser.GenOpInputOutputName("yeild"));
output_item->set_name(pir_parser.GenOpInputOutputName("yield"));
temp_helper.MakeNode(
"Identity", {input_item->name()}, {output_item->name()});
outputs[idx] = std::move(output_item);
Expand Down
2 changes: 1 addition & 1 deletion paddle2onnx/mapper/nn/interpolate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ void InterpolateMapper::Opset11() {
}
std::shared_ptr<ONNX_NAMESPACE::NodeProto> node;
if (size != "") {
node = helper_->MakeNode("Resize", {x_info[0].name, roi, "", size},
node = helper_->MakeNode("Resize", {x_info[0].name, roi, scale, size},
{out_info[0].name});
} else {
node = helper_->MakeNode("Resize", {x_info[0].name, roi, scale},
Expand Down
3 changes: 2 additions & 1 deletion paddle2onnx/mapper/while.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ void ModelExporter::ExportWhile(PaddlePirParser& pir_parser,
}
}

pir_parser.GetAllSubBlockOpOutputName(pir_parser.sub_blocks_ops);
// generate sub-block op outputs names in GetMinOpSetVersion() function.
// pir_parser.GetAllSubBlockOpOutputName(pir_parser.sub_blocks_ops);
if (!pir_parser.sub_blocks_ops.empty()) {
// get cf.yeild op input
pir::Operation* cf_yield_op = pir_parser.sub_blocks_ops.back();
Expand Down

0 comments on commit b971758

Please sign in to comment.