Skip to content

Commit

Permalink
Fix broken RecvMsgOut parsing (tokio-rs#257)
Browse files Browse the repository at this point in the history
* Fix broken RecvMsgOut parsing

Signed-off-by: Alex Saveau <[email protected]>

* Add payload truncation test

Signed-off-by: Alex Saveau <[email protected]>

---------

Signed-off-by: Alex Saveau <[email protected]>
  • Loading branch information
SUPERCILEX authored Jan 31, 2024
1 parent cc8060a commit 297f02b
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 36 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions io-uring-test/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ anyhow = "1"
tempfile = "3"
once_cell = "1"
socket2 = "0.5"
semver = "1.0.21"

[features]
direct-syscall = [ "io-uring/direct-syscall" ]
Expand Down
20 changes: 20 additions & 0 deletions io-uring-test/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,20 @@ mod tests;

use io_uring::{cqueue, squeue, IoUring, Probe};
use std::cell::Cell;
use std::ffi::CStr;
use std::mem;

pub struct Test {
probe: Probe,
target: Option<String>,
count: Cell<usize>,
kernel_version: semver::Version,
}

impl Test {
fn check_kernel_version(&self, min_version: &str) -> bool {
self.kernel_version >= semver::Version::parse(min_version).unwrap()
}
}

fn main() -> anyhow::Result<()> {
Expand Down Expand Up @@ -63,6 +72,16 @@ fn test<S: squeue::EntryMarker, C: cqueue::EntryMarker>(
probe,
target: std::env::args().nth(1),
count: Cell::new(0),
kernel_version: {
let mut uname: libc::utsname = unsafe { mem::zeroed() };
unsafe {
assert!(libc::uname(&mut uname) >= 0);
}

let version = unsafe { CStr::from_ptr(uname.release.as_ptr()) };
let version = version.to_str().unwrap();
semver::Version::parse(version).unwrap()
},
};

tests::queue::test_nop(&mut ring, &test)?;
Expand Down Expand Up @@ -132,6 +151,7 @@ fn test<S: squeue::EntryMarker, C: cqueue::EntryMarker>(
tests::net::test_shutdown(&mut ring, &test)?;
tests::net::test_socket(&mut ring, &test)?;
tests::net::test_udp_recvmsg_multishot(&mut ring, &test)?;
tests::net::test_udp_recvmsg_multishot_trunc(&mut ring, &test)?;
tests::net::test_udp_sendzc_with_dest(&mut ring, &test)?;

// queue
Expand Down
136 changes: 136 additions & 0 deletions io-uring-test/src/tests/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,142 @@ pub fn test_udp_recvmsg_multishot<S: squeue::EntryMarker, C: cqueue::EntryMarker

Ok(())
}
pub fn test_udp_recvmsg_multishot_trunc<S: squeue::EntryMarker, C: cqueue::EntryMarker>(
ring: &mut IoUring<S, C>,
test: &Test,
) -> anyhow::Result<()> {
require!(
test;
test.probe.is_supported(opcode::RecvMsgMulti::CODE);
test.probe.is_supported(opcode::ProvideBuffers::CODE);
test.probe.is_supported(opcode::SendMsg::CODE);
test.check_kernel_version("6.6.0" /* 6.2 is totally broken and returns nonsense upon truncation */);
);

println!("test udp_recvmsg_multishot_trunc");

let server_socket: socket2::Socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap().into();
let server_addr = server_socket.local_addr()?;

const BUF_GROUP: u16 = 33;
const DATA: &[u8] = b"testfooo for me";
let mut buf1 = [0u8; 20]; // 20 = size_of::<io_uring_recvmsg_out>() + msghdr.msg_namelen
let mut buf2 = [0u8; 20 + DATA.len()];
let mut buf3 = [0u8; 20 + DATA.len()];
let mut buffers = [
buf1.as_mut_slice(),
buf2.as_mut_slice(),
buf3.as_mut_slice(),
];

for (index, buf) in buffers.iter_mut().enumerate() {
let provide_bufs_e = io_uring::opcode::ProvideBuffers::new(
(**buf).as_mut_ptr(),
buf.len() as i32,
1,
BUF_GROUP,
index as u16,
)
.build()
.user_data(11)
.into();
unsafe { ring.submission().push(&provide_bufs_e)? };
ring.submitter().submit_and_wait(1)?;
let cqes: Vec<io_uring::cqueue::Entry> = ring.completion().map(Into::into).collect();
assert_eq!(cqes.len(), 1);
assert_eq!(cqes[0].user_data(), 11);
assert_eq!(cqes[0].result(), 0);
assert_eq!(cqes[0].flags(), 0);
}

// This structure is actually only used for input arguments to the kernel
// (and only name length and control length are actually relevant).
let mut msghdr: libc::msghdr = unsafe { std::mem::zeroed() };
msghdr.msg_namelen = 4;

let recvmsg_e = opcode::RecvMsgMulti::new(
Fd(server_socket.as_raw_fd()),
&msghdr as *const _,
BUF_GROUP,
)
.flags(libc::MSG_TRUNC as u32)
.build()
.user_data(77)
.into();
unsafe { ring.submission().push(&recvmsg_e)? };
ring.submitter().submit().unwrap();

let client_socket: socket2::Socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap().into();

let data = [io::IoSlice::new(DATA)];
let mut msghdr1: libc::msghdr = unsafe { mem::zeroed() };
msghdr1.msg_name = server_addr.as_ptr() as *const _ as *mut _;
msghdr1.msg_namelen = server_addr.len();
msghdr1.msg_iov = data.as_ptr() as *const _ as *mut _;
msghdr1.msg_iovlen = 1;

let send_msgs = (0..2)
.map(|_| {
opcode::SendMsg::new(Fd(client_socket.as_raw_fd()), &msghdr1 as *const _)
.build()
.user_data(55)
.into()
})
.collect::<Vec<_>>();
unsafe { ring.submission().push_multiple(&send_msgs)? };
ring.submitter().submit().unwrap();

ring.submitter().submit_and_wait(4).unwrap();
let cqes: Vec<io_uring::cqueue::Entry> = ring.completion().map(Into::into).collect();
assert_eq!(cqes.len(), 4);
let mut i = 0;
for cqe in cqes {
let is_more = io_uring::cqueue::more(cqe.flags());
match cqe.user_data() {
// send notifications
55 => {
assert!(cqe.result() > 0);
assert!(!is_more);
}
// RecvMsgMulti
77 => {
assert!(cqe.result() > 0);
assert!(is_more);
let buf_id = io_uring::cqueue::buffer_select(cqe.flags()).unwrap();
let tmp_buf = &buffers[buf_id as usize];
let msg = types::RecvMsgOut::parse(tmp_buf, &msghdr);

match i {
0 => {
let msg = msg.unwrap();
assert!(msg.is_payload_truncated());
assert!(msg.is_name_data_truncated());
assert_eq!(DATA.len(), msg.incoming_payload_len() as usize);
assert!(msg.payload_data().is_empty());
assert!(4 < msg.incoming_name_len());
assert_eq!(4, msg.name_data().len());
}
1 => {
let msg = msg.unwrap();
assert!(!msg.is_payload_truncated());
assert!(msg.is_name_data_truncated());
assert_eq!(DATA.len(), msg.incoming_payload_len() as usize);
assert_eq!(DATA, msg.payload_data());
assert!(4 < msg.incoming_name_len());
assert_eq!(4, msg.name_data().len());
}
_ => unreachable!(),
}
i += 1;
}
_ => {
unreachable!()
}
}
}

Ok(())
}
pub fn test_udp_sendzc_with_dest<S: squeue::EntryMarker, C: cqueue::EntryMarker>(
ring: &mut IoUring<S, C>,
test: &Test,
Expand Down
5 changes: 3 additions & 2 deletions io-uring-test/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ macro_rules! require {
$test:expr;
$( $cond:expr ; )*
) => {
let test = $test;
let mut cond = true;

if let Some(target) = $test.target.as_ref() {
if let Some(target) = test.target.as_ref() {
cond &= function_name!().contains(target);
}

Expand All @@ -20,7 +21,7 @@ macro_rules! require {
return Ok(());
}

$test.count.set($test.count.get() + 1);
test.count.set(test.count.get() + 1);
}
}

Expand Down
75 changes: 41 additions & 34 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub(crate) mod sealed {
use crate::sys;
use crate::util::{cast_ptr, unwrap_nonzero, unwrap_u32};
use bitflags::bitflags;
use std::convert::TryFrom;
use std::marker::PhantomData;
use std::num::NonZeroU32;
use std::os::unix::io::RawFd;
Expand Down Expand Up @@ -377,10 +378,7 @@ pub struct RecvMsgOut<'buf> {
/// If it is smaller, it gets 0-padded to fill the whole field. In either case,
/// this fixed amount of space is reserved in the result buffer.
msghdr_name_len: usize,
/// The fixed length of the control field, in bytes.
///
/// This follows the same semantics as the field above, but for control data.
msghdr_control_len: usize,

name_data: &'buf [u8],
control_data: &'buf [u8],
payload_data: &'buf [u8],
Expand All @@ -396,7 +394,15 @@ impl<'buf> RecvMsgOut<'buf> {
/// (only `msg_namelen` and `msg_controllen` fields are relevant).
#[allow(clippy::result_unit_err)]
pub fn parse(buffer: &'buf [u8], msghdr: &libc::msghdr) -> Result<Self, ()> {
if buffer.len() < std::mem::size_of::<sys::io_uring_recvmsg_out>() {
let msghdr_name_len = usize::try_from(msghdr.msg_namelen).unwrap();
let msghdr_control_len = usize::try_from(msghdr.msg_controllen).unwrap();

if Self::DATA_START
.checked_add(msghdr_name_len)
.and_then(|acc| acc.checked_add(msghdr_control_len))
.map(|header_len| buffer.len() < header_len)
.unwrap_or(true)
{
return Err(());
}
// SAFETY: buffer (minimum) length is checked here above.
Expand All @@ -407,45 +413,36 @@ impl<'buf> RecvMsgOut<'buf> {
.read_unaligned()
};

let msghdr_name_len = msghdr.msg_namelen as _;
let msghdr_control_len = msghdr.msg_controllen as _;

// Check total length upfront, so that further logic here
// below can safely use unchecked/saturating math.
let length_overflow = Some(Self::DATA_START)
.and_then(|acc| acc.checked_add(msghdr_name_len))
.and_then(|acc| acc.checked_add(msghdr_control_len))
.and_then(|acc| acc.checked_add(header.payloadlen as usize))
.map(|total_len| total_len > buffer.len())
.unwrap_or(true);
if length_overflow {
return Err(());
}

// min is used because the header may indicate the true size of the data
// while what we received was truncated.
let (name_data, control_start) = {
let name_start = Self::DATA_START;
let name_size = usize::min(header.namelen as usize, msghdr_name_len);
let name_data_end = name_start.saturating_add(name_size);
let name_data = &buffer[name_start..name_data_end];
let name_field_end = name_start.saturating_add(msghdr_name_len);
(name_data, name_field_end)
let name_data_end =
name_start + usize::min(usize::try_from(header.namelen).unwrap(), msghdr_name_len);
let name_field_end = name_start + msghdr_name_len;
(&buffer[name_start..name_data_end], name_field_end)
};
let (control_data, payload_start) = {
let control_size = usize::min(header.controllen as usize, msghdr_control_len);
let control_data_end = control_start.saturating_add(control_size);
let control_data = &buffer[control_start..control_data_end];
let control_field_end = control_start.saturating_add(msghdr_control_len);
(control_data, control_field_end)
let control_data_end = control_start
+ usize::min(
usize::try_from(header.controllen).unwrap(),
msghdr_control_len,
);
let control_field_end = control_start + msghdr_control_len;
(&buffer[control_start..control_data_end], control_field_end)
};
let payload_data = {
let payload_data_end = payload_start.saturating_add(header.payloadlen as usize);
let payload_data_end = payload_start
+ usize::min(
usize::try_from(header.payloadlen).unwrap(),
buffer.len() - payload_start,
);
&buffer[payload_start..payload_data_end]
};

Ok(Self {
header,
msghdr_name_len,
msghdr_control_len,
name_data,
control_data,
payload_data,
Expand Down Expand Up @@ -490,7 +487,7 @@ impl<'buf> RecvMsgOut<'buf> {
/// When `true`, data returned by `control_data()` is truncated and
/// incomplete.
pub fn is_control_data_truncated(&self) -> bool {
self.header.controllen as usize > self.msghdr_control_len
(self.header.flags & u32::try_from(libc::MSG_CTRUNC).unwrap()) != 0
}

/// Message control data, with the same semantics as `msghdr.msg_control`.
Expand All @@ -503,14 +500,24 @@ impl<'buf> RecvMsgOut<'buf> {
/// When `true`, data returned by `payload_data()` is truncated and
/// incomplete.
pub fn is_payload_truncated(&self) -> bool {
self.header.flags & (libc::MSG_TRUNC as u32) != 0
(self.header.flags & u32::try_from(libc::MSG_TRUNC).unwrap()) != 0
}

/// Message payload, as buffered by the kernel.
pub fn payload_data(&self) -> &[u8] {
self.payload_data
}

/// Return the length of the incoming `payload` data.
///
/// This may be larger than the size of the content returned by
/// `payload_data()`, if the kernel could not fit all the incoming
/// data in the provided buffer size. In that case, payload data in
/// the result buffer gets truncated.
pub fn incoming_payload_len(&self) -> u32 {
self.header.payloadlen
}

/// Message flags, with the same semantics as `msghdr.msg_flags`.
pub fn flags(&self) -> u32 {
self.header.flags
Expand Down

0 comments on commit 297f02b

Please sign in to comment.