Skip to content

Commit

Permalink
Add .token_count_async(); require the whole data_model
Browse files Browse the repository at this point in the history
  • Loading branch information
cpsievert committed Dec 19, 2024
1 parent 69436ff commit 1d27b5b
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 26 deletions.
41 changes: 31 additions & 10 deletions chatlas/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,15 +384,43 @@ def token_count(
self,
*args: Content | str,
tools: dict[str, Tool],
has_data_model: bool,
data_model: Optional[type[BaseModel]],
) -> int:
kwargs = self._token_count_args(
*args,
tools=tools,
data_model=data_model,
)
res = self._client.messages.count_tokens(**kwargs)
return res.input_tokens

async def token_count_async(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> int:
kwargs = self._token_count_args(
*args,
tools=tools,
data_model=data_model,
)
res = await self._async_client.messages.count_tokens(**kwargs)
return res.input_tokens

def _token_count_args(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> dict[str, Any]:
turn = user_turn(*args)

kwargs = self._chat_perform_args(
stream=False,
turns=[turn],
tools=tools,
data_model=None if not has_data_model else BaseModel,
data_model=data_model,
)

args_to_keep = [
Expand All @@ -403,14 +431,7 @@ def token_count(
"tool_choice",
]

kwargs_final = {}
for arg in args_to_keep:
if arg in kwargs:
kwargs_final[arg] = kwargs[arg]

res = self._client.messages.count_tokens(**kwargs_final)

return res.input_tokens
return {arg: kwargs[arg] for arg in args_to_keep if arg in kwargs}

def _as_message_params(self, turns: list[Turn]) -> list["MessageParam"]:
messages: list["MessageParam"] = []
Expand Down
45 changes: 40 additions & 5 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,23 @@ def tokens(self) -> list[tuple[int, int] | None]:
def token_count(
self,
*args: Content | str,
extract_data: bool = False,
data_model: Optional[type[BaseModel]] = None,
) -> int:
"""
Get an estimated token count for the given input.
Estimate the token size of input content. This can help determine whether input(s)
Estimate the token size of input content. This can help determine whether input(s)
and/or conversation history (i.e., `.get_turns()`) should be reduced in size before
sending it to the model.
Parameters
----------
args
The input to get a token count for.
extract_data
Whether or not the input is for data extraction (i.e., `.extract_data()`).
data_model
If the input is meant for data extraction (i.e., `.extract_data()`), then
this should be the Pydantic model that describes the structure of the data to
extract.
Returns
-------
Expand All @@ -231,7 +233,40 @@ def token_count(
return self.provider.token_count(
*args,
tools=self._tools,
has_data_model=extract_data,
data_model=data_model,
)

async def token_count_async(
self,
*args: Content | str,
data_model: Optional[type[BaseModel]] = None,
) -> int:
"""
Get an estimated token count for the given input asynchronously.
Estimate the token size of input content. This can help determine whether input(s)
and/or conversation history (i.e., `.get_turns()`) should be reduced in size before
sending it to the model.
Parameters
----------
args
The input to get a token count for.
data_model
If this input is meant for data extraction (i.e., `.extract_data_async()`),
then this should be the Pydantic model that describes the structure of the data
to extract.
Returns
-------
int
The token count for the input.
"""

return await self.provider.token_count_async(
*args,
tools=self._tools,
data_model=data_model,
)

def app(
Expand Down
42 changes: 33 additions & 9 deletions chatlas/_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,26 +337,50 @@ def token_count(
self,
*args: Content | str,
tools: dict[str, Tool],
has_data_model: bool,
data_model: Optional[type[BaseModel]],
):
kwargs = self._token_count_args(
*args,
tools=tools,
data_model=data_model,
)

res = self._client.count_tokens(**kwargs)
return res.total_tokens

async def token_count_async(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
):
kwargs = self._token_count_args(
*args,
tools=tools,
data_model=data_model,
)

res = await self._client.count_tokens_async(**kwargs)
return res.total_tokens

def _token_count_args(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> dict[str, Any]:
turn = user_turn(*args)

kwargs = self._chat_perform_args(
stream=False,
turns=[turn],
tools=tools,
data_model=None if not has_data_model else BaseModel,
data_model=data_model,
)

args_to_keep = ["contents", "tools"]

kwargs_final = {}
for arg in args_to_keep:
if arg in kwargs:
kwargs_final[arg] = kwargs[arg]

res = self._client.count_tokens(**kwargs_final)
return res.total_tokens
return {arg: kwargs[arg] for arg in args_to_keep if arg in kwargs}

def _google_contents(self, turns: list[Turn]) -> list["ContentDict"]:
contents: list["ContentDict"] = []
Expand Down
10 changes: 9 additions & 1 deletion chatlas/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def token_count(
self,
*args: Content | str,
tools: dict[str, Tool],
has_data_model: bool,
data_model: Optional[type[BaseModel]],
) -> int:
try:
import tiktoken
Expand Down Expand Up @@ -383,6 +383,14 @@ def token_count(

return res

async def token_count_async(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> int:
return self.token_count(*args, tools=tools, data_model=data_model)

@staticmethod
def _image_token_count(image: ContentImage) -> int:
if isinstance(image, ContentImageRemote) and image.detail == "low":
Expand Down
10 changes: 9 additions & 1 deletion chatlas/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,13 @@ def token_count(
self,
*args: Content | str,
tools: dict[str, Tool],
has_data_model: bool,
data_model: Optional[type[BaseModel]],
) -> int: ...

@abstractmethod
async def token_count_async(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> int: ...

0 comments on commit 1d27b5b

Please sign in to comment.