Skip to content

Commit

Permalink
Avoid double warning in deadline monitor
Browse files Browse the repository at this point in the history
  • Loading branch information
GnomedDev committed Nov 5, 2024
1 parent 63ac252 commit d9924e7
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 19 deletions.
8 changes: 6 additions & 2 deletions src/gtts.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::{sync::OnceLock, time::Duration};
use std::{
sync::{atomic::AtomicBool, Arc, OnceLock},
time::Duration,
};

use aformat::ToArrayString;
use ipgen::IpNetwork;
Expand Down Expand Up @@ -123,8 +126,9 @@ pub async fn get_tts(
state: &RwLock<State>,
text: &str,
voice: &str,
hit_any_deadline: Arc<AtomicBool>,
) -> Result<(bytes::Bytes, Option<reqwest::header::HeaderValue>)> {
let _guard = DeadlineMonitor::new(Duration::from_millis(1000), |took| {
let _guard = DeadlineMonitor::new(Duration::from_millis(1000), hit_any_deadline, |took| {
tracing::warn!("Fetching gTTS audio took {} millis!", took.as_millis());
});

Expand Down
58 changes: 41 additions & 17 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
use std::{
fmt::Display,
str::FromStr,
sync::OnceLock,
sync::{
atomic::{AtomicBool, Ordering},
Arc, OnceLock,
},
time::{Duration, Instant},
};

Expand Down Expand Up @@ -38,15 +41,17 @@ pub fn check_mp3_length(audio: &[u8], max_length: u64) -> bool {
}

pub struct DeadlineMonitor<F: FnOnce(Duration)> {
expected: Duration,
start: Instant,
expected: Duration,
on_not_hit: Option<F>,
hit_any_deadline: Arc<AtomicBool>,
}

impl<F: FnOnce(Duration)> DeadlineMonitor<F> {
pub fn new(expected: Duration, on_not_hit: F) -> Self {
pub fn new(expected: Duration, hit_any_deadline: Arc<AtomicBool>, on_not_hit: F) -> Self {
Self {
expected,
hit_any_deadline,
start: Instant::now(),
on_not_hit: Some(on_not_hit),
}
Expand All @@ -60,7 +65,7 @@ impl<F: FnOnce(Duration)> Drop for DeadlineMonitor<F> {
return;
};

if time_passed > self.expected {
if time_passed > self.expected && !self.hit_any_deadline.swap(true, Ordering::Relaxed) {
(self.on_not_hit.take().unwrap())(time_passed);
}
}
Expand Down Expand Up @@ -132,9 +137,14 @@ async fn get_tts(
tracing::debug!("Recieved request to TTS: {payload:?}");
}

let _guard = DeadlineMonitor::new(Duration::from_millis(5000), |took| {
tracing::warn!("get_tts took {} millis!", took.as_millis());
});
let hit_any_deadline = Arc::new(AtomicBool::new(false));
let _guard = DeadlineMonitor::new(
Duration::from_millis(5000),
hit_any_deadline.clone(),
|took| {
tracing::warn!("get_tts took {} millis!", took.as_millis());
},
);

let state = STATE.get().unwrap();
if let Some(auth_key) = state.auth_key.as_deref() {
Expand Down Expand Up @@ -169,9 +179,13 @@ async fn get_tts(
tracing::debug!("Recieved request to TTS: {cache_key}");

let redis_info = if let Some(redis_state) = &state.redis {
let _guard = DeadlineMonitor::new(Duration::from_millis(100), |took| {
tracing::warn!("Fetching from redis took {} millis!", took.as_millis());
});
let _guard = DeadlineMonitor::new(
Duration::from_millis(100),
hit_any_deadline.clone(),
|took| {
tracing::warn!("Fetching from redis took {} millis!", took.as_millis());
},
);

let cache_hash = sha2::Sha256::digest(&cache_key);

Expand Down Expand Up @@ -199,17 +213,23 @@ async fn get_tts(
return Err(Error::TranslationDisabled);
};

let _guard = DeadlineMonitor::new(Duration::from_millis(200), |took| {
tracing::warn!("Fetching translation took {} millis!", took.as_millis());
});
let _guard = DeadlineMonitor::new(
Duration::from_millis(200),
hit_any_deadline.clone(),
|took| {
tracing::warn!("Fetching translation took {} millis!", took.as_millis());
},
);

if let Some(translated) = translation::run(&state.reqwest, token, &text, &language).await? {
text = translated;
}
};

let (audio, content_type) = match mode {
TTSMode::gTTS => gtts::get_tts(&state.gtts, &text, &voice).await?,
TTSMode::gTTS => {
gtts::get_tts(&state.gtts, &text, &voice, hit_any_deadline.clone()).await?
}
TTSMode::eSpeak => {
espeak::get_tts(&text, &voice, speaking_rate.map_or(0, |r| r as u16)).await?
}
Expand Down Expand Up @@ -237,9 +257,13 @@ async fn get_tts(

tracing::debug!("Generated TTS from {cache_key}");
if let Some((mut redis_conn, key, cache_hash)) = redis_info {
let _guard = DeadlineMonitor::new(Duration::from_millis(100), |took| {
tracing::warn!("Caching audio result took {} millis!", took.as_millis());
});
let _guard = DeadlineMonitor::new(
Duration::from_millis(100),
hit_any_deadline.clone(),
|took| {
tracing::warn!("Caching audio result took {} millis!", took.as_millis());
},
);

if let Err(err) = redis_conn
.set::<_, _, ()>(&*cache_hash, key.encrypt(&audio))
Expand Down

0 comments on commit d9924e7

Please sign in to comment.