From 54cacf008f00d35d46273fed4d538cf5740d0965 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Sat, 18 Jan 2025 00:47:53 +0800 Subject: [PATCH] [Bugfix] Mistral tokenizer encode accept list of str (#12149) Signed-off-by: Kunshang Ji --- vllm/transformers_utils/tokenizers/mistral.py | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 17d722e3d88fe..d801cf4e4c7b1 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -18,6 +18,7 @@ Tekkenizer) from vllm.logger import init_logger +from vllm.utils import is_list_of if TYPE_CHECKING: from vllm.entrypoints.chat_utils import ChatCompletionMessageParam @@ -27,7 +28,7 @@ @dataclass class Encoding: - input_ids: List[int] + input_ids: Union[List[int], List[List[int]]] def maybe_serialize_tool_calls(request: ChatCompletionRequest): @@ -223,17 +224,25 @@ def __len__(self) -> int: def __call__( self, - prompt: str, + prompt: Union[str, List[str], List[int]], add_special_tokens: bool = False, truncation: bool = False, max_length: Optional[int] = None, ): - # Mistral Tokenizers should not add special tokens - input_ids = self.encode(prompt) - - if truncation: - input_ids = input_ids[:max_length] - + input_ids: Union[List[int], List[List[int]]] + # For List[str], original prompt text + if is_list_of(prompt, str): + input_ids_: List[List[int]] = [] + for p in prompt: + each_input_ids = self.encode_one(p, truncation, max_length) + input_ids_.append(each_input_ids) + input_ids = input_ids_ + # For List[int], apply chat template output, already tokens. + elif is_list_of(prompt, int): + input_ids = prompt + # For str, single prompt text + else: + input_ids = self.encode_one(prompt, truncation, max_length) return Encoding(input_ids=input_ids) def get_vocab(self) -> Dict[str, int]: @@ -245,6 +254,19 @@ def get_added_vocab(self) -> Dict[str, int]: # Mistral tokenizers have no added vocabulary return {} + def encode_one( + self, + prompt: str, + truncation: bool = False, + max_length: Optional[int] = None, + ) -> List[int]: + # Mistral Tokenizers should not add special tokens + input_ids = self.encode(prompt) + + if truncation: + input_ids = input_ids[:max_length] + return input_ids + def encode(self, prompt: str) -> List[int]: # `encode` should only be used for prompt completion # it should never be used for chat_completion.