Skip to content

Commit

Permalink
Merge pull request #110 from vintasoftware/feat/support-any-type-of-id
Browse files Browse the repository at this point in the history
Support any type of id
  • Loading branch information
fjsj authored Jun 25, 2024
2 parents 7197c7e + 99623d3 commit 73503db
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 46 deletions.
21 changes: 14 additions & 7 deletions django_ai_assistant/api/views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Any, List

from django.http import Http404
from django.shortcuts import get_object_or_404
Expand All @@ -17,6 +17,7 @@
ThreadSchemaIn,
)
from django_ai_assistant.conf import app_settings
from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.exceptions import AIAssistantNotDefinedError, AIUserNotAllowedError
from django_ai_assistant.helpers import use_cases
from django_ai_assistant.models import Message, Thread
Expand Down Expand Up @@ -85,7 +86,8 @@ def create_thread(request, payload: ThreadSchemaIn):


@api.get("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail_update_delete")
def get_thread(request, thread_id: str):
@with_cast_id
def get_thread(request, thread_id: Any):
try:
thread = use_cases.get_single_thread(
thread_id=thread_id, user=request.user, request=request
Expand All @@ -96,14 +98,16 @@ def get_thread(request, thread_id: str):


@api.patch("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail_update_delete")
def update_thread(request, thread_id: str, payload: ThreadSchemaIn):
@with_cast_id
def update_thread(request, thread_id: Any, payload: ThreadSchemaIn):
thread = get_object_or_404(Thread, id=thread_id)
name = payload.name
return use_cases.update_thread(thread=thread, name=name, user=request.user, request=request)


@api.delete("threads/{thread_id}/", response={204: None}, url_name="thread_detail_update_delete")
def delete_thread(request, thread_id: str):
@with_cast_id
def delete_thread(request, thread_id: Any):
thread = get_object_or_404(Thread, id=thread_id)
use_cases.delete_thread(thread=thread, user=request.user, request=request)
return 204, None
Expand All @@ -114,7 +118,8 @@ def delete_thread(request, thread_id: str):
response=List[ThreadMessagesSchemaOut],
url_name="messages_list_create",
)
def list_thread_messages(request, thread_id: str):
@with_cast_id
def list_thread_messages(request, thread_id: Any):
thread = get_object_or_404(Thread, id=thread_id)
messages = use_cases.get_thread_messages(thread=thread, user=request.user, request=request)
return [message_to_dict(m)["data"] for m in messages]
Expand All @@ -126,7 +131,8 @@ def list_thread_messages(request, thread_id: str):
response={201: None},
url_name="messages_list_create",
)
def create_thread_message(request, thread_id: str, payload: ThreadMessagesSchemaIn):
@with_cast_id
def create_thread_message(request, thread_id: Any, payload: ThreadMessagesSchemaIn):
thread = Thread.objects.get(id=thread_id)

