diff --git a/zbus/src/address/mod.rs b/zbus/src/address/mod.rs index d35cdca75..bdbd5a04c 100644 --- a/zbus/src/address/mod.rs +++ b/zbus/src/address/mod.rs @@ -225,6 +225,7 @@ mod tests { Error::Address(e) => assert_eq!(e, "unix: address is invalid"), _ => panic!(), } + #[cfg(target_os = "linux")] match Address::from_str("unix:path=/tmp,abstract=foo").unwrap_err() { Error::Address(e) => { assert_eq!(e, "unix: address is invalid") @@ -235,9 +236,10 @@ mod tests { Address::from_str("unix:path=/tmp/dbus-foo").unwrap(), Transport::Unix(Unix::new(UnixPath::File("/tmp/dbus-foo".into()))).into(), ); + #[cfg(target_os = "linux")] assert_eq!( Address::from_str("unix:abstract=/tmp/dbus-foo").unwrap(), - Transport::Unix(Unix::new(UnixPath::File("\0/tmp/dbus-foo".into()))).into(), + Transport::Unix(Unix::new(UnixPath::Abstract("/tmp/dbus-foo".into()))).into(), ); let guid = crate::Guid::generate(); assert_eq!( @@ -340,10 +342,10 @@ mod tests { "unix:tmpdir=/tmp/dbus-foo" ); // FIXME: figure out how to handle abstract on Windows - #[cfg(unix)] + #[cfg(target_os = "linux")] assert_eq!( - Address::from(Transport::Unix(Unix::new(UnixPath::File( - "\0/tmp/dbus-foo".into() + Address::from(Transport::Unix(Unix::new(UnixPath::Abstract( + "/tmp/dbus-foo".into() )))) .to_string(), "unix:abstract=/tmp/dbus-foo" diff --git a/zbus/src/address/transport/mod.rs b/zbus/src/address/transport/mod.rs index 37f43b95d..e1d6c7690 100644 --- a/zbus/src/address/transport/mod.rs +++ b/zbus/src/address/transport/mod.rs @@ -9,16 +9,14 @@ use crate::{Error, Result}; use async_io::Async; #[cfg(not(feature = "tokio"))] use std::net::TcpStream; -#[cfg(all(unix, not(feature = "tokio")))] -use std::os::unix::net::UnixStream; +#[cfg(unix)] +use std::os::unix::net::{SocketAddr, UnixStream}; use std::{collections::HashMap, ffi::OsStr}; #[cfg(feature = "tokio")] use tokio::net::TcpStream; -#[cfg(all(unix, feature = "tokio"))] -use tokio::net::UnixStream; #[cfg(feature = "tokio-vsock")] use tokio_vsock::VsockStream; -#[cfg(all(windows, not(feature = "tokio")))] +#[cfg(windows)] use uds_windows::UnixStream; #[cfg(all(feature = "vsock", not(feature = "tokio")))] use vsock::VsockStream; @@ -47,6 +45,8 @@ pub use launchd::Launchd; #[path = "vsock.rs"] // Gotta rename to avoid name conflict with the `vsock` crate. mod vsock_transport; +#[cfg(target_os = "linux")] +use std::os::linux::net::SocketAddrExt; #[cfg(any( all(feature = "vsock", not(feature = "tokio")), feature = "tokio-vsock" @@ -83,54 +83,57 @@ impl Transport { #[cfg_attr(any(target_os = "macos", windows), async_recursion::async_recursion)] pub(super) async fn connect(self) -> Result { match self { - Transport::Unix(unix) => match unix.take_path() { - UnixPath::File(path) => { - #[cfg(not(feature = "tokio"))] - { + Transport::Unix(unix) => { + // This is a `path` in case of Windows until uds_windows provides the needed API: + // https://github.com/haraldh/rust_uds_windows/issues/14 + let addr = match unix.take_path() { + #[cfg(unix)] + UnixPath::File(path) => SocketAddr::from_pathname(path)?, + #[cfg(windows)] + UnixPath::File(path) => path, + #[cfg(target_os = "linux")] + UnixPath::Abstract(name) => SocketAddr::from_abstract_name(name)?, + UnixPath::Dir(_) | UnixPath::TmpDir(_) => { + // you can't connect to a unix:dir + return Err(Error::Unsupported); + } + }; + let stream = crate::Task::spawn_blocking( + move || -> Result<_> { + #[cfg(unix)] + let stream = UnixStream::connect_addr(&addr)?; #[cfg(windows)] - { - let stream = crate::Task::spawn_blocking( - move || UnixStream::connect(path), - "unix stream connection", - ) - .await?; - Async::new(stream) - .map(Stream::Unix) - .map_err(|e| Error::InputOutput(e.into())) - } + let stream = UnixStream::connect(addr)?; + stream.set_nonblocking(true)?; + + Ok(stream) + }, + "unix stream connection", + ) + .await?; + #[cfg(not(feature = "tokio"))] + { + Async::new(stream) + .map(Stream::Unix) + .map_err(|e| Error::InputOutput(e.into())) + } - #[cfg(not(windows))] - { - Async::::connect(path) - .await - .map(Stream::Unix) - .map_err(|e| Error::InputOutput(e.into())) - } + #[cfg(feature = "tokio")] + { + #[cfg(unix)] + { + tokio::net::UnixStream::from_std(stream) + .map(Stream::Unix) + .map_err(|e| Error::InputOutput(e.into())) } - #[cfg(feature = "tokio")] + #[cfg(not(unix))] { - #[cfg(unix)] - { - UnixStream::connect(path) - .await - .map(Stream::Unix) - .map_err(|e| Error::InputOutput(e.into())) - } - - #[cfg(not(unix))] - { - let _ = path; - Err(Error::Unsupported) - } + let _ = path; + Err(Error::Unsupported) } } - UnixPath::Dir(_) | UnixPath::TmpDir(_) => { - // you can't connect to a unix:dir - Err(Error::Unsupported) - } - }, - + } #[cfg(all(feature = "vsock", not(feature = "tokio")))] Transport::Vsock(addr) => { let stream = VsockStream::connect_with_cid_port(addr.cid(), addr.port())?; @@ -239,7 +242,7 @@ pub(crate) enum Stream { #[derive(Debug)] pub(crate) enum Stream { #[cfg(unix)] - Unix(UnixStream), + Unix(tokio::net::UnixStream), Tcp(TcpStream), #[cfg(feature = "tokio-vsock")] Vsock(VsockStream), @@ -327,21 +330,12 @@ pub(super) fn encode_percents(f: &mut Formatter<'_>, mut value: &[u8]) -> std::f impl Display for Transport { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - fn fmt_unix_path( - f: &mut Formatter<'_>, - path: &OsStr, - _is_abstract: bool, - ) -> std::fmt::Result { + fn fmt_unix_path(f: &mut Formatter<'_>, path: &OsStr) -> std::fmt::Result { #[cfg(unix)] { use std::os::unix::ffi::OsStrExt; - let bytes = if _is_abstract { - &path.as_bytes()[1..] - } else { - path.as_bytes() - }; - encode_percents(f, bytes)?; + encode_percents(f, path.as_bytes())?; } #[cfg(windows)] @@ -378,32 +372,21 @@ impl Display for Transport { Self::Unix(unix) => match unix.path() { UnixPath::File(path) => { - let is_abstract = { - #[cfg(unix)] - { - use std::os::unix::ffi::OsStrExt; - - path.as_bytes().first() == Some(&b'\0') - } - #[cfg(not(unix))] - false - }; - - if is_abstract { - f.write_str("unix:abstract=")?; - } else { - f.write_str("unix:path=")?; - } - - fmt_unix_path(f, path, is_abstract)?; + f.write_str("unix:path=")?; + fmt_unix_path(f, path)?; + } + #[cfg(target_os = "linux")] + UnixPath::Abstract(name) => { + f.write_str("unix:abstract=")?; + encode_percents(f, name)?; } UnixPath::Dir(path) => { f.write_str("unix:dir=")?; - fmt_unix_path(f, path, false)?; + fmt_unix_path(f, path)?; } UnixPath::TmpDir(path) => { f.write_str("unix:tmpdir=")?; - fmt_unix_path(f, path, false)?; + fmt_unix_path(f, path)?; } }, diff --git a/zbus/src/address/transport/unix.rs b/zbus/src/address/transport/unix.rs index b02441ca9..6c3387553 100644 --- a/zbus/src/address/transport/unix.rs +++ b/zbus/src/address/transport/unix.rs @@ -30,10 +30,13 @@ impl Unix { let tmpdir = opts.get("tmpdir"); let path = match (path, abs, dir, tmpdir) { (Some(p), None, None, None) => UnixPath::File(OsString::from(p)), - (None, Some(p), None, None) => { - let mut s = OsString::from("\0"); - s.push(p); - UnixPath::File(s) + #[cfg(target_os = "linux")] + (None, Some(p), None, None) => UnixPath::Abstract(p.as_bytes().to_owned()), + #[cfg(not(target_os = "linux"))] + (None, Some(_), None, None) => { + return Err(crate::Error::Address( + "abstract sockets currently Linux-only".to_owned(), + )); } (None, None, Some(p), None) => UnixPath::Dir(OsString::from(p)), (None, None, None, Some(p)) => UnixPath::TmpDir(OsString::from(p)), @@ -51,6 +54,9 @@ impl Unix { pub enum UnixPath { /// A path to a unix domain socket on the filesystem. File(OsString), + /// A abstract unix domain socket name. + #[cfg(target_os = "linux")] + Abstract(Vec), /// A listenable address using the specified path, in which a socket file with a random file /// name starting with 'dbus-' will be created by the server. See [UNIX domain socket address] /// reference documentation.