diff --git a/borsh-derive/src/internals/attributes/field/mod.rs b/borsh-derive/src/internals/attributes/field/mod.rs index 0ff2555c1..3bc61b395 100644 --- a/borsh-derive/src/internals/attributes/field/mod.rs +++ b/borsh-derive/src/internals/attributes/field/mod.rs @@ -185,37 +185,41 @@ impl Attributes { #[cfg(feature = "schema")] impl Attributes { fn check_schema(&self, attr: &Attribute) -> Result<(), syn::Error> { - if let Some(ref schema) = self.schema { - if self.skip && schema.params.is_some() { - return Err(syn::Error::new_spanned( - attr, - format!( - "`{}` cannot be used at the same time as `{}({})`", - SKIP.0, SCHEMA.0, PARAMS.1 - ), - )); - } + let Some(ref schema) = self.schema else { + return Ok(()); + }; - if self.skip && schema.with_funcs.is_some() { - return Err(syn::Error::new_spanned( - attr, - format!( - "`{}` cannot be used at the same time as `{}({})`", - SKIP.0, SCHEMA.0, WITH_FUNCS.1 - ), - )); - } + if self.skip && schema.params.is_some() { + return Err(syn::Error::new_spanned( + attr, + format!( + "`{}` cannot be used at the same time as `{}({})`", + SKIP.0, SCHEMA.0, PARAMS.1 + ), + )); } + + if self.skip && schema.with_funcs.is_some() { + return Err(syn::Error::new_spanned( + attr, + format!( + "`{}` cannot be used at the same time as `{}({})`", + SKIP.0, SCHEMA.0, WITH_FUNCS.1 + ), + )); + } + Ok(()) } pub(crate) fn needs_schema_params_derive(&self) -> bool { - if let Some(ref schema) = self.schema { - if schema.params.is_some() { - return false; - } - } - true + !matches!( + &self.schema, + Some(schema::Attributes { + params: Some(_), + .. + }) + ) } pub(crate) fn schema_declaration(&self) -> Option { diff --git a/borsh-derive/src/internals/deserialize/mod.rs b/borsh-derive/src/internals/deserialize/mod.rs index f9732c05f..9133b26d6 100644 --- a/borsh-derive/src/internals/deserialize/mod.rs +++ b/borsh-derive/src/internals/deserialize/mod.rs @@ -1,6 +1,6 @@ use proc_macro2::TokenStream as TokenStream2; use quote::quote; -use syn::{ExprPath, Generics, Ident, Path}; +use syn::{parse_quote, ExprPath, Generics, Ident, Path, Type}; use super::{ attributes::{field, BoundType}, @@ -25,13 +25,14 @@ impl GenericsOutput { default_visitor: generics::FindTyParams::new(generics), } } + fn extend(self, where_clause: &mut syn::WhereClause, cratename: &Path) { - let de_trait: Path = syn::parse2(quote! { #cratename::de::BorshDeserialize }).unwrap(); - let default_trait: Path = syn::parse2(quote! { core::default::Default }).unwrap(); + let de_trait: Path = parse_quote! { #cratename::de::BorshDeserialize }; + let default_trait: Path = parse_quote! { ::core::default::Default }; let de_predicates = generics::compute_predicates(self.deserialize_visitor.process_for_bounds(), &de_trait); let default_predicates = - generics::compute_predicates(self.default_visitor.process_for_bounds(), &default_trait); + generics::compute_predicates(self.default_visitor.process_for_bounds(), &default_trait); // FIXME: this is not correct, the `Default` trait should be requested for field types, rather than their type parameters where_clause.predicates.extend(de_predicates); where_clause.predicates.extend(default_predicates); where_clause.predicates.extend(self.overrides); @@ -61,7 +62,7 @@ fn process_field( if needs_bounds_derive { generics.deserialize_visitor.visit_field(field); } - field_output(field_name, cratename, parsed.deserialize_with) + field_output(field_name, &field.ty, cratename, parsed.deserialize_with) }; body.extend(delta); Ok(()) @@ -71,20 +72,18 @@ fn process_field( /// of code, which deserializes single field fn field_output( field_name: Option<&Ident>, + field_type: &Type, cratename: &Path, deserialize_with: Option, ) -> TokenStream2 { - let default_path: ExprPath = - syn::parse2(quote! { #cratename::BorshDeserialize::deserialize_reader }).unwrap(); - let path: ExprPath = deserialize_with.unwrap_or(default_path); + let default_path = || { + parse_quote! { <#field_type as #cratename::BorshDeserialize>::deserialize_reader } + }; + let path: ExprPath = deserialize_with.unwrap_or_else(default_path); if let Some(field_name) = field_name { - quote! { - #field_name: #path(reader)?, - } + quote! { #field_name: #path(reader)?, } } else { - quote! { - #path(reader)?, - } + quote! { #path(reader)?, } } } @@ -92,10 +91,8 @@ fn field_output( /// of code, which deserializes single skipped field fn field_default_output(field_name: Option<&Ident>) -> TokenStream2 { if let Some(field_name) = field_name { - quote! { - #field_name: core::default::Default::default(), - } + quote! { #field_name: ::core::default::Default::default(), } } else { - quote! { core::default::Default::default(), } + quote! { ::core::default::Default::default(), } } } diff --git a/borsh-derive/src/internals/generics.rs b/borsh-derive/src/internals/generics.rs index 7914b3b00..e3cf67d17 100644 --- a/borsh-derive/src/internals/generics.rs +++ b/borsh-derive/src/internals/generics.rs @@ -1,9 +1,10 @@ use std::collections::{HashMap, HashSet}; -use quote::{quote, ToTokens}; +use quote::ToTokens; use syn::{ - punctuated::Pair, Field, GenericArgument, Generics, Ident, Macro, Path, PathArguments, - PathSegment, ReturnType, Type, TypeParamBound, TypePath, WhereClause, WherePredicate, + parse_quote, punctuated::Pair, Field, GenericArgument, Generics, Ident, Macro, Path, + PathArguments, PathSegment, ReturnType, Type, TypeGroup, TypeParamBound, TypePath, WhereClause, + WherePredicate, }; pub fn default_where(where_clause: Option<&WhereClause>) -> WhereClause { @@ -19,12 +20,7 @@ pub fn default_where(where_clause: Option<&WhereClause>) -> WhereClause { pub fn compute_predicates(params: Vec, traitname: &Path) -> Vec { params .into_iter() - .map(|param| { - syn::parse2(quote! { - #param: #traitname - }) - .unwrap() - }) + .map(|param| parse_quote! { #param: #traitname }) .collect() } @@ -32,7 +28,7 @@ pub fn compute_predicates(params: Vec, traitname: &Path) -> Vec Generics { - syn::Generics { + Generics { params: generics .params .iter() @@ -74,8 +70,8 @@ pub struct FindTyParams { } fn ungroup(mut ty: &Type) -> &Type { - while let Type::Group(group) = ty { - ty = &group.elem; + while let Type::Group(TypeGroup { elem, .. }) = ty { + ty = &**elem; } ty } @@ -99,6 +95,7 @@ impl FindTyParams { associated_type_params_usage: HashMap::new(), } } + pub fn process_for_bounds(self) -> Vec { let relevant_type_params = self.relevant_type_params; let associated_type_params_usage = self.associated_type_params_usage;