From 7a83077cd7798f0502d0f35af2a469f113d5e0f9 Mon Sep 17 00:00:00 2001 From: Anchen Date: Tue, 28 Jan 2025 12:13:50 +1100 Subject: [PATCH] chore(mlx-lm): support text type content in messages (#1225) * chore(mlx-lm): support text type content * chore: optimize the messagef content processing * nits + format --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/server.py | 31 ++++++++++++++++++++++++++++++- llms/tests/test_server.py | 23 +++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 4523e3ae4..de02704d9 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -114,6 +114,33 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): return prompt.rstrip() +def process_message_content(messages): + """ + Convert message content to a format suitable for `apply_chat_template`. + + The function operates on messages in place. It converts the 'content' field + to a string instead of a list of text fragments. + + Args: + message_list (list): A list of dictionaries, where each dictionary may + have a 'content' key containing a list of dictionaries with 'type' and + 'text' keys. + + Raises: + ValueError: If the 'content' type is not supported or if 'text' is missing. + + """ + for message in messages: + content = message["content"] + if isinstance(content, list): + text_fragments = [ + fragment["text"] for fragment in content if fragment["type"] == "text" + ] + if len(text_fragments) != len(content): + raise ValueError("Only 'text' content type is supported.") + message["content"] = "".join(text_fragments) + + @dataclass class PromptCache: cache: List[Any] = field(default_factory=list) @@ -591,8 +618,10 @@ def handle_chat_completions(self) -> List[int]: self.request_id = f"chatcmpl-{uuid.uuid4()}" self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" if self.tokenizer.chat_template: + messages = body["messages"] + process_message_content(messages) prompt = self.tokenizer.apply_chat_template( - body["messages"], + messages, body.get("tools", None), add_generation_prompt=True, ) diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index ad17554d1..ecf95f78d 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -80,6 +80,29 @@ def test_handle_chat_completions(self): self.assertIn("id", response_body) self.assertIn("choices", response_body) + def test_handle_chat_completions_with_content_fragments(self): + url = f"http://localhost:{self.port}/v1/chat/completions" + chat_post_data = { + "model": "chat_model", + "max_tokens": 10, + "temperature": 0.7, + "top_p": 0.85, + "repetition_penalty": 1.2, + "messages": [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ], + }, + {"role": "user", "content": [{"type": "text", "text": "Hello!"}]}, + ], + } + response = requests.post(url, json=chat_post_data) + response_body = response.text + self.assertIn("id", response_body) + self.assertIn("choices", response_body) + def test_handle_models(self): url = f"http://localhost:{self.port}/v1/models" response = requests.get(url)