Skip to content

Commit

Permalink
[Bugfix] Mistral tokenizer encode accept list of str (vllm-project#12149
Browse files Browse the repository at this point in the history
)

Signed-off-by: Kunshang Ji <[email protected]>
  • Loading branch information
jikunshang authored Jan 17, 2025
1 parent 58fd57f commit 54cacf0
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Tekkenizer)

from vllm.logger import init_logger
from vllm.utils import is_list_of

if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
Expand All @@ -27,7 +28,7 @@

@dataclass
class Encoding:
input_ids: List[int]
input_ids: Union[List[int], List[List[int]]]


def maybe_serialize_tool_calls(request: ChatCompletionRequest):
Expand Down Expand Up @@ -223,17 +224,25 @@ def __len__(self) -> int:

def __call__(
self,
prompt: str,
prompt: Union[str, List[str], List[int]],
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
):
# Mistral Tokenizers should not add special tokens
input_ids = self.encode(prompt)

if truncation:
input_ids = input_ids[:max_length]

input_ids: Union[List[int], List[List[int]]]
# For List[str], original prompt text
if is_list_of(prompt, str):
input_ids_: List[List[int]] = []
for p in prompt:
each_input_ids = self.encode_one(p, truncation, max_length)
input_ids_.append(each_input_ids)
input_ids = input_ids_
# For List[int], apply chat template output, already tokens.
elif is_list_of(prompt, int):
input_ids = prompt
# For str, single prompt text
else:
input_ids = self.encode_one(prompt, truncation, max_length)
return Encoding(input_ids=input_ids)

def get_vocab(self) -> Dict[str, int]:
Expand All @@ -245,6 +254,19 @@ def get_added_vocab(self) -> Dict[str, int]:
# Mistral tokenizers have no added vocabulary
return {}

def encode_one(
self,
prompt: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
# Mistral Tokenizers should not add special tokens
input_ids = self.encode(prompt)

if truncation:
input_ids = input_ids[:max_length]
return input_ids

def encode(self, prompt: str) -> List[int]:
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.
Expand Down

0 comments on commit 54cacf0

Please sign in to comment.