From 73a1b2b3efc01f600be3378072b895d2af578778 Mon Sep 17 00:00:00 2001 From: Cliff Dyer Date: Fri, 26 Jul 2024 12:04:01 -0400 Subject: [PATCH 1/2] Expose query parameters from StreamRequestBuilder, and demonstrate double-escaping problem with a failing test. --- .../transcription/websocket/simple_stream.rs | 2 +- src/common/options.rs | 16 +- src/lib.rs | 8 +- src/listen/websocket.rs | 199 ++++++++++++------ 4 files changed, 146 insertions(+), 79 deletions(-) diff --git a/examples/transcription/websocket/simple_stream.rs b/examples/transcription/websocket/simple_stream.rs index f3040af3..51333153 100644 --- a/examples/transcription/websocket/simple_stream.rs +++ b/examples/transcription/websocket/simple_stream.rs @@ -22,7 +22,7 @@ async fn main() -> Result<(), DeepgramError> { let mut results = dg .transcription() - .stream_request_with_options(Some(&options)) + .stream_request_with_options(options) .keep_alive() .encoding(Encoding::Linear16) .sample_rate(44100) diff --git a/src/common/options.rs b/src/common/options.rs index bbf0ca4d..0bf9bb8e 100644 --- a/src/common/options.rs +++ b/src/common/options.rs @@ -4,7 +4,7 @@ //! //! [api]: https://developers.deepgram.com/documentation/features/ -use std::collections::HashMap; +use std::{collections::HashMap, fmt}; use serde::{ser::SerializeSeq, Deserialize, Serialize}; @@ -172,14 +172,12 @@ pub enum Endpointing { CustomDurationMs(u32), } -/// Endpointing impl -impl Endpointing { - #[allow(missing_docs)] - pub fn to_str(&self) -> String { +impl fmt::Display for Endpointing { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Endpointing::Enabled => "true".to_string(), - Endpointing::Disabled => "false".to_string(), - Endpointing::CustomDurationMs(value) => value.to_string(), + Endpointing::Enabled => f.write_str("true"), + Endpointing::Disabled => f.write_str("false"), + Endpointing::CustomDurationMs(value) => f.write_fmt(format_args!("{value}")), } } } @@ -674,7 +672,7 @@ impl Options { /// ``` /// pub fn urlencoded(&self) -> Result { - serde_urlencoded::to_string(SerializableOptions(self)) + serde_urlencoded::to_string(SerializableOptions::from(self)) } } diff --git a/src/lib.rs b/src/lib.rs index 699bde6f..fce3c6dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -145,8 +145,12 @@ pub enum DeepgramError { WsError(#[from] tungstenite::Error), /// Something went wrong during serialization/deserialization. - #[error("Something went wrong during serialization/deserialization: {0}")] - SerdeError(#[from] serde_json::Error), + #[error("Something went wrong during json serialization/deserialization: {0}")] + JsonError(#[from] serde_json::Error), + + /// Something went wrong during serialization/deserialization. + #[error("Something went wrong during query serialization: {0}")] + UrlencodedError(#[from] serde_urlencoded::ser::Error), } #[cfg_attr(not(feature = "listen"), allow(unused))] diff --git a/src/listen/websocket.rs b/src/listen/websocket.rs index 51cff716..71ad6e6e 100644 --- a/src/listen/websocket.rs +++ b/src/listen/websocket.rs @@ -34,7 +34,7 @@ use tokio_util::io::ReaderStream; use tungstenite::handshake::client; use url::Url; -use crate::common::options::{Encoding, Endpointing, Options, SerializableOptions}; +use crate::common::options::{Encoding, Endpointing, Options}; use crate::{Deepgram, DeepgramError, Result, Transcription}; static LIVE_LISTEN_URL_PATH: &str = "v1/listen"; @@ -42,7 +42,7 @@ static LIVE_LISTEN_URL_PATH: &str = "v1/listen"; #[derive(Debug)] pub struct StreamRequestBuilder<'a> { config: &'a Deepgram, - options: Option<&'a Options>, + options: Options, encoding: Option, sample_rate: Option, channels: Option, @@ -65,13 +65,10 @@ struct FileChunker { impl Transcription<'_> { pub fn stream_request(&self) -> StreamRequestBuilder<'_> { - self.stream_request_with_options(None) + self.stream_request_with_options(Options::builder().build()) } - pub fn stream_request_with_options<'a>( - &'a self, - options: Option<&'a Options>, - ) -> StreamRequestBuilder<'a> { + pub fn stream_request_with_options(&self, options: Options) -> StreamRequestBuilder<'_> { StreamRequestBuilder { config: self.0, options, @@ -142,10 +139,112 @@ impl Stream for FileChunker { } impl<'a> StreamRequestBuilder<'a> { - pub fn keep_alive(mut self) -> Self { - self.keep_alive = Some(true); + /// Return the options in urlencoded format. If serialization would + /// fail, this will also return an error. + /// + /// This is intended primarily to help with debugging API requests. + /// + /// ``` + /// use deepgram::{ + /// Deepgram, + /// DeepgramError, + /// common::options::{ + /// DetectLanguage, + /// Encoding, + /// Model, + /// Options, + /// }, + /// }; + /// # let mut need_token = std::env::var("DEEPGRAM_API_TOKEN").is_err(); + /// # if need_token { + /// # std::env::set_var("DEEPGRAM_API_TOKEN", "abc") + /// # } + /// let dg = Deepgram::new(std::env::var("DEEPGRAM_API_TOKEN").unwrap()); + /// let transcription = dg.transcription(); + /// let options = Options::builder() + /// .model(Model::Nova2) + /// .detect_language(DetectLanguage::Enabled) + /// .build(); + /// let builder = transcription + /// .stream_request_with_options::>>( + /// options, + /// ) + /// .no_delay(true); + /// + /// # if need_token { + /// # std::env::remove_var("DEEPGRAM_API_TOKEN"); + /// # } + /// + /// assert_eq!(&builder.urlencoded().unwrap(), "model=nova-2&detect_language=true&no_delay=true") + /// ``` + /// + pub fn urlencoded(&self) -> std::result::Result { + Ok(self.as_url()?.query().unwrap_or_default().to_string()) + } - self + fn as_url(&self) -> std::result::Result { + // Destructuring ensures we don't miss new fields if they get added + let Self { + config: _, + keep_alive: _, + options, + encoding, + sample_rate, + channels, + endpointing, + utterance_end_ms, + interim_results, + no_delay, + vad_events, + stream_url, + } = self; + + let mut url = stream_url.clone(); + { + let mut pairs = url.query_pairs_mut(); + + // Add standard pre-recorded options + let query_string = options.urlencoded().unwrap(); + let query_pairs: Vec<(Cow, Cow)> = query_string + .split('&') + .map(|s| { + let mut split = s.splitn(2, '='); + ( + Cow::from(split.next().unwrap_or_default()), + Cow::from(split.next().unwrap_or_default()), + ) + }) + .collect(); + + for (key, value) in query_pairs { + pairs.append_pair(&key, &value); + } + if let Some(encoding) = encoding { + pairs.append_pair("encoding", encoding.as_str()); + } + if let Some(sample_rate) = sample_rate { + pairs.append_pair("sample_rate", &sample_rate.to_string()); + } + if let Some(channels) = channels { + pairs.append_pair("channels", &channels.to_string()); + } + if let Some(endpointing) = endpointing { + pairs.append_pair("endpointing", &endpointing.to_string()); + } + if let Some(utterance_end_ms) = utterance_end_ms { + pairs.append_pair("utterance_end_ms", &utterance_end_ms.to_string()); + } + if let Some(interim_results) = interim_results { + pairs.append_pair("interim_results", &interim_results.to_string()); + } + if let Some(no_delay) = no_delay { + pairs.append_pair("no_delay", &no_delay.to_string()); + } + if let Some(vad_events) = vad_events { + pairs.append_pair("vad_events", &vad_events.to_string()); + } + } + Ok(url) } pub fn encoding(mut self, encoding: Encoding) -> Self { @@ -195,13 +294,12 @@ impl<'a> StreamRequestBuilder<'a> { self } -} -#[derive(Debug)] -pub struct StreamRequest<'a, S, E> { - stream: S, - builder: StreamRequestBuilder<'a>, - _err: PhantomData, + pub fn keep_alive(mut self) -> Self { + self.keep_alive = Some(true); + + self + } } impl<'a> StreamRequestBuilder<'a> { @@ -237,9 +335,11 @@ impl<'a> StreamRequestBuilder<'a> { } } -fn options_to_query_string(options: &Options) -> String { - let serialized_options = SerializableOptions::from(options); - serde_urlencoded::to_string(serialized_options).unwrap_or_default() +#[derive(Debug)] +pub struct StreamRequest<'a, S, E> { + stream: S, + builder: StreamRequestBuilder<'a>, + _err: PhantomData, } impl StreamRequest<'_, S, E> @@ -248,54 +348,7 @@ where E: Error + Debug + Send + Unpin + 'static, { pub async fn start(self) -> Result>> { - let mut url = self.builder.stream_url; - { - let mut pairs = url.query_pairs_mut(); - - // Add standard pre-recorded options - if let Some(options) = &self.builder.options { - let query_string = options_to_query_string(options); - let query_pairs: Vec<(Cow, Cow)> = query_string - .split('&') - .map(|s| { - let mut split = s.splitn(2, '='); - ( - Cow::from(split.next().unwrap_or_default()), - Cow::from(split.next().unwrap_or_default()), - ) - }) - .collect(); - - for (key, value) in query_pairs { - pairs.append_pair(&key, &value); - } - } - if let Some(encoding) = &self.builder.encoding { - pairs.append_pair("encoding", encoding.as_str()); - } - if let Some(sample_rate) = self.builder.sample_rate { - pairs.append_pair("sample_rate", &sample_rate.to_string()); - } - if let Some(channels) = self.builder.channels { - pairs.append_pair("channels", &channels.to_string()); - } - if let Some(endpointing) = self.builder.endpointing { - pairs.append_pair("endpointing", &endpointing.to_str()); - } - if let Some(utterance_end_ms) = self.builder.utterance_end_ms { - pairs.append_pair("utterance_end_ms", &utterance_end_ms.to_string()); - } - if let Some(interim_results) = self.builder.interim_results { - pairs.append_pair("interim_results", &interim_results.to_string()); - } - if let Some(no_delay) = self.builder.no_delay { - pairs.append_pair("no_delay", &no_delay.to_string()); - } - if let Some(vad_events) = self.builder.vad_events { - pairs.append_pair("vad_events", &vad_events.to_string()); - } - } - + let url = self.builder.as_url()?; let mut source = self .stream .map(|res| res.map(|bytes| Message::binary(Vec::from(bytes.as_ref())))); @@ -397,6 +450,8 @@ where #[cfg(test)] mod tests { + use crate::common::options::Options; + #[test] fn test_stream_url() { let dg = crate::Deepgram::new("token"); @@ -414,4 +469,14 @@ mod tests { "ws://localhost:8080/v1/listen", ); } + + #[test] + fn query_escaping() { + let dg = crate::Deepgram::new("token"); + let opts = Options::builder().custom_topics(["A&R"]).build(); + let transcription = dg.transcription(); + let builder = transcription.stream_request_with_options(opts.clone()); + // Currently fails because A&R is double escaped in the streaming URL + assert_eq!(builder.urlencoded().unwrap(), opts.urlencoded().unwrap()) + } } From e49900e8a9b7f35f76d04af32053c48eaf094292 Mon Sep 17 00:00:00 2001 From: Cliff Dyer Date: Fri, 26 Jul 2024 12:24:05 -0400 Subject: [PATCH 2/2] handle escaped url params properly --- src/listen/websocket.rs | 84 +++++++++++++++++++++-------------------- 1 file changed, 43 insertions(+), 41 deletions(-) diff --git a/src/listen/websocket.rs b/src/listen/websocket.rs index 71ad6e6e..91e58461 100644 --- a/src/listen/websocket.rs +++ b/src/listen/websocket.rs @@ -8,40 +8,45 @@ //! //! [api]: https://developers.deepgram.com/api-reference/#transcription-streaming -use crate::common::stream_response::StreamResponse; -use serde_urlencoded; -use std::borrow::Cow; -use std::error::Error; -use std::fmt::Debug; -use std::marker::PhantomData; -use std::path::Path; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::time::Duration; +use std::{ + error::Error, + fmt::Debug, + marker::PhantomData, + path::Path, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; use bytes::{Bytes, BytesMut}; -use futures::channel::mpsc::{self, Receiver}; -use futures::stream::StreamExt; -use futures::{SinkExt, Stream}; +use futures::{ + channel::mpsc::{self, Receiver}, + stream::StreamExt, + SinkExt, Stream, +}; use http::Request; use pin_project::pin_project; -use tokio::fs::File; -use tokio::sync::Mutex; -use tokio::time; +use serde_urlencoded; +use tokio::{fs::File, sync::Mutex, time}; use tokio_tungstenite::tungstenite::protocol::Message; use tokio_util::io::ReaderStream; use tungstenite::handshake::client; use url::Url; -use crate::common::options::{Encoding, Endpointing, Options}; -use crate::{Deepgram, DeepgramError, Result, Transcription}; +use crate::{ + common::{ + options::{Encoding, Endpointing, Options}, + stream_response::StreamResponse, + }, + Deepgram, DeepgramError, Result, Transcription, +}; static LIVE_LISTEN_URL_PATH: &str = "v1/listen"; #[derive(Debug)] pub struct StreamRequestBuilder<'a> { - config: &'a Deepgram, + deepgram: &'a Deepgram, options: Options, encoding: Option, sample_rate: Option, @@ -70,7 +75,7 @@ impl Transcription<'_> { pub fn stream_request_with_options(&self, options: Options) -> StreamRequestBuilder<'_> { StreamRequestBuilder { - config: self.0, + deepgram: self.0, options, encoding: None, sample_rate: None, @@ -166,7 +171,7 @@ impl<'a> StreamRequestBuilder<'a> { /// .detect_language(DetectLanguage::Enabled) /// .build(); /// let builder = transcription - /// .stream_request_with_options::>>( + /// .stream_request_with_options( /// options, /// ) /// .no_delay(true); @@ -185,7 +190,7 @@ impl<'a> StreamRequestBuilder<'a> { fn as_url(&self) -> std::result::Result { // Destructuring ensures we don't miss new fields if they get added let Self { - config: _, + deepgram: _, keep_alive: _, options, encoding, @@ -203,22 +208,19 @@ impl<'a> StreamRequestBuilder<'a> { { let mut pairs = url.query_pairs_mut(); - // Add standard pre-recorded options - let query_string = options.urlencoded().unwrap(); - let query_pairs: Vec<(Cow, Cow)> = query_string - .split('&') - .map(|s| { - let mut split = s.splitn(2, '='); - ( - Cow::from(split.next().unwrap_or_default()), - Cow::from(split.next().unwrap_or_default()), - ) - }) - .collect(); - - for (key, value) in query_pairs { - pairs.append_pair(&key, &value); - } + // Add standard pre-recorded options. + // + // Here we serialize the options and then deserialize + // in order to avoid duplicating serialization logic. + // + // TODO: We should be able to lean on the serde more + // to avoid multiple serialization rounds. + pairs.extend_pairs( + serde_urlencoded::from_str::>(&options.urlencoded()?) + .expect("constructed query string can be deserialized"), + ); + + // Add streaming-specific options if let Some(encoding) = encoding { pairs.append_pair("encoding", encoding.as_str()); } @@ -244,6 +246,7 @@ impl<'a> StreamRequestBuilder<'a> { pairs.append_pair("vad_events", &vad_events.to_string()); } } + Ok(url) } @@ -363,7 +366,7 @@ where .header("upgrade", "websocket") .header("sec-websocket-version", "13"); - let builder = if let Some(api_key) = self.builder.config.api_key.as_deref() { + let builder = if let Some(api_key) = self.builder.deepgram.api_key.as_deref() { builder.header("authorization", format!("token {}", api_key)) } else { builder @@ -476,7 +479,6 @@ mod tests { let opts = Options::builder().custom_topics(["A&R"]).build(); let transcription = dg.transcription(); let builder = transcription.stream_request_with_options(opts.clone()); - // Currently fails because A&R is double escaped in the streaming URL assert_eq!(builder.urlencoded().unwrap(), opts.urlencoded().unwrap()) } }