Skip to content

Commit

Permalink
Move internal representation of groups to vector of integers
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaublitz committed Jan 24, 2025
1 parent 7c43f29 commit 3423439
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 33 deletions.
4 changes: 2 additions & 2 deletions examples/genl_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fn debug_stream() -> Result<(), Box<dyn Error>> {
let id = s
.resolve_nl_mcast_group(&family_name, &mc_group_name)
.await?;
s.add_mcast_membership(Groups::new_groups(&[id])?)?;
s.add_mcast_membership(Groups::new_groups(&[id]))?;
while let Some(Ok(msg)) = multicast.next::<u16, Genlmsghdr<u8, u16>>().await {
println!("{msg:?}");
}
Expand All @@ -53,7 +53,7 @@ fn debug_stream() -> Result<(), Box<dyn Error>> {
};
let (s, mc_recv) = NlRouter::connect(NlFamily::Generic, None, Groups::empty())?;
let id = s.resolve_nl_mcast_group(&family_name, &mc_group_name)?;
s.add_mcast_membership(Groups::new_groups(&[id])?)?;
s.add_mcast_membership(Groups::new_groups(&[id]))?;
for next in mc_recv {
println!("{:?}", next?);
}
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
//! "my_family_name",
//! "my_multicast_group_name",
//! )?;
//! s.add_mcast_membership(Groups::new_groups(&[id])?)?;
//! s.add_mcast_membership(Groups::new_groups(&[id]))?;
//! for next in multicast {
//! // Do stuff here with parsed packets...
//!
Expand Down
2 changes: 1 addition & 1 deletion src/router/asynchronous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ fn spawn_processing_thread(socket: Arc<NlSocketHandle>, senders: Senders) -> Pro
Ok(m) => {
let seq = *m.nl_seq();
let lock = senders.lock().await;
if group.as_bitmask() != 0 {
if !group.is_empty() {
if multicast_sender.send(Ok(m)).await.is_err() {
warn!("{}", RouterError::<u16, Buffer>::ClosedChannel);
}
Expand Down
13 changes: 7 additions & 6 deletions src/router/synchronous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ fn spawn_processing_thread(socket: Arc<NlSocketHandle>, senders: Senders) -> Pro
Ok(m) => {
let seq = *m.nl_seq();
let lock = senders.lock();
if group.as_bitmask() != 0 {
if !group.is_empty() {
if multicast_sender.send(Ok(m)).is_err() {
warn!("{}", RouterError::<u16, Buffer>::ClosedChannel);
}
Expand Down Expand Up @@ -529,24 +529,25 @@ mod test {
fn real_test_mcast_groups() {
setup();

let (sock, _) = NlRouter::connect(NlFamily::Generic, None, Groups::empty()).unwrap();
let (sock, _multicast) =
NlRouter::connect(NlFamily::Generic, None, Groups::empty()).unwrap();
sock.enable_strict_checking(true).unwrap();
let notify_id_result = sock.resolve_nl_mcast_group("nlctrl", "notify");
let config_id_result = sock.resolve_nl_mcast_group("devlink", "config");

let ids = match (notify_id_result, config_id_result) {
(Ok(ni), Ok(ci)) => {
sock.add_mcast_membership(Groups::new_groups(&[ni, ci]).unwrap())
sock.add_mcast_membership(Groups::new_groups(&[ni, ci]))
.unwrap();
vec![ni, ci]
}
(Ok(ni), Err(RouterError::Nlmsgerr(_))) => {
sock.add_mcast_membership(Groups::new_groups(&[ni]).unwrap())
sock.add_mcast_membership(Groups::new_groups(&[ni]))
.unwrap();
vec![ni]
}
(Err(RouterError::Nlmsgerr(_)), Ok(ci)) => {
sock.add_mcast_membership(Groups::new_groups(&[ci]).unwrap())
sock.add_mcast_membership(Groups::new_groups(&[ci]))
.unwrap();
vec![ci]
}
Expand All @@ -562,7 +563,7 @@ mod test {
assert!(groups.is_set(*id as usize));
}

sock.drop_mcast_membership(Groups::new_groups(ids.as_slice()).unwrap())
sock.drop_mcast_membership(Groups::new_groups(ids.as_slice()))
.unwrap();
let groups = sock.list_mcast_membership().unwrap();

Expand Down
2 changes: 1 addition & 1 deletion src/socket/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ impl NlSocket {
let mut nladdr = unsafe { zeroed::<libc::sockaddr_nl>() };
nladdr.nl_family = c_int::from(AddrFamily::Netlink) as u16;
nladdr.nl_pid = pid.unwrap_or(0);
nladdr.nl_groups = groups.as_bitmask();
match unsafe {
libc::bind(
self.fd,
Expand All @@ -99,6 +98,7 @@ impl NlSocket {
i if i >= 0 => (),
_ => return Err(io::Error::last_os_error()),
};
self.add_mcast_membership(groups)?;
Ok(())
}

Expand Down
64 changes: 42 additions & 22 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,61 +137,84 @@ fn mask_to_vec(mask: u32) -> Vec<u32> {

/// Struct implementing handling of groups both as numerical values and as
/// bitmasks.
pub struct Groups(u32);
pub struct Groups(Vec<u32>);

impl Groups {
/// Create an empty set of netlink multicast groups
pub fn empty() -> Self {
Groups(0)
Groups(vec![])
}

/// Create a new set of groups with a bitmask. Each bit represents a group.
pub fn new_bitmask(mask: u32) -> Self {
Groups(mask)
Groups(mask_to_vec(mask))
}

/// Add a new bitmask to the existing group set. Each bit represents a group.
pub fn add_bitmask(&mut self, mask: u32) {
self.0 |= mask
for group in mask_to_vec(mask) {
if !self.0.contains(&group) {
self.0.push(group);
}
}
}

/// Remove a bitmask from the existing group set. Each bit represents a group
/// and each bit set to 1 will be removed.
pub fn remove_bitmask(&mut self, mask: u32) {
self.0 &= !mask
let remove_items = mask_to_vec(mask);
self.0 = self
.0
.drain(..)
.filter(|g| !remove_items.contains(g))
.collect::<Vec<_>>();
}

/// Create a new set of groups from a list of numerical groups values. This differs
/// from the bitmask representation where the value 3 represents group 3 in this
/// format as opposed to 0x4 in the bitmask format.
pub fn new_groups(groups: &[u32]) -> Result<Self, MsgError> {
Ok(Groups(slice_to_mask(groups)?))
pub fn new_groups(groups: &[u32]) -> Self {
let mut vec = groups.to_owned();
vec.retain(|g| g != &0);
Groups(vec)
}

/// Add a list of numerical groups values to the set of groups. This differs
/// from the bitmask representation where the value 3 represents group 3 in this
/// format as opposed to 0x4 in the bitmask format.
pub fn add_groups(&mut self, groups: &[u32]) -> Result<(), MsgError> {
self.add_bitmask(slice_to_mask(groups)?);
Ok(())
pub fn add_groups(&mut self, groups: &[u32]) {
for group in groups {
if *group != 0 && !self.0.contains(group) {
self.0.push(*group)
}
}
}

/// Remove a list of numerical groups values from the set of groups. This differs
/// from the bitmask representation where the value 3 represents group 3 in this
/// format as opposed to 0x4 in the bitmask format.
pub fn remove_groups(&mut self, groups: &[u32]) -> Result<(), MsgError> {
self.remove_bitmask(slice_to_mask(groups)?);
Ok(())
pub fn remove_groups(&mut self, groups: &[u32]) {
self.0.retain(|g| !groups.contains(g));
}

/// Return the set of groups as a bitmask. The representation of a bitmask is u32.
pub fn as_bitmask(&self) -> u32 {
self.0
pub fn as_bitmask(&self) -> Result<u32, MsgError> {
slice_to_mask(&self.0)
}

/// Return the set of groups as a vector of group values.
pub fn as_groups(&self) -> Vec<u32> {
mask_to_vec(self.0)
self.0.clone()
}

/// Return the set of groups as a vector of group values.
pub fn into_groups(self) -> Vec<u32> {
self.0
}

/// Returns true if no group is set.
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}

Expand Down Expand Up @@ -587,11 +610,8 @@ mod test {
fn test_groups() {
setup();

Groups::new_groups(&[0, 0, 0, 0]).unwrap();

assert!(Groups::new_groups(&[100]).is_err());
assert!(Groups::new_groups(&[31]).is_ok());
assert!(Groups::new_groups(&[32]).is_ok());
assert!(Groups::new_groups(&[33]).is_err());
assert_eq!(Groups::new_groups(&[0, 0, 0, 0]).as_bitmask().unwrap(), 0);
let groups = Groups::new_groups(&[0, 0, 0, 0]).as_groups();
assert!(groups.is_empty());
}
}

0 comments on commit 3423439

Please sign in to comment.