From 7a83077cd7798f0502d0f35af2a469f113d5e0f9 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Tue, 28 Jan 2025 12:13:50 +1100
Subject: [PATCH] chore(mlx-lm): support text type content in messages (#1225)
* chore(mlx-lm): support text type content
* chore: optimize the messagef content processing
* nits + format
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/server.py | 31 ++++++++++++++++++++++++++++++-
llms/tests/test_server.py | 23 +++++++++++++++++++++++
2 files changed, 53 insertions(+), 1 deletion(-)
diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py
index 4523e3ae4..de02704d9 100644
--- a/llms/mlx_lm/server.py
+++ b/llms/mlx_lm/server.py
@@ -114,6 +114,33 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip()
+def process_message_content(messages):
+ """
+ Convert message content to a format suitable for `apply_chat_template`.
+
+ The function operates on messages in place. It converts the 'content' field
+ to a string instead of a list of text fragments.
+
+ Args:
+ message_list (list): A list of dictionaries, where each dictionary may
+ have a 'content' key containing a list of dictionaries with 'type' and
+ 'text' keys.
+
+ Raises:
+ ValueError: If the 'content' type is not supported or if 'text' is missing.
+
+ """
+ for message in messages:
+ content = message["content"]
+ if isinstance(content, list):
+ text_fragments = [
+ fragment["text"] for fragment in content if fragment["type"] == "text"
+ ]
+ if len(text_fragments) != len(content):
+ raise ValueError("Only 'text' content type is supported.")
+ message["content"] = "".join(text_fragments)
+
+
@dataclass
class PromptCache:
cache: List[Any] = field(default_factory=list)
@@ -591,8 +618,10 @@ def handle_chat_completions(self) -> List[int]:
self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
if self.tokenizer.chat_template:
+ messages = body["messages"]
+ process_message_content(messages)
prompt = self.tokenizer.apply_chat_template(
- body["messages"],
+ messages,
body.get("tools", None),
add_generation_prompt=True,
)
diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py
index ad17554d1..ecf95f78d 100644
--- a/llms/tests/test_server.py
+++ b/llms/tests/test_server.py
@@ -80,6 +80,29 @@ def test_handle_chat_completions(self):
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
+ def test_handle_chat_completions_with_content_fragments(self):
+ url = f"http://localhost:{self.port}/v1/chat/completions"
+ chat_post_data = {
+ "model": "chat_model",
+ "max_tokens": 10,
+ "temperature": 0.7,
+ "top_p": 0.85,
+ "repetition_penalty": 1.2,
+ "messages": [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."}
+ ],
+ },
+ {"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
+ ],
+ }
+ response = requests.post(url, json=chat_post_data)
+ response_body = response.text
+ self.assertIn("id", response_body)
+ self.assertIn("choices", response_body)
+
def test_handle_models(self):
url = f"http://localhost:{self.port}/v1/models"
response = requests.get(url)