Skip to content

Commit

Permalink
handle escaped url params properly
Browse files Browse the repository at this point in the history
  • Loading branch information
jcdyer committed Jul 26, 2024
1 parent 7f6c3f6 commit 22b531a
Showing 1 changed file with 44 additions and 31 deletions.
75 changes: 44 additions & 31 deletions src/listen/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
//!
//! [api]: https://developers.deepgram.com/api-reference/#transcription-streaming
use crate::common::stream_response::StreamResponse;
use serde_urlencoded;
use std::borrow::Cow;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
Expand All @@ -23,6 +20,7 @@ use futures::stream::StreamExt;
use futures::{SinkExt, Stream};
use http::Request;
use pin_project::pin_project;
use serde_urlencoded;
use tokio::fs::File;
use tokio::sync::Mutex;
use tokio::time;
Expand All @@ -32,6 +30,7 @@ use tungstenite::handshake::client;
use url::Url;

use crate::common::options::{Encoding, Endpointing, Options};
use crate::common::stream_response::StreamResponse;
use crate::{Deepgram, DeepgramError, Result, Transcription};

static LIVE_LISTEN_URL_PATH: &str = "v1/listen";
Expand All @@ -41,7 +40,7 @@ pub struct StreamRequestBuilder<'a, S, E>
where
S: Stream<Item = std::result::Result<Bytes, E>>,
{
config: &'a Deepgram,
deepgram: &'a Deepgram,
options: Options,
source: Option<S>,
encoding: Option<Encoding>,
Expand Down Expand Up @@ -80,7 +79,7 @@ impl Transcription<'_> {
S: Stream<Item = std::result::Result<Bytes, E>>,
{
StreamRequestBuilder {
config: self.0,
deepgram: self.0,
options,
source: None,
encoding: None,
Expand Down Expand Up @@ -154,52 +153,66 @@ where
S: Stream<Item = std::result::Result<Bytes, E>>,
{
pub fn as_url(&self) -> std::result::Result<Url, serde_urlencoded::ser::Error> {
// This unwrap is safe because we're parsing a static.
let mut url = self.stream_url.clone();
// Destructuring ensures we don't miss new fields if they get added
let Self {
deepgram: _,
source: _,
keep_alive: _,
options,
encoding,
sample_rate,
channels,
endpointing,
utterance_end_ms,
interim_results,
no_delay,
vad_events,
stream_url,
} = self;

let mut url = stream_url.clone();
{
let mut pairs = url.query_pairs_mut();

// Add standard pre-recorded options
let query_string = self.options.urlencoded().unwrap();
let query_pairs: Vec<(Cow<str>, Cow<str>)> = query_string
.split('&')
.map(|s| {
let mut split = s.splitn(2, '=');
(
Cow::from(split.next().unwrap_or_default()),
Cow::from(split.next().unwrap_or_default()),
)
})
.collect();
// Add standard pre-recorded options.
//
// Here we serialize the options and then deserialize
// in order to avoid duplicating serialization logic.
//
// TODO: We should be able to lean on the serde more
// to avoid multiple serialization rounds.
pairs.extend_pairs(
serde_urlencoded::from_str::<Vec<(String, String)>>(&options.urlencoded()?)
.expect("constructed query string can be deserialized"),
);

for (key, value) in query_pairs {
pairs.append_pair(&key, &value);
}
if let Some(encoding) = &self.encoding {
// Add streaming-specific options
if let Some(encoding) = encoding {
pairs.append_pair("encoding", encoding.as_str());
}
if let Some(sample_rate) = self.sample_rate {
if let Some(sample_rate) = sample_rate {
pairs.append_pair("sample_rate", &sample_rate.to_string());
}
if let Some(channels) = self.channels {
if let Some(channels) = channels {
pairs.append_pair("channels", &channels.to_string());
}
if let Some(endpointing) = self.endpointing {
if let Some(endpointing) = endpointing {
pairs.append_pair("endpointing", &endpointing.to_str());
}
if let Some(utterance_end_ms) = self.utterance_end_ms {
if let Some(utterance_end_ms) = utterance_end_ms {
pairs.append_pair("utterance_end_ms", &utterance_end_ms.to_string());
}
if let Some(interim_results) = self.interim_results {
if let Some(interim_results) = interim_results {
pairs.append_pair("interim_results", &interim_results.to_string());
}
if let Some(no_delay) = self.no_delay {
if let Some(no_delay) = no_delay {
pairs.append_pair("no_delay", &no_delay.to_string());
}
if let Some(vad_events) = self.vad_events {
if let Some(vad_events) = vad_events {
pairs.append_pair("vad_events", &vad_events.to_string());
}
}

Ok(url)
}

Expand Down Expand Up @@ -305,7 +318,7 @@ where
.header("upgrade", "websocket")
.header("sec-websocket-version", "13");

let builder = if let Some(api_key) = self.config.api_key.as_deref() {
let builder = if let Some(api_key) = self.deepgram.api_key.as_deref() {
builder.header("authorization", format!("token {}", api_key))
} else {
builder
Expand Down

0 comments on commit 22b531a

Please sign in to comment.