Skip to content

Commit

Permalink
✨ zb: Proper support for abstract unix sockets
Browse files Browse the repository at this point in the history
Add a new address::UnixPath variant for abstract sockets and replace our
hack with the abstract unix socket support from std.

Fixes #329.
  • Loading branch information
zeenix committed Jan 26, 2024
1 parent b616888 commit 5300ad4
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 86 deletions.
10 changes: 6 additions & 4 deletions zbus/src/address/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ mod tests {
Error::Address(e) => assert_eq!(e, "unix: address is invalid"),
_ => panic!(),
}
#[cfg(target_os = "linux")]
match Address::from_str("unix:path=/tmp,abstract=foo").unwrap_err() {
Error::Address(e) => {
assert_eq!(e, "unix: address is invalid")
Expand All @@ -235,9 +236,10 @@ mod tests {
Address::from_str("unix:path=/tmp/dbus-foo").unwrap(),
Transport::Unix(Unix::new(UnixPath::File("/tmp/dbus-foo".into()))).into(),
);
#[cfg(target_os = "linux")]
assert_eq!(
Address::from_str("unix:abstract=/tmp/dbus-foo").unwrap(),
Transport::Unix(Unix::new(UnixPath::File("\0/tmp/dbus-foo".into()))).into(),
Transport::Unix(Unix::new(UnixPath::Abstract("/tmp/dbus-foo".into()))).into(),
);
let guid = crate::Guid::generate();
assert_eq!(
Expand Down Expand Up @@ -340,10 +342,10 @@ mod tests {
"unix:tmpdir=/tmp/dbus-foo"
);
// FIXME: figure out how to handle abstract on Windows
#[cfg(unix)]
#[cfg(target_os = "linux")]
assert_eq!(
Address::from(Transport::Unix(Unix::new(UnixPath::File(
"\0/tmp/dbus-foo".into()
Address::from(Transport::Unix(Unix::new(UnixPath::Abstract(
"/tmp/dbus-foo".into()
))))
.to_string(),
"unix:abstract=/tmp/dbus-foo"
Expand Down
139 changes: 61 additions & 78 deletions zbus/src/address/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,14 @@ use crate::{Error, Result};
use async_io::Async;
#[cfg(not(feature = "tokio"))]
use std::net::TcpStream;
#[cfg(all(unix, not(feature = "tokio")))]
use std::os::unix::net::UnixStream;
#[cfg(unix)]
use std::os::unix::net::{SocketAddr, UnixStream};
use std::{collections::HashMap, ffi::OsStr};
#[cfg(feature = "tokio")]
use tokio::net::TcpStream;
#[cfg(all(unix, feature = "tokio"))]
use tokio::net::UnixStream;
#[cfg(feature = "tokio-vsock")]
use tokio_vsock::VsockStream;
#[cfg(all(windows, not(feature = "tokio")))]
#[cfg(windows)]
use uds_windows::UnixStream;
#[cfg(all(feature = "vsock", not(feature = "tokio")))]
use vsock::VsockStream;
Expand Down Expand Up @@ -47,6 +45,8 @@ pub use launchd::Launchd;
#[path = "vsock.rs"]
// Gotta rename to avoid name conflict with the `vsock` crate.
mod vsock_transport;
#[cfg(target_os = "linux")]
use std::os::linux::net::SocketAddrExt;
#[cfg(any(
all(feature = "vsock", not(feature = "tokio")),
feature = "tokio-vsock"
Expand Down Expand Up @@ -83,54 +83,57 @@ impl Transport {
#[cfg_attr(any(target_os = "macos", windows), async_recursion::async_recursion)]
pub(super) async fn connect(self) -> Result<Stream> {
match self {
Transport::Unix(unix) => match unix.take_path() {
UnixPath::File(path) => {
#[cfg(not(feature = "tokio"))]
{
Transport::Unix(unix) => {
// This is a `path` in case of Windows until uds_windows provides the needed API:
// https://github.com/haraldh/rust_uds_windows/issues/14
let addr = match unix.take_path() {
#[cfg(unix)]
UnixPath::File(path) => SocketAddr::from_pathname(path)?,
#[cfg(windows)]
UnixPath::File(path) => path,
#[cfg(target_os = "linux")]
UnixPath::Abstract(name) => SocketAddr::from_abstract_name(name)?,
UnixPath::Dir(_) | UnixPath::TmpDir(_) => {
// you can't connect to a unix:dir
return Err(Error::Unsupported);
}
};
let stream = crate::Task::spawn_blocking(
move || -> Result<_> {
#[cfg(unix)]
let stream = UnixStream::connect_addr(&addr)?;
#[cfg(windows)]
{
let stream = crate::Task::spawn_blocking(
move || UnixStream::connect(path),
"unix stream connection",
)
.await?;
Async::new(stream)
.map(Stream::Unix)
.map_err(|e| Error::InputOutput(e.into()))
}
let stream = UnixStream::connect(addr)?;
stream.set_nonblocking(true)?;

Ok(stream)
},
"unix stream connection",
)
.await?;
#[cfg(not(feature = "tokio"))]
{
Async::new(stream)
.map(Stream::Unix)
.map_err(|e| Error::InputOutput(e.into()))
}

#[cfg(not(windows))]
{
Async::<UnixStream>::connect(path)
.await
.map(Stream::Unix)
.map_err(|e| Error::InputOutput(e.into()))
}
#[cfg(feature = "tokio")]
{
#[cfg(unix)]
{
tokio::net::UnixStream::from_std(stream)
.map(Stream::Unix)
.map_err(|e| Error::InputOutput(e.into()))
}

#[cfg(feature = "tokio")]
#[cfg(not(unix))]
{
#[cfg(unix)]
{
UnixStream::connect(path)
.await
.map(Stream::Unix)
.map_err(|e| Error::InputOutput(e.into()))
}

#[cfg(not(unix))]
{
let _ = path;
Err(Error::Unsupported)
}
let _ = path;
Err(Error::Unsupported)
}
}
UnixPath::Dir(_) | UnixPath::TmpDir(_) => {
// you can't connect to a unix:dir
Err(Error::Unsupported)
}
},

}
#[cfg(all(feature = "vsock", not(feature = "tokio")))]
Transport::Vsock(addr) => {
let stream = VsockStream::connect_with_cid_port(addr.cid(), addr.port())?;
Expand Down Expand Up @@ -239,7 +242,7 @@ pub(crate) enum Stream {
#[derive(Debug)]
pub(crate) enum Stream {
#[cfg(unix)]
Unix(UnixStream),
Unix(tokio::net::UnixStream),
Tcp(TcpStream),
#[cfg(feature = "tokio-vsock")]
Vsock(VsockStream),
Expand Down Expand Up @@ -327,21 +330,12 @@ pub(super) fn encode_percents(f: &mut Formatter<'_>, mut value: &[u8]) -> std::f

impl Display for Transport {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
fn fmt_unix_path(
f: &mut Formatter<'_>,
path: &OsStr,
_is_abstract: bool,
) -> std::fmt::Result {
fn fmt_unix_path(f: &mut Formatter<'_>, path: &OsStr) -> std::fmt::Result {
#[cfg(unix)]
{
use std::os::unix::ffi::OsStrExt;

let bytes = if _is_abstract {
&path.as_bytes()[1..]
} else {
path.as_bytes()
};
encode_percents(f, bytes)?;
encode_percents(f, path.as_bytes())?;
}

#[cfg(windows)]
Expand Down Expand Up @@ -378,32 +372,21 @@ impl Display for Transport {

Self::Unix(unix) => match unix.path() {
UnixPath::File(path) => {
let is_abstract = {
#[cfg(unix)]
{
use std::os::unix::ffi::OsStrExt;

path.as_bytes().first() == Some(&b'\0')
}
#[cfg(not(unix))]
false
};

if is_abstract {
f.write_str("unix:abstract=")?;
} else {
f.write_str("unix:path=")?;
}

fmt_unix_path(f, path, is_abstract)?;
f.write_str("unix:path=")?;
fmt_unix_path(f, path)?;
}
#[cfg(target_os = "linux")]
UnixPath::Abstract(name) => {
f.write_str("unix:abstract=")?;
encode_percents(f, name)?;
}
UnixPath::Dir(path) => {
f.write_str("unix:dir=")?;
fmt_unix_path(f, path, false)?;
fmt_unix_path(f, path)?;
}
UnixPath::TmpDir(path) => {
f.write_str("unix:tmpdir=")?;
fmt_unix_path(f, path, false)?;
fmt_unix_path(f, path)?;
}
},

Expand Down
14 changes: 10 additions & 4 deletions zbus/src/address/transport/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ impl Unix {
let tmpdir = opts.get("tmpdir");
let path = match (path, abs, dir, tmpdir) {
(Some(p), None, None, None) => UnixPath::File(OsString::from(p)),
(None, Some(p), None, None) => {
let mut s = OsString::from("\0");
s.push(p);
UnixPath::File(s)
#[cfg(target_os = "linux")]
(None, Some(p), None, None) => UnixPath::Abstract(p.as_bytes().to_owned()),
#[cfg(not(target_os = "linux"))]
(None, Some(_), None, None) => {
return Err(crate::Error::Address(
"abstract sockets currently Linux-only".to_owned(),
));
}
(None, None, Some(p), None) => UnixPath::Dir(OsString::from(p)),
(None, None, None, Some(p)) => UnixPath::TmpDir(OsString::from(p)),
Expand All @@ -51,6 +54,9 @@ impl Unix {
pub enum UnixPath {
/// A path to a unix domain socket on the filesystem.
File(OsString),
/// A abstract unix domain socket name.
#[cfg(target_os = "linux")]
Abstract(Vec<u8>),
/// A listenable address using the specified path, in which a socket file with a random file
/// name starting with 'dbus-' will be created by the server. See [UNIX domain socket address]
/// reference documentation.
Expand Down

0 comments on commit 5300ad4

Please sign in to comment.