Skip to content

Commit

Permalink
feat(qp): introduce ModifyQueuePairError and PostSendError
Browse files Browse the repository at this point in the history
Also introduce anyhow for examples.

Signed-off-by: Luke Yue <[email protected]>
  • Loading branch information
dragonJACson committed Oct 13, 2024
1 parent 06bab73 commit 55beae3
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 60 deletions.
13 changes: 13 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ os_socketaddr = "0.2"
bitmask-enum = "2.2"
lazy_static = "1.5.0"
serde = { version = "1.0", features = ["derive"] }
thiserror = "1.0.64"

[dev-dependencies]
trybuild = "1.0"
Expand All @@ -33,3 +34,15 @@ quanta = "0.12"
byte-unit = "5.1"
ouroboros = "0.18"
proptest = "1.5"
anyhow = "1.0"

[features]
debug = []

[[example]]
name = "rc_pingpong"
required-features = ["debug"]

[[example]]
name = "rc_pingpong_split"
required-features = ["debug"]
34 changes: 16 additions & 18 deletions examples/rc_pingpong.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ struct TimeStamps {
}

#[allow(clippy::while_let_on_iterator)]
fn main() -> Result<(), Box<dyn std::error::Error>> {
fn main() -> anyhow::Result<()> {
let args = Args::parse();
let mut scnt: u32 = 0;
let mut rcnt: u32 = 0;
Expand All @@ -135,16 +135,16 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let device = match args.ib_dev {
Some(ib_dev) => device_list
.iter()
.find(|dev| dev.name().unwrap().eq(&ib_dev))
.find(|dev| dev.name()?.eq(&ib_dev))
.unwrap_or_else(|| panic!("IB device {ib_dev} not found")),
None => device_list.iter().next().expect("No IB device found"),
};

let context = device
.open()
.unwrap_or_else(|_| panic!("Couldn't get context for {}", device.name().unwrap()));
.unwrap_or_else(|_| panic!("Couldn't get context for {}", device.name()?));

let attr = context.query_device().unwrap();
let attr = context.query_device()?;

if args.ts {
completion_timestamp_mask = attr.completion_timestamp_mask();
Expand All @@ -162,7 +162,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.reg_managed_mr(args.size as _)
.unwrap_or_else(|_| panic!("Couldn't register recv MR"));

let gid = context.query_gid(args.ib_port, args.gid_idx.into()).unwrap();
let gid = context.query_gid(args.ib_port, args.gid_idx.into())?;
let psn = rand::random::<u32>() & 0xFFFFFF;

let mut cq_builder = context.create_cq_builder();
Expand All @@ -174,7 +174,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
);
}

let cq = cq_builder.setup_cqe(rx_depth + 1).build_ex().unwrap();
let cq = cq_builder.setup_cqe(rx_depth + 1).build_ex()?;

let mut builder = pd.create_qp_builder();

Expand All @@ -192,7 +192,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.setup_pkey_index(0)
.setup_port(args.ib_port)
.setup_access_flags(AccessFlags::LocalWrite | AccessFlags::RemoteWrite);
qp.modify(&attr).unwrap();
qp.modify(&attr)?;

for _i in 0..rx_depth {
let mut guard = qp.start_post_recv();
Expand All @@ -203,7 +203,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
recv_handle.setup_sge(recv_mr.lkey(), recv_mr.buf.data.as_ptr() as _, args.size);
};

guard.post().unwrap();
guard.post()?;
}

rout += rx_depth;
Expand All @@ -225,7 +225,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
};

