diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index de624564e..8aa28eabf 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest env: RUSTFLAGS: -D warnings - MSRV: 1.64.0 + MSRV: 1.65.0 steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@master diff --git a/zbus/src/addr/address.rs b/zbus/src/addr/address.rs new file mode 100644 index 000000000..cf162169a --- /dev/null +++ b/zbus/src/addr/address.rs @@ -0,0 +1,219 @@ +use std::{borrow::Cow, collections::HashSet, fmt}; + +use crate::{Error, Guid, Result}; + +use super::{ + percent::{decode_percents, decode_percents_str, Encodable}, + transport::Transport, +}; + +/// A bus address. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct DBusAddr<'a> { + pub(super) addr: Cow<'a, str>, +} + +impl<'a> DBusAddr<'a> { + /// The connection UUID (guid=) if any. + pub fn guid(&self) -> Option> { + match self.get_string("guid") { + Some(Ok(v)) => Some(Guid::try_from(v.as_ref())), + Some(Err(e)) => Some(Err(e)), + _ => None, + } + } + + /// Transport connection details + pub fn transport(&self) -> Result> { + self.try_into() + } + + fn validate(&self) -> Result<()> { + self.transport()?; + let mut set = HashSet::new(); + for (k, v) in self.key_val_iter() { + if !set.insert(k) { + return Err(Error::Address(format!("Duplicate key `{k}`"))); + } + if let Some(v) = v { + decode_percents(v)?; + } + } + Ok(()) + } + + fn new>>(addr: A) -> Result { + let addr = addr.into(); + let addr = Self { addr }; + + addr.validate()?; + Ok(addr) + } + + pub(super) fn key_val_iter(&'a self) -> KeyValIter<'a> { + let mut split = self.addr.splitn(2, ':'); + // skip transport:.. + split.next(); + let kv = split.next().unwrap_or(""); + KeyValIter::new(kv) + } + + fn get_string(&'a self, key: &str) -> Option>> { + for (k, v) in self.key_val_iter() { + if key == k { + return v.map(decode_percents_str); + } + } + None + } +} + +impl DBusAddr<'_> { + pub(crate) fn to_owned(&self) -> DBusAddr<'static> { + let addr = self.addr.to_string(); + DBusAddr { addr: addr.into() } + } +} + +impl<'a> TryFrom for DBusAddr<'a> { + type Error = Error; + + fn try_from(addr: String) -> Result { + Self::new(addr) + } +} + +impl<'a> TryFrom<&'a str> for DBusAddr<'a> { + type Error = Error; + + fn try_from(addr: &'a str) -> Result { + Self::new(addr) + } +} + +impl fmt::Display for DBusAddr<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let kv = KeyValFmt::new().add("guid", self.guid().and_then(|v| v.ok())); + let t = self.transport().map_err(|_| fmt::Error)?; + let kv = t.key_val_fmt_add(kv); + write!(f, "{t}:{kv}")?; + Ok(()) + } +} + +pub(super) struct KeyValIter<'a> { + data: &'a str, + next_index: usize, +} + +impl<'a> KeyValIter<'a> { + fn new(data: &'a str) -> Self { + KeyValIter { + data, + next_index: 0, + } + } +} + +impl<'a> Iterator for KeyValIter<'a> { + type Item = (&'a str, Option<&'a str>); + + fn next(&mut self) -> Option { + if self.next_index >= self.data.len() { + return None; + } + + let mut pair = &self.data[self.next_index..]; + if let Some(end) = pair.find(',') { + pair = &pair[..end]; + self.next_index += end + 1; + } else { + self.next_index = self.data.len(); + } + let mut split = pair.split('='); + // SAFETY: first split always returns something + let key = split.next().unwrap(); + Some((key, split.next())) + } +} + +pub(crate) trait KeyValFmtAdd { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b>; +} + +pub(crate) struct KeyValFmt<'a> { + fields: Vec<(Box, Box)>, +} + +impl<'a> KeyValFmt<'a> { + fn new() -> Self { + Self { fields: vec![] } + } + + pub(crate) fn add(mut self, key: K, val: Option) -> Self + where + K: fmt::Display + 'a, + V: Encodable + 'a, + { + if let Some(val) = val { + self.fields.push((Box::new(key), Box::new(val))); + } + self + } +} + +impl fmt::Display for KeyValFmt<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut first = true; + for (k, v) in self.fields.iter() { + if !first { + write!(f, ",")?; + } + write!(f, "{k}=")?; + v.encode(f)?; + first = false; + } + Ok(()) + } +} + +/// A trait for objects which can be converted or resolved to one or more [`DBusAddr`] values. +pub trait ToDBusAddrs<'a> { + type Iter: Iterator>>; + + fn to_dbus_addrs(&'a self) -> Self::Iter; +} + +impl<'a> ToDBusAddrs<'a> for DBusAddr<'a> { + type Iter = std::iter::Once>>; + + /// Get an iterator over the D-Bus addresses. + fn to_dbus_addrs(&'a self) -> Self::Iter { + std::iter::once(Ok(self.clone())) + } +} + +impl<'a> ToDBusAddrs<'a> for str { + type Iter = std::iter::Once>>; + + fn to_dbus_addrs(&'a self) -> Self::Iter { + std::iter::once(self.try_into()) + } +} + +impl<'a> ToDBusAddrs<'a> for String { + type Iter = std::iter::Once>>; + + fn to_dbus_addrs(&'a self) -> Self::Iter { + std::iter::once(self.as_str().try_into()) + } +} + +impl<'a> ToDBusAddrs<'a> for Vec>> { + type Iter = std::iter::Cloned>>>; + + /// Get an iterator over the D-Bus addresses. + fn to_dbus_addrs(&'a self) -> Self::Iter { + self.iter().cloned() + } +} diff --git a/zbus/src/addr/address_list.rs b/zbus/src/addr/address_list.rs new file mode 100644 index 000000000..2001f040f --- /dev/null +++ b/zbus/src/addr/address_list.rs @@ -0,0 +1,82 @@ +use std::{borrow::Cow, fmt}; + +use crate::{Error, Result}; + +use super::{DBusAddr, ToDBusAddrs}; + +/// A bus address list. +/// +/// D-Bus addresses are `;`-separated. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct DBusAddrList<'a> { + addr: Cow<'a, str>, +} + +impl<'a> ToDBusAddrs<'a> for DBusAddrList<'a> { + type Iter = DBusAddrListIter<'a>; + + /// Get an iterator over the D-Bus addresses. + fn to_dbus_addrs(&'a self) -> Self::Iter { + DBusAddrListIter::new(self) + } +} + +impl<'a> Iterator for DBusAddrListIter<'a> { + type Item = Result>; + + fn next(&mut self) -> Option { + if self.next_index >= self.data.len() { + return None; + } + + let mut addr = &self.data[self.next_index..]; + if let Some(end) = addr.find(';') { + addr = &addr[..end]; + self.next_index += end + 1; + } else { + self.next_index = self.data.len(); + } + Some(DBusAddr::try_from(addr)) + } +} + +/// An iterator of D-Bus addresses. +pub struct DBusAddrListIter<'a> { + data: &'a str, + next_index: usize, +} + +impl<'a> DBusAddrListIter<'a> { + fn new(list: &'a DBusAddrList<'_>) -> Self { + Self { + data: list.addr.as_ref(), + next_index: 0, + } + } +} + +impl<'a> TryFrom for DBusAddrList<'a> { + type Error = Error; + + fn try_from(value: String) -> Result { + Ok(Self { + addr: Cow::Owned(value), + }) + } +} + +impl<'a> TryFrom<&'a str> for DBusAddrList<'a> { + type Error = Error; + + fn try_from(value: &'a str) -> Result { + Ok(Self { + addr: Cow::Borrowed(value), + }) + } +} + +impl fmt::Display for DBusAddrList<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.addr) + } +} diff --git a/zbus/src/addr/mod.rs b/zbus/src/addr/mod.rs new file mode 100644 index 000000000..4e7457d83 --- /dev/null +++ b/zbus/src/addr/mod.rs @@ -0,0 +1,264 @@ +//! D-Bus address handling. +//! +//! Server addresses consist of a transport name followed by a colon, and then an optional, +//! comma-separated list of keys and values in the form key=value. +//! +//! See also: +//! +//! * [Server addresses] in the D-Bus specification. +//! +//! [Server addresses]: https://dbus.freedesktop.org/doc/dbus-specification.html#addresses + +// note: assumes values are utf-8 encoded - this should be clarified in the spec +// otherwise, fail to read them or use lossy representation for display +// +// assumes that empty key=val is accepted, so "transport:,,guid=..." is valid +// +// allows key only, so "transport:foo,bar" is ok +// +// mostly ignores unknown keys and transport + +use std::env; + +#[cfg(all(unix, not(target_os = "macos")))] +use nix::unistd::Uid; + +use crate::Result; + +pub mod transport; + +mod address; +pub use address::{DBusAddr, ToDBusAddrs}; + +mod address_list; +pub use address_list::{DBusAddrList, DBusAddrListIter}; + +mod percent; +pub use percent::*; + +/// Get the address for session socket respecting the DBUS_SESSION_BUS_ADDRESS environment +/// variable. If we don't recognize the value (or it's not set) we fall back to +/// $XDG_RUNTIME_DIR/bus +pub fn session() -> Result> { + match env::var("DBUS_SESSION_BUS_ADDRESS") { + Ok(val) => DBusAddrList::try_from(val), + _ => { + #[cfg(windows)] + { + #[cfg(feature = "windows-gdbus")] + return DBusAddrList::try_from("autolaunch:"); + + #[cfg(not(feature = "windows-gdbus"))] + return DBusAddrList::try_from("autolaunch:scope=*user"); + } + + #[cfg(all(unix, not(target_os = "macos")))] + { + let runtime_dir = env::var("XDG_RUNTIME_DIR") + .unwrap_or_else(|_| format!("/run/user/{}", Uid::effective())); + let path = format!("unix:path={runtime_dir}/bus"); + + DBusAddrList::try_from(path) + } + + #[cfg(target_os = "macos")] + return DBusAddrList::try_from("launchd:env=DBUS_LAUNCHD_SESSION_BUS_SOCKET"); + } + } +} + +/// Get the address for system bus respecting the DBUS_SYSTEM_BUS_ADDRESS environment +/// variable. If we don't recognize the value (or it's not set) we fall back to +/// /var/run/dbus/system_bus_socket +pub fn system() -> Result> { + match env::var("DBUS_SYSTEM_BUS_ADDRESS") { + Ok(val) => DBusAddrList::try_from(val), + _ => { + #[cfg(all(unix, not(target_os = "macos")))] + return DBusAddrList::try_from("unix:path=/var/run/dbus/system_bus_socket"); + + #[cfg(windows)] + return DBusAddrList::try_from("autolaunch:"); + + #[cfg(target_os = "macos")] + return DBusAddrList::try_from("launchd:env=DBUS_LAUNCHD_SESSION_BUS_SOCKET"); + } + } +} + +#[cfg(test)] +mod tests { + use std::{borrow::Cow, ffi::OsStr}; + + use crate::addr::transport::{AutolaunchScope, TcpFamily}; + + use super::{ + transport::{Transport, UnixAddrKind}, + DBusAddr, + }; + + #[test] + fn parse_err() { + assert_eq!( + DBusAddr::try_from("").unwrap_err().to_string(), + "address error: DBusAddr has no transport" + ); + assert_eq!( + DBusAddr::try_from("foo").unwrap_err().to_string(), + "address error: DBusAddr has no transport" + ); + DBusAddr::try_from("foo:opt").unwrap(); + assert_eq!( + DBusAddr::try_from("foo:opt=1,opt=2") + .unwrap_err() + .to_string(), + "address error: Duplicate key `opt`" + ); + assert_eq!( + DBusAddr::try_from("foo:opt=%1").unwrap_err().to_string(), + "address error: Incomplete percent-encoded sequence" + ); + assert_eq!( + DBusAddr::try_from("foo:opt=%1z").unwrap_err().to_string(), + "address error: Invalid hexadecimal character in percent-encoded sequence" + ); + assert_eq!( + DBusAddr::try_from("foo:opt=1\rz").unwrap_err().to_string(), + "address error: Invalid character in address" + ); + + let addr = DBusAddr::try_from("foo:guid=9406e28972c595c590766c9564ce623f").unwrap(); + addr.guid().unwrap().unwrap(); + } + + #[test] + fn parse_unix() { + let addr = + DBusAddr::try_from("unix:path=/tmp/dbus-foo,guid=9406e28972c595c590766c9564ce623f") + .unwrap(); + let Transport::Unix(u) = addr.transport().unwrap() else { + panic!(); + }; + assert_eq!( + u.kind(), + &UnixAddrKind::Path(Cow::Borrowed(OsStr::new("/tmp/dbus-foo"))) + ); + + assert_eq!( + DBusAddr::try_from("unix:foo=blah").unwrap_err().to_string(), + "address error: Invalid `unix:` address, missing required key" + ); + assert_eq!( + DBusAddr::try_from("unix:path=/blah,abstract=foo").unwrap_err().to_string(), + "address error: Invalid address: only one of `path` `dir` `tmpdir` `abstract` or `runtime` expected" + ); + assert_eq!( + DBusAddr::try_from("unix:runtime=no") + .unwrap_err() + .to_string(), + "address error: Invalid runtime=no value" + ); + DBusAddr::try_from(String::from("unix:path=/tmp/foo")).unwrap(); + } + + #[test] + fn parse_launchd() { + let addr = DBusAddr::try_from("launchd:env=FOOBAR").unwrap(); + let Transport::Launchd(t) = addr.transport().unwrap() else { + panic!(); + }; + assert_eq!(t.env(), "FOOBAR"); + + assert_eq!( + DBusAddr::try_from("launchd:weof").unwrap_err().to_string(), + "address error: Missing env=" + ); + } + + #[test] + fn parse_systemd() { + let addr = DBusAddr::try_from("systemd:").unwrap(); + let Transport::Systemd(_) = addr.transport().unwrap() else { + panic!(); + }; + } + + #[test] + fn parse_tcp() { + let addr = DBusAddr::try_from("tcp:host=localhost,bind=*,port=0,family=ipv4").unwrap(); + let Transport::Tcp(t) = addr.transport().unwrap() else { + panic!(); + }; + assert_eq!(t.host().unwrap(), "localhost"); + assert_eq!(t.bind().unwrap(), "*"); + assert_eq!(t.port().unwrap(), 0); + assert_eq!(t.family().unwrap(), TcpFamily::IPv4); + + let addr = DBusAddr::try_from("tcp:").unwrap(); + let Transport::Tcp(t) = addr.transport().unwrap() else { + panic!(); + }; + assert!(t.host().is_none()); + assert!(t.bind().is_none()); + assert!(t.port().is_none()); + assert!(t.family().is_none()); + } + + #[test] + fn parse_nonce_tcp() { + let addr = + DBusAddr::try_from("nonce-tcp:host=localhost,bind=*,port=0,family=ipv6,noncefile=foo") + .unwrap(); + let Transport::NonceTcp(t) = addr.transport().unwrap() else { + panic!(); + }; + assert_eq!(t.host().unwrap(), "localhost"); + assert_eq!(t.bind().unwrap(), "*"); + assert_eq!(t.port().unwrap(), 0); + assert_eq!(t.family().unwrap(), TcpFamily::IPv6); + assert_eq!(t.noncefile().unwrap(), "foo"); + } + + #[test] + fn parse_unixexec() { + let addr = DBusAddr::try_from("unixexec:path=/bin/test,argv2=foo").unwrap(); + let Transport::Unixexec(t) = addr.transport().unwrap() else { + panic!(); + }; + + assert_eq!(t.path(), "/bin/test"); + assert_eq!(t.argv(), &[(2, Cow::from("foo"))]); + + assert_eq!( + DBusAddr::try_from("unixexec:weof").unwrap_err().to_string(), + "address error: Missing path=" + ); + } + + #[test] + fn parse_autolaunch() { + let addr = DBusAddr::try_from("autolaunch:scope=*user").unwrap(); + let Transport::Autolaunch(t) = addr.transport().unwrap() else { + panic!(); + }; + assert_eq!(t.scope().unwrap(), &AutolaunchScope::User); + } + + #[test] + #[cfg(feature = "vsock")] + fn parse_vsock() { + let addr = DBusAddr::try_from("vsock:cid=12,port=32").unwrap(); + let Transport::Vsock(t) = addr.transport().unwrap() else { + panic!(); + }; + assert_eq!(t.port(), Some(32)); + assert_eq!(t.cid(), Some(12)); + + assert_eq!( + DBusAddr::try_from("vsock:port=abc") + .unwrap_err() + .to_string(), + "address error: Invalid port: invalid digit found in string" + ); + } +} diff --git a/zbus/src/addr/percent.rs b/zbus/src/addr/percent.rs new file mode 100644 index 000000000..4d05fdda6 --- /dev/null +++ b/zbus/src/addr/percent.rs @@ -0,0 +1,161 @@ +use std::{ + borrow::Cow, + ffi::{OsStr, OsString}, + fmt, + str::from_utf8_unchecked, +}; + +use crate::{Error, Result}; + +pub(crate) trait Encodable { + fn encode(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result; +} + +impl Encodable for T { + fn encode(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + encode_percents(f, self.to_string().as_bytes()) + } +} + +pub(crate) struct EncData(pub T); + +impl> Encodable for EncData { + fn encode(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + encode_percents(f, self.0.as_ref()) + } +} + +pub(crate) struct EncOsStr(pub T); + +impl Encodable for EncOsStr<&Cow<'_, OsStr>> { + fn encode(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + encode_percents(f, self.0.to_string_lossy().as_bytes()) + } +} + +impl Encodable for EncOsStr<&OsStr> { + fn encode(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + encode_percents(f, self.0.to_string_lossy().as_bytes()) + } +} + +/// Percent-encode the value. +pub fn encode_percents(f: &mut dyn fmt::Write, mut value: &[u8]) -> std::fmt::Result { + const LOOKUP: &str = "\ +%00%01%02%03%04%05%06%07%08%09%0a%0b%0c%0d%0e%0f\ +%10%11%12%13%14%15%16%17%18%19%1a%1b%1c%1d%1e%1f\ +%20%21%22%23%24%25%26%27%28%29%2a%2b%2c%2d%2e%2f\ +%30%31%32%33%34%35%36%37%38%39%3a%3b%3c%3d%3e%3f\ +%40%41%42%43%44%45%46%47%48%49%4a%4b%4c%4d%4e%4f\ +%50%51%52%53%54%55%56%57%58%59%5a%5b%5c%5d%5e%5f\ +%60%61%62%63%64%65%66%67%68%69%6a%6b%6c%6d%6e%6f\ +%70%71%72%73%74%75%76%77%78%79%7a%7b%7c%7d%7e%7f\ +%80%81%82%83%84%85%86%87%88%89%8a%8b%8c%8d%8e%8f\ +%90%91%92%93%94%95%96%97%98%99%9a%9b%9c%9d%9e%9f\ +%a0%a1%a2%a3%a4%a5%a6%a7%a8%a9%aa%ab%ac%ad%ae%af\ +%b0%b1%b2%b3%b4%b5%b6%b7%b8%b9%ba%bb%bc%bd%be%bf\ +%c0%c1%c2%c3%c4%c5%c6%c7%c8%c9%ca%cb%cc%cd%ce%cf\ +%d0%d1%d2%d3%d4%d5%d6%d7%d8%d9%da%db%dc%dd%de%df\ +%e0%e1%e2%e3%e4%e5%e6%e7%e8%e9%ea%eb%ec%ed%ee%ef\ +%f0%f1%f2%f3%f4%f5%f6%f7%f8%f9%fa%fb%fc%fd%fe%ff"; + + loop { + let pos = value.iter().position( + |c| !matches!(c, b'-' | b'0'..=b'9' | b'A'..=b'Z' | b'a'..=b'z' | b'_' | b'/' | b'.' | b'\\' | b'*'), + ); + + if let Some(pos) = pos { + // SAFETY: The above `position()` call made sure that only ASCII chars are in the string + // up to `pos` + f.write_str(unsafe { from_utf8_unchecked(&value[..pos]) })?; + + let c = value[pos]; + value = &value[pos + 1..]; + + let pos = c as usize * 3; + f.write_str(&LOOKUP[pos..pos + 3])?; + } else { + // SAFETY: The above `position()` call made sure that only ASCII chars are in the rest + // of the string + f.write_str(unsafe { from_utf8_unchecked(value) })?; + return Ok(()); + } + } +} + +fn decode_hex(c: char) -> Result { + match c { + '0'..='9' => Ok(c as u8 - b'0'), + 'a'..='f' => Ok(c as u8 - b'a' + 10), + 'A'..='F' => Ok(c as u8 - b'A' + 10), + + _ => Err(Error::Address( + "Invalid hexadecimal character in percent-encoded sequence".to_owned(), + )), + } +} + +/// Percent-decode the string. +pub fn decode_percents(value: &str) -> Result> { + if value.find('%').is_none() { + if value.find(|c| { + !matches!(c, '-' | '0'..='9' | 'A'..='Z' | 'a'..='z' | '_' | '/' | '.' | '\\' | '*') + }).is_some() { + return Err(Error::Address("Invalid character in address".into())); + } + return Ok(value.as_bytes().into()); + } + + let mut iter = value.chars(); + let mut decoded = Vec::new(); + + while let Some(c) = iter.next() { + if matches!(c, '-' | '0'..='9' | 'A'..='Z' | 'a'..='z' | '_' | '/' | '.' | '\\' | '*') { + decoded.push(c as u8) + } else if c == '%' { + decoded.push( + decode_hex(iter.next().ok_or_else(|| { + Error::Address("Incomplete percent-encoded sequence".into()) + })?)? + << 4 + | decode_hex(iter.next().ok_or_else(|| { + Error::Address("Incomplete percent-encoded sequence".into()) + })?)?, + ); + } else { + return Err(Error::Address("Invalid character in address".into())); + } + } + + Ok(decoded.into()) +} + +pub(super) fn decode_percents_str(value: &str) -> Result> { + cow_bytes_to_str(decode_percents(value)?) +} + +fn cow_bytes_to_str(cow: Cow<'_, [u8]>) -> Result> { + match cow { + Cow::Borrowed(bytes) => Ok(Cow::Borrowed( + std::str::from_utf8(bytes).map_err(|e| Error::Address(format!("{e}")))?, + )), + Cow::Owned(bytes) => Ok(Cow::Owned( + String::from_utf8(bytes).map_err(|e| Error::Address(format!("{e}")))?, + )), + } +} + +pub(super) fn decode_percents_os_str(value: &str) -> Result> { + cow_bytes_to_os_str(decode_percents(value)?) +} + +fn cow_bytes_to_os_str(cow: Cow<'_, [u8]>) -> Result> { + match cow { + Cow::Borrowed(bytes) => Ok(Cow::Borrowed(OsStr::new( + std::str::from_utf8(bytes).map_err(|e| Error::Address(format!("{e}")))?, + ))), + Cow::Owned(bytes) => Ok(Cow::Owned(OsString::from( + String::from_utf8(bytes).map_err(|e| Error::Address(format!("{e}")))?, + ))), + } +} diff --git a/zbus/src/addr/transport.rs b/zbus/src/addr/transport.rs new file mode 100644 index 000000000..b9ec8cfd4 --- /dev/null +++ b/zbus/src/addr/transport.rs @@ -0,0 +1,108 @@ +//! D-Bus supported transports. + +use std::fmt; + +use crate::{Error, Result}; + +use super::{ + address::{KeyValFmt, KeyValFmtAdd}, + percent, DBusAddr, +}; + +mod autolaunch; +pub use autolaunch::{Autolaunch, AutolaunchScope}; + +mod launchd; +pub use launchd::Launchd; + +mod nonce_tcp; +pub use nonce_tcp::NonceTcp; + +mod systemd; +pub use systemd::Systemd; + +mod tcp; +pub use tcp::{Tcp, TcpFamily}; + +mod unix; +pub use unix::{Unix, UnixAddrKind}; + +mod unixexec; +pub use unixexec::Unixexec; + +mod vsock; +#[cfg(any(feature = "vsock", feature = "tokio-vsock"))] +pub use self::vsock::*; + +/// A D-Bus transport. +#[derive(Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum Transport<'a> { + Unix(unix::Unix<'a>), + Launchd(launchd::Launchd<'a>), + Systemd(systemd::Systemd<'a>), + Tcp(tcp::Tcp<'a>), + NonceTcp(nonce_tcp::NonceTcp<'a>), + Unixexec(unixexec::Unixexec<'a>), + Autolaunch(autolaunch::Autolaunch<'a>), + #[cfg(any(feature = "vsock", feature = "tokio-vsock"))] + Vsock(vsock::Vsock<'a>), + Other(&'a str), +} + +impl fmt::Display for Transport<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Unix(_) => write!(f, "unix"), + Self::Launchd(_) => write!(f, "launchd"), + Self::Systemd(_) => write!(f, "systemd"), + Self::Tcp(_) => write!(f, "tcp"), + Self::NonceTcp(_) => write!(f, "nonce-tcp"), + Self::Unixexec(_) => write!(f, "unixexec"), + Self::Autolaunch(_) => write!(f, "autolaunch"), + #[cfg(any(feature = "vsock", feature = "tokio-vsock"))] + Self::Vsock(_) => write!(f, "vsock"), + Self::Other(o) => write!(f, "{o}"), + } + } +} + +impl KeyValFmtAdd for Transport<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + match self { + Self::Unix(t) => t.key_val_fmt_add(kv), + Self::Launchd(t) => t.key_val_fmt_add(kv), + Self::Systemd(t) => t.key_val_fmt_add(kv), + Self::Tcp(t) => t.key_val_fmt_add(kv), + Self::NonceTcp(t) => t.key_val_fmt_add(kv), + Self::Unixexec(t) => t.key_val_fmt_add(kv), + Self::Autolaunch(t) => t.key_val_fmt_add(kv), + #[cfg(any(feature = "vsock", feature = "tokio-vsock"))] + Self::Vsock(t) => t.key_val_fmt_add(kv), + Self::Other(_) => kv, + } + } +} + +impl<'a> TryFrom<&'a DBusAddr<'a>> for Transport<'a> { + type Error = Error; + + fn try_from(s: &'a DBusAddr<'a>) -> Result { + let col = s + .addr + .find(':') + .ok_or_else(|| Error::Address("DBusAddr has no transport".into()))?; + match &s.addr[..col] { + "unix" => Ok(Self::Unix(s.try_into()?)), + "launchd" => Ok(Self::Launchd(s.try_into()?)), + "systemd" => Ok(Self::Systemd(s.try_into()?)), + "tcp" => Ok(Self::Tcp(s.try_into()?)), + "nonce-tcp" => Ok(Self::NonceTcp(s.try_into()?)), + "unixexec" => Ok(Self::Unixexec(s.try_into()?)), + "autolaunch" => Ok(Self::Autolaunch(s.try_into()?)), + #[cfg(any(feature = "vsock", feature = "tokio-vsock"))] + "vsock" => Ok(Self::Vsock(s.try_into()?)), + o => Ok(Self::Other(o)), + } + } +} diff --git a/zbus/src/addr/transport/autolaunch.rs b/zbus/src/addr/transport/autolaunch.rs new file mode 100644 index 000000000..e42ae086f --- /dev/null +++ b/zbus/src/addr/transport/autolaunch.rs @@ -0,0 +1,77 @@ +use std::{borrow::Cow, fmt}; + +use crate::{Error, Result}; + +use super::{percent::decode_percents_str, DBusAddr, KeyValFmt, KeyValFmtAdd}; + +/// Scope of autolaunch (Windows only) +#[derive(Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum AutolaunchScope<'a> { + /// Limit session bus to dbus installation path. + InstallPath, + /// Limit session bus to the recent user. + User, + /// other values - specify dedicated session bus like "release", "debug" or other. + Other(Cow<'a, str>), +} + +impl fmt::Display for AutolaunchScope<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::InstallPath => write!(f, "*install-path"), + Self::User => write!(f, "*user"), + Self::Other(o) => write!(f, "{o}"), + } + } +} + +impl<'a> TryFrom> for AutolaunchScope<'a> { + type Error = Error; + + fn try_from(s: Cow<'a, str>) -> Result { + match s.as_ref() { + "*install-path" => Ok(Self::InstallPath), + "*user" => Ok(Self::User), + _ => Ok(Self::Other(s)), + } + } +} + +/// `autolaunch:` D-Bus transport. +#[derive(Debug, PartialEq, Eq, Default)] +pub struct Autolaunch<'a> { + scope: Option>, +} + +impl<'a> Autolaunch<'a> { + /// Scope of autolaunch (Windows only) + pub fn scope(&self) -> Option<&AutolaunchScope<'a>> { + self.scope.as_ref() + } +} + +impl<'a> TryFrom<&'a DBusAddr<'a>> for Autolaunch<'a> { + type Error = Error; + + fn try_from(s: &'a DBusAddr<'a>) -> Result { + let mut res = Autolaunch::default(); + + for (k, v) in s.key_val_iter() { + match (k, v) { + ("scope", Some(v)) => { + res.scope = Some(decode_percents_str(v)?.try_into()?); + } + _ => continue, + } + } + + Ok(res) + } +} + +impl KeyValFmtAdd for Autolaunch<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + kv.add("scope", self.scope()) + } +} diff --git a/zbus/src/addr/transport/launchd.rs b/zbus/src/addr/transport/launchd.rs new file mode 100644 index 000000000..be2cd22aa --- /dev/null +++ b/zbus/src/addr/transport/launchd.rs @@ -0,0 +1,44 @@ +use std::borrow::Cow; + +use crate::{Error, Result}; + +use super::{percent::decode_percents_str, DBusAddr, KeyValFmt, KeyValFmtAdd}; + +/// `launchd:` D-Bus transport. +#[derive(Debug, PartialEq, Eq)] +pub struct Launchd<'a> { + env: Cow<'a, str>, +} + +impl<'a> Launchd<'a> { + /// Environment variable used to get the path of the unix domain socket for the launchd created + /// dbus-daemon. + pub fn env(&self) -> &str { + self.env.as_ref() + } +} + +impl<'a> TryFrom<&'a DBusAddr<'a>> for Launchd<'a> { + type Error = Error; + + fn try_from(s: &'a DBusAddr<'a>) -> Result { + for (k, v) in s.key_val_iter() { + match (k, v) { + ("env", Some(v)) => { + return Ok(Launchd { + env: decode_percents_str(v)?, + }); + } + _ => continue, + } + } + + Err(Error::Address("Missing env=".into())) + } +} + +impl KeyValFmtAdd for Launchd<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + kv.add("env", Some(self.env())) + } +} diff --git a/zbus/src/addr/transport/nonce_tcp.rs b/zbus/src/addr/transport/nonce_tcp.rs new file mode 100644 index 000000000..f76759b4b --- /dev/null +++ b/zbus/src/addr/transport/nonce_tcp.rs @@ -0,0 +1,92 @@ +use std::{borrow::Cow, ffi::OsStr}; + +use crate::{ + addr::percent::{decode_percents_os_str, EncOsStr}, + Error, Result, +}; + +use super::{percent::decode_percents_str, tcp::TcpFamily, DBusAddr, KeyValFmt, KeyValFmtAdd}; + +/// `nonce-tcp:` D-Bus transport. +#[derive(Debug, Default, PartialEq, Eq)] +pub struct NonceTcp<'a> { + host: Option>, + bind: Option>, + port: Option, + family: Option, + noncefile: Option>, +} + +impl<'a> NonceTcp<'a> { + /// DNS name or IP address. + pub fn host(&self) -> Option<&str> { + self.host.as_ref().map(|v| v.as_ref()) + } + + /// Used in a listenable address to configure the interface on which the server will listen: + /// either the IP address of one of the local machine's interfaces (most commonly `127.0.0.1`), + /// or a DNS name that resolves to one of those IP addresses, or `*` to listen on all interfaces + /// simultaneously. + pub fn bind(&self) -> Option<&str> { + self.bind.as_ref().map(|v| v.as_ref()) + } + + /// The TCP port the server will open. A zero value let the server choose a free port provided + /// from the underlaying operating system. + pub fn port(&self) -> Option { + self.port + } + + /// If set, provide the type of socket family. + pub fn family(&self) -> Option { + self.family + } + + /// File location containing the secret. This is only meaningful in connectable addresses. + pub fn noncefile(&self) -> Option<&OsStr> { + self.noncefile.as_ref().map(|v| v.as_ref()) + } +} + +impl KeyValFmtAdd for NonceTcp<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + kv.add("host", self.host()) + .add("bind", self.bind()) + .add("port", self.port()) + .add("family", self.family()) + .add("noncefile", self.noncefile().map(EncOsStr)) + } +} + +impl<'a> TryFrom<&'a DBusAddr<'a>> for NonceTcp<'a> { + type Error = Error; + + fn try_from(s: &'a DBusAddr<'a>) -> Result { + let mut res = NonceTcp::default(); + for (k, v) in s.key_val_iter() { + match (k, v) { + ("host", Some(v)) => { + res.host = Some(decode_percents_str(v)?); + } + ("bind", Some(v)) => { + res.bind = Some(decode_percents_str(v)?); + } + ("port", Some(v)) => { + res.port = Some( + decode_percents_str(v)? + .parse() + .map_err(|e| Error::Address(format!("Invalid port: {e}")))?, + ); + } + ("family", Some(v)) => { + res.family = Some(decode_percents_str(v)?.as_ref().try_into()?); + } + ("noncefile", Some(v)) => { + res.noncefile = Some(decode_percents_os_str(v)?); + } + _ => continue, + } + } + Ok(res) + } +} diff --git a/zbus/src/addr/transport/systemd.rs b/zbus/src/addr/transport/systemd.rs new file mode 100644 index 000000000..34128bed6 --- /dev/null +++ b/zbus/src/addr/transport/systemd.rs @@ -0,0 +1,34 @@ +use std::{fmt, marker::PhantomData}; + +use crate::{Error, Result}; + +use super::{DBusAddr, KeyValFmt, KeyValFmtAdd}; + +/// `systemd:` D-Bus transport. +#[derive(Debug, PartialEq, Eq)] +pub struct Systemd<'a> { + // use a phantom lifetime for eventually future fields and consistency + phantom: PhantomData<&'a ()>, +} + +impl<'a> TryFrom<&'a DBusAddr<'a>> for Systemd<'a> { + type Error = Error; + + fn try_from(_s: &'a DBusAddr<'a>) -> Result { + Ok(Systemd { + phantom: PhantomData, + }) + } +} + +impl<'a> fmt::Display for Systemd<'a> { + fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { + Ok(()) + } +} + +impl KeyValFmtAdd for Systemd<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + kv + } +} diff --git a/zbus/src/addr/transport/tcp.rs b/zbus/src/addr/transport/tcp.rs new file mode 100644 index 000000000..29f994403 --- /dev/null +++ b/zbus/src/addr/transport/tcp.rs @@ -0,0 +1,110 @@ +use std::{borrow::Cow, fmt}; + +use crate::{Error, Result}; + +use super::{percent::decode_percents_str, DBusAddr, KeyValFmt, KeyValFmtAdd}; + +/// TCP IP address family +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[non_exhaustive] +pub enum TcpFamily { + /// IPv4 + IPv4, + /// IPv6 + IPv6, +} + +impl fmt::Display for TcpFamily { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::IPv4 => write!(f, "ipv4"), + Self::IPv6 => write!(f, "ipv6"), + } + } +} + +impl TryFrom<&str> for TcpFamily { + type Error = Error; + + fn try_from(s: &str) -> Result { + match s { + "ipv4" => Ok(Self::IPv4), + "ipv6" => Ok(Self::IPv6), + _ => Err(Error::Address(format!("Unknown TCP family: {s}"))), + } + } +} + +/// `tcp:` D-Bus transport. +#[derive(Debug, Default, PartialEq, Eq)] +pub struct Tcp<'a> { + host: Option>, + bind: Option>, + port: Option, + family: Option, +} + +impl<'a> Tcp<'a> { + /// DNS name or IP address. + pub fn host(&self) -> Option<&str> { + self.host.as_ref().map(|v| v.as_ref()) + } + + /// Used in a listenable address to configure the interface on which the server will listen: + /// either the IP address of one of the local machine's interfaces (most commonly `127.0.0.1`), + /// or a DNS name that resolves to one of those IP addresses, or `*` to listen on all interfaces + /// simultaneously. + pub fn bind(&self) -> Option<&str> { + self.bind.as_ref().map(|v| v.as_ref()) + } + + /// The TCP port the server will open. A zero value let the server choose a free port provided + /// from the underlaying operating system. + pub fn port(&self) -> Option { + self.port + } + + /// If set, provide the type of socket family. + pub fn family(&self) -> Option { + self.family + } +} + +impl KeyValFmtAdd for Tcp<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + kv.add("host", self.host()) + .add("bind", self.bind()) + .add("port", self.port()) + .add("family", self.family()) + } +} + +impl<'a> TryFrom<&'a DBusAddr<'a>> for Tcp<'a> { + type Error = Error; + + fn try_from(s: &'a DBusAddr<'a>) -> Result { + let mut res = Tcp::default(); + for (k, v) in s.key_val_iter() { + match (k, v) { + ("host", Some(v)) => { + res.host = Some(decode_percents_str(v)?); + } + ("bind", Some(v)) => { + res.bind = Some(decode_percents_str(v)?); + } + ("port", Some(v)) => { + res.port = Some( + decode_percents_str(v)? + .parse() + .map_err(|e| Error::Address(format!("Invalid port: {e}")))?, + ); + } + ("family", Some(v)) => { + res.family = Some(decode_percents_str(v)?.as_ref().try_into()?); + } + _ => continue, + } + } + Ok(res) + } +} diff --git a/zbus/src/addr/transport/unix.rs b/zbus/src/addr/transport/unix.rs new file mode 100644 index 000000000..1866726ec --- /dev/null +++ b/zbus/src/addr/transport/unix.rs @@ -0,0 +1,114 @@ +use std::{borrow::Cow, ffi::OsStr}; + +use crate::{Error, Result}; + +use super::{ + percent::{decode_percents, decode_percents_os_str, decode_percents_str, EncData, EncOsStr}, + DBusAddr, KeyValFmt, KeyValFmtAdd, +}; + +/// A sub-type of `unix:` transport. +#[derive(Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum UnixAddrKind<'a> { + /// Path of the unix domain socket. + Path(Cow<'a, OsStr>), + /// Directory in which a socket file with a random file name starting with 'dbus-' should be + /// created by a server. + Dir(Cow<'a, OsStr>), + /// The same as "dir", except that on platforms with abstract sockets, a server may attempt to + /// create an abstract socket whose name starts with this directory instead of a path-based + /// socket. + Tmpdir(Cow<'a, OsStr>), + /// Unique string in the abstract namespace, often syntactically resembling a path but + /// unconnected to the filesystem namespace + Abstract(Cow<'a, [u8]>), + /// Listen on $XDG_RUNTIME_DIR/bus. + Runtime, +} + +impl KeyValFmtAdd for UnixAddrKind<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + match self { + UnixAddrKind::Path(p) => kv.add("path", Some(EncOsStr(p))), + UnixAddrKind::Dir(p) => kv.add("dir", Some(EncOsStr(p))), + UnixAddrKind::Tmpdir(p) => kv.add("tmpdir", Some(EncOsStr(p))), + UnixAddrKind::Abstract(p) => kv.add("abstract", Some(EncData(p))), + UnixAddrKind::Runtime => kv.add("runtime", Some("yes")), + } + } +} + +/// `unix:` D-Bus transport. +#[derive(Debug, PartialEq, Eq)] +pub struct Unix<'a> { + kind: UnixAddrKind<'a>, +} + +impl<'a> Unix<'a> { + /// One of the various `unix:` addresses. + pub fn kind(&self) -> &UnixAddrKind<'a> { + &self.kind + } +} + +impl<'a> TryFrom<&'a DBusAddr<'a>> for Unix<'a> { + type Error = Error; + + fn try_from(s: &'a DBusAddr<'a>) -> Result { + let mut kind = None; + let mut iter = s.key_val_iter(); + for (k, v) in &mut iter { + match k { + "path" | "dir" | "tmpdir" => { + let v = v.ok_or_else(|| Error::Address(format!("Missing {}= value", k)))?; + let v = decode_percents_os_str(v)?; + kind = Some(match k { + "path" => UnixAddrKind::Path(v), + "dir" => UnixAddrKind::Dir(v), + "tmpdir" => UnixAddrKind::Tmpdir(v), + // can't happen, this we matched those earlier + _ => panic!(), + }); + break; + } + "abstract" => { + let v = v.ok_or_else(|| Error::Address(format!("Missing {}= value", k)))?; + let v = decode_percents(v)?; + kind = Some(UnixAddrKind::Abstract(v)); + break; + } + "runtime" => { + let v = v.ok_or_else(|| Error::Address(format!("Missing {}= value", k)))?; + let v = decode_percents_str(v)?; + if v != "yes" { + return Err(Error::Address(format!("Invalid runtime={} value", v))); + } + kind = Some(UnixAddrKind::Runtime); + break; + } + _ => continue, + } + } + let Some(kind) = kind else { + return Err(Error::Address( + "Invalid `unix:` address, missing required key".into(), + )); + }; + for (k, _) in iter { + match k { + "path" | "dir" | "tmpdir" | "abstract" | "runtime" => { + return Err(Error::Address("Invalid address: only one of `path` `dir` `tmpdir` `abstract` or `runtime` expected".into())); + } + _ => (), + } + } + Ok(Unix { kind }) + } +} + +impl KeyValFmtAdd for Unix<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + self.kind().key_val_fmt_add(kv) + } +} diff --git a/zbus/src/addr/transport/unixexec.rs b/zbus/src/addr/transport/unixexec.rs new file mode 100644 index 000000000..ee4a876fe --- /dev/null +++ b/zbus/src/addr/transport/unixexec.rs @@ -0,0 +1,78 @@ +use std::{borrow::Cow, ffi::OsStr, fmt}; + +use crate::{Error, Result}; + +use super::{ + percent::{decode_percents_os_str, decode_percents_str, EncOsStr}, + DBusAddr, KeyValFmt, KeyValFmtAdd, +}; + +#[derive(Debug, PartialEq, Eq)] +struct Argv(usize); + +impl fmt::Display for Argv { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let n = self.0; + + write!(f, "argv{n}") + } +} + +/// `unixexec:` D-Bus transport. +#[derive(Debug, PartialEq, Eq)] +pub struct Unixexec<'a> { + path: Cow<'a, OsStr>, + argv: Vec<(usize, Cow<'a, str>)>, +} + +impl<'a> Unixexec<'a> { + pub fn path(&self) -> &OsStr { + self.path.as_ref() + } + + pub fn argv(&self) -> &[(usize, Cow<'a, str>)] { + self.argv.as_ref() + } +} + +impl<'a> TryFrom<&'a DBusAddr<'a>> for Unixexec<'a> { + type Error = Error; + + fn try_from(s: &'a DBusAddr<'a>) -> Result { + let mut path = None; + let mut argv = Vec::new(); + + for (k, v) in s.key_val_iter() { + match (k, v) { + ("path", Some(v)) => { + path = Some(decode_percents_os_str(v)?); + } + (k, Some(v)) if k.starts_with("argv") => { + let n: usize = k[4..] + .parse() + .map_err(|e| Error::Address(format!("Invalid argv: {e}")))?; + let arg = decode_percents_str(v)?; + argv.push((n, arg)); + } + _ => continue, + } + } + + let Some(path) = path else { + return Err(Error::Address("Missing path=".into())); + }; + + argv.sort_by(|a, b| a.0.cmp(&b.0)); + Ok(Self { path, argv }) + } +} + +impl KeyValFmtAdd for Unixexec<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, mut kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + kv = kv.add("path", Some(EncOsStr(self.path()))); + for (n, arg) in self.argv() { + kv = kv.add(Argv(*n), Some(arg)); + } + kv + } +} diff --git a/zbus/src/addr/transport/vsock.rs b/zbus/src/addr/transport/vsock.rs new file mode 100644 index 000000000..f4597ce8d --- /dev/null +++ b/zbus/src/addr/transport/vsock.rs @@ -0,0 +1,71 @@ +#![cfg(any(feature = "vsock", feature = "tokio-vsock"))] + +use std::marker::PhantomData; + +use crate::{Error, Result}; + +use super::{percent::decode_percents_str, DBusAddr, KeyValFmt, KeyValFmtAdd}; + +/// `vsock:` D-Bus transport. +#[derive(Debug, PartialEq, Eq)] +pub struct Vsock<'a> { + // no cid means ANY + cid: Option, + // no port means ANY + port: Option, + // use a phantom lifetime for eventually future fields and consistency + phantom: PhantomData<&'a ()>, +} + +impl<'a> Vsock<'a> { + /// The VSOCK port. + pub fn port(&self) -> Option { + self.port + } + + /// The VSOCK CID. + pub fn cid(&self) -> Option { + self.cid + } +} + +impl<'a> TryFrom<&'a DBusAddr<'a>> for Vsock<'a> { + type Error = Error; + + fn try_from(s: &'a DBusAddr<'a>) -> Result { + let mut port = None; + let mut cid = None; + + for (k, v) in s.key_val_iter() { + match (k, v) { + ("port", Some(v)) => { + port = Some( + decode_percents_str(v)? + .parse() + .map_err(|e| Error::Address(format!("Invalid port: {e}")))?, + ); + } + ("cid", Some(v)) => { + cid = Some( + decode_percents_str(v)? + .parse() + .map_err(|e| Error::Address(format!("Invalid cid: {e}")))?, + ) + } + _ => continue, + } + } + + Ok(Vsock { + port, + cid, + phantom: PhantomData, + }) + } +} + +impl KeyValFmtAdd for Vsock<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + kv.add("cid", self.cid()).add("port", self.port()) + } +} diff --git a/zbus/src/address.rs b/zbus/src/address.rs index 27efe6c37..34f4c17b0 100644 --- a/zbus/src/address.rs +++ b/zbus/src/address.rs @@ -9,30 +9,10 @@ //! //! [Server addresses]: https://dbus.freedesktop.org/doc/dbus-specification.html#addresses -#[cfg(target_os = "macos")] -use crate::process::run; -#[cfg(windows)] -use crate::win32::windows_autolaunch_bus_address; use crate::{Error, Result}; -#[cfg(not(feature = "tokio"))] -use async_io::Async; #[cfg(all(unix, not(target_os = "macos")))] use nix::unistd::Uid; -#[cfg(not(feature = "tokio"))] -use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; -#[cfg(all(unix, not(feature = "tokio")))] -use std::os::unix::net::UnixStream; -use std::{collections::HashMap, env, str::FromStr}; -#[cfg(feature = "tokio")] -use tokio::net::TcpStream; -#[cfg(all(unix, feature = "tokio"))] -use tokio::net::UnixStream; -#[cfg(feature = "tokio-vsock")] -use tokio_vsock::VsockStream; -#[cfg(all(windows, not(feature = "tokio")))] -use uds_windows::UnixStream; -#[cfg(all(feature = "vsock", not(feature = "tokio")))] -use vsock::VsockStream; +use std::{collections::HashMap, convert::TryFrom, env, str::FromStr}; use std::{ ffi::OsString, @@ -192,214 +172,7 @@ pub enum Address { UnixTmpDir(OsString), } -#[cfg(not(feature = "tokio"))] -#[derive(Debug)] -pub(crate) enum Stream { - Unix(Async), - Tcp(Async), - #[cfg(feature = "vsock")] - Vsock(Async), -} - -#[cfg(feature = "tokio")] -#[derive(Debug)] -pub(crate) enum Stream { - #[cfg(unix)] - Unix(UnixStream), - Tcp(TcpStream), - #[cfg(feature = "tokio-vsock")] - Vsock(VsockStream), -} - -#[cfg(not(feature = "tokio"))] -async fn connect_tcp(addr: TcpAddress) -> Result> { - let addrs = crate::Task::spawn_blocking( - move || -> Result> { - let addrs = (addr.host(), addr.port()).to_socket_addrs()?.filter(|a| { - if let Some(family) = addr.family() { - if family == TcpAddressFamily::Ipv4 { - a.is_ipv4() - } else { - a.is_ipv6() - } - } else { - true - } - }); - Ok(addrs.collect()) - }, - "connect tcp", - ) - .await - .map_err(|e| Error::Address(format!("Failed to receive TCP addresses: {e}")))?; - - // we could attempt connections in parallel? - let mut last_err = Error::Address("Failed to connect".into()); - for addr in addrs { - match Async::::connect(addr).await { - Ok(stream) => return Ok(stream), - Err(e) => last_err = e.into(), - } - } - - Err(last_err) -} - -#[cfg(feature = "tokio")] -async fn connect_tcp(addr: TcpAddress) -> Result { - TcpStream::connect((addr.host(), addr.port())) - .await - .map_err(|e| Error::InputOutput(e.into())) -} - -#[cfg(target_os = "macos")] -pub(crate) async fn macos_launchd_bus_address(env_key: &str) -> Result
{ - let output = run("launchctl", ["getenv", env_key]) - .await - .expect("failed to wait on launchctl output"); - - if !output.status.success() { - return Err(crate::Error::Address(format!( - "launchctl terminated with code: {}", - output.status - ))); - } - - let addr = String::from_utf8(output.stdout).map_err(|e| { - crate::Error::Address(format!("Unable to parse launchctl output as UTF-8: {}", e)) - })?; - - format!("unix:path={}", addr.trim()).parse() -} - impl Address { - #[cfg_attr(any(target_os = "macos", windows), async_recursion::async_recursion)] - pub(crate) async fn connect(self) -> Result { - match self { - Address::Unix(p) => { - #[cfg(not(feature = "tokio"))] - { - #[cfg(windows)] - { - let stream = crate::Task::spawn_blocking( - move || UnixStream::connect(p), - "unix stream connection", - ) - .await?; - Async::new(stream) - .map(Stream::Unix) - .map_err(|e| Error::InputOutput(e.into())) - } - - #[cfg(not(windows))] - { - Async::::connect(p) - .await - .map(Stream::Unix) - .map_err(|e| Error::InputOutput(e.into())) - } - } - - #[cfg(feature = "tokio")] - { - #[cfg(unix)] - { - UnixStream::connect(p) - .await - .map(Stream::Unix) - .map_err(|e| Error::InputOutput(e.into())) - } - - #[cfg(not(unix))] - { - let _ = p; - Err(Error::Unsupported) - } - } - } - - #[cfg(all(feature = "vsock", not(feature = "tokio")))] - Address::Vsock(addr) => { - let stream = VsockStream::connect_with_cid_port(addr.cid, addr.port)?; - Async::new(stream).map(Stream::Vsock).map_err(Into::into) - } - - #[cfg(feature = "tokio-vsock")] - Address::Vsock(addr) => VsockStream::connect(addr.cid, addr.port) - .await - .map(Stream::Vsock) - .map_err(Into::into), - - Address::Tcp(addr) => connect_tcp(addr).await.map(Stream::Tcp), - - Address::NonceTcp { addr, nonce_file } => { - let mut stream = connect_tcp(addr).await?; - - #[cfg(unix)] - let nonce_file = { - use std::os::unix::ffi::OsStrExt; - std::ffi::OsStr::from_bytes(&nonce_file) - }; - - #[cfg(windows)] - let nonce_file = std::str::from_utf8(&nonce_file) - .map_err(|_| Error::Address("nonce file path is invalid UTF-8".to_owned()))?; - - #[cfg(not(feature = "tokio"))] - { - let nonce = std::fs::read(nonce_file)?; - let mut nonce = &nonce[..]; - - while !nonce.is_empty() { - let len = stream - .write_with_mut(|s| std::io::Write::write(s, nonce)) - .await?; - nonce = &nonce[len..]; - } - } - - #[cfg(feature = "tokio")] - { - let nonce = tokio::fs::read(nonce_file).await?; - tokio::io::AsyncWriteExt::write_all(&mut stream, &nonce).await?; - } - - Ok(Stream::Tcp(stream)) - } - - #[cfg(not(windows))] - Address::Autolaunch(_) => Err(Error::Address( - "Autolaunch addresses are only supported on Windows".to_owned(), - )), - - #[cfg(windows)] - Address::Autolaunch(Some(_)) => Err(Error::Address( - "Autolaunch scopes are currently unsupported".to_owned(), - )), - - #[cfg(windows)] - Address::Autolaunch(None) => { - let addr = windows_autolaunch_bus_address()?; - addr.connect().await - } - - #[cfg(not(target_os = "macos"))] - Address::Launchd(_) => Err(Error::Address( - "Launchd addresses are only supported on macOS".to_owned(), - )), - - #[cfg(target_os = "macos")] - Address::Launchd(env) => { - let addr = macos_launchd_bus_address(&env).await?; - addr.connect().await - } - Address::UnixDir(_) | Address::UnixTmpDir(_) => { - // you can't connect to a unix:dir - Err(Error::Unsupported) - } - } - } - /// Get the address for session socket respecting the DBUS_SESSION_BUS_ADDRESS environment /// variable. If we don't recognize the value (or it's not set) we fall back to /// $XDG_RUNTIME_DIR/bus @@ -790,6 +563,7 @@ mod tests { Error::Address(e) => assert_eq!(e, "Key `opt` specified multiple times"), _ => panic!(), } + match Address::from_str("tcp:host=localhost").unwrap_err() { Error::Address(e) => assert_eq!(e, "tcp address is missing `port`"), _ => panic!(), @@ -990,67 +764,4 @@ mod tests { "vsock:cid=98,port=2934", // no support for guid= yet.. ); } - - #[test] - fn connect_tcp() { - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - let port = listener.local_addr().unwrap().port(); - let addr = Address::from_str(&format!("tcp:host=localhost,port={port}")).unwrap(); - crate::utils::block_on(async { addr.connect().await }).unwrap(); - } - - #[test] - fn connect_nonce_tcp() { - struct PercentEncoded<'a>(&'a [u8]); - - impl std::fmt::Display for PercentEncoded<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - super::encode_percents(f, self.0) - } - } - - use std::io::Write; - - const TEST_COOKIE: &[u8] = b"VERILY SECRETIVE"; - - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - let port = listener.local_addr().unwrap().port(); - - let mut cookie = tempfile::NamedTempFile::new().unwrap(); - cookie.as_file_mut().write_all(TEST_COOKIE).unwrap(); - - let encoded_path = format!( - "{}", - PercentEncoded(cookie.path().to_str().unwrap().as_ref()) - ); - - let addr = Address::from_str(&format!( - "nonce-tcp:host=localhost,port={port},noncefile={encoded_path}" - )) - .unwrap(); - - let (sender, receiver) = std::sync::mpsc::sync_channel(1); - - std::thread::spawn(move || { - use std::io::Read; - - let mut client = listener.incoming().next().unwrap().unwrap(); - - let mut buf = [0u8; 16]; - client.read_exact(&mut buf).unwrap(); - - sender.send(buf == TEST_COOKIE).unwrap(); - }); - - crate::utils::block_on(addr.connect()).unwrap(); - - let saw_cookie = receiver - .recv_timeout(std::time::Duration::from_millis(100)) - .expect("nonce file content hasn't been received by server thread in time"); - - assert!( - saw_cookie, - "nonce file content has been received, but was invalid" - ); - } } diff --git a/zbus/src/blocking/connection/builder.rs b/zbus/src/blocking/connection/builder.rs index 62c9e052c..4207f40c3 100644 --- a/zbus/src/blocking/connection/builder.rs +++ b/zbus/src/blocking/connection/builder.rs @@ -13,7 +13,7 @@ use uds_windows::UnixStream; use zvariant::{ObjectPath, Str}; use crate::{ - address::Address, + addr::ToDBusAddrs, blocking::Connection, names::{UniqueName, WellKnownName}, object_server::Interface, @@ -42,10 +42,9 @@ impl<'a> Builder<'a> { /// Create a builder for connection that will use the given [D-Bus bus address]. /// /// [D-Bus bus address]: https://dbus.freedesktop.org/doc/dbus-specification.html#addresses - pub fn address(address: A) -> Result + pub fn address<'t, A>(address: &'t A) -> Result where - A: TryInto
, - A::Error: Into, + A: ToDBusAddrs<'t> + ?Sized, { crate::connection::Builder::address(address).map(Self) } diff --git a/zbus/src/connection/builder.rs b/zbus/src/connection/builder.rs index c1ad1acf9..24e701b17 100644 --- a/zbus/src/connection/builder.rs +++ b/zbus/src/connection/builder.rs @@ -2,29 +2,14 @@ use async_io::Async; use event_listener::Event; use static_assertions::assert_impl_all; -#[cfg(not(feature = "tokio"))] -use std::net::TcpStream; -#[cfg(all(unix, not(feature = "tokio")))] -use std::os::unix::net::UnixStream; use std::{ collections::{HashMap, HashSet, VecDeque}, sync::Arc, }; -#[cfg(feature = "tokio")] -use tokio::net::TcpStream; -#[cfg(all(unix, feature = "tokio"))] -use tokio::net::UnixStream; -#[cfg(feature = "tokio-vsock")] -use tokio_vsock::VsockStream; -#[cfg(windows)] -use uds_windows::UnixStream; -#[cfg(all(feature = "vsock", not(feature = "tokio")))] -use vsock::VsockStream; - use zvariant::{ObjectPath, Str}; use crate::{ - address::{self, Address}, + addr::{self, DBusAddr, ToDBusAddrs}, async_lock::RwLock, names::{InterfaceName, UniqueName, WellKnownName}, object_server::Interface, @@ -33,13 +18,22 @@ use crate::{ use super::{ handshake::{AuthMechanism, Authenticated}, - raw::Socket, + raw::{Socket, Stream, TcpStream}, }; +#[cfg(any(unix, all(windows, not(feature = "tokio"))))] +use super::raw::UnixStream; +#[cfg(any( + all(feature = "vsock", not(feature = "tokio")), + feature = "tokio-vsock" +))] +use super::raw::VsockStream; + const DEFAULT_MAX_QUEUED: usize = 64; #[derive(Debug)] enum Target { + #[cfg(any(unix, all(windows, not(feature = "tokio"))))] UnixStream(UnixStream), TcpStream(TcpStream), #[cfg(any( @@ -47,7 +41,8 @@ enum Target { feature = "tokio-vsock" ))] VsockStream(VsockStream), - Address(Address), + // FIXME: we should be able to keep a instead, but lifetime issues + Address(Vec>>), Socket(Box), } @@ -78,12 +73,12 @@ assert_impl_all!(Builder<'_>: Send, Sync, Unpin); impl<'a> Builder<'a> { /// Create a builder for the session/user message bus connection. pub fn session() -> Result { - Ok(Self::new(Target::Address(Address::session()?))) + Self::address(&addr::session()?) } /// Create a builder for the system-wide message bus connection. pub fn system() -> Result { - Ok(Self::new(Target::Address(Address::system()?))) + Self::address(&addr::system()?) } /// Create a builder for connection that will use the given [D-Bus bus address]. @@ -117,14 +112,18 @@ impl<'a> Builder<'a> { /// current session using `ibus address` command. /// /// [D-Bus bus address]: https://dbus.freedesktop.org/doc/dbus-specification.html#addresses - pub fn address(address: A) -> Result + pub fn address<'t, A>(address: &'t A) -> Result> where - A: TryInto
, - A::Error: Into, + A: ToDBusAddrs<'t> + ?Sized, { - Ok(Self::new(Target::Address( - address.try_into().map_err(Into::into)?, - ))) + let mut addr: Vec>> = vec![]; + for a in address.to_dbus_addrs() { + match a { + Ok(a) => addr.push(Ok(a.to_owned())), + _ => continue, + } + } + Ok(Builder::new(Target::Address(addr))) } /// Create a builder for connection that will use the given unix stream. @@ -132,6 +131,7 @@ impl<'a> Builder<'a> { /// If the default `async-io` feature is disabled, this method will expect /// [`tokio::net::UnixStream`](https://docs.rs/tokio/latest/tokio/net/struct.UnixStream.html) /// argument. + #[cfg(any(unix, all(windows, not(feature = "tokio"))))] pub fn unix_stream(stream: UnixStream) -> Self { Self::new(Target::UnixStream(stream)) } @@ -339,29 +339,27 @@ impl<'a> Builder<'a> { async fn build_(self, executor: Executor<'static>) -> Result { let stream = match self.target { - #[cfg(not(feature = "tokio"))] - Target::UnixStream(stream) => Box::new(Async::new(stream)?) as Box, + #[cfg(all(any(unix, windows), not(feature = "tokio")))] + Target::UnixStream(stream) => Box::new(Async::new(stream)?) as _, #[cfg(all(unix, feature = "tokio"))] Target::UnixStream(stream) => Box::new(stream) as Box, - #[cfg(all(not(unix), feature = "tokio"))] - Target::UnixStream(_) => return Err(Error::Unsupported), #[cfg(not(feature = "tokio"))] - Target::TcpStream(stream) => Box::new(Async::new(stream)?) as Box, + Target::TcpStream(stream) => Box::new(Async::new(stream)?) as _, #[cfg(feature = "tokio")] - Target::TcpStream(stream) => Box::new(stream) as Box, + Target::TcpStream(stream) => Box::new(stream) as _, #[cfg(all(feature = "vsock", not(feature = "tokio")))] - Target::VsockStream(stream) => Box::new(Async::new(stream)?) as Box, + Target::VsockStream(stream) => Box::new(Async::new(stream)?) as _, #[cfg(feature = "tokio-vsock")] - Target::VsockStream(stream) => Box::new(stream) as Box, - Target::Address(address) => match address.connect().await? { - #[cfg(any(unix, not(feature = "tokio")))] - address::Stream::Unix(stream) => Box::new(stream) as Box, - address::Stream::Tcp(stream) => Box::new(stream) as Box, + Target::VsockStream(stream) => Box::new(stream) as _, + Target::Address(address) => match Stream::connect(address).await? { + #[cfg(any(unix, all(windows, not(feature = "tokio"))))] + Stream::Unix(stream) => Box::new(stream) as _, + Stream::Tcp(stream) => Box::new(stream) as _, #[cfg(any( all(feature = "vsock", not(feature = "tokio")), feature = "tokio-vsock" ))] - address::Stream::Vsock(stream) => Box::new(stream) as Box, + Stream::Vsock(stream) => Box::new(stream) as _, }, Target::Socket(stream) => stream, }; diff --git a/zbus/src/connection/raw/macos.rs b/zbus/src/connection/raw/macos.rs new file mode 100644 index 000000000..013947892 --- /dev/null +++ b/zbus/src/connection/raw/macos.rs @@ -0,0 +1,21 @@ +#![cfg(target_os = "macos")] + +use crate::{addr::DBusAddr, process::run, Error, Result}; + +pub(crate) async fn launchd_bus_address(env_key: &str) -> Result> { + let output = run("launchctl", ["getenv", env_key]) + .await + .expect("failed to wait on launchctl output"); + + if !output.status.success() { + return Err(Error::Address(format!( + "launchctl terminated with code: {}", + output.status + ))); + } + + let addr = String::from_utf8(output.stdout) + .map_err(|e| Error::Address(format!("Unable to parse launchctl output as UTF-8: {}", e)))?; + + format!("unix:path={}", addr.trim()).try_into() +} diff --git a/zbus/src/connection/raw/mod.rs b/zbus/src/connection/raw/mod.rs index 68bece35e..4083c2507 100644 --- a/zbus/src/connection/raw/mod.rs +++ b/zbus/src/connection/raw/mod.rs @@ -1,5 +1,28 @@ mod connection; mod socket; +mod stream; +pub(crate) use stream::Stream; + +mod macos; +mod win32; + pub use connection::Connection; pub use socket::Socket; + +#[cfg(not(feature = "tokio"))] +pub(crate) type TcpStream = std::net::TcpStream; +#[cfg(feature = "tokio")] +pub(crate) use tokio::net::TcpStream; + +#[cfg(all(unix, not(feature = "tokio")))] +pub(crate) type UnixStream = std::os::unix::net::UnixStream; +#[cfg(all(windows, not(feature = "tokio")))] +pub(crate) type UnixStream = uds_windows::UnixStream; +#[cfg(all(unix, feature = "tokio"))] +pub(crate) use tokio::net::UnixStream; + +#[cfg(all(feature = "vsock", not(feature = "tokio")))] +pub(crate) type VsockStream = vsock::VsockStream; +#[cfg(feature = "tokio-vsock")] +pub(crate) use tokio_vsock::VsockStream; diff --git a/zbus/src/connection/raw/stream.rs b/zbus/src/connection/raw/stream.rs new file mode 100644 index 000000000..67c79c0c7 --- /dev/null +++ b/zbus/src/connection/raw/stream.rs @@ -0,0 +1,345 @@ +#[cfg(not(feature = "tokio"))] +use async_io::Async; +use std::{ffi::OsString, future::Future, pin::Pin}; + +use crate::{ + addr::{ + transport::{NonceTcp, Tcp, TcpFamily, Transport, UnixAddrKind}, + DBusAddr, ToDBusAddrs, + }, + Error, Result, +}; + +#[cfg(any(feature = "vsock", feature = "tokio-vsock"))] +use crate::addr::transport::Vsock; + +#[cfg(target_os = "macos")] +use crate::addr::transport::Launchd; + +#[cfg(target_os = "windows")] +use crate::addr::transport::Autolaunch; + +#[cfg(all(any(unix, windows), not(feature = "tokio")))] +type UnixStream = Async; +#[cfg(all(unix, feature = "tokio"))] +use super::UnixStream; + +#[cfg(not(feature = "tokio"))] +type TcpStream = Async; +#[cfg(feature = "tokio")] +use super::TcpStream; + +#[cfg(all(feature = "vsock", not(feature = "tokio")))] +type VsockStream = Async; +#[cfg(feature = "tokio-vsock")] +use super::VsockStream; + +#[derive(Debug)] +pub(crate) enum Stream { + #[cfg(any(unix, all(windows, not(feature = "tokio"))))] + Unix(UnixStream), + Tcp(TcpStream), + #[cfg(any( + all(feature = "vsock", not(feature = "tokio")), + feature = "tokio-vsock" + ))] + Vsock(VsockStream), +} + +async fn tcp_stream_connect(host: &str, port: u16, family: Option) -> Result { + #[cfg(not(feature = "tokio"))] + { + use std::net::ToSocketAddrs; + + let host = host.to_string(); + let addrs = crate::Task::spawn_blocking( + move || -> Result> { + let addrs = (host, port).to_socket_addrs()?.filter(|a| { + if let Some(family) = family { + if family == TcpFamily::IPv4 { + a.is_ipv4() + } else { + a.is_ipv6() + } + } else { + true + } + }); + Ok(addrs.collect()) + }, + "connect tcp", + ) + .await + .map_err(|e| Error::Address(format!("Failed to receive TCP addresses: {e}")))?; + + // we could attempt connections in parallel? + let mut last_err = Error::Address("Failed to connect".into()); + for addr in addrs { + match TcpStream::connect(addr).await { + Ok(stream) => return Ok(stream), + Err(e) => last_err = e.into(), + } + } + + Err(last_err) + } + + #[cfg(feature = "tokio")] + { + // FIXME: doesn't handle family + let _ = family; + TcpStream::connect((host, port)) + .await + .map_err(|e| Error::InputOutput(e.into())) + } +} + +impl Stream { + async fn connect_unix(addr: &UnixAddrKind<'_>) -> Result { + let mut s = OsString::from("\0"); + + let p = match addr { + // We should construct a SocketAddr instead, but this is not supported by all APIs + // So we limit ourself to utf-8 paths + UnixAddrKind::Path(p) => p.as_ref(), + UnixAddrKind::Abstract(a) => { + s.push( + std::str::from_utf8(a) + .map_err(|_| Error::Address("Unhandled abstract path".into()))?, + ); + s.as_os_str() + } + _ => return Err(Error::Address("Address is not connectable".into())), + }; + + #[cfg(not(feature = "tokio"))] + { + #[cfg(windows)] + { + let p = p.to_os_string(); + let stream = crate::Task::spawn_blocking( + move || uds_windows::UnixStream::connect(p), + "unix stream connection", + ) + .await?; + Async::new(stream) + .map(Stream::Unix) + .map_err(|e| Error::InputOutput(e.into())) + } + + #[cfg(not(windows))] + { + UnixStream::connect(p) + .await + .map(Stream::Unix) + .map_err(|e| Error::InputOutput(e.into())) + } + } + + #[cfg(feature = "tokio")] + { + #[cfg(unix)] + { + UnixStream::connect(p) + .await + .map(Stream::Unix) + .map_err(|e| Error::InputOutput(e.into())) + } + + #[cfg(not(unix))] + { + let _ = p; + Err(Error::Unsupported) + } + } + } + + #[cfg(target_os = "macos")] + async fn connect_launchd(addr: &Launchd<'_>) -> Result { + let addr = super::macos::launchd_bus_address(addr.env()).await?; + match addr.transport()? { + Transport::Unix(t) => Self::connect_unix(t.kind()).await, + _ => Err(Error::Address(format!("Address is unsupported: {}", addr))), + } + } + + async fn connect_tcp(addr: &Tcp<'_>) -> Result { + let Some(host) = addr.host() else { + return Err(Error::Address("No host in address".into())); + }; + let Some(port) = addr.port() else { + return Err(Error::Address("No port in address".into())); + }; + + tcp_stream_connect(host, port, addr.family()) + .await + .map(Stream::Tcp) + } + + async fn connect_nonce_tcp(addr: &NonceTcp<'_>) -> Result { + let Some(host) = addr.host() else { + return Err(Error::Address("No host in address".into())); + }; + let Some(port) = addr.port() else { + return Err(Error::Address("No port in address".into())); + }; + let Some(noncefile) = addr.noncefile() else { + return Err(Error::Address("No noncefile in address".into())); + }; + + let mut stream = tcp_stream_connect(host, port, addr.family()).await?; + + #[cfg(not(feature = "tokio"))] + { + let nonce = std::fs::read(noncefile)?; + let mut nonce = &nonce[..]; + + while !nonce.is_empty() { + let len = stream + .write_with_mut(|s| std::io::Write::write(s, nonce)) + .await?; + nonce = &nonce[len..]; + } + } + + #[cfg(feature = "tokio")] + { + let nonce = tokio::fs::read(noncefile).await?; + tokio::io::AsyncWriteExt::write_all(&mut stream, &nonce).await?; + } + + Ok(Stream::Tcp(stream)) + } + + #[cfg(target_os = "windows")] + async fn connect_autolaunch(addr: &Autolaunch<'_>) -> Result { + let addr = super::win32::autolaunch_bus_address(addr.scope())?; + + if let Transport::Autolaunch(_) = addr.transport()? { + return Err(Error::Address("Recursive autolaunch: address".into())); + } + + Self::connect_addr(addr).await + } + + #[cfg(any(feature = "vsock", feature = "tokio-vsock"))] + async fn connect_vsock(addr: &Vsock<'_>) -> Result { + let Some(cid) = addr.cid() else { + return Err(Error::Address("No cid in address".into())); + }; + let Some(port) = addr.port() else { + return Err(Error::Address("No port in address".into())); + }; + + #[cfg(all(feature = "vsock", not(feature = "tokio")))] + { + let stream = crate::Task::spawn_blocking( + move || vsock::VsockStream::connect_with_cid_port(cid, port), + "connect vsock", + ) + .await + .map_err(|e| Error::Address(format!("Failed to connect: {e}")))?; + Async::new(stream).map(Stream::Vsock).map_err(Into::into) + } + + #[cfg(feature = "tokio-vsock")] + VsockStream::connect(cid, port) + .await + .map(Stream::Vsock) + .map_err(Into::into) + } + + fn connect_addr(addr: DBusAddr<'_>) -> Pin> + '_>> { + Box::pin(async move { + match addr.transport()? { + Transport::Unix(t) => Self::connect_unix(t.kind()).await, + #[cfg(target_os = "macos")] + Transport::Launchd(t) => Self::connect_launchd(&t).await, + Transport::Tcp(t) => Self::connect_tcp(&t).await, + Transport::NonceTcp(t) => Self::connect_nonce_tcp(&t).await, + #[cfg(target_os = "windows")] + Transport::Autolaunch(t) => Self::connect_autolaunch(&t).await, + #[cfg(any(feature = "vsock", feature = "tokio-vsock"))] + Transport::Vsock(t) => Self::connect_vsock(&t).await, + _ => Err(Error::Address(format!("Address is unsupported: {}", addr))), + } + }) + } + + pub(crate) async fn connect(addr: A) -> Result + where + A: for<'a> ToDBusAddrs<'a>, + { + let mut last_err = None; + for addr in addr.to_dbus_addrs() { + let addr = match addr { + Ok(addr) => addr, + Err(e) => { + last_err = Some(e); + continue; + } + }; + match Self::connect_addr(addr).await { + Ok(l) => return Ok(l), + Err(e) => last_err = Some(e), + } + } + Err(last_err.unwrap_or_else(|| Error::Address("Could not resolve to any addresses".into()))) + } +} + +#[cfg(test)] +mod tests { + use super::Stream; + + #[test] + fn connect_tcp() { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + let addr = format!("tcp:host=localhost,port={port}"); + crate::utils::block_on(async { Stream::connect(addr).await }).unwrap(); + } + + #[test] + fn connect_nonce_tcp() { + use std::io::Write; + + const TEST_COOKIE: &[u8] = b"VERILY SECRETIVE"; + + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + + let mut cookie = tempfile::NamedTempFile::new().unwrap(); + cookie.as_file_mut().write_all(TEST_COOKIE).unwrap(); + + let mut encoded_path = String::new(); + crate::addr::encode_percents(&mut encoded_path, cookie.path().to_str().unwrap().as_ref()) + .unwrap(); + + let addr = format!("nonce-tcp:host=localhost,port={port},noncefile={encoded_path}"); + + let (sender, receiver) = std::sync::mpsc::sync_channel(1); + + std::thread::spawn(move || { + use std::io::Read; + + let mut client = listener.incoming().next().unwrap().unwrap(); + + let mut buf = [0u8; 16]; + client.read_exact(&mut buf).unwrap(); + + sender.send(buf == TEST_COOKIE).unwrap(); + }); + + crate::utils::block_on(Stream::connect(addr)).unwrap(); + + let saw_cookie = receiver + .recv_timeout(std::time::Duration::from_millis(100)) + .expect("nonce file content hasn't been received by server thread in time"); + + assert!( + saw_cookie, + "nonce file content has been received, but was invalid" + ); + } +} diff --git a/zbus/src/connection/raw/win32.rs b/zbus/src/connection/raw/win32.rs new file mode 100644 index 000000000..dd1433797 --- /dev/null +++ b/zbus/src/connection/raw/win32.rs @@ -0,0 +1,24 @@ +#![cfg(target_os = "windows")] + +use crate::{ + addr::{transport::AutolaunchScope, DBusAddr}, + win32::{read_shm, Mutex}, + Error, Result, +}; + +pub fn autolaunch_bus_address(scope: Option<&AutolaunchScope<'_>>) -> Result> { + if scope.is_some() { + return Err(Error::Address( + "autolaunch with scope isn't supported yet".into(), + )); + } + + let mutex = Mutex::new("DBusAutolaunchMutex")?; + let _guard = mutex.lock(); + + let addr = read_shm("DBusDaemonAddressInfo")?; + let addr = String::from_utf8(addr) + .map_err(|e| Error::Address(format!("Unable to parse address as UTF-8: {}", e)))?; + + addr.try_into() +} diff --git a/zbus/src/lib.rs b/zbus/src/lib.rs index 92b4a4db3..1c865db1a 100644 --- a/zbus/src/lib.rs +++ b/zbus/src/lib.rs @@ -41,20 +41,25 @@ pub use dbus_error::*; mod error; pub use error::*; +pub mod addr; + +#[doc(hidden)] pub mod address; +#[deprecated(note = "Use `addr::DBusAddress` instead")] +#[doc(hidden)] pub use address::Address; -#[deprecated(note = "Use `address::TcpAddress` instead")] +#[deprecated(note = "Use `addr::transport::Tcp` instead")] #[doc(hidden)] pub use address::TcpAddress; -#[deprecated(note = "Use `address::TcpAddressFamily` instead")] +#[deprecated(note = "Use `addr::transport::TcpFamily` instead")] #[doc(hidden)] pub use address::TcpAddressFamily; #[cfg(any( all(feature = "vsock", not(feature = "tokio")), feature = "tokio-vsock" ))] -#[deprecated(note = "Use `address::VsockAddress` instead")] +#[deprecated(note = "Use `addr::transport::Vsock` instead")] #[doc(hidden)] pub use address::VsockAddress; diff --git a/zbus/src/proxy/mod.rs b/zbus/src/proxy/mod.rs index d4b2fde5d..dfefd3d9f 100644 --- a/zbus/src/proxy/mod.rs +++ b/zbus/src/proxy/mod.rs @@ -443,9 +443,7 @@ impl PropertiesCache { } trace!("Property `{interface}.{property_name}` updated"); - let entry = values - .entry(property_name.to_string()) - .or_insert_with(PropertyValue::default); + let entry = values.entry(property_name.to_string()).or_default(); entry.value = Some(OwnedValue::from(value)); entry.event.notify(usize::MAX); diff --git a/zbus/src/win32.rs b/zbus/src/win32.rs index cbbd65a35..5b3f33162 100644 --- a/zbus/src/win32.rs +++ b/zbus/src/win32.rs @@ -26,7 +26,6 @@ use winapi::{ }, }; -use crate::Address; #[cfg(not(feature = "tokio"))] use uds_windows::UnixStream; @@ -51,7 +50,7 @@ impl Drop for OwnedHandle { } } -struct Mutex(OwnedHandle); +pub(crate) struct Mutex(OwnedHandle); impl Mutex { pub fn new(name: &str) -> Result { @@ -66,7 +65,7 @@ impl Mutex { Ok(Self(unsafe { OwnedHandle::new(handle) })) } - pub fn lock(&self) -> MutexGuard<'_> { + pub(crate) fn lock(&self) -> MutexGuard<'_> { match unsafe { WaitForSingleObject(self.0.get(), INFINITE) } { WAIT_ABANDONED | WAIT_OBJECT_0 => MutexGuard(self), err => panic!("WaitForSingleObject() failed: return code {}", err), @@ -74,7 +73,7 @@ impl Mutex { } } -struct MutexGuard<'a>(&'a Mutex); +pub(crate) struct MutexGuard<'a>(&'a Mutex); impl Drop for MutexGuard<'_> { fn drop(&mut self) { @@ -272,7 +271,7 @@ pub fn unix_stream_get_peer_pid(stream: &UnixStream) -> Result { Ok(ret) } -fn read_shm(name: &str) -> Result, crate::Error> { +pub(crate) fn read_shm(name: &str) -> Result, crate::Error> { let handle = { let wide_name = OsStr::new(name) .encode_wide() @@ -301,17 +300,6 @@ fn read_shm(name: &str) -> Result, crate::Error> { Ok(data.to_bytes().to_owned()) } -pub fn windows_autolaunch_bus_address() -> Result { - let mutex = Mutex::new("DBusAutolaunchMutex")?; - let _guard = mutex.lock(); - - let addr = read_shm("DBusDaemonAddressInfo")?; - let addr = String::from_utf8(addr) - .map_err(|e| crate::Error::Address(format!("Unable to parse address as UTF-8: {}", e)))?; - - addr.parse() -} - #[cfg(test)] mod tests { use super::*;