Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attn convertor. #1761

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions csrc/mmdeploy/codebase/mmocr/attn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include <algorithm>

#include "mmdeploy/core/device.h"
#include "mmdeploy/core/registry.h"
#include "mmdeploy/core/tensor.h"
#include "mmdeploy/core/utils/device_utils.h"
#include "base.h"
#include "mmocr.h"

namespace mmdeploy::mmocr {

using std::string;
using std::vector;

class AttnConvertor : public BaseConvertor {
public:
explicit AttnConvertor(const Value& cfg) : BaseConvertor(cfg) {
auto model = cfg["context"]["model"].get<Model>();
if (!cfg.contains("params")) {
MMDEPLOY_ERROR("'params' is required, but it's not in the config");
throw_exception(eInvalidArgument);
}
auto& _cfg = cfg["params"];

// unknwon
if (_cfg.value("with_unknown", false)) {
unknown_idx_ = static_cast<int>(idx2char_.size());
idx2char_.emplace_back("<UKN>");
}

// BOS/EOS
constexpr char start_end_token[] = "<BOS/EOS>";
constexpr char padding_token[] = "<PAD>";
start_idx_ = static_cast<int>(idx2char_.size());
end_idx_ = start_idx_;
idx2char_.emplace_back(start_end_token);
if (!_cfg.value("start_end_same", true)) {
end_idx_ = static_cast<int>(idx2char_.size());
idx2char_.emplace_back(start_end_token);
}

// padding
padding_idx_ = static_cast<int>(idx2char_.size());
idx2char_.emplace_back(padding_token);

model_ = model;
}

Result<Value> operator()(const Value& _data, const Value& _prob) {
auto d_conf = _prob["output"].get<Tensor>();

if (!(d_conf.shape().size() == 3 && d_conf.data_type() == DataType::kFLOAT)) {
MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}, dtype: {}", d_conf.shape(),
(int)d_conf.data_type());
return Status(eNotSupported);
}

OUTCOME_TRY(auto h_conf, MakeAvailableOnDevice(d_conf, Device{0}, stream()));
OUTCOME_TRY(stream().Wait());

auto data = h_conf.data<float>();

auto shape = d_conf.shape();
auto w = static_cast<int>(shape[1]);
auto c = static_cast<int>(shape[2]);

auto valid_ratio = _data["img_metas"]["valid_ratio"].get<float>();
auto [indexes, scores] = Tensor2Idx(data, w, c, valid_ratio);

auto text = Idx2Str(indexes);
MMDEPLOY_DEBUG("text: {}", text);

TextRecognition output{text, scores};

return make_pointer(to_value(output));
}

std::pair<vector<int>, vector<float> > Tensor2Idx(const float* data, int w, int c,
float valid_ratio) {
auto decode_len = std::min(w, static_cast<int>(std::ceil(w * valid_ratio)));
vector<int> indexes;
indexes.reserve(decode_len);
vector<float> scores;
scores.reserve(decode_len);

for (int t = 0; t < decode_len; ++t, data += c) {
auto iter = std::max_element(data, data + c);
auto index = static_cast<int>(iter - data);
if (index == padding_idx_) continue;
if (index == end_idx_) break;
indexes.push_back(index);
scores.push_back(*iter);
}

return {indexes, scores};
}

private:
int start_idx_{-1};
int end_idx_{-1};
int padding_idx_{-1};
};

MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMOCR, AttnConvertor);

} // namespace mmdeploy::mmocr
78 changes: 78 additions & 0 deletions csrc/mmdeploy/codebase/mmocr/base.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include <sstream>

#include "mmdeploy/codebase/mmocr/base.h"

namespace mmdeploy {
namespace mmocr {

using std::string;
using std::vector;

BaseConvertor::BaseConvertor(const Value& cfg) : MMOCR(cfg) {
auto model = cfg["context"]["model"].get<Model>();
if (!cfg.contains("params")) {
MMDEPLOY_ERROR("'params' is required, but it's not in the config");
throw_exception(eInvalidArgument);
}
// BaseConverter
auto& _cfg = cfg["params"];
if (_cfg.contains("dict_file")) {
auto filename = _cfg["dict_file"].get<std::string>();
auto content = model.ReadFile(filename).value();
idx2char_ = SplitLines(content);
} else if (_cfg.contains("dict_list")) {
from_value(_cfg["dict_list"], idx2char_);
} else if (_cfg.contains("dict_type")) {
auto dict_type = _cfg["dict_type"].get<std::string>();
if (dict_type == "DICT36") {
idx2char_ = SplitChars(DICT36);
} else if (dict_type == "DICT90") {
idx2char_ = SplitChars(DICT90);
} else {
MMDEPLOY_ERROR("unknown dict_type: {}", dict_type);
throw_exception(eInvalidArgument);
}
} else {
MMDEPLOY_ERROR("either dict_file, dict_list or dict_type must be specified");
throw_exception(eInvalidArgument);
}

model_ = model;
}

string BaseConvertor::Idx2Str(const vector<int>& indexes) {
size_t count = 0;
for (const auto& idx : indexes) {
count += idx2char_[idx].size();
}
std::string text;
text.reserve(count);
for (const auto& idx : indexes) {
text += idx2char_[idx];
}
return text;
}

vector<string> BaseConvertor::SplitLines(const string& s) {
std::istringstream is(s);
vector<string> ret;
string line;
while (std::getline(is, line)) {
ret.push_back(std::move(line));
}
return ret;
}

vector<string> BaseConvertor::SplitChars(const string& s) {
vector<string> ret;
ret.reserve(s.size());
for (char c : s) {
ret.push_back({c});
}
return ret;
}

}
}
40 changes: 40 additions & 0 deletions csrc/mmdeploy/codebase/mmocr/base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include <string>
#include <vector>

