From e628cfa822855edff5888691e7b11869ff690f19 Mon Sep 17 00:00:00 2001 From: Damien Murphy Date: Wed, 24 Jul 2024 21:10:09 -0700 Subject: [PATCH] Add handlers for errors, open, close and results --- Cargo.toml | 1 + .../websocket/microphone_stream.rs | 31 ++++++++++++---- .../transcription/websocket/simple_stream.rs | 35 +++++++++++++------ src/lib.rs | 13 ++++++- src/listen/websocket.rs | 32 +++++++++++++++-- 5 files changed, 92 insertions(+), 20 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4c2c91a5..35b67d96 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ categories = ["api-bindings", "multimedia::audio"] audio = "0.2.0" bytes = "1" futures = "0.3" +futures-util = { version = "0.3" , optional = true } http = "0.2" pin-project = "1" reqwest = { version = "0.11.22", default-features = false, features = ["json", "rustls-tls", "stream"] } diff --git a/examples/transcription/websocket/microphone_stream.rs b/examples/transcription/websocket/microphone_stream.rs index 3295d071..c92ed21c 100644 --- a/examples/transcription/websocket/microphone_stream.rs +++ b/examples/transcription/websocket/microphone_stream.rs @@ -6,9 +6,10 @@ use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use cpal::Sample; use crossbeam::channel::RecvError; use deepgram::common::options::Encoding; +use deepgram::listen::websocket::Event; use futures::channel::mpsc::{self, Receiver as FuturesReceiver}; -use futures::stream::StreamExt; use futures::SinkExt; +use futures_util::stream::StreamExt; use deepgram::{Deepgram, DeepgramError}; @@ -92,8 +93,21 @@ fn microphone_as_stream() -> FuturesReceiver> { async fn main() -> Result<(), DeepgramError> { let dg = Deepgram::new(env::var("DEEPGRAM_API_KEY").unwrap()); - let mut results = dg - .transcription() + let (event_tx, mut event_rx) = tokio::sync::mpsc::channel::(100); + + // Event handling task + tokio::spawn(async move { + while let Some(event) = event_rx.recv().await { + match event { + Event::Open => println!("Connection opened"), + Event::Close => println!("Connection closed"), + Event::Error(e) => eprintln!("Error occurred: {:?}", e), + Event::Result(result) => println!("got: {:?}", result), + } + } + }); + + let mut transcription_stream = dg.transcription() .stream_request() .keep_alive() .stream(microphone_as_stream()) @@ -102,12 +116,15 @@ async fn main() -> Result<(), DeepgramError> { .sample_rate(44100) // TODO Specific to my machine, not general enough example. .channels(2) - .start() + .start(event_tx.clone()) .await?; - while let Some(result) = results.next().await { - println!("got: {:?}", result); - } + while let Some(response) = transcription_stream.next().await { + match response { + Ok(result) => println!("Transcription result: {:?}", result), + Err(e) => eprintln!("Transcription error: {:?}", e), + } + } Ok(()) } diff --git a/examples/transcription/websocket/simple_stream.rs b/examples/transcription/websocket/simple_stream.rs index f3040af3..eab61a04 100644 --- a/examples/transcription/websocket/simple_stream.rs +++ b/examples/transcription/websocket/simple_stream.rs @@ -1,11 +1,10 @@ use std::env; use std::time::Duration; - -use futures::stream::StreamExt; +use tokio::sync::mpsc; +use futures_util::stream::StreamExt; use deepgram::{ - common::options::{Encoding, Endpointing, Language, Options}, - Deepgram, DeepgramError, + common::options::{Encoding, Endpointing, Language, Options}, listen::websocket::Event, Deepgram, DeepgramError }; static PATH_TO_FILE: &str = "examples/audio/bueller.wav"; @@ -20,8 +19,21 @@ async fn main() -> Result<(), DeepgramError> { .language(Language::en_US) .build(); - let mut results = dg - .transcription() + let (event_tx, mut event_rx) = mpsc::channel::(100); + + // Event handling task + tokio::spawn(async move { + while let Some(event) = event_rx.recv().await { + match event { + Event::Open => println!("Connection opened"), + Event::Close => println!("Connection closed"), + Event::Error(e) => eprintln!("Error occurred: {:?}", e), + Event::Result(result) => println!("got: {:?}", result), + } + } + }); + + let mut transcription_stream = dg.transcription() .stream_request_with_options(Some(&options)) .keep_alive() .encoding(Encoding::Linear16) @@ -32,13 +44,16 @@ async fn main() -> Result<(), DeepgramError> { .utterance_end_ms(1000) .vad_events(true) .no_delay(true) - .file(PATH_TO_FILE, AUDIO_CHUNK_SIZE, Duration::from_millis(16)) + .file(PATH_TO_FILE, AUDIO_CHUNK_SIZE, Duration::from_millis(16), event_tx.clone()) .await? - .start() + .start(event_tx.clone()) .await?; - while let Some(result) = results.next().await { - println!("got: {:?}", result); + while let Some(response) = transcription_stream.next().await { + match response { + Ok(result) => println!("Transcription result: {:?}", result), + Err(e) => eprintln!("Transcription error: {:?}", e), + } } Ok(()) diff --git a/src/lib.rs b/src/lib.rs index da064e42..cb9e0db7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ use core::fmt; use std::io; use std::ops::Deref; +use futures::channel::mpsc; use reqwest::{ header::{HeaderMap, HeaderValue}, RequestBuilder, @@ -143,14 +144,24 @@ pub enum DeepgramError { #[error("Something went wrong during I/O: {0}")] IoError(#[from] io::Error), - #[cfg(feature = "listen")] /// Something went wrong with WS. + #[cfg(feature = "listen")] #[error("Something went wrong with WS: {0}")] WsError(#[from] tungstenite::Error), /// Something went wrong during serialization/deserialization. #[error("Something went wrong during serialization/deserialization: {0}")] SerdeError(#[from] serde_json::Error), + + /// Something went wrong with sending + #[cfg(feature = "listen")] + #[error("Something went wrong with WS: {0}")] + SendError(#[from] mpsc::SendError), + + /// Something went wrong with receiving + #[cfg(feature = "listen")] + #[error("Channel receive error: {0}")] + ReceiveError(String), } #[cfg_attr(not(feature = "listen"), allow(unused))] diff --git a/src/listen/websocket.rs b/src/listen/websocket.rs index e7c41265..518e527b 100644 --- a/src/listen/websocket.rs +++ b/src/listen/websocket.rs @@ -19,12 +19,14 @@ use std::time::Duration; use bytes::{Bytes, BytesMut}; use futures::channel::mpsc::{self, Receiver}; +use futures::channel::mpsc as futures_mpsc; use futures::stream::StreamExt; use futures::{SinkExt, Stream}; use http::Request; use pin_project::pin_project; use tokio::fs::File; use tokio::sync::Mutex; +use tokio::sync::mpsc::Sender; use tokio::time; use tokio_tungstenite::tungstenite::protocol::Message; use tokio_util::io::ReaderStream; @@ -36,6 +38,15 @@ use crate::{Deepgram, DeepgramError, Result, Transcription}; static LIVE_LISTEN_URL_PATH: &str = "v1/listen"; +// Define event types +#[derive(Debug)] +pub enum Event { + Open, + Close, + Error(DeepgramError), + Result(String), +} + #[derive(Debug)] pub struct StreamRequestBuilder<'a, S, E> where @@ -214,6 +225,7 @@ impl<'a> StreamRequestBuilder<'a, Receiver>, DeepgramError> { filename: impl AsRef, frame_size: usize, frame_delay: Duration, + event_tx: Sender, ) -> Result>, DeepgramError>> { let file = File::open(filename).await?; let mut chunker = FileChunker::new(file, frame_size); @@ -223,7 +235,7 @@ impl<'a> StreamRequestBuilder<'a, Receiver>, DeepgramError> { tokio::time::sleep(frame_delay).await; if let Err(e) = tx.send(frame).await { eprintln!("Failed to send frame: {:?}", e); - // TODO Handle the error, e.g., break the loop, retry, or log the error + let _ = event_tx.send(Event::Error(DeepgramError::from(e))).await; break; } } @@ -245,7 +257,10 @@ where S: Stream> + Send + Unpin + 'static, E: Send + std::fmt::Debug, { - pub async fn start(self) -> Result>> { + pub async fn start( + self, + event_tx: Sender, + ) -> std::result::Result>, DeepgramError> { // This unwrap is safe because we're parsing a static. let mut url = self.stream_url; { @@ -322,6 +337,13 @@ where let write = Arc::new(Mutex::new(write)); let (mut tx, rx) = mpsc::channel::>(1); + let event_tx_open = event_tx.clone(); + let event_tx_keep_alive = event_tx.clone(); + let event_tx_send = event_tx.clone(); + let event_tx_receive = event_tx.clone(); + + event_tx_open.send(Event::Open).await.unwrap(); + // Spawn the keep-alive task if self.keep_alive.unwrap_or(false) { { @@ -335,6 +357,7 @@ where let mut write = write_clone.lock().await; if let Err(e) = write.send(keep_alive_message).await { eprintln!("Error Sending Keep Alive: {:?}", e); + let _ = event_tx_keep_alive.send(Event::Error(DeepgramError::from(e))).await; break; } } @@ -350,11 +373,13 @@ where let mut write = write_clone.lock().await; if let Err(e) = write.send(frame).await { println!("Error sending frame: {:?}", e); + let _ = event_tx_send.send(Event::Error(DeepgramError::from(e))).await; break; } } Err(e) => { println!("Error receiving from source: {:?}", e); + let _ = event_tx_send.send(Event::Error(DeepgramError::ReceiveError(format!("{:?}", e)))).await; break; } } @@ -363,6 +388,7 @@ where let mut write = write_clone.lock().await; if let Err(e) = write.send(Message::binary([])).await { println!("Error sending final frame: {:?}", e); + let _ = event_tx_send.send(Event::Error(DeepgramError::from(e))).await; } }; @@ -376,6 +402,7 @@ where let resp = serde_json::from_str(&txt).map_err(DeepgramError::from); if let Err(e) = tx.send(resp).await { eprintln!("Failed to send message: {:?}", e); + let _ = event_tx_receive.send(Event::Error(DeepgramError::from(e))).await; // Handle the error appropriately, e.g., log it, retry, or break the loop break; } @@ -386,6 +413,7 @@ where let mut write = write.lock().await; if let Err(e) = write.send(Message::Close(None)).await { eprintln!("Failed to send close frame: {:?}", e); + let _ = event_tx_receive.send(Event::Error(DeepgramError::from(e))).await; } break; }