Skip to content

Commit

Permalink
feat: expose token usage on chat request
Browse files Browse the repository at this point in the history
  • Loading branch information
moldhouse committed Feb 6, 2025
1 parent 43e52ca commit 8e3cdc8
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 17 deletions.
51 changes: 35 additions & 16 deletions src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,36 @@ impl Default for ChatSampling {
}
}

#[derive(Debug, PartialEq, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
}

#[derive(Debug, PartialEq)]
pub struct ChatOutput {
pub message: Message<'static>,
pub finish_reason: String,
/// Contains the logprobs for the sampled and top n tokens, given that [`crate::Logprobs`] has
/// been set to [`crate::Logprobs::Sampled`] or [`crate::Logprobs::Top`].
pub logprobs: Vec<Distribution>,
pub usage: Usage,
}

impl ChatOutput {
pub fn new(
message: Message<'static>,
finish_reason: String,
logprobs: Vec<Distribution>,
usage: Usage,
) -> Self {
Self {
message,
finish_reason,
logprobs,
usage,
}
}
}

#[derive(Deserialize, Debug, PartialEq)]
Expand All @@ -142,21 +165,6 @@ pub struct LogprobContent {
content: Vec<Distribution>,
}

impl ResponseChoice {
fn into_chat_output(self) -> ChatOutput {
let ResponseChoice {
message,
finish_reason,
logprobs,
} = self;
ChatOutput {
message,
finish_reason,
logprobs: logprobs.unwrap_or_default().content,
}
}
}

/// Logprob information for a single token
#[derive(Deserialize, Debug, PartialEq)]
pub struct Distribution {
Expand All @@ -171,6 +179,7 @@ pub struct Distribution {
#[derive(Deserialize, Debug, PartialEq)]
pub struct ResponseChat {
choices: Vec<ResponseChoice>,
usage: Usage,
}

#[derive(Serialize)]
Expand Down Expand Up @@ -262,7 +271,17 @@ impl Task for TaskChat<'_> {
}

fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
response.choices.pop().unwrap().into_chat_output()
let ResponseChoice {
message,
finish_reason,
logprobs,
} = response.choices.pop().unwrap();
ChatOutput::new(
message,
finish_reason,
logprobs.unwrap_or_default().content,
response.usage,
)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use std::{pin::Pin, time::Duration};
use tokenizers::Tokenizer;

pub use self::{
chat::{ChatEvent, ChatOutput, ChatSampling, ChatStreamChunk, Distribution, Message, TaskChat},
chat::{ChatEvent, ChatOutput, ChatSampling, ChatStreamChunk, Distribution, Message, TaskChat, Usage},
completion::{
CompletionEvent, CompletionOutput, CompletionSummary, Sampling, Stopping, StreamChunk,
StreamSummary, TaskCompletion,
Expand Down
22 changes: 22 additions & 0 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,3 +814,25 @@ async fn show_top_logprobs_completion() {
assert_eq!(response.logprobs[0].top[1].token_as_str().unwrap(), " may");
assert!(response.logprobs[0].top[0].logprob > response.logprobs[0].top[1].logprob);
}

#[tokio::test]
async fn show_token_usage_chat() {
// Given
let model = "pharia-1-llm-7b-control";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let message = Message::user("An apple a day");

let task = TaskChat {
messages: vec![message],
stopping: Stopping::from_maximum_tokens(3),
sampling: ChatSampling::MOST_LIKELY,
logprobs: Logprobs::No,
};

// When
let response = client.chat(&task, model, &How::default()).await.unwrap();

// Then
assert_eq!(response.usage.prompt_tokens, 19);
assert_eq!(response.usage.completion_tokens, 3);
}

0 comments on commit 8e3cdc8

Please sign in to comment.