Skip to content

Commit

Permalink
feat: expose token usage on completion request
Browse files Browse the repository at this point in the history
  • Loading branch information
moldhouse committed Feb 6, 2025
1 parent 8e3cdc8 commit fbdae55
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 5 deletions.
9 changes: 8 additions & 1 deletion src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::HashMap;

use serde::{Deserialize, Serialize};

use crate::{http::Task, Distribution, Logprob, Logprobs, Prompt, StreamTask};
use crate::{http::Task, Distribution, Logprob, Logprobs, Prompt, StreamTask, Usage};

/// Completes a prompt. E.g. continues a text.
pub struct TaskCompletion<'a> {
Expand Down Expand Up @@ -231,6 +231,8 @@ impl<'a> BodyCompletion<'a> {
pub struct ResponseCompletion {
model_version: String,
completions: Vec<DeserializedCompletion>,
num_tokens_prompt_total: u32,
num_tokens_generated: u32,
}

#[derive(Deserialize, Debug, PartialEq)]
Expand All @@ -250,6 +252,7 @@ pub struct CompletionOutput {
pub completion: String,
pub finish_reason: String,
pub logprobs: Vec<Distribution>,
pub usage: Usage,
}

impl Task for TaskCompletion<'_> {
Expand Down Expand Up @@ -289,6 +292,10 @@ impl Task for TaskCompletion<'_> {
completion_tokens,
self.logprobs.top_logprobs().unwrap_or_default(),
),
usage: Usage {
prompt_tokens: response.num_tokens_prompt_total,
completion_tokens: response.num_tokens_generated,
},
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ use std::{pin::Pin, time::Duration};
use tokenizers::Tokenizer;

pub use self::{
chat::{ChatEvent, ChatOutput, ChatSampling, ChatStreamChunk, Distribution, Message, TaskChat, Usage},
chat::{
ChatEvent, ChatOutput, ChatSampling, ChatStreamChunk, Distribution, Message, TaskChat,
Usage,
},
completion::{
CompletionEvent, CompletionOutput, CompletionSummary, Sampling, Stopping, StreamChunk,
StreamSummary, TaskCompletion,
Expand Down
22 changes: 20 additions & 2 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ async fn completion_with_luminous_base() {
.await
.unwrap();

eprintln!("{}", response.completion);

// Then
assert!(!response.completion.is_empty())
}
Expand Down Expand Up @@ -836,3 +834,23 @@ async fn show_token_usage_chat() {
assert_eq!(response.usage.prompt_tokens, 19);
assert_eq!(response.usage.completion_tokens, 3);
}

#[tokio::test]
async fn show_token_usage_completion() {
// Given
let model = "pharia-1-llm-7b-control";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let task = TaskCompletion::from_text("An apple a day")
.with_maximum_tokens(3)
.with_logprobs(Logprobs::No);

// When
let response = client
.completion(&task, model, &How::default())
.await
.unwrap();

// Then
assert_eq!(response.usage.prompt_tokens, 5);
assert_eq!(response.usage.completion_tokens, 3);
}
2 changes: 1 addition & 1 deletion tests/unit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async fn completion_with_luminous_base() {
// Start a background HTTP server on a random local part
let mock_server = MockServer::start().await;

let answer = r#"{"model_version":"2021-12","completions":[{"completion":"\n","finish_reason":"maximum_tokens"}]}"#;
let answer = r#"{"model_version":"2021-12","completions":[{"completion":"\n","finish_reason":"maximum_tokens"}],"num_tokens_prompt_total":5,"num_tokens_generated":1}"#;
let body = r#"{
"model": "luminous-base",
"prompt": [{"type": "text", "data": "Hello,"}],
Expand Down

0 comments on commit fbdae55

Please sign in to comment.