diff --git a/Cargo.toml b/Cargo.toml index 080cffe..58fdb0c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ os_socketaddr = "0.2" bitmask-enum = "2.2" lazy_static = "1.5.0" serde = { version = "1.0", features = ["derive"] } +thiserror = "1.0.64" [dev-dependencies] trybuild = "1.0" @@ -33,3 +34,15 @@ quanta = "0.12" byte-unit = "5.1" ouroboros = "0.18" proptest = "1.5" +anyhow = "1.0" + +[features] +debug = [] + +[[example]] +name = "rc_pingpong" +required-features = ["debug"] + +[[example]] +name = "rc_pingpong_split" +required-features = ["debug"] diff --git a/examples/rc_pingpong.rs b/examples/rc_pingpong.rs index a2d8207..3e9b0ab 100644 --- a/examples/rc_pingpong.rs +++ b/examples/rc_pingpong.rs @@ -115,7 +115,7 @@ struct TimeStamps { } #[allow(clippy::while_let_on_iterator)] -fn main() -> Result<(), Box> { +fn main() -> anyhow::Result<()> { let args = Args::parse(); let mut scnt: u32 = 0; let mut rcnt: u32 = 0; @@ -135,16 +135,16 @@ fn main() -> Result<(), Box> { let device = match args.ib_dev { Some(ib_dev) => device_list .iter() - .find(|dev| dev.name().unwrap().eq(&ib_dev)) + .find(|dev| dev.name()?.eq(&ib_dev)) .unwrap_or_else(|| panic!("IB device {ib_dev} not found")), None => device_list.iter().next().expect("No IB device found"), }; let context = device .open() - .unwrap_or_else(|_| panic!("Couldn't get context for {}", device.name().unwrap())); + .unwrap_or_else(|_| panic!("Couldn't get context for {}", device.name()?)); - let attr = context.query_device().unwrap(); + let attr = context.query_device()?; if args.ts { completion_timestamp_mask = attr.completion_timestamp_mask(); @@ -162,7 +162,7 @@ fn main() -> Result<(), Box> { .reg_managed_mr(args.size as _) .unwrap_or_else(|_| panic!("Couldn't register recv MR")); - let gid = context.query_gid(args.ib_port, args.gid_idx.into()).unwrap(); + let gid = context.query_gid(args.ib_port, args.gid_idx.into())?; let psn = rand::random::() & 0xFFFFFF; let mut cq_builder = context.create_cq_builder(); @@ -174,7 +174,7 @@ fn main() -> Result<(), Box> { ); } - let cq = cq_builder.setup_cqe(rx_depth + 1).build_ex().unwrap(); + let cq = cq_builder.setup_cqe(rx_depth + 1).build_ex()?; let mut builder = pd.create_qp_builder(); @@ -192,7 +192,7 @@ fn main() -> Result<(), Box> { .setup_pkey_index(0) .setup_port(args.ib_port) .setup_access_flags(AccessFlags::LocalWrite | AccessFlags::RemoteWrite); - qp.modify(&attr).unwrap(); + qp.modify(&attr)?; for _i in 0..rx_depth { let mut guard = qp.start_post_recv(); @@ -203,7 +203,7 @@ fn main() -> Result<(), Box> { recv_handle.setup_sge(recv_mr.lkey(), recv_mr.buf.data.as_ptr() as _, args.size); }; - guard.post().unwrap(); + guard.post()?; } rout += rx_depth; @@ -225,7 +225,7 @@ fn main() -> Result<(), Box> { }; let send_context = |stream: &mut TcpStream, dest: &PingPongDestination| { - let msg_buf = to_allocvec(dest).unwrap(); + let msg_buf = to_allocvec(dest)?; let size = msg_buf.len().to_be_bytes(); stream.write_all(&size)?; stream.write_all(&msg_buf)?; @@ -240,7 +240,7 @@ fn main() -> Result<(), Box> { msg_buf.clear(); msg_buf.resize(usize::from_be_bytes(size), 0); stream.read_exact(&mut *msg_buf)?; - let dest: PingPongDestination = from_bytes(msg_buf).unwrap(); + let dest: PingPongDestination = from_bytes(msg_buf)?; Ok::(dest) }; @@ -278,7 +278,7 @@ fn main() -> Result<(), Box> { .setup_grh_dest_gid(&remote_context.gid) .setup_grh_hop_limit(1); attr.setup_address_vector(&ah_attr); - qp.modify(&attr).unwrap(); + qp.modify(&attr)?; let mut attr = QueuePairAttribute::new(); attr.setup_state(QueuePairState::ReadyToSend) @@ -288,7 +288,7 @@ fn main() -> Result<(), Box> { .setup_rnr_retry(7) .setup_max_read_atomic(0); - qp.modify(&attr).unwrap(); + qp.modify(&attr)?; let clock = quanta::Clock::new(); let start_time = clock.now(); @@ -303,7 +303,7 @@ fn main() -> Result<(), Box> { send_handle.setup_sge(send_mr.lkey(), send_mr.buf.data.as_ptr() as _, send_mr.buf.len as _); } - guard.post().unwrap(); + guard.post()?; outstanding_send = true; } // poll for the completion @@ -342,7 +342,7 @@ fn main() -> Result<(), Box> { args.size, ); }; - guard.post().unwrap(); + guard.post()?; } rout += to_post; } @@ -388,7 +388,7 @@ fn main() -> Result<(), Box> { send_mr.buf.len as _, ); } - guard.post().unwrap(); + guard.post()?; outstanding_send = true; } } @@ -414,9 +414,7 @@ fn main() -> Result<(), Box> { "{} bytes in {:.2} seconds = {:.2}/s", bytes, time.as_secs_f64(), - Byte::from_f64(bytes_per_second) - .unwrap() - .get_appropriate_unit(UnitType::Binary) + Byte::from_f64(bytes_per_second)?.get_appropriate_unit(UnitType::Binary) ); println!( "{} iters in {:.2} seconds = {:#.2?}/iter", diff --git a/examples/rc_pingpong_split.rs b/examples/rc_pingpong_split.rs index a3cb3e6..079ac94 100644 --- a/examples/rc_pingpong_split.rs +++ b/examples/rc_pingpong_split.rs @@ -140,12 +140,12 @@ struct PingPongContext { } impl PingPongContext { - fn build(device: &Device, size: u32, rx_depth: u32, ib_port: u8, use_ts: bool) -> Result { + fn build(device: &Device, size: u32, rx_depth: u32, ib_port: u8, use_ts: bool) -> anyhow::Result { let context = device .open() - .unwrap_or_else(|_| panic!("Couldn't get context for {}", device.name().unwrap())); + .unwrap_or_else(|_| panic!("Couldn't get context for {}", device.name()?)); - let attr = context.query_device().unwrap(); + let attr = context.query_device()?; let completion_timestamp_mask = if use_ts { match attr.completion_timestamp_mask() { @@ -167,7 +167,7 @@ impl PingPongContext { | CreateCompletionQueueWorkCompletionFlags::CompletionTimestamp, ); } - let cq = cq_builder.setup_cqe(rx_depth + 1).build_ex().unwrap(); + let cq = cq_builder.setup_cqe(rx_depth + 1).build_ex()?; cq }, |pd, cq| { @@ -209,7 +209,7 @@ impl PingPongContext { )) } - fn post_recv(&mut self, num: u32) -> Result<(), String> { + fn post_recv(&mut self, num: u32) -> anyhow::Result<()> { for _i in 0..num { let (mut guard, lkey, ptr, size) = self.with_mut(|fields| { ( @@ -226,13 +226,13 @@ impl PingPongContext { recv_handle.setup_sge(lkey, ptr, size); }; - guard.post().unwrap(); + guard.post()?; } Ok(()) } - fn post_send(&mut self) -> Result<(), String> { + fn post_send(&mut self) -> anyhow::Result<()> { let (mut guard, lkey, ptr, size) = self.with_mut(|fields| { ( fields.qp.start_post_send(), @@ -253,7 +253,7 @@ impl PingPongContext { fn connect( &mut self, remote_context: &PingPongDestination, ib_port: u8, psn: u32, mtu: Mtu, sl: u8, gid_idx: u8, - ) -> Result<(), String> { + ) -> anyhow::Result<()> { let mut attr = QueuePairAttribute::new(); attr.setup_state(QueuePairState::ReadyToReceive) .setup_path_mtu(mtu) @@ -362,7 +362,7 @@ struct TimeStamps { } #[allow(clippy::while_let_on_iterator)] -fn main() -> Result<(), Box> { +fn main() -> anyhow::Result<()> { let args = Args::parse(); let mut scnt: u32 = 0; let mut rcnt: u32 = 0; @@ -381,17 +381,17 @@ fn main() -> Result<(), Box> { let device = match args.ib_dev { Some(ib_dev) => device_list .iter() - .find(|dev| dev.name().unwrap().eq(&ib_dev)) + .find(|dev| dev.name()?.eq(&ib_dev)) .unwrap_or_else(|| panic!("IB device {ib_dev} not found")), None => device_list.iter().next().expect("No IB device found"), }; - let mut ctx = PingPongContext::build(&device, args.size, rx_depth, args.ib_port, args.ts).unwrap(); + let mut ctx = PingPongContext::build(&device, args.size, rx_depth, args.ib_port, args.ts)?; - let gid = ctx.borrow_ctx().query_gid(args.ib_port, args.gid_idx.into()).unwrap(); + let gid = ctx.borrow_ctx().query_gid(args.ib_port, args.gid_idx.into())?; let psn = rand::random::() & 0xFFFFFF; - ctx.post_recv(rx_depth).unwrap(); + ctx.post_recv(rx_depth)?; rout += rx_depth; println!( @@ -414,7 +414,7 @@ fn main() -> Result<(), Box> { }; let send_context = |stream: &mut TcpStream, dest: &PingPongDestination| { - let msg_buf = to_allocvec(dest).unwrap(); + let msg_buf = to_allocvec(dest)?; let size = msg_buf.len().to_be_bytes(); stream.write_all(&size)?; stream.write_all(&msg_buf)?; @@ -429,7 +429,7 @@ fn main() -> Result<(), Box> { msg_buf.clear(); msg_buf.resize(usize::from_be_bytes(size), 0); stream.read_exact(&mut *msg_buf)?; - let dest: PingPongDestination = from_bytes(msg_buf).unwrap(); + let dest: PingPongDestination = from_bytes(msg_buf)?; Ok::(dest) }; @@ -448,15 +448,14 @@ fn main() -> Result<(), Box> { remote_context.qp_number, remote_context.packet_seq_number, remote_context.gid ); - ctx.connect(&remote_context, args.ib_port, psn, args.mtu.0, args.sl, args.gid_idx) - .unwrap(); + ctx.connect(&remote_context, args.ib_port, psn, args.mtu.0, args.sl, args.gid_idx)?; let clock = quanta::Clock::new(); let start_time = clock.now(); let mut outstanding_send = false; if args.server_ip.is_some() { - ctx.post_send().unwrap(); + ctx.post_send()?; outstanding_send = true; } // poll for the completion @@ -497,12 +496,12 @@ fn main() -> Result<(), Box> { } if need_post_recv { - ctx.post_recv(to_post_recv).unwrap(); + ctx.post_recv(to_post_recv)?; rout += to_post_recv; } if need_post_send { - ctx.post_send().unwrap(); + ctx.post_send()?; } // Check if we're done @@ -521,8 +520,7 @@ fn main() -> Result<(), Box> { "{} bytes in {:.2} seconds = {:.2}/s", bytes, time.as_secs_f64(), - Byte::from_f64(bytes_per_second) - .unwrap() + Byte::from_f64(bytes_per_second)? .get_appropriate_unit(UnitType::Binary) ); println!( diff --git a/src/verbs/queue_pair.rs b/src/verbs/queue_pair.rs index c7d7131..5503361 100644 --- a/src/verbs/queue_pair.rs +++ b/src/verbs/queue_pair.rs @@ -1,5 +1,4 @@ use bitmask_enum::bitmask; -use lazy_static::lazy_static; use rdma_mummy_sys::{ ibv_create_qp, ibv_create_qp_ex, ibv_data_buf, ibv_destroy_qp, ibv_modify_qp, ibv_post_recv, ibv_post_send, ibv_qp, ibv_qp_attr, ibv_qp_attr_mask, ibv_qp_cap, ibv_qp_create_send_ops_flags, ibv_qp_ex, ibv_qp_init_attr, @@ -20,6 +19,62 @@ use super::{ protection_domain::ProtectionDomain, AccessFlags, }; +#[cfg(feature = "debug")] +use crate::verbs::address::Gid; + +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum ModifyQueuePairError { + #[error("modify queue pair failed")] + GenericError(#[from] io::Error), + #[cfg(feature = "debug")] + #[error("invalid transition from {cur_state:?} to {next_state:?}")] + InvalidTransition { + cur_state: QueuePairState, + next_state: QueuePairState, + source: io::Error, + }, + #[cfg(feature = "debug")] + #[error("invalid transition from {cur_state:?} to {next_state:?}, possible invalid masks {invalid:?}, possible needed masks {needed:?}")] + InvalidAttributeMask { + cur_state: QueuePairState, + next_state: QueuePairState, + invalid: QueuePairAttributeMask, + needed: QueuePairAttributeMask, + source: io::Error, + }, + #[cfg(feature = "debug")] + #[error("resolve route timed out, source gid index: {sgid_index}, destination gid: {gid}")] + ResolveRouteTimedout { + sgid_index: u8, + gid: Gid, + source: io::Error, + }, + #[cfg(feature = "debug")] + #[error("network unreachable, source gid index: {sgid_index}, destination gid: {gid}")] + NetworkUnreachable { + sgid_index: u8, + gid: Gid, + source: io::Error, + }, +} + +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum PostSendError { + #[error("post send failed")] + GenericError(#[from] io::Error), + #[cfg(feature = "debug")] + #[error("invalid value provided in work request")] + InvalidWorkRequest(#[source] io::Error), + #[cfg(feature = "debug")] + #[error("invalid value provided in queue pair")] + InvalidValueInQueuePair(#[source] io::Error), + #[cfg(feature = "debug")] + #[error("send queue is full or not enough resources to complete this operation")] + NotEnoughResources(#[source] io::Error), +} + #[repr(u32)] #[derive(Debug, Clone, Copy)] pub enum QueuePairType { @@ -138,22 +193,44 @@ pub trait QueuePair { /// associated with the return value. unsafe fn qp(&self) -> NonNull; - fn modify(&mut self, attr: &QueuePairAttribute) -> Result<(), String> { + fn modify(&mut self, attr: &QueuePairAttribute) -> Result<(), ModifyQueuePairError> { // ibv_qp_attr does not impl Clone trait, so we use struct update syntax here let mut qp_attr = ibv_qp_attr { ..attr.attr }; let ret = unsafe { ibv_modify_qp(self.qp().as_ptr(), &mut qp_attr as *mut _, attr.attr_mask.bits) }; if ret == 0 { Ok(()) } else { - // User doesn't pass in a mask with IBV_QP_STATE, we just assume user doesn't - // want to change the state, pass self.state() as next_state - if attr.attr_mask.contains(QueuePairAttributeMask::State) { - attr_mask_check(attr.attr_mask, self.state(), attr.attr.qp_state.into()).unwrap(); - } else { - attr_mask_check(attr.attr_mask, self.state(), self.state()).unwrap(); + match ret { + #[cfg(feature = "debug")] + libc::EINVAL => { + // User doesn't pass in a mask with IBV_QP_STATE, we just assume user doesn't + // want to change the state, pass self.state() as next_state + let err = if attr.attr_mask.contains(QueuePairAttributeMask::State) { + attr_mask_check(attr.attr_mask, self.state(), attr.attr.qp_state.into()) + } else { + attr_mask_check(attr.attr_mask, self.state(), self.state()) + }; + match err { + Ok(()) => Err(ModifyQueuePairError::GenericError(io::Error::from_raw_os_error( + libc::EINVAL, + ))), + Err(err) => Err(err), + } + }, + #[cfg(feature = "debug")] + libc::ETIMEDOUT => Err(ModifyQueuePairError::ResolveRouteTimedout { + sgid_index: attr.attr.ah_attr.grh.sgid_index, + gid: attr.attr.ah_attr.grh.dgid.into(), + source: io::Error::from_raw_os_error(libc::ETIMEDOUT), + }), + #[cfg(feature = "debug")] + libc::ENETUNREACH => Err(ModifyQueuePairError::NetworkUnreachable { + sgid_index: attr.attr.ah_attr.grh.sgid_index, + gid: attr.attr.ah_attr.grh.dgid.into(), + source: io::Error::from_raw_os_error(libc::ENETUNREACH), + }), + err => Err(ModifyQueuePairError::GenericError(io::Error::from_raw_os_error(err))), } - - Err(format!("ibv_modify_qp failed, err={ret}")) } } @@ -217,7 +294,7 @@ pub trait PostSendGuard: private_traits::PostSendGuard { // every qp should hold only one WorkRequestHandle at the same time fn construct_wr(&mut self, wr_id: u64, wr_flags: WorkRequestFlags) -> WorkRequestHandle<'_, Self>; - fn post(self) -> Result<(), String>; + fn post(self) -> Result<(), PostSendError>; } // According to C standard, enums should be int, but Rust just uses whatever @@ -263,6 +340,7 @@ pub enum QueuePairAttributeMask { // // We should consider using `std::mem::variant_count` here, after it stablized. // +#[cfg(feature = "debug")] #[derive(Debug, Copy, Clone)] struct QueuePairStateTableEntry { // whether this state transition is valid. @@ -271,6 +349,10 @@ struct QueuePairStateTableEntry { optional_mask: QueuePairAttributeMask, } +#[cfg(feature = "debug")] +use lazy_static::lazy_static; + +#[cfg(feature = "debug")] lazy_static! { static ref RC_QP_STATE_TABLE: [[QueuePairStateTableEntry; QueuePairState::Error as usize + 1]; QueuePairState::Error as usize + 1] = { @@ -713,11 +795,13 @@ impl QueuePairAttribute { // TODO(zhp): trait for QueuePair +#[cfg(feature = "debug")] #[inline] fn get_needed_mask(cur_mask: QueuePairAttributeMask, required_mask: QueuePairAttributeMask) -> QueuePairAttributeMask { required_mask.and(required_mask.xor(cur_mask)) } +#[cfg(feature = "debug")] #[inline] fn get_invalid_mask( cur_mask: QueuePairAttributeMask, required_mask: QueuePairAttributeMask, optional_mask: QueuePairAttributeMask, @@ -725,11 +809,16 @@ fn get_invalid_mask( cur_mask.and(required_mask.or(optional_mask).not()) } +#[cfg(feature = "debug")] fn attr_mask_check( attr_mask: QueuePairAttributeMask, cur_state: QueuePairState, next_state: QueuePairState, -) -> Result<(), String> { +) -> Result<(), ModifyQueuePairError> { if !RC_QP_STATE_TABLE[cur_state as usize][next_state as usize].valid { - return Err(format!("Invalid transition from {cur_state:?} to {next_state:?}")); + return Err(ModifyQueuePairError::InvalidTransition { + cur_state, + next_state, + source: io::Error::from_raw_os_error(libc::EINVAL), + }); } let required = RC_QP_STATE_TABLE[cur_state as usize][next_state as usize].required_mask; @@ -739,7 +828,13 @@ fn attr_mask_check( if invalid.bits == 0 && needed.bits == 0 { Ok(()) } else { - Err(format!("Invalid transition from {cur_state:?} to {next_state:?}, possible invalid masks {invalid:?}, possible needed masks {needed:?}")) + Err(ModifyQueuePairError::InvalidAttributeMask { + cur_state, + next_state, + invalid, + needed, + source: io::Error::from_raw_os_error(libc::EINVAL), + }) } } @@ -750,12 +845,12 @@ pub struct WorkRequestHandle<'g, G: PostSendGuard + ?Sized> { pub trait SetScatterGatherEntry { /// # Safety /// - /// set a local buffer to the request; note that the lifetime of the buffer + /// set a local buffer to the request; note that the lifetime of the buffer /// associated with the sge is managed by the caller. unsafe fn setup_sge(self, lkey: u32, addr: u64, length: u32); /// # Safety /// - /// set a list of local buffers to the request; note that the lifetime of + /// set a list of local buffers to the request; note that the lifetime of /// the buffer associated with the sge is managed by the caller. unsafe fn setup_sge_list(self, sg_list: &[ibv_sge]); } @@ -826,7 +921,7 @@ impl PostSendGuard for BasicPostSendGuard<'_> { WorkRequestHandle { guard: self } } - fn post(mut self) -> Result<(), String> { + fn post(mut self) -> Result<(), PostSendError> { let mut sge_index = 0; for i in 0..self.wrs.len() { @@ -848,7 +943,19 @@ impl PostSendGuard for BasicPostSendGuard<'_> { let ret = unsafe { ibv_post_send(self.qp.as_ptr(), self.wrs.as_mut_ptr(), &mut bad_wr) }; match ret { 0 => Ok(()), - err => Err(format!("ibv_post_send failed, ret={err}")), + #[cfg(feature = "debug")] + libc::EINVAL => Err(PostSendError::InvalidWorkRequest(io::Error::from_raw_os_error( + libc::EINVAL, + ))), + #[cfg(feature = "debug")] + libc::ENOMEM => Err(PostSendError::NotEnoughResources(io::Error::from_raw_os_error( + libc::ENOMEM, + ))), + #[cfg(feature = "debug")] + libc::EFAULT => Err(PostSendError::InvalidValueInQueuePair(io::Error::from_raw_os_error( + libc::EFAULT, + ))), + err => Err(PostSendError::GenericError(io::Error::from_raw_os_error(err))), } } } @@ -932,14 +1039,26 @@ impl PostSendGuard for ExtendedPostSendGuard<'_> { WorkRequestHandle { guard: self } } - fn post(mut self) -> Result<(), String> { + fn post(mut self) -> Result<(), PostSendError> { let ret: i32 = unsafe { ibv_wr_complete(self.qp_ex.unwrap_unchecked().as_ptr()) }; self.qp_ex = None; match ret { 0 => Ok(()), - err => Err(format!("failed to ibv_wr_complete: ret {err}")), + #[cfg(feature = "debug")] + libc::EINVAL => Err(PostSendError::InvalidWorkRequest(io::Error::from_raw_os_error( + libc::EINVAL, + ))), + #[cfg(feature = "debug")] + libc::ENOMEM => Err(PostSendError::NotEnoughResources(io::Error::from_raw_os_error( + libc::ENOMEM, + ))), + #[cfg(feature = "debug")] + libc::EFAULT => Err(PostSendError::InvalidValueInQueuePair(io::Error::from_raw_os_error( + libc::EFAULT, + ))), + err => Err(PostSendError::GenericError(io::Error::from_raw_os_error(err))), } } }