From bf6dc9de23e8d9a9e5a8f00d425b7be9e657f51d Mon Sep 17 00:00:00 2001 From: Brent George Date: Thu, 22 Aug 2024 09:30:02 -0400 Subject: [PATCH] add method on websocket to get request id --- .../websocket/callback_stream.rs | 1 + .../websocket/microphone_stream.rs | 1 + .../transcription/websocket/simple_stream.rs | 1 + src/lib.rs | 4 ++ src/listen/websocket.rs | 37 ++++++++++++++++++- 5 files changed, 42 insertions(+), 2 deletions(-) diff --git a/examples/transcription/websocket/callback_stream.rs b/examples/transcription/websocket/callback_stream.rs index 43863afe..898183e1 100644 --- a/examples/transcription/websocket/callback_stream.rs +++ b/examples/transcription/websocket/callback_stream.rs @@ -45,6 +45,7 @@ async fn main() -> Result<(), DeepgramError> { .file(PATH_TO_FILE, AUDIO_CHUNK_SIZE, FRAME_DELAY) .await?; + println!("Deepgram Request ID: {}", results.request_id()); while let Some(result) = results.next().await { println!("got: {:?}", result); } diff --git a/examples/transcription/websocket/microphone_stream.rs b/examples/transcription/websocket/microphone_stream.rs index ab86fe7f..ac3ece33 100644 --- a/examples/transcription/websocket/microphone_stream.rs +++ b/examples/transcription/websocket/microphone_stream.rs @@ -107,6 +107,7 @@ async fn main() -> Result<(), DeepgramError> { .stream(microphone_as_stream()) .await?; + println!("Deepgram Request ID: {}", results.request_id()); while let Some(result) = results.next().await { println!("got: {:?}", result); } diff --git a/examples/transcription/websocket/simple_stream.rs b/examples/transcription/websocket/simple_stream.rs index 5aa6f8a4..90462d53 100644 --- a/examples/transcription/websocket/simple_stream.rs +++ b/examples/transcription/websocket/simple_stream.rs @@ -39,6 +39,7 @@ async fn main() -> Result<(), DeepgramError> { .file(PATH_TO_FILE, AUDIO_CHUNK_SIZE, FRAME_DELAY) .await?; + println!("Deepgram Request ID: {}", results.request_id()); while let Some(result) = results.next().await { println!("got: {:?}", result); } diff --git a/src/lib.rs b/src/lib.rs index a1d9834c..d11c95b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -172,6 +172,10 @@ pub enum DeepgramError { /// An unexpected error occurred in the client #[error("an unepected error occurred in the deepgram client: {0}")] InternalClientError(anyhow::Error), + + /// A Deepgram API server response was not in the expected format. + #[error("The Deepgram API server response was not in the expected format: {0}")] + UnexpectedServerResponse(String), } #[cfg_attr(not(feature = "listen"), allow(unused))] diff --git a/src/listen/websocket.rs b/src/listen/websocket.rs index 42b07630..4d064dd3 100644 --- a/src/listen/websocket.rs +++ b/src/listen/websocket.rs @@ -36,6 +36,7 @@ use tungstenite::{ protocol::frame::coding::{Data, OpCode}, }; use url::Url; +use uuid::Uuid; use self::file_chunker::FileChunker; use crate::{ @@ -363,6 +364,7 @@ impl<'a> WebsocketBuilder<'a> { let (tx, rx) = mpsc::channel(1); let mut is_done = false; + let request_id = handle.request_id(); tokio::task::spawn(async move { let mut handle = handle; let mut tx = tx; @@ -433,7 +435,11 @@ impl<'a> WebsocketBuilder<'a> { } } }); - Ok(TranscriptionStream { rx, done: false }) + Ok(TranscriptionStream { + rx, + done: false, + request_id, + }) } /// A low level interface to the Deepgram websocket transcription API. @@ -640,6 +646,7 @@ impl Deref for Audio { pub struct WebsocketHandle { message_tx: Sender, response_rx: Receiver>, + request_id: Uuid, } impl<'a> WebsocketHandle { @@ -664,7 +671,21 @@ impl<'a> WebsocketHandle { builder.body(())? }; - let (ws_stream, _) = tokio_tungstenite::connect_async(request).await?; + let (ws_stream, upgrade_response) = tokio_tungstenite::connect_async(request).await?; + + let request_id = upgrade_response + .headers() + .get("dg-request-id") + .ok_or(DeepgramError::UnexpectedServerResponse( + "Websocket upgrade headers missing request ID".to_string(), + ))? + .to_str() + .ok() + .and_then(|req_header_str| Uuid::parse_str(req_header_str).ok()) + .ok_or(DeepgramError::UnexpectedServerResponse( + "Received malformed request ID in websocket upgrade headers".to_string(), + ))?; + let (message_tx, message_rx) = mpsc::channel(256); let (response_tx, response_rx) = mpsc::channel(256); @@ -682,6 +703,7 @@ impl<'a> WebsocketHandle { Ok(WebsocketHandle { message_tx, response_rx, + request_id, }) } @@ -737,6 +759,10 @@ impl<'a> WebsocketHandle { // eprintln!(" receiving response: {resp:?}"); resp } + + pub fn request_id(&self) -> Uuid { + self.request_id + } } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)] @@ -753,6 +779,7 @@ pub struct TranscriptionStream { #[pin] rx: Receiver>, done: bool, + request_id: Uuid, } impl Stream for TranscriptionStream { @@ -764,6 +791,12 @@ impl Stream for TranscriptionStream { } } +impl TranscriptionStream { + pub fn request_id(&self) -> Uuid { + self.request_id + } +} + mod file_chunker { use bytes::{Bytes, BytesMut}; use futures::Stream;