Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Make thread_id parameter of AIAssistant run method optional #156

Merged
merged 2 commits into from
Aug 6, 2024

Conversation

rodbv
Copy link
Contributor

@rodbv rodbv commented Jul 14, 2024

The following command in the "Manually calling an AI Assistant" section of the Tutorial is failing:

In [2]: assistant = WeatherAIAssistant()
   ...: output = assistant.run("What's the weather in New York City?")

TypeError: AIAssistant.run() missing 1 required positional argument: 'thread_id'

This PR is to provide a default value for the thread_id parameter, to make the example work again.

It's also including a regression test, but let me know if that's overkill, I can remove it.

@@ -12,7 +12,10 @@
DEFAULT_DOCUMENT_PROMPT,
DEFAULT_DOCUMENT_SEPARATOR,
)
from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory
from langchain_core.chat_history import (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block was auto-formatted by the pre-commit hook (ruff)

@@ -295,7 +298,9 @@ def get_message_history(self, thread_id: Any | None) -> BaseChatMessageHistory:
"""

# DjangoChatMessageHistory must be here because Django may not be loaded yet elsewhere:
from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory
from django_ai_assistant.langchain.chat_message_histories import (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block was auto-formatted by the pre-commit hook (ruff)

@@ -538,7 +543,7 @@ def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
return chain.invoke(*args, **kwargs)

@with_cast_id
def run(self, message: str, thread_id: Any | None, **kwargs: Any) -> str:
def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> str:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fix.

@@ -222,4 +235,9 @@ def tool_a(self, foo: str) -> str:
assistant = FooAssistant()

assert hasattr(assistant, "_method_tools")
assert [t.name for t in assistant._method_tools] == ["tool_d", "tool_c", "tool_b", "tool_a"]
assert [t.name for t in assistant._method_tools] == [
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block was auto-formatted by the pre-commit hook (ruff)

@filipeximenes filipeximenes merged commit 0423bff into vintasoftware:main Aug 6, 2024
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants