Skip to content

Commit

Permalink
[2.x] Fix Amazon Nova support (use StrOutputParser) (#1203)
Browse files Browse the repository at this point in the history
* use StrOutputParser in default chat

* encourage using StrOutputParser in docs

* pre-commit

* use StrOutputParser in /fix
  • Loading branch information
dlqqq authored Jan 16, 2025
1 parent 92262ba commit a8c52c5
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/source/developers/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def create_llm_chain(
prompt_template = FIX_PROMPT_TEMPLATE
self.prompt_template = prompt_template

runnable = prompt_template | llm # type:ignore
runnable = prompt_template | llm | StrOutputParser() # type:ignore
self.llm_chain = runnable
```

Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import ConfigurableFieldSpec
from langchain_core.runnables.history import RunnableWithMessageHistory

Expand Down Expand Up @@ -37,7 +38,7 @@ def create_llm_chain(
self.llm = llm
self.prompt_template = prompt_template

runnable = prompt_template | llm # type:ignore
runnable = prompt_template | llm | StrOutputParser() # type:ignore
if not llm.manages_history:
runnable = RunnableWithMessageHistory(
runnable=runnable, # type:ignore[arg-type]
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from jupyter_ai.models import CellWithErrorSelection, HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

from .base import BaseChatHandler, SlashCommandRoutingType

Expand Down Expand Up @@ -76,7 +77,7 @@ def create_llm_chain(
self.llm = llm
prompt_template = FIX_PROMPT_TEMPLATE

runnable = prompt_template | llm # type:ignore
runnable = prompt_template | llm | StrOutputParser() # type:ignore
self.llm_chain = runnable

async def process_message(self, message: HumanChatMessage):
Expand Down

0 comments on commit a8c52c5

Please sign in to comment.