diff --git a/.github/workflows/ci_test.yml b/.github/workflows/ci_test.yml index 579fad74..31ac965c 100644 --- a/.github/workflows/ci_test.yml +++ b/.github/workflows/ci_test.yml @@ -43,6 +43,9 @@ jobs: target: "x86_64-pc-windows-gnu" - os: "windows-latest" target: "i686-pc-windows-msvc" + - os: "windows-latest" + target: "x86_64-pc-windows-msvc" + features: "iocp-global" - os: "macos-12" - os: "macos-13" - os: "macos-14" diff --git a/compio-dispatcher/tests/listener.rs b/compio-dispatcher/tests/listener.rs index 9ffda59c..4be7cdb0 100644 --- a/compio-dispatcher/tests/listener.rs +++ b/compio-dispatcher/tests/listener.rs @@ -1,10 +1,10 @@ use std::{num::NonZeroUsize, panic::resume_unwind}; -use compio_buf::{arrayvec::ArrayVec, IntoInner}; +use compio_buf::arrayvec::ArrayVec; use compio_dispatcher::Dispatcher; use compio_io::{AsyncReadExt, AsyncWriteExt}; use compio_net::{TcpListener, TcpStream}; -use compio_runtime::{spawn, Unattached}; +use compio_runtime::spawn; use futures_util::{stream::FuturesUnordered, StreamExt}; #[compio_macros::test] @@ -27,15 +27,11 @@ async fn listener_dispatch() { }); let mut handles = FuturesUnordered::new(); for _i in 0..CLIENT_NUM { - let (srv, _) = listener.accept().await.unwrap(); - let srv = Unattached::new(srv).unwrap(); + let (mut srv, _) = listener.accept().await.unwrap(); let handle = dispatcher - .dispatch(move || { - let mut srv = srv.into_inner(); - async move { - let (_, buf) = srv.read_exact(ArrayVec::::new()).await.unwrap(); - assert_eq!(buf.as_slice(), b"Hello world!"); - } + .dispatch(move || async move { + let (_, buf) = srv.read_exact(ArrayVec::::new()).await.unwrap(); + assert_eq!(buf.as_slice(), b"Hello world!"); }) .unwrap(); handles.push(handle.join()); diff --git a/compio-driver/Cargo.toml b/compio-driver/Cargo.toml index f44ce070..c1fddf52 100644 --- a/compio-driver/Cargo.toml +++ b/compio-driver/Cargo.toml @@ -70,6 +70,7 @@ polling = "3.3.0" os_pipe = { workspace = true } [target.'cfg(unix)'.dependencies] +crossbeam-channel = { workspace = true } crossbeam-queue = { workspace = true } libc = { workspace = true } @@ -83,6 +84,8 @@ polling = ["dep:polling", "dep:os_pipe"] io-uring-sqe128 = [] io-uring-cqe32 = [] +iocp-global = [] + # Nightly features once_cell_try = [] nightly = ["once_cell_try"] diff --git a/compio-driver/src/fusion/mod.rs b/compio-driver/src/fusion/mod.rs index de13546e..65fddd03 100644 --- a/compio-driver/src/fusion/mod.rs +++ b/compio-driver/src/fusion/mod.rs @@ -132,6 +132,13 @@ impl Driver { } } + pub fn create_op(&self, user_data: usize, op: T) -> RawOp { + match &self.fuse { + FuseDriver::Poll(driver) => driver.create_op(user_data, op), + FuseDriver::IoUring(driver) => driver.create_op(user_data, op), + } + } + pub fn attach(&mut self, fd: RawFd) -> io::Result<()> { match &mut self.fuse { FuseDriver::Poll(driver) => driver.attach(fd), diff --git a/compio-driver/src/iocp/cp/global.rs b/compio-driver/src/iocp/cp/global.rs new file mode 100644 index 00000000..fd093215 --- /dev/null +++ b/compio-driver/src/iocp/cp/global.rs @@ -0,0 +1,149 @@ +#[cfg(feature = "once_cell_try")] +use std::sync::OnceLock; +use std::{ + io, + os::windows::io::{AsRawHandle, RawHandle}, + time::Duration, +}; + +use compio_log::*; +#[cfg(not(feature = "once_cell_try"))] +use once_cell::sync::OnceCell as OnceLock; +use windows_sys::Win32::System::IO::PostQueuedCompletionStatus; + +use super::CompletionPort; +use crate::{syscall, Entry, Overlapped, RawFd}; + +struct GlobalPort { + port: CompletionPort, +} + +impl GlobalPort { + pub fn new() -> io::Result { + Ok(Self { + port: CompletionPort::new()?, + }) + } + + pub fn attach(&self, fd: RawFd) -> io::Result<()> { + self.port.attach(fd) + } + + pub fn post( + &self, + res: io::Result, + optr: *mut Overlapped, + ) -> io::Result<()> { + self.port.post(res, optr) + } + + pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> { + self.port.post_raw(optr) + } +} + +impl AsRawHandle for GlobalPort { + fn as_raw_handle(&self) -> RawHandle { + self.port.as_raw_handle() + } +} + +static IOCP_PORT: OnceLock = OnceLock::new(); + +#[inline] +fn iocp_port() -> io::Result<&'static GlobalPort> { + IOCP_PORT.get_or_try_init(GlobalPort::new) +} + +fn iocp_start() -> io::Result<()> { + let port = iocp_port()?; + std::thread::spawn(move || { + instrument!(compio_log::Level::TRACE, "iocp_start"); + loop { + for entry in port.port.poll_raw(None)? { + // Any thin pointer is OK because we don't use the type of opcode. + let overlapped_ptr: *mut Overlapped<()> = entry.lpOverlapped.cast(); + let overlapped = unsafe { &*overlapped_ptr }; + if let Err(_e) = syscall!( + BOOL, + PostQueuedCompletionStatus( + overlapped.driver as _, + entry.dwNumberOfBytesTransferred, + entry.lpCompletionKey, + entry.lpOverlapped, + ) + ) { + error!( + "fail to dispatch entry ({}, {}, {:p}) to driver {:p}: {:?}", + entry.dwNumberOfBytesTransferred, + entry.lpCompletionKey, + entry.lpOverlapped, + overlapped.driver, + _e + ); + } + } + } + #[allow(unreachable_code)] + io::Result::Ok(()) + }); + Ok(()) +} + +static IOCP_INIT_ONCE: OnceLock<()> = OnceLock::new(); + +pub struct Port { + port: CompletionPort, + global_port: &'static GlobalPort, +} + +impl Port { + pub fn new() -> io::Result { + IOCP_INIT_ONCE.get_or_try_init(iocp_start)?; + + Ok(Self { + port: CompletionPort::new()?, + global_port: iocp_port()?, + }) + } + + pub fn attach(&mut self, fd: RawFd) -> io::Result<()> { + self.global_port.attach(fd) + } + + pub fn handle(&self) -> PortHandle { + PortHandle::new(self.global_port) + } + + pub fn poll(&self, timeout: Option) -> io::Result + '_> { + self.port.poll(timeout, None) + } +} + +impl AsRawHandle for Port { + fn as_raw_handle(&self) -> RawHandle { + self.port.as_raw_handle() + } +} + +pub struct PortHandle { + port: &'static GlobalPort, +} + +impl PortHandle { + fn new(port: &'static GlobalPort) -> Self { + Self { port } + } + + pub fn post( + &self, + res: io::Result, + optr: *mut Overlapped, + ) -> io::Result<()> { + self.port.post(res, optr) + } + + pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> { + self.port.post_raw(optr) + } +} diff --git a/compio-driver/src/iocp/cp/mod.rs b/compio-driver/src/iocp/cp/mod.rs new file mode 100644 index 00000000..f78972f0 --- /dev/null +++ b/compio-driver/src/iocp/cp/mod.rs @@ -0,0 +1,201 @@ +//! Completion Port +//! +//! This mod contains utilities of IOCP. It provides 2 working modes: +//! IOCP-per-thread, and IOCP-global. +//! +//! ## IOCP-per-thread +//! In `mod multi`. Each driver hosts a seperate port. If the port receives +//! entry that doesn't belong to the current port, it will try to repost it to +//! the correct port. +//! +//! ## IOCP-global +//! In `mod global`. A main port runs in a separate thread, and dispatches all +//! entries to the correct driver. + +use std::{ + io, + os::windows::io::{AsRawHandle, FromRawHandle, OwnedHandle, RawHandle}, + time::Duration, +}; + +use compio_buf::arrayvec::ArrayVec; +use compio_log::*; +use windows_sys::Win32::{ + Foundation::{ + RtlNtStatusToDosError, ERROR_BAD_COMMAND, ERROR_HANDLE_EOF, ERROR_IO_INCOMPLETE, + ERROR_NO_DATA, FACILITY_NTWIN32, INVALID_HANDLE_VALUE, NTSTATUS, STATUS_PENDING, + STATUS_SUCCESS, + }, + Storage::FileSystem::SetFileCompletionNotificationModes, + System::{ + SystemServices::ERROR_SEVERITY_ERROR, + Threading::INFINITE, + WindowsProgramming::{FILE_SKIP_COMPLETION_PORT_ON_SUCCESS, FILE_SKIP_SET_EVENT_ON_HANDLE}, + IO::{ + CreateIoCompletionPort, GetQueuedCompletionStatusEx, PostQueuedCompletionStatus, + OVERLAPPED_ENTRY, + }, + }, +}; + +use crate::{syscall, Entry, Overlapped, RawFd}; + +cfg_if::cfg_if! { + if #[cfg(feature = "iocp-global")] { + mod global; + pub use global::*; + } else { + mod multi; + pub use multi::*; + } +} + +struct CompletionPort { + port: OwnedHandle, +} + +impl CompletionPort { + pub fn new() -> io::Result { + let port = syscall!(BOOL, CreateIoCompletionPort(INVALID_HANDLE_VALUE, 0, 0, 1))?; + trace!("new iocp handle: {port}"); + let port = unsafe { OwnedHandle::from_raw_handle(port as _) }; + Ok(Self { port }) + } + + pub fn attach(&self, fd: RawFd) -> io::Result<()> { + syscall!( + BOOL, + CreateIoCompletionPort(fd as _, self.port.as_raw_handle() as _, 0, 0) + )?; + syscall!( + BOOL, + SetFileCompletionNotificationModes( + fd as _, + (FILE_SKIP_COMPLETION_PORT_ON_SUCCESS | FILE_SKIP_SET_EVENT_ON_HANDLE) as _ + ) + )?; + Ok(()) + } + + pub fn post( + &self, + res: io::Result, + optr: *mut Overlapped, + ) -> io::Result<()> { + if let Some(overlapped) = unsafe { optr.as_mut() } { + match &res { + Ok(transferred) => { + overlapped.base.Internal = STATUS_SUCCESS as _; + overlapped.base.InternalHigh = *transferred; + } + Err(e) => { + let code = e.raw_os_error().unwrap_or(ERROR_BAD_COMMAND as _); + overlapped.base.Internal = ntstatus_from_win32(code) as _; + } + } + } + self.post_raw(optr) + } + + pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> { + syscall!( + BOOL, + PostQueuedCompletionStatus(self.port.as_raw_handle() as _, 0, 0, optr.cast()) + )?; + Ok(()) + } + + pub fn poll_raw( + &self, + timeout: Option, + ) -> io::Result> { + const DEFAULT_CAPACITY: usize = 1024; + + let mut entries = ArrayVec::::new(); + let mut recv_count = 0; + let timeout = match timeout { + Some(timeout) => timeout.as_millis() as u32, + None => INFINITE, + }; + syscall!( + BOOL, + GetQueuedCompletionStatusEx( + self.port.as_raw_handle() as _, + entries.as_mut_ptr(), + DEFAULT_CAPACITY as _, + &mut recv_count, + timeout, + 0 + ) + )?; + trace!("recv_count: {recv_count}"); + unsafe { entries.set_len(recv_count as _) }; + + Ok(entries.into_iter()) + } + + // If current_driver is specified, any entry that doesn't belong the driver will + // be reposted. The driver id will be used as IOCP handle. + pub fn poll( + &self, + timeout: Option, + current_driver: Option, + ) -> io::Result> { + Ok(self.poll_raw(timeout)?.map(move |entry| { + // Any thin pointer is OK because we don't use the type of opcode. + let overlapped_ptr: *mut Overlapped<()> = entry.lpOverlapped.cast(); + let overlapped = unsafe { &*overlapped_ptr }; + if let Some(current_driver) = current_driver { + if overlapped.driver != current_driver { + // Repose the entry to correct port. + if let Err(_e) = syscall!( + BOOL, + PostQueuedCompletionStatus( + overlapped.driver as _, + entry.dwNumberOfBytesTransferred, + entry.lpCompletionKey, + entry.lpOverlapped, + ) + ) { + error!( + "fail to repost entry ({}, {}, {:p}) to driver {:p}: {:?}", + entry.dwNumberOfBytesTransferred, + entry.lpCompletionKey, + entry.lpOverlapped, + overlapped.driver, + _e + ); + } + } + } + let res = if matches!( + overlapped.base.Internal as NTSTATUS, + STATUS_SUCCESS | STATUS_PENDING + ) { + Ok(overlapped.base.InternalHigh) + } else { + let error = unsafe { RtlNtStatusToDosError(overlapped.base.Internal as _) }; + match error { + ERROR_IO_INCOMPLETE | ERROR_HANDLE_EOF | ERROR_NO_DATA => Ok(0), + _ => Err(io::Error::from_raw_os_error(error as _)), + } + }; + Entry::new(overlapped.user_data, res) + })) + } +} + +impl AsRawHandle for CompletionPort { + fn as_raw_handle(&self) -> RawHandle { + self.port.as_raw_handle() + } +} + +#[inline] +fn ntstatus_from_win32(x: i32) -> NTSTATUS { + if x <= 0 { + x + } else { + ((x) & 0x0000FFFF) | (FACILITY_NTWIN32 << 16) as NTSTATUS | ERROR_SEVERITY_ERROR as NTSTATUS + } +} diff --git a/compio-driver/src/iocp/cp/multi.rs b/compio-driver/src/iocp/cp/multi.rs new file mode 100644 index 00000000..5f99956c --- /dev/null +++ b/compio-driver/src/iocp/cp/multi.rs @@ -0,0 +1,62 @@ +use std::{ + io, + os::windows::io::{AsRawHandle, RawHandle}, + sync::Arc, + time::Duration, +}; + +use super::CompletionPort; +use crate::{Entry, Overlapped, RawFd}; + +pub struct Port { + port: Arc, +} + +impl Port { + pub fn new() -> io::Result { + Ok(Self { + port: Arc::new(CompletionPort::new()?), + }) + } + + pub fn attach(&mut self, fd: RawFd) -> io::Result<()> { + self.port.attach(fd) + } + + pub fn handle(&self) -> PortHandle { + PortHandle::new(self.port.clone()) + } + + pub fn poll(&self, timeout: Option) -> io::Result + '_> { + let current_id = self.as_raw_handle(); + self.port.poll(timeout, Some(current_id)) + } +} + +impl AsRawHandle for Port { + fn as_raw_handle(&self) -> RawHandle { + self.port.as_raw_handle() + } +} + +pub struct PortHandle { + port: Arc, +} + +impl PortHandle { + fn new(port: Arc) -> Self { + Self { port } + } + + pub fn post( + &self, + res: io::Result, + optr: *mut Overlapped, + ) -> io::Result<()> { + self.port.post(res, optr) + } + + pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> { + self.port.post_raw(optr) + } +} diff --git a/compio-driver/src/iocp/mod.rs b/compio-driver/src/iocp/mod.rs index f818b976..4f900d65 100644 --- a/compio-driver/src/iocp/mod.rs +++ b/compio-driver/src/iocp/mod.rs @@ -4,41 +4,30 @@ use std::{ mem::ManuallyDrop, os::windows::prelude::{ AsRawHandle, AsRawSocket, FromRawHandle, FromRawSocket, IntoRawHandle, IntoRawSocket, - OwnedHandle, RawHandle, + RawHandle, }, pin::Pin, - ptr::{null_mut, NonNull}, + ptr::NonNull, sync::Arc, task::Poll, time::Duration, }; -use compio_buf::{arrayvec::ArrayVec, BufResult}; +use compio_buf::BufResult; use compio_log::{instrument, trace}; use slab::Slab; use windows_sys::Win32::{ - Foundation::{ - RtlNtStatusToDosError, ERROR_BAD_COMMAND, ERROR_BUSY, ERROR_HANDLE_EOF, - ERROR_IO_INCOMPLETE, ERROR_NO_DATA, ERROR_OPERATION_ABORTED, FACILITY_NTWIN32, - INVALID_HANDLE_VALUE, NTSTATUS, STATUS_PENDING, STATUS_SUCCESS, - }, + Foundation::{ERROR_BUSY, ERROR_OPERATION_ABORTED}, Networking::WinSock::{WSACleanup, WSAStartup, WSADATA}, - Storage::FileSystem::SetFileCompletionNotificationModes, - System::{ - SystemServices::ERROR_SEVERITY_ERROR, - Threading::INFINITE, - WindowsProgramming::{FILE_SKIP_COMPLETION_PORT_ON_SUCCESS, FILE_SKIP_SET_EVENT_ON_HANDLE}, - IO::{ - CreateIoCompletionPort, GetQueuedCompletionStatusEx, PostQueuedCompletionStatus, - OVERLAPPED, OVERLAPPED_ENTRY, - }, - }, + System::IO::OVERLAPPED, }; use crate::{syscall, AsyncifyPool, Entry, OutEntries, ProactorBuilder}; pub(crate) mod op; +mod cp; + pub(crate) use windows_sys::Win32::Networking::WinSock::{ socklen_t, SOCKADDR_STORAGE as sockaddr_storage, }; @@ -141,24 +130,15 @@ pub trait OpCode { } } -fn ntstatus_from_win32(x: i32) -> NTSTATUS { - if x <= 0 { - x - } else { - ((x) & 0x0000FFFF) | (FACILITY_NTWIN32 << 16) as NTSTATUS | ERROR_SEVERITY_ERROR as NTSTATUS - } -} - /// Low-level driver of IOCP. pub(crate) struct Driver { - // IOCP handle could not be duplicated. - port: Arc, + port: cp::Port, cancelled: HashSet, pool: AsyncifyPool, + notify_overlapped: Arc>, } impl Driver { - const DEFAULT_CAPACITY: usize = 1024; const NOTIFY: usize = usize::MAX; pub fn new(builder: &ProactorBuilder) -> io::Result { @@ -166,96 +146,22 @@ impl Driver { let mut data: WSADATA = unsafe { std::mem::zeroed() }; syscall!(SOCKET, WSAStartup(0x202, &mut data))?; - let port = syscall!(BOOL, CreateIoCompletionPort(INVALID_HANDLE_VALUE, 0, 0, 0))?; - trace!("new iocp driver at port: {port}"); - let port = unsafe { OwnedHandle::from_raw_handle(port as _) }; + let port = cp::Port::new()?; + let driver = port.as_raw_handle() as _; Ok(Self { - port: Arc::new(port), + port, cancelled: HashSet::default(), pool: builder.create_or_get_thread_pool(), + notify_overlapped: Arc::new(Overlapped::new(driver, Self::NOTIFY, ())), }) } - #[inline] - fn poll_impl( - &mut self, - timeout: Option, - iocp_entries: &mut ArrayVec, - ) -> io::Result<()> { - instrument!(compio_log::Level::TRACE, "poll_impl", ?timeout); - let mut recv_count = 0; - let timeout = match timeout { - Some(timeout) => timeout.as_millis() as u32, - None => INFINITE, - }; - syscall!( - BOOL, - GetQueuedCompletionStatusEx( - self.port.as_raw_handle() as _, - iocp_entries.as_mut_ptr(), - N as _, - &mut recv_count, - timeout, - 0, - ) - )?; - trace!("recv_count: {recv_count}"); - unsafe { - iocp_entries.set_len(recv_count as _); - } - Ok(()) - } - - fn create_entry(&mut self, iocp_entry: OVERLAPPED_ENTRY) -> Option { - if iocp_entry.lpOverlapped.is_null() { - // This entry is posted by `post_driver_nop`. - let user_data = iocp_entry.lpCompletionKey; - trace!("entry {user_data} is posted by post_driver_nop"); - if user_data != Self::NOTIFY { - let result = if self.cancelled.remove(&user_data) { - Err(io::Error::from_raw_os_error(ERROR_OPERATION_ABORTED as _)) - } else { - Ok(0) - }; - Some(Entry::new(user_data, result)) - } else { - None - } - } else { - let transferred = iocp_entry.dwNumberOfBytesTransferred; - // Any thin pointer is OK because we don't use the type of opcode. - trace!("entry transferred: {transferred}"); - let overlapped_ptr: *mut Overlapped<()> = iocp_entry.lpOverlapped.cast(); - let overlapped = unsafe { &*overlapped_ptr }; - let res = if matches!( - overlapped.base.Internal as NTSTATUS, - STATUS_SUCCESS | STATUS_PENDING - ) { - Ok(transferred as _) - } else { - let error = unsafe { RtlNtStatusToDosError(overlapped.base.Internal as _) }; - match error { - ERROR_IO_INCOMPLETE | ERROR_HANDLE_EOF | ERROR_NO_DATA => Ok(0), - _ => Err(io::Error::from_raw_os_error(error as _)), - } - }; - Some(Entry::new(overlapped.user_data, res)) - } + pub fn create_op(&self, user_data: usize, op: T) -> RawOp { + RawOp::new(self.port.as_raw_handle() as _, user_data, op) } pub fn attach(&mut self, fd: RawFd) -> io::Result<()> { - syscall!( - BOOL, - CreateIoCompletionPort(fd as _, self.port.as_raw_handle() as _, 0, 0) - )?; - syscall!( - BOOL, - SetFileCompletionNotificationModes( - fd as _, - (FILE_SKIP_COMPLETION_PORT_ON_SUCCESS | FILE_SKIP_SET_EVENT_ON_HANDLE) as _ - ) - )?; - Ok(()) + self.port.attach(fd) } pub fn cancel(&mut self, user_data: usize, registry: &mut Slab) { @@ -284,7 +190,7 @@ impl Driver { let op_pin = op.as_op_pin(); if op_pin.is_overlapped() { unsafe { op_pin.operate(optr.cast()) } - } else if self.push_blocking(op) { + } else if self.push_blocking(op)? { Poll::Pending } else { Poll::Ready(Err(io::Error::from_raw_os_error(ERROR_BUSY as _))) @@ -292,14 +198,15 @@ impl Driver { } } - fn push_blocking(&mut self, op: &mut RawOp) -> bool { + fn push_blocking(&mut self, op: &mut RawOp) -> io::Result { // Safety: the RawOp is not released before the operation returns. struct SendWrapper(T); unsafe impl Send for SendWrapper {} let optr = SendWrapper(NonNull::from(op)); - let handle = self.as_raw_fd() as _; - self.pool + let port = self.port.handle(); + Ok(self + .pool .dispatch(move || { #[allow(clippy::redundant_locals)] let mut optr = optr; @@ -312,22 +219,23 @@ impl Driver { Poll::Pending => unreachable!("this operation is not overlapped"), Poll::Ready(res) => res, }; - if let Err(e) = &res { - let code = e.raw_os_error().unwrap_or(ERROR_BAD_COMMAND as _); - unsafe { &mut *optr }.base.Internal = ntstatus_from_win32(code) as _; - } - syscall!( - BOOL, - PostQueuedCompletionStatus( - handle, - res.unwrap_or_default() as _, - 0, - optr.cast() - ) - ) - .ok(); + port.post(res, optr).ok(); }) - .is_ok() + .is_ok()) + } + + fn create_entry(cancelled: &mut HashSet, entry: Entry) -> Option { + let user_data = entry.user_data(); + if user_data != Self::NOTIFY { + let result = if cancelled.remove(&user_data) { + Err(io::Error::from_raw_os_error(ERROR_OPERATION_ABORTED as _)) + } else { + entry.into_result() + }; + Some(Entry::new(user_data, result)) + } else { + None + } } pub unsafe fn poll( @@ -336,36 +244,21 @@ impl Driver { mut entries: OutEntries>, ) -> io::Result<()> { instrument!(compio_log::Level::TRACE, "poll", ?timeout); - // Prevent stack growth. - let mut iocp_entries = ArrayVec::::new(); - self.poll_impl(timeout, &mut iocp_entries)?; - entries.extend(iocp_entries.drain(..).filter_map(|e| self.create_entry(e))); - - // See if there are remaining entries. - loop { - match self.poll_impl(Some(Duration::ZERO), &mut iocp_entries) { - Ok(()) => { - entries.extend(iocp_entries.drain(..).filter_map(|e| self.create_entry(e))); - } - Err(e) => match e.kind() { - io::ErrorKind::TimedOut => { - trace!("poll timeout"); - break; - } - _ => return Err(e), - }, - } - } + + entries.extend( + self.port + .poll(timeout)? + .filter_map(|e| Self::create_entry(&mut self.cancelled, e)), + ); Ok(()) } pub fn handle(&self) -> io::Result { - self.handle_for(Self::NOTIFY) - } - - pub fn handle_for(&self, user_data: usize) -> io::Result { - Ok(NotifyHandle::new(user_data, self.port.clone())) + Ok(NotifyHandle::new( + self.port.handle(), + self.notify_overlapped.clone(), + )) } } @@ -383,30 +276,18 @@ impl Drop for Driver { /// A notify handle to the inner driver. pub struct NotifyHandle { - user_data: usize, - handle: Arc, + port: cp::PortHandle, + overlapped: Arc>, } -unsafe impl Send for NotifyHandle {} -unsafe impl Sync for NotifyHandle {} - impl NotifyHandle { - fn new(user_data: usize, handle: Arc) -> Self { - Self { user_data, handle } + fn new(port: cp::PortHandle, overlapped: Arc>) -> Self { + Self { port, overlapped } } /// Notify the inner driver. pub fn notify(&self) -> io::Result<()> { - syscall!( - BOOL, - PostQueuedCompletionStatus( - self.handle.as_raw_handle() as _, - 0, - self.user_data, - null_mut() - ) - )?; - Ok(()) + self.port.post_raw(self.overlapped.as_ref()) } } @@ -415,6 +296,8 @@ impl NotifyHandle { pub struct Overlapped { /// The base [`OVERLAPPED`]. pub base: OVERLAPPED, + /// The unique ID of created driver. + pub driver: RawFd, /// The registered user defined data. pub user_data: usize, /// The opcode. @@ -423,15 +306,20 @@ pub struct Overlapped { } impl Overlapped { - pub(crate) fn new(user_data: usize, op: T) -> Self { + pub(crate) fn new(driver: RawFd, user_data: usize, op: T) -> Self { Self { base: unsafe { std::mem::zeroed() }, + driver, user_data, op, } } } +// SAFETY: neither field of `OVERLAPPED` is used +unsafe impl Send for Overlapped<()> {} +unsafe impl Sync for Overlapped<()> {} + pub(crate) struct RawOp { op: NonNull>, // The two flags here are manual reference counting. The driver holds the strong ref until it @@ -441,8 +329,8 @@ pub(crate) struct RawOp { } impl RawOp { - pub(crate) fn new(user_data: usize, op: impl OpCode + 'static) -> Self { - let op = Overlapped::new(user_data, op); + pub(crate) fn new(driver: RawFd, user_data: usize, op: impl OpCode + 'static) -> Self { + let op = Overlapped::new(driver, user_data, op); let op = Box::new(op) as Box>; Self { op: unsafe { NonNull::new_unchecked(Box::into_raw(op)) }, diff --git a/compio-driver/src/iour/mod.rs b/compio-driver/src/iour/mod.rs index d6884be7..13df621f 100644 --- a/compio-driver/src/iour/mod.rs +++ b/compio-driver/src/iour/mod.rs @@ -196,6 +196,10 @@ impl Driver { entries.extend(completed_entries); } + pub fn create_op(&self, user_data: usize, op: T) -> RawOp { + RawOp::new(user_data, op) + } + pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> { Ok(()) } diff --git a/compio-driver/src/lib.rs b/compio-driver/src/lib.rs index 328644b7..39ce24f9 100644 --- a/compio-driver/src/lib.rs +++ b/compio-driver/src/lib.rs @@ -177,20 +177,13 @@ impl Proactor { }) } - /// Attach an fd to the driver. It will cause unexpected result to attach - /// the handle with one driver and push an op to another driver. + /// Attach an fd to the driver. /// /// ## Platform specific /// * IOCP: it will be attached to the completion port. An fd could only be /// attached to one driver, and could only be attached once, even if you /// `try_clone` it. - /// * io-uring: it will do nothing and return `Ok(())`. - /// * polling: it will initialize inner queue and register to the driver. On - /// Linux and Android, if the fd is a normal file or a directory, this - /// method will do nothing. For other fd and systems, you should only call - /// this method once for a specific resource. If this method is called - /// twice with the same fd, we assume that the old fd has been closed, and - /// it's a new fd. + /// * io-uring & polling: it will do nothing but return `Ok(())`. pub fn attach(&mut self, fd: RawFd) -> io::Result<()> { self.driver.attach(fd) } @@ -223,7 +216,7 @@ impl Proactor { pub fn push(&mut self, op: T) -> PushEntry, BufResult> { let entry = self.ops.vacant_entry(); let user_data = entry.key(); - let op = RawOp::new(user_data, op); + let op = self.driver.create_op(user_data, op); let op = entry.insert(op); match self.driver.push(user_data, op) { Poll::Pending => PushEntry::Pending(unsafe { Key::new(user_data) }), @@ -277,16 +270,6 @@ impl Proactor { pub fn handle(&self) -> io::Result { self.driver.handle() } - - /// Create a notify handle for specified user_data. - /// - /// # Safety - /// - /// The caller should ensure `user_data` being valid. - #[cfg(windows)] - pub unsafe fn handle_for(&self, user_data: usize) -> io::Result { - self.driver.handle_for(user_data) - } } impl AsRawFd for Proactor { diff --git a/compio-driver/src/poll/mod.rs b/compio-driver/src/poll/mod.rs index 4722c4aa..a979c87d 100644 --- a/compio-driver/src/poll/mod.rs +++ b/compio-driver/src/poll/mod.rs @@ -138,11 +138,6 @@ impl FdQueue { } None } - - pub fn clear(&mut self) { - self.read_queue.clear(); - self.write_queue.clear(); - } } /// Low-level driver of polling. @@ -187,41 +182,28 @@ impl Driver { }) } - fn submit(&mut self, user_data: usize, arg: WaitArg) -> io::Result<()> { - let queue = self - .registry - .get_mut(&arg.fd) - .expect("the fd should be attached"); + pub fn create_op(&self, user_data: usize, op: T) -> RawOp { + RawOp::new(user_data, op) + } + + /// # Safety + /// The input fd should be valid. + unsafe fn submit(&mut self, user_data: usize, arg: WaitArg) -> io::Result<()> { + let need_add = !self.registry.contains_key(&arg.fd); + let queue = self.registry.entry(arg.fd).or_default(); queue.push_back_interest(user_data, arg.interest); // We use fd as the key. let event = queue.event(arg.fd as usize); - unsafe { + if need_add { + self.poll.add(arg.fd, event)?; + } else { let fd = BorrowedFd::borrow_raw(arg.fd); self.poll.modify(fd, event)?; } Ok(()) } - pub fn attach(&mut self, fd: RawFd) -> io::Result<()> { - if cfg!(any(target_os = "linux", target_os = "android")) { - let mut stat = unsafe { std::mem::zeroed() }; - syscall!(libc::fstat(fd, &mut stat))?; - if matches!(stat.st_mode & libc::S_IFMT, libc::S_IFREG | libc::S_IFDIR) { - return Ok(()); - } - } - let queue = self.registry.entry(fd).or_default(); - unsafe { - match self.poll.add(fd, Event::none(0)) { - Ok(()) => {} - Err(e) if e.kind() == io::ErrorKind::AlreadyExists => { - queue.clear(); - let fd = BorrowedFd::borrow_raw(fd); - self.poll.modify(fd, Event::none(0))?; - } - Err(e) => return Err(e), - } - } + pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> { Ok(()) } @@ -236,7 +218,10 @@ impl Driver { let op_pin = op.as_pin(); match op_pin.pre_submit() { Ok(Decision::Wait(arg)) => { - self.submit(user_data, arg)?; + // SAFETY: fd is from the OpCode. + unsafe { + self.submit(user_data, arg)?; + } Poll::Pending } Ok(Decision::Completed(res)) => Poll::Ready(Ok(res)), diff --git a/compio-fs/src/file.rs b/compio-fs/src/file.rs index a114740c..5907e49f 100644 --- a/compio-fs/src/file.rs +++ b/compio-fs/src/file.rs @@ -1,14 +1,13 @@ use std::{future::Future, io, mem::ManuallyDrop, path::Path}; -use compio_buf::{buf_try, BufResult, IntoInner, IoBuf, IoBufMut}; +use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut}; use compio_driver::{ + impl_raw_fd, op::{BufResultExt, CloseFile, FileStat, ReadAt, Sync, WriteAt}, - syscall, + syscall, AsRawFd, }; use compio_io::{AsyncReadAt, AsyncWriteAt}; -use compio_runtime::{ - impl_attachable, impl_try_as_raw_fd, Attacher, Runtime, TryAsRawFd, TryClone, -}; +use compio_runtime::{impl_try_clone, Attacher, Runtime}; #[cfg(unix)] use { compio_buf::{IoVectoredBuf, IoVectoredBufMut}, @@ -29,6 +28,12 @@ pub struct File { } impl File { + pub(crate) fn new(file: std::fs::File) -> io::Result { + Ok(Self { + inner: Attacher::new(file)?, + }) + } + /// Attempts to open a file in read-only mode. /// /// See the [`OpenOptions::open`] method for more details. @@ -59,24 +64,15 @@ impl File { // `close` should be cancelled. let this = ManuallyDrop::new(self); async move { - let op = CloseFile::new(this.inner.try_as_raw_fd()?); + let op = CloseFile::new(this.inner.as_raw_fd()); Runtime::current().submit(op).await.0?; Ok(()) } } - /// Creates a new `File` instance that shares the same underlying file - /// handle as the existing `File` instance. - /// - /// It does not clear the attach state. - pub fn try_clone(&self) -> io::Result { - let inner = self.inner.try_clone()?; - Ok(Self { inner }) - } - /// Queries metadata about the underlying file. pub async fn metadata(&self) -> io::Result { - let op = FileStat::new(self.try_as_raw_fd()?); + let op = FileStat::new(self.as_raw_fd()); let BufResult(res, op) = Runtime::current().submit(op).await; res.map(|_| Metadata::from_stat(op.into_inner())) } @@ -88,7 +84,7 @@ impl File { FileBasicInfo, SetFileInformationByHandle, FILE_BASIC_INFO, }; - let fd = self.try_as_raw_fd()? as _; + let fd = self.as_raw_fd() as _; Runtime::current() .spawn_blocking(move || { let info = FILE_BASIC_INFO { @@ -117,7 +113,7 @@ impl File { pub async fn set_permissions(&self, perm: Permissions) -> io::Result<()> { use std::os::unix::fs::PermissionsExt; - let fd = self.try_as_raw_fd()? as _; + let fd = self.as_raw_fd() as _; Runtime::current() .spawn_blocking(move || { syscall!(libc::fchmod(fd, perm.mode() as libc::mode_t))?; @@ -127,7 +123,7 @@ impl File { } async fn sync_impl(&self, datasync: bool) -> io::Result<()> { - let op = Sync::new(self.try_as_raw_fd()?, datasync); + let op = Sync::new(self.as_raw_fd(), datasync); Runtime::current().submit(op).await.0?; Ok(()) } @@ -158,7 +154,7 @@ impl File { impl AsyncReadAt for File { async fn read_at(&self, buffer: T, pos: u64) -> BufResult { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = ReadAt::new(fd, pos, buffer); Runtime::current() .submit(op) @@ -173,7 +169,7 @@ impl AsyncReadAt for File { buffer: T, pos: u64, ) -> BufResult { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = ReadVectoredAt::new(fd, pos, buffer); Runtime::current() .submit(op) @@ -202,7 +198,7 @@ impl AsyncWriteAt for File { impl AsyncWriteAt for &File { async fn write_at(&mut self, buffer: T, pos: u64) -> BufResult { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = WriteAt::new(fd, pos, buffer); Runtime::current().submit(op).await.into_inner() } @@ -213,12 +209,12 @@ impl AsyncWriteAt for &File { buffer: T, pos: u64, ) -> BufResult { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = WriteVectoredAt::new(fd, pos, buffer); Runtime::current().submit(op).await.into_inner() } } -impl_try_as_raw_fd!(File, inner); +impl_raw_fd!(File, inner); -impl_attachable!(File, inner); +impl_try_clone!(File, inner); diff --git a/compio-fs/src/named_pipe.rs b/compio-fs/src/named_pipe.rs index 32fedaef..6d785c57 100644 --- a/compio-fs/src/named_pipe.rs +++ b/compio-fs/src/named_pipe.rs @@ -7,9 +7,9 @@ use std::ptr::null_mut; use std::{ffi::OsStr, io, ptr::null}; use compio_buf::{BufResult, IoBuf, IoBufMut}; -use compio_driver::{op::ConnectNamedPipe, syscall, FromRawFd, RawFd}; +use compio_driver::{impl_raw_fd, op::ConnectNamedPipe, syscall, AsRawFd, FromRawFd, RawFd}; use compio_io::{AsyncRead, AsyncReadAt, AsyncWrite, AsyncWriteAt}; -use compio_runtime::{impl_attachable, impl_try_as_raw_fd, Runtime, TryAsRawFd}; +use compio_runtime::{impl_try_clone, Runtime}; use widestring::U16CString; use windows_sys::Win32::{ Security::SECURITY_ATTRIBUTES, @@ -93,15 +93,6 @@ pub struct NamedPipeServer { } impl NamedPipeServer { - /// Creates a new independently owned handle to the underlying file handle. - /// - /// It does not clear the attach state. - pub fn try_clone(&self) -> io::Result { - Ok(Self { - handle: self.handle.try_clone()?, - }) - } - /// Retrieves information about the named pipe the server is associated /// with. /// @@ -125,8 +116,7 @@ impl NamedPipeServer { /// ``` pub fn info(&self) -> io::Result { // Safety: we're ensuring the lifetime of the named pipe. - // Safety: getting info doesn't need to be attached. - unsafe { named_pipe_info(self.as_raw_fd_unchecked()) } + unsafe { named_pipe_info(self.as_raw_fd()) } } /// Enables a named pipe server process to wait for a client process to @@ -152,7 +142,7 @@ impl NamedPipeServer { /// # std::io::Result::Ok(()) }); /// ``` pub async fn connect(&self) -> io::Result<()> { - let op = ConnectNamedPipe::new(self.handle.try_as_raw_fd()?); + let op = ConnectNamedPipe::new(self.handle.as_raw_fd()); Runtime::current().submit(op).await.0?; Ok(()) } @@ -185,7 +175,7 @@ impl NamedPipeServer { /// # }) /// ``` pub fn disconnect(&self) -> io::Result<()> { - syscall!(BOOL, DisconnectNamedPipe(self.try_as_raw_fd()? as _))?; + syscall!(BOOL, DisconnectNamedPipe(self.as_raw_fd() as _))?; Ok(()) } } @@ -240,9 +230,9 @@ impl AsyncWrite for &NamedPipeServer { } } -impl_try_as_raw_fd!(NamedPipeServer, handle); +impl_raw_fd!(NamedPipeServer, handle); -impl_attachable!(NamedPipeServer, handle); +impl_try_clone!(NamedPipeServer, handle); /// A [Windows named pipe] client. /// @@ -289,15 +279,6 @@ pub struct NamedPipeClient { } impl NamedPipeClient { - /// Creates a new independently owned handle to the underlying file handle. - /// - /// It does not clear the attach state. - pub fn try_clone(&self) -> io::Result { - Ok(Self { - handle: self.handle.try_clone()?, - }) - } - /// Retrieves information about the named pipe the client is associated /// with. /// @@ -318,8 +299,7 @@ impl NamedPipeClient { /// ``` pub fn info(&self) -> io::Result { // Safety: we're ensuring the lifetime of the named pipe. - // Safety: getting info doesn't need to be attached. - unsafe { named_pipe_info(self.as_raw_fd_unchecked()) } + unsafe { named_pipe_info(self.as_raw_fd()) } } } @@ -373,9 +353,9 @@ impl AsyncWrite for &NamedPipeClient { } } -impl_try_as_raw_fd!(NamedPipeClient, handle); +impl_raw_fd!(NamedPipeClient, handle); -impl_attachable!(NamedPipeClient, handle); +impl_try_clone!(NamedPipeClient, handle); /// A builder structure for construct a named pipe with named pipe-specific /// options. This is required to use for named pipe servers who wants to modify @@ -410,7 +390,9 @@ impl ServerOptions { /// /// const PIPE_NAME: &str = r"\\.\pipe\compio-named-pipe-new"; /// + /// # compio_runtime::Runtime::new().unwrap().block_on(async move { /// let server = ServerOptions::new().create(PIPE_NAME).unwrap(); + /// # }) /// ``` pub fn new() -> ServerOptions { ServerOptions { @@ -738,8 +720,8 @@ impl ServerOptions { /// ``` /// use std::{io, ptr}; /// + /// use compio_driver::AsRawFd; /// use compio_fs::named_pipe::ServerOptions; - /// use compio_runtime::TryAsRawFd; /// use windows_sys::Win32::{ /// Foundation::ERROR_SUCCESS, /// Security::{ @@ -759,7 +741,7 @@ impl ServerOptions { /// assert_eq!( /// ERROR_SUCCESS, /// SetSecurityInfo( - /// pipe.as_raw_fd_unchecked() as _, + /// pipe.as_raw_fd() as _, /// SE_KERNEL_OBJECT, /// DACL_SECURITY_INFORMATION, /// ptr::null_mut(), @@ -775,8 +757,8 @@ impl ServerOptions { /// ``` /// use std::{io, ptr}; /// + /// use compio_driver::AsRawFd; /// use compio_fs::named_pipe::ServerOptions; - /// use compio_runtime::TryAsRawFd; /// use windows_sys::Win32::{ /// Foundation::ERROR_ACCESS_DENIED, /// Security::{ @@ -796,7 +778,7 @@ impl ServerOptions { /// assert_eq!( /// ERROR_ACCESS_DENIED, /// SetSecurityInfo( - /// pipe.as_raw_fd_unchecked() as _, + /// pipe.as_raw_fd() as _, /// SE_KERNEL_OBJECT, /// DACL_SECURITY_INFORMATION, /// ptr::null_mut(), @@ -1009,7 +991,9 @@ impl ServerOptions { ) )?; - Ok(unsafe { NamedPipeServer::from_raw_fd(h as _) }) + Ok(NamedPipeServer { + handle: File::new(unsafe { std::fs::File::from_raw_fd(h as _) })?, + }) } } @@ -1165,7 +1149,7 @@ impl ClientOptions { let mode = PIPE_READMODE_MESSAGE; syscall!( BOOL, - SetNamedPipeHandleState(file.as_raw_fd_unchecked() as _, &mode, null(), null()) + SetNamedPipeHandleState(file.as_raw_fd() as _, &mode, null(), null()) )?; } diff --git a/compio-fs/src/open_options/unix.rs b/compio-fs/src/open_options/unix.rs index f35dce4c..a544653c 100644 --- a/compio-fs/src/open_options/unix.rs +++ b/compio-fs/src/open_options/unix.rs @@ -92,6 +92,6 @@ impl OpenOptions { let p = path_string(p)?; let op = OpenFile::new(p, flags, self.mode); let fd = Runtime::current().submit(op).await.0? as RawFd; - Ok(unsafe { File::from_raw_fd(fd) }) + File::new(unsafe { std::fs::File::from_raw_fd(fd) }) } } diff --git a/compio-fs/src/open_options/windows.rs b/compio-fs/src/open_options/windows.rs index 004314a6..9bf5e14f 100644 --- a/compio-fs/src/open_options/windows.rs +++ b/compio-fs/src/open_options/windows.rs @@ -160,6 +160,6 @@ impl OpenOptions { ) )?; } - Ok(unsafe { File::from_raw_fd(fd) }) + File::new(unsafe { std::fs::File::from_raw_fd(fd) }) } } diff --git a/compio-fs/src/pipe.rs b/compio-fs/src/pipe.rs index 01a376f8..51adeba2 100644 --- a/compio-fs/src/pipe.rs +++ b/compio-fs/src/pipe.rs @@ -2,13 +2,14 @@ use std::{future::Future, io, path::Path}; -use compio_buf::{buf_try, BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; +use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; use compio_driver::{ + impl_raw_fd, op::{BufResultExt, Recv, RecvVectored, Send, SendVectored}, - syscall, FromRawFd, IntoRawFd, + syscall, AsRawFd, FromRawFd, IntoRawFd, }; use compio_io::{AsyncRead, AsyncWrite}; -use compio_runtime::{impl_attachable, impl_try_as_raw_fd, Runtime, TryAsRawFd}; +use compio_runtime::{impl_try_clone, Runtime}; use crate::File; @@ -358,13 +359,13 @@ impl AsyncWrite for Sender { impl AsyncWrite for &Sender { async fn write(&mut self, buffer: T) -> BufResult { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = Send::new(fd, buffer); Runtime::current().submit(op).await.into_inner() } async fn write_vectored(&mut self, buffer: T) -> BufResult { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = SendVectored::new(fd, buffer); Runtime::current().submit(op).await.into_inner() } @@ -380,9 +381,9 @@ impl AsyncWrite for &Sender { } } -impl_try_as_raw_fd!(Sender, file); +impl_raw_fd!(Sender, file); -impl_attachable!(Sender, file); +impl_try_clone!(Sender, file); /// Reading end of a Unix pipe. /// @@ -483,7 +484,7 @@ impl AsyncRead for Receiver { impl AsyncRead for &Receiver { async fn read(&mut self, buffer: B) -> BufResult { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = Recv::new(fd, buffer); Runtime::current() .submit(op) @@ -493,7 +494,7 @@ impl AsyncRead for &Receiver { } async fn read_vectored(&mut self, buffer: V) -> BufResult { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = RecvVectored::new(fd, buffer); Runtime::current() .submit(op) @@ -503,9 +504,9 @@ impl AsyncRead for &Receiver { } } -impl_try_as_raw_fd!(Receiver, file); +impl_raw_fd!(Receiver, file); -impl_attachable!(Receiver, file); +impl_try_clone!(Receiver, file); /// Checks if file is a FIFO async fn is_fifo(file: &File) -> io::Result { @@ -515,9 +516,9 @@ async fn is_fifo(file: &File) -> io::Result { } /// Sets file's flags with O_NONBLOCK by fcntl. -fn set_nonblocking(file: &impl TryAsRawFd) -> io::Result<()> { +fn set_nonblocking(file: &impl AsRawFd) -> io::Result<()> { if cfg!(not(all(target_os = "linux", feature = "io-uring"))) { - let fd = file.try_as_raw_fd()?; + let fd = file.as_raw_fd(); let current_flags = syscall!(libc::fcntl(fd, libc::F_GETFL))?; let flags = current_flags | libc::O_NONBLOCK; if flags != current_flags { diff --git a/compio-fs/src/stdio/unix.rs b/compio-fs/src/stdio/unix.rs index f883f926..fe9c359e 100644 --- a/compio-fs/src/stdio/unix.rs +++ b/compio-fs/src/stdio/unix.rs @@ -1,9 +1,8 @@ use std::{io, mem::ManuallyDrop}; use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; -use compio_driver::{FromRawFd, RawFd}; +use compio_driver::{AsRawFd, FromRawFd, RawFd}; use compio_io::{AsyncRead, AsyncWrite}; -use compio_runtime::TryAsRawFd; use crate::pipe::{Receiver, Sender}; @@ -31,13 +30,9 @@ impl AsyncRead for Stdin { } } -impl TryAsRawFd for Stdin { - fn try_as_raw_fd(&self) -> io::Result { - self.0.try_as_raw_fd() - } - - unsafe fn as_raw_fd_unchecked(&self) -> RawFd { - self.0.as_raw_fd_unchecked() +impl AsRawFd for Stdin { + fn as_raw_fd(&self) -> RawFd { + self.0.as_raw_fd() } } @@ -73,13 +68,9 @@ impl AsyncWrite for Stdout { } } -impl TryAsRawFd for Stdout { - fn try_as_raw_fd(&self) -> io::Result { - self.0.try_as_raw_fd() - } - - unsafe fn as_raw_fd_unchecked(&self) -> RawFd { - self.0.as_raw_fd_unchecked() +impl AsRawFd for Stdout { + fn as_raw_fd(&self) -> RawFd { + self.0.as_raw_fd() } } @@ -115,12 +106,8 @@ impl AsyncWrite for Stderr { } } -impl TryAsRawFd for Stderr { - fn try_as_raw_fd(&self) -> io::Result { - self.0.try_as_raw_fd() - } - - unsafe fn as_raw_fd_unchecked(&self) -> RawFd { - self.0.as_raw_fd_unchecked() +impl AsRawFd for Stderr { + fn as_raw_fd(&self) -> RawFd { + self.0.as_raw_fd() } } diff --git a/compio-net/Cargo.toml b/compio-net/Cargo.toml index d4396955..f88f7207 100644 --- a/compio-net/Cargo.toml +++ b/compio-net/Cargo.toml @@ -39,6 +39,5 @@ libc = { workspace = true } # Shared dev dependencies for all platforms [dev-dependencies] compio-macros = { workspace = true } -futures-channel = { workspace = true } futures-util = { workspace = true } tempfile = { workspace = true } diff --git a/compio-net/src/socket.rs b/compio-net/src/socket.rs index c3d5e195..873f2228 100644 --- a/compio-net/src/socket.rs +++ b/compio-net/src/socket.rs @@ -1,13 +1,15 @@ use std::{future::Future, io, mem::ManuallyDrop}; -use compio_buf::{buf_try, BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; -use compio_driver::op::{ - Accept, BufResultExt, CloseSocket, Connect, Recv, RecvFrom, RecvFromVectored, RecvResultExt, - RecvVectored, Send, SendTo, SendToVectored, SendVectored, ShutdownSocket, -}; -use compio_runtime::{ - impl_attachable, impl_try_as_raw_fd, Attacher, Runtime, TryAsRawFd, TryClone, +use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; +use compio_driver::{ + impl_raw_fd, + op::{ + Accept, BufResultExt, CloseSocket, Connect, Recv, RecvFrom, RecvFromVectored, + RecvResultExt, RecvVectored, Send, SendTo, SendToVectored, SendVectored, ShutdownSocket, + }, + AsRawFd, }; +use compio_runtime::{impl_try_clone, Attacher, Runtime}; use socket2::{Domain, Protocol, SockAddr, Socket as Socket2, Type}; #[derive(Debug)] @@ -16,23 +18,18 @@ pub struct Socket { } impl Socket { - pub fn from_socket2(socket: Socket2) -> Self { - Self { - socket: Attacher::new(socket), - } - } - - pub fn try_clone(&self) -> io::Result { - let socket = self.socket.try_clone()?; - Ok(Self { socket }) + pub fn from_socket2(socket: Socket2) -> io::Result { + Ok(Self { + socket: Attacher::new(socket)?, + }) } pub fn peer_addr(&self) -> io::Result { - unsafe { self.socket.get_unchecked() }.peer_addr() + self.socket.peer_addr() } pub fn local_addr(&self) -> io::Result { - unsafe { self.socket.get_unchecked() }.local_addr() + self.socket.local_addr() } pub fn new(domain: Domain, ty: Type, protocol: Option) -> io::Result { @@ -48,25 +45,25 @@ impl Socket { )) { socket.set_nonblocking(true)?; } - Ok(Self::from_socket2(socket)) + Self::from_socket2(socket) } pub fn bind(addr: &SockAddr, ty: Type, protocol: Option) -> io::Result { let socket = Self::new(addr.domain(), ty, protocol)?; - unsafe { socket.socket.get_unchecked() }.bind(addr)?; + socket.socket.bind(addr)?; Ok(socket) } pub fn listen(&self, backlog: i32) -> io::Result<()> { - unsafe { self.socket.get_unchecked() }.listen(backlog) + self.socket.listen(backlog) } pub fn connect(&self, addr: &SockAddr) -> io::Result<()> { - self.socket.try_get()?.connect(addr) + self.socket.connect(addr) } pub async fn connect_async(&self, addr: &SockAddr) -> io::Result<()> { - let op = Connect::new(self.try_as_raw_fd()?, addr.clone()); + let op = Connect::new(self.as_raw_fd(), addr.clone()); let BufResult(res, _op) = Runtime::current().submit(op).await; #[cfg(windows)] { @@ -84,7 +81,7 @@ impl Socket { pub async fn accept(&self) -> io::Result<(Self, SockAddr)> { use compio_driver::FromRawFd; - let op = Accept::new(self.try_as_raw_fd()?); + let op = Accept::new(self.as_raw_fd()); let BufResult(res, op) = Runtime::current().submit(op).await; let accept_sock = unsafe { Socket2::from_raw_fd(res? as _) }; if cfg!(all( @@ -93,28 +90,26 @@ impl Socket { )) { accept_sock.set_nonblocking(true)?; } - let accept_sock = Self::from_socket2(accept_sock); + let accept_sock = Self::from_socket2(accept_sock)?; let addr = op.into_addr(); Ok((accept_sock, addr)) } #[cfg(windows)] pub async fn accept(&self) -> io::Result<(Self, SockAddr)> { - use compio_driver::AsRawFd; - let local_addr = self.local_addr()?; // We should allow users sending this accepted socket to a new thread. let accept_sock = Socket2::new( local_addr.domain(), - unsafe { self.socket.get_unchecked() }.r#type()?, - unsafe { self.socket.get_unchecked() }.protocol()?, + self.socket.r#type()?, + self.socket.protocol()?, )?; - let op = Accept::new(self.try_as_raw_fd()?, accept_sock.as_raw_fd() as _); + let op = Accept::new(self.as_raw_fd(), accept_sock.as_raw_fd() as _); let BufResult(res, op) = Runtime::current().submit(op).await; res?; op.update_context()?; let addr = op.into_addr()?; - Ok((Self::from_socket2(accept_sock), addr)) + Ok((Self::from_socket2(accept_sock)?, addr)) } pub fn close(self) -> impl Future> { @@ -123,20 +118,20 @@ impl Socket { // `close` should be cancelled. let this = ManuallyDrop::new(self); async move { - let op = CloseSocket::new(this.try_as_raw_fd()?); + let op = CloseSocket::new(this.as_raw_fd()); Runtime::current().submit(op).await.0?; Ok(()) } } pub async fn shutdown(&self) -> io::Result<()> { - let op = ShutdownSocket::new(self.try_as_raw_fd()?, std::net::Shutdown::Write); + let op = ShutdownSocket::new(self.as_raw_fd(), std::net::Shutdown::Write); Runtime::current().submit(op).await.0?; Ok(()) } pub async fn recv(&self, buffer: B) -> BufResult { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = Recv::new(fd, buffer); Runtime::current() .submit(op) @@ -146,7 +141,7 @@ impl Socket { } pub async fn recv_vectored(&self, buffer: V) -> BufResult { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = RecvVectored::new(fd, buffer); Runtime::current() .submit(op) @@ -156,19 +151,19 @@ impl Socket { } pub async fn send(&self, buffer: T) -> BufResult { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = Send::new(fd, buffer); Runtime::current().submit(op).await.into_inner() } pub async fn send_vectored(&self, buffer: T) -> BufResult { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = SendVectored::new(fd, buffer); Runtime::current().submit(op).await.into_inner() } pub async fn recv_from(&self, buffer: T) -> BufResult<(usize, SockAddr), T> { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = RecvFrom::new(fd, buffer); Runtime::current() .submit(op) @@ -182,7 +177,7 @@ impl Socket { &self, buffer: T, ) -> BufResult<(usize, SockAddr), T> { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = RecvFromVectored::new(fd, buffer); Runtime::current() .submit(op) @@ -193,7 +188,7 @@ impl Socket { } pub async fn send_to(&self, buffer: T, addr: &SockAddr) -> BufResult { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = SendTo::new(fd, buffer, addr.clone()); Runtime::current().submit(op).await.into_inner() } @@ -203,12 +198,12 @@ impl Socket { buffer: T, addr: &SockAddr, ) -> BufResult { - let (fd, buffer) = buf_try!(self.try_as_raw_fd(), buffer); + let fd = self.as_raw_fd(); let op = SendToVectored::new(fd, buffer, addr.clone()); Runtime::current().submit(op).await.into_inner() } } -impl_try_as_raw_fd!(Socket, socket); +impl_raw_fd!(Socket, socket); -impl_attachable!(Socket, socket); +impl_try_clone!(Socket, socket); diff --git a/compio-net/src/tcp.rs b/compio-net/src/tcp.rs index cd025a4f..7700650a 100644 --- a/compio-net/src/tcp.rs +++ b/compio-net/src/tcp.rs @@ -1,8 +1,9 @@ use std::{future::Future, io, net::SocketAddr}; use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; +use compio_driver::impl_raw_fd; use compio_io::{AsyncRead, AsyncWrite}; -use compio_runtime::{impl_attachable, impl_try_as_raw_fd}; +use compio_runtime::impl_try_clone; use socket2::{Protocol, SockAddr, Type}; use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, ToSocketAddrsAsync, WriteHalf}; @@ -67,15 +68,6 @@ impl TcpListener { self.inner.close() } - /// Creates a new independently owned handle to the underlying socket. - /// - /// It does not clear the attach state. - pub fn try_clone(&self) -> io::Result { - Ok(Self { - inner: self.inner.try_clone()?, - }) - } - /// Accepts a new incoming connection from this listener. /// /// This function will yield once a new TCP connection is established. When @@ -117,9 +109,9 @@ impl TcpListener { } } -impl_try_as_raw_fd!(TcpListener, inner); +impl_raw_fd!(TcpListener, inner); -impl_attachable!(TcpListener, inner); +impl_try_clone!(TcpListener, inner); /// A TCP stream between a local and a remote socket. /// @@ -181,15 +173,6 @@ impl TcpStream { self.inner.close() } - /// Creates a new independently owned handle to the underlying socket. - /// - /// It does not clear the attach state. - pub fn try_clone(&self) -> io::Result { - Ok(Self { - inner: self.inner.try_clone()?, - }) - } - /// Returns the socket address of the remote peer of this TCP connection. pub fn peer_addr(&self) -> io::Result { self.inner @@ -292,6 +275,6 @@ impl AsyncWrite for &TcpStream { } } -impl_try_as_raw_fd!(TcpStream, inner); +impl_raw_fd!(TcpStream, inner); -impl_attachable!(TcpStream, inner); +impl_try_clone!(TcpStream, inner); diff --git a/compio-net/src/udp.rs b/compio-net/src/udp.rs index 92e2abeb..e865f07b 100644 --- a/compio-net/src/udp.rs +++ b/compio-net/src/udp.rs @@ -1,7 +1,8 @@ use std::{future::Future, io, net::SocketAddr}; use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; -use compio_runtime::{impl_attachable, impl_try_as_raw_fd}; +use compio_driver::impl_raw_fd; +use compio_runtime::impl_try_clone; use socket2::{Protocol, SockAddr, Type}; use crate::{Socket, ToSocketAddrsAsync}; @@ -121,15 +122,6 @@ impl UdpSocket { self.inner.close() } - /// Creates a new independently owned handle to the underlying socket. - /// - /// It does not clear the attach state. - pub fn try_clone(&self) -> io::Result { - Ok(Self { - inner: self.inner.try_clone()?, - }) - } - /// Returns the socket address of the remote peer this socket was connected /// to. /// @@ -259,6 +251,6 @@ impl UdpSocket { } } -impl_try_as_raw_fd!(UdpSocket, inner); +impl_raw_fd!(UdpSocket, inner); -impl_attachable!(UdpSocket, inner); +impl_try_clone!(UdpSocket, inner); diff --git a/compio-net/src/unix.rs b/compio-net/src/unix.rs index 0eff2b02..89928428 100644 --- a/compio-net/src/unix.rs +++ b/compio-net/src/unix.rs @@ -1,8 +1,9 @@ use std::{future::Future, io, path::Path}; use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; +use compio_driver::impl_raw_fd; use compio_io::{AsyncRead, AsyncWrite}; -use compio_runtime::{impl_attachable, impl_try_as_raw_fd}; +use compio_runtime::impl_try_clone; use socket2::{SockAddr, Type}; use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, WriteHalf}; @@ -70,15 +71,6 @@ impl UnixListener { self.inner.close() } - /// Creates a new independently owned handle to the underlying socket. - /// - /// It does not clear the attach state. - pub fn try_clone(&self) -> io::Result { - Ok(Self { - inner: self.inner.try_clone()?, - }) - } - /// Accepts a new incoming connection from this listener. /// /// This function will yield once a new Unix domain socket connection @@ -96,9 +88,9 @@ impl UnixListener { } } -impl_try_as_raw_fd!(UnixListener, inner); +impl_raw_fd!(UnixListener, inner); -impl_attachable!(UnixListener, inner); +impl_try_clone!(UnixListener, inner); /// A Unix stream between two local sockets on Windows & WSL. /// @@ -164,15 +156,6 @@ impl UnixStream { self.inner.close() } - /// Creates a new independently owned handle to the underlying socket. - /// - /// It does not clear the attach state. - pub fn try_clone(&self) -> io::Result { - Ok(Self { - inner: self.inner.try_clone()?, - }) - } - /// Returns the socket path of the remote peer of this connection. pub fn peer_addr(&self) -> io::Result { #[allow(unused_mut)] @@ -277,9 +260,7 @@ impl AsyncWrite for &UnixStream { } } -impl_try_as_raw_fd!(UnixStream, inner); - -impl_attachable!(UnixStream, inner); +impl_raw_fd!(UnixStream, inner); #[cfg(windows)] #[inline] @@ -324,3 +305,5 @@ fn fix_unix_socket_length(addr: &mut SockAddr) { addr.set_length(addr_len as _); } } + +impl_try_clone!(UnixStream, inner); diff --git a/compio-net/tests/tcp_accept.rs b/compio-net/tests/tcp_accept.rs index 2ca12ba5..8a44de0d 100644 --- a/compio-net/tests/tcp_accept.rs +++ b/compio-net/tests/tcp_accept.rs @@ -3,14 +3,12 @@ use compio_net::{TcpListener, TcpStream, ToSocketAddrsAsync}; async fn test_impl(addr: impl ToSocketAddrsAsync) { let listener = TcpListener::bind(addr).await.unwrap(); let addr = listener.local_addr().unwrap(); - let (tx, rx) = futures_channel::oneshot::channel(); - compio_runtime::spawn(async move { + let task = compio_runtime::spawn(async move { let (socket, _) = listener.accept().await.unwrap(); - assert!(tx.send(socket).is_ok()); - }) - .detach(); + socket + }); let cli = TcpStream::connect(&addr).await.unwrap(); - let srv = rx.await.unwrap(); + let srv = task.await; assert_eq!(cli.local_addr().unwrap(), srv.peer_addr().unwrap()); } diff --git a/compio-net/tests/tcp_connect.rs b/compio-net/tests/tcp_connect.rs index ab9bbd81..1057326c 100644 --- a/compio-net/tests/tcp_connect.rs +++ b/compio-net/tests/tcp_connect.rs @@ -10,17 +10,14 @@ async fn test_connect_ip_impl( let addr = listener.local_addr().unwrap(); assert!(assert_fn(&addr)); - let (tx, rx) = futures_channel::oneshot::channel(); - - compio_runtime::spawn(async move { + let task = compio_runtime::spawn(async move { let (socket, addr) = listener.accept().await.unwrap(); assert_eq!(addr, socket.peer_addr().unwrap()); - assert!(tx.send(socket).is_ok()); - }) - .detach(); + socket + }); let mine = TcpStream::connect(&addr).await.unwrap(); - let theirs = rx.await.unwrap(); + let theirs = task.await; assert_eq!(mine.local_addr().unwrap(), theirs.peer_addr().unwrap()); assert_eq!(theirs.local_addr().unwrap(), mine.peer_addr().unwrap()); diff --git a/compio-runtime/src/attacher.rs b/compio-runtime/src/attacher.rs index 29430865..7c533321 100644 --- a/compio-runtime/src/attacher.rs +++ b/compio-runtime/src/attacher.rs @@ -4,7 +4,10 @@ use std::os::fd::OwnedFd; use std::os::windows::prelude::{OwnedHandle, OwnedSocket}; #[cfg(feature = "once_cell_try")] use std::sync::OnceLock; -use std::{io, marker::PhantomData}; +use std::{ + io, + ops::{Deref, DerefMut}, +}; use compio_buf::IntoInner; use compio_driver::AsRawFd; @@ -23,77 +26,29 @@ use crate::Runtime; pub struct Attacher { source: S, // Make it thread safe. - once: OnceLock, - _p: PhantomData<*mut ()>, + once: OnceLock<()>, } -impl Attacher { - /// Create [`Attacher`]. - pub const fn new(source: S) -> Self { - Self { +impl Attacher { + /// Create [`Attacher`]. It tries to attach the source, and will return + /// [`Err`] if it fails. + pub fn new(source: S) -> io::Result { + let this = Self { source, once: OnceLock::new(), - _p: PhantomData, - } + }; + this.attach()?; + Ok(this) } -} -impl Attacher { /// Attach the source. This method could be called many times, but if the - /// action fails, the error will only return once. + /// action fails, it will try to attach the source during each call. fn attach(&self) -> io::Result<()> { let r = Runtime::current(); let inner = r.inner(); - let id = self.once.get_or_try_init(|| { - inner.attach(self.source.as_raw_fd())?; - io::Result::Ok(inner.id()) - })?; - if id != &inner.id() { - Err(io::Error::new( - io::ErrorKind::InvalidInput, - "the current runtime is not the attached runtime", - )) - } else { - Ok(()) - } - } - - /// Attach the inner source and get the reference. - pub fn try_get(&self) -> io::Result<&S> { - self.attach()?; - Ok(&self.source) - } - - /// Get the reference of the inner source without attaching it. - /// - /// # Safety - /// - /// The caller should ensure it is attached before submit an operation with - /// it. - pub unsafe fn get_unchecked(&self) -> &S { - &self.source - } - - /// Attach the inner source and get the mutable reference. - pub fn try_get_mut(&mut self) -> io::Result<&mut S> { - self.attach()?; - Ok(&mut self.source) - } - - /// Get the mutable reference of the inner source without attaching it. - /// - /// # Safety - /// - /// The caller should ensure it is attached before submit an operation with - /// it. - pub unsafe fn get_unchecked_mut(&mut self) -> &mut S { - &mut self.source - } -} - -impl Attachable for Attacher { - fn is_attached(&self) -> bool { - self.once.get().is_some() + self.once + .get_or_try_init(|| inner.attach(self.source.as_raw_fd()))?; + Ok(()) } } @@ -105,33 +60,22 @@ impl IntoRawFd for Attacher { impl FromRawFd for Attacher { unsafe fn from_raw_fd(fd: RawFd) -> Self { - Self::new(S::from_raw_fd(fd)) + Self { + source: S::from_raw_fd(fd), + once: OnceLock::from(()), + } } } -impl TryClone for Attacher { +impl TryClone for Attacher { /// Try clone self with the cloned source. The attach state will be /// reserved. - /// - /// ## Platform specific - /// * io-uring/polling: it will try to attach in the current thread if - /// needed. fn try_clone(&self) -> io::Result { let source = self.source.try_clone()?; - let new_self = if cfg!(windows) { - Self { - source, - once: self.once.clone(), - _p: PhantomData, - } - } else { - let new_self = Self::new(source); - if self.is_attached() { - new_self.attach()?; - } - new_self - }; - Ok(new_self) + Ok(Self { + source, + once: self.once.clone(), + }) } } @@ -143,10 +87,18 @@ impl IntoInner for Attacher { } } -/// Represents an attachable resource to driver. -pub trait Attachable { - /// Check if [`Attachable::attach`] has been called. - fn is_attached(&self) -> bool; +impl Deref for Attacher { + type Target = S; + + fn deref(&self) -> &Self::Target { + &self.source + } +} + +impl DerefMut for Attacher { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.source + } } /// Duplicatable file or socket. @@ -188,107 +140,15 @@ impl TryClone for OwnedFd { } } -/// Extracts raw fds. -pub trait TryAsRawFd { - /// Get the inner raw fd, while ensuring the source being attached. - fn try_as_raw_fd(&self) -> io::Result; - - /// Get the inner raw fd and don't check if it has been attached. - /// - /// # Safety - /// - /// The caller should ensure it is attached before submit an operation with - /// it. - unsafe fn as_raw_fd_unchecked(&self) -> RawFd; -} - -impl TryAsRawFd for T { - fn try_as_raw_fd(&self) -> io::Result { - Ok(self.as_raw_fd()) - } - - unsafe fn as_raw_fd_unchecked(&self) -> RawFd { - self.as_raw_fd() - } -} - -impl TryAsRawFd for Attacher { - fn try_as_raw_fd(&self) -> io::Result { - Ok(self.try_get()?.as_raw_fd()) - } - - unsafe fn as_raw_fd_unchecked(&self) -> RawFd { - self.source.as_raw_fd() - } -} - -/// A [`Send`] wrapper for attachable resource that has not been attached. The -/// resource should be able to send to another thread before attaching. -pub struct Unattached(T); - -impl Unattached { - /// Create the [`Unattached`] wrapper, or fail if the resource has already - /// been attached. - pub fn new(a: T) -> Result { - if a.is_attached() { Err(a) } else { Ok(Self(a)) } - } - - /// Create [`Unattached`] without checking. - /// - /// # Safety - /// - /// The caller should ensure that the resource has not been attached. - pub unsafe fn new_unchecked(a: T) -> Self { - Self(a) - } -} - -impl IntoInner for Unattached { - type Inner = T; - - fn into_inner(self) -> Self::Inner { - self.0 - } -} - -unsafe impl Send for Unattached {} -unsafe impl Sync for Unattached {} - -#[macro_export] -#[doc(hidden)] -macro_rules! impl_attachable { - ($t:ty, $inner:ident) => { - impl $crate::Attachable for $t { - fn is_attached(&self) -> bool { - self.$inner.is_attached() - } - } - }; -} - #[macro_export] #[doc(hidden)] -macro_rules! impl_try_as_raw_fd { +macro_rules! impl_try_clone { ($t:ty, $inner:ident) => { - impl $crate::TryAsRawFd for $t { - fn try_as_raw_fd(&self) -> ::std::io::Result<$crate::RawFd> { - self.$inner.try_as_raw_fd() - } - - unsafe fn as_raw_fd_unchecked(&self) -> $crate::RawFd { - self.$inner.as_raw_fd_unchecked() - } - } - impl $crate::FromRawFd for $t { - unsafe fn from_raw_fd(fd: $crate::RawFd) -> Self { - Self { - $inner: $crate::FromRawFd::from_raw_fd(fd), - } - } - } - impl $crate::IntoRawFd for $t { - fn into_raw_fd(self) -> $crate::RawFd { - self.$inner.into_raw_fd() + impl $crate::TryClone for $t { + fn try_clone(&self) -> ::std::io::Result { + Ok(Self { + $inner: self.$inner.try_clone()?, + }) } } }; diff --git a/compio-runtime/src/runtime/mod.rs b/compio-runtime/src/runtime/mod.rs index 2923f49a..83d36257 100644 --- a/compio-runtime/src/runtime/mod.rs +++ b/compio-runtime/src/runtime/mod.rs @@ -3,10 +3,7 @@ use std::{ future::{ready, Future}, io, rc::{Rc, Weak}, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, + sync::Arc, task::{Context, Poll, Waker}, time::Duration, }; @@ -43,10 +40,7 @@ impl Default for FutureState { } } -static RUNTIME_COUNTER: AtomicUsize = AtomicUsize::new(0); - pub(crate) struct RuntimeInner { - id: usize, driver: RefCell, runnables: Arc>, op_runtime: RefCell, @@ -57,7 +51,6 @@ pub(crate) struct RuntimeInner { impl RuntimeInner { pub fn new(builder: &ProactorBuilder) -> io::Result { Ok(Self { - id: RUNTIME_COUNTER.fetch_add(1, Ordering::AcqRel), driver: RefCell::new(builder.build()?), runnables: Arc::new(SegQueue::new()), op_runtime: RefCell::default(), @@ -66,10 +59,6 @@ impl RuntimeInner { }) } - pub fn id(&self) -> usize { - self.id - } - // Safety: be careful about the captured lifetime. pub unsafe fn spawn_unchecked(&self, future: F) -> Task { let runnables = self.runnables.clone(); diff --git a/compio/examples/dispatcher.rs b/compio/examples/dispatcher.rs index 63e2647c..dc5c6903 100644 --- a/compio/examples/dispatcher.rs +++ b/compio/examples/dispatcher.rs @@ -1,11 +1,10 @@ use std::{num::NonZeroUsize, panic::resume_unwind}; use compio::{ - buf::IntoInner, dispatcher::Dispatcher, io::{AsyncRead, AsyncWriteExt}, net::{TcpListener, TcpStream}, - runtime::{spawn, Unattached}, + runtime::spawn, BufResult, }; use futures_util::{stream::FuturesUnordered, StreamExt}; @@ -34,16 +33,12 @@ async fn main() { .detach(); let mut handles = FuturesUnordered::new(); for _i in 0..CLIENT_NUM { - let (srv, _) = listener.accept().await.unwrap(); - let srv = Unattached::new(srv).unwrap(); + let (mut srv, _) = listener.accept().await.unwrap(); let handle = dispatcher - .dispatch(move || { - let mut srv = srv.into_inner(); - async move { - let BufResult(res, buf) = srv.read(Vec::with_capacity(20)).await; - res.unwrap(); - println!("{}", std::str::from_utf8(&buf).unwrap()); - } + .dispatch(move || async move { + let BufResult(res, buf) = srv.read(Vec::with_capacity(20)).await; + res.unwrap(); + println!("{}", std::str::from_utf8(&buf).unwrap()); }) .unwrap(); handles.push(handle.join()); diff --git a/compio/tests/runtime.rs b/compio/tests/runtime.rs index 080dba9c..b3522835 100644 --- a/compio/tests/runtime.rs +++ b/compio/tests/runtime.rs @@ -7,7 +7,7 @@ use compio::{ fs::File, io::{AsyncReadAt, AsyncReadExt, AsyncWriteAt, AsyncWriteExt}, net::{TcpListener, TcpStream}, - runtime::Unattached, + runtime::TryClone, }; use tempfile::NamedTempFile; @@ -18,24 +18,22 @@ async fn multi_threading() { let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap(); let addr = listener.local_addr().unwrap(); - let (mut tx, (rx, _)) = + let (mut tx, (mut rx, _)) = futures_util::try_join!(TcpStream::connect(&addr), listener.accept()).unwrap(); tx.write_all(DATA).await.0.unwrap(); + tx.write_all(DATA).await.0.unwrap(); - let rx = Unattached::new(rx).unwrap(); - if let Err(e) = std::thread::spawn(move || { - let mut rx = rx.into_inner(); + let ((), buffer) = rx.read_exact(Vec::with_capacity(DATA.len())).await.unwrap(); + assert_eq!(DATA, String::from_utf8(buffer).unwrap()); + + compio::runtime::spawn_blocking(move || { compio::runtime::Runtime::new().unwrap().block_on(async { - let buffer = Vec::with_capacity(DATA.len()); - let ((), buffer) = rx.read_exact(buffer).await.unwrap(); + let ((), buffer) = rx.read_exact(Vec::with_capacity(DATA.len())).await.unwrap(); assert_eq!(DATA, String::from_utf8(buffer).unwrap()); }); }) - .join() - { - std::panic::resume_unwind(e) - } + .await } #[compio_macros::test] @@ -45,15 +43,13 @@ async fn try_clone() { let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap(); let addr = listener.local_addr().unwrap(); - let (tx, (rx, _)) = + let (tx, (mut rx, _)) = futures_util::try_join!(TcpStream::connect(&addr), listener.accept()).unwrap(); let mut tx = tx.try_clone().unwrap(); tx.write_all(DATA).await.0.unwrap(); - let rx = Unattached::new(rx.try_clone().unwrap()).unwrap(); if let Err(e) = std::thread::spawn(move || { - let mut rx = rx.into_inner(); compio::runtime::Runtime::new().unwrap().block_on(async { let buffer = Vec::with_capacity(DATA.len()); let ((), buffer) = rx.read_exact(buffer).await.unwrap();