Skip to content

Commit

Permalink
Add API to get status of strict checking
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaublitz committed Jan 25, 2025
1 parent 5c62bb5 commit 0e619c5
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/router/asynchronous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,15 @@ impl NlRouter {
.map_err(RouterError::from)
}

/// Return [`true`] if strict checking is enabled for this socket.
/// Only supported by `NlFamily::Route` sockets.
/// Requires Linux >= 4.20.
pub fn get_strict_checking_enabled(&self) -> Result<bool, RouterError<u16, Buffer>> {
self.socket
.get_strict_checking_enabled()
.map_err(RouterError::from)
}

/// Get the PID for the current socket.
pub fn pid(&self) -> u32 {
self.socket.pid()
Expand Down
9 changes: 9 additions & 0 deletions src/router/synchronous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ impl NlRouter {
.map_err(RouterError::from)
}

/// Return [`true`] if strict checking is enabled for this socket.
/// Only supported by `NlFamily::Route` sockets.
/// Requires Linux >= 4.20.
pub fn get_strict_checking_enabled(&self) -> Result<bool, RouterError<u16, Buffer>> {
self.socket
.get_strict_checking_enabled()
.map_err(RouterError::from)
}

/// Get the PID for the current socket.
pub fn pid(&self) -> u32 {
self.socket.pid()
Expand Down
10 changes: 10 additions & 0 deletions src/socket/asynchronous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@ impl NlSocketHandle {
.enable_strict_checking(enable)
.map_err(SocketError::from)
}

/// Return [`true`] if strict checking is enabled for this socket.
/// Only supported by `NlFamily::Route` sockets.
/// Requires Linux >= 4.20.
pub fn get_strict_checking_enabled(&self) -> Result<bool, SocketError> {
self.socket
.get_ref()
.get_strict_checking_enabled()
.map_err(SocketError::from)
}
}

impl AsRawFd for NlSocketHandle {
Expand Down
30 changes: 30 additions & 0 deletions src/socket/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,26 @@ impl NlSocket {
_ => Err(io::Error::last_os_error()),
}
}

/// Return [`true`] if strict checking is enabled for this socket.
/// Only supported by `NlFamily::Route` sockets.
/// Requires Linux >= 4.20.
pub fn get_strict_checking_enabled(&self) -> Result<bool, io::Error> {
let mut sock_len = size_of::<libc::c_int>() as libc::socklen_t;
let mut sock_val: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
match unsafe {
libc::getsockopt(
self.fd,
libc::SOL_NETLINK,
libc::NETLINK_GET_STRICT_CHK,
&mut sock_val as *mut _ as *mut libc::c_void,
&mut sock_len as *mut _ as *mut libc::socklen_t,
)
} {
0 => Ok(unsafe { sock_val.assume_init() } != 0),
_ => Err(io::Error::last_os_error()),
}
}
}

#[cfg(feature = "sync")]
Expand Down Expand Up @@ -329,4 +349,14 @@ mod test {
let s = NlSocket::connect(NlFamily::Generic, Some(5555), Groups::empty()).unwrap();
assert_eq!(s.pid().unwrap(), 5555);
}

#[test]
fn real_strict_checking() {
setup();

let s = NlSocket::connect(NlFamily::Route, None, Groups::empty()).unwrap();
assert!(!s.get_strict_checking_enabled().unwrap());
s.enable_strict_checking(true).unwrap();
assert!(s.get_strict_checking_enabled().unwrap());
}
}
9 changes: 9 additions & 0 deletions src/socket/synchronous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ impl NlSocketHandle {
.enable_strict_checking(enable)
.map_err(SocketError::from)
}

/// Return [`true`] if strict checking is enabled for this socket.
/// Only supported by `NlFamily::Route` sockets.
/// Requires Linux >= 4.20.
pub fn get_strict_checking_enabled(&self) -> Result<bool, SocketError> {
self.socket
.get_strict_checking_enabled()
.map_err(SocketError::from)
}
}

impl AsRawFd for NlSocketHandle {
Expand Down

0 comments on commit 0e619c5

Please sign in to comment.