Skip to content

Commit

Permalink
Merge pull request #175 from vintasoftware/feat/movie-improvements
Browse files Browse the repository at this point in the history
Improve MovieRecommendationAIAssistant
  • Loading branch information
fjsj authored Oct 1, 2024
2 parents c1f3a82 + 28f70d6 commit 0f2ec94
Show file tree
Hide file tree
Showing 13 changed files with 481 additions and 1,271 deletions.
73 changes: 27 additions & 46 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import abc
import inspect
import json
import re
from typing import Annotated, Any, ClassVar, Dict, Sequence, Type, TypedDict, cast

Expand Down Expand Up @@ -75,6 +74,8 @@ class AIAssistant(abc.ABC): # noqa: F821
"""
temperature: float = 1.0
"""Temperature to use for the assistant LLM model.\nDefaults to `1.0`."""
tool_max_concurrency: int = 1
"""Maximum number of tools to run concurrently / in parallel.\nDefaults to `1` (no concurrency)."""
has_rag: bool = False
"""Whether the assistant uses RAG (Retrieval-Augmented Generation) or not.\n
Defaults to `False`.
Expand Down Expand Up @@ -430,7 +431,7 @@ def as_graph(self, thread_id: Any | None = None) -> Runnable[dict, dict]:
llm_with_tools = llm.bind_tools(tools) if tools else llm

def custom_add_messages(left: list[BaseMessage], right: list[BaseMessage]):
result = add_messages(left, right)
result = add_messages(left, right) # type: ignore

if message_history:
messages_to_store = [
Expand All @@ -447,40 +448,11 @@ class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], custom_add_messages]
input: str # noqa: A003
context: str
output: str
output: Any

def setup(state: AgentState):
messages: list[AnyMessage] = [SystemMessage(content=self.get_instructions())]

if self.structured_output:
schema = None

# If Pydantic
if inspect.isclass(self.structured_output) and issubclass(
self.structured_output, BaseModel
):
schema = json.dumps(self.structured_output.model_json_schema())

schema_information = (
f"JSON will have the following schema:\n\n{schema}\n\n" if schema else ""
)
tools_information = "Gather information using tools. " if tools else ""

# The assistant won't have access to the schema of the structured output before
# the last step of the chat. This message gives visibility about what fields the
# response should have so it can gather the necessary information by using tools.
messages.append(
HumanMessage(
content=(
"In the last step of this chat you will be asked to respond in JSON. "
+ schema_information
+ tools_information
+ "Don't generate JSON until you are explicitly told to. "
)
)
)

return {"messages": messages}
system_prompt = self.get_instructions()
return {"messages": [SystemMessage(content=system_prompt)]}

def history(state: AgentState):
history = message_history.messages if message_history else []
Expand Down Expand Up @@ -522,16 +494,23 @@ def tool_selector(state: AgentState):
return "continue"

def record_response(state: AgentState):
# Structured output must happen in the end, to avoid disabling tool calling.
# Tool calling + structured output is not supported by OpenAI:
if self.structured_output:
llm_with_structured_output = self.get_structured_output_llm()
response = llm_with_structured_output.invoke(
[
*state["messages"],
HumanMessage(
content="Use the information gathered in the conversation to answer."
),
]
messages = state["messages"]

# Change the original system prompt:
if isinstance(messages[0], SystemMessage):
messages[0].content += "\nUse the chat history to produce a JSON output."

# Add a final message asking for JSON generation / structured output:
json_request_message = HumanMessage(
content="Use the chat history to produce a JSON output."
)
messages.append(json_request_message)

llm_with_structured_output = self.get_structured_output_llm()
response = llm_with_structured_output.invoke(messages)
else:
response = state["messages"][-1].content

Expand Down Expand Up @@ -581,10 +560,12 @@ def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
structured like `{"output": "assistant response", "history": ...}`.
"""
graph = self.as_graph(thread_id)
return graph.invoke(*args, **kwargs)
config = kwargs.pop("config", {})
config["max_concurrency"] = config.pop("max_concurrency", self.tool_max_concurrency)
return graph.invoke(*args, config=config, **kwargs)

@with_cast_id
def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> str:
def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> Any:
"""Run the assistant with the given message and thread ID.\n
This is the higher-level method to run the assistant.\n
Expand All @@ -595,7 +576,7 @@ def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> str:
**kwargs: Additional keyword arguments to pass to the graph.
Returns:
str: The assistant response to the user message.
Any: The assistant response to the user message.
"""
return self.invoke(
{
Expand All @@ -605,7 +586,7 @@ def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> str:
**kwargs,
)["output"]

def _run_as_tool(self, message: str, **kwargs: Any) -> str:
def _run_as_tool(self, message: str, **kwargs: Any) -> Any:
return self.run(message, thread_id=None, **kwargs)

def as_tool(self, description: str) -> BaseTool:
Expand Down
5 changes: 2 additions & 3 deletions example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ cp .env.example .env
Fill the `.env` file with the necessary API keys. You'll need accounts on:

- [OpenAI](https://platform.openai.com/)
- [Weather](https://www.weatherapi.com/)
- [Tavily](https://app.tavily.com/)
- [Firecrawl](https://www.firecrawl.dev/)
- [Weather API](https://www.weatherapi.com/)
- [Brave Search API](https://app.tavily.com/)

Activate the poetry shell:

Expand Down
1 change: 0 additions & 1 deletion example/demo/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,4 @@ def get(self, request, *args, **kwargs):

a = TourGuideAIAssistant()
data = a.run(f"My coordinates are: ({coordinates})")

return JsonResponse(data.model_dump())
5 changes: 5 additions & 0 deletions example/example/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@
# Necessary to avoid "OperationalError: database is locked" errors
# on parallel tool calling:
"init_command": "PRAGMA journal_mode=WAL;",
"transaction_mode": "IMMEDIATE",
"timeout": 20,
},
}
}
Expand Down Expand Up @@ -178,4 +180,7 @@
# Example specific settings:

WEATHER_API_KEY = os.getenv("WEATHER_API_KEY") # get for free at https://www.weatherapi.com/
BRAVE_SEARCH_API_KEY = os.getenv(
"BRAVE_SEARCH_API_KEY"
) # get for free at https://brave.com/search/api/
DJANGO_DOCS_BRANCH = "stable/5.0.x"
2 changes: 1 addition & 1 deletion example/issue_tracker/ai_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class IssueTrackerAIAssistant(AIAssistant):
"Make sure to include issue IDs in your responses, "
"to know which issue you or the user are referring to. "
)
model = "gpt-4o"
model = "gpt-4o-mini"
_user: User

@method_tool
Expand Down
2 changes: 1 addition & 1 deletion example/movies/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ class MovieBacklogItemAdmin(admin.ModelAdmin):
list_select_related = ("user",)
raw_id_fields = ("user",)

@admin.display(ordering="imdb_url", description="IMDB URL")
@admin.display(ordering="imdb_url", description="IMDb URL")
def imdb_url_link(self, obj):
return mark_safe(f'<a href="{obj.imdb_url}">{obj.imdb_url}</a>') # noqa: S308
116 changes: 62 additions & 54 deletions example/movies/ai_assistants.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,42 @@
from typing import Sequence

from django.conf import settings
from django.db import transaction
from django.db.models import Max
from django.utils import timezone

from firecrawl import FirecrawlApp
from langchain_community.tools import WikipediaQueryRun
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities import WikipediaAPIWrapper
import requests
from langchain_community.tools import BraveSearch
from langchain_core.tools import BaseTool
from pydantic import BaseModel

from django_ai_assistant import AIAssistant, method_tool
from movies.models import MovieBacklogItem


class IMDbMovie(BaseModel):
imdb_url: str
imdb_rating: float
scrapped_imdb_page_markdown: str


# Note this assistant is not registered, but we'll use it as a tool on the other.
# This one shouldn't be used directly, as it does web searches and scraping.
class IMDBURLFinderTool(AIAssistant):
id = "imdb_url_finder" # noqa: A003
class IMDbScraper(AIAssistant):
id = "imdb_scraper" # noqa: A003
instructions = (
"You're a tool to find the IMDB URL of a given movie. "
"Use the Tavily Search API to find the IMDB URL. "
"You're a function to find the IMDb URL of a given movie, "
"and scrape this URL to get the movie rating and other information.\n"
"Use the search function to find the IMDb URL. "
"Make search queries like: \n"
"- IMDB page of The Matrix\n"
"- IMDB page of The Godfather\n"
"- IMDB page of The Shawshank Redemption\n"
"Then check results and provide only the IMDB URL to the user."
"- IMDb page of The Matrix\n"
"- IMDb page of The Godfather\n"
"- IMDb page of The Shawshank Redemption\n"
"Then check results, scape the IMDb URL, process the page, and produce a JSON output."
)
name = "IMDB URL Finder"
model = "gpt-4o"
name = "IMDb Scraper"
model = "gpt-4o-mini"
structured_output = IMDbMovie

def get_instructions(self):
# Warning: this will use the server's timezone
Expand All @@ -36,9 +45,16 @@ def get_instructions(self):
current_date_str = timezone.now().date().isoformat()
return f"{self.instructions} Today is: {current_date_str}."

@method_tool
def scrape_imdb_url(self, url: str) -> str:
"""Scrape the IMDb URL and return the content as markdown."""
return requests.get("https://r.jina.ai/" + url, timeout=20).text[:10000]

def get_tools(self) -> Sequence[BaseTool]:
return [
TavilySearchResults(),
BraveSearch.from_api_key(
api_key=settings.BRAVE_SEARCH_API_KEY, search_kwargs={"count": 5}
),
*super().get_tools(),
]

Expand All @@ -48,21 +64,17 @@ class MovieRecommendationAIAssistant(AIAssistant):
instructions = (
"You're a helpful movie recommendation assistant. "
"Help the user find movies to watch and manage their movie backlogs. "
"By using the provided tools, you can:\n"
"- Get the IMDB URL of a movie\n"
"- Visit the IMDB page of a movie to get its rating\n"
"- Research for upcoming movies\n"
"- Research for similar movies\n"
"- Research more information about movies\n"
"- Get what movies are on user's backlog\n"
"- Add a movie to user's backlog\n"
"- Remove a movie to user's backlog\n"
"- Reorder movies in user's backlog\n"
"Use the provided functions to answer questions and run operations.\n"
"Note the backlog is stored in a DB. "
"When managing the backlog, you must call the functions, to keep the sync with the DB. "
"The backlog has an order, and you should respect it. Call `reorder_backlog` when necessary.\n"
"Include the IMDb URL and rating of the movies when displaying the backlog. "
"You must use the IMDb Scraper to get the IMDb URL and rating of the movies. \n"
"Ask the user if they want to add your recommended movies to their backlog, "
"but only if the movie is not on the user's backlog yet."
)
name = "Movie Recommendation Assistant"
model = "gpt-4o"
model = "gpt-4o-mini"

def get_instructions(self):
# Warning: this will use the server's timezone
Expand All @@ -81,28 +93,21 @@ def get_instructions(self):

def get_tools(self) -> Sequence[BaseTool]:
return [
TavilySearchResults(),
WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()), # pyright: ignore[reportCallIssue]
IMDBURLFinderTool().as_tool(description="Tool to find the IMDB URL of a given movie."),
BraveSearch.from_api_key(
api_key=settings.BRAVE_SEARCH_API_KEY, search_kwargs={"count": 5}
),
IMDbScraper().as_tool(description="IMDb Scraper to get the IMDb data a given movie."),
*super().get_tools(),
]

@method_tool
def firecrawl_scrape_url(self, url: str) -> str:
"""Visit the provided website URL and return the content as markdown."""

firecrawl_app = FirecrawlApp()
response = firecrawl_app.scrape_url(url, params={"formats": ["markdown"]})
return response["markdown"]

@method_tool
def get_movies_backlog(self) -> str:
"""Get what movies are on user's backlog."""

return (
"\n".join(
[
f"{item.position} - [{item.movie_name}]({item.imdb_url}) - Rating {item.imdb_rating}"
f"{item.position} - [{item.movie_name}]({item.imdb_url}) - {item.imdb_rating}"
for item in MovieBacklogItem.objects.filter(user=self._user)
]
)
Expand All @@ -113,28 +118,31 @@ def get_movies_backlog(self) -> str:
def add_movie_to_backlog(self, movie_name: str, imdb_url: str, imdb_rating: float) -> str:
"""Add a movie to user's backlog. Must pass the movie_name, imdb_url, and imdb_rating."""

MovieBacklogItem.objects.update_or_create(
imdb_url=imdb_url.strip(),
user=self._user,
defaults={
"movie_name": movie_name.strip(),
"imdb_rating": imdb_rating,
"position": MovieBacklogItem.objects.filter(user=self._user).aggregate(
Max("position", default=1)
)["position__max"],
},
)
with transaction.atomic():
MovieBacklogItem.objects.update_or_create(
imdb_url=imdb_url.strip(),
user=self._user,
defaults={
"movie_name": movie_name.strip(),
"imdb_rating": imdb_rating,
"position": MovieBacklogItem.objects.filter(user=self._user).aggregate(
Max("position", default=0)
)["position__max"]
+ 1,
},
)
return f"Added {movie_name} to backlog."

@method_tool
def remove_movie_from_backlog(self, movie_name: str) -> str:
"""Remove a movie from user's backlog."""

MovieBacklogItem.objects.filter(
user=self._user,
movie_name=movie_name.strip(),
).delete()
MovieBacklogItem.reorder_backlog(self._user)
with transaction.atomic():
MovieBacklogItem.objects.filter(
user=self._user,
movie_name=movie_name.strip(),
).delete()
MovieBacklogItem.reorder_backlog(self._user)
return f"Removed {movie_name} from backlog."

@method_tool
Expand Down
2 changes: 1 addition & 1 deletion example/rag/ai_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class DjangoDocsAssistant(AIAssistant):
"the user's question. If you don't know the answer, say that you don't know. "
"Use three sentences maximum and keep the answer concise."
)
model = "gpt-4o"
model = "gpt-4o-mini"
has_rag = True

def get_retriever(self) -> BaseRetriever:
Expand Down
2 changes: 1 addition & 1 deletion example/weather/ai_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class WeatherAIAssistant(AIAssistant):
id = "weather_assistant" # noqa: A003
name = "Weather Assistant"
model = "gpt-4o"
model = "gpt-4o-mini"

def get_instructions(self):
# Warning: this will use the server's timezone
Expand Down
Loading

0 comments on commit 0f2ec94

Please sign in to comment.