Skip to content

Commit

Permalink
Allow ChatCompletionChunkResponse (and therefore streaming) to have…
Browse files Browse the repository at this point in the history
… `Usage`. (#1078)

* handle assistant messages with 'tool_calls' when used in chat_template

* linting

* add better methods for using tools in  and update examples

* fixes

* Update interactive_mode.rs

* add Usage to ChatCompletionChunkResponse

* add usage telemetry to streaming messages

* clppy
  • Loading branch information
Jeadie authored Jan 22, 2025
1 parent 710e1f1 commit e4c7d6f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
12 changes: 11 additions & 1 deletion mistralrs-core/src/pipeline/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,19 @@ pub(crate) async fn finish_or_add_toks_to_seq(
this.reset_non_granular_state();
}

// Send usage on final chunk.
let usage_opt = if is_done.is_some() {
let usage = seq.get_mut_group().get_usage();
seq.get_mut_group().total_prompt_toks = 0;
seq.get_mut_group().total_toks = 0;
Some(usage)
} else {
None
};

if seq
.get_mut_group()
.maybe_send_streaming_response(seq, this.name().clone())
.maybe_send_streaming_response(seq, this.name().clone(), usage_opt)
.await
.is_err()
{
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ pub struct ChatCompletionChunkResponse {
pub model: String,
pub system_fingerprint: String,
pub object: String,
pub usage: Option<Usage>,
}

generate_repr!(ChatCompletionChunkResponse);
Expand Down
7 changes: 5 additions & 2 deletions mistralrs-core/src/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -705,8 +705,8 @@ impl Sequence {

get_mut_group!(self).total_time += now - self.timestamp;

get_mut_group!(self).total_prompt_toks += self.prompt_len;
get_mut_group!(self).total_toks += self.len();
get_mut_group!(self).total_prompt_toks = self.prompt_len;
get_mut_group!(self).total_toks = self.len();
}

pub fn add_image_choice_to_group(&self, choice: ImageChoice) {
Expand Down Expand Up @@ -749,6 +749,7 @@ impl Sequence {

pub fn add_streaming_chunk_choice_to_group(&self, chunk: ChunkChoice) {
get_mut_group!(self).chat_streaming_chunks.push(chunk);
self.update_time_info();
}

pub fn add_streaming_completion_chunk_choice_to_group(&self, chunk: CompletionChunkChoice) {
Expand Down Expand Up @@ -920,6 +921,7 @@ impl SequenceGroup {
&mut self,
seq: &Sequence,
model: String,
usage_opt: Option<Usage>,
) -> Result<(), Box<SendError<Response>>> {
if self.chat_streaming_chunks.len() == self.n_choices && self.is_streaming {
let mut swap_streaming_chunks = vec![];
Expand All @@ -934,6 +936,7 @@ impl SequenceGroup {
model: model.clone(),
system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
object: "chat.completion.chunk".to_string(),
usage: usage_opt,
}))
.await?;
} else if self.completion_streaming_chunks.len() == self.n_choices && self.is_streaming {
Expand Down

0 comments on commit e4c7d6f

Please sign in to comment.