diff --git a/src/client/client.rs b/src/client/client.rs index 10e8157..9106a96 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -102,7 +102,54 @@ pub struct Client { /// /// This field uses a Mutex for interior mutability so that /// `Client` is `Send`. It's not expected to be `Sync`. - free_write_pids: Mutex, + free_write_pids: Arc>, +} + +/// A clonable structure which can be used to publish messages from multiple +/// threads concurrently. +/// +/// Publisher instances are invalidated when the client disconnects (after retrying). +/// After calling `connect` again, publishers will have to be recreated by calling `Client::publisher`. +/// +/// ```no_run +/// use mqtt_async_client::client:: { +/// Client, +/// Publish +/// }; +/// +/// let client = +/// Client::builder() +/// .set_url_string("mqtt://example.com").unwrap() +/// .build() +/// .unwrap(); +/// +/// let publisher1 = client.publisher().unwrap(); +/// let publisher2 = client.publisher().unwrap(); // Or use publisher1.clone() +/// +/// tokio::spawn(async move { +/// publisher1.publish(&Publish::new( +/// String::from("topic"), +/// vec![0u8, 1u8] +/// )); +/// }); +/// +/// tokio::spawn(async move { +/// publisher2.publish(&Publish::new( +/// String::from("topic"), +/// vec![5u8, 8u8] +/// )); +/// }); +/// ``` +#[derive(Clone)] +pub struct ClientPublisher { + /// Sender to send IO requests to the IO task. + tx_io_requests: mpsc::Sender, + + /// Time to wait for a response before timing out + operation_timeout: Duration, + + /// Tracks which Pids (MQTT packet IDs) are in use. + free_write_pids: Arc>, } impl fmt::Debug for Client { @@ -151,7 +198,7 @@ impl fmt::Debug for ClientOptions { /// The client side of the communication channels to an IO task. struct IoTaskHandle { /// Sender to send IO requests to the IO task. - tx_io_requests: mpsc::Sender, + publisher: ClientPublisher, /// Receiver to receive Publish packets from the IO task. rx_recv_published: mpsc::Receiver>, @@ -257,7 +304,7 @@ impl Client { Ok(Client { options: opts, io_task_handle: None, - free_write_pids: Mutex::new(FreePidList::new()), + free_write_pids: Arc::new(Mutex::new(FreePidList::new())), }) } @@ -276,7 +323,11 @@ impl Client { mpsc::channel::>(self.options.packet_buffer_len); let halt = Arc::new(AtomicBool::new(false)); self.io_task_handle = Some(IoTaskHandle { - tx_io_requests, + publisher: ClientPublisher { + tx_io_requests, + operation_timeout: self.options.operation_timeout, + free_write_pids: self.free_write_pids.clone() + }, rx_recv_published, halt: halt.clone(), }); @@ -298,56 +349,22 @@ impl Client { /// create several publish futures to publish several payloads of /// data simultaneously without waiting for responses. pub async fn publish(&self, p: &Publish) -> Result<()> { - let qos = p.qos(); - if qos == QoS::ExactlyOnce { - return Err("QoS::ExactlyOnce is not supported".into()); - } - let p2 = Packet::Publish(mqttrs::Publish { - dup: false, // TODO. - qospid: match qos { - QoS::AtMostOnce => QosPid::AtMostOnce, - QoS::AtLeastOnce => QosPid::AtLeastOnce(self.alloc_write_pid()?), - QoS::ExactlyOnce => panic!("Not reached"), - }, - retain: p.retain(), - topic_name: p.topic().to_owned(), - payload: p.payload().to_owned(), - }); - match qos { - QoS::AtMostOnce => { - let res = timeout(self.options.operation_timeout, - self.write_only_packet(&p2)).await; - if let Err(Elapsed { .. }) = res { - return Err(format!("Timeout writing publish after {}ms", - self.options.operation_timeout.as_millis()).into()); - } - res.expect("No timeout")?; - } - QoS::AtLeastOnce => { - let res = timeout(self.options.operation_timeout, - self.write_response_packet(&p2)).await; - if let Err(Elapsed { .. }) = res { - // We report this but can't really deal with it properly. - // The protocol says we can't re-use the packet ID so we have to leak it - // and potentially run out of packet IDs. - return Err(format!("Timeout waiting for Puback after {}ms", - self.options.operation_timeout.as_millis()).into()); - } - let res = res.expect("No timeout")?; - match res { - Packet::Puback(pid) => self.free_write_pid(pid)?, - _ => error!("Bad packet response for publish: {:#?}", res), - } - }, - QoS::ExactlyOnce => panic!("Not reached"), - }; - Ok(()) + let c = self.check_io_task()?; + c.publisher.publish(p).await + } + + /// Creates a `ClientPublisher` which can be cloned and used to publish messages + /// from multiple threads concurrently. + pub fn publisher(&self) -> Result { + let c = self.check_io_task()?; + Ok(c.publisher.clone()) } /// Subscribe to some topics.`read_subscriptions` will return /// data for them. pub async fn subscribe(&mut self, s: Subscribe) -> Result { - let pid = self.alloc_write_pid()?; + let c = self.check_io_task()?; + let pid = c.publisher.alloc_write_pid()?; // TODO: Support subscribe to qos == ExactlyOnce. if s.topics().iter().any(|t| t.qos == QoS::ExactlyOnce) { return Err("Qos::ExactlyOnce is not supported right now".into()) @@ -356,7 +373,7 @@ impl Client { pid, topics: s.topics().to_owned(), }); - let res = timeout(self.options.operation_timeout, self.write_response_packet(&p)).await; + let res = timeout(self.options.operation_timeout, c.publisher.write_response_packet(&p)).await; if let Err(Elapsed { .. }) = res { // We report this but can't really deal with it properly. // The protocol says we can't re-use the packet ID so we have to leak it @@ -370,7 +387,7 @@ impl Client { pid: suback_pid, return_codes: rcs, }) if suback_pid == pid => { - self.free_write_pid(pid)?; + c.publisher.free_write_pid(pid)?; Ok(SubscribeResult { return_codes: rcs }) @@ -386,13 +403,14 @@ impl Client { /// Unsubscribe from some topics. `read_subscriptions` will no /// longer return data for them. pub async fn unsubscribe(&mut self, u: Unsubscribe) -> Result<()> { - let pid = self.alloc_write_pid()?; + let c = self.check_io_task()?; + let pid = c.publisher.alloc_write_pid()?; let p = Packet::Unsubscribe(mqttrs::Unsubscribe { pid, topics: u.topics().iter().map(|ut| ut.topic_name().to_owned()) .collect::>(), }); - let res = timeout(self.options.operation_timeout, self.write_response_packet(&p)).await; + let res = timeout(self.options.operation_timeout, c.publisher.write_response_packet(&p)).await; if let Err(Elapsed { .. }) = res { // We report this but can't really deal with it properly. // The protocol says we can't re-use the packet ID so we have to leak it @@ -404,7 +422,7 @@ impl Client { match res { Packet::Unsuback(ack_pid) if ack_pid == pid => { - self.free_write_pid(pid)?; + c.publisher.free_write_pid(pid)?; Ok(()) }, _ => { @@ -431,7 +449,7 @@ impl Client { match p.qospid { QosPid::AtMostOnce => (), QosPid::AtLeastOnce(pid) => { - self.write_only_packet(&Packet::Puback(pid)).await?; + h.publisher.write_only_packet(&Packet::Puback(pid)).await?; }, QosPid::ExactlyOnce(_) => { error!("Received publish with unimplemented QoS: ExactlyOnce"); @@ -451,11 +469,11 @@ impl Client { /// Gracefully close the connection to the server. pub async fn disconnect(&mut self) -> Result<()> { - self.check_io_task()?; + let c = self.check_io_task()?; debug!("Disconnecting"); let p = Packet::Disconnect; let res = timeout(self.options.operation_timeout, - self.write_only_packet(&p)).await; + c.publisher.write_only_packet(&p)).await; if let Err(Elapsed { .. }) = res { return Err(format!("Timeout waiting for Disconnect to send after {}ms", self.options.operation_timeout.as_millis()).into()); @@ -465,57 +483,14 @@ impl Client { Ok(()) } - fn alloc_write_pid(&self) -> Result { - match self.free_write_pids.lock().expect("not poisoned").alloc() { - Some(pid) => Ok(Pid::try_from(pid).expect("Non-zero Pid")), - None => Err(Error::from("No free Pids")), - } - } - - fn free_write_pid(&self, p: Pid) -> Result<()> { - match self.free_write_pids.lock().expect("not poisoned").free(p.get()) { - true => Err(Error::from("Pid was already free")), - false => Ok(()) - } - } - async fn shutdown(&mut self) -> Result <()> { let c = self.check_io_task()?; c.halt.store(true, Ordering::SeqCst); - self.write_request(IoType::ShutdownConnection, None).await?; + c.publisher.write_request(IoType::ShutdownConnection, None).await?; self.io_task_handle = None; Ok(()) } - async fn write_only_packet(&self, p: &Packet) -> Result<()> { - self.write_request(IoType::WriteOnly { packet: p.clone(), }, None) - .await.map(|_v| ()) - - } - - async fn write_response_packet(&self, p: &Packet) -> Result { - let io_type = IoType::WriteAndResponse { - packet: p.clone(), - response_pid: packet_pid(p).expect("packet_pid"), - }; - let (tx, rx) = oneshot::channel::(); - self.write_request(io_type, Some(tx)) - .await?; - // TODO: Add a timeout? - let res = rx.await.map_err(Error::from_std_err)?; - res.result.map(|v| v.expect("return packet")) - } - - async fn write_request(&self, io_type: IoType, tx_result: Option>) -> Result<()> { - // NB: Some duplication in IoTask::replay_subscriptions. - - let c = self.check_io_task()?; - let req = IoRequest { tx_result, io_type }; - c.tx_io_requests.clone().send(req).await - .map_err(|e| Error::from_std_err(e))?; - Ok(()) - } - fn check_io_task_mut(&mut self) -> Result<&mut IoTaskHandle> { match self.io_task_handle { Some(ref mut h) => Ok(h), @@ -630,6 +605,101 @@ async fn connect_stream(opts: &ClientOptions) -> Result { } } +impl ClientPublisher { + /// Publish some data on a topic. + /// + /// Note that this method takes `&self`. This means a caller can + /// create several publish futures to publish several payloads of + /// data simultaneously without waiting for responses. + pub async fn publish(&self, p: &Publish) -> Result<()> { + let qos = p.qos(); + if qos == QoS::ExactlyOnce { + return Err("QoS::ExactlyOnce is not supported".into()); + } + let p2 = Packet::Publish(mqttrs::Publish { + dup: false, // TODO. + qospid: match qos { + QoS::AtMostOnce => QosPid::AtMostOnce, + QoS::AtLeastOnce => QosPid::AtLeastOnce(self.alloc_write_pid()?), + QoS::ExactlyOnce => panic!("Not reached"), + }, + retain: p.retain(), + topic_name: p.topic().to_owned(), + payload: p.payload().to_owned(), + }); + match qos { + QoS::AtMostOnce => { + let res = timeout(self.operation_timeout, + self.write_only_packet(&p2)).await; + if let Err(Elapsed { .. }) = res { + return Err(format!("Timeout writing publish after {}ms", + self.operation_timeout.as_millis()).into()); + } + res.expect("No timeout")?; + } + QoS::AtLeastOnce => { + let res = timeout(self.operation_timeout, + self.write_response_packet(&p2)).await; + if let Err(Elapsed { .. }) = res { + // We report this but can't really deal with it properly. + // The protocol says we can't re-use the packet ID so we have to leak it + // and potentially run out of packet IDs. + return Err(format!("Timeout waiting for Puback after {}ms", + self.operation_timeout.as_millis()).into()); + } + let res = res.expect("No timeout")?; + match res { + Packet::Puback(pid) => self.free_write_pid(pid)?, + _ => error!("Bad packet response for publish: {:#?}", res), + } + }, + QoS::ExactlyOnce => panic!("Not reached"), + }; + Ok(()) + } + + async fn write_only_packet(&self, p: &Packet) -> Result<()> { + self.write_request(IoType::WriteOnly { packet: p.clone(), }, None) + .await.map(|_v| ()) + + } + + async fn write_response_packet(&self, p: &Packet) -> Result { + let io_type = IoType::WriteAndResponse { + packet: p.clone(), + response_pid: packet_pid(p).expect("packet_pid"), + }; + let (tx, rx) = oneshot::channel::(); + self.write_request(io_type, Some(tx)) + .await?; + // TODO: Add a timeout? + let res = rx.await.map_err(Error::from_std_err)?; + res.result.map(|v| v.expect("return packet")) + } + + async fn write_request(&self, io_type: IoType, tx_result: Option>) -> Result<()> { + // NB: Some duplication in IoTask::replay_subscriptions. + let req = IoRequest { tx_result, io_type }; + self.tx_io_requests.clone().send(req).await + .map_err(|e| Error::from_std_err(e))?; + Ok(()) + } + + fn alloc_write_pid(&self) -> Result { + match self.free_write_pids.lock().expect("not poisoned").alloc() { + Some(pid) => Ok(Pid::try_from(pid).expect("Non-zero Pid")), + None => Err(Error::from("No free Pids")), + } + } + + fn free_write_pid(&self, p: Pid) -> Result<()> { + match self.free_write_pids.lock().expect("not poisoned").free(p.get()) { + true => Err(Error::from("Pid was already free")), + false => Ok(()) + } + } +} + /// Build a connect packet from ClientOptions. fn connect_packet(opts: &ClientOptions) -> Result { Ok(Packet::Connect(mqttrs::Connect { diff --git a/src/client/mod.rs b/src/client/mod.rs index 2138232..d9c7446 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -4,7 +4,10 @@ mod builder; pub use builder::ClientBuilder; mod client; -pub use client::Client; +pub use client::{ + Client, + ClientPublisher +}; pub(crate) use client::ClientOptions; mod value_types;