diff --git a/.cirrus.yml b/.cirrus.yml index 0bce537..d2b9b05 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -27,6 +27,7 @@ task: - export LD_LIBRARY_PATH=./rdma-core/build/lib - just test-basic-with-cov - just test-rc-pingpong-with-cov + - just test-cmtime-with-cov - just generate-cov - sed -i 's#/tmp/cirrus-ci-build/##g' lcov.info - ./codecov --verbose upload-process --disable-search --fail-on-error -t $CODECOV_TOKEN --git-service github -f ./lcov.info diff --git a/Justfile b/Justfile index 4a57f54..083ef91 100644 --- a/Justfile +++ b/Justfile @@ -20,5 +20,10 @@ test-rc-pingpong-with-cov: sleep 2 cargo llvm-cov --no-report run --features="debug" --example rc_pingpong_split -- -d {{rdma_dev}} -g 1 127.0.0.1 +test-cmtime-with-cov: + cargo llvm-cov --no-report run --example cmtime -- -b {{ip}} & + sleep 2 + cargo llvm-cov --no-report run --example cmtime -- -b {{ip}} -s {{ip}} + generate-cov: cargo llvm-cov report --lcov --output-path lcov.info diff --git a/examples/cmtime.rs b/examples/cmtime.rs new file mode 100644 index 0000000..c29727c --- /dev/null +++ b/examples/cmtime.rs @@ -0,0 +1,493 @@ +use sideway::cm::communication_manager::{ConnectionParameter, Event, EventChannel, Identifier, PortSpace}; +use sideway::verbs::completion::GenericCompletionQueue; +use sideway::verbs::device_context::DeviceContext; +use sideway::verbs::protection_domain::ProtectionDomain; +use sideway::verbs::queue_pair::{GenericQueuePair, QueuePair, QueuePairState}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::str::FromStr; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::mpsc::{channel, Sender}; +use std::sync::{Arc, Mutex, Once}; +use std::thread; +use std::time::Duration; +use tabled::settings::object::Columns; + +use clap::Parser; +use lazy_static::lazy_static; +use quanta::Instant; +use tabled::{ + settings::{object::Segment, Alignment, Modify, Style}, + Table, Tabled, +}; + +#[derive(Debug, Parser)] +#[clap(name = "cmtime", version = "0.1.0")] +pub struct Args { + /// Listen on / connect to port + #[clap(long, short = 'p', default_value_t = 18515)] + port: u16, + /// Bind address + #[clap(long, short = 'b')] + bind_address: Option, + /// If no value provided, start a server and wait for connection, otherwise, connect to server at [host] + #[clap(long, short = 's')] + server_address: Option, + // Use self-created, self-modified QP + #[arg(long, short = 'q', default_value_t = false)] + self_modify: bool, + // Number of connections + #[arg(long, short = 'c', default_value_t = 100)] + connections: u32, +} + +#[repr(usize)] +#[derive(Debug)] +pub enum Step { + CreateId, + Bind, + ResolveAddr, + ResolveRoute, + CreateQueuePair, + Connect, + ModifyToInit, + ModifyToRTR, + ModifyToRTS, + Disconnect, + Destroy, + Count, +} + +static mut CTX: Option> = None; +static mut PD: Option> = None; +static mut CQ: Option> = None; + +static OPEN_VERBS: Once = Once::new(); + +struct Node<'a> { + id: Option>, + qp: Option>, + times: [(Instant, Instant); Step::Count as usize], +} + +#[derive(Tabled)] +struct StageResult { + #[tabled(rename = "Step")] + stage: String, + #[tabled(rename = "Total (ms)", format = "{:.2}")] + total: f64, + #[tabled(rename = "Max (us)", format = "{:.2}")] + max: f64, + #[tabled(rename = "Min (us)", format = "{:.2}")] + min: f64, +} + +lazy_static! { + static ref STARTED: [AtomicU32; Step::Count as usize] = [const { AtomicU32::new(0) }; Step::Count as usize]; + static ref COMPLETED: [AtomicU32; Step::Count as usize] = [const { AtomicU32::new(0) }; Step::Count as usize]; + static ref TIMES: Mutex<[(Instant, Instant); Step::Count as usize]> = + Mutex::new([(Instant::recent(), Instant::recent()); Step::Count as usize]); + static ref CHANNEL: Mutex = + Mutex::new(EventChannel::new().expect("Failed to create rdma cm event channel")); +} + +macro_rules! start_perf { + ($node:expr, $step:expr) => {{ + $node.lock().unwrap().times[$step as usize].0 = Instant::now(); + }}; +} + +macro_rules! end_perf { + ($node:expr, $step:expr) => {{ + $node.lock().unwrap().times[$step as usize].1 = Instant::now(); + }}; +} + +macro_rules! start_time { + ($step:expr) => {{ + { + let mut times = TIMES.lock().unwrap(); + times[$step as usize].0 = Instant::now(); + } + }}; +} + +macro_rules! end_time { + ($step:expr, $results:expr, $nodes:expr) => {{ + { + let mut times = TIMES.lock().unwrap(); + times[$step as usize].1 = Instant::now(); + + // Calculate min/max from individual node times + let mut max_us = 0.0f64; + let mut min_us = f64::MAX; + + for node in $nodes { + let node = node.lock().unwrap(); + let duration = node.times[$step as usize] + .1 + .duration_since(node.times[$step as usize].0) + .as_secs_f64() + * 1_000_000.0; // Convert to microseconds + + max_us = max_us.max(duration); + min_us = min_us.min(duration); + } + + // Handle case where no valid measurements exist + if min_us == f64::MAX { + min_us = 0.0; + } + + $results.push(StageResult { + stage: format!("{:?}", $step), + total: times[$step as usize] + .1 + .duration_since(times[$step as usize].0) + .as_secs_f64() + * 1000.0, // Keep total in milliseconds + max: max_us, + min: min_us, + }); + } + }}; +} + +fn cma_handler( + id: Arc, event: Event, resp_wq: Option>>, + req_wq: Option>>, disc_wq: Option>>, +) { + use sideway::cm::communication_manager::EventType::*; + let node: Option>> = id.get_context(); + + match event.event_type() { + AddressResolved => { + end_perf!(node.unwrap(), Step::ResolveAddr); + COMPLETED[Step::ResolveAddr as usize].fetch_add(1, Ordering::Relaxed); + }, + RouteResolved => { + end_perf!(node.unwrap(), Step::ResolveRoute); + COMPLETED[Step::ResolveRoute as usize].fetch_add(1, Ordering::Relaxed); + }, + ConnectRequest => { + let cm_id = event.cm_id().clone().unwrap(); + OPEN_VERBS.call_once(|| unsafe { + CTX = Some(cm_id.get_device_context().unwrap().clone()); + PD = Some(Arc::new(CTX.as_ref().unwrap().alloc_pd().unwrap())); + CQ = Some(Arc::new( + CTX.as_ref() + .unwrap() + .create_cq_builder() + .setup_cqe(1) + .build_ex() + .unwrap() + .into(), + )); + }); + req_wq.unwrap().send(cm_id).unwrap(); + }, + ConnectResponse => { + if let Some(wq) = resp_wq { + wq.send(id).unwrap(); + } else { + end_perf!(node.unwrap(), Step::Connect); + } + }, + Established => { + if let Some(node) = node { + end_perf!(node, Step::Connect); + COMPLETED[Step::Connect as usize].fetch_add(1, Ordering::Relaxed); + } + }, + Disconnected => { + if let Some(wq) = disc_wq { + wq.send(id).unwrap(); + } else { + end_perf!(node.unwrap(), Step::Disconnect); + } + COMPLETED[Step::Disconnect as usize].fetch_add(1, Ordering::Relaxed); + }, + AddressError => { + println!("Event: {:?}, error: {}", event.event_type(), event.status()); + }, + ConnectError | Unreachable | Rejected => { + println!("Event: {:?}, error: {}", event.event_type(), event.status()); + }, + TimewaitExit => {}, + _ => { + println!("Other events: {:?}", event.event_type()); + }, + } + let _ = event.ack(); +} + +impl<'a> Node<'a> { + fn create_qp(&mut self) { + unsafe { + let pd = PD.as_ref().unwrap(); + let cq = CQ.as_ref().unwrap(); + + let mut qp_builder = pd.create_qp_builder(); + + qp_builder + .setup_max_send_wr(1) + .setup_max_send_sge(1) + .setup_max_recv_wr(1) + .setup_max_recv_sge(1) + .setup_send_cq(cq.as_ref()) + .setup_recv_cq(cq.as_ref()); + + let qp = qp_builder.build_ex().unwrap().into(); + + self.qp = Some(qp); + } + } +} + +fn main() -> Result<(), Box> { + let args = Args::parse(); + let mut results: Vec = Vec::new(); + + if args.server_address.is_some() { + let (resp_tx, resp_rx) = channel(); + + let _resp_handler = thread::spawn(move || loop { + let cm_id: Arc = resp_rx.recv().expect("Failed to receive cm_id"); + + let node: Arc> = cm_id.get_context().unwrap(); + + { + let mut guard = node.lock().unwrap(); + let qp = guard.qp.as_mut().unwrap(); + + let attr = cm_id.get_qp_attr(QueuePairState::Init).unwrap(); + qp.modify(&attr).unwrap(); + + let attr = cm_id.get_qp_attr(QueuePairState::ReadyToReceive).unwrap(); + qp.modify(&attr).unwrap(); + + let attr = cm_id.get_qp_attr(QueuePairState::ReadyToSend).unwrap(); + qp.modify(&attr).unwrap(); + + cm_id.establish().unwrap(); + } + + end_perf!(node, Step::Connect); + COMPLETED[Step::Connect as usize].fetch_add(1, Ordering::Relaxed); + }); + + let mut nodes = Vec::with_capacity(args.connections as usize); + + start_time!(Step::CreateId); + for _i in 0..args.connections { + let node = Mutex::new(Node { + id: None, + qp: None, + times: [(Instant::recent(), Instant::recent()); Step::Count as usize], + }); + start_perf!(node, Step::CreateId); + let id = CHANNEL.lock().unwrap().create_id(PortSpace::Tcp)?; + end_perf!(node, Step::CreateId); + node.lock().unwrap().id = Some(id.clone()); + id.setup_context(node); + let node: Arc> = id.get_context().unwrap(); + nodes.push(node); + } + end_time!(Step::CreateId, results, &nodes); + + let _dispatcher = thread::spawn(move || loop { + match CHANNEL.lock().unwrap().get_cm_event() { + Ok(event) => cma_handler(event.cm_id().unwrap(), event, Some(resp_tx.clone()), None, None), + Err(err) => { + eprintln!("{err}"); + break; + }, + } + }); + + let ip = IpAddr::from_str(&args.server_address.unwrap()).expect("Invalid IP address"); + let server_addr = SocketAddr::from((ip, args.port)); + + let ip = IpAddr::from_str(&args.bind_address.unwrap()).expect("Invalid IP address"); + let client_addr = SocketAddr::from((ip, 0)); + + start_time!(Step::ResolveAddr); + for node in &nodes { + start_perf!(node, Step::ResolveAddr); + if let Some(ref id) = node.lock().unwrap().id { + id.resolve_addr(Some(client_addr), server_addr, Duration::new(2, 0))?; + STARTED[Step::ResolveAddr as usize].fetch_add(1, Ordering::Relaxed); + } + } + + while STARTED[Step::ResolveAddr as usize].load(Ordering::Acquire) + != COMPLETED[Step::ResolveAddr as usize].load(Ordering::Acquire) + { + thread::yield_now(); + } + end_time!(Step::ResolveAddr, results, &nodes); + + start_time!(Step::ResolveRoute); + for node in &nodes { + start_perf!(node, Step::ResolveRoute); + if let Some(ref id) = node.lock().unwrap().id { + id.resolve_route(Duration::new(2, 0))?; + STARTED[Step::ResolveRoute as usize].fetch_add(1, Ordering::Relaxed); + } + } + + while STARTED[Step::ResolveRoute as usize].load(Ordering::Acquire) + != COMPLETED[Step::ResolveRoute as usize].load(Ordering::Acquire) + { + thread::yield_now(); + } + end_time!(Step::ResolveRoute, results, &nodes); + + start_time!(Step::CreateQueuePair); + for node in &nodes { + start_perf!(node, Step::CreateQueuePair); + { + let mut guard = node.lock().unwrap(); + if let Some(ref id) = guard.id { + OPEN_VERBS.call_once(|| unsafe { + CTX = Some(id.get_device_context().unwrap().clone()); + PD = Some(Arc::new(CTX.as_ref().unwrap().alloc_pd().unwrap())); + CQ = Some(Arc::new( + CTX.as_ref() + .unwrap() + .create_cq_builder() + .setup_cqe(1) + .build_ex() + .unwrap() + .into(), + )); + }); + guard.create_qp(); + } + } + end_perf!(node, Step::CreateQueuePair); + } + end_time!(Step::CreateQueuePair, results, &nodes); + + start_time!(Step::Connect); + for node in &nodes { + start_perf!(node, Step::Connect); + let guard = node.lock().unwrap(); + if let Some(ref id) = guard.id { + let qp = guard.qp.as_ref().unwrap(); + + let mut conn_param = ConnectionParameter::default(); + conn_param.setup_qp_number(qp.qp_number()); + + id.connect(conn_param)?; + + STARTED[Step::Connect as usize].fetch_add(1, Ordering::Relaxed); + } + } + + while STARTED[Step::Connect as usize].load(Ordering::Acquire) + != COMPLETED[Step::Connect as usize].load(Ordering::Acquire) + { + thread::yield_now(); + } + end_time!(Step::Connect, results, &nodes); + + start_time!(Step::Disconnect); + for node in &nodes { + start_perf!(node, Step::Disconnect); + if let Some(ref id) = node.lock().unwrap().id { + id.disconnect()?; + STARTED[Step::Disconnect as usize].fetch_add(1, Ordering::Relaxed); + } + } + + while STARTED[Step::Disconnect as usize].load(Ordering::Acquire) + != COMPLETED[Step::Disconnect as usize].load(Ordering::Acquire) + { + thread::yield_now(); + } + end_time!(Step::Disconnect, results, &nodes); + + let style = Style::psql().remove_verticals(); + + let table = Table::new(results) + .with(Modify::new(Segment::all()).with(Alignment::right())) + .with(style) + .modify(Columns::first(), Alignment::left()) + .to_string(); + + println!("{}", table); + } else { + let id = CHANNEL.lock().unwrap().create_id(PortSpace::Tcp)?; + id.bind_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), args.port))?; + id.listen(1024)?; + + let mut nodes = vec![None; args.connections as usize]; + let node_idx = AtomicU32::new(0); + + let (req_tx, req_rx) = channel(); + let (disc_tx, disc_rx) = channel(); + + let dispatcher = thread::spawn(move || loop { + match CHANNEL.lock().unwrap().get_cm_event() { + Ok(event) => cma_handler( + event.cm_id().unwrap(), + event, + None, + Some(req_tx.clone()), + Some(disc_tx.clone()), + ), + Err(err) => { + eprintln!("{err}"); + break; + }, + } + }); + + let req_handler = thread::spawn(move || loop { + let cm_id: Arc = req_rx.recv().expect("Failed to receive cm_id"); + + let node = Arc::new(Mutex::new(Node { + id: Some(cm_id.clone()), + qp: None, + times: [(Instant::recent(), Instant::recent()); Step::Count as usize], + })); + + let mut conn_param = ConnectionParameter::default(); + + { + let mut guard = node.lock().unwrap(); + + guard.create_qp(); + + let qp = guard.qp.as_mut().unwrap(); + + let attr = cm_id.get_qp_attr(QueuePairState::Init).unwrap(); + qp.modify(&attr).unwrap(); + + let attr = cm_id.get_qp_attr(QueuePairState::ReadyToReceive).unwrap(); + qp.modify(&attr).unwrap(); + + let attr = cm_id.get_qp_attr(QueuePairState::ReadyToSend).unwrap(); + qp.modify(&attr).unwrap(); + + conn_param.setup_qp_number(qp.qp_number()); + } + + cm_id.setup_context(node.clone()); + nodes[(node_idx.fetch_add(1, Ordering::Relaxed)) as usize] = Some(node); + + cm_id.accept(Some(conn_param)).unwrap(); + }); + + let disc_handler = thread::spawn(move || loop { + let cm_id: Arc = disc_rx.recv().expect("Failed to receive cm_id"); + cm_id.disconnect().unwrap(); + }); + + let _ = req_handler.join(); + let _ = disc_handler.join(); + let _ = dispatcher.join(); + } + + Ok(()) +} diff --git a/src/cm/communication_manager.rs b/src/cm/communication_manager.rs index 4e60f7e..545af22 100644 --- a/src/cm/communication_manager.rs +++ b/src/cm/communication_manager.rs @@ -1,29 +1,115 @@ -use std::{mem::MaybeUninit, net::SocketAddr, ptr::NonNull}; +use std::any::Any; +use std::collections::HashMap; +use std::marker::PhantomData; +use std::mem::ManuallyDrop; +use std::ptr::{null, null_mut}; +use std::sync::{LazyLock, Mutex, Weak}; +use std::time::Duration; +use std::{io, mem::MaybeUninit, net::SocketAddr, ptr::NonNull, sync::Arc}; use os_socketaddr::OsSocketAddr; use rdma_mummy_sys::{ - rdma_bind_addr, rdma_cm_event, rdma_cm_id, rdma_create_event_channel, rdma_create_id, rdma_destroy_event_channel, - rdma_destroy_id, rdma_event_channel, rdma_listen, rdma_port_space, + ibv_qp_attr, ibv_qp_cap, ibv_qp_init_attr, rdma_accept, rdma_ack_cm_event, rdma_bind_addr, rdma_cm_event, + rdma_cm_event_type, rdma_cm_id, rdma_conn_param, rdma_connect, rdma_create_event_channel, rdma_create_id, + rdma_create_qp, rdma_destroy_event_channel, rdma_destroy_id, rdma_destroy_qp, rdma_disconnect, rdma_establish, + rdma_event_channel, rdma_get_cm_event, rdma_init_qp_attr, rdma_listen, rdma_port_space, rdma_resolve_addr, + rdma_resolve_route, }; use crate::verbs::device_context::DeviceContext; +use crate::verbs::queue_pair::{BasicQueuePair, QueuePair, QueuePairAttribute, QueuePairState, QueuePairType}; + +#[repr(u32)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum EventType { + AddressResolved = rdma_cm_event_type::RDMA_CM_EVENT_ADDR_RESOLVED, + AddressError = rdma_cm_event_type::RDMA_CM_EVENT_ADDR_ERROR, + RouteResolved = rdma_cm_event_type::RDMA_CM_EVENT_ROUTE_RESOLVED, + RouteError = rdma_cm_event_type::RDMA_CM_EVENT_ROUTE_ERROR, + ConnectRequest = rdma_cm_event_type::RDMA_CM_EVENT_CONNECT_REQUEST, + ConnectResponse = rdma_cm_event_type::RDMA_CM_EVENT_CONNECT_RESPONSE, + ConnectError = rdma_cm_event_type::RDMA_CM_EVENT_CONNECT_ERROR, + Unreachable = rdma_cm_event_type::RDMA_CM_EVENT_UNREACHABLE, + Rejected = rdma_cm_event_type::RDMA_CM_EVENT_REJECTED, + Established = rdma_cm_event_type::RDMA_CM_EVENT_ESTABLISHED, + Disconnected = rdma_cm_event_type::RDMA_CM_EVENT_DISCONNECTED, + DeviceRemoval = rdma_cm_event_type::RDMA_CM_EVENT_DEVICE_REMOVAL, + MulticastJoin = rdma_cm_event_type::RDMA_CM_EVENT_MULTICAST_JOIN, + MulticastError = rdma_cm_event_type::RDMA_CM_EVENT_MULTICAST_ERROR, + AddressChange = rdma_cm_event_type::RDMA_CM_EVENT_ADDR_CHANGE, + TimewaitExit = rdma_cm_event_type::RDMA_CM_EVENT_TIMEWAIT_EXIT, +} + +impl From for EventType { + fn from(event: u32) -> Self { + match event { + rdma_cm_event_type::RDMA_CM_EVENT_ADDR_RESOLVED => EventType::AddressResolved, + rdma_cm_event_type::RDMA_CM_EVENT_ADDR_ERROR => EventType::AddressError, + rdma_cm_event_type::RDMA_CM_EVENT_ROUTE_RESOLVED => EventType::RouteResolved, + rdma_cm_event_type::RDMA_CM_EVENT_ROUTE_ERROR => EventType::RouteError, + rdma_cm_event_type::RDMA_CM_EVENT_CONNECT_REQUEST => EventType::ConnectRequest, + rdma_cm_event_type::RDMA_CM_EVENT_CONNECT_RESPONSE => EventType::ConnectResponse, + rdma_cm_event_type::RDMA_CM_EVENT_CONNECT_ERROR => EventType::ConnectError, + rdma_cm_event_type::RDMA_CM_EVENT_UNREACHABLE => EventType::Unreachable, + rdma_cm_event_type::RDMA_CM_EVENT_REJECTED => EventType::Rejected, + rdma_cm_event_type::RDMA_CM_EVENT_ESTABLISHED => EventType::Established, + rdma_cm_event_type::RDMA_CM_EVENT_DISCONNECTED => EventType::Disconnected, + rdma_cm_event_type::RDMA_CM_EVENT_DEVICE_REMOVAL => EventType::DeviceRemoval, + rdma_cm_event_type::RDMA_CM_EVENT_MULTICAST_JOIN => EventType::MulticastJoin, + rdma_cm_event_type::RDMA_CM_EVENT_MULTICAST_ERROR => EventType::MulticastError, + rdma_cm_event_type::RDMA_CM_EVENT_ADDR_CHANGE => EventType::AddressChange, + rdma_cm_event_type::RDMA_CM_EVENT_TIMEWAIT_EXIT => EventType::TimewaitExit, + _ => panic!("Unknown RDMA CM event type: {event}"), + } + } +} + +static DEVICE_LISTS: LazyLock>>> = LazyLock::new(|| Mutex::new(HashMap::new())); pub struct Event { - _event: NonNull, + event: NonNull, + cm_id: Option>, + listener_id: Option>, } pub struct EventChannel { channel: NonNull, } -pub struct CommunicationManager { +// enum QueuePairStatus { +// SelfCreatedBound, +// SelfCreatedDestroyed, +// CommunicationManagerCreated, +// CommunicationManagerDestroyed, +// NoQueuePairBound, +// } + +pub struct Identifier { cm_id: NonNull, + // queue_pair_status: QueuePairStatus, + user_context: Mutex>>, } +pub struct ConnectionParameter(rdma_conn_param); + pub enum PortSpace { - Ib = rdma_port_space::RDMA_PS_IB as isize, - Ipoib = rdma_port_space::RDMA_PS_IPOIB as isize, + /// Provides for any InfiniBand services (UD, UC, RC, XRC, etc.). + InfiniBand = rdma_port_space::RDMA_PS_IB as isize, + IpOverInfiniband = rdma_port_space::RDMA_PS_IPOIB as isize, + /// Provides reliable, connection-oriented QP communication. Unlike TCP, the RDMA port space + /// provides message, not stream, based communication. In other words, this would create a + /// [`QueuePair`] for [`ReliableConnection`]. + /// + /// [`QueuePair`]: crate::verbs::queue_pair::QueuePair + /// [`ReliableConnection`]: crate::verbs::queue_pair::QueuePairType::ReliableConnection + /// Tcp = rdma_port_space::RDMA_PS_TCP as isize, + /// Provides unreliable, connectionless QP communication. Supports both datagram and multicast + /// communication. In other words, this would create a [`QueuePair`] for [`UnreliableDatagram`]. + /// + /// [`QueuePair`]: crate::verbs::queue_pair::QueuePair + /// [`UnreliableDatagram`]: crate::verbs::queue_pair::QueuePairType::UnreliableDatagram + /// Udp = rdma_port_space::RDMA_PS_UDP as isize, } @@ -35,14 +121,104 @@ impl Drop for EventChannel { } } -impl Drop for CommunicationManager { +impl Event { + /// Get the [`CommunicationManager`] associated with this [`Event`]. + /// + /// # Special cases + /// + /// - For [`EventType::ConnectRequest`]: + /// A new [`CommunicationManager`] is automatically created to handle + /// the incoming connection request. This is distinct from the listener + /// [`CommunicationManager`]. + /// + /// - For other event types: + /// Returns the existing [`CommunicationManager`] associated with the event. + /// + /// # Note + /// + /// To access the listener [`CommunicationManager`] in case of a connect request, + /// use the [`listener_id`] method instead. + /// + /// [`listener_id`]: crate::cm::communication_manager::Event::listener_id + /// + pub fn cm_id(&self) -> Option> { + self.cm_id.clone() + } + + /// Get the listener [`CommunicationManager`] associated with this [`Event`]. + /// + /// # Note + /// + /// This method is primarily useful for [`EventType::ConnectRequest`] events, + /// allowing access to the listener that received the connection request, for + /// other events, this method would return [`None`]. + pub fn listener_id(&self) -> Option> { + self.listener_id.clone() + } + + /// Get the event type of this event. + pub fn event_type(&self) -> EventType { + unsafe { self.event.as_ref().event.into() } + } + + /// Get the event status of this event, this would be useful when you get an error + /// event, for example, [`EventType::Rejected`]. + pub fn status(&self) -> i32 { + unsafe { self.event.as_ref().status } + } + + /// Acknowledge and free the communication event. + /// + /// # Note + /// + /// This method should be called to release events allocated by [`get_cm_event`]. + /// There should be a one-to-one correspondence between successful gets and acks. + /// This call frees the event structure and any memory that it references. + /// + /// [`get_cm_event`]: crate::cm::communication_manager::EventChannel::get_cm_event + /// + pub fn ack(mut self) -> Result<(), String> { + let ret = unsafe { rdma_ack_cm_event(self.event.as_mut()) }; + + if ret < 0 { + return Err(format!("Failed to ack cm event {:?}", io::Error::last_os_error())); + } + + self.cm_id.take(); + self.listener_id.take(); + + // The event has been freed by rdma_ack_cm_event, so we don't need to drop it. + std::mem::forget(self); + + Ok(()) + } +} + +impl Drop for Event { fn drop(&mut self) { unsafe { - rdma_destroy_id(self.cm_id.as_mut()); + rdma_ack_cm_event(self.event.as_mut()); } } } +fn new_cm_id_for_raw(raw: *mut rdma_cm_id) -> Result, String> { + let cm = Arc::new(Identifier { + cm_id: NonNull::new(raw).unwrap(), + user_context: Mutex::new(None), + }); + + let weak_cm = Arc::downgrade(&cm.clone()); + let boxed = Box::new(weak_cm); + let raw_box = Box::into_raw(boxed); + + unsafe { + (*raw).context = raw_box as *mut std::ffi::c_void; + } + + Ok(cm) +} + impl EventChannel { pub fn new() -> Result { let channel = unsafe { rdma_create_event_channel() }; @@ -56,40 +232,101 @@ impl EventChannel { }) } - pub fn create_id(&mut self, ctx: DeviceContext) -> Result { - let mut cm_id = MaybeUninit::<*mut rdma_cm_id>::uninit(); - let ret; - - unsafe { - ret = rdma_create_id( - self.channel.as_mut(), - cm_id.as_mut_ptr(), - ctx.context as _, - PortSpace::Tcp as u32, - ); - } + pub fn create_id(&mut self, port_space: PortSpace) -> Result, String> { + let mut cm_id_ptr: *mut rdma_cm_id = null_mut(); + let ret = unsafe { rdma_create_id(self.channel.as_mut(), &mut cm_id_ptr, null_mut(), port_space as u32) }; if ret < 0 { return Err(format!("Failed to create cm_id {ret}")); } - unsafe { - cm_id.assume_init(); + new_cm_id_for_raw(cm_id_ptr) + } + + pub fn get_cm_event(&mut self) -> Result { + let mut event_ptr = MaybeUninit::<*mut rdma_cm_event>::uninit(); + + let ret = unsafe { rdma_get_cm_event(self.channel.as_ptr(), event_ptr.as_mut_ptr()) }; + + if ret < 0 { + return Err(format!("Failed to get cm event {:?}", io::Error::last_os_error())); } - Ok(CommunicationManager { - cm_id: unsafe { NonNull::new(*cm_id.as_mut_ptr()).unwrap_unchecked() }, + let event = unsafe { NonNull::new(event_ptr.assume_init()).unwrap() }; + + let cm_id = unsafe { + let raw_cm_id = event.as_ref().id; + + assert_ne!(raw_cm_id, null_mut()); + if event.as_ref().event == EventType::ConnectRequest as u32 { + // For connect requests, create a new CommunicationManager + Some(new_cm_id_for_raw(raw_cm_id).unwrap()) + } else { + // For other events, return the existing CommunicationManager + let context_ptr = (*raw_cm_id).context as *mut Weak; + assert_ne!(context_ptr, null_mut()); + (*context_ptr).clone().upgrade() + } + }; + + let listener_id = unsafe { + let raw_listen_id = event.as_ref().listen_id; + + if !raw_listen_id.is_null() { + let context_ptr = (*raw_listen_id).context as *mut Weak; + assert_ne!(context_ptr, null_mut()); + (*context_ptr).clone().upgrade() + } else { + None + } + }; + + Ok(Event { + event, + cm_id, + listener_id, }) } +} + +unsafe impl Send for EventChannel {} +unsafe impl Sync for EventChannel {} - pub fn get_cm_event() -> Result { - todo!(); +impl Drop for Identifier { + fn drop(&mut self) { + let cm_id = self.cm_id; + unsafe { + let _ = Box::from_raw((*cm_id.as_ptr()).context as *mut Weak); + rdma_destroy_id(cm_id.as_ptr()); + } } } -impl CommunicationManager { - pub fn bind_addr(&mut self, addr: SocketAddr) -> Result<(), String> { - let ret = unsafe { rdma_bind_addr(self.cm_id.as_mut(), OsSocketAddr::from(addr).as_mut_ptr()) }; +// Mark CommunicationManager as Sync & Send, implying that we guarantee its thread-safety +unsafe impl Sync for Identifier {} +unsafe impl Send for Identifier {} + +impl Identifier { + pub fn setup_context(&self, ctx: C) { + let mut user_data = self.user_context.lock().unwrap(); + *user_data = Some(Arc::new(ctx)); + } + + pub fn get_context(&self) -> Option> { + let user_data = self.user_context.lock().unwrap(); + let arc_any = user_data.as_ref()?.clone(); + arc_any.downcast::().ok() + } + + pub fn port(&self) -> u8 { + let cm_id = self.cm_id; + + unsafe { cm_id.as_ref().port_num } + } + + pub fn bind_addr(&self, addr: SocketAddr) -> Result<(), String> { + let cm_id = self.cm_id; + let ret = unsafe { rdma_bind_addr(cm_id.as_ptr(), OsSocketAddr::from(addr).as_mut_ptr()) }; if ret < 0 { return Err(format!("Failed to bind addr {addr:?}, returned {ret}")); @@ -98,8 +335,47 @@ impl CommunicationManager { Ok(()) } - pub fn listen(&mut self, backlog: i32) -> Result<(), String> { - let ret = unsafe { rdma_listen(self.cm_id.as_mut(), backlog) }; + pub fn resolve_addr( + &self, src_addr: Option, dst_addr: SocketAddr, timeout: Duration, + ) -> Result<(), String> { + let cm_id = self.cm_id; + let timeout_ms: i32 = timeout.as_millis().try_into().unwrap(); + + let ret = unsafe { + rdma_resolve_addr( + cm_id.as_ptr(), + match src_addr { + Some(addr) => OsSocketAddr::from(addr).as_mut_ptr(), + None => null_mut(), + }, + OsSocketAddr::from(dst_addr).as_mut_ptr(), + timeout_ms, + ) + }; + + if ret < 0 { + return Err(format!("Failed to resolve address {ret}")); + } + + Ok(()) + } + + pub fn resolve_route(&self, timeout: Duration) -> Result<(), String> { + let cm_id = self.cm_id; + let timeout_ms: i32 = timeout.as_millis().try_into().unwrap(); + + let ret = unsafe { rdma_resolve_route(cm_id.as_ptr(), timeout_ms) }; + + if ret < 0 { + return Err(format!("Failed to resolve route {ret}")); + } + + Ok(()) + } + + pub fn listen(&self, backlog: i32) -> Result<(), String> { + let cm_id = self.cm_id; + let ret = unsafe { rdma_listen(cm_id.as_ptr(), backlog) }; if ret < 0 { return Err(format!("Failed to listen {ret}")); @@ -107,4 +383,231 @@ impl CommunicationManager { Ok(()) } + + pub fn get_device_context(&self) -> Option> { + let cm_id = self.cm_id; + + unsafe { + let mut guard = DEVICE_LISTS.lock().unwrap(); + let device_ctx = guard + .entry((*cm_id.as_ptr()).verbs as usize) + .or_insert(Arc::new(DeviceContext { + context: (*cm_id.as_ptr()).verbs, + })); + + Some(device_ctx.clone()) + } + } + + pub fn connect(&self, mut conn_param: ConnectionParameter) -> Result<(), String> { + let cm_id = self.cm_id; + let ret = unsafe { rdma_connect(cm_id.as_ptr(), &mut conn_param.0) }; + + if ret < 0 { + return Err(format!("Failed to connect {:?}", io::Error::last_os_error())); + } + + Ok(()) + } + + pub fn disconnect(&self) -> Result<(), String> { + let cm_id = self.cm_id; + let ret = unsafe { rdma_disconnect(cm_id.as_ptr()) }; + + if ret < 0 { + return Err(format!("Failed to disconnect {:?}", io::Error::last_os_error())); + } + + Ok(()) + } + + pub fn accept(&self, conn_param: Option) -> Result<(), String> { + let cm_id = self.cm_id; + + let ret = match conn_param { + Some(mut param) => unsafe { rdma_accept(cm_id.as_ptr(), &mut param.0) }, + None => unsafe { rdma_accept(cm_id.as_ptr(), null_mut()) }, + }; + + if ret < 0 { + return Err(format!("Failed to accept {:?}", io::Error::last_os_error())); + } + + Ok(()) + } + + pub fn establish(&self) -> Result<(), String> { + let cm_id = self.cm_id; + let ret = unsafe { rdma_establish(cm_id.as_ptr()) }; + + if ret < 0 { + return Err(format!("Failed to establish {:?}", io::Error::last_os_error())); + } + + Ok(()) + } + + pub fn create_qp(&self) -> Result<(), String> { + let cm_id = self.cm_id; + let mut attr = ibv_qp_init_attr { + qp_context: null_mut(), + send_cq: null_mut(), + recv_cq: null_mut(), + srq: null_mut(), + cap: ibv_qp_cap { + max_inline_data: 16, + max_recv_sge: 1, + max_recv_wr: 1, + max_send_sge: 1, + max_send_wr: 1, + }, + qp_type: QueuePairType::ReliableConnection as _, + sq_sig_all: 0, + }; + + let ret = unsafe { rdma_create_qp(cm_id.as_ptr(), null_mut(), &mut attr) }; + + if ret < 0 { + return Err(format!("Failed to create qp {:?}", io::Error::last_os_error())); + } + + Ok(()) + } + + pub fn destroy_qp(&self) -> Result<(), String> { + let cm_id = self.cm_id; + + unsafe { rdma_destroy_qp(cm_id.as_ptr()) }; + + Ok(()) + } + + pub fn bind_qp(&self, qp: &impl QueuePair) -> Result<(), String> { + let cm_id = self.cm_id; + + unsafe { (*cm_id.as_ptr()).qp = qp.qp().as_ptr() }; + + Ok(()) + } + + pub fn qp(&self) -> Result, String> { + let cm_id = self.cm_id; + + let qp_ptr = unsafe { cm_id.as_ref().qp }; + + Ok(ManuallyDrop::new(BasicQueuePair { + qp: NonNull::new(qp_ptr).expect("Failed to get bound QP"), + _phantom: PhantomData, + })) + } + + pub fn get_qp_attr(&self, state: QueuePairState) -> Result { + let cm_id = self.cm_id; + let mut attr = MaybeUninit::::uninit(); + let mut mask = 0; + + unsafe { (*attr.as_mut_ptr()).qp_state = state as _ }; + + let ret = unsafe { rdma_init_qp_attr(cm_id.as_ptr(), attr.as_mut_ptr(), &mut mask) }; + + if ret < 0 { + return Err(format!("Failed to get qp attr {:?}", io::Error::last_os_error())); + } + + Ok(QueuePairAttribute::from(unsafe { attr.assume_init_ref() }, mask)) + } +} + +impl Default for ConnectionParameter { + fn default() -> Self { + Self(rdma_conn_param { + private_data: null(), + private_data_len: 0, + responder_resources: 1, + initiator_depth: 1, + flow_control: 0, + retry_count: 7, + rnr_retry_count: 7, + srq: 0, + qp_num: 0, + }) + } +} + +impl ConnectionParameter { + pub fn new() -> Self { + Self(rdma_conn_param { + private_data: null(), + private_data_len: 0, + responder_resources: 0, + initiator_depth: 0, + flow_control: 0, + retry_count: 0, + rnr_retry_count: 0, + srq: 0, + qp_num: 0, + }) + } + + pub fn setup_qp_number(&mut self, qp_number: u32) -> &mut Self { + self.0.qp_num = qp_number; + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{IpAddr, SocketAddr}; + use std::str::FromStr; + + #[test] + fn test_cm_id_reference_count() -> Result<(), Box> { + match EventChannel::new() { + Ok(mut channel) => { + let id = channel.create_id(PortSpace::Tcp).unwrap(); + + assert_eq!(Arc::strong_count(&id), 1); + + let _ = id.resolve_addr( + None, + SocketAddr::from((IpAddr::from_str("127.0.0.1").expect("Invalid IP address"), 0)), + Duration::new(0, 200000000), + ); + + assert_eq!(Arc::strong_count(&id), 1); + + let event = channel.get_cm_event().unwrap(); + + assert_eq!(Arc::strong_count(&id), 2); + + let cm_id = event.cm_id().unwrap(); + + assert_eq!(Arc::strong_count(&id), 3); + assert_eq!(Arc::strong_count(&cm_id), 3); + + event.ack().unwrap(); + + assert_eq!(Arc::strong_count(&id), 2); + assert_eq!(Arc::strong_count(&cm_id), 2); + + Ok(()) + }, + Err(_) => Ok(()), + } + } + + #[test] + fn test_conn_param() -> Result<(), Box> { + match EventChannel::new() { + Ok(mut channel) => { + let _id = channel.create_id(PortSpace::Tcp).unwrap(); + + let _param = ConnectionParameter::new(); + + Ok(()) + }, + Err(_) => Ok(()), + } + } } diff --git a/src/verbs/queue_pair.rs b/src/verbs/queue_pair.rs index e7b99e4..b8c4cef 100644 --- a/src/verbs/queue_pair.rs +++ b/src/verbs/queue_pair.rs @@ -531,7 +531,7 @@ lazy_static! { pub struct BasicQueuePair<'res> { pub(crate) qp: NonNull, // phantom data for protection domain & completion queues - _phantom: PhantomData<&'res ()>, + pub(crate) _phantom: PhantomData<&'res ()>, } unsafe impl Send for BasicQueuePair<'_> {}