Skip to content

Commit

Permalink
chat.tokens() gains a values argument (#27)
Browse files Browse the repository at this point in the history
* The .tokens() method now returns a list of ints: where each int represents the number of tokens each turn takes

* Add format argument to .tokens(); make default behavior same as before:

* Improvements and tests

* Update changelog

* Update test expectation
  • Loading branch information
cpsievert authored Dec 19, 2024
1 parent f5a300f commit e033684
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### New features

* `Chat`'s `.tokens()` method gains a `values` argument. Set it to `"discrete"` to get a result that can be summed to determine the token cost of submitting the current turns. The default (`"cumulative"`), remains the same (the result can be summed to determine the overall token cost of the conversation).

### Bug fixes

* `ChatOllama` no longer fails when a `OPENAI_API_KEY` environment variable is not set.
Expand Down
116 changes: 111 additions & 5 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Optional,
Sequence,
TypeVar,
overload,
)

from pydantic import BaseModel
Expand Down Expand Up @@ -176,17 +177,122 @@ def system_prompt(self, value: str | None):
if value is not None:
self._turns.insert(0, Turn("system", value))

def tokens(self) -> list[tuple[int, int] | None]:
@overload
def tokens(self) -> list[tuple[int, int] | None]: ...

@overload
def tokens(
self,
values: Literal["cumulative"],
) -> list[tuple[int, int] | None]: ...

@overload
def tokens(
self,
values: Literal["discrete"],
) -> list[int]: ...

def tokens(
self,
values: Literal["cumulative", "discrete"] = "discrete",
) -> list[int] | list[tuple[int, int] | None]:
"""
Get the tokens for each turn in the chat.
Parameters
----------
values
If "cumulative" (the default), the result can be summed to get the
chat's overall token usage (helpful for computing overall cost of
the chat). If "discrete", the result can be summed to get the number of
tokens the turns will cost to generate the next response (helpful
for estimating cost of the next response, or for determining if you
are about to exceed the token limit).
Returns
-------
list[tuple[int, int] | None]
A list of tuples, where each tuple contains the start and end token
indices for a turn.
list[int]
A list of token counts for each (non-system) turn in the chat. The
1st turn includes the tokens count for the system prompt (if any).
Raises
------
ValueError
If the chat's turns (i.e., `.get_turns()`) are not in an expected
format. This may happen if the chat history is manually set (i.e.,
`.set_turns()`). In this case, you can inspect the "raw" token
values via the `.get_turns()` method (each turn has a `.tokens`
attribute).
"""
return [turn.tokens for turn in self._turns]

turns = self.get_turns(include_system_prompt=False)

if values == "cumulative":
return [turn.tokens for turn in turns]

if len(turns) == 0:
return []

err_info = (
"This can happen if the chat history is manually set (i.e., `.set_turns()`). "
"Consider getting the 'raw' token values via the `.get_turns()` method "
"(each turn has a `.tokens` attribute)."
)

# Sanity checks for the assumptions made to figure out user token counts
if len(turns) == 1:
raise ValueError(
"Expected at least two turns in the chat history. " + err_info
)

if len(turns) % 2 != 0:
raise ValueError(
"Expected an even number of turns in the chat history. " + err_info
)

if turns[0].role != "user":
raise ValueError(
"Expected the 1st non-system turn to have role='user'. " + err_info
)

if turns[1].role != "assistant":
raise ValueError(
"Expected the 2nd turn non-system to have role='assistant'. " + err_info
)

if turns[1].tokens is None:
raise ValueError(
"Expected the 1st assistant turn to contain token counts. " + err_info
)

res: list[int] = [
# Implied token count for the 1st user input
turns[1].tokens[0],
# The token count for the 1st assistant response
turns[1].tokens[1],
]
for i in range(1, len(turns) - 1, 2):
ti = turns[i]
tj = turns[i + 2]
if ti.role != "assistant" or tj.role != "assistant":
raise ValueError(
"Expected even turns to have role='assistant'." + err_info
)
if ti.tokens is None or tj.tokens is None:
raise ValueError(
"Expected role='assistant' turns to contain token counts."
+ err_info
)
res.extend(
[
# Implied token count for the user input
tj.tokens[0] - sum(ti.tokens),
# The token count for the assistant response
tj.tokens[1],
]
)

return res

def app(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_provider_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_openai_simple_request():
chat.chat("What is 1 + 1?")
turn = chat.get_last_turn()
assert turn is not None
assert turn.tokens == (27, 1)
assert turn.tokens == (27, 2)
assert turn.finish_reason == "stop"


Expand Down
28 changes: 28 additions & 0 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,35 @@
from chatlas import ChatOpenAI, Turn
from chatlas._openai import OpenAIAzureProvider, OpenAIProvider
from chatlas._tokens import token_usage, tokens_log, tokens_reset


def test_tokens_method():
chat = ChatOpenAI()
assert chat.tokens(values="discrete") == []

chat = ChatOpenAI(
turns=[
Turn(role="user", contents="Hi"),
Turn(role="assistant", contents="Hello", tokens=(2, 10)),
]
)

assert chat.tokens(values="discrete") == [2, 10]

chat = ChatOpenAI(
turns=[
Turn(role="user", contents="Hi"),
Turn(role="assistant", contents="Hello", tokens=(2, 10)),
Turn(role="user", contents="Hi"),
Turn(role="assistant", contents="Hello", tokens=(14, 10)),
]
)

assert chat.tokens(values="discrete") == [2, 10, 2, 10]

assert chat.tokens(values="cumulative") == [None, (2, 10), None, (14, 10)]


def test_usage_is_none():
tokens_reset()
assert token_usage() is None
Expand Down

0 comments on commit e033684

Please sign in to comment.