let send_context = |stream: &mut TcpStream, dest: &PingPongDestination| {
let msg_buf = to_allocvec(dest).unwrap();
let msg_buf = to_allocvec(dest)?;
let size = msg_buf.len().to_be_bytes();
stream.write_all(&size)?;
stream.write_all(&msg_buf)?;
Expand All @@ -240,7 +240,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
msg_buf.clear();
msg_buf.resize(usize::from_be_bytes(size), 0);
stream.read_exact(&mut *msg_buf)?;
let dest: PingPongDestination = from_bytes(msg_buf).unwrap();
let dest: PingPongDestination = from_bytes(msg_buf)?;

Ok::<PingPongDestination, Error>(dest)
};
Expand Down Expand Up @@ -278,7 +278,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.setup_grh_dest_gid(&remote_context.gid)
.setup_grh_hop_limit(1);
attr.setup_address_vector(&ah_attr);
qp.modify(&attr).unwrap();
qp.modify(&attr)?;

let mut attr = QueuePairAttribute::new();
attr.setup_state(QueuePairState::ReadyToSend)
Expand All @@ -288,7 +288,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.setup_rnr_retry(7)
.setup_max_read_atomic(0);

qp.modify(&attr).unwrap();
qp.modify(&attr)?;

let clock = quanta::Clock::new();
let start_time = clock.now();
Expand All @@ -303,7 +303,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
send_handle.setup_sge(send_mr.lkey(), send_mr.buf.data.as_ptr() as _, send_mr.buf.len as _);
}

