Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: server crash if two sessions leaved with same room peer #376

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 61 additions & 23 deletions packages/media_core/src/cluster/room/media_track/publisher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//!

use std::{
collections::{HashMap, VecDeque},
collections::{HashMap, HashSet, VecDeque},
fmt::Debug,
hash::Hash,
};
Expand Down Expand Up @@ -41,7 +41,7 @@ impl TryFrom<Feedback> for FeedbackKind {
pub struct RoomChannelPublisher<Endpoint> {
room: ClusterRoomHash,
tracks: HashMap<(Endpoint, RemoteTrackId), (PeerId, TrackName, ChannelId)>,
tracks_source: HashMap<ChannelId, (Endpoint, RemoteTrackId)>,
tracks_source: HashMap<ChannelId, HashSet<(Endpoint, RemoteTrackId)>>, // We allow multi sources here for avoiding crash
queue: VecDeque<Output<Endpoint>>,
}

Expand All @@ -61,21 +61,23 @@ impl<Endpoint: Debug + Hash + Eq + Copy> RoomChannelPublisher<Endpoint> {

pub fn on_track_feedback(&mut self, channel: ChannelId, fb: Feedback) {
let fb = return_if_err!(FeedbackKind::try_from(fb));
let (endpoint, track_id) = return_if_none!(self.tracks_source.get(&channel));
match fb {
FeedbackKind::Bitrate { min, max } => {
log::debug!("[ClusterRoom {}/Publishers] channel {channel} limit bitrate [{min},{max}]", self.room);
self.queue.push_back(Output::Endpoint(
vec![*endpoint],
ClusterEndpointEvent::RemoteTrack(*track_id, ClusterRemoteTrackEvent::LimitBitrate { min, max }),
));
}
FeedbackKind::KeyFrameRequest => {
log::debug!("[ClusterRoom {}/Publishers] channel {channel} request key_frame", self.room);
self.queue.push_back(Output::Endpoint(
vec![*endpoint],
ClusterEndpointEvent::RemoteTrack(*track_id, ClusterRemoteTrackEvent::RequestKeyFrame),
));
let sources = return_if_none!(self.tracks_source.get(&channel));
for (endpoint, track_id) in sources {
match fb {
FeedbackKind::Bitrate { min, max } => {
log::debug!("[ClusterRoom {}/Publishers] channel {channel} limit bitrate [{min},{max}]", self.room);
self.queue.push_back(Output::Endpoint(
vec![*endpoint],
ClusterEndpointEvent::RemoteTrack(*track_id, ClusterRemoteTrackEvent::LimitBitrate { min, max }),
));
}
FeedbackKind::KeyFrameRequest => {
log::debug!("[ClusterRoom {}/Publishers] channel {channel} request key_frame", self.room);
self.queue.push_back(Output::Endpoint(
vec![*endpoint],
ClusterEndpointEvent::RemoteTrack(*track_id, ClusterRemoteTrackEvent::RequestKeyFrame),
));
}
}
}
}
Expand All @@ -84,9 +86,11 @@ impl<Endpoint: Debug + Hash + Eq + Copy> RoomChannelPublisher<Endpoint> {
log::info!("[ClusterRoom {}/Publishers] peer ({peer} started track ({name})", self.room);
let channel_id = id_generator::gen_channel_id(self.room, &peer, &name);
self.tracks.insert((endpoint, track), (peer.clone(), name.clone(), channel_id));
self.tracks_source.insert(channel_id, (endpoint, track));

self.queue.push_back(Output::Pubsub(pubsub::Control(channel_id, ChannelControl::PubStart)));
let sources = self.tracks_source.entry(channel_id).or_default();
if sources.is_empty() {
self.queue.push_back(Output::Pubsub(pubsub::Control(channel_id, ChannelControl::PubStart)));
}
sources.insert((endpoint, track));
}

pub fn on_track_data(&mut self, endpoint: Endpoint, track: RemoteTrackId, media: MediaPacket) {
Expand All @@ -104,9 +108,15 @@ impl<Endpoint: Debug + Hash + Eq + Copy> RoomChannelPublisher<Endpoint> {

pub fn on_track_unpublish(&mut self, endpoint: Endpoint, track: RemoteTrackId) {
let (peer, name, channel_id) = return_if_none!(self.tracks.remove(&(endpoint, track)));
self.tracks_source.remove(&channel_id).expect("Should have track_source");
let sources = self.tracks_source.get_mut(&channel_id).expect("Should have track_source");
let removed = sources.remove(&(endpoint, track));
assert!(removed, "Should remove source child on unpublish");
if sources.is_empty() {
self.tracks_source.remove(&channel_id).expect("Should remove source channel on unpublish");
self.queue.push_back(Output::Pubsub(pubsub::Control(channel_id, ChannelControl::PubStop)));
}

log::info!("[ClusterRoom {}/Publishers] peer ({peer} stopped track {name})", self.room);
self.queue.push_back(Output::Pubsub(pubsub::Control(channel_id, ChannelControl::PubStop)));
if self.tracks.is_empty() {
self.queue.push_back(Output::OnResourceEmpty);
}
Expand All @@ -132,7 +142,10 @@ impl<Endpoint> Drop for RoomChannelPublisher<Endpoint> {
#[cfg(test)]
mod tests {
use atm0s_sdn::features::pubsub::{ChannelControl, Control, Feedback};
use media_server_protocol::media::{MediaMeta, MediaPacket};
use media_server_protocol::{
endpoint::{PeerId, TrackName},
media::{MediaMeta, MediaPacket},
};
use sans_io_runtime::TaskSwitcherChild;

use crate::{
Expand Down Expand Up @@ -220,4 +233,29 @@ mod tests {
assert_eq!(publisher.pop_output(()), Some(Output::OnResourceEmpty));
assert_eq!(publisher.pop_output(()), None);
}

#[test]
fn two_sessions_same_room_peer_should_not_crash() {
let room = 1.into();
let mut publisher = RoomChannelPublisher::<u8>::new(room);

let endpoint1 = 1;
let endpoint2 = 2;
let track = RemoteTrackId(3);
let peer: PeerId = "peer1".to_string().into();
let name: TrackName = "audio_main".to_string().into();

publisher.on_track_publish(endpoint1, track, peer.clone(), name.clone());
publisher.on_track_publish(endpoint2, track, peer, name);

assert!(publisher.pop_output(()).is_some()); // PubStart
assert!(publisher.pop_output(()).is_none());

publisher.on_track_unpublish(endpoint1, track);
publisher.on_track_unpublish(endpoint2, track);

assert!(publisher.pop_output(()).is_some()); // PubStop
assert!(publisher.pop_output(()).is_some()); // OnResourceEmpty
assert!(publisher.pop_output(()).is_none());
}
}
Loading