Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move all examples to a single Django project #77

Merged
merged 6 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 49 additions & 32 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, ClassVar, Sequence, cast

from django.http import HttpRequest
from django.views import View

from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad.tools import (
Expand Down Expand Up @@ -223,7 +222,7 @@ def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]:
def as_chain(self, thread_id: int | None) -> Runnable[dict, dict]:
# Based on:
# - https://python.langchain.com/v0.2/docs/how_to/qa_chat_history_how_to/
# - https://python.langchain.com/v0.2/docs/how_to/migrate_agent/#memory
# - https://python.langchain.com/v0.2/docs/how_to/migrate_agent/
# TODO: use langgraph instead?
llm = self.get_llm()
tools = self.get_tools()
Expand Down Expand Up @@ -312,18 +311,43 @@ def register_assistant(cls: type[AIAssistant]):
return cls


def _get_assistant_cls(
assistant_id: str,
user: Any,
request: HttpRequest | None = None,
):
if assistant_id not in ASSISTANT_CLS_REGISTRY:
raise AIAssistantNotDefinedError(f"Assistant with id={assistant_id} not found")
assistant_cls = ASSISTANT_CLS_REGISTRY[assistant_id]
if not can_run_assistant(
assistant_cls=assistant_cls,
user=user,
request=request,
):
raise AIUserNotAllowedError("User is not allowed to use this assistant")
return assistant_cls


def get_single_assistant_info(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll use that in another task to show the assistant proper name in the frontend: #78

assistant_id: str,
user: Any,
request: HttpRequest | None = None,
):
assistant_cls = _get_assistant_cls(assistant_id, user, request)

return {
"id": assistant_id,
"name": assistant_cls.name,
}


