Skip to content

Commit

Permalink
Merge branch 'main' into token-usage
Browse files Browse the repository at this point in the history
  • Loading branch information
cpsievert committed Dec 19, 2024
2 parents bb528d7 + e033684 commit da070a3
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 7 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### New features

* The `Chat` class gains a `.token_count()` method to help estimate input tokens before sending it to the LLM. (#23)
* The `Chat` class gains a `.token_count()` method to help estimate token cost of new input before generating a response for it. (#23)
* `Chat`'s `.tokens()` method gains a `values` argument. Set it to `"discrete"` to get a result that can be summed to determine the token cost of submitting the current turns. The default (`"cumulative"`), remains the same (the result can be summed to determine the overall token cost of the conversation).

### Bug fixes

Expand Down
116 changes: 111 additions & 5 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Optional,
Sequence,
TypeVar,
overload,
)

from pydantic import BaseModel
Expand Down Expand Up @@ -176,17 +177,122 @@ def system_prompt(self, value: str | None):
if value is not None:
self._turns.insert(0, Turn("system", value))

def tokens(self) -> list[tuple[int, int] | None]:
@overload
def tokens(self) -> list[tuple[int, int] | None]: ...

@overload
def tokens(
self,
values: Literal["cumulative"],
) -> list[tuple[int, int] | None]: ...

@overload
def tokens(
self,
values: Literal["discrete"],
) -> list[int]: ...

def tokens(
self,
values: Literal["cumulative", "discrete"] = "discrete",
) -> list[int] | list[tuple[int, int] | None]:
"""
Get the tokens for each turn in the chat.
Parameters
----------
values
If "cumulative" (the default), the result can be summed to get the
chat's overall token usage (helpful for computing overall cost of
the chat). If "discrete", the result can be summed to get the number of
tokens the turns will cost to generate the next response (helpful
for estimating cost of the next response, or for determining if you
are about to exceed the token limit).
Returns
-------
list[tuple[int, int] | None]
A list of tuples, where each tuple contains the start and end token
indices for a turn.
list[int]
A list of token counts for each (non-system) turn in the chat. The
1st turn includes the tokens count for the system prompt (if any).
Raises
------
ValueError
If the chat's turns (i.e., `.get_turns()`) are not in an expected
format. This may happen if the chat history is manually set (i.e.,
`.set_turns()`). In this case, you can inspect the "raw" token
values via the `.get_turns()` method (each turn has a `.tokens`
attribute).
"""
return [turn.tokens for turn in self._turns]

turns = self.get_turns(include_system_prompt=False)

if values == "cumulative":
return [turn.tokens for turn in turns]

if len(turns) == 0:
return []

err_info = (
"This can happen if the chat history is manually set (i.e., `.set_turns()`). "
"Consider getting the 'raw' token values via the `.get_turns()` method "
"(each turn has a `.tokens` attribute)."
)

# Sanity checks for the assumptions made to figure out user token counts
if len(turns) == 1:
raise ValueError(
"Expected at least two turns in the chat history. " + err_info
)

if len(turns) % 2 != 0:
raise ValueError(
"Expected an even number of turns in the chat history. " + err_info
)

if turns[0].role != "user":
raise ValueError(
"Expected the 1st non-system turn to have role='user'. " + err_info
)

if turns[1].role != "assistant":
raise ValueError(
"Expected the 2nd turn non-system to have role='assistant'. " + err_info
)

if turns[1].tokens is None:
raise ValueError(
"Expected the 1st assistant turn to contain token counts. " + err_info
)

res: list[int] = [
# Implied token count for the 1st user input
turns[1].tokens[0],
# The token count for the 1st assistant response
turns[1].tokens[1],
]
for i in range(1, len(turns) - 1, 2):
ti = turns[i]
tj = turns[i + 2]
if ti.role != "assistant" or tj.role != "assistant":
raise ValueError(
"Expected even turns to have role='assistant'." + err_info
)
if ti.tokens is None or tj.tokens is None:
raise ValueError(
"Expected role='assistant' turns to contain token counts."
+ err_info
)
res.extend(
[
# Implied token count for the user input
tj.tokens[0] - sum(ti.tokens),
# The token count for the assistant response
tj.tokens[1],
]
)

return res

def token_count(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_provider_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_openai_simple_request():
chat.chat("What is 1 + 1?")
turn = chat.get_last_turn()
assert turn is not None
assert turn.tokens == (27, 1)
assert turn.tokens == (27, 2)
assert turn.finish_reason == "stop"


Expand Down
28 changes: 28 additions & 0 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,35 @@
from chatlas import ChatOpenAI, Turn
from chatlas._openai import OpenAIAzureProvider, OpenAIProvider
from chatlas._tokens import token_usage, tokens_log, tokens_reset


def test_tokens_method():
chat = ChatOpenAI()
assert chat.tokens(values="discrete") == []

chat = ChatOpenAI(
turns=[
Turn(role="user", contents="Hi"),
Turn(role="assistant", contents="Hello", tokens=(2, 10)),
]
)

assert chat.tokens(values="discrete") == [2, 10]

chat = ChatOpenAI(
turns=[
Turn(role="user", contents="Hi"),
Turn(role="assistant", contents="Hello", tokens=(2, 10)),
Turn(role="user", contents="Hi"),
Turn(role="assistant", contents="Hello", tokens=(14, 10)),
]
)

assert chat.tokens(values="discrete") == [2, 10, 2, 10]

assert chat.tokens(values="cumulative") == [None, (2, 10), None, (14, 10)]


def test_usage_is_none():
tokens_reset()
assert token_usage() is None
Expand Down

0 comments on commit da070a3

Please sign in to comment.