Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow splitting client to support concurrent read/publish #27

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 171 additions & 101 deletions src/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,54 @@ pub struct Client {
///
/// This field uses a Mutex for interior mutability so that
/// `Client` is `Send`. It's not expected to be `Sync`.
free_write_pids: Mutex<FreePidList>,
free_write_pids: Arc<Mutex<FreePidList>>,
}

/// A clonable structure which can be used to publish messages from multiple
/// threads concurrently.
///
/// Publisher instances are invalidated when the client disconnects (after retrying).
/// After calling `connect` again, publishers will have to be recreated by calling `Client::publisher`.
///
/// ```no_run
/// use mqtt_async_client::client:: {
/// Client,
/// Publish
/// };
///
/// let client =
/// Client::builder()
/// .set_url_string("mqtt://example.com").unwrap()
/// .build()
/// .unwrap();
///
/// let publisher1 = client.publisher().unwrap();
/// let publisher2 = client.publisher().unwrap(); // Or use publisher1.clone()
///
/// tokio::spawn(async move {
/// publisher1.publish(&Publish::new(
/// String::from("topic"),
/// vec![0u8, 1u8]
/// ));
/// });
///
/// tokio::spawn(async move {
/// publisher2.publish(&Publish::new(
/// String::from("topic"),
/// vec![5u8, 8u8]
/// ));
/// });
/// ```
#[derive(Clone)]
pub struct ClientPublisher {
/// Sender to send IO requests to the IO task.
tx_io_requests: mpsc::Sender<IoRequest>,

/// Time to wait for a response before timing out
operation_timeout: Duration,

/// Tracks which Pids (MQTT packet IDs) are in use.
free_write_pids: Arc<Mutex<FreePidList>>,
}

