diff --git a/packages/media_core/src/cluster/room/media_track/publisher.rs b/packages/media_core/src/cluster/room/media_track/publisher.rs index 4bd5733d..134828f6 100644 --- a/packages/media_core/src/cluster/room/media_track/publisher.rs +++ b/packages/media_core/src/cluster/room/media_track/publisher.rs @@ -3,7 +3,7 @@ //! use std::{ - collections::{HashMap, VecDeque}, + collections::{HashMap, HashSet, VecDeque}, fmt::Debug, hash::Hash, }; @@ -41,7 +41,7 @@ impl TryFrom for FeedbackKind { pub struct RoomChannelPublisher { room: ClusterRoomHash, tracks: HashMap<(Endpoint, RemoteTrackId), (PeerId, TrackName, ChannelId)>, - tracks_source: HashMap, + tracks_source: HashMap>, // We allow multi sources here for avoiding crash queue: VecDeque>, } @@ -61,21 +61,23 @@ impl RoomChannelPublisher { pub fn on_track_feedback(&mut self, channel: ChannelId, fb: Feedback) { let fb = return_if_err!(FeedbackKind::try_from(fb)); - let (endpoint, track_id) = return_if_none!(self.tracks_source.get(&channel)); - match fb { - FeedbackKind::Bitrate { min, max } => { - log::debug!("[ClusterRoom {}/Publishers] channel {channel} limit bitrate [{min},{max}]", self.room); - self.queue.push_back(Output::Endpoint( - vec![*endpoint], - ClusterEndpointEvent::RemoteTrack(*track_id, ClusterRemoteTrackEvent::LimitBitrate { min, max }), - )); - } - FeedbackKind::KeyFrameRequest => { - log::debug!("[ClusterRoom {}/Publishers] channel {channel} request key_frame", self.room); - self.queue.push_back(Output::Endpoint( - vec![*endpoint], - ClusterEndpointEvent::RemoteTrack(*track_id, ClusterRemoteTrackEvent::RequestKeyFrame), - )); + let sources = return_if_none!(self.tracks_source.get(&channel)); + for (endpoint, track_id) in sources { + match fb { + FeedbackKind::Bitrate { min, max } => { + log::debug!("[ClusterRoom {}/Publishers] channel {channel} limit bitrate [{min},{max}]", self.room); + self.queue.push_back(Output::Endpoint( + vec![*endpoint], + ClusterEndpointEvent::RemoteTrack(*track_id, ClusterRemoteTrackEvent::LimitBitrate { min, max }), + )); + } + FeedbackKind::KeyFrameRequest => { + log::debug!("[ClusterRoom {}/Publishers] channel {channel} request key_frame", self.room); + self.queue.push_back(Output::Endpoint( + vec![*endpoint], + ClusterEndpointEvent::RemoteTrack(*track_id, ClusterRemoteTrackEvent::RequestKeyFrame), + )); + } } } } @@ -84,9 +86,11 @@ impl RoomChannelPublisher { log::info!("[ClusterRoom {}/Publishers] peer ({peer} started track ({name})", self.room); let channel_id = id_generator::gen_channel_id(self.room, &peer, &name); self.tracks.insert((endpoint, track), (peer.clone(), name.clone(), channel_id)); - self.tracks_source.insert(channel_id, (endpoint, track)); - - self.queue.push_back(Output::Pubsub(pubsub::Control(channel_id, ChannelControl::PubStart))); + let sources = self.tracks_source.entry(channel_id).or_default(); + if sources.is_empty() { + self.queue.push_back(Output::Pubsub(pubsub::Control(channel_id, ChannelControl::PubStart))); + } + sources.insert((endpoint, track)); } pub fn on_track_data(&mut self, endpoint: Endpoint, track: RemoteTrackId, media: MediaPacket) { @@ -104,9 +108,15 @@ impl RoomChannelPublisher { pub fn on_track_unpublish(&mut self, endpoint: Endpoint, track: RemoteTrackId) { let (peer, name, channel_id) = return_if_none!(self.tracks.remove(&(endpoint, track))); - self.tracks_source.remove(&channel_id).expect("Should have track_source"); + let sources = self.tracks_source.get_mut(&channel_id).expect("Should have track_source"); + let removed = sources.remove(&(endpoint, track)); + assert!(removed, "Should remove source child on unpublish"); + if sources.is_empty() { + self.tracks_source.remove(&channel_id).expect("Should remove source channel on unpublish"); + self.queue.push_back(Output::Pubsub(pubsub::Control(channel_id, ChannelControl::PubStop))); + } + log::info!("[ClusterRoom {}/Publishers] peer ({peer} stopped track {name})", self.room); - self.queue.push_back(Output::Pubsub(pubsub::Control(channel_id, ChannelControl::PubStop))); if self.tracks.is_empty() { self.queue.push_back(Output::OnResourceEmpty); } @@ -132,7 +142,10 @@ impl Drop for RoomChannelPublisher { #[cfg(test)] mod tests { use atm0s_sdn::features::pubsub::{ChannelControl, Control, Feedback}; - use media_server_protocol::media::{MediaMeta, MediaPacket}; + use media_server_protocol::{ + endpoint::{PeerId, TrackName}, + media::{MediaMeta, MediaPacket}, + }; use sans_io_runtime::TaskSwitcherChild; use crate::{ @@ -220,4 +233,29 @@ mod tests { assert_eq!(publisher.pop_output(()), Some(Output::OnResourceEmpty)); assert_eq!(publisher.pop_output(()), None); } + + #[test] + fn two_sessions_same_room_peer_should_not_crash() { + let room = 1.into(); + let mut publisher = RoomChannelPublisher::::new(room); + + let endpoint1 = 1; + let endpoint2 = 2; + let track = RemoteTrackId(3); + let peer: PeerId = "peer1".to_string().into(); + let name: TrackName = "audio_main".to_string().into(); + + publisher.on_track_publish(endpoint1, track, peer.clone(), name.clone()); + publisher.on_track_publish(endpoint2, track, peer, name); + + assert!(publisher.pop_output(()).is_some()); // PubStart + assert!(publisher.pop_output(()).is_none()); + + publisher.on_track_unpublish(endpoint1, track); + publisher.on_track_unpublish(endpoint2, track); + + assert!(publisher.pop_output(()).is_some()); // PubStop + assert!(publisher.pop_output(()).is_some()); // OnResourceEmpty + assert!(publisher.pop_output(()).is_none()); + } }