Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing Whisper Model Token Normalization #1904

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,5 @@ us_gold.json
us_silver.json
kokoro-multi-lang-v1_0
sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16
cmake-build-debug
README-DEV.txt
48 changes: 25 additions & 23 deletions sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,6 @@

namespace sherpa_onnx {

static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
const SymbolTable &sym_table) {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());

std::string text;
for (auto i : src.tokens) {
if (!sym_table.Contains(i)) {
continue;
}

const auto &s = sym_table[i];
text += s;
r.tokens.push_back(s);
}

r.text = text;
r.lang = src.lang;

return r;
}

class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config)
Expand Down Expand Up @@ -156,7 +134,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
std::move(cross_kv.second));

auto r = Convert(results[0], symbol_table_);
r.text = ApplyInverseTextNormalization(std::move(r.text));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to Convert()

s->SetResult(r);
} catch (const Ort::Exception &ex) {
SHERPA_ONNX_LOGE(
Expand All @@ -169,6 +146,31 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
}
}

private:
OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
const SymbolTable &sym_table) const {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());

std::string text;
for (auto i : src.tokens) {
if (!sym_table.Contains(i)) {
continue;
}

std::string s = sym_table[i];
s = ApplyInverseTextNormalization(s);

text += s;
r.tokens.push_back(s);
}

r.text = text;
r.lang = src.lang;

return r;
}

private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
Expand Down
73 changes: 73 additions & 0 deletions sherpa-onnx/csrc/text-utils-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,77 @@ TEST(RemoveInvalidUtf8Sequences, Case1) {
EXPECT_EQ(s.size() + 4, v.size());
}


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few tests, verified the characters are being removed correctly.

// Tests for sanitizeUtf8
TEST(RemoveInvalidUtf8Sequences, ValidUtf8StringPassesUnchanged) {
std::string input = "Valid UTF-8 🌍";
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), input);
}

TEST(RemoveInvalidUtf8Sequences, SingleInvalidByteReplaced) {
std::string input = "Invalid \xFF UTF-8";
std::string expected = "Invalid UTF-8";
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
}

TEST(RemoveInvalidUtf8Sequences, TruncatedUtf8SequenceReplaced) {
std::string input = "Broken \xE2\x82"; // Incomplete UTF-8 sequence
std::string expected = "Broken ";
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
}

TEST(RemoveInvalidUtf8Sequences, MultipleInvalidBytes) {
std::string input = "Test \xC0\xC0\xF8\xA0"; // Multiple invalid sequences
std::string expected = "Test ";
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
}

TEST(RemoveInvalidUtf8Sequences, BreakingCase_SpaceFollowedByInvalidByte) {
std::string input = "\x20\xC4"; // Space followed by an invalid byte
std::string expected = " "; // 0xC4 removed
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
}

TEST(RemoveInvalidUtf8Sequences, ValidUtf8WithEdgeCaseCharacters) {
std::string input = "Edge 🏆💯";
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), input);
}

TEST(RemoveInvalidUtf8Sequences, MixedValidAndInvalidBytes) {
std::string input = "Mix \xE2\x82\xAC \xF0\x9F\x98\x81 \xFF";
std::string expected = "Mix € 😁 "; // Invalid bytes removed
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
}

TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte) {
std::string input = "\x20\xC4"; // Space (0x20) followed by invalid (0xC4)
std::string expected = " "; // Space remains, 0xC4 is removed
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
}

TEST(RemoveInvalidUtf8Sequences, RemoveTruncatedC4) {
std::string input = "Hello \xc4 world"; // Invalid `0xC4`
std::string expected = "Hello world"; // `0xC4` should be removed
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
}

TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte_Breaking) {
std::string input = "\x20\xc4"; // Space followed by invalid `0xc4`
std::string expected = " "; // `0xc4` should be removed, space remains
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
}

TEST(RemoveInvalidUtf8Sequences, DebugSpaceFollowedByInvalidByte) {
std::string input = "\x20\xc4"; // Space followed by invalid `0xc4`
std::string output = RemoveInvalidUtf8Sequences(input);

std::cout << "Processed string: ";
for (unsigned char c : output) {
printf("\\x%02x ", c);
}
std::cout << std::endl;

EXPECT_EQ(output, " "); // Expect `0xc4` to be removed, leaving only space
}

} // namespace sherpa_onnx
Loading