guard.post().unwrap();
guard.post()?;
outstanding_send = true;
}
// poll for the completion
Expand Down Expand Up @@ -342,7 +342,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
args.size,
);
};
guard.post().unwrap();
guard.post()?;
}
rout += to_post;
}
Expand Down Expand Up @@ -388,7 +388,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
send_mr.buf.len as _,
);
}
guard.post().unwrap();
guard.post()?;
outstanding_send = true;
}
}
Expand All @@ -414,9 +414,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
"{} bytes in {:.2} seconds = {:.2}/s",
bytes,
time.as_secs_f64(),
Byte::from_f64(bytes_per_second)
.unwrap()
.get_appropriate_unit(UnitType::Binary)
Byte::from_f64(bytes_per_second)?.get_appropriate_unit(UnitType::Binary)
);
println!(
"{} iters in {:.2} seconds = {:#.2?}/iter",
Expand Down
42 changes: 20 additions & 22 deletions examples/rc_pingpong_split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,12 @@ struct PingPongContext {
}

impl PingPongContext {
fn build(device: &Device, size: u32, rx_depth: u32, ib_port: u8, use_ts: bool) -> Result<PingPongContext, String> {
fn build(device: &Device, size: u32, rx_depth: u32, ib_port: u8, use_ts: bool) -> anyhow::Result<PingPongContext> {
let context = device
.open()
.unwrap_or_else(|_| panic!("Couldn't get context for {}", device.name().unwrap()));
.unwrap_or_else(|_| panic!("Couldn't get context for {}", device.name()?));

let attr = context.query_device().unwrap();
let attr = context.query_device()?;

let completion_timestamp_mask = if use_ts {
match attr.completion_timestamp_mask() {
Expand All @@ -167,7 +167,7 @@ impl PingPongContext {
| CreateCompletionQueueWorkCompletionFlags::CompletionTimestamp,
);
}
let cq = cq_builder.setup_cqe(rx_depth + 1).build_ex().unwrap();
let cq = cq_builder.setup_cqe(rx_depth + 1).build_ex()?;
cq
},
|pd, cq| {
Expand Down Expand Up @@ -209,7 +209,7 @@ impl PingPongContext {
))
}

fn post_recv(&mut self, num: u32) -> Result<(), String> {
fn post_recv(&mut self, num: u32) -> anyhow::Result<()> {
for _i in 0..num {
let (mut guard, lkey, ptr, size) = self.with_mut(|fields| {
(
Expand All @@ -226,13 +226,13 @@ impl PingPongContext {
recv_handle.setup_sge(lkey, ptr, size);
};

guard.post().unwrap();
guard.post()?;
}

Ok(())
}

fn post_send(&mut self) -> Result<(), String> {
fn post_send(&mut self) -> anyhow::Result<()> {
let (mut guard, lkey, ptr, size) = self.with_mut(|fields| {
(
fields.qp.start_post_send(),
Expand All @@ -253,7 +253,7 @@ impl PingPongContext {

fn connect(
&mut self, remote_context: &PingPongDestination, ib_port: u8, psn: u32, mtu: Mtu, sl: u8, gid_idx: u8,
) -> Result<(), String> {
) -> anyhow::Result<()> {
let mut attr = QueuePairAttribute::new();
attr.setup_state(QueuePairState::ReadyToReceive)
.setup_path_mtu(mtu)
Expand Down Expand Up @@ -362,7 +362,7 @@ struct TimeStamps {
}

#[allow(clippy::while_let_on_iterator)]
fn main() -> Result<(), Box<dyn std::error::Error>> {
fn main() -> anyhow::Result<()> {
let args = Args::parse();
let mut scnt: u32 = 0;
let mut rcnt: u32 = 0;
Expand All @@ -381,17 +381,17 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let device = match args.ib_dev {
Some(ib_dev) => device_list
.iter()
.find(|dev| dev.name().unwrap().eq(&ib_dev))
.find(|dev| dev.name()?.eq(&ib_dev))
.unwrap_or_else(|| panic!("IB device {ib_dev} not found")),
None => device_list.iter().next().expect("No IB device found"),
};

let mut ctx = PingPongContext::build(&device, args.size, rx_depth, args.ib_port, args.ts).unwrap();
let mut ctx = PingPongContext::build(&device, args.size, rx_depth, args.ib_port, args.ts)?;

let gid = ctx.borrow_ctx().query_gid(args.ib_port, args.gid_idx.into()).unwrap();
let gid = ctx.borrow_ctx().query_gid(args.ib_port, args.gid_idx.into())?;
let psn = rand::random::<u32>() & 0xFFFFFF;

ctx.post_recv(rx_depth).unwrap();
ctx.post_recv(rx_depth)?;
rout += rx_depth;

println!(
Expand All @@ -414,7 +414,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
};

let send_context = |stream: &mut TcpStream, dest: &PingPongDestination| {
let msg_buf = to_allocvec(dest).unwrap();
let msg_buf = to_allocvec(dest)?;
let size = msg_buf.len().to_be_bytes();
stream.write_all(&size)?;
stream.write_all(&msg_buf)?;
Expand All @@ -429,7 +429,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
msg_buf.clear();
msg_buf.resize(usize::from_be_bytes(size), 0);
stream.read_exact(&mut *msg_buf)?;
let dest: PingPongDestination = from_bytes(msg_buf).unwrap();
let dest: PingPongDestination = from_bytes(msg_buf)?;

Ok::<PingPongDestination, Error>(dest)
};
Expand All @@ -448,15 +448,14 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
remote_context.qp_number, remote_context.packet_seq_number, remote_context.gid
);

ctx.connect(&remote_context, args.ib_port, psn, args.mtu.0, args.sl, args.gid_idx)
.unwrap();
ctx.connect(&remote_context, args.ib_port, psn, args.mtu.0, args.sl, args.gid_idx)?;

let clock = quanta::Clock::new();
let start_time = clock.now();
let mut outstanding_send = false;

if args.server_ip.is_some() {
ctx.post_send().unwrap();
ctx.post_send()?;
outstanding_send = true;
}
// poll for the completion
Expand Down Expand Up @@ -497,12 +496,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}

if need_post_recv {
ctx.post_recv(to_post_recv).unwrap();
ctx.post_recv(to_post_recv)?;
rout += to_post_recv;
}

if need_post_send {
ctx.post_send().unwrap();
ctx.post_send()?;
}

// Check if we're done
Expand All @@ -521,8 +520,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
"{} bytes in {:.2} seconds = {:.2}/s",
bytes,
time.as_secs_f64(),
Byte::from_f64(bytes_per_second)
.unwrap()
Byte::from_f64(bytes_per_second)?
.get_appropriate_unit(UnitType::Binary)
);
println!(
Expand Down
Loading

0 comments on commit 55beae3

Please sign in to comment.