Skip to content

Commit

Permalink
add method on websocket to get request id
Browse files Browse the repository at this point in the history
  • Loading branch information
bd-g committed Aug 22, 2024
1 parent 41b5851 commit bf6dc9d
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 2 deletions.
1 change: 1 addition & 0 deletions examples/transcription/websocket/callback_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ async fn main() -> Result<(), DeepgramError> {
.file(PATH_TO_FILE, AUDIO_CHUNK_SIZE, FRAME_DELAY)
.await?;

println!("Deepgram Request ID: {}", results.request_id());
while let Some(result) = results.next().await {
println!("got: {:?}", result);
}
Expand Down
1 change: 1 addition & 0 deletions examples/transcription/websocket/microphone_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ async fn main() -> Result<(), DeepgramError> {
.stream(microphone_as_stream())
.await?;

println!("Deepgram Request ID: {}", results.request_id());
while let Some(result) = results.next().await {
println!("got: {:?}", result);
}
Expand Down
1 change: 1 addition & 0 deletions examples/transcription/websocket/simple_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ async fn main() -> Result<(), DeepgramError> {
.file(PATH_TO_FILE, AUDIO_CHUNK_SIZE, FRAME_DELAY)
.await?;

println!("Deepgram Request ID: {}", results.request_id());
while let Some(result) = results.next().await {
println!("got: {:?}", result);
}
Expand Down
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ pub enum DeepgramError {
/// An unexpected error occurred in the client
#[error("an unepected error occurred in the deepgram client: {0}")]
InternalClientError(anyhow::Error),

/// A Deepgram API server response was not in the expected format.
#[error("The Deepgram API server response was not in the expected format: {0}")]
UnexpectedServerResponse(String),
}

#[cfg_attr(not(feature = "listen"), allow(unused))]
Expand Down
37 changes: 35 additions & 2 deletions src/listen/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use tungstenite::{
protocol::frame::coding::{Data, OpCode},
};
use url::Url;
use uuid::Uuid;

use self::file_chunker::FileChunker;
use crate::{
Expand Down Expand Up @@ -363,6 +364,7 @@ impl<'a> WebsocketBuilder<'a> {

let (tx, rx) = mpsc::channel(1);
let mut is_done = false;
let request_id = handle.request_id();
tokio::task::spawn(async move {
let mut handle = handle;
let mut tx = tx;
Expand Down Expand Up @@ -433,7 +435,11 @@ impl<'a> WebsocketBuilder<'a> {
}
}
});
Ok(TranscriptionStream { rx, done: false })
Ok(TranscriptionStream {
rx,
done: false,
request_id,
})
}

/// A low level interface to the Deepgram websocket transcription API.
Expand Down Expand Up @@ -640,6 +646,7 @@ impl Deref for Audio {
pub struct WebsocketHandle {
message_tx: Sender<WsMessage>,
response_rx: Receiver<Result<StreamResponse>>,
request_id: Uuid,
}

impl<'a> WebsocketHandle {
Expand All @@ -664,7 +671,21 @@ impl<'a> WebsocketHandle {
builder.body(())?
};

let (ws_stream, _) = tokio_tungstenite::connect_async(request).await?;
let (ws_stream, upgrade_response) = tokio_tungstenite::connect_async(request).await?;

let request_id = upgrade_response
.headers()
.get("dg-request-id")
.ok_or(DeepgramError::UnexpectedServerResponse(
"Websocket upgrade headers missing request ID".to_string(),
))?
.to_str()
.ok()
.and_then(|req_header_str| Uuid::parse_str(req_header_str).ok())
.ok_or(DeepgramError::UnexpectedServerResponse(
"Received malformed request ID in websocket upgrade headers".to_string(),
))?;

let (message_tx, message_rx) = mpsc::channel(256);
let (response_tx, response_rx) = mpsc::channel(256);

Expand All @@ -682,6 +703,7 @@ impl<'a> WebsocketHandle {
Ok(WebsocketHandle {
message_tx,
response_rx,
request_id,
})
}

Expand Down Expand Up @@ -737,6 +759,10 @@ impl<'a> WebsocketHandle {
// eprintln!("<handle> receiving response: {resp:?}");
resp
}

pub fn request_id(&self) -> Uuid {
self.request_id
}
}

#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)]
Expand All @@ -753,6 +779,7 @@ pub struct TranscriptionStream {
#[pin]
rx: Receiver<Result<StreamResponse>>,
done: bool,
request_id: Uuid,
}

impl Stream for TranscriptionStream {
Expand All @@ -764,6 +791,12 @@ impl Stream for TranscriptionStream {
}
}

impl TranscriptionStream {
pub fn request_id(&self) -> Uuid {
self.request_id
}
}

mod file_chunker {
use bytes::{Bytes, BytesMut};
use futures::Stream;
Expand Down

0 comments on commit bf6dc9d

Please sign in to comment.