def get_assistants_info(
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
return [
{
"id": assistant_id,
"name": assistant_cls.name,
}
for assistant_id, assistant_cls in ASSISTANT_CLS_REGISTRY.items()
if can_run_assistant(assistant_cls=assistant_cls, user=user, request=request, view=view)
_get_assistant_cls(assistant_id=assistant_id, user=user, request=request)
for assistant_id in ASSISTANT_CLS_REGISTRY.keys()
]


Expand All @@ -333,23 +357,14 @@ def create_message(
user: Any,
content: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
if not can_create_message(thread=thread, user=user, request=request, view=view):
assistant_cls = _get_assistant_cls(assistant_id, user, request)

if not can_create_message(thread=thread, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to create messages in this thread")
if assistant_id not in ASSISTANT_CLS_REGISTRY:
raise AIAssistantNotDefinedError(f"Assistant with id={assistant_id} not found")
assistant_cls = ASSISTANT_CLS_REGISTRY[assistant_id]
if not can_run_assistant(
assistant_cls=assistant_cls,
user=user,
request=request,
view=view,
):
raise AIUserNotAllowedError("User is not allowed to use this assistant")

# TODO: Check if we can separate the message creation from the chain invoke
assistant = assistant_cls(user=user, request=request, view=view)
assistant = assistant_cls(user=user, request=request)
assistant_message = assistant.invoke(
{"input": content},
thread_id=thread.id,
Expand All @@ -361,19 +376,25 @@ def create_thread(
name: str,
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
if not can_create_thread(user=user, request=request, view=view):
if not can_create_thread(user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to create threads")

thread = Thread.objects.create(name=name, created_by=user)
return thread


def get_single_thread(
thread_id: str,
user: Any,
request: HttpRequest | None = None,
):
return Thread.objects.filter(created_by=user).get(id=thread_id)


def get_threads(
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
return list(Thread.objects.filter(created_by=user))

Expand All @@ -382,9 +403,8 @@ def delete_thread(
thread: Thread,
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
if not can_delete_thread(thread=thread, user=user, request=request, view=view):
if not can_delete_thread(thread=thread, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to delete this thread")

return thread.delete()
Expand All @@ -394,7 +414,6 @@ def get_thread_messages(
thread_id: str,
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
) -> list[BaseMessage]:
# TODO: have more permissions for threads? View thread permission?
thread = Thread.objects.get(id=thread_id)
Expand All @@ -409,7 +428,6 @@ def create_thread_message_as_user(
content: str,
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
# TODO: have more permissions for threads? View thread permission?
thread = Thread.objects.get(id=thread_id)
Expand All @@ -423,9 +441,8 @@ def delete_message(
message: Message,
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
if not can_delete_message(message=message, user=user, request=request, view=view):
if not can_delete_message(message=message, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to delete this message")

return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages([message.id])
return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages([str(message.id)])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type error.

33 changes: 18 additions & 15 deletions django_ai_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .helpers.assistants import (
create_message,
get_assistants_info,
get_single_assistant_info,
get_single_thread,
get_thread_messages,
get_threads,
)
Expand Down Expand Up @@ -39,29 +41,35 @@ def ai_user_not_allowed_handler(request, exc):

@api.get("assistants/", response=List[AssistantSchema], url_name="assistants_list")
def list_assistants(request):
return list(get_assistants_info(user=request.user, request=request, view=None))
return list(get_assistants_info(user=request.user, request=request))


@api.get("assistants/{assistant_id}/", response=AssistantSchema, url_name="assistant_detail")
def get_assistant(request, assistant_id: str):
return get_single_assistant_info(assistant_id=assistant_id, user=request.user, request=request)


@api.get("threads/", response=List[ThreadSchema], url_name="threads_list_create")
def list_threads(request):
return list(get_threads(user=request.user, request=request, view=None))
return list(get_threads(user=request.user, request=request))


@api.get("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail")
def get_thread(request, thread_id: str):
thread = get_single_thread(thread_id=thread_id, user=request.user, request=request)
return thread


@api.post("threads/", response=ThreadSchema, url_name="threads_list_create")
def create_thread(request, payload: ThreadSchemaIn):
name = payload.name
return assistants.create_thread(name=name, user=request.user, request=request, view=None)
return assistants.create_thread(name=name, user=request.user, request=request)


@api.delete("threads/{thread_id}/", response={204: None}, url_name="threads_delete")
fjsj marked this conversation as resolved.
Show resolved Hide resolved
def delete_thread(request, thread_id: str):
thread = get_object_or_404(Thread, id=thread_id)
assistants.delete_thread(
thread=thread,
user=request.user,
request=request,
view=None,
)
assistants.delete_thread(thread=thread, user=request.user, request=request)
return 204, None


Expand All @@ -71,9 +79,7 @@ def delete_thread(request, thread_id: str):
url_name="messages_list_create",
)
def list_thread_messages(request, thread_id: str):
messages = get_thread_messages(
thread_id=thread_id, user=request.user, request=request, view=None
)
messages = get_thread_messages(thread_id=thread_id, user=request.user, request=request)
return [message_to_dict(m)["data"] for m in messages]


Expand All @@ -92,9 +98,7 @@ def create_thread_message(request, thread_id: str, payload: ThreadMessagesSchema
user=request.user,
content=payload.content,
request=request,
view=None,
)

return 201, None


Expand All @@ -107,6 +111,5 @@ def delete_thread_message(request, thread_id: str, message_id: str):
message=message,
user=request.user,
request=request,
view=None,
)
return 204, None
60 changes: 57 additions & 3 deletions example/assets/js/App.tsx
Original file line number Diff line number Diff line change
@@ -1,19 +1,73 @@
import "@mantine/core/styles.css";

import { createTheme, MantineProvider } from "@mantine/core";
import { Chat } from "@/Chat";
import { Container, createTheme, MantineProvider } from "@mantine/core";
import { Chat } from "@/components";
import { createBrowserRouter, Link, RouterProvider } from "react-router-dom";
import { configAIAssistant } from "django-ai-assistant-client";
import React from "react";
fjsj marked this conversation as resolved.
Show resolved Hide resolved

const theme = createTheme({});

// Relates to path("ai-assistant/", include("django_ai_assistant.urls"))
// which can be found at example/demo/urls.py)
configAIAssistant({ baseURL: "ai-assistant" });

const ExampleIndex = () => {
return (
<Container>
<h1>Examples</h1>
<ul>
<li>
<Link to="/weather-chat">Weather Chat</Link>
</li>
<li>
<Link to="/movies-chat">Movie Recommendation Chat</Link>
</li>
<li>
<Link to="/rag-chat">Django Docs RAG Chat</Link>
</li>
<li>
<Link to="/htmx">HTMX demo (no React)</Link>
</li>
</ul>
</Container>
);
};
fjsj marked this conversation as resolved.
Show resolved Hide resolved

const Redirect = ({ to }: { to: string }) => {
window.location.href = to;
return null;
};

const router = createBrowserRouter([
{
path: "/",
element: <ExampleIndex />,
},
{
path: "/weather-chat",
element: <Chat assistantId="weather_assistant" />,
},
{
path: "/movies-chat",
element: <Chat assistantId="movie_recommendation_assistant" />,
},
{
path: "/rag-chat",
element: <Chat assistantId="django_docs_assistant" />,
},
{
path: "/htmx",
element: <Redirect to="/htmx/" />,
},
]);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please review.


const App = () => {
return (
<MantineProvider theme={theme}>
<Chat />
<React.StrictMode>
<RouterProvider router={router} />
</React.StrictMode>
</MantineProvider>
);
};
Expand Down
1 change: 0 additions & 1 deletion example/assets/js/Chat/index.ts

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@
.chat {
height: 100%;
}

.mdMessage p {
margin: 0;
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reduced the margins:
Screenshot from 2024-06-14 15-33-58

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks better!

Loading
Loading