Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename .turns() -> .get_turns(); .last_turn() -> .get_last_turn() #19

Merged
merged 4 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,16 @@ Easily get a full markdown or HTML export of a conversation:
chat.export("index.html", title="Python Q&A")
```

If the export doesn't have all the information you need, you can also access the full conversation history via the `.turns()` method:
If the export doesn't have all the information you need, you can also access the full conversation history via the `.get_turns()` method:

```python
chat.turns()
chat.get_turns()
```

And, if the conversation is too long, you can specify which turns to include:

```python
chat.export("index.html", turns=chat.turns()[-5:])
chat.export("index.html", turns=chat.get_turns()[-5:])
```

### Async
Expand Down Expand Up @@ -242,7 +242,7 @@ chat.chat("What is the capital of France?", echo="all")

This shows important information like tool call results, finish reasons, and more.

If the problem isn't self-evident, you can also reach into the `.last_turn()`, which contains the full response object, with full details about the completion.
If the problem isn't self-evident, you can also reach into the `.get_last_turn()`, which contains the full response object, with full details about the completion.


<div style="display:flex;justify-content:center;">
Expand Down
30 changes: 18 additions & 12 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
"css_styles": {},
}

def turns(
def get_turns(
self,
*,
include_system_prompt: bool = False,
Expand All @@ -115,7 +115,7 @@ def turns(
return self._turns[1:]
return self._turns

def last_turn(
def get_last_turn(
self,
*,
role: Literal["assistant", "user", "system"] = "assistant",
Expand Down Expand Up @@ -158,7 +158,12 @@ def set_turns(self, turns: Sequence[Turn]):
@property
def system_prompt(self) -> str | None:
"""
Get the system prompt for the chat.
A property to get (or set) the system prompt for the chat.

Returns
-------
str | None
The system prompt (if any).
"""
if self._turns and self._turns[0].role == "system":
return self._turns[0].text
Expand Down Expand Up @@ -228,7 +233,8 @@ def server(input): # noqa: A002
chat = ui.Chat(
"chat",
messages=[
{"role": turn.role, "content": turn.text} for turn in self.turns()
{"role": turn.role, "content": turn.text}
for turn in self.get_turns()
],
)

Expand Down Expand Up @@ -533,7 +539,7 @@ def extract_data(
for _ in response:
pass

turn = self.last_turn()
turn = self.get_last_turn()
assert turn is not None

res: list[ContentJson] = []
Expand Down Expand Up @@ -593,7 +599,7 @@ async def extract_data_async(
async for _ in response:
pass

turn = self.last_turn()
turn = self.get_last_turn()
assert turn is not None

res: list[ContentJson] = []
Expand Down Expand Up @@ -711,7 +717,7 @@ def export(
The filename to export the chat to. Currently this must
be a `.md` or `.html` file.
turns
The `.turns()` to export. If not provided, the chat's current turns
The `.get_turns()` to export. If not provided, the chat's current turns
will be used.
title
A title to place at the top of the exported file.
Expand All @@ -729,7 +735,7 @@ def export(
The path to the exported file.
"""
if not turns:
turns = self.turns(include_system_prompt=False)
turns = self.get_turns(include_system_prompt=False)
if not turns:
raise ValueError("No turns to export.")

Expand Down Expand Up @@ -986,7 +992,7 @@ def emit(text: str | Content):
self._turns.extend([user_turn, turn])

def _invoke_tools(self) -> Turn | None:
turn = self.last_turn()
turn = self.get_last_turn()
if turn is None:
return None

Expand All @@ -1003,7 +1009,7 @@ def _invoke_tools(self) -> Turn | None:
return Turn("user", results)

async def _invoke_tools_async(self) -> Turn | None:
turn = self.last_turn()
turn = self.get_last_turn()
if turn is None:
return None

Expand Down Expand Up @@ -1112,15 +1118,15 @@ def set_echo_options(
}

def __str__(self):
turns = self.turns(include_system_prompt=False)
turns = self.get_turns(include_system_prompt=False)
res = ""
for turn in turns:
icon = "👤" if turn.role == "user" else "🤖"
res += f"## {icon} {turn.role.capitalize()} turn:\n\n{str(turn)}\n\n"
return res

def __repr__(self):
turns = self.turns(include_system_prompt=True)
turns = self.get_turns(include_system_prompt=True)
tokens = sum(sum(turn.tokens) for turn in turns if turn.tokens)
res = f"<Chat turns={len(turns)} tokens={tokens}>"
for turn in turns:
Expand Down
4 changes: 2 additions & 2 deletions chatlas/_turn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ class Turn(Generic[CompletionT]):

chat = ChatOpenAI()
str(chat.chat("What is the capital of France?"))
turns = chat.turns()
turns = chat.get_turns()
assert len(turns) == 2
assert isinstance(turns[0], Turn)
assert turns[0].role == "user"
assert turns[1].role == "assistant"

# Load context into a new chat instance
chat2 = ChatAnthropic(turns=turns)
turns2 = chat2.turns()
turns2 = chat2.get_turns()
assert turns == turns2
```

