From bc6facfdefc74e7209f35f01204e26f57daa5920 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Wed, 14 Aug 2024 14:55:13 +0200 Subject: [PATCH] Fix race condition due to incorrect use of Notify --- lib/tests/common/wait_map.rs | 30 +++++++++++++++++++----------- net/src/stun.rs | 6 +++++- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/lib/tests/common/wait_map.rs b/lib/tests/common/wait_map.rs index f9d7dfc85..6d2fc5c28 100644 --- a/lib/tests/common/wait_map.rs +++ b/lib/tests/common/wait_map.rs @@ -52,19 +52,27 @@ where Q: ToOwned + ?Sized, { loop { - let notify = match self.inner.lock().unwrap().entry(key.to_owned()) { - Entry::Occupied(entry) => match entry.get() { - Slot::Occupied(value) => return value.clone(), - Slot::Waiting(notify) => notify.clone(), - }, - Entry::Vacant(entry) => { - let notify = Arc::new(Notify::new()); - entry.insert(Slot::Waiting(notify.clone())); - notify - } + // We need to call `Notify::notified` while the mutex is still locked but then await it + // only after it's been unlocked. This is so we don't miss any notifications. + let mut notify = None; + let notified = { + let mut inner = self.inner.lock().unwrap(); + let new_notify = match inner.entry(key.to_owned()) { + Entry::Occupied(entry) => match entry.get() { + Slot::Occupied(value) => return value.clone(), + Slot::Waiting(notify) => notify.clone(), + }, + Entry::Vacant(entry) => { + let notify = Arc::new(Notify::new()); + entry.insert(Slot::Waiting(notify.clone())); + notify + } + }; + + notify.get_or_insert(new_notify).notified() }; - notify.notified().await; + notified.await; } } diff --git a/net/src/stun.rs b/net/src/stun.rs index d5296a984..92021c2e0 100644 --- a/net/src/stun.rs +++ b/net/src/stun.rs @@ -191,6 +191,10 @@ impl StunClient { let result = time::timeout(TIMEOUT, async move { loop { + // Need to obtain the `notified` future before removing the response but await it + // after. This is so we don't miss any notifications. + let notified = self.responses_notify.notified(); + if let Some(response) = self.remove_response(transaction_id) { return Ok(response); } @@ -205,7 +209,7 @@ impl StunClient { self.insert_response(response); } }, - _ = self.responses_notify.notified() => (), + _ = notified => (), } } })