diff --git a/serde/src/private/ser.rs b/serde/src/private/ser.rs index ebfeba97e..f30680b8e 100644 --- a/serde/src/private/ser.rs +++ b/serde/src/private/ser.rs @@ -19,7 +19,7 @@ pub fn serialize_tagged_newtype( type_ident: &'static str, variant_ident: &'static str, tag: &'static str, - variant_name: &'static str, + variant_name: VariantName, value: &T, ) -> Result where @@ -39,7 +39,7 @@ struct TaggedSerializer { type_ident: &'static str, variant_ident: &'static str, tag: &'static str, - variant_name: &'static str, + variant_name: VariantName, delegate: S, } @@ -179,13 +179,13 @@ where fn serialize_unit(self) -> Result { let mut map = tri!(self.delegate.serialize_map(Some(1))); - tri!(map.serialize_entry(self.tag, self.variant_name)); + tri!(map.serialize_entry(self.tag, &self.variant_name)); map.end() } fn serialize_unit_struct(self, _: &'static str) -> Result { let mut map = tri!(self.delegate.serialize_map(Some(1))); - tri!(map.serialize_entry(self.tag, self.variant_name)); + tri!(map.serialize_entry(self.tag, &self.variant_name)); map.end() } @@ -196,7 +196,7 @@ where inner_variant: &'static str, ) -> Result { let mut map = tri!(self.delegate.serialize_map(Some(2))); - tri!(map.serialize_entry(self.tag, self.variant_name)); + tri!(map.serialize_entry(self.tag, &self.variant_name)); tri!(map.serialize_entry(inner_variant, &())); map.end() } @@ -223,7 +223,7 @@ where T: ?Sized + Serialize, { let mut map = tri!(self.delegate.serialize_map(Some(2))); - tri!(map.serialize_entry(self.tag, self.variant_name)); + tri!(map.serialize_entry(self.tag, &self.variant_name)); tri!(map.serialize_entry(inner_variant, inner_value)); map.end() } @@ -266,7 +266,7 @@ where len: usize, ) -> Result { let mut map = tri!(self.delegate.serialize_map(Some(2))); - tri!(map.serialize_entry(self.tag, self.variant_name)); + tri!(map.serialize_entry(self.tag, &self.variant_name)); tri!(map.serialize_key(inner_variant)); Ok(SerializeTupleVariantAsMapValue::new( map, @@ -277,7 +277,7 @@ where fn serialize_map(self, len: Option) -> Result { let mut map = tri!(self.delegate.serialize_map(len.map(|len| len + 1))); - tri!(map.serialize_entry(self.tag, self.variant_name)); + tri!(map.serialize_entry(self.tag, &self.variant_name)); Ok(map) } @@ -287,7 +287,7 @@ where len: usize, ) -> Result { let mut state = tri!(self.delegate.serialize_struct(name, len + 1)); - tri!(state.serialize_field(self.tag, self.variant_name)); + tri!(state.serialize_field(self.tag, &self.variant_name)); Ok(state) } @@ -313,7 +313,7 @@ where len: usize, ) -> Result { let mut map = tri!(self.delegate.serialize_map(Some(2))); - tri!(map.serialize_entry(self.tag, self.variant_name)); + tri!(map.serialize_entry(self.tag, &self.variant_name)); tri!(map.serialize_key(inner_variant)); Ok(SerializeStructVariantAsMapValue::new( map, @@ -1331,6 +1331,47 @@ where } } +pub enum Integer { + U8(u8), + U16(u16), + U32(u32), + U64(u64), + + I8(i8), + I16(i16), + I32(i32), + I64(i64), +} + +pub enum VariantName { + String(&'static str), + Integer(Integer), + Boolean(bool), +} + +impl Serialize for VariantName { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + VariantName::String(s) => serializer.serialize_str(s), + VariantName::Integer(i) => match *i { + Integer::U8(i) => serializer.serialize_u8(i), + Integer::U16(i) => serializer.serialize_u16(i), + Integer::U32(i) => serializer.serialize_u32(i), + Integer::U64(i) => serializer.serialize_u64(i), + + Integer::I8(i) => serializer.serialize_i8(i), + Integer::I16(i) => serializer.serialize_i16(i), + Integer::I32(i) => serializer.serialize_i32(i), + Integer::I64(i) => serializer.serialize_i64(i), + }, + VariantName::Boolean(b) => serializer.serialize_bool(*b), + } + } +} + pub struct AdjacentlyTaggedEnumVariant { pub enum_name: &'static str, pub variant_index: u32, diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index 996e97e86..864617486 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -1,6 +1,7 @@ use crate::fragment::{Expr, Fragment, Match, Stmts}; use crate::internals::ast::{Container, Data, Field, Style, Variant}; -use crate::internals::{attr, replace_receiver, ungroup, Ctxt, Derive}; +use crate::internals::attr::{self, AsVariant, IntRepr, VariantMix}; +use crate::internals::{replace_receiver, ungroup, Ctxt, Derive}; use crate::{bound, dummy, pretend, this}; use proc_macro2::{Literal, Span, TokenStream}; use quote::{quote, quote_spanned, ToTokens}; @@ -280,7 +281,10 @@ fn deserialize_body(cont: &Container, params: &Parameters) -> Fragment { deserialize_try_from(type_try_from) } else if let attr::Identifier::No = cont.attrs.identifier() { match &cont.data { - Data::Enum(variants) => deserialize_enum(params, variants, &cont.attrs), + Data::Enum(variants) => { + let mix = VariantMix::from_de(variants); + deserialize_enum(mix, params, variants, &cont.attrs) + } Data::Struct(Style::Struct, fields) => { deserialize_struct(params, fields, &cont.attrs, StructForm::Struct) } @@ -1192,6 +1196,7 @@ fn deserialize_struct_in_place( } fn deserialize_enum( + mix: VariantMix, params: &Parameters, variants: &[Variant], cattrs: &attr::Container, @@ -1200,14 +1205,15 @@ fn deserialize_enum( match variants.iter().position(|var| var.attrs.untagged()) { Some(variant_idx) => { let (tagged, untagged) = variants.split_at(variant_idx); - let tagged_frag = Expr(deserialize_homogeneous_enum(params, tagged, cattrs)); + let tagged_frag = Expr(deserialize_homogeneous_enum(mix, params, tagged, cattrs)); deserialize_untagged_enum_after(params, untagged, cattrs, Some(tagged_frag)) } - None => deserialize_homogeneous_enum(params, variants, cattrs), + None => deserialize_homogeneous_enum(mix, params, variants, cattrs), } } fn deserialize_homogeneous_enum( + mix: VariantMix, params: &Parameters, variants: &[Variant], cattrs: &attr::Container, @@ -1215,16 +1221,19 @@ fn deserialize_homogeneous_enum( match cattrs.tag() { attr::TagType::External => deserialize_externally_tagged_enum(params, variants, cattrs), attr::TagType::Internal { tag } => { - deserialize_internally_tagged_enum(params, variants, cattrs, tag) + deserialize_internally_tagged_enum(mix, params, variants, cattrs, tag) } attr::TagType::Adjacent { tag, content } => { - deserialize_adjacently_tagged_enum(params, variants, cattrs, tag, content) + deserialize_adjacently_tagged_enum(mix, params, variants, cattrs, tag, content) } attr::TagType::None => deserialize_untagged_enum(params, variants, cattrs), } } -fn prepare_enum_variant_enum(variants: &[Variant]) -> (TokenStream, Stmts) { +fn prepare_enum_variant_enum( + mix: VariantMix, + variants: &[Variant], +) -> (TokenStream, Stmts) { let mut deserialized_variants = variants .iter() .enumerate() @@ -1249,7 +1258,9 @@ fn prepare_enum_variant_enum(variants: &[Variant]) -> (TokenStream, Stmts) { }); let variants_stmt = { - let variant_names = variant_names_idents.iter().map(|(name, _, _)| name); + let variant_names = variant_names_idents + .iter() + .map(|(name, _, _)| name.to_variant_string()); quote! { #[doc(hidden)] const VARIANTS: &'static [&'static str] = &[ #(#variant_names),* ]; @@ -1257,6 +1268,7 @@ fn prepare_enum_variant_enum(variants: &[Variant]) -> (TokenStream, Stmts) { }; let variant_visitor = Stmts(deserialize_generated_identifier( + mix, &variant_names_idents, false, // variant identifiers do not depend on the presence of flatten fields true, @@ -1281,7 +1293,8 @@ fn deserialize_externally_tagged_enum( let expecting = format!("enum {}", params.type_name()); let expecting = cattrs.expecting().unwrap_or(&expecting); - let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants); + let (variants_stmt, variant_visitor) = + prepare_enum_variant_enum(VariantMix::OnlyStrings, variants); // Match arms to extract a variant from a string let variant_arms = variants @@ -1361,12 +1374,13 @@ fn deserialize_externally_tagged_enum( } fn deserialize_internally_tagged_enum( + mix: VariantMix, params: &Parameters, variants: &[Variant], cattrs: &attr::Container, tag: &str, ) -> Fragment { - let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants); + let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(mix, variants); // Match arms to extract a variant from a string let variant_arms = variants @@ -1408,6 +1422,7 @@ fn deserialize_internally_tagged_enum( } fn deserialize_adjacently_tagged_enum( + mix: VariantMix, params: &Parameters, variants: &[Variant], cattrs: &attr::Container, @@ -1420,7 +1435,7 @@ fn deserialize_adjacently_tagged_enum( split_with_de_lifetime(params); let delife = params.borrowed.de_lifetime(); - let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants); + let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(mix, variants); let variant_arms: &Vec<_> = &variants .iter() @@ -1463,14 +1478,6 @@ fn deserialize_adjacently_tagged_enum( } }; - let variant_seed = quote! { - _serde::__private::de::AdjacentlyTaggedEnumVariantSeed::<__Field> { - enum_name: #rust_name, - variants: VARIANTS, - fields_enum: _serde::__private::PhantomData - } - }; - let mut missing_content = quote! { _serde::__private::Err(<__A::Error as _serde::de::Error>::missing_field(#content)) }; @@ -1518,8 +1525,21 @@ fn deserialize_adjacently_tagged_enum( _serde::de::MapAccess::next_key_seed(&mut __map, #tag_or_content)? }; - let variant_from_map = quote! { - _serde::de::MapAccess::next_value_seed(&mut __map, #variant_seed)? + let variant_from_map = match mix { + // if all variants are string variants, we can deserialize by variant index + VariantMix::OnlyStrings => { + quote! { + _serde::de::MapAccess::next_value_seed( + &mut __map, + _serde::__private::de::AdjacentlyTaggedEnumVariantSeed::<__Field> { + enum_name: #rust_name, + variants: VARIANTS, + fields_enum: _serde::__private::PhantomData, + } + )? + } + } + _ => quote! { _serde::de::MapAccess::next_value::<__Field>(&mut __map)? }, }; // When allowing unknown fields, we want to transparently step through keys @@ -1994,8 +2014,9 @@ fn deserialize_untagged_newtype_variant( } } -fn deserialize_generated_identifier( - fields: &[(&str, Ident, &BTreeSet)], +fn deserialize_generated_identifier( + mix: VariantMix, + fields: &[(&T, Ident, &BTreeSet)], has_flatten: bool, is_variant: bool, ignore_variant: Option, @@ -2020,6 +2041,51 @@ fn deserialize_generated_identifier( None }; + let deserialize_call = if is_variant { + match mix { + VariantMix::OnlyStrings => { + quote! { _serde::Deserializer::deserialize_string(__deserializer, __FieldVisitor) } + } + VariantMix::OnlyIntegers(IntRepr::U8) => { + quote! { _serde::Deserializer::deserialize_u8(__deserializer, __FieldVisitor) } + } + VariantMix::OnlyIntegers(IntRepr::U16) => { + quote! { _serde::Deserializer::deserialize_u16(__deserializer, __FieldVisitor) } + } + VariantMix::OnlyIntegers(IntRepr::U32) => { + quote! { _serde::Deserializer::deserialize_u32(__deserializer, __FieldVisitor) } + } + VariantMix::OnlyIntegers(IntRepr::U64) => { + quote! { _serde::Deserializer::deserialize_u64(__deserializer, __FieldVisitor) } + } + VariantMix::OnlyIntegers(IntRepr::I8) => { + quote! { _serde::Deserializer::deserialize_i8(__deserializer, __FieldVisitor) } + } + VariantMix::OnlyIntegers(IntRepr::I16) => { + quote! { _serde::Deserializer::deserialize_i16(__deserializer, __FieldVisitor) } + } + VariantMix::OnlyIntegers(IntRepr::I32) => { + quote! { _serde::Deserializer::deserialize_i32(__deserializer, __FieldVisitor) } + } + VariantMix::OnlyIntegers(IntRepr::I64) => { + quote! { _serde::Deserializer::deserialize_i64(__deserializer, __FieldVisitor) } + } + VariantMix::UnknownIntegers => { + quote! { _serde::Deserializer::deserialize_i64(__deserializer, __FieldVisitor) } + } + VariantMix::OnlyBooleans => { + quote! { _serde::Deserializer::deserialize_bool(__deserializer, __FieldVisitor) } + } + VariantMix::Any => { + quote! { _serde::Deserializer::deserialize_any(__deserializer, __FieldVisitor) } + } + } + } else { + quote! { + _serde::Deserializer::deserialize_identifier(__deserializer, __FieldVisitor) + } + }; + quote_block! { #[allow(non_camel_case_types)] #[doc(hidden)] @@ -2043,7 +2109,7 @@ fn deserialize_generated_identifier( where __D: _serde::Deserializer<'de>, { - _serde::Deserializer::deserialize_identifier(__deserializer, __FieldVisitor) + #deserialize_call } } } @@ -2051,8 +2117,8 @@ fn deserialize_generated_identifier( /// Generates enum and its `Deserialize` implementation that represents each /// non-skipped field of the struct -fn deserialize_field_identifier( - fields: &[(&str, Ident, &BTreeSet)], +fn deserialize_field_identifier( + fields: &[(&T, Ident, &BTreeSet)], cattrs: &attr::Container, has_flatten: bool, ) -> Stmts { @@ -2069,6 +2135,7 @@ fn deserialize_field_identifier( }; Stmts(deserialize_generated_identifier( + VariantMix::OnlyStrings, fields, has_flatten, false, @@ -2192,26 +2259,77 @@ fn deserialize_custom_identifier( } } -fn deserialize_identifier( +fn deserialize_identifier( this_value: &TokenStream, - fields: &[(&str, Ident, &BTreeSet)], + fields: &[(&T, Ident, &BTreeSet)], is_variant: bool, fallthrough: Option, fallthrough_borrowed: Option, collect_other_fields: bool, expecting: Option<&str>, ) -> Fragment { - let str_mapping = fields.iter().map(|(_, ident, aliases)| { - // `aliases` also contains a main name - quote!(#(#aliases)|* => _serde::__private::Ok(#this_value::#ident)) - }); - let bytes_mapping = fields.iter().map(|(_, ident, aliases)| { - // `aliases` also contains a main name - let aliases = aliases - .iter() - .map(|alias| Literal::byte_string(alias.as_bytes())); - quote!(#(#aliases)|* => _serde::__private::Ok(#this_value::#ident)) - }); + let mut flat_fields = Vec::new(); + for (_, ident, aliases) in fields { + flat_fields.extend(aliases.iter().map(|alias| (alias, ident))); + } + + let field_strs: &Vec<_> = &flat_fields + .iter() + .filter_map(|(name, _)| name.as_string()) + .collect(); + let field_ints: &Vec<_> = &flat_fields + .iter() + .filter_map(|(name, _)| { + name.as_int().map(|i| match i.negative { + true => Literal::i64_unsuffixed(if i.negative { + -((i.magnitude - 1) as i64) - 1 + } else { + i.magnitude as i64 + }), + false => Literal::u64_unsuffixed(i.magnitude), + }) + }) + .collect(); + let field_ints_positive: &Vec<_> = &flat_fields + .iter() + .filter_map(|(name, _)| { + name.as_int().and_then(|i| match i.negative { + true => None, + false => Some(Literal::u64_unsuffixed(i.magnitude)), + }) + }) + .collect(); + let field_bools: &Vec<_> = &flat_fields + .iter() + .filter_map(|(name, _)| name.as_bool()) + .collect(); + let field_bytes: &Vec<_> = &flat_fields + .iter() + .filter_map(|(name, _)| name.as_string().map(|s| Literal::byte_string(s.as_bytes()))) + .collect(); + + let constructor_strs: &Vec<_> = &flat_fields + .iter() + .filter(|(name, _)| name.as_string().is_some()) + .map(|(_, ident)| quote!(#this_value::#ident)) + .collect(); + + let constructor_ints: &Vec<_> = &flat_fields + .iter() + .filter(|(name, _)| name.as_int().is_some()) + .map(|(_, ident)| quote!(#this_value::#ident)) + .collect(); + let constructor_ints_positive: &Vec<_> = &flat_fields + .iter() + .filter(|(name, _)| name.as_int().map(|i| !i.negative).unwrap_or(false)) + .map(|(_, ident)| quote!(#this_value::#ident)) + .collect(); + + let constructor_bools: &Vec<_> = &flat_fields + .iter() + .filter(|(name, _)| name.as_bool().is_some()) + .map(|(_, ident)| quote!(#this_value::#ident)) + .collect(); let expecting = expecting.unwrap_or(if is_variant { "variant identifier" @@ -2219,12 +2337,21 @@ fn deserialize_identifier( "field identifier" }); - let bytes_to_str = if fallthrough.is_some() || collect_other_fields { - None + let (bytes_to_str, bool_to_str, int_to_str) = if fallthrough.is_some() || collect_other_fields { + (None, None, None) } else { - Some(quote! { - let __value = &_serde::__private::from_utf8_lossy(__value); - }) + ( + Some(quote! { + let __value = &_serde::__private::from_utf8_lossy(__value); + }), + Some(quote! { + let __value = if __value { "true" } else { "false" }; + }), + Some(quote! { + let __value = _serde::__private::ToString::to_string(&__value); + let __value = &__value; + }), + ) }; let ( @@ -2232,6 +2359,7 @@ fn deserialize_identifier( value_as_borrowed_str_content, value_as_bytes_content, value_as_borrowed_bytes_content, + value_as_bool_content, ) = if collect_other_fields { ( Some(quote! { @@ -2246,9 +2374,12 @@ fn deserialize_identifier( Some(quote! { let __value = _serde::__private::de::Content::Bytes(__value); }), + Some(quote! { + let __value = _serde::__private::de::Content::Bool(__value); + }), ) } else { - (None, None, None, None) + (None, None, None, None, None) }; let fallthrough_arm_tokens; @@ -2266,144 +2397,345 @@ fn deserialize_identifier( &fallthrough_arm_tokens }; - let visit_other = if collect_other_fields { - quote! { + let visit_bool = if constructor_bools.is_empty() { + if collect_other_fields { + Some(quote! { + fn visit_bool<__E>(self, __value: bool) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::Bool(__value))) + } + }) + } else { + None + } + } else { + let missing_true_arm = field_bools.iter().all(|b| !*b); + let missing_false_arm = field_bools.iter().all(|b| *b); + + let fallthrough_true_arm = if missing_true_arm { + Some(quote! { + true => { + #bool_to_str + #value_as_bool_content + #fallthrough_arm + }, + }) + } else { + None + }; + + let fallthrough_false_arm = if missing_false_arm { + Some(quote! { + false => { + #bool_to_str + #value_as_bool_content + #fallthrough_arm + }, + }) + } else { + None + }; + + Some(quote! { fn visit_bool<__E>(self, __value: bool) -> _serde::__private::Result where __E: _serde::de::Error, { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::Bool(__value))) + #[allow(unreachable_patterns)] + match __value { + #( + #field_bools => _serde::__private::Ok(#constructor_bools), + )* + #fallthrough_true_arm + #fallthrough_false_arm + } } + }) + }; - fn visit_i8<__E>(self, __value: i8) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I8(__value))) - } + let visit_int = if constructor_ints.is_empty() { + if collect_other_fields { + quote! { + fn visit_i8<__E>(self, __value: i8) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I8(__value))) + } - fn visit_i16<__E>(self, __value: i16) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I16(__value))) - } + fn visit_i16<__E>(self, __value: i16) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I16(__value))) + } - fn visit_i32<__E>(self, __value: i32) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I32(__value))) - } + fn visit_i32<__E>(self, __value: i32) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I32(__value))) + } - fn visit_i64<__E>(self, __value: i64) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I64(__value))) + fn visit_u8<__E>(self, __value: u8) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U8(__value))) + } + + fn visit_u16<__E>(self, __value: u16) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U16(__value))) + } + + fn visit_u32<__E>(self, __value: u32) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U32(__value))) + } + + fn visit_i64<__E>(self, __value: i64) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I64(__value))) + } + + fn visit_u64<__E>(self, __value: u64) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U64(__value))) + } } + } else { + let variant_indices = 0_u64..; + let u64_fallthrough_arm_tokens; + let u64_fallthrough_arm = if let Some(fallthrough) = &fallthrough { + fallthrough + } else { + let index_expecting = if is_variant { "variant" } else { "field" }; + let fallthrough_msg = + format!("{} index 0 <= i < {}", index_expecting, fields.len()); + u64_fallthrough_arm_tokens = quote! { + _serde::__private::Err(_serde::de::Error::invalid_value( + _serde::de::Unexpected::Unsigned(__value), + &#fallthrough_msg, + )) + }; + &u64_fallthrough_arm_tokens + }; - fn visit_u8<__E>(self, __value: u8) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U8(__value))) + let main_constructors: &Vec<_> = &fields + .iter() + .map(|(_, ident, _)| quote!(#this_value::#ident)) + .collect(); + + quote! { + fn visit_u64<__E>(self, __value: u64) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + match __value { + #( + #variant_indices => _serde::__private::Ok(#main_constructors), + )* + _ => { + #u64_fallthrough_arm + }, + } + } } + } + } else { + let (value_as_u8_content, value_as_u16_content, value_as_u32_content, value_as_u64_content) = + if collect_other_fields { + ( + Some(quote! { + let __value = _serde::__private::de::Content::U8(__value); + }), + Some(quote! { + let __value = _serde::__private::de::Content::U16(__value); + }), + Some(quote! { + let __value = _serde::__private::de::Content::U32(__value); + }), + Some(quote! { + let __value = _serde::__private::de::Content::U64(__value); + }), + ) + } else { + (None, None, None, None) + }; + let (value_as_i8_content, value_as_i16_content, value_as_i32_content, value_as_i64_content) = + if collect_other_fields { + ( + Some(quote! { + let __value = _serde::__private::de::Content::I8(__value); + }), + Some(quote! { + let __value = _serde::__private::de::Content::I16(__value); + }), + Some(quote! { + let __value = _serde::__private::de::Content::I32(__value); + }), + Some(quote! { + let __value = _serde::__private::de::Content::I64(__value); + }), + ) + } else { + (None, None, None, None) + }; - fn visit_u16<__E>(self, __value: u16) -> _serde::__private::Result + quote! { + fn visit_i8<__E>(self, __value: i8) -> _serde::__private::Result where __E: _serde::de::Error, { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U16(__value))) + match __value { + #( + #field_ints => _serde::__private::Ok(#constructor_ints), + )* + _ => { + #int_to_str + #value_as_i8_content + #fallthrough_arm + } + } } - fn visit_u32<__E>(self, __value: u32) -> _serde::__private::Result + fn visit_i16<__E>(self, __value: i16) -> _serde::__private::Result where __E: _serde::de::Error, { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U32(__value))) + match __value { + #( + #field_ints => _serde::__private::Ok(#constructor_ints), + )* + _ => { + #int_to_str + #value_as_i16_content + #fallthrough_arm + } + } } - fn visit_u64<__E>(self, __value: u64) -> _serde::__private::Result + fn visit_i32<__E>(self, __value: i32) -> _serde::__private::Result where __E: _serde::de::Error, { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U64(__value))) + match __value { + #( + #field_ints => _serde::__private::Ok(#constructor_ints), + )* + _ => { + #int_to_str + #value_as_i32_content + #fallthrough_arm + } + } } - fn visit_f32<__E>(self, __value: f32) -> _serde::__private::Result + fn visit_u8<__E>(self, __value: u8) -> _serde::__private::Result where __E: _serde::de::Error, { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::F32(__value))) + match __value { + #( + #field_ints_positive => _serde::__private::Ok(#constructor_ints_positive), + )* + _ => { + #int_to_str + #value_as_u8_content + #fallthrough_arm + } + } } - fn visit_f64<__E>(self, __value: f64) -> _serde::__private::Result + fn visit_u16<__E>(self, __value: u16) -> _serde::__private::Result where __E: _serde::de::Error, { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::F64(__value))) + match __value { + #( + #field_ints_positive => _serde::__private::Ok(#constructor_ints_positive), + )* + _ => { + #int_to_str + #value_as_u16_content + #fallthrough_arm + } + } } - fn visit_char<__E>(self, __value: char) -> _serde::__private::Result + fn visit_u32<__E>(self, __value: u32) -> _serde::__private::Result where __E: _serde::de::Error, { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::Char(__value))) + match __value { + #( + #field_ints_positive => _serde::__private::Ok(#constructor_ints_positive), + )* + _ => { + #int_to_str + #value_as_u32_content + #fallthrough_arm + } + } } - fn visit_unit<__E>(self) -> _serde::__private::Result + fn visit_i64<__E>(self, __value: i64) -> _serde::__private::Result where __E: _serde::de::Error, { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::Unit)) + match __value { + #( + #field_ints => _serde::__private::Ok(#constructor_ints), + )* + _ => { + #int_to_str + #value_as_i64_content + #fallthrough_arm + } + } } - } - } else { - let u64_mapping = fields.iter().enumerate().map(|(i, (_, ident, _))| { - let i = i as u64; - quote!(#i => _serde::__private::Ok(#this_value::#ident)) - }); - let u64_fallthrough_arm_tokens; - let u64_fallthrough_arm = if let Some(fallthrough) = &fallthrough { - fallthrough - } else { - let index_expecting = if is_variant { "variant" } else { "field" }; - let fallthrough_msg = format!("{} index 0 <= i < {}", index_expecting, fields.len()); - u64_fallthrough_arm_tokens = quote! { - _serde::__private::Err(_serde::de::Error::invalid_value( - _serde::de::Unexpected::Unsigned(__value), - &#fallthrough_msg, - )) - }; - &u64_fallthrough_arm_tokens - }; - - quote! { fn visit_u64<__E>(self, __value: u64) -> _serde::__private::Result where __E: _serde::de::Error, { match __value { - #(#u64_mapping,)* - _ => #u64_fallthrough_arm, + #( + #field_ints_positive => _serde::__private::Ok(#constructor_ints_positive), + )* + _ => { + #int_to_str + #value_as_u64_content + #fallthrough_arm + } } } } }; let visit_borrowed = if fallthrough_borrowed.is_some() || collect_other_fields { - let str_mapping = str_mapping.clone(); - let bytes_mapping = bytes_mapping.clone(); let fallthrough_borrowed_arm = fallthrough_borrowed.as_ref().unwrap_or(fallthrough_arm); + Some(quote! { fn visit_borrowed_str<__E>(self, __value: &'de str) -> _serde::__private::Result where __E: _serde::de::Error, { match __value { - #(#str_mapping,)* + #( + #field_strs => _serde::__private::Ok(#constructor_strs), + )* _ => { #value_as_borrowed_str_content #fallthrough_borrowed_arm @@ -2416,7 +2748,9 @@ fn deserialize_identifier( __E: _serde::de::Error, { match __value { - #(#bytes_mapping,)* + #( + #field_bytes => _serde::__private::Ok(#constructor_strs), + )* _ => { #bytes_to_str #value_as_borrowed_bytes_content @@ -2429,19 +2763,15 @@ fn deserialize_identifier( None }; - quote_block! { - fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { - _serde::__private::Formatter::write_str(__formatter, #expecting) - } - - #visit_other - + let visit_str_and_bytes = quote! { fn visit_str<__E>(self, __value: &str) -> _serde::__private::Result where __E: _serde::de::Error, { match __value { - #(#str_mapping,)* + #( + #field_strs => _serde::__private::Ok(#constructor_strs), + )* _ => { #value_as_str_content #fallthrough_arm @@ -2454,7 +2784,9 @@ fn deserialize_identifier( __E: _serde::de::Error, { match __value { - #(#bytes_mapping,)* + #( + #field_bytes => _serde::__private::Ok(#constructor_strs), + )* _ => { #bytes_to_str #value_as_bytes_content @@ -2462,6 +2794,54 @@ fn deserialize_identifier( } } } + }; + + let visit_other = if collect_other_fields { + Some(quote! { + fn visit_f32<__E>(self, __value: f32) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::F32(__value))) + } + + fn visit_f64<__E>(self, __value: f64) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::F64(__value))) + } + + fn visit_char<__E>(self, __value: char) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::Char(__value))) + } + + fn visit_unit<__E>(self) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::Unit)) + } + }) + } else { + None + }; + + quote_block! { + fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { + _serde::__private::Formatter::write_str(__formatter, #expecting) + } + + #visit_other + + #visit_bool + + #visit_int + + #visit_str_and_bytes #visit_borrowed } diff --git a/serde_derive/src/internals/attr.rs b/serde_derive/src/internals/attr.rs index ac5f5d9a5..8767a5b73 100644 --- a/serde_derive/src/internals/attr.rs +++ b/serde_derive/src/internals/attr.rs @@ -1,6 +1,7 @@ use crate::internals::symbol::*; -use crate::internals::{ungroup, Ctxt}; +use crate::internals::{ast, ungroup, Ctxt}; use proc_macro2::{Spacing, Span, TokenStream, TokenTree}; +use quote::quote; use quote::ToTokens; use std::borrow::Cow; use std::collections::BTreeSet; @@ -129,25 +130,363 @@ impl<'c, T> VecAttr<'c, T> { } } -pub struct Name { - serialize: String, - serialize_renamed: bool, - deserialize: String, - deserialize_renamed: bool, - deserialize_aliases: BTreeSet, +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)] +pub enum IntRepr { + I8, + U8, + + I16, + U16, + + I32, + U32, + + I64, + U64, } -fn unraw(ident: &Ident) -> String { - ident.to_string().trim_start_matches("r#").to_owned() +impl IntRepr { + fn from_integer(i: Integer) -> Option { + if let Some(repr) = i.repr { + Some(repr) + } else { + match i.negative { + false if i.magnitude <= u64::from(u8::MAX) => Some(Self::U8), + false if i.magnitude <= u64::from(u16::MAX) => Some(Self::U16), + false if i.magnitude <= u64::from(u32::MAX) => Some(Self::U32), + false if i.magnitude <= u64::from(u64::MAX) => Some(Self::U64), + + true if i.magnitude <= u64::from(u8::MAX) / 2 + 1 => Some(Self::I8), + true if i.magnitude <= u64::from(u16::MAX) / 2 + 1 => Some(Self::I16), + true if i.magnitude <= u64::from(u32::MAX) / 2 + 1 => Some(Self::I32), + true if i.magnitude <= u64::from(u64::MAX) / 2 + 1 => Some(Self::I64), + + _ => None, + } + } + } + fn suffix(self) -> &'static str { + match self { + IntRepr::I8 => "i8", + IntRepr::U8 => "u8", + IntRepr::I16 => "i16", + IntRepr::U16 => "u16", + IntRepr::I32 => "i32", + IntRepr::U32 => "u32", + IntRepr::I64 => "i64", + IntRepr::U64 => "u64", + } + } +} + +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug)] +pub struct Integer { + pub negative: bool, + pub magnitude: u64, + pub repr: Option, +} + +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug)] +pub enum VariantName { + String(String), + Integer(Integer), + Boolean(bool), +} + +impl VariantName { + pub fn to_variant_string(&self) -> String { + match self { + VariantName::String(s) => s.clone(), + VariantName::Integer(i) => { + let suffix = i.repr.map(|repr| repr.suffix()).unwrap_or(""); + if i.negative { + format!("-{}{}", i.magnitude, suffix) + } else { + format!("{}{}", i.magnitude, suffix) + } + } + VariantName::Boolean(b) => b.to_string(), + } + } + pub fn to_literal(&self, mix: VariantMix) -> TokenStream { + match self { + VariantName::String(s) => quote!(#s), + VariantName::Boolean(b) => quote!(#b), + VariantName::Integer(n) => { + let repr = match mix { + VariantMix::OnlyIntegers(repr) => repr, + _ => n.repr.unwrap_or(IntRepr::I64), + }; + match repr { + IntRepr::U8 => { + let i = n.magnitude as u8; + quote!(#i) + } + IntRepr::U16 => { + let i = n.magnitude as u16; + quote!(#i) + } + IntRepr::U32 => { + let i = n.magnitude as u32; + quote!(#i) + } + IntRepr::U64 => { + let i = n.magnitude; + quote!(#i) + } + IntRepr::I8 => { + let i = n.magnitude as i8; + if n.negative { + let i = -i; + quote!(#i) + } else { + quote!(#i) + } + } + IntRepr::I16 => { + let i = n.magnitude as i16; + if n.negative { + let i = -i; + quote!(#i) + } else { + quote!(#i) + } + } + IntRepr::I32 => { + let i = n.magnitude as i32; + if n.negative { + let i = -i; + quote!(#i) + } else { + quote!(#i) + } + } + IntRepr::I64 => { + if n.negative { + let i = -((n.magnitude - 1) as i64) - 1; + quote!(#i) + } else { + let i = n.magnitude as i64; + quote!(#i) + } + } + } + } + } + } +} + +/// An enum representing the distribution of different variant names +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum VariantMix { + OnlyStrings, + UnknownIntegers, + OnlyIntegers(IntRepr), + OnlyBooleans, + Any, +} + +impl VariantMix { + fn from_single(variant: &VariantName) -> Self { + match variant { + &VariantName::String(_) => Self::OnlyStrings, + &VariantName::Integer(i) => IntRepr::from_integer(i) + .map(VariantMix::OnlyIntegers) + .unwrap_or(VariantMix::UnknownIntegers), + &VariantName::Boolean(_) => Self::OnlyBooleans, + } + } + pub fn from_ser(variants: &[ast::Variant]) -> Self { + let mut iter = variants.iter().map(|v| v.attrs.name.serialize_name()); + match iter.next() { + Some(first) => { + let mut mix = Self::from_single(first); + for rest in iter { + if mix == VariantMix::Any { + return mix; + } + mix = match (Self::from_single(rest), mix) { + (VariantMix::OnlyStrings, VariantMix::OnlyStrings) => { + VariantMix::OnlyStrings + } + (VariantMix::OnlyBooleans, VariantMix::OnlyBooleans) => { + VariantMix::OnlyBooleans + } + (VariantMix::OnlyIntegers(a), VariantMix::OnlyIntegers(b)) if a == b => { + VariantMix::OnlyIntegers(a) + } + (VariantMix::OnlyIntegers(a), VariantMix::UnknownIntegers) => { + VariantMix::OnlyIntegers(a) + } + (VariantMix::UnknownIntegers, VariantMix::OnlyIntegers(b)) => { + VariantMix::OnlyIntegers(b) + } + _ => VariantMix::Any, + }; + } + if mix == VariantMix::UnknownIntegers { + let negative = + variants + .iter() + .map(|v| v.attrs.name.serialize_name()) + .any(|v| match v { + VariantName::Integer(i) => i.negative, + _ => false, + }); + let max = variants + .iter() + .map(|v| v.attrs.name.serialize_name()) + .filter_map(|v| match v { + VariantName::Integer(i) => Some(i.magnitude), + _ => None, + }) + .max() + .unwrap(); + match IntRepr::from_integer(Integer { + negative, + magnitude: max, + repr: None, + }) { + Some(repr) => VariantMix::OnlyIntegers(repr), + None => VariantMix::UnknownIntegers, + } + } else { + mix + } + } + + // string variants are the base case as they were the original case + None => Self::OnlyStrings, + } + } + pub fn from_de(variants: &[ast::Variant]) -> Self { + let mut iter = variants.iter().map(|v| v.attrs.name.deserialize_name()); + match iter.next() { + Some(first) => { + let mut mix = Self::from_single(first); + for rest in iter { + if mix == VariantMix::Any { + return mix; + } + mix = match (Self::from_single(rest), mix) { + (VariantMix::OnlyStrings, VariantMix::OnlyStrings) => { + VariantMix::OnlyStrings + } + (VariantMix::OnlyBooleans, VariantMix::OnlyBooleans) => { + VariantMix::OnlyBooleans + } + (VariantMix::OnlyIntegers(a), VariantMix::OnlyIntegers(b)) if a == b => { + VariantMix::OnlyIntegers(a) + } + (VariantMix::OnlyIntegers(a), VariantMix::UnknownIntegers) => { + VariantMix::OnlyIntegers(a) + } + (VariantMix::UnknownIntegers, VariantMix::OnlyIntegers(b)) => { + VariantMix::OnlyIntegers(b) + } + _ => VariantMix::Any, + }; + } + if mix == VariantMix::UnknownIntegers { + let negative = variants + .iter() + .map(|v| v.attrs.name.deserialize_name()) + .any(|v| match v { + VariantName::Integer(i) => i.negative, + _ => false, + }); + let max = variants + .iter() + .map(|v| v.attrs.name.deserialize_name()) + .filter_map(|v| match v { + VariantName::Integer(i) => Some(i.magnitude), + _ => None, + }) + .max() + .unwrap(); + match IntRepr::from_integer(Integer { + negative, + magnitude: max, + repr: None, + }) { + Some(repr) => VariantMix::OnlyIntegers(repr), + None => VariantMix::UnknownIntegers, + } + } else { + mix + } + } + + // string variants are the base case as they were the original case + None => Self::OnlyStrings, + } + } +} + +pub trait AsVariant { + fn as_string(&self) -> Option<&str>; + fn as_int(&self) -> Option; + fn as_bool(&self) -> Option; +} + +impl AsVariant for VariantName { + fn as_string(&self) -> Option<&str> { + match self { + VariantName::String(s) => Some(s), + _ => None, + } + } + + fn as_int(&self) -> Option { + match self { + VariantName::Integer(i) => Some(*i), + _ => None, + } + } + + fn as_bool(&self) -> Option { + match self { + VariantName::Boolean(b) => Some(*b), + _ => None, + } + } +} + +impl AsVariant for String { + fn as_string(&self) -> Option<&str> { + Some(self) + } + + fn as_int(&self) -> Option { + None + } + + fn as_bool(&self) -> Option { + None + } +} + +impl ToTokens for VariantName { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.to_variant_string().to_tokens(tokens) + } } -impl Name { +pub struct Names { + serialize: T, + serialize_renamed: bool, + deserialize: T, + deserialize_renamed: bool, + deserialize_aliases: BTreeSet, +} + +impl Names { fn from_attrs( - source_name: String, - ser_name: Attr, - de_name: Attr, - de_aliases: Option>, - ) -> Name { + source_name: T, + ser_name: Attr, + de_name: Attr, + de_aliases: Option>, + ) -> Names { let mut alias_set = BTreeSet::new(); if let Some(de_aliases) = de_aliases { for alias_name in de_aliases.get() { @@ -159,7 +498,7 @@ impl Name { let ser_renamed = ser_name.is_some(); let de_name = de_name.get(); let de_renamed = de_name.is_some(); - Name { + Names { serialize: ser_name.unwrap_or_else(|| source_name.clone()), serialize_renamed: ser_renamed, deserialize: de_name.unwrap_or(source_name), @@ -169,20 +508,27 @@ impl Name { } /// Return the container name for the container when serializing. - pub fn serialize_name(&self) -> &str { + pub fn serialize_name(&self) -> &T { &self.serialize } /// Return the container name for the container when deserializing. - pub fn deserialize_name(&self) -> &str { + pub fn deserialize_name(&self) -> &T { &self.deserialize } - fn deserialize_aliases(&self) -> &BTreeSet { + fn deserialize_aliases(&self) -> &BTreeSet { &self.deserialize_aliases } } +pub type VariantNames = Names; +pub type StringNames = Names; + +fn unraw(ident: &Ident) -> String { + ident.to_string().trim_start_matches("r#").to_owned() +} + #[derive(Copy, Clone)] pub struct RenameAllRules { pub serialize: RenameRule, @@ -202,7 +548,7 @@ impl RenameAllRules { /// Represents struct or enum attribute information. pub struct Container { - name: Name, + name: StringNames, transparent: bool, deny_unknown_fields: bool, default: Default, @@ -566,7 +912,7 @@ impl Container { } Container { - name: Name::from_attrs(unraw(&item.ident), ser_name, de_name, None), + name: StringNames::from_attrs(unraw(&item.ident), ser_name, de_name, None), transparent: transparent.get(), deny_unknown_fields: deny_unknown_fields.get(), default: default.get().unwrap_or(Default::None), @@ -593,7 +939,7 @@ impl Container { } } - pub fn name(&self) -> &Name { + pub fn name(&self) -> &StringNames { &self.name } @@ -780,7 +1126,7 @@ fn decide_identifier( /// Represents variant attribute information pub struct Variant { - name: Name, + name: VariantNames, rename_all_rules: RenameAllRules, ser_bound: Option>, de_bound: Option>, @@ -830,16 +1176,16 @@ impl Variant { if meta.path == RENAME { // #[serde(rename = "foo")] // #[serde(rename(serialize = "foo", deserialize = "bar"))] - let (ser, de) = get_multiple_renames(cx, &meta)?; - ser_name.set_opt(&meta.path, ser.as_ref().map(syn::LitStr::value)); + let (ser, de) = get_multiple_variant_renames(cx, &meta)?; + ser_name.set_opt(&meta.path, ser); for de_value in de { - de_name.set_if_none(de_value.value()); - de_aliases.insert(&meta.path, de_value.value()); + de_name.set_if_none(de_value.clone()); + de_aliases.insert(&meta.path, de_value); } } else if meta.path == ALIAS { // #[serde(alias = "foo")] - if let Some(s) = get_lit_str(cx, ALIAS, &meta)? { - de_aliases.insert(&meta.path, s.value()); + if let Some(name) = get_variant_name(cx, ALIAS, &meta)? { + de_aliases.insert(&meta.path, name); } } else if meta.path == RENAME_ALL { // #[serde(rename_all = "foo")] @@ -946,7 +1292,12 @@ impl Variant { } Variant { - name: Name::from_attrs(unraw(&variant.ident), ser_name, de_name, Some(de_aliases)), + name: VariantNames::from_attrs( + VariantName::String(unraw(&variant.ident)), + ser_name, + de_name, + Some(de_aliases), + ), rename_all_rules: RenameAllRules { serialize: rename_all_ser_rule.get().unwrap_or(RenameRule::None), deserialize: rename_all_de_rule.get().unwrap_or(RenameRule::None), @@ -963,11 +1314,11 @@ impl Variant { } } - pub fn name(&self) -> &Name { + pub fn name(&self) -> &VariantNames { &self.name } - pub fn aliases(&self) -> &BTreeSet { + pub fn aliases(&self) -> &BTreeSet { self.name.deserialize_aliases() } @@ -1022,7 +1373,7 @@ impl Variant { /// Represents field attribute information pub struct Field { - name: Name, + name: StringNames, skip_serializing: bool, skip_deserializing: bool, skip_serializing_if: Option, @@ -1289,7 +1640,7 @@ impl Field { } Field { - name: Name::from_attrs(ident, ser_name, de_name, Some(de_aliases)), + name: StringNames::from_attrs(ident, ser_name, de_name, Some(de_aliases)), skip_serializing: skip_serializing.get(), skip_deserializing: skip_deserializing.get(), skip_serializing_if: skip_serializing_if.get(), @@ -1305,7 +1656,7 @@ impl Field { } } - pub fn name(&self) -> &Name { + pub fn name(&self) -> &StringNames { &self.name } @@ -1434,6 +1785,14 @@ fn get_renames( Ok((ser.at_most_one(), de.at_most_one())) } +fn get_multiple_variant_renames( + cx: &Ctxt, + meta: &ParseNestedMeta, +) -> syn::Result<(Option, Vec)> { + let (ser, de) = get_ser_and_de(cx, RENAME, meta, get_variant_name2)?; + Ok((ser.at_most_one(), de.get())) +} + fn get_multiple_renames( cx: &Ctxt, meta: &ParseNestedMeta, @@ -1458,21 +1817,11 @@ fn get_lit_str( get_lit_str2(cx, attr_name, attr_name, meta) } -fn get_lit_str2( - cx: &Ctxt, - attr_name: Symbol, - meta_item_name: Symbol, - meta: &ParseNestedMeta, -) -> syn::Result> { - let expr: syn::Expr = meta.value()?.parse()?; - let mut value = &expr; - while let syn::Expr::Group(e) = value { - value = &e.expr; - } +fn try_get_lit_str<'a>(cx: &Ctxt, lit: &'a syn::Expr) -> Result<&'a syn::LitStr, ()> { if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit), .. - }) = value + }) = lit { let suffix = lit.suffix(); if !suffix.is_empty() { @@ -1481,6 +1830,196 @@ fn get_lit_str2( format!("unexpected suffix `{}` on string literal", suffix), ); } + Ok(lit) + } else { + Err(()) + } +} + +fn try_get_lit_int(lit: &syn::Expr) -> Result<(bool, &syn::LitInt), ()> { + if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Int(lit), + .. + }) = lit + { + Ok((false, lit)) + } else if let syn::Expr::Unary(syn::ExprUnary { + op: syn::UnOp::Neg(_), + expr, + .. + }) = lit + { + if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Int(lit), + .. + }) = &**expr + { + Ok((true, lit)) + } else { + Err(()) + } + } else { + Err(()) + } +} + +fn try_get_lit_bool(lit: &syn::Expr) -> Result<&syn::LitBool, ()> { + if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Bool(lit), + .. + }) = lit + { + Ok(lit) + } else { + Err(()) + } +} + +fn get_variant_name( + cx: &Ctxt, + attr_name: Symbol, + meta: &ParseNestedMeta, +) -> syn::Result> { + get_variant_name2(cx, attr_name, attr_name, meta) +} + +fn get_variant_name2( + cx: &Ctxt, + attr_name: Symbol, + meta_item_name: Symbol, + meta: &ParseNestedMeta, +) -> syn::Result> { + let expr: syn::Expr = meta.value()?.parse()?; + let mut value = &expr; + while let syn::Expr::Group(e) = value { + value = &e.expr; + } + if let Ok(lit) = try_get_lit_str(cx, value) { + Ok(Some(VariantName::String(lit.value()))) + } else if let Ok((negative, lit)) = try_get_lit_int(value) { + let integer = match (negative, lit.suffix()) { + (false, "u8") => { + let parse = lit.base10_parse(); + parse.ok().map(|parse: u8| Integer { + negative: false, + magnitude: u64::from(parse), + repr: Some(IntRepr::U8), + }) + } + (false, "u16") => { + let parse = lit.base10_parse(); + parse.ok().map(|parse: u16| Integer { + negative: false, + magnitude: u64::from(parse), + repr: Some(IntRepr::U16), + }) + } + (false, "u32") => { + let parse = lit.base10_parse(); + parse.ok().map(|parse: u32| Integer { + negative: false, + magnitude: u64::from(parse), + repr: Some(IntRepr::U32), + }) + } + (false, "u64") => { + let parse = lit.base10_parse(); + parse.ok().map(|parse: u64| Integer { + negative: false, + magnitude: parse, + repr: Some(IntRepr::U64), + }) + } + (_, "i8") => { + let parse = lit.base10_parse(); + parse.ok().map(|parse: i8| Integer { + negative, + magnitude: u64::from(parse as u8), + repr: Some(IntRepr::I8), + }) + } + (_, "i16") => { + let parse = lit.base10_parse(); + parse.ok().map(|parse: i16| Integer { + negative, + magnitude: u64::from(parse as u16), + repr: Some(IntRepr::I16), + }) + } + (_, "i32") => { + let parse = lit.base10_parse(); + parse.ok().map(|parse: i32| Integer { + negative, + magnitude: u64::from(parse as u32), + repr: Some(IntRepr::I32), + }) + } + (_, "i64") => { + let parse = lit.base10_parse(); + parse.ok().map(|parse: i64| Integer { + negative, + magnitude: parse as u64, + repr: Some(IntRepr::I64), + }) + } + (_, "") => { + let parse = lit.base10_parse(); + parse.ok().map(|parse: i64| Integer { + negative, + magnitude: parse as u64, + repr: None, + }) + } + (true, "u8" | "u16" | "u32" | "u64") => None, + (_, suffix) => { + cx.error_spanned_by( + lit, + format!( + "serde {} attribute has an integer value of unsupported integer type {}", + attr_name, suffix, + ), + ); + return Ok(None); + } + }; + match integer { + Some(integer) => Ok(Some(VariantName::Integer(integer))), + None => { + let suffix = match lit.suffix() { + "" => "i64", + suffix => suffix, + }; + cx.error_spanned_by( + lit, + format!( + "serde {} attribute has an integer value that cannot be represented as type {}", + attr_name, + suffix, + ), + ); + Ok(None) + } + } + } else if let Ok(lit) = try_get_lit_bool(value) { + Ok(Some(VariantName::Boolean(lit.value()))) + } else { + cx.error_spanned_by(expr, format!("expected serde {} attribute to be a string, integer, or boolean literal: `{} = \"...\"`", attr_name, meta_item_name)); + Ok(None) + } +} + +fn get_lit_str2( + cx: &Ctxt, + attr_name: Symbol, + meta_item_name: Symbol, + meta: &ParseNestedMeta, +) -> syn::Result> { + let expr: syn::Expr = meta.value()?.parse()?; + let mut value = &expr; + while let syn::Expr::Group(e) = value { + value = &e.expr; + } + if let Ok(lit) = try_get_lit_str(cx, value) { Ok(Some(lit.clone())) } else { cx.error_spanned_by( diff --git a/serde_derive/src/internals/case.rs b/serde_derive/src/internals/case.rs index 8c8c02e75..2c65e9df1 100644 --- a/serde_derive/src/internals/case.rs +++ b/serde_derive/src/internals/case.rs @@ -4,6 +4,8 @@ use self::RenameRule::*; use std::fmt::{self, Debug, Display}; +use super::attr::VariantName; + /// The different possible ways to change case of fields in a struct, or variants in an enum. #[derive(Copy, Clone, PartialEq)] pub enum RenameRule { @@ -54,7 +56,7 @@ impl RenameRule { } /// Apply a renaming rule to an enum variant, returning the version expected in the source. - pub fn apply_to_variant(self, variant: &str) -> String { + fn apply_to_variant_str(self, variant: &str) -> String { match self { None | PascalCase => variant.to_owned(), LowerCase => variant.to_ascii_lowercase(), @@ -70,14 +72,24 @@ impl RenameRule { } snake } - ScreamingSnakeCase => SnakeCase.apply_to_variant(variant).to_ascii_uppercase(), - KebabCase => SnakeCase.apply_to_variant(variant).replace('_', "-"), + ScreamingSnakeCase => SnakeCase.apply_to_variant_str(variant).to_ascii_uppercase(), + KebabCase => SnakeCase.apply_to_variant_str(variant).replace('_', "-"), ScreamingKebabCase => ScreamingSnakeCase - .apply_to_variant(variant) + .apply_to_variant_str(variant) .replace('_', "-"), } } + /// Apply a renaming rule to an enum variant, returning the version expected in the source. + pub fn apply_to_variant(&self, variant: &VariantName) -> VariantName { + match variant { + VariantName::String(variant) => { + VariantName::String(self.apply_to_variant_str(&variant)) + } + _ => variant.clone(), + } + } + /// Apply a renaming rule to a struct field, returning the version expected in the source. pub fn apply_to_field(self, field: &str) -> String { match self { @@ -155,16 +167,16 @@ fn rename_variants() { ("A", "a", "A", "a", "a", "A", "a", "A"), ("Z42", "z42", "Z42", "z42", "z42", "Z42", "z42", "Z42"), ] { - assert_eq!(None.apply_to_variant(original), original); - assert_eq!(LowerCase.apply_to_variant(original), lower); - assert_eq!(UpperCase.apply_to_variant(original), upper); - assert_eq!(PascalCase.apply_to_variant(original), original); - assert_eq!(CamelCase.apply_to_variant(original), camel); - assert_eq!(SnakeCase.apply_to_variant(original), snake); - assert_eq!(ScreamingSnakeCase.apply_to_variant(original), screaming); - assert_eq!(KebabCase.apply_to_variant(original), kebab); + assert_eq!(None.apply_to_variant_str(original), original); + assert_eq!(LowerCase.apply_to_variant_str(original), lower); + assert_eq!(UpperCase.apply_to_variant_str(original), upper); + assert_eq!(PascalCase.apply_to_variant_str(original), original); + assert_eq!(CamelCase.apply_to_variant_str(original), camel); + assert_eq!(SnakeCase.apply_to_variant_str(original), snake); + assert_eq!(ScreamingSnakeCase.apply_to_variant_str(original), screaming); + assert_eq!(KebabCase.apply_to_variant_str(original), kebab); assert_eq!( - ScreamingKebabCase.apply_to_variant(original), + ScreamingKebabCase.apply_to_variant_str(original), screaming_kebab ); } diff --git a/serde_derive/src/internals/check.rs b/serde_derive/src/internals/check.rs index 52b0f379f..84d016cc8 100644 --- a/serde_derive/src/internals/check.rs +++ b/serde_derive/src/internals/check.rs @@ -1,6 +1,6 @@ use crate::internals::ast::{Container, Data, Field, Style}; -use crate::internals::attr::{Default, Identifier, TagType}; -use crate::internals::{ungroup, Ctxt, Derive}; +use crate::internals::attr::{Default, Identifier, TagType, VariantName}; +use crate::internals::{symbol, ungroup, Ctxt, Derive}; use syn::{Member, Type}; // Cross-cutting checks that require looking at more than a single attrs object. @@ -16,6 +16,7 @@ pub fn check(cx: &Ctxt, cont: &mut Container, derive: Derive) { check_adjacent_tag_conflict(cx, cont); check_transparent(cx, cont, derive); check_from_and_try_from(cx, cont); + check_non_string_renames(cx, cont); } // If some field of a tuple struct is marked #[serde(default)] then all fields @@ -445,6 +446,68 @@ fn check_transparent(cx: &Ctxt, cont: &mut Container, derive: Derive) { } } +/// Externally tagged/untagged enum variants must have string names. +fn check_non_string_renames(cx: &Ctxt, cont: &mut Container) { + let (details, variants) = match &cont.data { + Data::Enum(variants) => match cont.attrs.tag() { + TagType::Adjacent { .. } | TagType::Internal { .. } => return, + TagType::External => ("externally tagged enums", variants), + TagType::None => ("untagged enums", variants), + }, + Data::Struct(_, _) => return, + }; + + for v in variants { + let name = v.attrs.name(); + let ser_name = name.serialize_name(); + let de_name = name.deserialize_name(); + + let attr = v + .original + .attrs + .iter() + .filter_map(|attr| { + if attr.path() != symbol::SERDE { + None? + } + + let meta: syn::Meta = attr.parse_args().ok()?; + if meta.path() != symbol::RENAME { + None? + } + + Some(meta) + }) + .next(); + + match ser_name { + VariantName::String(_) => {} + _ => cx.error_spanned_by( + attr.as_ref().unwrap(), + format!("#[serde(rename)] must use a string name in {}", details), + ), + } + + match de_name { + VariantName::String(_) => {} + _ => cx.error_spanned_by( + attr.as_ref().unwrap(), + format!("#[serde(rename)] must use a string name in {}", details), + ), + } + + for alias in v.attrs.aliases() { + match alias { + VariantName::String(_) => {} + _ => cx.error_spanned_by( + attr.as_ref().unwrap(), + format!("#[serde(rename)] must use a string name in {}", details), + ), + } + } + } +} + fn member_message(member: &Member) -> String { match member { Member::Named(ident) => format!("`{}`", ident), diff --git a/serde_derive/src/ser.rs b/serde_derive/src/ser.rs index 35f8ca4bd..0cace45a0 100644 --- a/serde_derive/src/ser.rs +++ b/serde_derive/src/ser.rs @@ -1,6 +1,7 @@ use crate::fragment::{Fragment, Match, Stmts}; use crate::internals::ast::{Container, Data, Field, Style, Variant}; -use crate::internals::{attr, replace_receiver, Ctxt, Derive}; +use crate::internals::attr::{self, IntRepr, VariantMix, VariantName}; +use crate::internals::{replace_receiver, Ctxt, Derive}; use crate::{bound, dummy, pretend, this}; use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; @@ -170,7 +171,10 @@ fn serialize_body(cont: &Container, params: &Parameters) -> Fragment { serialize_into(params, type_into) } else { match &cont.data { - Data::Enum(variants) => serialize_enum(params, variants, &cont.attrs), + Data::Enum(variants) => { + let mix = VariantMix::from_ser(variants); + serialize_enum(mix, params, variants, &cont.attrs) + } Data::Struct(Style::Struct, fields) => serialize_struct(params, fields, &cont.attrs), Data::Struct(Style::Tuple, fields) => { serialize_tuple_struct(params, fields, &cont.attrs) @@ -387,7 +391,12 @@ fn serialize_struct_as_map( } } -fn serialize_enum(params: &Parameters, variants: &[Variant], cattrs: &attr::Container) -> Fragment { +fn serialize_enum( + mix: VariantMix, + params: &Parameters, + variants: &[Variant], + cattrs: &attr::Container, +) -> Fragment { assert!(variants.len() as u64 <= u64::from(u32::MAX)); let self_var = ¶ms.self_var; @@ -396,7 +405,7 @@ fn serialize_enum(params: &Parameters, variants: &[Variant], cattrs: &attr::Cont .iter() .enumerate() .map(|(variant_index, variant)| { - serialize_variant(params, variant, variant_index as u32, cattrs) + serialize_variant(mix, params, variant, variant_index as u32, cattrs) }) .collect(); @@ -414,6 +423,7 @@ fn serialize_enum(params: &Parameters, variants: &[Variant], cattrs: &attr::Cont } fn serialize_variant( + mix: VariantMix, params: &Parameters, variant: &Variant, variant_index: u32, @@ -472,10 +482,11 @@ fn serialize_variant( serialize_externally_tagged_variant(params, variant, variant_index, cattrs) } (attr::TagType::Internal { tag }, false) => { - serialize_internally_tagged_variant(params, variant, cattrs, tag) + serialize_internally_tagged_variant(mix, params, variant, cattrs, tag) } (attr::TagType::Adjacent { tag, content }, false) => { serialize_adjacently_tagged_variant( + mix, params, variant, cattrs, @@ -502,7 +513,13 @@ fn serialize_externally_tagged_variant( cattrs: &attr::Container, ) -> Fragment { let type_name = cattrs.name().serialize_name(); - let variant_name = variant.attrs.name().serialize_name(); + + let variant_name = match variant.attrs.name().serialize_name() { + VariantName::String(name) => name, + // An externally tagged variant with a non-string name will fail to + // check, so this branch is unreachable. + _ => unreachable!(), + }; if let Some(path) = variant.attrs.serialize_with() { let ser = wrap_serialize_variant_with(params, path, variant); @@ -557,6 +574,7 @@ fn serialize_externally_tagged_variant( &variant.fields, ), Style::Struct => serialize_struct_variant( + VariantMix::OnlyStrings, StructVariant::ExternallyTagged { variant_index, variant_name, @@ -569,6 +587,7 @@ fn serialize_externally_tagged_variant( } fn serialize_internally_tagged_variant( + mix: VariantMix, params: &Parameters, variant: &Variant, cattrs: &attr::Container, @@ -582,6 +601,7 @@ fn serialize_internally_tagged_variant( if let Some(path) = variant.attrs.serialize_with() { let ser = wrap_serialize_variant_with(params, path, variant); + let variant_name = serialize_variant_name(mix, &variant_name); return quote_expr! { _serde::__private::ser::serialize_tagged_newtype( __serializer, @@ -596,11 +616,12 @@ fn serialize_internally_tagged_variant( match effective_style(variant) { Style::Unit => { + let variant_name = variant_name.to_literal(mix); quote_block! { let mut __struct = _serde::Serializer::serialize_struct( __serializer, #type_name, 1)?; _serde::ser::SerializeStruct::serialize_field( - &mut __struct, #tag, #variant_name)?; + &mut __struct, #tag, &#variant_name)?; _serde::ser::SerializeStruct::end(__struct) } } @@ -613,6 +634,7 @@ fn serialize_internally_tagged_variant( let span = field.original.span(); let func = quote_spanned!(span=> _serde::__private::ser::serialize_tagged_newtype); + let variant_name = serialize_variant_name(mix, &variant_name); quote_expr! { #func( __serializer, @@ -625,6 +647,7 @@ fn serialize_internally_tagged_variant( } } Style::Struct => serialize_struct_variant( + mix, StructVariant::InternallyTagged { tag, variant_name }, params, &variant.fields, @@ -635,6 +658,7 @@ fn serialize_internally_tagged_variant( } fn serialize_adjacently_tagged_variant( + mix: VariantMix, params: &Parameters, variant: &Variant, cattrs: &attr::Container, @@ -645,12 +669,19 @@ fn serialize_adjacently_tagged_variant( let this_type = ¶ms.this_type; let type_name = cattrs.name().serialize_name(); let variant_name = variant.attrs.name().serialize_name(); - let serialize_variant = quote! { - &_serde::__private::ser::AdjacentlyTaggedEnumVariant { - enum_name: #type_name, - variant_index: #variant_index, - variant_name: #variant_name, + + let serialize_variant = match mix { + // if these are all strings, then we can serialize by variant index + VariantMix::OnlyStrings => { + quote! { + _serde::__private::ser::AdjacentlyTaggedEnumVariant { + enum_name: #type_name, + variant_index: #variant_index, + variant_name: #variant_name, + } + } } + _ => variant_name.to_literal(mix), }; let inner = Stmts(if let Some(path) = variant.attrs.serialize_with() { @@ -665,7 +696,7 @@ fn serialize_adjacently_tagged_variant( let mut __struct = _serde::Serializer::serialize_struct( __serializer, #type_name, 1)?; _serde::ser::SerializeStruct::serialize_field( - &mut __struct, #tag, #serialize_variant)?; + &mut __struct, #tag, &#serialize_variant)?; _serde::ser::SerializeStruct::end(__struct) }; } @@ -682,7 +713,7 @@ fn serialize_adjacently_tagged_variant( let mut __struct = _serde::Serializer::serialize_struct( __serializer, #type_name, 2)?; _serde::ser::SerializeStruct::serialize_field( - &mut __struct, #tag, #serialize_variant)?; + &mut __struct, #tag, &#serialize_variant)?; #func( &mut __struct, #content, #field_expr)?; _serde::ser::SerializeStruct::end(__struct) @@ -692,10 +723,11 @@ fn serialize_adjacently_tagged_variant( serialize_tuple_variant(TupleVariant::Untagged, params, &variant.fields) } Style::Struct => serialize_struct_variant( + mix, StructVariant::Untagged, params, &variant.fields, - variant_name, + &variant_name.to_variant_string(), ), } }); @@ -747,7 +779,7 @@ fn serialize_adjacently_tagged_variant( let mut __struct = _serde::Serializer::serialize_struct( __serializer, #type_name, 2)?; _serde::ser::SerializeStruct::serialize_field( - &mut __struct, #tag, #serialize_variant)?; + &mut __struct, #tag, &#serialize_variant)?; _serde::ser::SerializeStruct::serialize_field( &mut __struct, #content, &__AdjacentlyTagged { data: (#(#fields_ident,)*), @@ -791,7 +823,13 @@ fn serialize_untagged_variant( Style::Tuple => serialize_tuple_variant(TupleVariant::Untagged, params, &variant.fields), Style::Struct => { let type_name = cattrs.name().serialize_name(); - serialize_struct_variant(StructVariant::Untagged, params, &variant.fields, type_name) + serialize_struct_variant( + VariantMix::Any, + StructVariant::Untagged, + params, + &variant.fields, + type_name, + ) } } } @@ -871,19 +909,20 @@ enum StructVariant<'a> { }, InternallyTagged { tag: &'a str, - variant_name: &'a str, + variant_name: &'a VariantName, }, Untagged, } fn serialize_struct_variant( + mix: VariantMix, context: StructVariant, params: &Parameters, fields: &[Field], name: &str, ) -> Fragment { if fields.iter().any(|field| field.attrs.flatten()) { - return serialize_struct_variant_with_flatten(context, params, fields, name); + return serialize_struct_variant_with_flatten(mix, context, params, fields, name); } let struct_trait = match context { @@ -931,6 +970,8 @@ fn serialize_struct_variant( } } StructVariant::InternallyTagged { tag, variant_name } => { + let variant_name = serialize_variant_name(mix, &variant_name); + quote_block! { let mut __serde_state = _serde::Serializer::serialize_struct( __serializer, @@ -940,7 +981,7 @@ fn serialize_struct_variant( _serde::ser::SerializeStruct::serialize_field( &mut __serde_state, #tag, - #variant_name, + &#variant_name, )?; #(#serialize_fields)* _serde::ser::SerializeStruct::end(__serde_state) @@ -961,6 +1002,7 @@ fn serialize_struct_variant( } fn serialize_struct_variant_with_flatten( + mix: VariantMix, context: StructVariant, params: &Parameters, fields: &[Field], @@ -1022,6 +1064,7 @@ fn serialize_struct_variant_with_flatten( } } StructVariant::InternallyTagged { tag, variant_name } => { + let variant_name = serialize_variant_name(mix, &variant_name); quote_block! { let #let_mut __serde_state = _serde::Serializer::serialize_map( __serializer, @@ -1029,7 +1072,7 @@ fn serialize_struct_variant_with_flatten( _serde::ser::SerializeMap::serialize_entry( &mut __serde_state, #tag, - #variant_name, + &#variant_name, )?; #(#serialize_fields)* _serde::ser::SerializeMap::end(__serde_state) @@ -1348,3 +1391,28 @@ impl TupleTrait { } } } + +fn serialize_variant_name(mix: VariantMix, name: &VariantName) -> TokenStream { + match name { + VariantName::String(s) => quote!(_serde::__private::ser::VariantName::String(#s)), + VariantName::Integer(i) => { + let num = name.to_literal(mix); + let repr = match mix { + VariantMix::OnlyIntegers(repr) => repr, + _ => i.repr.unwrap_or(IntRepr::I64), + }; + let variant = match repr { + IntRepr::U8 => quote!(_serde::__private::ser::Integer::U8), + IntRepr::U16 => quote!(_serde::__private::ser::Integer::U16), + IntRepr::U32 => quote!(_serde::__private::ser::Integer::U32), + IntRepr::U64 => quote!(_serde::__private::ser::Integer::U64), + IntRepr::I8 => quote!(_serde::__private::ser::Integer::I8), + IntRepr::I16 => quote!(_serde::__private::ser::Integer::I16), + IntRepr::I32 => quote!(_serde::__private::ser::Integer::I32), + IntRepr::I64 => quote!(_serde::__private::ser::Integer::I64), + }; + quote!(_serde::__private::ser::VariantName::Integer(#variant(#num))) + } + VariantName::Boolean(b) => quote!(_serde::__private::ser::VariantName::Boolean(#b)), + } +} diff --git a/test_suite/tests/test_annotations.rs b/test_suite/tests/test_annotations.rs index 9aa328725..cef7731e7 100644 --- a/test_suite/tests/test_annotations.rs +++ b/test_suite/tests/test_annotations.rs @@ -2065,6 +2065,204 @@ fn test_expecting_message_identifier_enum() { ); } +#[test] +fn test_non_string_renames() { + #[derive(Deserialize, Serialize, PartialEq, Eq, Debug)] + #[serde(tag = "op")] + enum SpecialEnum { + #[serde(rename = -1)] + A, + #[serde(rename = true)] + B, + } + + assert_de_tokens( + &SpecialEnum::A, + &[ + Token::Map { len: None }, + Token::Str("op"), + Token::I64(-1), + Token::MapEnd, + ], + ); + + assert_de_tokens( + &SpecialEnum::B, + &[ + Token::Map { len: None }, + Token::Str("op"), + Token::Bool(true), + Token::MapEnd, + ], + ); + + assert_ser_tokens( + &SpecialEnum::A, + &[ + Token::Struct { + name: "SpecialEnum", + len: 1, + }, + Token::Str("op"), + Token::I64(-1), + Token::StructEnd, + ], + ); + + assert_ser_tokens( + &SpecialEnum::B, + &[ + Token::Struct { + name: "SpecialEnum", + len: 1, + }, + Token::Str("op"), + Token::Bool(true), + Token::StructEnd, + ], + ); + + #[derive(Deserialize, Serialize, PartialEq, Eq, Debug)] + #[serde(tag = "op", content = "d")] + enum AdjacentEnum { + #[serde(rename = -1i64)] + A { a: u64 }, + #[serde(rename = true)] + B, + } + + assert_de_tokens( + &AdjacentEnum::A { a: 1 }, + &[ + Token::Map { len: None }, + Token::Str("op"), + Token::I64(-1), + Token::Str("d"), + Token::Map { len: Some(1) }, + Token::Str("a"), + Token::U64(1), + Token::MapEnd, + Token::MapEnd, + ], + ); + + assert_de_tokens( + &AdjacentEnum::B, + &[ + Token::Map { len: None }, + Token::Str("op"), + Token::Bool(true), + Token::MapEnd, + ], + ); + + assert_ser_tokens( + &AdjacentEnum::A { a: 1 }, + &[ + Token::Struct { + name: "AdjacentEnum", + len: 2, + }, + Token::Str("op"), + Token::I64(-1), + Token::Str("d"), + Token::Struct { + name: "-1i64", + len: 1, + }, + Token::Str("a"), + Token::U64(1), + Token::StructEnd, + Token::StructEnd, + ], + ); + + assert_ser_tokens( + &AdjacentEnum::B, + &[ + Token::Struct { + name: "AdjacentEnum", + len: 1, + }, + Token::Str("op"), + Token::Bool(true), + Token::StructEnd, + ], + ) +} + +#[test] +fn test_non_string_aliases() { + #[derive(Deserialize, Serialize, PartialEq, Eq, Debug)] + #[serde(tag = "op")] + enum AliasedEnum { + #[serde(rename = 1, alias = 2, alias = "foo")] + A, + #[serde(rename = 3, alias = "bar", alias = false)] + B, + } + + assert_de_tokens( + &AliasedEnum::A, + &[ + Token::Map { len: None }, + Token::Str("op"), + Token::I64(1), + Token::MapEnd, + ], + ); + + assert_de_tokens( + &AliasedEnum::A, + &[ + Token::Map { len: None }, + Token::Str("op"), + Token::I64(2), + Token::MapEnd, + ], + ); + + assert_de_tokens( + &AliasedEnum::A, + &[ + Token::Map { len: None }, + Token::Str("op"), + Token::Str("foo"), + Token::MapEnd, + ], + ); + + assert_de_tokens( + &AliasedEnum::B, + &[ + Token::Map { len: None }, + Token::Str("op"), + Token::I64(3), + Token::MapEnd, + ], + ); + + assert_de_tokens( + &AliasedEnum::B, + &[ + Token::Map { len: None }, + Token::Str("op"), + Token::Str("bar"), + Token::MapEnd, + ], + ); + + assert_de_tokens( + &AliasedEnum::B, + &[ + Token::Map { len: None }, + Token::Str("op"), + Token::Bool(false), + Token::MapEnd, + ], + ); +} + mod flatten { use super::*; diff --git a/test_suite/tests/ui/rename/non_string_on_externally_tagged.rs b/test_suite/tests/ui/rename/non_string_on_externally_tagged.rs new file mode 100644 index 000000000..abddc767e --- /dev/null +++ b/test_suite/tests/ui/rename/non_string_on_externally_tagged.rs @@ -0,0 +1,9 @@ +use serde_derive::Serialize; + +#[derive(Serialize)] +enum S { + #[serde(rename = 1)] + A, +} + +fn main() {} diff --git a/test_suite/tests/ui/rename/non_string_on_externally_tagged.stderr b/test_suite/tests/ui/rename/non_string_on_externally_tagged.stderr new file mode 100644 index 000000000..837477a3a --- /dev/null +++ b/test_suite/tests/ui/rename/non_string_on_externally_tagged.stderr @@ -0,0 +1,5 @@ +error: #[serde(rename)] must use a string name in externally tagged enums + --> $DIR/non_string_on_externally_tagged.rs:5:13 + | +5 | #[serde(rename = 1)] + | ^^^^^^^^^^ diff --git a/test_suite/tests/ui/rename/non_string_on_untagged_enum.rs b/test_suite/tests/ui/rename/non_string_on_untagged_enum.rs new file mode 100644 index 000000000..45d77d8bc --- /dev/null +++ b/test_suite/tests/ui/rename/non_string_on_untagged_enum.rs @@ -0,0 +1,10 @@ +use serde_derive::Serialize; + +#[derive(Serialize)] +#[serde(untagged)] +enum S { + #[serde(rename = 1)] + A, +} + +fn main() {} diff --git a/test_suite/tests/ui/rename/non_string_on_untagged_enum.stderr b/test_suite/tests/ui/rename/non_string_on_untagged_enum.stderr new file mode 100644 index 000000000..1bca4c724 --- /dev/null +++ b/test_suite/tests/ui/rename/non_string_on_untagged_enum.stderr @@ -0,0 +1,5 @@ +error: #[serde(rename)] must use a string name in untagged enums + --> $DIR/non_string_on_untagged_enum.rs:6:13 + | +6 | #[serde(rename = 1)] + | ^^^^^^^^^^