Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
baur-krykpayev authored Jun 20, 2024
2 parents 8ffa16e + a7b4175 commit 2daaa13
Showing 1 changed file with 42 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,17 +1,40 @@
import base64
import json
from typing import Optional

import httpx
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessageChunk,
HumanMessage,
ToolMessage,
)
from langchain_core.tools import tool

from langchain_standard_tests.unit_tests.chat_models import (
ChatModelTests,
my_adder_tool,
)


@tool
def magic_function(input: int) -> int:
"""Applies a magic function to an input."""
return input + 2


def _validate_tool_call_message(message: AIMessage) -> None:
assert isinstance(message, AIMessage)
assert len(message.tool_calls) == 1
tool_call = message.tool_calls[0]
assert tool_call["name"] == "magic_function"
assert tool_call["args"] == {"input": 3}
assert tool_call["id"] is not None


class ChatModelIntegrationTests(ChatModelTests):
def test_invoke(self, model: BaseChatModel) -> None:
result = model.invoke("Hello")
Expand Down Expand Up @@ -98,6 +121,24 @@ def test_stop_sequence(self, model: BaseChatModel) -> None:
result = custom_model.invoke("hi")
assert isinstance(result, AIMessage)

def test_tool_calling(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
model_with_tools = model.bind_tools([magic_function])

# Test invoke
query = "What is the value of magic_function(3)? Use the tool."
result = model_with_tools.invoke(query)
assert isinstance(result, AIMessage)
_validate_tool_call_message(result)

# Test stream
full: Optional[BaseMessageChunk] = None
for chunk in model_with_tools.stream(query):
full = chunk if full is None else full + chunk # type: ignore
assert isinstance(full, AIMessage)
_validate_tool_call_message(full)

def test_tool_message_histories_string_content(
self,
model: BaseChatModel,
Expand Down

0 comments on commit 2daaa13

Please sign in to comment.