Skip to content

Commit

Permalink
Merge pull request #171 from vintasoftware/structured-outputs
Browse files Browse the repository at this point in the history
Enables structured output assistants
  • Loading branch information
filipeximenes authored Sep 13, 2024
2 parents 2a2efc4 + b3109bc commit 142f606
Show file tree
Hide file tree
Showing 9 changed files with 1,315 additions and 59 deletions.
90 changes: 81 additions & 9 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import abc
import inspect
import json
import re
from typing import Annotated, Any, ClassVar, Sequence, TypedDict, cast
from typing import Annotated, Any, ClassVar, Dict, Sequence, Type, TypedDict, cast

from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
Expand Down Expand Up @@ -37,6 +38,7 @@
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from pydantic import BaseModel

from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.exceptions import (
Expand Down Expand Up @@ -79,6 +81,12 @@ class AIAssistant(abc.ABC): # noqa: F821
When True, the assistant will use a retriever to get documents to provide as context to the LLM.
Additionally, the assistant class should implement the `get_retriever` method to return
the retriever to use."""
structured_output: Dict[str, Any] | Type[BaseModel] | Type | None = None
"""Structured output to use for the assistant.\n
Defaults to `None`.
When not `None`, the assistant will return a structured output in the provided format.
See https://python.langchain.com/v0.2/docs/how_to/structured_output/ for the available formats.
"""
_user: Any | None
"""The current user the assistant is helping. A model instance.\n
Set by the constructor.
Expand Down Expand Up @@ -269,6 +277,26 @@ def get_llm(self) -> BaseChatModel:
model_kwargs=model_kwargs,
)

def get_structured_output_llm(self) -> Runnable:
"""Get the LLM model to use for the structured output.
Returns:
BaseChatModel: The LLM model to use for the structured output.
"""
if not self.structured_output:
raise ValueError("structured_output is not defined")

llm = self.get_llm()

method = "json_mode"
if isinstance(llm, ChatOpenAI):
# When using ChatOpenAI, it's better to use json_schema method
# because it enables strict mode.
# https://platform.openai.com/docs/guides/structured-outputs
method = "json_schema"

return llm.with_structured_output(self.structured_output, method=method)

def get_tools(self) -> Sequence[BaseTool]:
"""Get the list of method tools the assistant can use.
By default, this is the `_method_tools` attribute, which are all `@method_tool`s.\n
Expand Down Expand Up @@ -422,7 +450,37 @@ class AgentState(TypedDict):
output: str

def setup(state: AgentState):
return {"messages": [SystemMessage(content=self.get_instructions())]}
messages: list[AnyMessage] = [SystemMessage(content=self.get_instructions())]

if self.structured_output:
schema = None

# If Pydantic
if inspect.isclass(self.structured_output) and issubclass(
self.structured_output, BaseModel
):
schema = json.dumps(self.structured_output.model_json_schema())

schema_information = (
f"JSON will have the following schema:\n\n{schema}\n\n" if schema else ""
)
tools_information = "Gather information using tools. " if tools else ""

# The assistant won't have access to the schema of the structured output before
# the last step of the chat. This message gives visibility about what fields the
# response should have so it can gather the necessary information by using tools.
messages.append(
HumanMessage(
content=(
"In the last step of this chat you will be asked to respond in JSON. "
+ schema_information
+ tools_information
+ "Don't generate JSON until you are explicitly told to. "
)
)
)

return {"messages": messages}

def history(state: AgentState):
history = message_history.messages if message_history else []
Expand All @@ -433,7 +491,9 @@ def retriever(state: AgentState):
return

retriever = self.get_history_aware_retriever()
messages_without_input = state["messages"][:-1]
# Remove the initial instructions to prevent having two SystemMessages
# This is necessary for compatibility with Anthropic
messages_without_input = state["messages"][1:-1]
docs = retriever.invoke({"input": state["input"], "history": messages_without_input})

document_separator = self.get_document_separator()
Expand All @@ -443,11 +503,10 @@ def retriever(state: AgentState):
format_document(doc, document_prompt) for doc in docs
)

return {
"messages": SystemMessage(
content=f"---START OF CONTEXT---\n{formatted_docs}---END OF CONTEXT---\n"
)
}
system_message = state["messages"][0]
system_message.content += (
f"\n\n---START OF CONTEXT---\n{formatted_docs}---END OF CONTEXT---\n\n"
)

def agent(state: AgentState):
response = llm_with_tools.invoke(state["messages"])
Expand All @@ -463,7 +522,20 @@ def tool_selector(state: AgentState):
return "continue"

def record_response(state: AgentState):
return {"output": state["messages"][-1].content}
if self.structured_output:
llm_with_structured_output = self.get_structured_output_llm()
response = llm_with_structured_output.invoke(
[
*state["messages"],
HumanMessage(
content="Use the information gathered in the conversation to answer."
),
]
)
else:
response = state["messages"][-1].content

return {"output": response}

workflow = StateGraph(AgentState)

Expand Down
4 changes: 1 addition & 3 deletions example/demo/views.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

from django.contrib import messages
from django.http import JsonResponse
from django.shortcuts import get_object_or_404, redirect, render
Expand Down Expand Up @@ -122,4 +120,4 @@ def get(self, request, *args, **kwargs):
a = TourGuideAIAssistant()
data = a.run(f"My coordinates are: ({coordinates})")

return JsonResponse(json.loads(data))
return JsonResponse(data.model_dump())
42 changes: 15 additions & 27 deletions example/tour_guide/ai_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,37 @@

from django.utils import timezone

from pydantic import BaseModel, Field

from django_ai_assistant import AIAssistant, method_tool
from tour_guide.integrations import fetch_points_of_interest


def _tour_guide_example_json():
return json.dumps(
{
"nearby_attractions": [
{
"attraction_name": f"<attraction-{i}-name-here>",
"attraction_description": f"<attraction-{i}-description-here>",
"attraction_url": f"<attraction-{i}-imdb-page-url-here>",
}
for i in range(1, 6)
]
},
indent=2,
).translate( # Necessary due to ChatPromptTemplate
str.maketrans(
{
"{": "{{",
"}": "}}",
}
)
class Attraction(BaseModel):
attraction_name: str = Field(description="The name of the attraction in english")
attraction_description: str = Field(
description="The description of the attraction, provide information in an entertaining way"
)
attraction_url: str = Field(
description="The URL of the attraction, keep empty if you don't have this information"
)


class TourGuide(BaseModel):
nearby_attractions: list[Attraction] = Field(description="The list of nearby attractions")


class TourGuideAIAssistant(AIAssistant):
id = "tour_guide_assistant" # noqa: A003
name = "Tour Guide Assistant"
instructions = (
"You are a tour guide assistant that offers information about nearby attractions. "
"You will receive the user coordinates and should use available tools to find nearby attractions. "
"Only include in your response the items that are relevant to a tourist visiting the area. "
"Only call the find_nearby_attractions tool once. "
"Your response should only contain valid JSON data. DON'T include '```json' in your response. "
"The JSON should be formatted according to the following structure: \n"
f"\n\n{_tour_guide_example_json()}\n\n\n"
"In the 'attraction_name' field provide the name of the attraction in english. "
"In the 'attraction_description' field generate an overview about the attraction with the most important information, "
"curiosities and interesting facts. "
"Only include a value for the 'attraction_url' field if you find a real value in the provided data otherwise keep it empty. "
)
model = "gpt-4o-2024-08-06"
structured_output = TourGuide

def get_instructions(self):
# Warning: this will use the server's timezone
Expand Down
Loading

0 comments on commit 142f606

Please sign in to comment.