From 34234391f9c086bc605f1c61bc1013f51371b192 Mon Sep 17 00:00:00 2001 From: John Baublitz Date: Fri, 24 Jan 2025 15:21:29 -0500 Subject: [PATCH] Move internal representation of groups to vector of integers --- examples/genl_stream.rs | 4 +-- src/lib.rs | 2 +- src/router/asynchronous.rs | 2 +- src/router/synchronous.rs | 13 ++++---- src/socket/shared.rs | 2 +- src/utils.rs | 64 +++++++++++++++++++++++++------------- 6 files changed, 54 insertions(+), 33 deletions(-) diff --git a/examples/genl_stream.rs b/examples/genl_stream.rs index f9882c2c..1db95132 100644 --- a/examples/genl_stream.rs +++ b/examples/genl_stream.rs @@ -28,7 +28,7 @@ fn debug_stream() -> Result<(), Box> { let id = s .resolve_nl_mcast_group(&family_name, &mc_group_name) .await?; - s.add_mcast_membership(Groups::new_groups(&[id])?)?; + s.add_mcast_membership(Groups::new_groups(&[id]))?; while let Some(Ok(msg)) = multicast.next::>().await { println!("{msg:?}"); } @@ -53,7 +53,7 @@ fn debug_stream() -> Result<(), Box> { }; let (s, mc_recv) = NlRouter::connect(NlFamily::Generic, None, Groups::empty())?; let id = s.resolve_nl_mcast_group(&family_name, &mc_group_name)?; - s.add_mcast_membership(Groups::new_groups(&[id])?)?; + s.add_mcast_membership(Groups::new_groups(&[id]))?; for next in mc_recv { println!("{:?}", next?); } diff --git a/src/lib.rs b/src/lib.rs index eb279659..7516aa1b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -112,7 +112,7 @@ //! "my_family_name", //! "my_multicast_group_name", //! )?; -//! s.add_mcast_membership(Groups::new_groups(&[id])?)?; +//! s.add_mcast_membership(Groups::new_groups(&[id]))?; //! for next in multicast { //! // Do stuff here with parsed packets... //! diff --git a/src/router/asynchronous.rs b/src/router/asynchronous.rs index 39823a7c..46edcfb8 100644 --- a/src/router/asynchronous.rs +++ b/src/router/asynchronous.rs @@ -64,7 +64,7 @@ fn spawn_processing_thread(socket: Arc, senders: Senders) -> Pro Ok(m) => { let seq = *m.nl_seq(); let lock = senders.lock().await; - if group.as_bitmask() != 0 { + if !group.is_empty() { if multicast_sender.send(Ok(m)).await.is_err() { warn!("{}", RouterError::::ClosedChannel); } diff --git a/src/router/synchronous.rs b/src/router/synchronous.rs index 44b04302..df68cea5 100644 --- a/src/router/synchronous.rs +++ b/src/router/synchronous.rs @@ -68,7 +68,7 @@ fn spawn_processing_thread(socket: Arc, senders: Senders) -> Pro Ok(m) => { let seq = *m.nl_seq(); let lock = senders.lock(); - if group.as_bitmask() != 0 { + if !group.is_empty() { if multicast_sender.send(Ok(m)).is_err() { warn!("{}", RouterError::::ClosedChannel); } @@ -529,24 +529,25 @@ mod test { fn real_test_mcast_groups() { setup(); - let (sock, _) = NlRouter::connect(NlFamily::Generic, None, Groups::empty()).unwrap(); + let (sock, _multicast) = + NlRouter::connect(NlFamily::Generic, None, Groups::empty()).unwrap(); sock.enable_strict_checking(true).unwrap(); let notify_id_result = sock.resolve_nl_mcast_group("nlctrl", "notify"); let config_id_result = sock.resolve_nl_mcast_group("devlink", "config"); let ids = match (notify_id_result, config_id_result) { (Ok(ni), Ok(ci)) => { - sock.add_mcast_membership(Groups::new_groups(&[ni, ci]).unwrap()) + sock.add_mcast_membership(Groups::new_groups(&[ni, ci])) .unwrap(); vec![ni, ci] } (Ok(ni), Err(RouterError::Nlmsgerr(_))) => { - sock.add_mcast_membership(Groups::new_groups(&[ni]).unwrap()) + sock.add_mcast_membership(Groups::new_groups(&[ni])) .unwrap(); vec![ni] } (Err(RouterError::Nlmsgerr(_)), Ok(ci)) => { - sock.add_mcast_membership(Groups::new_groups(&[ci]).unwrap()) + sock.add_mcast_membership(Groups::new_groups(&[ci])) .unwrap(); vec![ci] } @@ -562,7 +563,7 @@ mod test { assert!(groups.is_set(*id as usize)); } - sock.drop_mcast_membership(Groups::new_groups(ids.as_slice()).unwrap()) + sock.drop_mcast_membership(Groups::new_groups(ids.as_slice())) .unwrap(); let groups = sock.list_mcast_membership().unwrap(); diff --git a/src/socket/shared.rs b/src/socket/shared.rs index 2189f97d..ed244939 100644 --- a/src/socket/shared.rs +++ b/src/socket/shared.rs @@ -88,7 +88,6 @@ impl NlSocket { let mut nladdr = unsafe { zeroed::() }; nladdr.nl_family = c_int::from(AddrFamily::Netlink) as u16; nladdr.nl_pid = pid.unwrap_or(0); - nladdr.nl_groups = groups.as_bitmask(); match unsafe { libc::bind( self.fd, @@ -99,6 +98,7 @@ impl NlSocket { i if i >= 0 => (), _ => return Err(io::Error::last_os_error()), }; + self.add_mcast_membership(groups)?; Ok(()) } diff --git a/src/utils.rs b/src/utils.rs index ddd5ab49..7005463a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -137,61 +137,84 @@ fn mask_to_vec(mask: u32) -> Vec { /// Struct implementing handling of groups both as numerical values and as /// bitmasks. -pub struct Groups(u32); +pub struct Groups(Vec); impl Groups { /// Create an empty set of netlink multicast groups pub fn empty() -> Self { - Groups(0) + Groups(vec![]) } /// Create a new set of groups with a bitmask. Each bit represents a group. pub fn new_bitmask(mask: u32) -> Self { - Groups(mask) + Groups(mask_to_vec(mask)) } /// Add a new bitmask to the existing group set. Each bit represents a group. pub fn add_bitmask(&mut self, mask: u32) { - self.0 |= mask + for group in mask_to_vec(mask) { + if !self.0.contains(&group) { + self.0.push(group); + } + } } /// Remove a bitmask from the existing group set. Each bit represents a group /// and each bit set to 1 will be removed. pub fn remove_bitmask(&mut self, mask: u32) { - self.0 &= !mask + let remove_items = mask_to_vec(mask); + self.0 = self + .0 + .drain(..) + .filter(|g| !remove_items.contains(g)) + .collect::>(); } /// Create a new set of groups from a list of numerical groups values. This differs /// from the bitmask representation where the value 3 represents group 3 in this /// format as opposed to 0x4 in the bitmask format. - pub fn new_groups(groups: &[u32]) -> Result { - Ok(Groups(slice_to_mask(groups)?)) + pub fn new_groups(groups: &[u32]) -> Self { + let mut vec = groups.to_owned(); + vec.retain(|g| g != &0); + Groups(vec) } /// Add a list of numerical groups values to the set of groups. This differs /// from the bitmask representation where the value 3 represents group 3 in this /// format as opposed to 0x4 in the bitmask format. - pub fn add_groups(&mut self, groups: &[u32]) -> Result<(), MsgError> { - self.add_bitmask(slice_to_mask(groups)?); - Ok(()) + pub fn add_groups(&mut self, groups: &[u32]) { + for group in groups { + if *group != 0 && !self.0.contains(group) { + self.0.push(*group) + } + } } /// Remove a list of numerical groups values from the set of groups. This differs /// from the bitmask representation where the value 3 represents group 3 in this /// format as opposed to 0x4 in the bitmask format. - pub fn remove_groups(&mut self, groups: &[u32]) -> Result<(), MsgError> { - self.remove_bitmask(slice_to_mask(groups)?); - Ok(()) + pub fn remove_groups(&mut self, groups: &[u32]) { + self.0.retain(|g| !groups.contains(g)); } /// Return the set of groups as a bitmask. The representation of a bitmask is u32. - pub fn as_bitmask(&self) -> u32 { - self.0 + pub fn as_bitmask(&self) -> Result { + slice_to_mask(&self.0) } /// Return the set of groups as a vector of group values. pub fn as_groups(&self) -> Vec { - mask_to_vec(self.0) + self.0.clone() + } + + /// Return the set of groups as a vector of group values. + pub fn into_groups(self) -> Vec { + self.0 + } + + /// Returns true if no group is set. + pub fn is_empty(&self) -> bool { + self.0.is_empty() } } @@ -587,11 +610,8 @@ mod test { fn test_groups() { setup(); - Groups::new_groups(&[0, 0, 0, 0]).unwrap(); - - assert!(Groups::new_groups(&[100]).is_err()); - assert!(Groups::new_groups(&[31]).is_ok()); - assert!(Groups::new_groups(&[32]).is_ok()); - assert!(Groups::new_groups(&[33]).is_err()); + assert_eq!(Groups::new_groups(&[0, 0, 0, 0]).as_bitmask().unwrap(), 0); + let groups = Groups::new_groups(&[0, 0, 0, 0]).as_groups(); + assert!(groups.is_empty()); } }