Skip to content

Commit

Permalink
Breaking change: .chat() now _always_ consumes and displays stream. A…
Browse files Browse the repository at this point in the history
…dd .stream() method
  • Loading branch information
cpsievert committed Nov 19, 2024
1 parent 2e0c570 commit 78590f7
Show file tree
Hide file tree
Showing 12 changed files with 68 additions and 64 deletions.
22 changes: 9 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Again, keep in mind that the chat object retains state, so when you enter the ch

### The `.chat()` method

For a more programmatic approach, you can use the `.chat()` method to ask a question and get a response. If you're in a REPL (e.g., Jupyter, IPython, etc), the result of `.chat()` is automatically displayed using a [rich](https://github.com/Textualize/rich) console.
For a more programmatic approach, you can use the `.chat()` method to ask a question and get a response. By default, the response prints to a [rich](https://github.com/Textualize/rich) console as it streams in:

```python
chat.chat("What preceding languages most influenced Python?")
Expand All @@ -109,25 +109,21 @@ Python was primarily influenced by ABC, with additional inspiration from C,
Modula-3, and various other languages.
```

If you're not in a REPL (e.g., a non-interactive Python script), you can explicitly `.display()` the response:
To get the full response as a string, use the built-in `str()` function. Optionally, you can (soon) suppress the rich console output by setting `echo="none"`:

```python
response = chat.chat("What is the Python programming language?")
response.display()
response = chat.chat("Who is Posit?", echo="none")
print(str(response))
```

The `response` is also an iterable, so you can loop over it to get the response in streaming chunks:
### The `.stream()` method

```python
result = ""
for chunk in response:
result += chunk
```

Or, if you just want the full response as a string, use the built-in `str()` function:
If you want to process the response in real-time (i.e., as it arrives in chunks), you can use the `.stream()` method. This method returns a generator that yields chunks of the response as they come in. This is useful for long responses, or if you want to process the response as it streams in:

```python
str(response)
response = chat.stream("Who is Posit?")
for chunk in response:
print(chunk, end="")
```


Expand Down
17 changes: 1 addition & 16 deletions chatlas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import sys

from . import types
from ._anthropic import ChatAnthropic, ChatBedrockAnthropic
from ._chat import Chat, ChatResponse
from ._chat import Chat
from ._content_image import content_image_file, content_image_plot, content_image_url
from ._github import ChatGithub
from ._google import ChatGoogle
Expand Down Expand Up @@ -35,16 +33,3 @@
"Tool",
"Provider",
)

# ChatResponse objects are displayed in the REPL using rich
original_displayhook = sys.displayhook


def custom_displayhook(value):
if isinstance(value, ChatResponse):
value.display()
else:
original_displayhook(value)


