From e4c7d6fa3dd81ae83e39d784dce2ff5245359657 Mon Sep 17 00:00:00 2001 From: Jack Eadie Date: Wed, 22 Jan 2025 22:51:17 +1000 Subject: [PATCH] Allow `ChatCompletionChunkResponse` (and therefore streaming) to have `Usage`. (#1078) * handle assistant messages with 'tool_calls' when used in chat_template * linting * add better methods for using tools in and update examples * fixes * Update interactive_mode.rs * add Usage to ChatCompletionChunkResponse * add usage telemetry to streaming messages * clppy --- mistralrs-core/src/pipeline/sampling.rs | 12 +++++++++++- mistralrs-core/src/response.rs | 1 + mistralrs-core/src/sequence.rs | 7 +++++-- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/mistralrs-core/src/pipeline/sampling.rs b/mistralrs-core/src/pipeline/sampling.rs index fdb0f825a..49dd8e6cc 100644 --- a/mistralrs-core/src/pipeline/sampling.rs +++ b/mistralrs-core/src/pipeline/sampling.rs @@ -90,9 +90,19 @@ pub(crate) async fn finish_or_add_toks_to_seq( this.reset_non_granular_state(); } + // Send usage on final chunk. + let usage_opt = if is_done.is_some() { + let usage = seq.get_mut_group().get_usage(); + seq.get_mut_group().total_prompt_toks = 0; + seq.get_mut_group().total_toks = 0; + Some(usage) + } else { + None + }; + if seq .get_mut_group() - .maybe_send_streaming_response(seq, this.name().clone()) + .maybe_send_streaming_response(seq, this.name().clone(), usage_opt) .await .is_err() { diff --git a/mistralrs-core/src/response.rs b/mistralrs-core/src/response.rs index d3c9f22fa..e405678cf 100644 --- a/mistralrs-core/src/response.rs +++ b/mistralrs-core/src/response.rs @@ -154,6 +154,7 @@ pub struct ChatCompletionChunkResponse { pub model: String, pub system_fingerprint: String, pub object: String, + pub usage: Option, } generate_repr!(ChatCompletionChunkResponse); diff --git a/mistralrs-core/src/sequence.rs b/mistralrs-core/src/sequence.rs index 9aa9d49b9..e8b84ab6b 100644 --- a/mistralrs-core/src/sequence.rs +++ b/mistralrs-core/src/sequence.rs @@ -705,8 +705,8 @@ impl Sequence { get_mut_group!(self).total_time += now - self.timestamp; - get_mut_group!(self).total_prompt_toks += self.prompt_len; - get_mut_group!(self).total_toks += self.len(); + get_mut_group!(self).total_prompt_toks = self.prompt_len; + get_mut_group!(self).total_toks = self.len(); } pub fn add_image_choice_to_group(&self, choice: ImageChoice) { @@ -749,6 +749,7 @@ impl Sequence { pub fn add_streaming_chunk_choice_to_group(&self, chunk: ChunkChoice) { get_mut_group!(self).chat_streaming_chunks.push(chunk); + self.update_time_info(); } pub fn add_streaming_completion_chunk_choice_to_group(&self, chunk: CompletionChunkChoice) { @@ -920,6 +921,7 @@ impl SequenceGroup { &mut self, seq: &Sequence, model: String, + usage_opt: Option, ) -> Result<(), Box>> { if self.chat_streaming_chunks.len() == self.n_choices && self.is_streaming { let mut swap_streaming_chunks = vec![]; @@ -934,6 +936,7 @@ impl SequenceGroup { model: model.clone(), system_fingerprint: SYSTEM_FINGERPRINT.to_string(), object: "chat.completion.chunk".to_string(), + usage: usage_opt, })) .await?; } else if self.completion_streaming_chunks.len() == self.n_choices && self.is_streaming {