Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow chat streaming to use tools #1088

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions mistralrs-core/src/pipeline/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,33 @@ use rand_isaac::Isaac64Rng;
use crate::{
prefix_cacher_v2::PrefixCacheManagerV2,
sampler::Logprobs,
sequence::{Sequence, SequenceRecognizer},
sequence::{Sequence, SequenceRecognizer, StopReason},
tools::ToolCallingMatcher,
ToolCallResponse,
};

use super::Pipeline;

/// Takes raw UTf8 text and parses any possible tool calls from it.
fn parse_text_tools(
Jeadie marked this conversation as resolved.
Show resolved Hide resolved
raw_text: &str,
matcher: Option<Arc<ToolCallingMatcher>>,
) -> Result<(Option<&str>, Vec<ToolCallResponse>)> {
let mut tool_calls = Vec::new();
let mut text_new = Some(raw_text);

if let Some(ref matcher) = matcher {
let calls = matcher
.get_call(raw_text)
.map_err(candle_core::Error::msg)?;
if !calls.is_empty() {
text_new = None;
tool_calls = calls;
}
};
Ok((text_new, tool_calls))
}

pub(crate) async fn finish_or_add_toks_to_seq(
this: &dyn Pipeline,
prefix_cacher: &mut PrefixCacheManagerV2,
Expand All @@ -19,7 +41,7 @@ pub(crate) async fn finish_or_add_toks_to_seq(
eos_tok: Option<&[u32]>,
use_prefix_cacher: bool,
) -> Result<()> {
let is_done = seq.is_done(logprobs.token, eos_tok, this.get_metadata().max_seq_len);
let mut is_done = seq.is_done(logprobs.token, eos_tok, this.get_metadata().max_seq_len);
seq.add_token(
logprobs.clone(),
this.get_metadata()
Expand All @@ -40,13 +62,28 @@ pub(crate) async fn finish_or_add_toks_to_seq(
let token_index = seq.get_toks().len();
let rate_limit_allowed = is_done.is_some() || token_index % STREAMING_RATE_LIMIT == 0;

if rate_limit_allowed {
let mut tool_use_still_possible = false;
let mut tool_use_is_done = false;
if let Some(ref t) = seq.tools {
if let Ok(Some(ref d)) = seq.peek_delta() {
(tool_use_still_possible, tool_use_is_done) = t.prefix_could_be_tool(d.as_str());
}
};

if (rate_limit_allowed && !tool_use_still_possible) || tool_use_is_done {
if let Some(delta) = crate::handle_seq_error_ok!(seq.get_delta(), seq.responder()) {
if seq.get_mut_group().is_chat {
let (text_new, tool_calls) =
parse_text_tools(delta.as_str(), seq.tools.clone())?;

if !tool_calls.is_empty() && is_done.is_none() {
is_done = Some(StopReason::Eos);
};
seq.add_streaming_chunk_choice_to_group(crate::ChunkChoice {
delta: crate::Delta {
content: delta.clone(),
content: text_new.map(ToString::to_string),
role: "assistant".to_string(),
tool_calls: Some(tool_calls),
},
index: seq.get_response_index(),
finish_reason: is_done.map(|x| x.to_string()),
Expand Down Expand Up @@ -175,20 +212,12 @@ pub(crate) async fn finish_or_add_toks_to_seq(
};

if seq.get_mut_group().is_chat {
let mut tool_calls = Vec::new();
let mut text_new = Some(text.clone());
if let Some(ref matcher) = seq.tools {
let calls = matcher.get_call(&text).map_err(candle_core::Error::msg)?;
if !calls.is_empty() {
text_new = None;
}
tool_calls = calls;
}
let (text_new, tool_calls) = parse_text_tools(text.as_str(), seq.tools.clone())?;
let choice = crate::Choice {
finish_reason: reason.to_string(),
index: seq.get_response_index(),
message: crate::ResponseMessage {
content: text_new,
content: text_new.map(ToString::to_string),
role: "assistant".to_string(),
tool_calls,
},
Expand Down
3 changes: 2 additions & 1 deletion mistralrs-core/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ generate_repr!(ResponseMessage);
#[derive(Debug, Clone, Serialize)]
/// Delta in content for streaming response.
pub struct Delta {
pub content: String,
pub content: Option<String>,
pub role: String,
pub tool_calls: Option<Vec<ToolCallResponse>>,
}

generate_repr!(Delta);
Expand Down
10 changes: 9 additions & 1 deletion mistralrs-core/src/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -667,13 +667,21 @@ impl Sequence {
pub fn get_delta(
&mut self,
) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
let new_decoded = self.peek_delta();
if matches!(new_decoded, Ok(Some(_))) {
self.stream_idx = self.completion_bytes.len();
}
new_decoded
}

/// Peeks at the delta between the last two decoded sequences, but does not advance the stream index.
pub fn peek_delta(&self) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
let is_first = self.stream_idx == 0;
let new_decoded = String::from_utf8_lossy(&self.completion_bytes[self.stream_idx..]);
// Check if the sequence ends with valid utf8, if not skip it as it probably is a multi token sequence
if new_decoded.ends_with('�') {
return Ok(None);
}
self.stream_idx = self.completion_bytes.len();

// The first token usually starts with a space. We don't want to add that to the delta.
// Since we're using the completion_bytes, we need to take care of that ourselves.
Expand Down
45 changes: 45 additions & 0 deletions mistralrs-core/src/tools/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,36 @@ impl ToolCallingMatcher {
Ok(Self { tool_choice })
}

// Checks if the the `message_prefix` could be a tool call. If false, either
// [`ToolChoice::None`] was selected, or the prefix could not match.
//
// If the start of a message could be a tool call, then it looks like an incomplete JSON of a given structure, e.g. `{"name": "foo", "param`.
//
// Returns a tuple of `(could_be_tool, is_complete_tool)`.
pub fn prefix_could_be_tool(&self, message_prefix: &str) -> (bool, bool) {
if matches!(self.tool_choice, ToolChoice::None) {
return (false, false);
}

// Check if the prefix could be a JSON serialization of any of the following types.
[
could_be_json::<CalledFunctionParameters>,
could_be_json::<CalledFunctionArguments>,
could_be_json::<Vec<CalledFunctionParameters>>,
could_be_json::<Vec<CalledFunctionArguments>>,
]
.iter()
.find_map(|check| {
let (could_be_tool, is_complete_tool) = check(message_prefix);
if could_be_tool || is_complete_tool {
Some((could_be_tool, is_complete_tool))
} else {
None
}
})
.unwrap_or_default()
}

pub fn get_call(&self, message: &str) -> anyhow::Result<Vec<ToolCallResponse>> {
if matches!(self.tool_choice, ToolChoice::None) {
return Ok(Vec::new());
Expand Down Expand Up @@ -93,3 +123,18 @@ impl ToolCallingMatcher {
}
}
}

/// Checks if the given prefix could be the start of, or the entire JSON serialization of a given type, `T`.
///
/// Returns a tuple of `(could_be_tool, is_entire_tool)`.
fn could_be_json<T>(text_prefix: &str) -> (bool, bool)
where
T: serde::de::DeserializeOwned,
{
match serde_json::from_str::<T>(text_prefix) {
Ok(_) => (false, true),
// EOF show that JSON parsing was successful up to the end of the entire string.
Err(e) if e.is_eof() => (true, false),
_ => (false, false),
}
}
64 changes: 42 additions & 22 deletions mistralrs-server/src/interactive_mode.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use either::Either;
use indexmap::IndexMap;
use mistralrs_core::{
Constraint, DiffusionGenerationParams, DrySamplingParams, ImageGenerationResponseFormat,
MessageContent, MistralRs, ModelCategory, NormalRequest, Request, RequestMessage, Response,
ResponseOk, SamplingParams, TERMINATE_ALL_NEXT_STEP,
ChunkChoice, Constraint, Delta, DiffusionGenerationParams, DrySamplingParams,
ImageGenerationResponseFormat, MessageContent, MistralRs, ModelCategory, NormalRequest,
Request, RequestMessage, Response, ResponseOk, SamplingParams, TERMINATE_ALL_NEXT_STEP,
};
use once_cell::sync::Lazy;
use regex::Regex;
Expand Down Expand Up @@ -59,7 +59,7 @@ Commands:
- `\system <system message here>`:
Add a system message to the chat without running the model.
Ex: `\system Always respond as a pirate.`
- `\image <image URL or local path here> <message here>`:
- `\image <image URL or local path here> <message here>`:
Add a message paired with an image. The image will be fed to the model as if it were the first item in this prompt.
You do not need to modify your prompt for specific models.
Ex: `\image path/to/image.jpg Describe what is in this image.`
Expand Down Expand Up @@ -187,16 +187,26 @@ async fn text_interactive_mode(mistralrs: Arc<MistralRs>, throughput: bool) {
while let Some(resp) = rx.recv().await {
match resp {
Response::Chunk(chunk) => {
let choice = &chunk.choices[0];
assistant_output.push_str(&choice.delta.content);
print!("{}", choice.delta.content);
toks += 3usize; // NOTE: we send toks every 3.
io::stdout().flush().unwrap();
if choice.finish_reason.is_some() {
if matches!(choice.finish_reason.as_ref().unwrap().as_str(), "length") {
print!("...");
if let ChunkChoice {
delta:
Delta {
content: Some(content),
..
},
finish_reason,
..
} = &chunk.choices[0]
{
assistant_output.push_str(content);
print!("{}", content);
toks += 3usize; // NOTE: we send toks every 3.
io::stdout().flush().unwrap();
if finish_reason.is_some() {
if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
print!("...");
}
break;
}
break;
}
}
Response::InternalError(e) => {
Expand Down Expand Up @@ -408,16 +418,26 @@ async fn vision_interactive_mode(mistralrs: Arc<MistralRs>, throughput: bool) {
while let Some(resp) = rx.recv().await {
match resp {
Response::Chunk(chunk) => {
let choice = &chunk.choices[0];
assistant_output.push_str(&choice.delta.content);
print!("{}", choice.delta.content);
toks += 3usize; // NOTE: we send toks every 3.
io::stdout().flush().unwrap();
if choice.finish_reason.is_some() {
if matches!(choice.finish_reason.as_ref().unwrap().as_str(), "length") {
print!("...");
if let ChunkChoice {
delta:
Delta {
content: Some(content),
..
},
finish_reason,
..
} = &chunk.choices[0]
{
assistant_output.push_str(content);
print!("{}", content);
toks += 3usize; // NOTE: we send toks every 3.
io::stdout().flush().unwrap();
if finish_reason.is_some() {
if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
print!("...");
}
break;
}
break;
}
}
Response::InternalError(e) => {
Expand Down
18 changes: 14 additions & 4 deletions mistralrs/examples/simple_stream/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use anyhow::Result;
use mistralrs::{
IsqType, PagedAttentionMetaBuilder, RequestBuilder, Response, TextMessageRole, TextMessages,
TextModelBuilder,
ChatCompletionChunkResponse, ChunkChoice, Delta, IsqType, PagedAttentionMetaBuilder,
RequestBuilder, Response, TextMessageRole, TextMessages, TextModelBuilder,
};
use std::io::Write;

Expand Down Expand Up @@ -44,8 +44,18 @@ async fn main() -> Result<()> {
let lock = stdout.lock();
let mut buf = std::io::BufWriter::new(lock);
while let Some(chunk) = stream.next().await {
if let Response::Chunk(chunk) = chunk {
buf.write_all(chunk.choices[0].delta.content.as_bytes())?;
if let Response::Chunk(ChatCompletionChunkResponse { choices, .. }) = chunk {
if let Some(ChunkChoice {
delta:
Delta {
content: Some(content),
..
},
..
}) = choices.first()
{
buf.write_all(content.as_bytes())?;
};
} else {
// Handle errors
}
Expand Down
74 changes: 42 additions & 32 deletions mistralrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,41 +52,51 @@
//!
//! ## Streaming example
//! ```no_run
//! use anyhow::Result;
//! use mistralrs::{
//! IsqType, PagedAttentionMetaBuilder, TextMessageRole, TextMessages, TextModelBuilder, Response
//! };
//! use anyhow::Result;
//! use mistralrs::{
//! IsqType, PagedAttentionMetaBuilder, Response, TextMessageRole, TextMessages,
//! TextModelBuilder,
//! };
//! use mistralrs_core::{ChatCompletionChunkResponse, ChunkChoice, Delta};
//!
//! #[tokio::main]
//! async fn main() -> Result<()> {
//! let model = TextModelBuilder::new("microsoft/Phi-3.5-mini-instruct".to_string())
//! .with_isq(IsqType::Q8_0)
//! .with_logging()
//! .with_paged_attn(|| PagedAttentionMetaBuilder::default().build())?
//! .build()
//! .await?;
//! #[tokio::main]
//! async fn main() -> Result<()> {
//! let model = TextModelBuilder::new("microsoft/Phi-3.5-mini-instruct".to_string())
//! .with_isq(IsqType::Q8_0)
//! .with_logging()
//! .with_paged_attn(|| PagedAttentionMetaBuilder::default().build())?
//! .build()
//! .await?;
//!
//! let messages = TextMessages::new()
//! .add_message(
//! TextMessageRole::System,
//! "You are an AI agent with a specialty in programming.",
//! )
//! .add_message(
//! TextMessageRole::User,
//! "Hello! How are you? Please write generic binary search function in Rust.",
//! );
//!
//! let mut stream = model.stream_chat_request(messages).await?;
//!
//! while let Some(chunk) = stream.next().await {
//! if let Response::Chunk(chunk) = chunk{
//! print!("{}", chunk.choices[0].delta.content);
//! }
//! // Handle the error cases.
//! let messages = TextMessages::new()
//! .add_message(
//! TextMessageRole::System,
//! "You are an AI agent with a specialty in programming.",
//! )
//! .add_message(
//! TextMessageRole::User,
//! "Hello! How are you? Please write generic binary search function in Rust.",
//! );
//!
//! }
//! Ok(())
//! }
//! let mut stream = model.stream_chat_request(messages).await?;

//! while let Some(chunk) = stream.next().await {
//! if let Response::Chunk(ChatCompletionChunkResponse { choices, .. }) = chunk {
//! if let Some(ChunkChoice {
//! delta:
//! Delta {
//! content: Some(content),
//! ..
//! },
//! ..
//! }) = choices.first()
//! {
//! print!("content");
//! };
//! }
//! }
//! Ok(())
//! }
//! ```

mod anymoe;
Expand Down
Loading