Skip to content

Commit

Permalink
Add handlers for errors, open, close and results
Browse files Browse the repository at this point in the history
  • Loading branch information
DamienDeepgram committed Jul 25, 2024
1 parent 7d837e9 commit e628cfa
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 20 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ categories = ["api-bindings", "multimedia::audio"]
audio = "0.2.0"
bytes = "1"
futures = "0.3"
futures-util = { version = "0.3" , optional = true }
http = "0.2"
pin-project = "1"
reqwest = { version = "0.11.22", default-features = false, features = ["json", "rustls-tls", "stream"] }
Expand Down
31 changes: 24 additions & 7 deletions examples/transcription/websocket/microphone_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::Sample;
use crossbeam::channel::RecvError;
use deepgram::common::options::Encoding;
use deepgram::listen::websocket::Event;
use futures::channel::mpsc::{self, Receiver as FuturesReceiver};
use futures::stream::StreamExt;
use futures::SinkExt;
use futures_util::stream::StreamExt;

use deepgram::{Deepgram, DeepgramError};

Expand Down Expand Up @@ -92,8 +93,21 @@ fn microphone_as_stream() -> FuturesReceiver<Result<Bytes, RecvError>> {
async fn main() -> Result<(), DeepgramError> {
let dg = Deepgram::new(env::var("DEEPGRAM_API_KEY").unwrap());

let mut results = dg
.transcription()
let (event_tx, mut event_rx) = tokio::sync::mpsc::channel::<Event>(100);

// Event handling task
tokio::spawn(async move {
while let Some(event) = event_rx.recv().await {
match event {
Event::Open => println!("Connection opened"),
Event::Close => println!("Connection closed"),
Event::Error(e) => eprintln!("Error occurred: {:?}", e),
Event::Result(result) => println!("got: {:?}", result),
}
}
});

let mut transcription_stream = dg.transcription()
.stream_request()
.keep_alive()
.stream(microphone_as_stream())
Expand All @@ -102,12 +116,15 @@ async fn main() -> Result<(), DeepgramError> {
.sample_rate(44100)
// TODO Specific to my machine, not general enough example.
.channels(2)
.start()
.start(event_tx.clone())
.await?;

while let Some(result) = results.next().await {
println!("got: {:?}", result);
}
while let Some(response) = transcription_stream.next().await {
match response {
Ok(result) => println!("Transcription result: {:?}", result),
Err(e) => eprintln!("Transcription error: {:?}", e),
}
}

Ok(())
}
35 changes: 25 additions & 10 deletions examples/transcription/websocket/simple_stream.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::env;
use std::time::Duration;

use futures::stream::StreamExt;
use tokio::sync::mpsc;
use futures_util::stream::StreamExt;

use deepgram::{
common::options::{Encoding, Endpointing, Language, Options},
Deepgram, DeepgramError,
common::options::{Encoding, Endpointing, Language, Options}, listen::websocket::Event, Deepgram, DeepgramError
};

static PATH_TO_FILE: &str = "examples/audio/bueller.wav";
Expand All @@ -20,8 +19,21 @@ async fn main() -> Result<(), DeepgramError> {
.language(Language::en_US)
.build();

let mut results = dg
.transcription()
let (event_tx, mut event_rx) = mpsc::channel::<Event>(100);

// Event handling task
tokio::spawn(async move {
while let Some(event) = event_rx.recv().await {
match event {
Event::Open => println!("Connection opened"),
Event::Close => println!("Connection closed"),
Event::Error(e) => eprintln!("Error occurred: {:?}", e),
Event::Result(result) => println!("got: {:?}", result),
}
}
});

let mut transcription_stream = dg.transcription()
.stream_request_with_options(Some(&options))
.keep_alive()
.encoding(Encoding::Linear16)
Expand All @@ -32,13 +44,16 @@ async fn main() -> Result<(), DeepgramError> {
.utterance_end_ms(1000)
.vad_events(true)
.no_delay(true)
.file(PATH_TO_FILE, AUDIO_CHUNK_SIZE, Duration::from_millis(16))
.file(PATH_TO_FILE, AUDIO_CHUNK_SIZE, Duration::from_millis(16), event_tx.clone())
.await?
.start()
.start(event_tx.clone())
.await?;

while let Some(result) = results.next().await {
println!("got: {:?}", result);
while let Some(response) = transcription_stream.next().await {
match response {
Ok(result) => println!("Transcription result: {:?}", result),
Err(e) => eprintln!("Transcription error: {:?}", e),
}
}

Ok(())
Expand Down
13 changes: 12 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use core::fmt;
use std::io;
use std::ops::Deref;

use futures::channel::mpsc;
use reqwest::{
header::{HeaderMap, HeaderValue},
RequestBuilder,
Expand Down Expand Up @@ -143,14 +144,24 @@ pub enum DeepgramError {
#[error("Something went wrong during I/O: {0}")]
IoError(#[from] io::Error),

#[cfg(feature = "listen")]
/// Something went wrong with WS.
#[cfg(feature = "listen")]
#[error("Something went wrong with WS: {0}")]
WsError(#[from] tungstenite::Error),

/// Something went wrong during serialization/deserialization.
#[error("Something went wrong during serialization/deserialization: {0}")]
SerdeError(#[from] serde_json::Error),

/// Something went wrong with sending
#[cfg(feature = "listen")]
#[error("Something went wrong with WS: {0}")]
SendError(#[from] mpsc::SendError),

/// Something went wrong with receiving
#[cfg(feature = "listen")]
#[error("Channel receive error: {0}")]
ReceiveError(String),
}

#[cfg_attr(not(feature = "listen"), allow(unused))]
Expand Down
32 changes: 30 additions & 2 deletions src/listen/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ use std::time::Duration;

use bytes::{Bytes, BytesMut};
use futures::channel::mpsc::{self, Receiver};
use futures::channel::mpsc as futures_mpsc;
use futures::stream::StreamExt;
use futures::{SinkExt, Stream};
use http::Request;
use pin_project::pin_project;
use tokio::fs::File;
use tokio::sync::Mutex;
use tokio::sync::mpsc::Sender;
use tokio::time;
use tokio_tungstenite::tungstenite::protocol::Message;
use tokio_util::io::ReaderStream;
Expand All @@ -36,6 +38,15 @@ use crate::{Deepgram, DeepgramError, Result, Transcription};

static LIVE_LISTEN_URL_PATH: &str = "v1/listen";

// Define event types
#[derive(Debug)]
pub enum Event {
Open,
Close,
Error(DeepgramError),
Result(String),
}

#[derive(Debug)]
pub struct StreamRequestBuilder<'a, S, E>
where
Expand Down Expand Up @@ -214,6 +225,7 @@ impl<'a> StreamRequestBuilder<'a, Receiver<Result<Bytes>>, DeepgramError> {
filename: impl AsRef<Path>,
frame_size: usize,
frame_delay: Duration,
event_tx: Sender<Event>,
) -> Result<StreamRequestBuilder<'a, Receiver<Result<Bytes>>, DeepgramError>> {
let file = File::open(filename).await?;
let mut chunker = FileChunker::new(file, frame_size);
Expand All @@ -223,7 +235,7 @@ impl<'a> StreamRequestBuilder<'a, Receiver<Result<Bytes>>, DeepgramError> {
tokio::time::sleep(frame_delay).await;
if let Err(e) = tx.send(frame).await {
eprintln!("Failed to send frame: {:?}", e);
// TODO Handle the error, e.g., break the loop, retry, or log the error
let _ = event_tx.send(Event::Error(DeepgramError::from(e))).await;
break;
}
}
Expand All @@ -245,7 +257,10 @@ where
S: Stream<Item = std::result::Result<Bytes, E>> + Send + Unpin + 'static,
E: Send + std::fmt::Debug,
{
pub async fn start(self) -> Result<Receiver<Result<StreamResponse>>> {
pub async fn start(
self,
event_tx: Sender<Event>,
) -> std::result::Result<futures_mpsc::Receiver<std::result::Result<StreamResponse, DeepgramError>>, DeepgramError> {
// This unwrap is safe because we're parsing a static.
let mut url = self.stream_url;
{
Expand Down Expand Up @@ -322,6 +337,13 @@ where
let write = Arc::new(Mutex::new(write));
let (mut tx, rx) = mpsc::channel::<Result<StreamResponse>>(1);

let event_tx_open = event_tx.clone();
let event_tx_keep_alive = event_tx.clone();
let event_tx_send = event_tx.clone();
let event_tx_receive = event_tx.clone();

event_tx_open.send(Event::Open).await.unwrap();

// Spawn the keep-alive task
if self.keep_alive.unwrap_or(false) {
{
Expand All @@ -335,6 +357,7 @@ where
let mut write = write_clone.lock().await;
if let Err(e) = write.send(keep_alive_message).await {
eprintln!("Error Sending Keep Alive: {:?}", e);
let _ = event_tx_keep_alive.send(Event::Error(DeepgramError::from(e))).await;
break;
}
}
Expand All @@ -350,11 +373,13 @@ where
let mut write = write_clone.lock().await;
if let Err(e) = write.send(frame).await {
println!("Error sending frame: {:?}", e);
let _ = event_tx_send.send(Event::Error(DeepgramError::from(e))).await;
break;
}
}
Err(e) => {
println!("Error receiving from source: {:?}", e);
let _ = event_tx_send.send(Event::Error(DeepgramError::ReceiveError(format!("{:?}", e)))).await;
break;
}
}
Expand All @@ -363,6 +388,7 @@ where
let mut write = write_clone.lock().await;
if let Err(e) = write.send(Message::binary([])).await {
println!("Error sending final frame: {:?}", e);
let _ = event_tx_send.send(Event::Error(DeepgramError::from(e))).await;
}
};

Expand All @@ -376,6 +402,7 @@ where
let resp = serde_json::from_str(&txt).map_err(DeepgramError::from);
if let Err(e) = tx.send(resp).await {
eprintln!("Failed to send message: {:?}", e);
let _ = event_tx_receive.send(Event::Error(DeepgramError::from(e))).await;
// Handle the error appropriately, e.g., log it, retry, or break the loop
break;
}
Expand All @@ -386,6 +413,7 @@ where
let mut write = write.lock().await;
if let Err(e) = write.send(Message::Close(None)).await {
eprintln!("Failed to send close frame: {:?}", e);
let _ = event_tx_receive.send(Event::Error(DeepgramError::from(e))).await;
}
break;
}
Expand Down

0 comments on commit e628cfa

Please sign in to comment.