From 2ad717e797034828a731600ebea9bdfc72bc0935 Mon Sep 17 00:00:00 2001 From: Erin Power Date: Sat, 20 Nov 2021 16:24:42 +0100 Subject: [PATCH] Add serde support for Value --- pbjson-types/build.rs | 6 +- pbjson-types/src/lib.rs | 4 + pbjson-types/src/list_value.rs | 27 +++ pbjson-types/src/struct.rs | 48 +++++ pbjson-types/src/value.rs | 317 +++++++++++++++++++++++++++++++++ 5 files changed, 401 insertions(+), 1 deletion(-) create mode 100644 pbjson-types/src/list_value.rs create mode 100644 pbjson-types/src/struct.rs create mode 100644 pbjson-types/src/value.rs diff --git a/pbjson-types/build.rs b/pbjson-types/build.rs index 696e648..9bc0655 100644 --- a/pbjson-types/build.rs +++ b/pbjson-types/build.rs @@ -28,7 +28,11 @@ fn main() -> Result<()> { let descriptor_set = std::fs::read(descriptor_path)?; pbjson_build::Builder::new() .register_descriptors(&descriptor_set)? - .exclude([".google.protobuf.Duration", ".google.protobuf.Timestamp"]) + .exclude([ + ".google.protobuf.Duration", + ".google.protobuf.Timestamp", + ".google.protobuf.Value", + ]) .build(&[".google"])?; Ok(()) diff --git a/pbjson-types/src/lib.rs b/pbjson-types/src/lib.rs index e073a8e..46ae5a9 100644 --- a/pbjson-types/src/lib.rs +++ b/pbjson-types/src/lib.rs @@ -24,6 +24,7 @@ clippy::redundant_closure, clippy::redundant_field_names, clippy::clone_on_ref_ptr, + clippy::enum_variant_names, clippy::use_self )] mod pb { @@ -36,6 +37,9 @@ mod pb { } mod duration; +mod list_value; +mod r#struct; mod timestamp; +mod value; pub use pb::google::protobuf::*; diff --git a/pbjson-types/src/list_value.rs b/pbjson-types/src/list_value.rs new file mode 100644 index 0000000..c80ed6d --- /dev/null +++ b/pbjson-types/src/list_value.rs @@ -0,0 +1,27 @@ +impl From> for crate::ListValue { + fn from(values: Vec) -> Self { + Self { values } + } +} + +impl FromIterator for crate::ListValue { + fn from_iter(iter: T) -> Self + where + T: IntoIterator, + { + Self { + values: iter.into_iter().map(Into::into).collect(), + } + } +} + +impl FromIterator for crate::ListValue { + fn from_iter(iter: T) -> Self + where + T: IntoIterator, + { + Self { + values: iter.into_iter().collect(), + } + } +} diff --git a/pbjson-types/src/struct.rs b/pbjson-types/src/struct.rs new file mode 100644 index 0000000..52e7b5a --- /dev/null +++ b/pbjson-types/src/struct.rs @@ -0,0 +1,48 @@ +impl From> for crate::Struct { + fn from(fields: std::collections::HashMap) -> Self { + Self { fields } + } +} + +impl FromIterator<(String, crate::Value)> for crate::Struct { + fn from_iter(iter: T) -> Self + where + T: IntoIterator, + { + Self { + fields: iter.into_iter().collect(), + } + } +} + +#[cfg(test)] +mod tests { + #[test] + fn it_works() { + let map = std::collections::HashMap::from([ + (String::from("bool"), true.into()), + (String::from("unit"), crate::value::Kind::NullValue(0)), + (String::from("number"), 5.0.into()), + (String::from("string"), "string".into()), + (String::from("list"), vec![1.0.into(), 2.0.into()].into()), + ( + String::from("map"), + std::collections::HashMap::from([(String::from("key"), "value".into())]).into(), + ), + ]); + + assert_eq!( + serde_json::to_value(map).unwrap(), + serde_json::json!({ + "bool": true, + "unit": null, + "number": 5.0, + "string": "string", + "list": [1.0, 2.0], + "map": { + "key": "value", + } + }) + ); + } +} diff --git a/pbjson-types/src/value.rs b/pbjson-types/src/value.rs new file mode 100644 index 0000000..003daf5 --- /dev/null +++ b/pbjson-types/src/value.rs @@ -0,0 +1,317 @@ +pub use crate::pb::google::protobuf::value::Kind; + +use serde::{ + de::{self, MapAccess, SeqAccess}, + ser, Deserialize, Deserializer, Serialize, Serializer, +}; + +macro_rules! from { + ($($typ: ty [$id:ident] => {$($from_type:ty => $exp:expr),+ $(,)?})+) => { + $($( + impl From<$from_type> for $typ { + fn from($id: $from_type) -> Self { + $exp + } + } + )+)+ + } +} + +from! { + crate::Value[value] => { + bool => Kind::from(value).into(), + f64 => Kind::from(value).into(), + String => Kind::from(value).into(), + &'static str => Kind::from(value).into(), + Vec => Kind::from(value).into(), + std::collections::HashMap => Kind::from(value).into(), + Kind => Self { kind: Some(value) }, + Option => Self { kind: value }, + } + + Kind[value] => { + bool => Self::BoolValue(value), + f64 => Self::NumberValue(value), + String => Self::StringValue(value), + &'static str => Self::StringValue(value.into()), + Vec => Self::ListValue(value.into()), + std::collections::HashMap => Self::StructValue(value.into()), + } +} + +impl Serialize for crate::Value { + fn serialize(&self, ser: S) -> Result + where + S: Serializer, + { + self.kind.serialize(ser) + } +} + +impl<'de> Deserialize<'de> for crate::Value { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Ok(Self { + kind: <_>::deserialize(deserializer)?, + }) + } +} + +impl Serialize for Kind { + fn serialize(&self, ser: S) -> Result + where + S: Serializer, + { + match self { + Self::NullValue(_) => ().serialize(ser), + Self::StringValue(value) => value.serialize(ser), + Self::BoolValue(value) => value.serialize(ser), + Self::StructValue(value) => value.fields.serialize(ser), + Self::ListValue(list) => list.values.serialize(ser), + Self::NumberValue(value) => { + // Kind does not allow NaN's or Infinity as they are + // indistinguishable from strings. + if value.is_nan() { + Err(ser::Error::custom( + "Cannot serialize NaN as google.protobuf.Value.number_value", + )) + } else if value.is_infinite() { + Err(ser::Error::custom( + "Cannot serialize infinity as google.protobuf.Value.number_value", + )) + } else { + value.serialize(ser) + } + } + } + } +} + +impl<'de> Deserialize<'de> for Kind { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(KindVisitor) + } +} + +struct KindVisitor; + +impl<'de> serde::de::Visitor<'de> for KindVisitor { + type Value = Kind; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("google.protobuf.Value") + } + + fn visit_bool(self, v: bool) -> Result + where + E: de::Error, + { + Ok(Kind::BoolValue(v)) + } + + fn visit_i8(self, v: i8) -> Result + where + E: de::Error, + { + Ok(Kind::NumberValue(v.into())) + } + + fn visit_i16(self, v: i16) -> Result + where + E: de::Error, + { + Ok(Kind::NumberValue(v.into())) + } + + fn visit_i32(self, v: i32) -> Result + where + E: de::Error, + { + Ok(Kind::NumberValue(v.into())) + } + + fn visit_i64(self, v: i64) -> Result + where + E: de::Error, + { + self.visit_i32(v.try_into().map_err(de::Error::custom)?) + } + + fn visit_i128(self, v: i128) -> Result + where + E: de::Error, + { + self.visit_i32(v.try_into().map_err(de::Error::custom)?) + } + + fn visit_u8(self, v: u8) -> Result + where + E: de::Error, + { + Ok(Kind::NumberValue(v.into())) + } + + fn visit_u16(self, v: u16) -> Result + where + E: de::Error, + { + Ok(Kind::NumberValue(v.into())) + } + + fn visit_u32(self, v: u32) -> Result + where + E: de::Error, + { + Ok(Kind::NumberValue(v.into())) + } + + fn visit_u64(self, v: u64) -> Result + where + E: de::Error, + { + self.visit_u32(v.try_into().map_err(de::Error::custom)?) + } + + fn visit_u128(self, v: u128) -> Result + where + E: de::Error, + { + self.visit_u32(v.try_into().map_err(de::Error::custom)?) + } + + fn visit_f32(self, v: f32) -> Result + where + E: de::Error, + { + Ok(Kind::NumberValue(v.into())) + } + + fn visit_f64(self, v: f64) -> Result + where + E: de::Error, + { + Ok(Kind::NumberValue(v)) + } + + fn visit_char(self, v: char) -> Result + where + E: de::Error, + { + Ok(Kind::StringValue(v.into())) + } + + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + Ok(Kind::StringValue(v.into())) + } + + fn visit_borrowed_str(self, v: &'de str) -> Result + where + E: de::Error, + { + Ok(Kind::StringValue(v.into())) + } + + fn visit_string(self, v: String) -> Result + where + E: de::Error, + { + Ok(Kind::StringValue(v)) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: de::Error, + { + Ok(Kind::ListValue( + v.iter() + .copied() + .map(f64::from) + .map(Kind::NumberValue) + .collect(), + )) + } + + fn visit_borrowed_bytes(self, v: &'de [u8]) -> Result + where + E: de::Error, + { + self.visit_bytes(v) + } + + fn visit_byte_buf(self, v: Vec) -> Result + where + E: de::Error, + { + Ok(Kind::ListValue( + v.into_iter() + .map(f64::from) + .map(Kind::NumberValue) + .collect(), + )) + } + + fn visit_none(self) -> Result + where + E: de::Error, + { + Ok(Kind::NullValue(0)) + } + + fn visit_some(self, de: D) -> Result + where + D: Deserializer<'de>, + { + Deserialize::deserialize(de) + } + + fn visit_unit(self) -> Result + where + E: de::Error, + { + Ok(Kind::NullValue(0)) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut list = Vec::new(); + + while let Some(value) = seq.next_element()? { + list.push(value); + } + + Ok(Kind::ListValue(list.into())) + } + + fn visit_map(self, mut map_access: A) -> Result + where + A: MapAccess<'de>, + { + let mut map = std::collections::HashMap::new(); + + while let Some((key, value)) = map_access.next_entry()? { + map.insert(key, value); + } + + Ok(Kind::StructValue(map.into())) + } +} + +#[cfg(test)] +mod tests { + #[test] + fn float_special_cases() { + assert!(serde_json::to_value(crate::Value::from(f64::NAN)).is_err()); + assert!(serde_json::to_value(crate::Value::from(f64::INFINITY)).is_err()); + assert!(serde_json::to_value(crate::Value::from(f64::NEG_INFINITY)).is_err()); + } +}