diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index 83c6a45fa8a3b..f69c74f0407e8 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -11,7 +11,6 @@ List, Optional, Sequence, - Union, cast, ) @@ -38,12 +37,10 @@ from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput from langchain.schema.messages import ( AIMessage, + AnyMessage, BaseMessage, BaseMessageChunk, - ChatMessage, - FunctionMessage, HumanMessage, - SystemMessage, ) from langchain.schema.output import ChatGenerationChunk from langchain.schema.runnable import RunnableConfig @@ -79,7 +76,7 @@ async def _agenerate_from_stream( return ChatResult(generations=[generation]) -class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): +class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): """Base class for Chat models.""" cache: Optional[bool] = None @@ -116,9 +113,7 @@ class Config: @property def OutputType(self) -> Any: """Get the output type for this runnable.""" - return Union[ - HumanMessage, AIMessage, ChatMessage, FunctionMessage, SystemMessage - ] + return AnyMessage def _convert_input(self, input: LanguageModelInput) -> PromptValue: if isinstance(input, PromptValue): @@ -140,23 +135,20 @@ def invoke( *, stop: Optional[List[str]] = None, **kwargs: Any, - ) -> BaseMessageChunk: + ) -> BaseMessage: config = config or {} return cast( - BaseMessageChunk, - cast( - ChatGeneration, - self.generate_prompt( - [self._convert_input(input)], - stop=stop, - callbacks=config.get("callbacks"), - tags=config.get("tags"), - metadata=config.get("metadata"), - run_name=config.get("run_name"), - **kwargs, - ).generations[0][0], - ).message, - ) + ChatGeneration, + self.generate_prompt( + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + **kwargs, + ).generations[0][0], + ).message async def ainvoke( self, @@ -165,7 +157,7 @@ async def ainvoke( *, stop: Optional[List[str]] = None, **kwargs: Any, - ) -> BaseMessageChunk: + ) -> BaseMessage: config = config or {} llm_result = await self.agenerate_prompt( [self._convert_input(input)], @@ -176,9 +168,7 @@ async def ainvoke( run_name=config.get("run_name"), **kwargs, ) - return cast( - BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message - ) + return cast(ChatGeneration, llm_result.generations[0][0]).message def stream( self, @@ -190,7 +180,9 @@ def stream( ) -> Iterator[BaseMessageChunk]: if type(self)._stream == BaseChatModel._stream: # model doesn't implement streaming, so use default implementation - yield self.invoke(input, config=config, stop=stop, **kwargs) + yield cast( + BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs) + ) else: config = config or {} messages = self._convert_input(input).to_messages() @@ -241,7 +233,9 @@ async def astream( ) -> AsyncIterator[BaseMessageChunk]: if type(self)._astream == BaseChatModel._astream: # model doesn't implement streaming, so use default implementation - yield self.invoke(input, config=config, stop=stop, **kwargs) + yield cast( + BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs) + ) else: config = config or {} messages = self._convert_input(input).to_messages() diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index b4b607f3fa2ac..e78280380bf35 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -2163,19 +2163,19 @@ dict({ 'anyOf': list([ dict({ - '$ref': '#/definitions/HumanMessage', + '$ref': '#/definitions/AIMessage', }), dict({ - '$ref': '#/definitions/AIMessage', + '$ref': '#/definitions/HumanMessage', }), dict({ '$ref': '#/definitions/ChatMessage', }), dict({ - '$ref': '#/definitions/FunctionMessage', + '$ref': '#/definitions/SystemMessage', }), dict({ - '$ref': '#/definitions/SystemMessage', + '$ref': '#/definitions/FunctionMessage', }), ]), 'definitions': dict({