Expand Down
6 changes: 3 additions & 3 deletions docs/reference/Chat.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ You should generally not create this object yourself, but instead call
| [console](#chatlas.Chat.console) | Enter a chat console to interact with the LLM. |
| [extract_data](#chatlas.Chat.extract_data) | Extract structured data from the given input. |
| [extract_data_async](#chatlas.Chat.extract_data_async) | Extract structured data from the given input asynchronously. |
| [last_turn](#chatlas.Chat.last_turn) | Get the last turn in the chat with a specific role. |
| [get_last_turn](#chatlas.Chat.get_last_turn) | Get the last turn in the chat with a specific role. |
| [register_tool](#chatlas.Chat.register_tool) | Register a tool (function) with the chat. |
| [set_turns](#chatlas.Chat.set_turns) | Set the turns of the chat. |
| [tokens](#chatlas.Chat.tokens) | Get the tokens for each turn in the chat. |
Expand Down Expand Up @@ -158,7 +158,7 @@ Extract structured data from the given input asynchronously.
|--------|-----------------------------------------------------|---------------------|
| | [dict](`dict`)\[[str](`str`), [Any](`typing.Any`)\] | The extracted data. |

### last_turn { #chatlas.Chat.last_turn }
### get_last_turn { #chatlas.Chat.get_last_turn }

```python
Chat.get_last_turn(role='assistant')
Expand Down Expand Up @@ -284,7 +284,7 @@ Get the tokens for each turn in the chat.
### turns { #chatlas.Chat.turns }

```python
Chat.turns(include_system_prompt=False)
Chat.get_turns(include_system_prompt=False)
```

Get all the turns (i.e., message contents) in the chat.
Expand Down
4 changes: 2 additions & 2 deletions docs/reference/Turn.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ from chatlas import Turn, ChatOpenAI, ChatAnthropic

chat = ChatOpenAI()
str(chat.chat("What is the capital of France?"))
turns = chat.turns()
turns = chat.get_turns()
assert len(turns) == 2
assert isinstance(turns[0], Turn)
assert turns[0].role == "user"
assert turns[1].role == "assistant"

# Load context into a new chat instance
chat2 = ChatAnthropic(turns=turns)
turns2 = chat2.turns()
turns2 = chat2.get_turns()
assert turns == turns2
```

Expand Down
2 changes: 1 addition & 1 deletion docs/web-apps.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,5 @@ if prompt := st.chat_input():
with st.chat_message("assistant"):
st.write_stream(response)

st.session_state["turns"] = chat.turns()
st.session_state["turns"] = chat.get_turns()
```
16 changes: 8 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def assert_turns_system(chat_fun: ChatFun):
chat = chat_fun(system_prompt=system_prompt)
response = chat.chat("What is the name of Winnie the Pooh's human friend?")
response_text = str(response)
assert len(chat.turns()) == 2
assert len(chat.get_turns()) == 2
assert "CHRISTOPHER ROBIN" in response_text

chat = chat_fun(turns=[Turn("system", system_prompt)])
response = chat.chat("What is the name of Winnie the Pooh's human friend?")
assert "CHRISTOPHER ROBIN" in str(response)
assert len(chat.turns()) == 2
assert len(chat.get_turns()) == 2


def assert_turns_existing(chat_fun: ChatFun):
Expand All @@ -70,11 +70,11 @@ def assert_turns_existing(chat_fun: ChatFun):
),
]
)
assert len(chat.turns()) == 2
assert len(chat.get_turns()) == 2

response = chat.chat("Who is the remaining one? Just give the name")
assert "Prancer" in str(response)
assert len(chat.turns()) == 4
assert len(chat.get_turns()) == 4


def assert_tools_simple(chat_fun: ChatFun, stream: bool = True):
Expand Down Expand Up @@ -133,7 +133,7 @@ def favorite_color(person: str):

assert "Joe: sage green" in str(response)
assert "Hadley: red" in str(response)
assert len(chat.turns()) == 4
assert len(chat.get_turns()) == 4


def assert_tools_sequential(chat_fun: ChatFun, total_calls: int, stream: bool = True):
Expand All @@ -156,7 +156,7 @@ def equipment(weather: str):
stream=stream,
)
assert "umbrella" in str(response).lower()
assert len(chat.turns()) == total_calls
assert len(chat.get_turns()) == total_calls


def assert_data_extraction(chat_fun: ChatFun):
Expand All @@ -178,7 +178,7 @@ def assert_images_inline(chat_fun: ChatFun, stream: bool = True):
chat = chat_fun()
response = chat.chat(
"What's in this image?",
content_image_file(str(img_path)),
content_image_file(str(img_path), resize="low"),
stream=stream,
)
assert "red" in str(response).lower()
Expand All @@ -202,4 +202,4 @@ def assert_images_remote_error(chat_fun: ChatFun):
with pytest.raises(Exception, match="Remote images aren't supported"):
chat.chat("What's in this image?", image_remote)

assert len(chat.turns()) == 0
assert len(chat.get_turns()) == 0
16 changes: 8 additions & 8 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_simple_streaming_chat():
result = "".join(chunks)
rainbow_re = "^red *\norange *\nyellow *\ngreen *\nblue *\nindigo *\nviolet *\n?$"
assert re.match(rainbow_re, result.lower())
turn = chat.last_turn()
turn = chat.get_last_turn()
assert turn is not None
assert re.match(rainbow_re, turn.text.lower())

Expand All @@ -50,7 +50,7 @@ async def test_simple_streaming_chat_async():
result = "".join(chunks)
rainbow_re = "^red *\norange *\nyellow *\ngreen *\nblue *\nindigo *\nviolet *\n?$"
assert re.match(rainbow_re, result.lower())
turn = chat.last_turn()
turn = chat.get_last_turn()
assert turn is not None
assert re.match(rainbow_re, turn.text.lower())

Expand Down Expand Up @@ -119,24 +119,24 @@ class Person(BaseModel):

def test_last_turn_retrieval():
chat = ChatOpenAI()
assert chat.last_turn(role="user") is None
assert chat.last_turn(role="assistant") is None
assert chat.get_last_turn(role="user") is None
assert chat.get_last_turn(role="assistant") is None

chat.chat("Hi")
user_turn = chat.last_turn(role="user")
user_turn = chat.get_last_turn(role="user")
assert user_turn is not None and user_turn.role == "user"
turn = chat.last_turn(role="assistant")
turn = chat.get_last_turn(role="assistant")
assert turn is not None and turn.role == "assistant"


def test_system_prompt_retrieval():
chat1 = ChatOpenAI()
assert chat1.system_prompt is None
assert chat1.last_turn(role="system") is None
assert chat1.get_last_turn(role="system") is None

chat2 = ChatOpenAI(system_prompt="You are from New Zealand")
assert chat2.system_prompt == "You are from New Zealand"
turn = chat2.last_turn(role="system")
turn = chat2.get_last_turn(role="system")
assert turn is not None and turn.text == "You are from New Zealand"


Expand Down
5 changes: 3 additions & 2 deletions tests/test_content_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_can_create_image_from_path(tmp_path):
path = tmp_path / "test.png"
img.save(path)

obj = content_image_file(str(path))
obj = content_image_file(str(path), resize="low")
assert isinstance(obj, ContentImageInline)


Expand Down Expand Up @@ -65,7 +65,8 @@ def test_image_resizing(tmp_path):
content_image_file(str(tmp_path / "test.txt"))

# Test valid resize options
assert content_image_file(str(img_path)) is not None
with pytest.warns(RuntimeWarning):
assert content_image_file(str(img_path)) is not None
assert content_image_file(str(img_path), resize="low") is not None
assert content_image_file(str(img_path), resize="high") is not None
assert content_image_file(str(img_path), resize="none") is not None
Expand Down
4 changes: 2 additions & 2 deletions tests/test_provider_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_anthropic_simple_request():
system_prompt="Be as terse as possible; no punctuation",
)
chat.chat("What is 1 + 1?")
turn = chat.last_turn()
turn = chat.get_last_turn()
assert turn is not None
assert turn.tokens == (26, 5)
assert turn.finish_reason == "end_turn"
Expand All @@ -37,7 +37,7 @@ async def test_anthropic_simple_streaming_request():
async for x in foo:
res.append(x)
assert "2" in "".join(res)
turn = chat.last_turn()
turn = chat.get_last_turn()
assert turn is not None
assert turn.finish_reason == "end_turn"

Expand Down
4 changes: 2 additions & 2 deletions tests/test_provider_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_azure_simple_request():

response = chat.chat("What is 1 + 1?")
assert "2" == response.get_content()
turn = chat.last_turn()
turn = chat.get_last_turn()
assert turn is not None
assert turn.tokens == (27, 1)

Expand All @@ -34,6 +34,6 @@ async def test_azure_simple_request_async():

response = await chat.chat_async("What is 1 + 1?")
assert "2" == await response.get_content()
turn = chat.last_turn()
turn = chat.get_last_turn()
assert turn is not None
assert turn.tokens == (27, 1)
2 changes: 1 addition & 1 deletion tests/test_provider_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# system_prompt="Be as terse as possible; no punctuation",
# )
# _ = str(chat.chat("What is 1 + 1?"))
# turn = chat.last_turn()
# turn = chat.get_last_turn()
# assert turn is not None
# assert turn.tokens == (26, 5)

Expand Down
Loading
Loading