diff --git a/async-nats/src/lib.rs b/async-nats/src/lib.rs index ccda935d0..44086bec9 100644 --- a/async-nats/src/lib.rs +++ b/async-nats/src/lib.rs @@ -128,7 +128,7 @@ use tracing::{debug, error}; use core::fmt; use std::collections::HashMap; -use std::fmt::Display; +use std::fmt::{Display, Formatter}; use std::future::Future; use std::iter; use std::mem; @@ -143,7 +143,8 @@ use tokio::time::{interval, Duration, Interval, MissedTickBehavior}; use url::{Host, Url}; use bytes::Bytes; -use serde::{Deserialize, Serialize}; +use serde::de::{self, Visitor}; +use serde::{Deserialize, Deserializer, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; use tokio::io; use tokio::sync::mpsc; @@ -1357,6 +1358,32 @@ pub trait ToServerAddrs { fn to_server_addrs(&self) -> io::Result; } +struct ServerAddrVisitor; + +impl<'de> Visitor<'de> for ServerAddrVisitor { + type Value = ServerAddr; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + formatter.write_str("a valid NATS server address") + } + + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + ServerAddr::from_str(v).map_err(|e| de::Error::custom(e.to_string())) + } +} + +impl<'de> Deserialize<'de> for ServerAddr { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_str(ServerAddrVisitor) + } +} + impl ToServerAddrs for ServerAddr { type Iter = option::IntoIter; fn to_server_addrs(&self) -> io::Result { @@ -1441,4 +1468,24 @@ mod tests { let address = ServerAddr::from_str("nats://example.com").unwrap(); assert_eq!(address.host(), "example.com") } + + #[test] + fn deserialize_valid_server_address() { + let serialized = "\"nats://example.com\""; + let address = ServerAddr::from_str("nats://example.com").unwrap(); + + assert_eq!( + serde_json::from_str::(serialized).unwrap(), + address + ); + } + + #[test] + fn deserialize_invalid_server_address() { + let serialized = "\"this is not the address\""; + let result = serde_json::from_str::(serialized); + + assert!(result.is_err()); + assert!(format!("{}", result.unwrap_err()).contains("NATS server URL is invalid")); + } }