Skip to content

Commit

Permalink
Refactor format_chat_entry to add input_urls for images, audio, and d…
Browse files Browse the repository at this point in the history
…ocuments to content_text
  • Loading branch information
milovate committed Mar 7, 2025
1 parent 21d9bec commit cb8d5bb
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 16 deletions.
4 changes: 2 additions & 2 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,8 +1440,8 @@ def as_llm_context(
for i, msg in enumerate(reversed(msgs)):
entries[i] = format_chat_entry(
role=msg.role,
content=msg.content,
images=msg.attachments.filter(
content_text=msg.content,
input_images=msg.attachments.filter(
metadata__mime_type__startswith="image/"
).values_list("url", flat=True),
)
Expand Down
35 changes: 28 additions & 7 deletions daras_ai_v2/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from enum import Enum
from functools import wraps

import furl
import aifail
import requests
import typing_extensions
Expand Down Expand Up @@ -702,7 +703,7 @@ def run_language_model(
if prompt and not messages:
# convert text prompt to chat messages
messages = [
format_chat_entry(role=CHATML_ROLE_USER, content=prompt),
format_chat_entry(role=CHATML_ROLE_USER, content_text=prompt),
]
if not model.is_vision_model:
# remove images from the messages
Expand All @@ -718,7 +719,7 @@ def run_language_model(
messages.insert(
0,
format_chat_entry(
role=CHATML_ROLE_SYSTEM, content=DEFAULT_JSON_PROMPT
role=CHATML_ROLE_SYSTEM, content_text=DEFAULT_JSON_PROMPT
),
)
else:
Expand Down Expand Up @@ -781,7 +782,7 @@ def run_language_model(
if stream:
ret = [
[
format_chat_entry(role=CHATML_ROLE_ASSISTANT, content=msg)
format_chat_entry(role=CHATML_ROLE_ASSISTANT, content_text=msg)
for msg in ret
]
]
Expand Down Expand Up @@ -1228,7 +1229,7 @@ def run_openai_chat(
if stream:
return _stream_openai_chunked(completion, used_model, messages)
if not completion or not completion.choices:
return [format_chat_entry(role=CHATML_ROLE_ASSISTANT, content="")]
return [format_chat_entry(role=CHATML_ROLE_ASSISTANT, content_text="")]
else:
ret = [choice.message.dict() for choice in completion.choices]
record_openai_llm_usage(used_model, completion, messages, ret)
Expand Down Expand Up @@ -1964,11 +1965,31 @@ def entry_to_prompt_str(entry: ConversationEntry) -> str:


def format_chat_entry(
*, role: str, content: str, images: list[str] = None
*,
role: str,
content_text: str,
input_images: typing.Optional[list[str]] = None,
input_audio: typing.Optional[str] = None,
input_documents: typing.Optional[list[str]] = None,
render_input_urls: typing.Optional[bool] = False,
) -> ConversationEntry:
if images:

input_urls = []
if input_images and not render_input_urls:
input_urls.append(f"Image URLs: {', '.join(input_images)}")
if input_audio and not render_input_urls:
input_urls.append(f"Audio URL: {input_audio}")
if input_documents and not render_input_urls:
filenames = ", ".join(
f"\`{furl.furl(url.strip('/')).path.segments[-1]}\` \({url}\)"
for url in input_documents
)
input_urls.append(f"Document URLs: {filenames}")

content = content_text + ("\n" + "\n".join(input_urls) if input_urls else "")
if input_images:
content = [
{"type": "image_url", "image_url": {"url": url}} for url in images
{"type": "image_url", "image_url": {"url": url}} for url in input_images
] + [
{"type": "text", "text": content},
]
Expand Down
23 changes: 16 additions & 7 deletions recipes/VideoBots.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,12 +720,14 @@ def on_send(
gui.session_state["messages"] = gui.session_state.get("messages", []) + [
format_chat_entry(
role=CHATML_ROLE_USER,
content=prev_input,
images=prev_input_images,
content_text=prev_input,
input_images=prev_input_images,
input_audio=prev_input_audio,
input_documents=prev_input_documents,
),
format_chat_entry(
role=CHATML_ROLE_ASSISTANT,
content=prev_output,
content_text=prev_output,
),
]

Expand Down Expand Up @@ -974,7 +976,7 @@ def input_translation_step(self, request, user_input, ocr_texts):
return user_input

def build_final_prompt(self, request, response, user_input, model):
# consturct the system prompt
# construct the system prompt
bot_script = (request.bot_script or "").strip()
if bot_script:
bot_script = render_prompt_vars(bot_script, gui.session_state)
Expand All @@ -987,7 +989,11 @@ def build_final_prompt(self, request, response, user_input, model):
user_input = yield from self.search_step(request, response, user_input, model)
# construct user prompt
user_prompt = format_chat_entry(
role=CHATML_ROLE_USER, content=user_input, images=request.input_images
role=CHATML_ROLE_USER,
content_text=user_input,
input_images=request.input_images,
input_audio=request.input_audio,
input_documents=request.input_documents,
)
# truncate the history to fit the model's max tokens
max_history_tokens = (
Expand Down Expand Up @@ -1017,7 +1023,7 @@ def search_step(self, request, response, user_input, model):
if request.documents:
# formulate the search query as a history of all the messages
query_msgs = request.messages + [
format_chat_entry(role=CHATML_ROLE_USER, content=user_input)
format_chat_entry(role=CHATML_ROLE_USER, content_text=user_input)
]
clip_idx = convo_window_clipper(query_msgs, model.context_window // 2)
query_msgs = query_msgs[clip_idx:]
Expand Down Expand Up @@ -1658,7 +1664,10 @@ def render_chat_list_view(self):
if input_prompt or input_images or input_audio:
messages += [
format_chat_entry(
role=CHATML_ROLE_USER, content=input_prompt, images=input_images
role=CHATML_ROLE_USER,
content_text=input_prompt,
input_images=input_images,
render_input_urls=True,
),
]
# render history
Expand Down

0 comments on commit cb8d5bb

Please sign in to comment.