Skip to content

Commit

Permalink
metrics for completion API
Browse files Browse the repository at this point in the history
  • Loading branch information
dineshyv committed Feb 5, 2025
1 parent 7fbd6ca commit 5f9375e
Showing 1 changed file with 84 additions and 34 deletions.
118 changes: 84 additions & 34 deletions llama_stack/distribution/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,37 @@ async def _log_token_usage(
)
)

async def _add_token_metrics(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
target: Any,
) -> None:
metrics = getattr(target, "metrics", None)
if metrics is None:
target.metrics = []
target.metrics.append(
TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
)
if self.telemetry:
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)

async def _count_tokens(
self,
messages: Union[List[Message], List[RawMessage]],
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> Optional[int]:
if not self.telemetry:
return None
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
return len(encoded.tokens) if encoded and encoded.tokens else 0

async def chat_completion(
self,
model_id: str,
Expand Down Expand Up @@ -203,60 +234,46 @@ async def chat_completion(
tool_config=tool_config,
)
provider = self.routing_table.get_provider_impl(model_id)
model_input = self.formatter.encode_dialog_prompt(
messages,
tool_config.tool_prompt_format,
)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)

if stream:

async def stream_generator():
prompt_tokens = len(model_input.tokens) if model_input.tokens else 0
completion_text = ""
async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
model_output = self.formatter.encode_dialog_prompt(
completion_tokens = await self._count_tokens(
[RawMessage(role="assistant", content=completion_text)],
tool_config.tool_prompt_format,
)
completion_tokens = len(model_output.tokens) if model_output.tokens else 0
total_tokens = prompt_tokens + completion_tokens
if chunk.metrics is None:
chunk.metrics = []
chunk.metrics.append(
TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
chunk,
)
if self.telemetry:
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
yield chunk

return stream_generator()
else:
response = await provider.chat_completion(**params)
model_output = self.formatter.encode_dialog_prompt(
completion_tokens = await self._count_tokens(
[response.completion_message],
tool_config.tool_prompt_format,
)
prompt_tokens = len(model_input.tokens) if model_input.tokens else 0
completion_tokens = len(model_output.tokens) if model_output.tokens else 0
total_tokens = prompt_tokens + completion_tokens
if response.metrics is None:
response.metrics = []
response.metrics.append(
TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
response,
)
if self.telemetry:
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
return response

async def completion(
Expand All @@ -282,10 +299,43 @@ async def completion(
stream=stream,
logprobs=logprobs,
)

prompt_tokens = await self._count_tokens([RawMessage(role="user", content=str(content))])

if stream:
return (chunk async for chunk in await provider.completion(**params))

async def stream_generator():
completion_text = ""
async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens(
[RawMessage(role="assistant", content=completion_text)]
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
chunk,
)
yield chunk

return stream_generator()
else:
return await provider.completion(**params)
response = await provider.completion(**params)
completion_tokens = await self._count_tokens([RawMessage(role="assistant", content=str(response.content))])
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
response,
)
return response

async def embeddings(
self,
Expand Down

0 comments on commit 5f9375e

Please sign in to comment.