Skip to content

Commit

Permalink
ensure keep alive is sending correctly and keeping the websocket open…
Browse files Browse the repository at this point in the history
…, also use a mutex to avoid concurrent writes
  • Loading branch information
DamienDeepgram committed Jul 19, 2024
1 parent 1a0bb25 commit 2807ac1
Showing 1 changed file with 39 additions and 30 deletions.
69 changes: 39 additions & 30 deletions src/listen/live.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use serde_urlencoded;
use std::borrow::Cow;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

Expand All @@ -23,6 +24,7 @@ use futures::{SinkExt, Stream};
use http::Request;
use pin_project::pin_project;
use tokio::fs::File;
use tokio::sync::Mutex;
use tokio::time;
use tokio_tungstenite::tungstenite::protocol::Message;
use tokio_util::io::ReaderStream;
Expand Down Expand Up @@ -316,26 +318,52 @@ where
builder.body(())?
};
let (ws_stream, _) = tokio_tungstenite::connect_async(request).await?;
let (mut write, mut read) = ws_stream.split();
let (write, mut read) = ws_stream.split();
let write = Arc::new(Mutex::new(write));
let (mut tx, rx) = mpsc::channel::<Result<StreamResponse>>(1);

// Spawn the keep-alive task
if self.keep_alive.unwrap_or(false) {
{
let write_clone = Arc::clone(&write);
tokio::spawn(async move {
let mut interval = time::interval(Duration::from_secs(1));
loop {
println!("Keep Alive");
interval.tick().await;
let keep_alive_message = Message::Text("{\"type\": \"KeepAlive\"}".to_string());
let mut write = write_clone.lock().await;
if let Err(e) = write.send(keep_alive_message).await {
println!("Error Sending Keep Alive: {:?}", e);
break;
}
}
})
};
}

let write_clone = Arc::clone(&write);
let send_task = async move {
loop {
match source.next().await {
None => break,
Some(Ok(frame)) => {
// This unwrap is not safe.
write.send(frame).await.unwrap();
while let Some(frame) = source.next().await {
match frame {
Ok(frame) => {
let mut write = write_clone.lock().await;
if let Err(e) = write.send(frame).await {
println!("Error sending frame: {:?}", e);
break;
}
}
Some(e) => {
let _ = dbg!(e);
Err(e) => {
println!("Error receiving from source: {:?}", e);
break;
}
}
}

// This unwrap is not safe.
write.send(Message::binary([])).await.unwrap();
let mut write = write_clone.lock().await;
if let Err(e) = write.send(Message::binary([])).await {
println!("Error sending final frame: {:?}", e);
}
};

let recv_task = async move {
Expand Down Expand Up @@ -367,25 +395,6 @@ where
}

pub fn keep_alive(mut self) -> Self {
tokio::spawn({
async move {
let mut interval = time::interval(Duration::from_secs(10));
loop {
interval.tick().await;
let (mut tx, _rx) = mpsc::channel(1);
let keep_alive_message = Message::Text("{\"type\": \"KeepAlive\"}".to_string());
if let Err(e) = tx.send(keep_alive_message).await {
if e.is_disconnected() {
println!("Keep Alive waiting for connection",);
} else {
println!("Error Sending Keep Alive: {:?}", e);
break;
}
}
}
}
});

self.keep_alive = Some(true);

self
Expand Down

0 comments on commit 2807ac1

Please sign in to comment.