diff --git a/crates/llm-base/src/tokenizer/embedded.rs b/crates/llm-base/src/tokenizer/embedded.rs index 25387d23..6d7614ae 100644 --- a/crates/llm-base/src/tokenizer/embedded.rs +++ b/crates/llm-base/src/tokenizer/embedded.rs @@ -171,16 +171,12 @@ impl EmbeddedTokenizer { match self.model { GgufEmbeddedTokenizerModel::Llama => { - let text = escape_whitespace(format!(" {text}").as_bytes()); - - Ok(TokenizerSpm::new(self) - .tokenize(&text) - .into_iter() - .map(|id| { - // TODO: see if this can be made more efficient - (self.id_to_token[id as usize].text.clone(), id) - }) - .collect()) + let text = escape_whitespace(format!(" {}", text).as_bytes()); + for id in TokenizerSpm::new(self).tokenize(&text) { + // TODO: see if this can be made more efficient + output.push((self.id_to_token[id as usize].text.clone(), id)); + } + Ok(output) } _ => unimplemented!(), } @@ -212,6 +208,11 @@ impl EmbeddedTokenizer { _ => unimplemented!(), } + // remove first b' ' + if ret.first() == Some(&b' ') { + ret.remove(0); + } + ret } } @@ -321,7 +322,7 @@ impl TryFrom for TokenType { } } -#[derive(Clone)] +#[derive(Clone, Debug)] struct Symbol { prev: isize, next: isize, @@ -329,6 +330,7 @@ struct Symbol { n: usize, } +#[derive(Debug)] struct LlmBigramSpm { left: isize, right: isize, @@ -379,18 +381,16 @@ impl<'a> TokenizerSpm<'a> { let mut index = 0; let mut offs = 0; while offs < text.len() { - let len = text[offs..].len(); + let len = utf8_len(text[offs]); + let sym_text = text[offs..].to_vec(); + let sym_n = len.min(text.len() - offs); + offs += sym_n; let sym = Symbol { - text: text[offs..offs + len].to_vec(), - n: len.min(text.len() - offs), + text: sym_text, + n: sym_n, prev: index - 1, - next: if offs + len == text.len() { - -1 - } else { - index + 1 - }, + next: if offs == text.len() { -1 } else { index + 1 }, }; - offs += sym.n; index += 1; self.symbols.push(sym); } @@ -435,7 +435,7 @@ impl<'a> TokenizerSpm<'a> { } fn resegment(&self, symbol: &Symbol, output: &mut Vec) { - let text = symbol.text.clone(); + let text = symbol.text.clone()[..symbol.n].to_vec(); if let Some(&token_id) = self.vocab.token_to_id.get(&text) { output.push(token_id); return; @@ -457,11 +457,10 @@ impl<'a> TokenizerSpm<'a> { return; } - let text = [ - self.symbols[left as usize].text.clone(), - self.symbols[right as usize].text.clone(), - ] - .concat(); + let text = self.symbols[left as usize].text.clone() + [..(self.symbols[left as usize].n + self.symbols[right as usize].n)] + .to_vec(); + if let Some(&token_id) = self.vocab.token_to_id.get(&text) { if (token_id as usize) < self.vocab.id_to_token.len() { let tok_data = &self.vocab.id_to_token[token_id as usize]; @@ -520,3 +519,9 @@ fn unescape_whitespace(text: &[u8]) -> Vec { out } + +fn utf8_len(src: u8) -> usize { + const LOOKUP: &[u8] = &[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4]; + let highbits: u8 = src >> 4; + LOOKUP[highbits as usize] as usize +} diff --git a/crates/llm-base/src/tokenizer/mod.rs b/crates/llm-base/src/tokenizer/mod.rs index 9852993e..d1d50c56 100644 --- a/crates/llm-base/src/tokenizer/mod.rs +++ b/crates/llm-base/src/tokenizer/mod.rs @@ -79,10 +79,6 @@ impl Display for HuggingFaceTokenizerErrorSource { } } -/// At the time of writing, the embedded tokenizer is not enabled as it has -/// some bugs. We're just not enabling the option while it's broken. -const EMBEDDED_TOKENIZER_ENABLED: bool = false; - #[derive(Clone, Debug, PartialEq)] /// The source of a tokenizer. pub enum TokenizerSource { @@ -140,13 +136,7 @@ impl TokenizerSource { if let Ok(hf) = gguf.metadata.get_str("tokenizer.huggingface.json") { Ok(Self::load_huggingface_json(hf)?) } else if EmbeddedTokenizer::is_present_in_metadata(&gguf.metadata) { - if EMBEDDED_TOKENIZER_ENABLED { - Ok(EmbeddedTokenizer::from_metadata(&gguf.metadata)?.into()) - } else { - Err(TokenizerLoadError::NoSupportedTokenizersFound { - unsupported_tokenizers: vec!["embedded".to_owned()], - }) - } + Ok(EmbeddedTokenizer::from_metadata(&gguf.metadata)?.into()) } else { Err(TokenizerLoadError::NoSupportedTokenizersFound { unsupported_tokenizers: vec![], diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 39069f06..540d1e91 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -116,7 +116,7 @@ macro_rules! define_models { impl ModelArchitecture { /// All available model architectures - pub const ALL: &[Self] = &[ + pub const ALL: &'static [Self] = &[ $( #[cfg(feature = $model_lowercase_str)] Self::$model_pascalcase,