From 5a1420c0656cd4a1c83615d7d14787203809dc0c Mon Sep 17 00:00:00 2001 From: Brent George <49082060+bd-g@users.noreply.github.com> Date: Thu, 22 Aug 2024 09:02:27 -0600 Subject: [PATCH] Streaming callbacks (#85) * add warning for callback_method for streaming * add callback parameter to WebsocketBuilder * fix clippy * add method on websocket to get request id * documentation and error fix --- Cargo.toml | 5 ++ .../websocket/callback_stream.rs | 54 +++++++++++++++++++ .../websocket/microphone_stream.rs | 1 + .../transcription/websocket/simple_stream.rs | 1 + src/common/options.rs | 5 ++ src/lib.rs | 4 ++ src/listen/websocket.rs | 54 ++++++++++++++++++- 7 files changed, 122 insertions(+), 2 deletions(-) create mode 100644 examples/transcription/websocket/callback_stream.rs diff --git a/Cargo.toml b/Cargo.toml index 81060cb9..2594fb58 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,6 +72,11 @@ name = "simple_stream" path = "examples/transcription/websocket/simple_stream.rs" required-features = ["listen"] +[[example]] +name = "callback_stream" +path = "examples/transcription/websocket/callback_stream.rs" +required-features = ["listen"] + [[example]] name = "microphone_stream" path = "examples/transcription/websocket/microphone_stream.rs" diff --git a/examples/transcription/websocket/callback_stream.rs b/examples/transcription/websocket/callback_stream.rs new file mode 100644 index 00000000..898183e1 --- /dev/null +++ b/examples/transcription/websocket/callback_stream.rs @@ -0,0 +1,54 @@ +use std::env; +use std::time::Duration; + +use futures::stream::StreamExt; + +use deepgram::{ + common::options::{Encoding, Endpointing, Language, Options}, + Deepgram, DeepgramError, +}; + +static PATH_TO_FILE: &str = "examples/audio/bueller.wav"; +static AUDIO_CHUNK_SIZE: usize = 3174; +static FRAME_DELAY: Duration = Duration::from_millis(16); + +#[tokio::main] +async fn main() -> Result<(), DeepgramError> { + let deepgram_api_key = + env::var("DEEPGRAM_API_KEY").expect("DEEPGRAM_API_KEY environmental variable"); + + let dg_client = Deepgram::new(&deepgram_api_key)?; + + let options = Options::builder() + .smart_format(true) + .language(Language::en_US) + .build(); + + let callback_url = env::var("DEEPGRAM_CALLBACK_URL") + .expect("DEEPGRAM_CALLBACK_URL environmental variable") + .parse() + .expect("DEEPGRAM_CALLBACK_URL not a valid URL"); + + let mut results = dg_client + .transcription() + .stream_request_with_options(options) + .keep_alive() + .encoding(Encoding::Linear16) + .sample_rate(44100) + .channels(2) + .endpointing(Endpointing::CustomDurationMs(300)) + .interim_results(true) + .utterance_end_ms(1000) + .vad_events(true) + .no_delay(true) + .callback(callback_url) + .file(PATH_TO_FILE, AUDIO_CHUNK_SIZE, FRAME_DELAY) + .await?; + + println!("Deepgram Request ID: {}", results.request_id()); + while let Some(result) = results.next().await { + println!("got: {:?}", result); + } + + Ok(()) +} diff --git a/examples/transcription/websocket/microphone_stream.rs b/examples/transcription/websocket/microphone_stream.rs index ab86fe7f..ac3ece33 100644 --- a/examples/transcription/websocket/microphone_stream.rs +++ b/examples/transcription/websocket/microphone_stream.rs @@ -107,6 +107,7 @@ async fn main() -> Result<(), DeepgramError> { .stream(microphone_as_stream()) .await?; + println!("Deepgram Request ID: {}", results.request_id()); while let Some(result) = results.next().await { println!("got: {:?}", result); } diff --git a/examples/transcription/websocket/simple_stream.rs b/examples/transcription/websocket/simple_stream.rs index 5aa6f8a4..90462d53 100644 --- a/examples/transcription/websocket/simple_stream.rs +++ b/examples/transcription/websocket/simple_stream.rs @@ -39,6 +39,7 @@ async fn main() -> Result<(), DeepgramError> { .file(PATH_TO_FILE, AUDIO_CHUNK_SIZE, FRAME_DELAY) .await?; + println!("Deepgram Request ID: {}", results.request_id()); while let Some(result) = results.next().await { println!("got: {:?}", result); } diff --git a/src/common/options.rs b/src/common/options.rs index 4044ace1..d533dddd 100644 --- a/src/common/options.rs +++ b/src/common/options.rs @@ -1887,7 +1887,12 @@ impl OptionsBuilder { /// /// See the [Deepgram Callback Method feature docs][docs] for more info. /// + /// Note that modifying the callback method is only available for pre-recorded audio. + /// See the [Deepgram Callback feature docs for streaming][streaming-docs] for details + /// on streaming callbacks. + /// /// [docs]: https://developers.deepgram.com/docs/callback#pre-recorded-audio + /// [streaming-docs]: https://developers.deepgram.com/docs/callback#streaming-audio /// /// # Examples /// diff --git a/src/lib.rs b/src/lib.rs index a1d9834c..ed1a21f3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -172,6 +172,10 @@ pub enum DeepgramError { /// An unexpected error occurred in the client #[error("an unepected error occurred in the deepgram client: {0}")] InternalClientError(anyhow::Error), + + /// A Deepgram API server response was not in the expected format. + #[error("The Deepgram API server response was not in the expected format: {0}")] + UnexpectedServerResponse(anyhow::Error), } #[cfg_attr(not(feature = "listen"), allow(unused))] diff --git a/src/listen/websocket.rs b/src/listen/websocket.rs index 2bd2fbf8..5f0b375f 100644 --- a/src/listen/websocket.rs +++ b/src/listen/websocket.rs @@ -18,6 +18,7 @@ use std::{ time::Duration, }; +use anyhow::anyhow; use bytes::Bytes; use futures::{ channel::mpsc::{self, Receiver, Sender}, @@ -36,6 +37,7 @@ use tungstenite::{ protocol::frame::coding::{Data, OpCode}, }; use url::Url; +use uuid::Uuid; use self::file_chunker::FileChunker; use crate::{ @@ -62,6 +64,7 @@ pub struct WebsocketBuilder<'a> { vad_events: Option, stream_url: Url, keep_alive: Option, + callback: Option, } impl Transcription<'_> { @@ -143,6 +146,7 @@ impl Transcription<'_> { vad_events: None, stream_url: self.listen_stream_url(), keep_alive: None, + callback: None, } } @@ -214,6 +218,7 @@ impl<'a> WebsocketBuilder<'a> { no_delay, vad_events, stream_url, + callback, } = self; let mut url = stream_url.clone(); @@ -257,6 +262,9 @@ impl<'a> WebsocketBuilder<'a> { if let Some(vad_events) = vad_events { pairs.append_pair("vad_events", &vad_events.to_string()); } + if let Some(callback) = callback { + pairs.append_pair("callback", callback.as_ref()); + } } Ok(url) @@ -315,6 +323,12 @@ impl<'a> WebsocketBuilder<'a> { self } + + pub fn callback(mut self, callback: Url) -> Self { + self.callback = Some(callback); + + self + } } impl<'a> WebsocketBuilder<'a> { @@ -351,6 +365,7 @@ impl<'a> WebsocketBuilder<'a> { let (tx, rx) = mpsc::channel(1); let mut is_done = false; + let request_id = handle.request_id(); tokio::task::spawn(async move { let mut handle = handle; let mut tx = tx; @@ -421,7 +436,11 @@ impl<'a> WebsocketBuilder<'a> { } } }); - Ok(TranscriptionStream { rx, done: false }) + Ok(TranscriptionStream { + rx, + done: false, + request_id, + }) } /// A low level interface to the Deepgram websocket transcription API. @@ -628,6 +647,7 @@ impl Deref for Audio { pub struct WebsocketHandle { message_tx: Sender, response_rx: Receiver>, + request_id: Uuid, } impl<'a> WebsocketHandle { @@ -652,7 +672,21 @@ impl<'a> WebsocketHandle { builder.body(())? }; - let (ws_stream, _) = tokio_tungstenite::connect_async(request).await?; + let (ws_stream, upgrade_response) = tokio_tungstenite::connect_async(request).await?; + + let request_id = upgrade_response + .headers() + .get("dg-request-id") + .ok_or(DeepgramError::UnexpectedServerResponse(anyhow!( + "Websocket upgrade headers missing request ID" + )))? + .to_str() + .ok() + .and_then(|req_header_str| Uuid::parse_str(req_header_str).ok()) + .ok_or(DeepgramError::UnexpectedServerResponse(anyhow!( + "Received malformed request ID in websocket upgrade headers" + )))?; + let (message_tx, message_rx) = mpsc::channel(256); let (response_tx, response_rx) = mpsc::channel(256); @@ -670,6 +704,7 @@ impl<'a> WebsocketHandle { Ok(WebsocketHandle { message_tx, response_rx, + request_id, }) } @@ -725,6 +760,10 @@ impl<'a> WebsocketHandle { // eprintln!(" receiving response: {resp:?}"); resp } + + pub fn request_id(&self) -> Uuid { + self.request_id + } } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)] @@ -741,6 +780,7 @@ pub struct TranscriptionStream { #[pin] rx: Receiver>, done: bool, + request_id: Uuid, } impl Stream for TranscriptionStream { @@ -752,6 +792,16 @@ impl Stream for TranscriptionStream { } } +impl TranscriptionStream { + /// Returns the Deepgram request ID for the speech-to-text live request. + /// + /// A request ID needs to be provided to Deepgram as part of any support + /// or troubleshooting assistance related to a specific request. + pub fn request_id(&self) -> Uuid { + self.request_id + } +} + mod file_chunker { use bytes::{Bytes, BytesMut}; use futures::Stream;