Skip to content

Commit

Permalink
Merge pull request #1484 from 0x3878f/pir_develop_14
Browse files Browse the repository at this point in the history
Fix issues with inference of onnx models converted from paddlex
  • Loading branch information
risemeup1 authored Jan 23, 2025
2 parents df6d78a + b971758 commit 040c051
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 31 deletions.
2 changes: 1 addition & 1 deletion paddle2onnx/mapper/detection/multiclass_nms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ void NMSMapper::KeepTopK(const std::string& selected_indices) {
AddAttribute(ensemble_value, "axis", int64_t(0));

std::shared_ptr<ONNX_NAMESPACE::NodeProto> new_top_k;
if (OnnxHelper::GetOpsetVersion() > 13) {
if (OnnxHelper::GetOpsetVersion() >= 18) {
std::string reduce_min_axis = helper_->Constant(
{1}, ONNX_NAMESPACE::TensorProto::INT64, static_cast<int64_t>(0));
new_top_k = helper_->MakeNode(
Expand Down
88 changes: 68 additions & 20 deletions paddle2onnx/mapper/exporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,32 +200,78 @@ int32_t ModelExporter::GetMinOpsetVersion(const PaddleParser& parser) {
return max_opset;
}

int32_t ModelExporter::GetMinOpsetVersion(const PaddlePirParser& pir_parser) {
int32_t ModelExporter::GetCfBlockMinOpsetVersion(
const PaddlePirParser& pir_parser, pir::Block& block) {
std::vector<pir::Operation*> sub_blocks_ops_copy(pir_parser.sub_blocks_ops);
pir_parser.sub_blocks_ops.clear();
std::vector<pir::Operation*> block_ops;
for (auto& op : block.ops()) {
if (op->name() != "builtin.parameter") {
pir_parser.sub_blocks_ops.push_back(op);
}
}
// Must generate All sub_block's op output names must be generated here
// because it's may used in OPMapper.GetMinOpsetVersion function.
pir_parser.GetAllSubBlockOpOutputName(pir_parser.sub_blocks_ops);
auto max_opset = GetMinOpsetVersion(pir_parser, &block, true);
pir_parser.sub_blocks_ops.clear();
pir_parser.sub_blocks_ops = sub_blocks_ops_copy;
return max_opset;
}

int32_t ModelExporter::GetMinOpsetVersion(const PaddlePirParser& pir_parser,
pir::Block* block,
bool if_in_sublock) {
int32_t max_opset = 7;
std::set<std::string> verbose_log;
OnnxHelper helper;
// TODO(wangmingkai02): consider the case of cf op
for (auto i = 0; i < pir_parser.global_blocks_ops.size(); i++) {
std::string op_name = pir_parser.global_blocks_ops[i]->name();
if (op_name == "pd_op.data" || op_name == "pd_op.fetch") {
continue;
std::vector<pir::Operation*> block_ops;
// it's necessary to be same with global/sub_blocks_ops
for (auto& op : block->ops()) {
if (op->name() != "builtin.parameter") {
block_ops.push_back(op);
}
if (op_name == "pd_op.if" || op_name == "pd_op.while") {
}
for (auto i = 0; i < block_ops.size(); ++i) {
auto op = block_ops[i];
std::string op_name = op->name();
if (op_name == "pd_op.data" || op_name == "pd_op.fetch" ||
op_name == "cf.yield") {
continue;
}
int current_opset = 7;
auto mapper = MapperHelper::Get()->CreateMapper(
convert_pir_op_name(op_name), pir_parser, &helper, i, false);
current_opset = mapper->GetMinOpsetVersion(verbose_);
delete mapper;
if (op_name == "pd_op.if") {
auto if_op = op->dyn_cast<paddle::dialect::IfOp>();
pir::Block& true_block = if_op.true_block();
auto true_block_opset_version =
GetCfBlockMinOpsetVersion(pir_parser, true_block);
pir::Block& false_block = if_op.false_block();
auto false_block_opset_version =
GetCfBlockMinOpsetVersion(pir_parser, false_block);
current_opset = true_block_opset_version > false_block_opset_version
? true_block_opset_version
: false_block_opset_version;
current_opset = current_opset > 11 ? current_opset : 11;
} else if (op_name == "pd_op.while") {
auto while_op = op->dyn_cast<paddle::dialect::WhileOp>();
current_opset = GetCfBlockMinOpsetVersion(pir_parser, while_op.body());
current_opset = current_opset > 11 ? current_opset : 11;

} else {
auto mapper = MapperHelper::Get()->CreateMapper(
convert_pir_op_name(op_name), pir_parser, &helper, i, if_in_sublock);
current_opset = mapper->GetMinOpsetVersion(verbose_);
delete mapper;
}
if (current_opset > max_opset) {
max_opset = current_opset;
if (current_opset > opset_version_) {
verbose_log.insert(
"Due to the operator: " + pir_parser.global_blocks_ops[i]->name() +
" " + "requires opset_version >= " + std::to_string(current_opset) +
".");
if (opset_version_ < 11 ||
(op_name != "pd_op.if" && op_name != "pd_op.while")) {
verbose_log.insert("Due to the operator: " + op_name + " " +
"requires opset_version >= " +
std::to_string(current_opset) + ".");
}
}
}
}
Expand All @@ -240,7 +286,8 @@ void ModelExporter::SetOpsetVersion(const PaddlePirParser& pir_parser,
bool auto_upgrade_opset) {
bool opset_is_legal = true;
// here
int32_t min_opset = GetMinOpsetVersion(pir_parser);
int32_t min_opset =
GetMinOpsetVersion(pir_parser, pir_parser.pir_program_->block(), false);
if (min_opset < 7 || min_opset > MAX_ONNX_OPSET_VERSION) {
P2OLogger(verbose_) << "The Opset Version must be between 7 and "
<< MAX_ONNX_OPSET_VERSION << std::endl;
Expand All @@ -249,7 +296,7 @@ void ModelExporter::SetOpsetVersion(const PaddlePirParser& pir_parser,
if (!auto_upgrade_opset) {
if (min_opset > opset_version_) {
P2OLogger(verbose_) << "Please set the opset_version to "
<< std::to_string(opset_version_)
<< std::to_string(min_opset)
<< " or set auto_upgrade_opset=true." << std::endl;
opset_is_legal = false;
}
Expand Down Expand Up @@ -425,9 +472,10 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportIfBlock(
pir_parser.sub_blocks_ops.push_back(op);
}
}
pir_parser.GetAllSubBlockOpOutputName(pir_parser.sub_blocks_ops);
// generate sub-block op outputs names in GetMinOpSetVersion() function.
// pir_parser.GetAllSubBlockOpOutputName(pir_parser.sub_blocks_ops);
if (!pir_parser.sub_blocks_ops.empty()) {
// get cf.yeild op input
// get cf.yield op input
pir::Operation* cf_yield_op = pir_parser.sub_blocks_ops.back();
// std::vector<std::string> sub_block_outpus;
for (int32_t idx = 0; idx < cf_yield_op->num_operands(); ++idx) {
Expand Down Expand Up @@ -545,7 +593,7 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
for (int32_t idx = 0; idx < outputs.size(); ++idx) {
auto output_item = outputs[idx];
if (output_item->name() == input_item->name()) {
output_item->set_name(pir_parser.GenOpInputOutputName("yeild"));
output_item->set_name(pir_parser.GenOpInputOutputName("yield"));
temp_helper.MakeNode(
"Identity", {input_item->name()}, {output_item->name()});
outputs[idx] = std::move(output_item);
Expand Down
7 changes: 6 additions & 1 deletion paddle2onnx/mapper/exporter.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,13 @@ class ModelExporter {

ONNX_NAMESPACE::ModelProto onnx_model_;
// Opset Version

int32_t GetCfBlockMinOpsetVersion(const PaddlePirParser& pir_parser,
pir::Block& block);
int32_t GetMinOpsetVersion(const PaddleParser& parser);
int32_t GetMinOpsetVersion(const PaddlePirParser& parser);
int32_t GetMinOpsetVersion(const PaddlePirParser& pir_parser,
pir::Block* block,
bool if_in_sublock);
void SetOpsetVersion(const PaddleParser& parser, bool auto_upgrade_opset);
void SetOpsetVersion(const PaddlePirParser& pir_parser,
bool auto_upgrade_opset);
Expand Down
3 changes: 3 additions & 0 deletions paddle2onnx/mapper/tensor/linspace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ void LinspaceMapper::Opset9() {
std::string range_tensor = helper_->AutoCast(
num_info[0].name, num_info[0].dtype, P2ODataType::INT64);

if(num_info[0].Rank() == 0) {
range_tensor = helper_->Unsqueeze(range_tensor, std::vector<int64_t>(1, 0));
}
std::string one_like_node = helper_->ConstOfShape(
range_tensor, GetOnnxDtype(P2ODataType::FP32), static_cast<float>(1));

Expand Down
4 changes: 2 additions & 2 deletions paddle2onnx/mapper/tensor/set_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ int32_t SetValueMapper::GetMinOpsetVersion(bool verbose) {
<< std::endl;
return -1;
}
return 12;
return 17;
}

void SetValueMapper::Opset12() {
void SetValueMapper::Opset17() {
auto input_info = GetInput("Input");
auto output_info = GetOutput("Out");
std::string starts = "";
Expand Down
2 changes: 1 addition & 1 deletion paddle2onnx/mapper/tensor/set_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class SetValueMapper : public Mapper {
}

int32_t GetMinOpsetVersion(bool verbose) override;
void Opset12() override;
void Opset17() override;

private:
std::vector<int64_t> axes_;
Expand Down
3 changes: 2 additions & 1 deletion paddle2onnx/mapper/while.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ void ModelExporter::ExportWhile(PaddlePirParser& pir_parser,
}
}

pir_parser.GetAllSubBlockOpOutputName(pir_parser.sub_blocks_ops);
// generate sub-block op outputs names in GetMinOpSetVersion() function.
// pir_parser.GetAllSubBlockOpOutputName(pir_parser.sub_blocks_ops);
if (!pir_parser.sub_blocks_ops.empty()) {
// get cf.yeild op input
pir::Operation* cf_yield_op = pir_parser.sub_blocks_ops.back();
Expand Down
2 changes: 1 addition & 1 deletion paddle2onnx/parser/pir_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class PaddlePirParser {
// recoring set of operators for global block
std::vector<pir::Operation*> global_blocks_ops;
// recoring set of operators for sub block
std::vector<pir::Operation*>
mutable std::vector<pir::Operation*>
sub_blocks_ops; // todo(wangmingkai02): delete sub_blocks_ops
// recording args of while op body name info
std::unordered_map<pir::detail::ValueImpl*, pir::detail::ValueImpl*>
Expand Down
1 change: 0 additions & 1 deletion tests/run.bat
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ set ignore=!ignore! test_auto_scan_isx_ops.py
set ignore=!ignore! test_auto_scan_masked_select.py
set ignore=!ignore! test_auto_scan_pad2d.py
set ignore=!ignore! test_auto_scan_roll.py
set ignore=!ignore! test_auto_scan_set_value.py
set ignore=!ignore! test_auto_scan_unfold.py
set ignore=!ignore! test_auto_scan_uniform_random_batch_size_like.py
set ignore=!ignore! test_auto_scan_uniform_random.py
Expand Down
1 change: 0 additions & 1 deletion tests/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ ignore="test_auto_scan_multiclass_nms.py
test_auto_scan_masked_select.py \
test_auto_scan_pad2d.py \
test_auto_scan_roll.py \
test_auto_scan_set_value.py \
test_auto_scan_unfold.py \
test_auto_scan_uniform_random_batch_size_like.py \
test_auto_scan_uniform_random.py \
Expand Down
4 changes: 2 additions & 2 deletions tests/test_auto_scan_set_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, inputs, update_input):
class TestSetValueConvert(OPConvertAutoScanTest):
"""
api: set_value
OPset version: 12, 13, 15
OPset version: 17, 19
"""

def sample_convert_config(self, draw):
Expand All @@ -54,7 +54,7 @@ def sample_convert_config(self, draw):
"op_names": ["set_value"],
"test_data_shapes": [input_shape, update_input_shape],
"test_data_types": [[dtype], [dtype]],
"opset_version": [12, 13, 14, 15],
"opset_version": [17, 19],
"input_spec_shape": [],
}

Expand Down

0 comments on commit 040c051

Please sign in to comment.