impl fmt::Debug for Client {
Expand Down Expand Up @@ -151,7 +198,7 @@ impl fmt::Debug for ClientOptions {
/// The client side of the communication channels to an IO task.
struct IoTaskHandle {
/// Sender to send IO requests to the IO task.
tx_io_requests: mpsc::Sender<IoRequest>,
publisher: ClientPublisher,

/// Receiver to receive Publish packets from the IO task.
rx_recv_published: mpsc::Receiver<Result<Packet>>,
Expand Down Expand Up @@ -257,7 +304,7 @@ impl Client {
Ok(Client {
options: opts,
io_task_handle: None,
free_write_pids: Mutex::new(FreePidList::new()),
free_write_pids: Arc::new(Mutex::new(FreePidList::new())),
})
}

Expand All @@ -276,7 +323,11 @@ impl Client {
mpsc::channel::<Result<Packet>>(self.options.packet_buffer_len);
let halt = Arc::new(AtomicBool::new(false));
self.io_task_handle = Some(IoTaskHandle {
tx_io_requests,
publisher: ClientPublisher {
tx_io_requests,
operation_timeout: self.options.operation_timeout,
free_write_pids: self.free_write_pids.clone()
},
rx_recv_published,
halt: halt.clone(),
});
Expand All @@ -298,56 +349,22 @@ impl Client {
/// create several publish futures to publish several payloads of
/// data simultaneously without waiting for responses.
pub async fn publish(&self, p: &Publish) -> Result<()> {
let qos = p.qos();
if qos == QoS::ExactlyOnce {
return Err("QoS::ExactlyOnce is not supported".into());
}
let p2 = Packet::Publish(mqttrs::Publish {
dup: false, // TODO.
qospid: match qos {
QoS::AtMostOnce => QosPid::AtMostOnce,
QoS::AtLeastOnce => QosPid::AtLeastOnce(self.alloc_write_pid()?),
QoS::ExactlyOnce => panic!("Not reached"),
},
retain: p.retain(),
topic_name: p.topic().to_owned(),
payload: p.payload().to_owned(),
});
match qos {
QoS::AtMostOnce => {
let res = timeout(self.options.operation_timeout,
self.write_only_packet(&p2)).await;
if let Err(Elapsed { .. }) = res {
return Err(format!("Timeout writing publish after {}ms",
self.options.operation_timeout.as_millis()).into());
}
res.expect("No timeout")?;
}
QoS::AtLeastOnce => {
let res = timeout(self.options.operation_timeout,
self.write_response_packet(&p2)).await;
if let Err(Elapsed { .. }) = res {
// We report this but can't really deal with it properly.
// The protocol says we can't re-use the packet ID so we have to leak it
// and potentially run out of packet IDs.
return Err(format!("Timeout waiting for Puback after {}ms",
self.options.operation_timeout.as_millis()).into());
}
let res = res.expect("No timeout")?;
match res {
Packet::Puback(pid) => self.free_write_pid(pid)?,
_ => error!("Bad packet response for publish: {:#?}", res),
}
},
QoS::ExactlyOnce => panic!("Not reached"),
};
Ok(())
let c = self.check_io_task()?;
c.publisher.publish(p).await
}

/// Creates a `ClientPublisher` which can be cloned and used to publish messages
/// from multiple threads concurrently.
pub fn publisher(&self) -> Result<ClientPublisher> {
let c = self.check_io_task()?;
Ok(c.publisher.clone())
}

/// Subscribe to some topics.`read_subscriptions` will return
/// data for them.
pub async fn subscribe(&mut self, s: Subscribe) -> Result<SubscribeResult> {
let pid = self.alloc_write_pid()?;
let c = self.check_io_task()?;
let pid = c.publisher.alloc_write_pid()?;
// TODO: Support subscribe to qos == ExactlyOnce.
if s.topics().iter().any(|t| t.qos == QoS::ExactlyOnce) {
return Err("Qos::ExactlyOnce is not supported right now".into())
Expand All @@ -356,7 +373,7 @@ impl Client {
pid,
topics: s.topics().to_owned(),
});
let res = timeout(self.options.operation_timeout, self.write_response_packet(&p)).await;
let res = timeout(self.options.operation_timeout, c.publisher.write_response_packet(&p)).await;
if let Err(Elapsed { .. }) = res {
// We report this but can't really deal with it properly.
// The protocol says we can't re-use the packet ID so we have to leak it
Expand All @@ -370,7 +387,7 @@ impl Client {
pid: suback_pid,
return_codes: rcs,
}) if suback_pid == pid => {
self.free_write_pid(pid)?;
c.publisher.free_write_pid(pid)?;
Ok(SubscribeResult {
return_codes: rcs
})
Expand All @@ -386,13 +403,14 @@ impl Client {
/// Unsubscribe from some topics. `read_subscriptions` will no
/// longer return data for them.
pub async fn unsubscribe(&mut self, u: Unsubscribe) -> Result<()> {
let pid = self.alloc_write_pid()?;
let c = self.check_io_task()?;
let pid = c.publisher.alloc_write_pid()?;
let p = Packet::Unsubscribe(mqttrs::Unsubscribe {
pid,
topics: u.topics().iter().map(|ut| ut.topic_name().to_owned())
.collect::<Vec<String>>(),
});
let res = timeout(self.options.operation_timeout, self.write_response_packet(&p)).await;
let res = timeout(self.options.operation_timeout, c.publisher.write_response_packet(&p)).await;
if let Err(Elapsed { .. }) = res {
// We report this but can't really deal with it properly.
// The protocol says we can't re-use the packet ID so we have to leak it
Expand All @@ -404,7 +422,7 @@ impl Client {
match res {
Packet::Unsuback(ack_pid)
if ack_pid == pid => {
self.free_write_pid(pid)?;
c.publisher.free_write_pid(pid)?;
Ok(())
},
_ => {
Expand All @@ -431,7 +449,7 @@ impl Client {
match p.qospid {
QosPid::AtMostOnce => (),
QosPid::AtLeastOnce(pid) => {
self.write_only_packet(&Packet::Puback(pid)).await?;
h.publisher.write_only_packet(&Packet::Puback(pid)).await?;
},
QosPid::ExactlyOnce(_) => {
error!("Received publish with unimplemented QoS: ExactlyOnce");
Expand All @@ -451,11 +469,11 @@ impl Client {

/// Gracefully close the connection to the server.
pub async fn disconnect(&mut self) -> Result<()> {
self.check_io_task()?;
let c = self.check_io_task()?;
debug!("Disconnecting");
let p = Packet::Disconnect;
let res = timeout(self.options.operation_timeout,
self.write_only_packet(&p)).await;
c.publisher.write_only_packet(&p)).await;
if let Err(Elapsed { .. }) = res {
return Err(format!("Timeout waiting for Disconnect to send after {}ms",
self.options.operation_timeout.as_millis()).into());
Expand All @@ -465,57 +483,14 @@ impl Client {
Ok(())
}

fn alloc_write_pid(&self) -> Result<Pid> {
match self.free_write_pids.lock().expect("not poisoned").alloc() {
Some(pid) => Ok(Pid::try_from(pid).expect("Non-zero Pid")),
None => Err(Error::from("No free Pids")),
}
}

fn free_write_pid(&self, p: Pid) -> Result<()> {
match self.free_write_pids.lock().expect("not poisoned").free(p.get()) {
true => Err(Error::from("Pid was already free")),
false => Ok(())
}
}

async fn shutdown(&mut self) -> Result <()> {
let c = self.check_io_task()?;
c.halt.store(true, Ordering::SeqCst);
self.write_request(IoType::ShutdownConnection, None).await?;
c.publisher.write_request(IoType::ShutdownConnection, None).await?;
self.io_task_handle = None;
Ok(())
}

async fn write_only_packet(&self, p: &Packet) -> Result<()> {
self.write_request(IoType::WriteOnly { packet: p.clone(), }, None)
.await.map(|_v| ())

}

async fn write_response_packet(&self, p: &Packet) -> Result<Packet> {
let io_type = IoType::WriteAndResponse {
packet: p.clone(),
response_pid: packet_pid(p).expect("packet_pid"),
};
let (tx, rx) = oneshot::channel::<IoResult>();
self.write_request(io_type, Some(tx))
.await?;
// TODO: Add a timeout?
let res = rx.await.map_err(Error::from_std_err)?;
res.result.map(|v| v.expect("return packet"))
}

async fn write_request(&self, io_type: IoType, tx_result: Option<oneshot::Sender<IoResult>>) -> Result<()> {
// NB: Some duplication in IoTask::replay_subscriptions.

let c = self.check_io_task()?;
let req = IoRequest { tx_result, io_type };
c.tx_io_requests.clone().send(req).await
.map_err(|e| Error::from_std_err(e))?;
Ok(())
}

fn check_io_task_mut(&mut self) -> Result<&mut IoTaskHandle> {
match self.io_task_handle {
Some(ref mut h) => Ok(h),
Expand Down Expand Up @@ -630,6 +605,101 @@ async fn connect_stream(opts: &ClientOptions) -> Result<AsyncStream> {
}
}

impl ClientPublisher {
/// Publish some data on a topic.
///
/// Note that this method takes `&self`. This means a caller can
/// create several publish futures to publish several payloads of
/// data simultaneously without waiting for responses.
pub async fn publish(&self, p: &Publish) -> Result<()> {
let qos = p.qos();
if qos == QoS::ExactlyOnce {
return Err("QoS::ExactlyOnce is not supported".into());
}
let p2 = Packet::Publish(mqttrs::Publish {
dup: false, // TODO.
qospid: match qos {
QoS::AtMostOnce => QosPid::AtMostOnce,
QoS::AtLeastOnce => QosPid::AtLeastOnce(self.alloc_write_pid()?),
QoS::ExactlyOnce => panic!("Not reached"),
},
retain: p.retain(),
topic_name: p.topic().to_owned(),
payload: p.payload().to_owned(),
});
match qos {
QoS::AtMostOnce => {
let res = timeout(self.operation_timeout,
self.write_only_packet(&p2)).await;
if let Err(Elapsed { .. }) = res {
return Err(format!("Timeout writing publish after {}ms",
self.operation_timeout.as_millis()).into());
}
res.expect("No timeout")?;
}
QoS::AtLeastOnce => {
let res = timeout(self.operation_timeout,
self.write_response_packet(&p2)).await;
if let Err(Elapsed { .. }) = res {
// We report this but can't really deal with it properly.
// The protocol says we can't re-use the packet ID so we have to leak it
// and potentially run out of packet IDs.
return Err(format!("Timeout waiting for Puback after {}ms",
self.operation_timeout.as_millis()).into());
}
let res = res.expect("No timeout")?;
match res {
Packet::Puback(pid) => self.free_write_pid(pid)?,
_ => error!("Bad packet response for publish: {:#?}", res),
}
},
QoS::ExactlyOnce => panic!("Not reached"),
};
Ok(())
}

async fn write_only_packet(&self, p: &Packet) -> Result<()> {
self.write_request(IoType::WriteOnly { packet: p.clone(), }, None)
.await.map(|_v| ())

}

async fn write_response_packet(&self, p: &Packet) -> Result<Packet> {
let io_type = IoType::WriteAndResponse {
packet: p.clone(),
response_pid: packet_pid(p).expect("packet_pid"),
};
let (tx, rx) = oneshot::channel::<IoResult>();
self.write_request(io_type, Some(tx))
.await?;
// TODO: Add a timeout?
let res = rx.await.map_err(Error::from_std_err)?;
res.result.map(|v| v.expect("return packet"))
}

async fn write_request(&self, io_type: IoType, tx_result: Option<oneshot::Sender<IoResult>>) -> Result<()> {
// NB: Some duplication in IoTask::replay_subscriptions.
let req = IoRequest { tx_result, io_type };
self.tx_io_requests.clone().send(req).await
.map_err(|e| Error::from_std_err(e))?;
Ok(())
}

fn alloc_write_pid(&self) -> Result<Pid> {
match self.free_write_pids.lock().expect("not poisoned").alloc() {
Some(pid) => Ok(Pid::try_from(pid).expect("Non-zero Pid")),
None => Err(Error::from("No free Pids")),
}
}

fn free_write_pid(&self, p: Pid) -> Result<()> {
match self.free_write_pids.lock().expect("not poisoned").free(p.get()) {
true => Err(Error::from("Pid was already free")),
false => Ok(())
}
}
}

/// Build a connect packet from ClientOptions.
fn connect_packet(opts: &ClientOptions) -> Result<Packet> {
Ok(Packet::Connect(mqttrs::Connect {
Expand Down
5 changes: 4 additions & 1 deletion src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ mod builder;
pub use builder::ClientBuilder;

mod client;
pub use client::Client;
pub use client::{
Client,
ClientPublisher
};
pub(crate) use client::ClientOptions;

mod value_types;
Expand Down