Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix url handling #81

Merged
merged 2 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/transcription/websocket/simple_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async fn main() -> Result<(), DeepgramError> {

let mut results = dg
.transcription()
.stream_request_with_options(Some(&options))
.stream_request_with_options(options)
.keep_alive()
.encoding(Encoding::Linear16)
.sample_rate(44100)
Expand Down
16 changes: 7 additions & 9 deletions src/common/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//!
//! [api]: https://developers.deepgram.com/documentation/features/

use std::collections::HashMap;
use std::{collections::HashMap, fmt};

use serde::{ser::SerializeSeq, Deserialize, Serialize};

Expand Down Expand Up @@ -172,14 +172,12 @@ pub enum Endpointing {
CustomDurationMs(u32),
}

/// Endpointing impl
impl Endpointing {
#[allow(missing_docs)]
pub fn to_str(&self) -> String {
impl fmt::Display for Endpointing {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Endpointing::Enabled => "true".to_string(),
Endpointing::Disabled => "false".to_string(),
Endpointing::CustomDurationMs(value) => value.to_string(),
Endpointing::Enabled => f.write_str("true"),
Endpointing::Disabled => f.write_str("false"),
Endpointing::CustomDurationMs(value) => f.write_fmt(format_args!("{value}")),
}
}
}
Expand Down Expand Up @@ -674,7 +672,7 @@ impl Options {
/// ```
///
pub fn urlencoded(&self) -> Result<String, serde_urlencoded::ser::Error> {
serde_urlencoded::to_string(SerializableOptions(self))
serde_urlencoded::to_string(SerializableOptions::from(self))
}
}

Expand Down
8 changes: 6 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,12 @@ pub enum DeepgramError {
WsError(#[from] tungstenite::Error),

/// Something went wrong during serialization/deserialization.
#[error("Something went wrong during serialization/deserialization: {0}")]
SerdeError(#[from] serde_json::Error),
#[error("Something went wrong during json serialization/deserialization: {0}")]
JsonError(#[from] serde_json::Error),

/// Something went wrong during serialization/deserialization.
#[error("Something went wrong during query serialization: {0}")]
UrlencodedError(#[from] serde_urlencoded::ser::Error),
}

#[cfg_attr(not(feature = "listen"), allow(unused))]
Expand Down
243 changes: 155 additions & 88 deletions src/listen/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,46 @@
//!
//! [api]: https://developers.deepgram.com/api-reference/#transcription-streaming

use crate::common::stream_response::StreamResponse;
use serde_urlencoded;
use std::borrow::Cow;
use std::error::Error;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{
error::Error,
fmt::Debug,
marker::PhantomData,
path::Path,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};

use bytes::{Bytes, BytesMut};
use futures::channel::mpsc::{self, Receiver};
use futures::stream::StreamExt;
use futures::{SinkExt, Stream};
use futures::{
channel::mpsc::{self, Receiver},
stream::StreamExt,
SinkExt, Stream,
};
use http::Request;
use pin_project::pin_project;
use tokio::fs::File;
use tokio::sync::Mutex;
use tokio::time;
use serde_urlencoded;
use tokio::{fs::File, sync::Mutex, time};
use tokio_tungstenite::tungstenite::protocol::Message;
use tokio_util::io::ReaderStream;
use tungstenite::handshake::client;
use url::Url;

use crate::common::options::{Encoding, Endpointing, Options, SerializableOptions};
use crate::{Deepgram, DeepgramError, Result, Transcription};
use crate::{
common::{
options::{Encoding, Endpointing, Options},
stream_response::StreamResponse,
},
Deepgram, DeepgramError, Result, Transcription,
};

static LIVE_LISTEN_URL_PATH: &str = "v1/listen";

#[derive(Debug)]
pub struct StreamRequestBuilder<'a> {
config: &'a Deepgram,
options: Option<&'a Options>,
deepgram: &'a Deepgram,
options: Options,
encoding: Option<Encoding>,
sample_rate: Option<u32>,
channels: Option<u16>,
Expand All @@ -65,15 +70,12 @@ struct FileChunker {

impl Transcription<'_> {
pub fn stream_request(&self) -> StreamRequestBuilder<'_> {
self.stream_request_with_options(None)
self.stream_request_with_options(Options::builder().build())
}

pub fn stream_request_with_options<'a>(
&'a self,
options: Option<&'a Options>,
) -> StreamRequestBuilder<'a> {
pub fn stream_request_with_options(&self, options: Options) -> StreamRequestBuilder<'_> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Callout: I changed this public interface to take a crate::common::options::Options instead of an Option<&crate::common::options::Options>. It says it takes options, so only sometimes taking them is weird, especially if we also provide a version that doesn't take them, and since we can get the same behavior by passing an empty Options.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked: This function was not available in 0.5.0, so this isn't a breaking change.

StreamRequestBuilder {
config: self.0,
deepgram: self.0,
Comment on lines -76 to +78
Copy link
Contributor Author

@jcdyer jcdyer Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this because I was getting confused about what a config is vs an options. It's a strictly internal change. Not visible to users.

options,
encoding: None,
sample_rate: None,
Expand Down Expand Up @@ -142,10 +144,110 @@ impl Stream for FileChunker {
}

impl<'a> StreamRequestBuilder<'a> {
pub fn keep_alive(mut self) -> Self {
self.keep_alive = Some(true);
/// Return the options in urlencoded format. If serialization would
/// fail, this will also return an error.
///
/// This is intended primarily to help with debugging API requests.
///
/// ```
/// use deepgram::{
/// Deepgram,
/// DeepgramError,
/// common::options::{
/// DetectLanguage,
/// Encoding,
/// Model,
/// Options,
/// },
/// };
/// # let mut need_token = std::env::var("DEEPGRAM_API_TOKEN").is_err();
/// # if need_token {
/// # std::env::set_var("DEEPGRAM_API_TOKEN", "abc")
/// # }
Comment on lines +163 to +166
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love this pattern. The #s ensure it doesn't show up in the docs, but it's still kind of ugly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the intent of those lines?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a hard time fixing these imports in the examples in VSCode when they would break the build since there is no auto complete for them.

/// let dg = Deepgram::new(std::env::var("DEEPGRAM_API_TOKEN").unwrap());
/// let transcription = dg.transcription();
/// let options = Options::builder()
/// .model(Model::Nova2)
/// .detect_language(DetectLanguage::Enabled)
/// .build();
/// let builder = transcription
/// .stream_request_with_options(
/// options,
/// )
/// .no_delay(true);
///
/// # if need_token {
/// # std::env::remove_var("DEEPGRAM_API_TOKEN");
/// # }
///
/// assert_eq!(&builder.urlencoded().unwrap(), "model=nova-2&detect_language=true&no_delay=true")
/// ```
///
pub fn urlencoded(&self) -> std::result::Result<String, serde_urlencoded::ser::Error> {
Ok(self.as_url()?.query().unwrap_or_default().to_string())
}

self
fn as_url(&self) -> std::result::Result<Url, serde_urlencoded::ser::Error> {
// Destructuring ensures we don't miss new fields if they get added
let Self {
deepgram: _,
keep_alive: _,
options,
encoding,
sample_rate,
channels,
endpointing,
utterance_end_ms,
interim_results,
no_delay,
vad_events,
stream_url,
} = self;

let mut url = stream_url.clone();
{
let mut pairs = url.query_pairs_mut();

// Add standard pre-recorded options.
//
// Here we serialize the options and then deserialize
// in order to avoid duplicating serialization logic.
//
// TODO: We should be able to lean on the serde more
// to avoid multiple serialization rounds.
pairs.extend_pairs(
serde_urlencoded::from_str::<Vec<(String, String)>>(&options.urlencoded()?)
.expect("constructed query string can be deserialized"),
);

// Add streaming-specific options
if let Some(encoding) = encoding {
pairs.append_pair("encoding", encoding.as_str());
}
if let Some(sample_rate) = sample_rate {
pairs.append_pair("sample_rate", &sample_rate.to_string());
}
if let Some(channels) = channels {
pairs.append_pair("channels", &channels.to_string());
}
if let Some(endpointing) = endpointing {
pairs.append_pair("endpointing", &endpointing.to_string());
}
if let Some(utterance_end_ms) = utterance_end_ms {
pairs.append_pair("utterance_end_ms", &utterance_end_ms.to_string());
}
if let Some(interim_results) = interim_results {
pairs.append_pair("interim_results", &interim_results.to_string());
}
if let Some(no_delay) = no_delay {
pairs.append_pair("no_delay", &no_delay.to_string());
}
if let Some(vad_events) = vad_events {
pairs.append_pair("vad_events", &vad_events.to_string());
}
}

Ok(url)
}

pub fn encoding(mut self, encoding: Encoding) -> Self {
Expand Down Expand Up @@ -195,13 +297,12 @@ impl<'a> StreamRequestBuilder<'a> {

self
}
}

#[derive(Debug)]
pub struct StreamRequest<'a, S, E> {
stream: S,
builder: StreamRequestBuilder<'a>,
_err: PhantomData<E>,
pub fn keep_alive(mut self) -> Self {
self.keep_alive = Some(true);

self
}
}

impl<'a> StreamRequestBuilder<'a> {
Expand Down Expand Up @@ -237,9 +338,11 @@ impl<'a> StreamRequestBuilder<'a> {
}
}

fn options_to_query_string(options: &Options) -> String {
let serialized_options = SerializableOptions::from(options);
serde_urlencoded::to_string(serialized_options).unwrap_or_default()
#[derive(Debug)]
pub struct StreamRequest<'a, S, E> {
stream: S,
builder: StreamRequestBuilder<'a>,
_err: PhantomData<E>,
}

impl<S, E> StreamRequest<'_, S, E>
Expand All @@ -248,54 +351,7 @@ where
E: Error + Debug + Send + Unpin + 'static,
{
pub async fn start(self) -> Result<Receiver<Result<StreamResponse>>> {
let mut url = self.builder.stream_url;
{
let mut pairs = url.query_pairs_mut();

// Add standard pre-recorded options
if let Some(options) = &self.builder.options {
let query_string = options_to_query_string(options);
let query_pairs: Vec<(Cow<str>, Cow<str>)> = query_string
.split('&')
.map(|s| {
let mut split = s.splitn(2, '=');
(
Cow::from(split.next().unwrap_or_default()),
Cow::from(split.next().unwrap_or_default()),
)
})
.collect();

for (key, value) in query_pairs {
pairs.append_pair(&key, &value);
}
}
Comment on lines -255 to -272
Copy link
Contributor Author

@jcdyer jcdyer Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the original logic which double-quoted inputs. So if the original option contained an &, L257 would quote it as %26 (as it should be), and then after L270, the % in %26 would get quoted to %25, and the final result would be %2526. The new logic unquotes parameters instead of just splitting on separators.

if let Some(encoding) = &self.builder.encoding {
pairs.append_pair("encoding", encoding.as_str());
}
if let Some(sample_rate) = self.builder.sample_rate {
pairs.append_pair("sample_rate", &sample_rate.to_string());
}
if let Some(channels) = self.builder.channels {
pairs.append_pair("channels", &channels.to_string());
}
if let Some(endpointing) = self.builder.endpointing {
pairs.append_pair("endpointing", &endpointing.to_str());
}
if let Some(utterance_end_ms) = self.builder.utterance_end_ms {
pairs.append_pair("utterance_end_ms", &utterance_end_ms.to_string());
}
if let Some(interim_results) = self.builder.interim_results {
pairs.append_pair("interim_results", &interim_results.to_string());
}
if let Some(no_delay) = self.builder.no_delay {
pairs.append_pair("no_delay", &no_delay.to_string());
}
if let Some(vad_events) = self.builder.vad_events {
pairs.append_pair("vad_events", &vad_events.to_string());
}
}

let url = self.builder.as_url()?;
let mut source = self
.stream
.map(|res| res.map(|bytes| Message::binary(Vec::from(bytes.as_ref()))));
Expand All @@ -310,7 +366,7 @@ where
.header("upgrade", "websocket")
.header("sec-websocket-version", "13");

let builder = if let Some(api_key) = self.builder.config.api_key.as_deref() {
let builder = if let Some(api_key) = self.builder.deepgram.api_key.as_deref() {
builder.header("authorization", format!("token {}", api_key))
} else {
builder
Expand Down Expand Up @@ -397,6 +453,8 @@ where

#[cfg(test)]
mod tests {
use crate::common::options::Options;

#[test]
fn test_stream_url() {
let dg = crate::Deepgram::new("token");
Expand All @@ -414,4 +472,13 @@ mod tests {
"ws://localhost:8080/v1/listen",
);
}

#[test]
fn query_escaping() {
let dg = crate::Deepgram::new("token");
let opts = Options::builder().custom_topics(["A&R"]).build();
let transcription = dg.transcription();
let builder = transcription.stream_request_with_options(opts.clone());
assert_eq!(builder.urlencoded().unwrap(), opts.urlencoded().unwrap())
}
}
Loading