sys.displayhook = custom_displayhook
55 changes: 39 additions & 16 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,11 @@ def chat(
A response from the chat.
"""
turn = user_turn(*args)
return ChatResponse(self._chat_impl(turn, stream=stream, kwargs=kwargs))
resp = ChatResponse(self._chat_impl(turn, stream=stream, kwargs=kwargs))

resp.display()

return resp

async def chat_async(
self,
Expand All @@ -299,8 +303,39 @@ async def chat_async(
the response.
"""
turn = user_turn(*args)
gen = self._chat_impl_async(turn, stream=stream, kwargs=kwargs)
return ChatResponseAsync(gen)
resp = ChatResponseAsync(
self._chat_impl_async(turn, stream=stream, kwargs=kwargs),
)

await resp.display()

return resp

def stream(
self,
*args: Content | str,
kwargs: Optional[SubmitInputArgsT] = None,
) -> ChatResponse:
"""
TODO: Add docstring.
"""
turn = user_turn(*args)
return ChatResponse(
self._chat_impl(turn, stream=True, kwargs=kwargs),
)

async def stream_async(
self,
*args: Content | str,
kwargs: Optional[SubmitInputArgsT] = None,
) -> ChatResponseAsync:
"""
TODO: Add docstring.
"""
turn = user_turn(*args)
return ChatResponseAsync(
self._chat_impl_async(turn, stream=True, kwargs=kwargs),
)

def extract_data(
self,
Expand Down Expand Up @@ -746,12 +781,6 @@ def consumed(self) -> bool:
def __str__(self) -> str:
return self.get_content()

def __repr__(self) -> str:
return (
"ChatResponse object. Call `.display()` to show it in a rich"
"console or `.get_content()` to get the content."
)


class ChatResponseAsync:
"""
Expand Down Expand Up @@ -789,7 +818,7 @@ async def __anext__(self) -> str:
self.content += chunk # Keep track of accumulated content
return chunk

async def display(self) -> None:
async def display(self):
"Display the content in a rich console."
from rich.live import Live
from rich.markdown import Markdown
Expand All @@ -810,12 +839,6 @@ async def get_content(self) -> str:
def consumed(self) -> bool:
return self._generator.ag_frame is None

def __repr__(self) -> str:
return (
"ChatResponseAsync object. Call `.display()` to show it in a rich"
"console or `.get_content()` to get the content."
)


@contextmanager
def JupyterFriendlyConsole():
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/types.ChatResponse.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ This method gets called automatically when the object is displayed.
### get_string { #chatlas.types.ChatResponse.get_string }

```python
types.ChatResponse.get_string()
types.ChatResponse.get_content()
```

Get the chat response content as a string.
2 changes: 1 addition & 1 deletion docs/reference/types.ChatResponseAsync.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Display the content in a rich console.
### get_string { #chatlas.types.ChatResponseAsync.get_string }

```python
types.ChatResponseAsync.get_string()
types.ChatResponseAsync.get_content()
```

Get the chat response content as a string.
2 changes: 1 addition & 1 deletion docs/web-apps.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ chat_model = ChatAnthropic()

@chat.on_user_submit
def _():
response = chat_model.chat(chat.user_input())
response = chat_model.stream(chat.user_input())
chat.append_message_stream(response)
```

Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def get_current_date():
response = await chat.chat_async(
"What's the current date in YMD format?", stream=stream
)
assert "2024-01-01" in await response.get_string()
assert "2024-01-01" in await response.get_content()

with pytest.raises(Exception, match="async tools in a synchronous chat"):
str(chat.chat("Great. Do it again.", stream=stream))
Expand Down Expand Up @@ -203,6 +203,6 @@ def assert_images_remote_error(chat_fun: ChatFun):
image_remote = content_image_url("https://httr2.r-lib.org/logo.png")

with pytest.raises(Exception, match="Remote images aren't supported"):
_ = str(chat.chat("What's in this image?", image_remote))
chat.chat("What's in this image?", image_remote)

assert len(chat.turns()) == 0
8 changes: 4 additions & 4 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ async def test_simple_async_batch_chat():
response = await chat.chat_async(
"What's 1 + 1. Just give me the answer, no punctuation",
)
assert "2" == await response.get_string()
assert "2" == await response.get_content()


def test_simple_streaming_chat():
chat = ChatOpenAI()
res = chat.chat("""
res = chat.stream("""
What are the canonical colors of the ROYGBIV rainbow?
Put each colour on its own line. Don't use punctuation.
""")
Expand All @@ -39,7 +39,7 @@ def test_simple_streaming_chat():
@pytest.mark.asyncio
async def test_simple_streaming_chat_async():
chat = ChatOpenAI()
res = await chat.chat_async("""
res = await chat.stream_async("""
What are the canonical colors of the ROYGBIV rainbow?
Put each colour on its own line. Don't use punctuation.
""")
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_last_turn_retrieval():
assert chat.last_turn(role="user") is None
assert chat.last_turn(role="assistant") is None

_ = str(chat.chat("Hi"))
chat.chat("Hi")
user_turn = chat.last_turn(role="user")
assert user_turn is not None and user_turn.role == "user"
turn = chat.last_turn(role="assistant")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_provider_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_anthropic_simple_request():
chat = ChatAnthropic(
system_prompt="Be as terse as possible; no punctuation",
)
_ = str(chat.chat("What is 1 + 1?"))
chat.chat("What is 1 + 1?")
turn = chat.last_turn()
assert turn is not None
assert turn.tokens == (26, 5)
Expand All @@ -32,7 +32,7 @@ async def test_anthropic_simple_streaming_request():
system_prompt="Be as terse as possible; no punctuation",
)
res = []
foo = await chat.chat_async("What is 1 + 1?")
foo = await chat.stream_async("What is 1 + 1?")
async for x in foo:
res.append(x)
assert "2" in "".join(res)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_provider_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_azure_simple_request():
)

response = chat.chat("What is 1 + 1?")
assert "2" == response.get_string()
assert "2" == response.get_content()
turn = chat.last_turn()
assert turn is not None
assert turn.tokens == (27, 1)
Expand All @@ -33,7 +33,7 @@ async def test_azure_simple_request_async():
)

response = await chat.chat_async("What is 1 + 1?")
assert "2" == await response.get_string()
assert "2" == await response.get_content()
turn = chat.last_turn()
assert turn is not None
assert turn.tokens == (27, 1)
6 changes: 3 additions & 3 deletions tests/test_provider_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def test_google_simple_request():
chat = ChatGoogle(
system_prompt="Be as terse as possible; no punctuation",
)
_ = str(chat.chat("What is 1 + 1?"))
chat.chat("What is 1 + 1?")
turn = chat.last_turn()
assert turn is not None
assert turn.tokens == (17, 1)
assert turn.tokens == (17, 2)


@pytest.mark.asyncio
Expand All @@ -36,7 +36,7 @@ async def test_google_simple_streaming_request():
system_prompt="Be as terse as possible; no punctuation",
)
res = []
async for x in await chat.chat_async("What is 1 + 1?"):
async for x in await chat.stream_async("What is 1 + 1?"):
res.append(x)
assert "2" in "".join(res)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_provider_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_openai_simple_request():
chat = ChatOpenAI(
system_prompt="Be as terse as possible; no punctuation",
)
_ = str(chat.chat("What is 1 + 1?"))
chat.chat("What is 1 + 1?")
turn = chat.last_turn()
assert turn is not None
assert turn.tokens == (27, 1)
Expand All @@ -30,7 +30,7 @@ async def test_openai_simple_streaming_request():
system_prompt="Be as terse as possible; no punctuation",
)
res = []
async for x in await chat.chat_async("What is 1 + 1?"):
async for x in await chat.stream_async("What is 1 + 1?"):
res.append(x)
assert "2" in "".join(res)

Expand Down Expand Up @@ -68,7 +68,7 @@ async def test_openai_logprobs():
chat = ChatOpenAI()

pieces = []
async for x in await chat.chat_async("Hi", kwargs={"logprobs": True}):
async for x in await chat.stream_async("Hi", kwargs={"logprobs": True}):
pieces.append(x)

turn = chat.last_turn()
Expand Down

0 comments on commit 78590f7

Please sign in to comment.