Skip to content

Commit

Permalink
Encapsulate SetTensorArrayName and GetTensorArrayName within class Ma…
Browse files Browse the repository at this point in the history
…pper.
  • Loading branch information
0x3878f committed Jan 24, 2025
1 parent aa9d4c0 commit d5ef33b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
8 changes: 8 additions & 0 deletions paddle2onnx/mapper/mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,5 +386,13 @@ class Mapper {
bool TryGetValue(const TensorInfo& info, std::vector<T>* data) {
return parser_->TryGetTensorValue(block_idx_, info.name, data);
}

void SetTensorArrayName(const std::string &arr_name) {
pir_parser_->SetTensorArrayName(pir_op_idx_, if_in_cf_block, arr_name);
}

std::string GetTensorArrayName() {
return pir_parser_->GetTensorArrayName(pir_op_idx_, if_in_cf_block);
}
};
} // namespace paddle2onnx
15 changes: 5 additions & 10 deletions paddle2onnx/mapper/tensor/tensor_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,15 @@ void CreateArrayMapper::Opset11() {
auto output_info = GetOutput(0);
auto node = helper_->MakeNode("SequenceEmpty", {}, {output_info[0].name});
AddAttribute(node, "dtype", GetOnnxDtype(dtype_));
pir_parser_->SetTensorArrayName(
pir_op_idx_, if_in_cf_block, output_info[0].name);
SetTensorArrayName(output_info[0].name);
}
int32_t ArrayLengthMapper::GetMinOpsetVersion(bool verbose) {
Logger(verbose, 11) << "ArrayLengthMapper " << RequireOpset(11) << std::endl;
return 11;
}
void ArrayLengthMapper::Opset11() {
auto output_info = GetOutput(0);
std::string arr_name =
pir_parser_->GetTensorArrayName(pir_op_idx_, if_in_cf_block);
std::string arr_name = GetTensorArrayName();
helper_->MakeNode("SequenceLength", {arr_name}, {output_info[0].name});
}
int32_t ArrayWriteMapper::GetMinOpsetVersion(bool verbose) {
Expand All @@ -55,13 +53,11 @@ void ArrayWriteMapper::Opset11() {
auto index_info = GetInput(2);
auto output_info = GetOutput(0);
auto squeeze_node = helper_->MakeNode("Squeeze", {index_info[0].name});
std::string arr_name =
pir_parser_->GetTensorArrayName(pir_op_idx_, if_in_cf_block);
std::string arr_name = GetTensorArrayName();
helper_->MakeNode("SequenceInsert",
{arr_name, tensor_info[0].name, squeeze_node->output(0)},
{output_info[0].name});
pir_parser_->SetTensorArrayName(
pir_op_idx_, if_in_cf_block, output_info[0].name);
SetTensorArrayName(output_info[0].name);
}
int32_t ArrayReadMapper::GetMinOpsetVersion(bool verbose) {
Logger(verbose, 11) << "ArrayReadMapper " << RequireOpset(11) << std::endl;
Expand All @@ -71,8 +67,7 @@ void ArrayReadMapper::Opset11() {
auto index_info = GetInput(1);
auto output_info = GetOutput(0);
auto squeeze_node = helper_->MakeNode("Squeeze", {index_info[0].name});
std::string arr_name =
pir_parser_->GetTensorArrayName(pir_op_idx_, if_in_cf_block);
std::string arr_name = GetTensorArrayName();
helper_->MakeNode(
"SequenceAt", {arr_name, squeeze_node->output(0)}, {output_info[0].name});
}
Expand Down

0 comments on commit d5ef33b

Please sign in to comment.