Skip to content

Commit

Permalink
chore(mlx-lm): support text type content in messages (#1225)
Browse files Browse the repository at this point in the history
* chore(mlx-lm): support text type content

* chore: optimize the messagef content processing

* nits + format

---------

Co-authored-by: Awni Hannun <[email protected]>
  • Loading branch information
mzbac and awni authored Jan 28, 2025
1 parent f44a52e commit 7a83077
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
31 changes: 30 additions & 1 deletion llms/mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,33 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip()


def process_message_content(messages):
"""
Convert message content to a format suitable for `apply_chat_template`.
The function operates on messages in place. It converts the 'content' field
to a string instead of a list of text fragments.
Args:
message_list (list): A list of dictionaries, where each dictionary may
have a 'content' key containing a list of dictionaries with 'type' and
'text' keys.
Raises:
ValueError: If the 'content' type is not supported or if 'text' is missing.
"""
for message in messages:
content = message["content"]
if isinstance(content, list):
text_fragments = [
fragment["text"] for fragment in content if fragment["type"] == "text"
]
if len(text_fragments) != len(content):
raise ValueError("Only 'text' content type is supported.")
message["content"] = "".join(text_fragments)


@dataclass
class PromptCache:
cache: List[Any] = field(default_factory=list)
Expand Down Expand Up @@ -591,8 +618,10 @@ def handle_chat_completions(self) -> List[int]:
self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
if self.tokenizer.chat_template:
messages = body["messages"]
process_message_content(messages)
prompt = self.tokenizer.apply_chat_template(
body["messages"],
messages,
body.get("tools", None),
add_generation_prompt=True,
)
Expand Down
23 changes: 23 additions & 0 deletions llms/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,29 @@ def test_handle_chat_completions(self):
self.assertIn("id", response_body)
self.assertIn("choices", response_body)

def test_handle_chat_completions_with_content_fragments(self):
url = f"http://localhost:{self.port}/v1/chat/completions"
chat_post_data = {
"model": "chat_model",
"max_tokens": 10,
"temperature": 0.7,
"top_p": 0.85,
"repetition_penalty": 1.2,
"messages": [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
],
},
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
],
}
response = requests.post(url, json=chat_post_data)
response_body = response.text
self.assertIn("id", response_body)
self.assertIn("choices", response_body)

def test_handle_models(self):
url = f"http://localhost:{self.port}/v1/models"
response = requests.get(url)
Expand Down

0 comments on commit 7a83077

Please sign in to comment.