diff --git a/src/de.rs b/src/de.rs index 170e0593..fb000134 100644 --- a/src/de.rs +++ b/src/de.rs @@ -972,6 +972,160 @@ where } } +// A wrapper around `visitor` to convert strings to integers. An version of this +// struct is created for each integer type `$numtype` (e.g. u16). Any call to +// `visit_str()` for this visitor will attempt to parse it to `$numtype` and pass +// call `visit_$numtype()` on the wrapped visitor. Calls to `visit_$numtype()` are +// passed through directly. All other visits result in an error. +macro_rules! int_or_string_visitor { + ($visitor_name:ident, $visit:ident, $numtype:ty) => { + struct $visitor_name<'de, V: de::Visitor<'de>> { + inner_visitor: V, + _phantom_data: std::marker::PhantomData<&'de ()>, + } + + impl<'de, V: de::Visitor<'de>> de::Visitor<'_> for $visitor_name<'de, V> { + type Value = V::Value; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "an integer or string-encoded integer key") + } + + fn visit_str(self, v: &str) -> std::result::Result + where + E: de::Error, + { + let num = v.parse().map_err(|e| E::custom(e))?; + self.inner_visitor.$visit(num) + } + + fn $visit(self, v: $numtype) -> std::result::Result + where + E: de::Error, + { + self.inner_visitor.$visit(v) + } + } + }; +} + +int_or_string_visitor!(IntOrStringVisitorI8, visit_i8, i8); +int_or_string_visitor!(IntOrStringVisitorI16, visit_i16, i16); +int_or_string_visitor!(IntOrStringVisitorI32, visit_i32, i32); +int_or_string_visitor!(IntOrStringVisitorI64, visit_i64, i64); +int_or_string_visitor!(IntOrStringVisitorU8, visit_u8, u8); +int_or_string_visitor!(IntOrStringVisitorU16, visit_u16, u16); +int_or_string_visitor!(IntOrStringVisitorU32, visit_u32, u32); +int_or_string_visitor!(IntOrStringVisitorU64, visit_u64, u64); + +serde::serde_if_integer128! { + int_or_string_visitor!(IntOrStringVisitorI128, visit_i128, i128); + int_or_string_visitor!(IntOrStringVisitorU128, visit_u128, u128); +} + +// A type that can deserialize strings of integers (e.g. "4") to integers. +// This is necessary for compatibility with serde_json. +struct MapKey<'a, R> { + de: &'a mut Deserializer, +} + +macro_rules! deserialize_integer_key { + ($visitor_name:ident, $method:ident) => { + fn $method(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + // Deserialize the next value, which should be either an integer + // of the type corresponding to $visitor_name, (e.g. u16), or it + // should be a string that can be parsed to that type. + self.de.deserialize_any($visitor_name { + inner_visitor: visitor, + _phantom_data: Default::default(), + }) + } + }; +} + +impl<'de, 'a, R> de::Deserializer<'de> for MapKey<'a, R> +where + R: Read<'de>, +{ + type Error = Error; + + #[inline] + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.de.deserialize_any(visitor) + } + + deserialize_integer_key!(IntOrStringVisitorI8, deserialize_i8); + deserialize_integer_key!(IntOrStringVisitorI16, deserialize_i16); + deserialize_integer_key!(IntOrStringVisitorI32, deserialize_i32); + deserialize_integer_key!(IntOrStringVisitorI64, deserialize_i64); + deserialize_integer_key!(IntOrStringVisitorU8, deserialize_u8); + deserialize_integer_key!(IntOrStringVisitorU16, deserialize_u16); + deserialize_integer_key!(IntOrStringVisitorU32, deserialize_u32); + deserialize_integer_key!(IntOrStringVisitorU64, deserialize_u64); + + serde::serde_if_integer128! { + deserialize_integer_key!(IntOrStringVisitorI128, deserialize_i128); + deserialize_integer_key!(IntOrStringVisitorU128, deserialize_u128); + } + + #[inline] + fn deserialize_option(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + // Map keys cannot be null. + visitor.visit_some(self) + } + + #[inline] + fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + #[inline] + fn deserialize_enum( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + self.de.deserialize_enum(name, variants, visitor) + } + + #[inline] + fn deserialize_bytes(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.de.deserialize_bytes(visitor) + } + + #[inline] + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.de.deserialize_bytes(visitor) + } + + serde::forward_to_deserialize_any! { + bool f32 f64 char str string unit unit_struct seq tuple tuple_struct map + struct identifier ignored_any + } +} + struct MapAccess<'a, R> { de: &'a mut Deserializer, len: &'a mut usize, @@ -1004,7 +1158,8 @@ where _ => {} }; - let value = seed.deserialize(&mut *self.de)?; + let value = seed.deserialize(MapKey { de: &mut *self.de })?; + Ok(Some(value)) } diff --git a/tests/de.rs b/tests/de.rs index 01d79145..6e25f729 100644 --- a/tests/de.rs +++ b/tests/de.rs @@ -744,4 +744,77 @@ mod std_tests { let err = serde_cbor::from_slice::(&input).expect_err("recursion limit"); assert!(err.is_syntax()); } + + #[test] + fn test_int_as_string_map_keys() { + use std::collections::HashMap; + + // Given a map with keys that are strings, but happen to be strings of integers + // e.g. "4", try to deserialize it into a HashMap. This should + // work to have compatibility with serde_json. + let mut input = HashMap::::new(); + input.insert("1".to_string(), 1); + input.insert("12345".to_string(), 2); + input.insert("-2".to_string(), 3); + let buf = to_vec(&input).unwrap(); + + let deserialized = from_slice::>(&buf).unwrap(); + + assert_eq!(deserialized.len(), 3); + assert_eq!(deserialized.get(&1), Some(&1)); + assert_eq!(deserialized.get(&12345), Some(&2)); + assert_eq!(deserialized.get(&-2), Some(&3)); + } + + #[test] + fn test_int_as_string_map_keys_unsigned_err() { + use std::collections::HashMap; + + let mut input = HashMap::::new(); + input.insert("1".to_string(), 1); + input.insert("-2".to_string(), 3); + let buf = to_vec(&input).unwrap(); + + // Should fail because key is negative. + from_slice::>(&buf).expect_err(""); + } + + #[test] + fn test_int_as_string_map_keys_out_of_range() { + use std::collections::HashMap; + + let mut input = HashMap::::new(); + input.insert("1".to_string(), 1); + input.insert("12345".to_string(), 2); + let buf = to_vec(&input).unwrap(); + + // Should fail because key is out of range. + from_slice::>(&buf).expect_err(""); + } + + #[test] + fn test_int_as_string_map_keys_empty() { + use std::collections::HashMap; + + let mut input = HashMap::::new(); + input.insert("1".to_string(), 1); + input.insert("".to_string(), 2); + let buf = to_vec(&input).unwrap(); + + // Should fail because key is empty + from_slice::>(&buf).expect_err(""); + } + + #[test] + fn test_int_as_string_map_keys_invalid() { + use std::collections::HashMap; + + let mut input = HashMap::::new(); + input.insert("1".to_string(), 1); + input.insert("12.5".to_string(), 2); + let buf = to_vec(&input).unwrap(); + + // Should fail because key is not an integer + from_slice::>(&buf).expect_err(""); + } }