diff --git a/docs/source/developers/index.md b/docs/source/developers/index.md index 46b6c5719..465b5d1ee 100644 --- a/docs/source/developers/index.md +++ b/docs/source/developers/index.md @@ -492,7 +492,7 @@ def create_llm_chain( prompt_template = FIX_PROMPT_TEMPLATE self.prompt_template = prompt_template - runnable = prompt_template | llm # type:ignore + runnable = prompt_template | llm | StrOutputParser() # type:ignore self.llm_chain = runnable ``` diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 266ad73ad..852220be1 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -3,6 +3,7 @@ from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider +from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import ConfigurableFieldSpec from langchain_core.runnables.history import RunnableWithMessageHistory @@ -37,7 +38,7 @@ def create_llm_chain( self.llm = llm self.prompt_template = prompt_template - runnable = prompt_template | llm # type:ignore + runnable = prompt_template | llm | StrOutputParser() # type:ignore if not llm.manages_history: runnable = RunnableWithMessageHistory( runnable=runnable, # type:ignore[arg-type] diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py index 390b93cf6..27ec4d024 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py @@ -3,6 +3,7 @@ from jupyter_ai.models import CellWithErrorSelection, HumanChatMessage from jupyter_ai_magics.providers import BaseProvider from langchain.prompts import PromptTemplate +from langchain_core.output_parsers import StrOutputParser from .base import BaseChatHandler, SlashCommandRoutingType @@ -76,7 +77,7 @@ def create_llm_chain( self.llm = llm prompt_template = FIX_PROMPT_TEMPLATE - runnable = prompt_template | llm # type:ignore + runnable = prompt_template | llm | StrOutputParser() # type:ignore self.llm_chain = runnable async def process_message(self, message: HumanChatMessage):