diff --git a/src/verbs/queue_pair.rs b/src/verbs/queue_pair.rs index c7d7131..cb209ad 100644 --- a/src/verbs/queue_pair.rs +++ b/src/verbs/queue_pair.rs @@ -136,8 +136,34 @@ pub trait QueuePair { /// return the basic handle of QP; /// we mark this method unsafe because the lifetime of ibv_qp is not /// associated with the return value. + /// + /// # Examples + /// + /// ```compile_fail + /// unsafe { + /// let qp_ptr = generic_queue_pair.qp(); + /// // Use qp_ptr carefully... + /// } unsafe fn qp(&self) -> NonNull; + /// Modifies the queue pair attributes. + /// + /// # Arguments + /// + /// * `attr` - A reference to the QueuePairAttribute to be applied. + /// + /// # Returns + /// + /// * `Ok(())` if the modification was successful. + /// * `Err(String)` if the modification failed, containing an error message. + /// + /// # Examples + /// + /// ```compile_fail + /// let mut attr = QueuePairAttribute::new(); + /// attr.setup_state(QueuePairState::ReadyToSend); + /// generic_queue_pair.modify(&attr)?; + /// ``` fn modify(&mut self, attr: &QueuePairAttribute) -> Result<(), String> { // ibv_qp_attr does not impl Clone trait, so we use struct update syntax here let mut qp_attr = ibv_qp_attr { ..attr.attr }; @@ -157,15 +183,34 @@ pub trait QueuePair { } } + /// Get the queue pair state. fn state(&self) -> QueuePairState { unsafe { self.qp().as_ref().state.into() } } + /// Get the queue pair number. fn qp_number(&self) -> u32 { unsafe { self.qp().as_ref().qp_num } } - // Every qp should hold only one PostSendGuard at the same time. + /// Could be [`ExtendedPostSendGuard`], [`BasicPostSendGuard`] or [`GenericPostSendGuard`] + type Guard<'g>: PostSendGuard + where + Self: 'g; + + /// Starts a post send operation, every qp should hold only one PostSendGuard at the same time. + /// + /// # Returns + /// + /// A `PostSendGuard` that can be used to construct and post send work requests. + /// + /// # Examples + /// + /// ```compile_fail + /// let mut guard = generic_queue_pair.start_post_send(); + /// let send_wr = guard.construct_wr(/* ... */); + /// guard.post()?; + /// ``` // // RPITIT could be used here, but with lifetime bound, there could be problems. // @@ -174,11 +219,21 @@ pub trait QueuePair { // https://github.com/rust-lang/rfcs/pull/3425 // https://github.com/rust-lang/rust/issues/125836 // - type Guard<'g>: PostSendGuard - where - Self: 'g; fn start_post_send(&mut self) -> Self::Guard<'_>; + /// Starts a post receive operation. + /// + /// # Returns + /// + /// A `PostRecvGuard` that can be used to construct and post receive work requests. + /// + /// # Examples + /// + /// ```compile_fail + /// let mut guard = generic_queue_pair.start_post_recv(); + /// let recv_wr = guard.construct_wr(/* ... */); + /// guard.post()?; + /// ``` fn start_post_recv(&mut self) -> PostRecvGuard<'_> { PostRecvGuard { qp: unsafe { self.qp() }, @@ -750,12 +805,12 @@ pub struct WorkRequestHandle<'g, G: PostSendGuard + ?Sized> { pub trait SetScatterGatherEntry { /// # Safety /// - /// set a local buffer to the request; note that the lifetime of the buffer + /// set a local buffer to the request; note that the lifetime of the buffer /// associated with the sge is managed by the caller. unsafe fn setup_sge(self, lkey: u32, addr: u64, length: u32); /// # Safety /// - /// set a list of local buffers to the request; note that the lifetime of + /// set a list of local buffers to the request; note that the lifetime of /// the buffer associated with the sge is managed by the caller. unsafe fn setup_sge_list(self, sg_list: &[ibv_sge]); } @@ -803,12 +858,12 @@ impl<'g, G: PostSendGuard> WorkRequestHandle<'g, G> { } } -pub struct BasicPostSendGuard<'qp> { +pub struct BasicPostSendGuard<'g> { qp: NonNull, wrs: Vec, sges: Vec, inline_buffers: Vec>, - _phantom: PhantomData<&'qp ()>, + _phantom: PhantomData<&'g ()>, } impl PostSendGuard for BasicPostSendGuard<'_> { @@ -1060,3 +1115,153 @@ impl SetScatterGatherEntry for RecvWorkRequestHandle<'_, '_> { self.guard.sges.extend_from_slice(sg_list); } } + +#[derive(Debug)] +pub enum GenericQueuePair<'qp> { + /// Variant for a Basic Queue Pair + Basic(BasicQueuePair<'qp>), + /// Variant for an Extended Queue Pair + Extended(ExtendedQueuePair<'qp>), +} + +impl QueuePair for GenericQueuePair<'_> { + unsafe fn qp(&self) -> NonNull { + match self { + GenericQueuePair::Basic(qp) => qp.qp(), + GenericQueuePair::Extended(qp) => qp.qp(), + } + } + + fn qp_number(&self) -> u32 { + match self { + GenericQueuePair::Basic(qp) => qp.qp_number(), + GenericQueuePair::Extended(qp) => qp.qp_number(), + } + } + + fn modify(&mut self, attr: &QueuePairAttribute) -> Result<(), String> { + match self { + GenericQueuePair::Basic(qp) => qp.modify(attr), + GenericQueuePair::Extended(qp) => qp.modify(attr), + } + } + + fn start_post_recv(&mut self) -> PostRecvGuard<'_> { + match self { + GenericQueuePair::Basic(qp) => qp.start_post_recv(), + GenericQueuePair::Extended(qp) => qp.start_post_recv(), + } + } + + type Guard<'g> = GenericPostSendGuard<'g> where Self: 'g; + + fn start_post_send(&mut self) -> Self::Guard<'_> { + match self { + GenericQueuePair::Basic(qp) => GenericPostSendGuard::Basic(qp.start_post_send()), + GenericQueuePair::Extended(qp) => GenericPostSendGuard::Extended(qp.start_post_send()), + } + } +} + +pub enum GenericPostSendGuard<'g> { + Basic(BasicPostSendGuard<'g>), + Extended(ExtendedPostSendGuard<'g>), +} + +impl<'g> PostSendGuard for GenericPostSendGuard<'g> { + fn construct_wr(&mut self, wr_id: u64, wr_flags: WorkRequestFlags) -> WorkRequestHandle<'_, Self> { + match self { + GenericPostSendGuard::Basic(guard) => { + guard.construct_wr(wr_id, wr_flags); + WorkRequestHandle { guard: self } + }, + GenericPostSendGuard::Extended(guard) => { + guard.construct_wr(wr_id, wr_flags); + WorkRequestHandle { guard: self } + }, + } + } + + fn post(self) -> Result<(), String> { + match self { + GenericPostSendGuard::Basic(guard) => guard.post(), + GenericPostSendGuard::Extended(guard) => guard.post(), + } + } +} + +impl<'g> private_traits::PostSendGuard for GenericPostSendGuard<'g> { + fn setup_send(&mut self) { + match self { + GenericPostSendGuard::Basic(guard) => guard.setup_send(), + GenericPostSendGuard::Extended(guard) => guard.setup_send(), + } + } + + fn setup_write(&mut self, rkey: u32, remote_addr: u64) { + match self { + GenericPostSendGuard::Basic(guard) => guard.setup_write(rkey, remote_addr), + GenericPostSendGuard::Extended(guard) => guard.setup_write(rkey, remote_addr), + } + } + + fn setup_inline_data(&mut self, buf: &[u8]) { + match self { + GenericPostSendGuard::Basic(guard) => guard.setup_inline_data(buf), + GenericPostSendGuard::Extended(guard) => guard.setup_inline_data(buf), + } + } + + fn setup_inline_data_list(&mut self, bufs: &[IoSlice<'_>]) { + match self { + GenericPostSendGuard::Basic(guard) => guard.setup_inline_data_list(bufs), + GenericPostSendGuard::Extended(guard) => guard.setup_inline_data_list(bufs), + } + } + + unsafe fn setup_sge(&mut self, lkey: u32, addr: u64, length: u32) { + match self { + GenericPostSendGuard::Basic(guard) => guard.setup_sge(lkey, addr, length), + GenericPostSendGuard::Extended(guard) => guard.setup_sge(lkey, addr, length), + } + } + + unsafe fn setup_sge_list(&mut self, sg_list: &[ibv_sge]) { + match self { + GenericPostSendGuard::Basic(guard) => guard.setup_sge_list(sg_list), + GenericPostSendGuard::Extended(guard) => guard.setup_sge_list(sg_list), + } + } +} + +impl<'qp> From> for GenericQueuePair<'qp> { + /// Converts a BasicQueuePair into a GenericQueuePair. + /// + /// This allows for easy creation of a GenericQueuePair from a BasicQueuePair. + /// + /// # Examples + /// + /// ```compile_fail + /// let basic_qp = builder.build().unwarp(); + /// let generic_qp: GenericQueuePair = basic_qp.into(); + /// ``` + fn from(qp: BasicQueuePair<'qp>) -> Self { + GenericQueuePair::Basic(qp) + } +} + +impl<'qp> From> for GenericQueuePair<'qp> { + /// Converts an ExtendedQueuePair into a GenericQueuePair. + /// + /// This allows for easy creation of a GenericQueuePair from an ExtendedQueuePair. + /// + /// # Examples + /// + /// ```compile_fail + /// let extended_qp = builder.build_ex().unwarp(); + /// let generic_qp: GenericQueuePair = extended_qp.into(); + /// ``` + fn from(qp: ExtendedQueuePair<'qp>) -> Self { + GenericQueuePair::Extended(qp) + } +} diff --git a/tests/test_ibv_wr.rs b/tests/test_ibv_wr.rs deleted file mode 100644 index 438ecaa..0000000 --- a/tests/test_ibv_wr.rs +++ /dev/null @@ -1,166 +0,0 @@ -use core::{slice, time}; -use std::{io::IoSlice, thread}; - -use rdma_mummy_sys::ibv_sge; -use sideway::verbs::{ - address::{AddressHandleAttribute, GidType}, - device, - device_context::Mtu, - queue_pair::{ - PostSendGuard, QueuePair, QueuePairAttribute, QueuePairState, SetInlineData, SetScatterGatherEntry, - WorkRequestFlags, - }, - AccessFlags, -}; - -#[test] -#[allow(clippy::while_let_on_iterator)] -fn main() -> Result<(), Box> { - let device_list = device::DeviceList::new()?; - for device in &device_list { - let ctx = device.open().unwrap(); - - let pd = ctx.alloc_pd().unwrap(); - let mr = pd.reg_managed_mr(64).unwrap(); - let recv_mr = pd.reg_managed_mr(64).unwrap(); - - let _comp_channel = ctx.create_comp_channel().unwrap(); - let mut cq_builder = ctx.create_cq_builder(); - let sq = cq_builder.setup_cqe(128).build_ex().unwrap(); - let rq = cq_builder.setup_cqe(128).build_ex().unwrap(); - - let mut builder = pd.create_qp_builder(); - - let mut qp = builder - .setup_max_inline_data(128) - .setup_send_cq(&sq) - .setup_recv_cq(&rq) - .build_ex() - .unwrap(); - - println!("qp pointer is {:?}", qp); - // modify QP to INIT state - let mut attr = QueuePairAttribute::new(); - attr.setup_state(QueuePairState::Init) - .setup_pkey_index(0) - .setup_port(1) - .setup_access_flags(AccessFlags::LocalWrite | AccessFlags::RemoteWrite); - qp.modify(&attr).unwrap(); - - assert_eq!(QueuePairState::Init, qp.state()); - - // modify QP to RTR state, set dest qp as itself - let mut attr = QueuePairAttribute::new(); - attr.setup_state(QueuePairState::ReadyToReceive) - .setup_path_mtu(Mtu::Mtu1024) - .setup_dest_qp_num(qp.qp_number()) - .setup_rq_psn(1) - .setup_max_dest_read_atomic(0) - .setup_min_rnr_timer(0); - // setup address vector - let mut ah_attr = AddressHandleAttribute::new(); - let gid_entries = ctx.query_gid_table().unwrap(); - let gid = gid_entries - .iter() - .find(|&&gid| !gid.gid().is_unicast_link_local() || gid.gid_type() == GidType::RoceV1) - .unwrap(); - - ah_attr - .setup_dest_lid(1) - .setup_port(1) - .setup_service_level(1) - .setup_grh_src_gid_index(gid.gid_index().try_into().unwrap()) - .setup_grh_dest_gid(&gid.gid()) - .setup_grh_hop_limit(64); - attr.setup_address_vector(&ah_attr); - qp.modify(&attr).unwrap(); - - assert_eq!(QueuePairState::ReadyToReceive, qp.state()); - - // modify QP to RTS state - let mut attr = QueuePairAttribute::new(); - attr.setup_state(QueuePairState::ReadyToSend) - .setup_sq_psn(1) - .setup_timeout(12) - .setup_retry_cnt(7) - .setup_rnr_retry(7) - .setup_max_read_atomic(0); - - qp.modify(&attr).unwrap(); - - assert_eq!(QueuePairState::ReadyToSend, qp.state()); - - // post one recv buf to the qp - let mut guard = qp.start_post_recv(); - let recv_handle = guard.construct_wr(114514); - unsafe { - recv_handle.setup_sge_list(slice::from_ref(&ibv_sge { - addr: recv_mr.buf.data.as_ptr() as _, - length: recv_mr.buf.len as _, - lkey: recv_mr.lkey(), - })) - }; - guard.post().unwrap(); - - let mut guard = qp.start_post_send(); - let buf = vec![0, 1, 2, 3]; - - let write_handle = guard - .construct_wr(233, WorkRequestFlags::Signaled | WorkRequestFlags::Inline) - .setup_write(mr.rkey(), mr.buf.data.as_ptr() as _); - - write_handle.setup_inline_data(&buf); - - // it's safe for users to drop the inline buffer after they calling setup inline data - drop(buf); - - let buf = vec![vec![b'H', b'e', b'l', b'l', b'o'], vec![b'R', b'D', b'M', b'A']]; - - let write_handle = unsafe { - guard - .construct_wr(234, WorkRequestFlags::Signaled | WorkRequestFlags::Inline) - .setup_write(mr.rkey(), mr.buf.data.byte_add(4).as_ptr() as _) - }; - - write_handle.setup_inline_data_list(&[IoSlice::new(buf[0].as_ref()), IoSlice::new(buf[1].as_ref())]); - - // use SEND to transmit the same data - let send_handle = guard.construct_wr(567, WorkRequestFlags::Signaled).setup_send(); - send_handle.setup_inline_data_list(&[IoSlice::new(buf[0].as_ref()), IoSlice::new(buf[1].as_ref())]); - - // it's safe for users to drop the inline buffer after they calling setup inline data - drop(buf); - - guard.post().unwrap(); - - thread::sleep(time::Duration::from_millis(10)); - - // poll send CQ for the completion - { - let mut poller = sq.start_poll().unwrap(); - while let Some(wc) = poller.next() { - println!("wr_id {}, status: {}, opcode: {}", wc.wr_id(), wc.status(), wc.opcode()) - } - } - - unsafe { - let slice = std::slice::from_raw_parts(mr.buf.data.as_ptr(), mr.buf.len); - println!("Buffer contents: {:?}", slice); - } - - // poll recv CQ for the completion - { - let mut poller = rq.start_poll().unwrap(); - while let Some(wc) = poller.next() { - println!("wr_id {}, status: {}, opcode: {}", wc.wr_id(), wc.status(), wc.opcode()) - } - } - - unsafe { - let slice = std::slice::from_raw_parts(recv_mr.buf.data.as_ptr(), recv_mr.buf.len); - println!("Recv Buffer contents: {:?}", slice); - } - } - - Ok(()) -} diff --git a/tests/test_post_send.rs b/tests/test_post_send.rs index d9251ea..16b33a2 100644 --- a/tests/test_post_send.rs +++ b/tests/test_post_send.rs @@ -1,6 +1,9 @@ +#![allow(clippy::while_let_on_iterator)] + use core::time; use std::{io::IoSlice, thread}; +use sideway::verbs::queue_pair::GenericQueuePair; use sideway::verbs::{ address::{AddressHandleAttribute, GidType}, device, @@ -12,9 +15,12 @@ use sideway::verbs::{ AccessFlags, }; -#[test] -#[allow(clippy::while_let_on_iterator)] -fn main() -> Result<(), Box> { +use rstest::rstest; + +#[rstest] +#[case(true)] +#[case(false)] +fn main(#[case] use_qp_ex: bool) -> Result<(), Box> { let device_list = device::DeviceList::new()?; for device in &device_list { let ctx = device.open().unwrap(); @@ -30,12 +36,23 @@ fn main() -> Result<(), Box> { let mut builder = pd.create_qp_builder(); - let mut qp = builder - .setup_max_inline_data(128) - .setup_send_cq(&sq) - .setup_recv_cq(&rq) - .build() - .unwrap(); + let mut qp: GenericQueuePair = if use_qp_ex { + builder + .setup_max_inline_data(128) + .setup_send_cq(&sq) + .setup_recv_cq(&rq) + .build_ex() + .unwrap() + .into() + } else { + builder + .setup_max_inline_data(128) + .setup_send_cq(&sq) + .setup_recv_cq(&rq) + .build() + .unwrap() + .into() + }; println!("qp pointer is {:?}", qp); // modify QP to INIT state