#include "mmdeploy/core/model.h"
#include "mmocr.h"

namespace mmdeploy::mmocr {

using std::string;
using std::vector;

class BaseConvertor : public MMOCR {
public:
explicit BaseConvertor(const Value& cfg);

string Idx2Str(const vector<int>& indexes);

protected:
static vector<string> SplitLines(const string& s);

static vector<string> SplitChars(const string& s);

static constexpr const auto DICT36 = R"(0123456789abcdefghijklmnopqrstuvwxyz)";
static constexpr const auto DICT90 = R"(0123456789abcdefghijklmnopqrstuvwxyz)"
R"(ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'())"
R"(*+,-./:;<=>?@[\]_`~)";

static constexpr const auto kHost = Device(0);

Model model_;

static constexpr const int blank_idx_{0};
int unknown_idx_{-1};

vector<string> idx2char_;
};

} // namespace mmdeploy::mmocr
80 changes: 3 additions & 77 deletions csrc/mmdeploy/codebase/mmocr/crnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,52 +11,24 @@
#include "mmdeploy/core/utils/formatter.h"
#include "mmdeploy/core/value.h"
#include "mmdeploy/experimental/module_adapter.h"
#include "mmocr.h"
#include "base.h"

namespace mmdeploy::mmocr {

using std::string;
using std::vector;

class CTCConvertor : public MMOCR {
class CTCConvertor : public BaseConvertor {
public:
explicit CTCConvertor(const Value& cfg) : MMOCR(cfg) {
auto model = cfg["context"]["model"].get<Model>();
if (!cfg.contains("params")) {
MMDEPLOY_ERROR("'params' is required, but it's not in the config");
throw_exception(eInvalidArgument);
}
// BaseConverter
explicit CTCConvertor(const Value& cfg) : BaseConvertor(cfg) {
auto& _cfg = cfg["params"];
if (_cfg.contains("dict_file")) {
auto filename = _cfg["dict_file"].get<std::string>();
auto content = model.ReadFile(filename).value();
idx2char_ = SplitLines(content);
} else if (_cfg.contains("dict_list")) {
from_value(_cfg["dict_list"], idx2char_);
} else if (_cfg.contains("dict_type")) {
auto dict_type = _cfg["dict_type"].get<std::string>();
if (dict_type == "DICT36") {
idx2char_ = SplitChars(DICT36);
} else if (dict_type == "DICT90") {
idx2char_ = SplitChars(DICT90);
} else {
MMDEPLOY_ERROR("unknown dict_type: {}", dict_type);
throw_exception(eInvalidArgument);
}
} else {
MMDEPLOY_ERROR("either dict_file, dict_list or dict_type must be specified");
throw_exception(eInvalidArgument);
}
// CTCConverter
idx2char_.insert(begin(idx2char_), "<BLK>");

if (_cfg.value("with_unknown", false)) {
unknown_idx_ = static_cast<int>(idx2char_.size());
idx2char_.emplace_back("<UKN>");
}

model_ = model;
}

Result<Value> operator()(const Value& _data, const Value& _prob) {
Expand Down Expand Up @@ -110,19 +82,6 @@ class CTCConvertor : public MMOCR {
return {indexes, scores};
}

string Idx2Str(const vector<int>& indexes) {
size_t count = 0;
for (const auto& idx : indexes) {
count += idx2char_[idx].size();
}
std::string text;
text.reserve(count);
for (const auto& idx : indexes) {
text += idx2char_[idx];
}
return text;
}

// TODO: move softmax & top-k into model
static void softmax(const float* src, float* dst, int n) {
auto max_val = *std::max_element(src, src + n);
Expand All @@ -136,39 +95,6 @@ class CTCConvertor : public MMOCR {
}
}

protected:
static vector<string> SplitLines(const string& s) {
std::istringstream is(s);
vector<string> ret;
string line;
while (std::getline(is, line)) {
ret.push_back(std::move(line));
}
return ret;
}

static vector<string> SplitChars(const string& s) {
vector<string> ret;
ret.reserve(s.size());
for (char c : s) {
ret.push_back({c});
}
return ret;
}

static constexpr const auto DICT36 = R"(0123456789abcdefghijklmnopqrstuvwxyz)";
static constexpr const auto DICT90 = R"(0123456789abcdefghijklmnopqrstuvwxyz)"
R"(ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'())"
R"(*+,-./:;<=>?@[\]_`~)";

static constexpr const auto kHost = Device(0);

Model model_;

static constexpr const int blank_idx_{0};
int unknown_idx_{-1};

vector<string> idx2char_;
};

MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMOCR, CTCConvertor);
Expand Down