From 2807ac19156e06375da30cce04fb4a690f67b34c Mon Sep 17 00:00:00 2001 From: Damien Murphy Date: Fri, 19 Jul 2024 10:40:03 -0700 Subject: [PATCH] ensure keep alive is sending correctly and keeping the websocket open, also use a mutex to avoid concurrent writes --- src/listen/live.rs | 69 ++++++++++++++++++++++++++-------------------- 1 file changed, 39 insertions(+), 30 deletions(-) diff --git a/src/listen/live.rs b/src/listen/live.rs index 88bb07fc..6ae3f8d2 100644 --- a/src/listen/live.rs +++ b/src/listen/live.rs @@ -13,6 +13,7 @@ use serde_urlencoded; use std::borrow::Cow; use std::path::Path; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; @@ -23,6 +24,7 @@ use futures::{SinkExt, Stream}; use http::Request; use pin_project::pin_project; use tokio::fs::File; +use tokio::sync::Mutex; use tokio::time; use tokio_tungstenite::tungstenite::protocol::Message; use tokio_util::io::ReaderStream; @@ -316,26 +318,52 @@ where builder.body(())? }; let (ws_stream, _) = tokio_tungstenite::connect_async(request).await?; - let (mut write, mut read) = ws_stream.split(); + let (write, mut read) = ws_stream.split(); + let write = Arc::new(Mutex::new(write)); let (mut tx, rx) = mpsc::channel::>(1); + // Spawn the keep-alive task + if self.keep_alive.unwrap_or(false) { + { + let write_clone = Arc::clone(&write); + tokio::spawn(async move { + let mut interval = time::interval(Duration::from_secs(1)); + loop { + println!("Keep Alive"); + interval.tick().await; + let keep_alive_message = Message::Text("{\"type\": \"KeepAlive\"}".to_string()); + let mut write = write_clone.lock().await; + if let Err(e) = write.send(keep_alive_message).await { + println!("Error Sending Keep Alive: {:?}", e); + break; + } + } + }) + }; + } + + let write_clone = Arc::clone(&write); let send_task = async move { - loop { - match source.next().await { - None => break, - Some(Ok(frame)) => { - // This unwrap is not safe. - write.send(frame).await.unwrap(); + while let Some(frame) = source.next().await { + match frame { + Ok(frame) => { + let mut write = write_clone.lock().await; + if let Err(e) = write.send(frame).await { + println!("Error sending frame: {:?}", e); + break; + } } - Some(e) => { - let _ = dbg!(e); + Err(e) => { + println!("Error receiving from source: {:?}", e); break; } } } - // This unwrap is not safe. - write.send(Message::binary([])).await.unwrap(); + let mut write = write_clone.lock().await; + if let Err(e) = write.send(Message::binary([])).await { + println!("Error sending final frame: {:?}", e); + } }; let recv_task = async move { @@ -367,25 +395,6 @@ where } pub fn keep_alive(mut self) -> Self { - tokio::spawn({ - async move { - let mut interval = time::interval(Duration::from_secs(10)); - loop { - interval.tick().await; - let (mut tx, _rx) = mpsc::channel(1); - let keep_alive_message = Message::Text("{\"type\": \"KeepAlive\"}".to_string()); - if let Err(e) = tx.send(keep_alive_message).await { - if e.is_disconnected() { - println!("Keep Alive waiting for connection",); - } else { - println!("Error Sending Keep Alive: {:?}", e); - break; - } - } - } - } - }); - self.keep_alive = Some(true); self