use_cases.create_message(
Expand All @@ -142,7 +148,8 @@ def create_thread_message(request, thread_id: str, payload: ThreadMessagesSchema
@api.delete(
"threads/{thread_id}/messages/{message_id}/", response={204: None}, url_name="messages_delete"
)
def delete_thread_message(request, thread_id: str, message_id: str):
@with_cast_id
def delete_thread_message(request, thread_id: Any, message_id: Any):
message = get_object_or_404(Message, id=message_id, thread_id=thread_id)
use_cases.delete_message(
message=message,
Expand Down
35 changes: 35 additions & 0 deletions django_ai_assistant/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import uuid
from functools import wraps


def _cast_id(item_id, model):
if isinstance(item_id, str) and "UUID" in model._meta.pk.get_internal_type():
return uuid.UUID(item_id)
return item_id


# Decorator to cast ids to the correct type when using workaround UUIDAutoField
def with_cast_id(func):
@wraps(func)
def wrapper(*args, **kwargs):
from django_ai_assistant.models import Message, Thread

thread_id = kwargs.get("thread_id")
message_id = kwargs.get("message_id")
message_ids = kwargs.get("message_ids")

if thread_id:
thread_id = _cast_id(thread_id, Thread)
kwargs["thread_id"] = thread_id

if message_id:
message_id = _cast_id(message_id, Message)
kwargs["message_id"] = message_id

if message_ids:
message_ids = [_cast_id(message_id, Message) for message_id in message_ids]
kwargs["message_ids"] = message_ids

return func(*args, **kwargs)

return wrapper
21 changes: 13 additions & 8 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI

from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.exceptions import (
AIAssistantMisconfiguredError,
)
Expand Down Expand Up @@ -279,13 +280,14 @@ def get_prompt_template(self) -> ChatPromptTemplate:
]
)

def get_message_history(self, thread_id: int | None) -> BaseChatMessageHistory:
@with_cast_id
def get_message_history(self, thread_id: Any | None) -> BaseChatMessageHistory:
"""Get the chat message history instance for the given `thread_id`.\n
The Langchain chain uses the return of this method to get the thread messages
for the assistant, filling the `history` placeholder in the `get_prompt_template`.\n
Args:
thread_id (int | None): The thread ID for the chat message history.
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.
Returns:
Expand Down Expand Up @@ -430,7 +432,8 @@ def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]:
prompt | llm | StrOutputParser() | retriever,
)

def as_chain(self, thread_id: int | None) -> Runnable[dict, dict]:
@with_cast_id
def as_chain(self, thread_id: Any | None) -> Runnable[dict, dict]:
"""Create the Langchain chain for the assistant.\n
This chain is an agent that supports chat history, tool calling, and RAG (if `has_rag=True`).\n
`as_chain` uses many other methods to create the chain.\n
Expand All @@ -442,7 +445,7 @@ def as_chain(self, thread_id: int | None) -> Runnable[dict, dict]:
along with the key `"history"` containing the previous chat history.
Args:
thread_id (int | None): The thread ID for the chat message history.
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.
Returns:
Expand Down Expand Up @@ -514,15 +517,16 @@ def as_chain(self, thread_id: int | None) -> Runnable[dict, dict]:

return agent_with_chat_history

def invoke(self, *args: Any, thread_id: int | None, **kwargs: Any) -> dict:
@with_cast_id
def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
"""Invoke the assistant Langchain chain with the given arguments and keyword arguments.\n
This is the lower-level method to run the assistant.\n
The chain is created by the `as_chain` method.\n
Args:
*args: Positional arguments to pass to the chain.
Make sure to include a `dict` like `{"input": "user message"}`.
thread_id (int | None): The thread ID for the chat message history.
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.
**kwargs: Keyword arguments to pass to the chain.
Expand All @@ -533,13 +537,14 @@ def invoke(self, *args: Any, thread_id: int | None, **kwargs: Any) -> dict:
chain = self.as_chain(thread_id)
return chain.invoke(*args, **kwargs)

def run(self, message: str, thread_id: int | None, **kwargs: Any) -> str:
@with_cast_id
def run(self, message: str, thread_id: Any | None, **kwargs: Any) -> str:
"""Run the assistant with the given message and thread ID.\n
This is the higher-level method to run the assistant.\n
Args:
message (str): The user message to pass to the assistant.
thread_id (int | None): The thread ID for the chat message history.
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.
**kwargs: Additional keyword arguments to pass to the chain.
Expand Down
6 changes: 4 additions & 2 deletions django_ai_assistant/helpers/use_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def create_thread(


def get_single_thread(
thread_id: str,
thread_id: Any,
user: Any,
request: HttpRequest | None = None,
) -> Thread:
Expand Down Expand Up @@ -287,4 +287,6 @@ def delete_message(
if not can_delete_message(message=message, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to delete this message")

return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages([str(message.id)])
return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages(
message_ids=[str(message.id)]
)
8 changes: 6 additions & 2 deletions django_ai_assistant/langchain/chat_message_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict

from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.models import Message


logger = logging.getLogger(__name__)


class DjangoChatMessageHistory(BaseChatMessageHistory):
@with_cast_id
def __init__(
self,
thread_id: Any,
Expand Down Expand Up @@ -103,15 +105,17 @@ async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:

await Message.objects.abulk_update(created_messages, ["message"])

def remove_messages(self, message_ids: List[str]) -> None:
@with_cast_id
def remove_messages(self, message_ids: List[Any]) -> None:
"""Remove messages from the chat thread.
Args:
message_ids: A list of message IDs to remove.
"""
Message.objects.filter(id__in=message_ids).delete()

async def aremove_messages(self, message_ids: List[str]) -> None:
@with_cast_id
async def aremove_messages(self, message_ids: List[Any]) -> None:
"""Remove messages from the chat thread.
Args:
Expand Down
5 changes: 2 additions & 3 deletions django_ai_assistant/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import Any

from django.conf import settings
from django.db import models
Expand All @@ -9,7 +10,6 @@ class Thread(models.Model):
"""Thread model. A thread is a collection of messages between a user and the AI assistant.
Also called conversation or session."""

id: int # noqa: A003
messages: Manager["Message"]
name = models.CharField(max_length=255, blank=True)
"""Name of the thread. Can be blank."""
Expand Down Expand Up @@ -47,10 +47,9 @@ class Message(models.Model):
A message can be sent by a user or the AI assistant.\n
The message data is stored as a JSON field called `message`."""

id: int # noqa: A003
thread = models.ForeignKey(Thread, on_delete=models.CASCADE, related_name="messages")
"""Thread to which the message belongs."""
thread_id: int # noqa: A003
thread_id: Any # noqa: A003
message = models.JSONField()
"""Message content. This is a serialized Langchain `BaseMessage` that was serialized
with `message_to_dict` and can be deserialized with `messages_from_dict`."""
Expand Down
45 changes: 45 additions & 0 deletions docs/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,51 @@ class DocsAssistant(AIAssistant):
The `rag/ai_assistants.py` file in the [example project](https://github.com/vintasoftware/django-ai-assistant/tree/main/example)
shows an example of a RAG-powered AI Assistant that's able to answer questions about Django using the Django Documentation as context.

### Support for other types of Primary Key (PK)

You can have Django AI Assistant models use other types of primary key, such as strings, UUIDs, etc.
This is useful if you're concerned about leaking IDs that exponse thread count, message count, etc. to the frontend.
When using UUIDs, it will prevent users from figuring out if a thread or message exist or not (due to HTTP 404 vs 403).

Here are the files you have to change if you need the ids to be UUID:

```{.python title="myapp/fields.py"}
import uuid
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models import AutoField, UUIDField
BaseDatabaseOperations.integer_field_ranges['UUIDField'] = (0, 0)
class UUIDAutoField(UUIDField, AutoField):
def __init__(self, *args, **kwargs):
kwargs.setdefault('default', uuid.uuid4)
kwargs.setdefault('editable', False)
super().__init__(*args, **kwargs)
```

```{.python title="myapp/apps.py"}
from django_ai_assistant.apps import AIAssistantConfig
class AIAssistantConfigOverride(AIAssistantConfig):
default_auto_field = "django_ai_assistant.api.fields.UUIDAutoField"
```

```{.python title="myproject/settings.py"}
INSTALLED_APPS = [
# "django_ai_assistant", remove this line and add the one below
"example.apps.AIAssistantConfigOverride",
]
```

Make sure to run migrations after those changes:

```bash
python manage.py makemigrations
python manage.py migrate
```

For more information, check [Django docs on overriding AppConfig](https://docs.djangoproject.com/en/5.0/ref/applications/#for-application-users).

### Further configuration of AI Assistants

You can further configure the `AIAssistant` subclasses by overriding its public methods. Check the [Reference](reference/assistants-ref.md) for more information.
21 changes: 7 additions & 14 deletions frontend/openapi_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@
"in": "path",
"name": "thread_id",
"schema": {
"title": "Thread Id",
"type": "string"
"title": "Thread Id"
},
"required": true
}
Expand Down Expand Up @@ -169,8 +168,7 @@
"in": "path",
"name": "thread_id",
"schema": {
"title": "Thread Id",
"type": "string"
"title": "Thread Id"
},
"required": true
}
Expand Down Expand Up @@ -211,8 +209,7 @@
"in": "path",
"name": "thread_id",
"schema": {
"title": "Thread Id",
"type": "string"
"title": "Thread Id"
},
"required": true
}
Expand All @@ -238,8 +235,7 @@
"in": "path",
"name": "thread_id",
"schema": {
"title": "Thread Id",
"type": "string"
"title": "Thread Id"
},
"required": true
}
Expand Down Expand Up @@ -274,8 +270,7 @@
"in": "path",
"name": "thread_id",
"schema": {
"title": "Thread Id",
"type": "string"
"title": "Thread Id"
},
"required": true
}
Expand Down Expand Up @@ -311,17 +306,15 @@
"in": "path",
"name": "thread_id",
"schema": {
"title": "Thread Id",
"type": "string"
"title": "Thread Id"
},
"required": true
},
{
"in": "path",
"name": "message_id",
"schema": {
"title": "Message Id",
"type": "string"
"title": "Message Id"
},
"required": true
}
Expand Down
Loading

0 comments on commit 73503db

Please sign in to comment.