From 80a59a9c7dfc6e6c1207ec4e1f04690649623ab9 Mon Sep 17 00:00:00 2001 From: Piotr Dulikowski Date: Fri, 20 Oct 2023 10:01:32 +0200 Subject: [PATCH 1/4] scylla-macros: introduce SerializeCql derive macro Introduce a derive macro which serializes a struct into a UDT. Unlike the previous IntoUserType, the new macro takes care to match the struct fields to UDT fields by their names. It does not assume that the order of the fields in the Rust struct is the same as in the UDT. --- Cargo.lock.msrv | 48 ++++ scylla-cql/src/lib.rs | 12 + scylla-cql/src/macros.rs | 62 +++++ scylla-cql/src/types/serialize/value.rs | 265 +++++++++++++++++++++- scylla-cql/src/types/serialize/writers.rs | 1 + scylla-macros/Cargo.toml | 1 + scylla-macros/src/lib.rs | 12 + scylla-macros/src/serialize/cql.rs | 224 ++++++++++++++++++ scylla-macros/src/serialize/mod.rs | 1 + scylla/tests/integration/hygiene.rs | 6 + 10 files changed, 630 insertions(+), 2 deletions(-) create mode 100644 scylla-macros/src/serialize/cql.rs create mode 100644 scylla-macros/src/serialize/mod.rs diff --git a/Cargo.lock.msrv b/Cargo.lock.msrv index 10393ac6eb..59c9ee1b56 100644 --- a/Cargo.lock.msrv +++ b/Cargo.lock.msrv @@ -340,6 +340,41 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "darling" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0209d94da627ab5605dcccf08bb18afa5009cfbef48d8a8b7d7bdbc79be25c5e" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "177e3443818124b357d8e76f53be906d60937f0d3a90773a664fa63fa253e621" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.38", +] + +[[package]] +name = "darling_macro" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "836a9bbc7ad63342d6d6e7b815ccab164bc77a2d95d84bc3117a8c0d5c98e2d5" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.38", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -444,6 +479,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "foreign-types" version = "0.3.2" @@ -651,6 +692,12 @@ dependencies = [ "cc", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.4.0" @@ -1455,6 +1502,7 @@ dependencies = [ name = "scylla-macros" version = "0.2.0" dependencies = [ + "darling", "proc-macro2", "quote", "syn 2.0.38", diff --git a/scylla-cql/src/lib.rs b/scylla-cql/src/lib.rs index b8d7d28671..ab94470e10 100644 --- a/scylla-cql/src/lib.rs +++ b/scylla-cql/src/lib.rs @@ -20,4 +20,16 @@ pub mod _macro_internal { SerializedResult, SerializedValues, Value, ValueList, ValueTooBig, }; pub use crate::macros::*; + + pub use crate::types::serialize::value::{ + BuiltinSerializationError as BuiltinTypeSerializationError, + BuiltinSerializationErrorKind as BuiltinTypeSerializationErrorKind, + BuiltinTypeCheckError as BuiltinTypeTypeCheckError, + BuiltinTypeCheckErrorKind as BuiltinTypeTypeCheckErrorKind, SerializeCql, + UdtSerializationErrorKind, UdtTypeCheckErrorKind, + }; + pub use crate::types::serialize::writers::WrittenCellProof; + pub use crate::types::serialize::{CellValueBuilder, CellWriter, SerializationError}; + + pub use crate::frame::response::result::ColumnType; } diff --git a/scylla-cql/src/macros.rs b/scylla-cql/src/macros.rs index 8d60312145..56f1f43cf3 100644 --- a/scylla-cql/src/macros.rs +++ b/scylla-cql/src/macros.rs @@ -13,6 +13,68 @@ pub use scylla_macros::IntoUserType; /// #[derive(ValueList)] allows to pass struct as a list of values for a query pub use scylla_macros::ValueList; +/// Derive macro for the [`SerializeCql`](crate::types::serialize::value::SerializeCql) trait +/// which serializes given Rust structure as a User Defined Type (UDT). +/// +/// At the moment, only structs with named fields are supported. The generated +/// implementation of the trait will match the struct fields to UDT fields +/// by name automatically. +/// +/// Serialization will fail if there are some fields in the UDT that don't match +/// to any of the Rust struct fields, _or vice versa_. +/// +/// In case of failure, either [`BuiltinTypeCheckError`](crate::types::serialize::value::BuiltinTypeCheckError) +/// or [`BuiltinSerializationError`](crate::types::serialize::value::BuiltinSerializationError) +/// will be returned. +/// +/// # Example +/// +/// A UDT defined like this: +/// +/// ```notrust +/// CREATE TYPE ks.my_udt (a int, b text, c blob); +/// ``` +/// +/// ...can be serialized using the following struct: +/// +/// ```rust +/// # use scylla_cql::macros::SerializeCql; +/// #[derive(SerializeCql)] +/// # #[scylla(crate = scylla_cql)] +/// struct MyUdt { +/// a: i32, +/// b: Option, +/// c: Vec, +/// } +/// ``` +/// +/// # Attributes +/// +/// `#[scylla(crate = crate_name)]` +/// +/// By default, the code generated by the derive macro will refer to the items +/// defined by the driver (types, traits, etc.) via the `::scylla` path. +/// For example, it will refer to the [`SerializeCql`](crate::types::serialize::value::SerializeCql) trait +/// using the following path: +/// +/// ```rust,ignore +/// use ::scylla::_macro_internal::SerializeCql; +/// ``` +/// +/// Most users will simply add `scylla` to their dependencies, then use +/// the derive macro and the path above will work. However, there are some +/// niche cases where this path will _not_ work: +/// +/// - The `scylla` crate is imported under a different name, +/// - The `scylla` crate is _not imported at all_ - the macro actually +/// is defined in the `scylla-macros` crate and the generated code depends +/// on items defined in `scylla-cql`. +/// +/// It's not possible to automatically resolve those issues in the procedural +/// macro itself, so in those cases the user must provide an alternative path +/// to either the `scylla` or `scylla-cql` crate. +pub use scylla_macros::SerializeCql; + // Reexports for derive(IntoUserType) pub use bytes::{BufMut, Bytes, BytesMut}; diff --git a/scylla-cql/src/types/serialize/value.rs b/scylla-cql/src/types/serialize/value.rs index 37244f7073..85033dac25 100644 --- a/scylla-cql/src/types/serialize/value.rs +++ b/scylla-cql/src/types/serialize/value.rs @@ -1309,6 +1309,9 @@ pub enum UdtTypeCheckErrorKind { /// The name of the UDT being serialized to does not match. NameMismatch { keyspace: String, type_name: String }, + /// One of the fields that is required to be present by the Rust struct was not present in the CQL UDT type. + MissingField { field_name: String }, + /// The Rust data contains a field that is not present in the UDT UnexpectedFieldInDestination { field_name: String }, } @@ -1327,6 +1330,9 @@ impl Display for UdtTypeCheckErrorKind { f, "the Rust UDT name does not match the actual CQL UDT name ({keyspace}.{type_name})" ), + UdtTypeCheckErrorKind::MissingField { field_name } => { + write!(f, "the field {field_name} is missing from the CQL UDT type") + } UdtTypeCheckErrorKind::UnexpectedFieldInDestination { field_name } => write!( f, "the field {field_name} present in the Rust data is not present in the CQL type" @@ -1369,11 +1375,17 @@ pub enum ValueToSerializeCqlAdapterError { #[cfg(test)] mod tests { - use crate::frame::response::result::ColumnType; + use crate::frame::response::result::{ColumnType, CqlValue}; use crate::frame::value::{MaybeUnset, Value}; + use crate::types::serialize::value::{ + BuiltinSerializationError, BuiltinSerializationErrorKind, BuiltinTypeCheckError, + BuiltinTypeCheckErrorKind, + }; use crate::types::serialize::CellWriter; - use super::SerializeCql; + use scylla_macros::SerializeCql; + + use super::{SerializeCql, UdtSerializationErrorKind, UdtTypeCheckErrorKind}; fn check_compat(v: V) { let mut legacy_data = Vec::new(); @@ -1407,4 +1419,253 @@ mod tests { assert_eq!(typed_data, erased_data); } + + fn do_serialize(t: T, typ: &ColumnType) -> Vec { + let mut ret = Vec::new(); + let writer = CellWriter::new(&mut ret); + t.serialize(typ, writer).unwrap(); + ret + } + + // Do not remove. It's not used in tests but we keep it here to check that + // we properly ignore warnings about unused variables, unnecessary `mut`s + // etc. that usually pop up when generating code for empty structs. + #[derive(SerializeCql)] + #[scylla(crate = crate)] + struct TestUdtWithNoFields {} + + #[derive(SerializeCql, Debug, PartialEq, Eq, Default)] + #[scylla(crate = crate)] + struct TestUdtWithFieldSorting { + a: String, + b: i32, + c: Vec, + } + + #[test] + fn test_udt_serialization_with_field_sorting_correct_order() { + let typ = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let reference = do_serialize( + CqlValue::UserDefinedType { + keyspace: "ks".to_string(), + type_name: "typ".to_string(), + fields: vec![ + ( + "a".to_string(), + Some(CqlValue::Text(String::from("Ala ma kota"))), + ), + ("b".to_string(), Some(CqlValue::Int(42))), + ( + "c".to_string(), + Some(CqlValue::List(vec![ + CqlValue::BigInt(1), + CqlValue::BigInt(2), + CqlValue::BigInt(3), + ])), + ), + ], + }, + &typ, + ); + let udt = do_serialize( + TestUdtWithFieldSorting { + a: "Ala ma kota".to_owned(), + b: 42, + c: vec![1, 2, 3], + }, + &typ, + ); + + assert_eq!(reference, udt); + } + + #[test] + fn test_udt_serialization_with_field_sorting_incorrect_order() { + let typ = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + // Two first columns are swapped + ("b".to_string(), ColumnType::Int), + ("a".to_string(), ColumnType::Text), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let reference = do_serialize( + CqlValue::UserDefinedType { + keyspace: "ks".to_string(), + type_name: "typ".to_string(), + fields: vec![ + // FIXME: UDTs in CqlValue should also honor the order + // For now, it's swapped here as well + ("b".to_string(), Some(CqlValue::Int(42))), + ( + "a".to_string(), + Some(CqlValue::Text(String::from("Ala ma kota"))), + ), + ( + "c".to_string(), + Some(CqlValue::List(vec![ + CqlValue::BigInt(1), + CqlValue::BigInt(2), + CqlValue::BigInt(3), + ])), + ), + ], + }, + &typ, + ); + let udt = do_serialize( + TestUdtWithFieldSorting { + a: "Ala ma kota".to_owned(), + b: 42, + c: vec![1, 2, 3], + }, + &typ, + ); + + assert_eq!(reference, udt); + } + + #[test] + fn test_udt_serialization_failing_type_check() { + let typ_not_udt = ColumnType::Ascii; + let udt = TestUdtWithFieldSorting::default(); + let mut data = Vec::new(); + + let err = udt + .serialize(&typ_not_udt, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NotUdt) + )); + + let typ_without_c = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + // Last field is missing + ], + }; + + let err = udt + .serialize(&typ_without_c, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::MissingField { .. }) + )); + + let typ_unexpected_field = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + // Unexpected field + ("d".to_string(), ColumnType::Counter), + ], + }; + + let err = udt + .serialize(&typ_unexpected_field, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::UnexpectedFieldInDestination { .. } + ) + )); + + let typ_wrong_type = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ("c".to_string(), ColumnType::TinyInt), // Wrong column type + ], + }; + + let err = udt + .serialize(&typ_wrong_type, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinSerializationErrorKind::UdtError( + UdtSerializationErrorKind::FieldSerializationFailed { .. } + ) + )); + } + + #[derive(SerializeCql)] + #[scylla(crate = crate)] + struct TestUdtWithGenerics<'a, T: SerializeCql> { + a: &'a str, + b: T, + } + + #[test] + fn test_udt_serialization_with_generics() { + // A minimal smoke test just to test that it works. + fn check_with_type(typ: ColumnType, t: T, cql_t: CqlValue) { + let typ = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![("a".to_string(), ColumnType::Text), ("b".to_string(), typ)], + }; + let reference = do_serialize( + CqlValue::UserDefinedType { + keyspace: "ks".to_string(), + type_name: "typ".to_string(), + fields: vec![ + ( + "a".to_string(), + Some(CqlValue::Text(String::from("Ala ma kota"))), + ), + ("b".to_string(), Some(cql_t)), + ], + }, + &typ, + ); + let udt = do_serialize( + TestUdtWithGenerics { + a: "Ala ma kota", + b: t, + }, + &typ, + ); + assert_eq!(reference, udt); + } + + check_with_type(ColumnType::Int, 123_i32, CqlValue::Int(123_i32)); + check_with_type(ColumnType::Double, 123_f64, CqlValue::Double(123_f64)); + } } diff --git a/scylla-cql/src/types/serialize/writers.rs b/scylla-cql/src/types/serialize/writers.rs index 4d350adc75..9b2be47998 100644 --- a/scylla-cql/src/types/serialize/writers.rs +++ b/scylla-cql/src/types/serialize/writers.rs @@ -183,6 +183,7 @@ impl<'buf> CellValueBuilder<'buf> { /// [`SerializeCql::serialize`](super::value::SerializeCql::serialize): either /// the method succeeds and returns a proof that it serialized itself /// into the given value, or it fails and returns an error or panics. +#[derive(Debug)] pub struct WrittenCellProof<'buf> { /// Using *mut &'buf () is deliberate and makes WrittenCellProof invariant /// on the 'buf lifetime parameter. diff --git a/scylla-macros/Cargo.toml b/scylla-macros/Cargo.toml index d39bd58116..ac5f5d16f1 100644 --- a/scylla-macros/Cargo.toml +++ b/scylla-macros/Cargo.toml @@ -12,6 +12,7 @@ license = "MIT OR Apache-2.0" proc-macro = true [dependencies] +darling = "0.20.0" syn = "2.0" quote = "1.0" proc-macro2 = "1.0" \ No newline at end of file diff --git a/scylla-macros/src/lib.rs b/scylla-macros/src/lib.rs index 59300a0020..84ee58bca0 100644 --- a/scylla-macros/src/lib.rs +++ b/scylla-macros/src/lib.rs @@ -1,4 +1,5 @@ use proc_macro::TokenStream; +use quote::ToTokens; mod from_row; mod from_user_type; @@ -6,6 +7,17 @@ mod into_user_type; mod parser; mod value_list; +mod serialize; + +/// See the documentation for this item in the `scylla` crate. +#[proc_macro_derive(SerializeCql, attributes(scylla))] +pub fn serialize_cql_derive(tokens_input: TokenStream) -> TokenStream { + match serialize::cql::derive_serialize_cql(tokens_input) { + Ok(t) => t.into_token_stream().into(), + Err(e) => e.into_compile_error().into(), + } +} + /// #[derive(FromRow)] derives FromRow for struct /// Works only on simple structs without generics etc #[proc_macro_derive(FromRow, attributes(scylla_crate))] diff --git a/scylla-macros/src/serialize/cql.rs b/scylla-macros/src/serialize/cql.rs new file mode 100644 index 0000000000..f19e47b27c --- /dev/null +++ b/scylla-macros/src/serialize/cql.rs @@ -0,0 +1,224 @@ +use darling::FromAttributes; +use proc_macro::TokenStream; +use proc_macro2::Span; +use syn::parse_quote; + +#[derive(FromAttributes)] +#[darling(attributes(scylla))] +struct Attributes { + #[darling(rename = "crate")] + crate_path: Option, +} + +impl Attributes { + fn crate_path(&self) -> syn::Path { + self.crate_path + .as_ref() + .map(|p| parse_quote!(#p::_macro_internal)) + .unwrap_or_else(|| parse_quote!(::scylla::_macro_internal)) + } +} + +struct Context { + attributes: Attributes, + fields: Vec, +} + +pub fn derive_serialize_cql(tokens_input: TokenStream) -> Result { + let input: syn::DeriveInput = syn::parse(tokens_input)?; + let struct_name = input.ident.clone(); + let named_fields = crate::parser::parse_named_fields(&input, "SerializeCql")?; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let attributes = Attributes::from_attributes(&input.attrs)?; + + let crate_path = attributes.crate_path(); + let implemented_trait: syn::Path = parse_quote!(#crate_path::SerializeCql); + + let fields = named_fields.named.iter().cloned().collect(); + let ctx = Context { attributes, fields }; + let gen = FieldSortingGenerator { ctx: &ctx }; + + let serialize_item = gen.generate_serialize(); + + let res = parse_quote! { + impl #impl_generics #implemented_trait for #struct_name #ty_generics #where_clause { + #serialize_item + } + }; + Ok(res) +} + +impl Context { + fn generate_udt_type_match(&self, err: syn::Expr) -> syn::Stmt { + let crate_path = self.attributes.crate_path(); + + parse_quote! { + let (type_name, keyspace, field_types) = match typ { + #crate_path::ColumnType::UserDefinedType { type_name, keyspace, field_types, .. } => { + (type_name, keyspace, field_types) + } + _ => return ::std::result::Result::Err(mk_typck_err(#err)), + }; + } + } + + fn generate_mk_typck_err(&self) -> syn::Stmt { + let crate_path = self.attributes.crate_path(); + parse_quote! { + let mk_typck_err = |kind: #crate_path::UdtTypeCheckErrorKind| -> #crate_path::SerializationError { + #crate_path::SerializationError::new( + #crate_path::BuiltinTypeTypeCheckError { + rust_name: ::std::any::type_name::(), + got: <_ as ::std::clone::Clone>::clone(typ), + kind: #crate_path::BuiltinTypeTypeCheckErrorKind::UdtError(kind), + } + ) + }; + } + } + + fn generate_mk_ser_err(&self) -> syn::Stmt { + let crate_path = self.attributes.crate_path(); + parse_quote! { + let mk_ser_err = |kind: #crate_path::UdtSerializationErrorKind| -> #crate_path::SerializationError { + #crate_path::SerializationError::new( + #crate_path::BuiltinTypeSerializationError { + rust_name: ::std::any::type_name::(), + got: <_ as ::std::clone::Clone>::clone(typ), + kind: #crate_path::BuiltinTypeSerializationErrorKind::UdtError(kind), + } + ) + }; + } + } +} + +// Generates an implementation of the trait which sorts the fields according +// to how it is defined in the database. +struct FieldSortingGenerator<'a> { + ctx: &'a Context, +} + +impl<'a> FieldSortingGenerator<'a> { + fn generate_serialize(&self) -> syn::TraitItemFn { + // Need to: + // - Check that all required fields are there and no more + // - Check that the field types match + let mut statements: Vec = Vec::new(); + + let crate_path = self.ctx.attributes.crate_path(); + + let rust_field_idents = self + .ctx + .fields + .iter() + .map(|f| f.ident.clone()) + .collect::>(); + let rust_field_names = rust_field_idents + .iter() + .map(|i| i.as_ref().unwrap().to_string()) + .collect::>(); + let udt_field_names = rust_field_names.clone(); // For now, it's the same + let field_types = self.ctx.fields.iter().map(|f| &f.ty).collect::>(); + + // Declare helper lambdas for creating errors + statements.push(self.ctx.generate_mk_typck_err()); + statements.push(self.ctx.generate_mk_ser_err()); + + // Check that the type we want to serialize to is a UDT + statements.push( + self.ctx + .generate_udt_type_match(parse_quote!(#crate_path::UdtTypeCheckErrorKind::NotUdt)), + ); + + // Generate a "visited" flag for each field + let visited_flag_names = rust_field_names + .iter() + .map(|s| syn::Ident::new(&format!("visited_flag_{}", s), Span::call_site())) + .collect::>(); + statements.extend::>(parse_quote! { + #(let mut #visited_flag_names = false;)* + }); + + // Generate a variable that counts down visited fields. + let field_count = self.ctx.fields.len(); + statements.push(parse_quote! { + let mut remaining_count = #field_count; + }); + + // Turn the cell writer into a value builder + statements.push(parse_quote! { + let mut builder = #crate_path::CellWriter::into_value_builder(writer); + }); + + // Generate a loop over the fields and a `match` block to match on + // the field name. + statements.push(parse_quote! { + for (field_name, field_type) in field_types { + match ::std::string::String::as_str(field_name) { + #( + #udt_field_names => { + let sub_builder = #crate_path::CellValueBuilder::make_sub_writer(&mut builder); + match <#field_types as #crate_path::SerializeCql>::serialize(&self.#rust_field_idents, field_type, sub_builder) { + ::std::result::Result::Ok(_proof) => {} + ::std::result::Result::Err(err) => { + return ::std::result::Result::Err(mk_ser_err( + #crate_path::UdtSerializationErrorKind::FieldSerializationFailed { + field_name: <_ as ::std::clone::Clone>::clone(field_name), + err, + } + )); + } + } + if !#visited_flag_names { + #visited_flag_names = true; + remaining_count -= 1; + } + } + )* + _ => return ::std::result::Result::Err(mk_typck_err( + #crate_path::UdtTypeCheckErrorKind::UnexpectedFieldInDestination { + field_name: <_ as ::std::clone::Clone>::clone(field_name), + } + )), + } + } + }); + + // Finally, check that all fields were consumed. + // If there are some missing fields, return an error + statements.push(parse_quote! { + if remaining_count > 0 { + #( + if !#visited_flag_names { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::UdtTypeCheckErrorKind::MissingField { + field_name: <_ as ::std::string::ToString>::to_string(#rust_field_names), + } + )); + } + )* + ::std::unreachable!() + } + }); + + parse_quote! { + fn serialize<'b>( + &self, + typ: &#crate_path::ColumnType, + writer: #crate_path::CellWriter<'b>, + ) -> ::std::result::Result<#crate_path::WrittenCellProof<'b>, #crate_path::SerializationError> { + #(#statements)* + let proof = #crate_path::CellValueBuilder::finish(builder) + .map_err(|_| #crate_path::SerializationError::new( + #crate_path::BuiltinTypeSerializationError { + rust_name: ::std::any::type_name::(), + got: <_ as ::std::clone::Clone>::clone(typ), + kind: #crate_path::BuiltinTypeSerializationErrorKind::SizeOverflow, + } + ) as #crate_path::SerializationError)?; + ::std::result::Result::Ok(proof) + } + } + } +} diff --git a/scylla-macros/src/serialize/mod.rs b/scylla-macros/src/serialize/mod.rs new file mode 100644 index 0000000000..15fd9ae87c --- /dev/null +++ b/scylla-macros/src/serialize/mod.rs @@ -0,0 +1 @@ +pub(crate) mod cql; diff --git a/scylla/tests/integration/hygiene.rs b/scylla/tests/integration/hygiene.rs index 6195bb0256..12d55ccb61 100644 --- a/scylla/tests/integration/hygiene.rs +++ b/scylla/tests/integration/hygiene.rs @@ -63,6 +63,12 @@ macro_rules! test_crate { let sv2 = tuple_with_same_layout.serialized().unwrap().into_owned(); assert_eq!(sv, sv2); } + + #[derive(_scylla::macros::SerializeCql)] + #[scylla(crate = _scylla)] + struct TestStructNew { + x: ::core::primitive::i32, + } }; } From 30a69f84199498fc7f1c403d943fc926c9c79cd4 Mon Sep 17 00:00:00 2001 From: Piotr Dulikowski Date: Fri, 20 Oct 2023 10:08:40 +0200 Subject: [PATCH 2/4] scylla-macros: introduce SerializeRow derive macro Introduce a derive macro which serializes a struct into bind markers of a statement. Unlike the previous ValueList, the new macro takes care to match the struct fields to bind markers/columns by their names. --- scylla-cql/src/lib.rs | 11 +- scylla-cql/src/macros.rs | 64 ++++++++ scylla-cql/src/types/serialize/row.rs | 168 ++++++++++++++++++++- scylla-macros/src/lib.rs | 9 ++ scylla-macros/src/serialize/mod.rs | 1 + scylla-macros/src/serialize/row.rs | 202 ++++++++++++++++++++++++++ scylla/tests/integration/hygiene.rs | 2 +- 7 files changed, 454 insertions(+), 3 deletions(-) create mode 100644 scylla-macros/src/serialize/row.rs diff --git a/scylla-cql/src/lib.rs b/scylla-cql/src/lib.rs index ab94470e10..6d74b680ba 100644 --- a/scylla-cql/src/lib.rs +++ b/scylla-cql/src/lib.rs @@ -21,6 +21,13 @@ pub mod _macro_internal { }; pub use crate::macros::*; + pub use crate::types::serialize::row::{ + BuiltinSerializationError as BuiltinRowSerializationError, + BuiltinSerializationErrorKind as BuiltinRowSerializationErrorKind, + BuiltinTypeCheckError as BuiltinRowTypeCheckError, + BuiltinTypeCheckErrorKind as BuiltinRowTypeCheckErrorKind, RowSerializationContext, + SerializeRow, + }; pub use crate::types::serialize::value::{ BuiltinSerializationError as BuiltinTypeSerializationError, BuiltinSerializationErrorKind as BuiltinTypeSerializationErrorKind, @@ -29,7 +36,9 @@ pub mod _macro_internal { UdtSerializationErrorKind, UdtTypeCheckErrorKind, }; pub use crate::types::serialize::writers::WrittenCellProof; - pub use crate::types::serialize::{CellValueBuilder, CellWriter, SerializationError}; + pub use crate::types::serialize::{ + CellValueBuilder, CellWriter, RowWriter, SerializationError, + }; pub use crate::frame::response::result::ColumnType; } diff --git a/scylla-cql/src/macros.rs b/scylla-cql/src/macros.rs index 56f1f43cf3..8f53e24fa9 100644 --- a/scylla-cql/src/macros.rs +++ b/scylla-cql/src/macros.rs @@ -75,6 +75,70 @@ pub use scylla_macros::ValueList; /// to either the `scylla` or `scylla-cql` crate. pub use scylla_macros::SerializeCql; +/// Derive macro for the [`SerializeRow`](crate::types::serialize::row::SerializeRow) trait +/// which serializes given Rust structure into bind markers for a CQL statement. +/// +/// At the moment, only structs with named fields are supported. The generated +/// implementation of the trait will match the struct fields to bind markers/columns +/// by name automatically. +/// +/// Serialization will fail if there are some bind markers/columns in the statement +/// that don't match to any of the Rust struct fields, _or vice versa_. +/// +/// In case of failure, either [`BuiltinTypeCheckError`](crate::types::serialize::row::BuiltinTypeCheckError) +/// or [`BuiltinSerializationError`](crate::types::serialize::row::BuiltinSerializationError) +/// will be returned. +/// +/// # Example +/// +/// A UDT defined like this: +/// Given a table and a query: +/// +/// ```notrust +/// CREATE TABLE ks.my_t (a int PRIMARY KEY, b text, c blob); +/// INSERT INTO ks.my_t (a, b, c) VALUES (?, ?, ?); +/// ``` +/// +/// ...the values for the query can be serialized using the following struct: +/// +/// ```rust +/// # use scylla_cql::macros::SerializeRow; +/// #[derive(SerializeRow)] +/// # #[scylla(crate = scylla_cql)] +/// struct MyValues { +/// a: i32, +/// b: Option, +/// c: Vec, +/// } +/// ``` +/// +/// # Attributes +/// +/// `#[scylla(crate = crate_name)]` +/// +/// By default, the code generated by the derive macro will refer to the items +/// defined by the driver (types, traits, etc.) via the `::scylla` path. +/// For example, it will refer to the [`SerializeRow`](crate::types::serialize::row::SerializeRow) trait +/// using the following path: +/// +/// ```rust,ignore +/// use ::scylla::_macro_internal::SerializeRow; +/// ``` +/// +/// Most users will simply add `scylla` to their dependencies, then use +/// the derive macro and the path above will work. However, there are some +/// niche cases where this path will _not_ work: +/// +/// - The `scylla` crate is imported under a different name, +/// - The `scylla` crate is _not imported at all_ - the macro actually +/// is defined in the `scylla-macros` crate and the generated code depends +/// on items defined in `scylla-cql`. +/// +/// It's not possible to automatically resolve those issues in the procedural +/// macro itself, so in those cases the user must provide an alternative path +/// to either the `scylla` or `scylla-cql` crate. +pub use scylla_macros::SerializeRow; + // Reexports for derive(IntoUserType) pub use bytes::{BufMut, Bytes, BytesMut}; diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index d8702100b6..d398a42281 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -561,7 +561,12 @@ mod tests { use crate::frame::value::{MaybeUnset, SerializedValues, ValueList}; use crate::types::serialize::RowWriter; - use super::{RowSerializationContext, SerializeRow}; + use super::{ + BuiltinSerializationError, BuiltinSerializationErrorKind, BuiltinTypeCheckError, + BuiltinTypeCheckErrorKind, RowSerializationContext, SerializeCql, SerializeRow, + }; + + use scylla_macros::SerializeRow; fn col_spec(name: &str, typ: ColumnType) -> ColumnSpec { ColumnSpec { @@ -672,4 +677,165 @@ mod tests { ); assert_eq!(typed_data, erased_data); } + + fn do_serialize(t: T, columns: &[ColumnSpec]) -> Vec { + let ctx = RowSerializationContext { columns }; + let mut ret = Vec::new(); + let mut builder = RowWriter::new(&mut ret); + t.serialize(&ctx, &mut builder).unwrap(); + ret + } + + fn col(name: &str, typ: ColumnType) -> ColumnSpec { + ColumnSpec { + table_spec: TableSpec { + ks_name: "ks".to_string(), + table_name: "tbl".to_string(), + }, + name: name.to_string(), + typ, + } + } + + // Do not remove. It's not used in tests but we keep it here to check that + // we properly ignore warnings about unused variables, unnecessary `mut`s + // etc. that usually pop up when generating code for empty structs. + #[derive(SerializeRow)] + #[scylla(crate = crate)] + struct TestRowWithNoColumns {} + + #[derive(SerializeRow, Debug, PartialEq, Eq, Default)] + #[scylla(crate = crate)] + struct TestRowWithColumnSorting { + a: String, + b: i32, + c: Vec, + } + + #[test] + fn test_row_serialization_with_column_sorting_correct_order() { + let spec = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + col("c", ColumnType::List(Box::new(ColumnType::BigInt))), + ]; + + let reference = do_serialize(("Ala ma kota", 42i32, vec![1i64, 2i64, 3i64]), &spec); + let row = do_serialize( + TestRowWithColumnSorting { + a: "Ala ma kota".to_owned(), + b: 42, + c: vec![1, 2, 3], + }, + &spec, + ); + + assert_eq!(reference, row); + } + + #[test] + fn test_row_serialization_with_column_sorting_incorrect_order() { + // The order of two last columns is swapped + let spec = [ + col("a", ColumnType::Text), + col("c", ColumnType::List(Box::new(ColumnType::BigInt))), + col("b", ColumnType::Int), + ]; + + let reference = do_serialize(("Ala ma kota", vec![1i64, 2i64, 3i64], 42i32), &spec); + let row = do_serialize( + TestRowWithColumnSorting { + a: "Ala ma kota".to_owned(), + b: 42, + c: vec![1, 2, 3], + }, + &spec, + ); + + assert_eq!(reference, row); + } + + #[test] + fn test_row_serialization_failing_type_check() { + let row = TestRowWithColumnSorting::default(); + let mut data = Vec::new(); + let mut row_writer = RowWriter::new(&mut data); + + let spec_without_c = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + // Missing column c + ]; + + let ctx = RowSerializationContext { + columns: &spec_without_c, + }; + let err = <_ as SerializeRow>::serialize(&row, &ctx, &mut row_writer).unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::ColumnMissingForValue { .. } + )); + + let spec_duplicate_column = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + col("c", ColumnType::List(Box::new(ColumnType::BigInt))), + // Unexpected last column + col("d", ColumnType::Counter), + ]; + + let ctx = RowSerializationContext { + columns: &spec_duplicate_column, + }; + let err = <_ as SerializeRow>::serialize(&row, &ctx, &mut row_writer).unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::MissingValueForColumn { .. } + )); + + let spec_wrong_type = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + col("c", ColumnType::TinyInt), // Wrong type + ]; + + let ctx = RowSerializationContext { + columns: &spec_wrong_type, + }; + let err = <_ as SerializeRow>::serialize(&row, &ctx, &mut row_writer).unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinSerializationErrorKind::ColumnSerializationFailed { .. } + )); + } + + #[derive(SerializeRow)] + #[scylla(crate = crate)] + struct TestRowWithGenerics<'a, T: SerializeCql> { + a: &'a str, + b: T, + } + + #[test] + fn test_row_serialization_with_generics() { + // A minimal smoke test just to test that it works. + fn check_with_type(typ: ColumnType, t: T) { + let spec = [col("a", ColumnType::Text), col("b", typ)]; + let reference = do_serialize(("Ala ma kota", t), &spec); + let row = do_serialize( + TestRowWithGenerics { + a: "Ala ma kota", + b: t, + }, + &spec, + ); + assert_eq!(reference, row); + } + + check_with_type(ColumnType::Int, 123_i32); + check_with_type(ColumnType::Double, 123_f64); + } } diff --git a/scylla-macros/src/lib.rs b/scylla-macros/src/lib.rs index 84ee58bca0..64ce0ee06e 100644 --- a/scylla-macros/src/lib.rs +++ b/scylla-macros/src/lib.rs @@ -18,6 +18,15 @@ pub fn serialize_cql_derive(tokens_input: TokenStream) -> TokenStream { } } +/// See the documentation for this item in the `scylla` crate. +#[proc_macro_derive(SerializeRow, attributes(scylla))] +pub fn serialize_row_derive(tokens_input: TokenStream) -> TokenStream { + match serialize::row::derive_serialize_row(tokens_input) { + Ok(t) => t.into_token_stream().into(), + Err(e) => e.into_compile_error().into(), + } +} + /// #[derive(FromRow)] derives FromRow for struct /// Works only on simple structs without generics etc #[proc_macro_derive(FromRow, attributes(scylla_crate))] diff --git a/scylla-macros/src/serialize/mod.rs b/scylla-macros/src/serialize/mod.rs index 15fd9ae87c..53abe0f296 100644 --- a/scylla-macros/src/serialize/mod.rs +++ b/scylla-macros/src/serialize/mod.rs @@ -1 +1,2 @@ pub(crate) mod cql; +pub(crate) mod row; diff --git a/scylla-macros/src/serialize/row.rs b/scylla-macros/src/serialize/row.rs new file mode 100644 index 0000000000..0dd2356041 --- /dev/null +++ b/scylla-macros/src/serialize/row.rs @@ -0,0 +1,202 @@ +use darling::FromAttributes; +use proc_macro::TokenStream; +use proc_macro2::Span; +use syn::parse_quote; + +#[derive(FromAttributes)] +#[darling(attributes(scylla))] +struct Attributes { + #[darling(rename = "crate")] + crate_path: Option, +} + +impl Attributes { + fn crate_path(&self) -> syn::Path { + self.crate_path + .as_ref() + .map(|p| parse_quote!(#p::_macro_internal)) + .unwrap_or_else(|| parse_quote!(::scylla::_macro_internal)) + } +} + +struct Context { + attributes: Attributes, + fields: Vec, +} + +pub fn derive_serialize_row(tokens_input: TokenStream) -> Result { + let input: syn::DeriveInput = syn::parse(tokens_input)?; + let struct_name = input.ident.clone(); + let named_fields = crate::parser::parse_named_fields(&input, "SerializeRow")?; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let attributes = Attributes::from_attributes(&input.attrs)?; + + let crate_path = attributes.crate_path(); + let implemented_trait: syn::Path = parse_quote!(#crate_path::SerializeRow); + + let fields = named_fields.named.iter().cloned().collect(); + let ctx = Context { attributes, fields }; + let gen = ColumnSortingGenerator { ctx: &ctx }; + + let serialize_item = gen.generate_serialize(); + let is_empty_item = gen.generate_is_empty(); + + let res = parse_quote! { + impl #impl_generics #implemented_trait for #struct_name #ty_generics #where_clause { + #serialize_item + #is_empty_item + } + }; + Ok(res) +} + +impl Context { + fn generate_mk_typck_err(&self) -> syn::Stmt { + let crate_path = self.attributes.crate_path(); + parse_quote! { + let mk_typck_err = |kind: #crate_path::BuiltinRowTypeCheckErrorKind| -> #crate_path::SerializationError { + #crate_path::SerializationError::new( + #crate_path::BuiltinRowTypeCheckError { + rust_name: ::std::any::type_name::(), + kind, + } + ) + }; + } + } + + fn generate_mk_ser_err(&self) -> syn::Stmt { + let crate_path = self.attributes.crate_path(); + parse_quote! { + let mk_ser_err = |kind: #crate_path::BuiltinRowSerializationErrorKind| -> #crate_path::SerializationError { + #crate_path::SerializationError::new( + #crate_path::BuiltinRowSerializationError { + rust_name: ::std::any::type_name::(), + kind, + } + ) + }; + } + } +} + +// Generates an implementation of the trait which sorts the columns according +// to how they are defined in prepared statement metadata. +struct ColumnSortingGenerator<'a> { + ctx: &'a Context, +} + +impl<'a> ColumnSortingGenerator<'a> { + fn generate_serialize(&self) -> syn::TraitItemFn { + // Need to: + // - Check that all required columns are there and no more + // - Check that the column types match + let mut statements: Vec = Vec::new(); + + let crate_path = self.ctx.attributes.crate_path(); + + let rust_field_idents = self + .ctx + .fields + .iter() + .map(|f| f.ident.clone()) + .collect::>(); + let rust_field_names = rust_field_idents + .iter() + .map(|i| i.as_ref().unwrap().to_string()) + .collect::>(); + let udt_field_names = rust_field_names.clone(); // For now, it's the same + let field_types = self.ctx.fields.iter().map(|f| &f.ty).collect::>(); + + // Declare a helper lambda for creating errors + statements.push(self.ctx.generate_mk_typck_err()); + statements.push(self.ctx.generate_mk_ser_err()); + + // Generate a "visited" flag for each field + let visited_flag_names = rust_field_names + .iter() + .map(|s| syn::Ident::new(&format!("visited_flag_{}", s), Span::call_site())) + .collect::>(); + statements.extend::>(parse_quote! { + #(let mut #visited_flag_names = false;)* + }); + + // Generate a variable that counts down visited fields. + let field_count = self.ctx.fields.len(); + statements.push(parse_quote! { + let mut remaining_count = #field_count; + }); + + // Generate a loop over the fields and a `match` block to match on + // the field name. + statements.push(parse_quote! { + for spec in ctx.columns() { + match ::std::string::String::as_str(&spec.name) { + #( + #udt_field_names => { + let sub_writer = #crate_path::RowWriter::make_cell_writer(writer); + match <#field_types as #crate_path::SerializeCql>::serialize(&self.#rust_field_idents, &spec.typ, sub_writer) { + ::std::result::Result::Ok(_proof) => {} + ::std::result::Result::Err(err) => { + return ::std::result::Result::Err(mk_ser_err( + #crate_path::BuiltinRowSerializationErrorKind::ColumnSerializationFailed { + name: <_ as ::std::clone::Clone>::clone(&spec.name), + err, + } + )); + } + } + if !#visited_flag_names { + #visited_flag_names = true; + remaining_count -= 1; + } + } + )* + _ => return ::std::result::Result::Err(mk_typck_err( + #crate_path::BuiltinRowTypeCheckErrorKind::MissingValueForColumn { + name: <_ as ::std::clone::Clone>::clone(&&spec.name), + } + )), + } + } + }); + + // Finally, check that all fields were consumed. + // If there are some missing fields, return an error + statements.push(parse_quote! { + if remaining_count > 0 { + #( + if !#visited_flag_names { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::BuiltinRowTypeCheckErrorKind::ColumnMissingForValue { + name: <_ as ::std::string::ToString>::to_string(#rust_field_names), + } + )); + } + )* + ::std::unreachable!() + } + }); + + parse_quote! { + fn serialize<'b>( + &self, + ctx: &#crate_path::RowSerializationContext, + writer: &mut #crate_path::RowWriter<'b>, + ) -> ::std::result::Result<(), #crate_path::SerializationError> { + #(#statements)* + ::std::result::Result::Ok(()) + } + } + } + + fn generate_is_empty(&self) -> syn::TraitItemFn { + let is_empty = self.ctx.fields.is_empty(); + parse_quote! { + #[inline] + fn is_empty(&self) -> bool { + #is_empty + } + } + } +} diff --git a/scylla/tests/integration/hygiene.rs b/scylla/tests/integration/hygiene.rs index 12d55ccb61..cf2aaed7b3 100644 --- a/scylla/tests/integration/hygiene.rs +++ b/scylla/tests/integration/hygiene.rs @@ -64,7 +64,7 @@ macro_rules! test_crate { assert_eq!(sv, sv2); } - #[derive(_scylla::macros::SerializeCql)] + #[derive(_scylla::macros::SerializeCql, _scylla::macros::SerializeRow)] #[scylla(crate = _scylla)] struct TestStructNew { x: ::core::primitive::i32, From dcb4cf48f2841d7ad37d6849b570788398ddde41 Mon Sep 17 00:00:00 2001 From: Piotr Dulikowski Date: Fri, 27 Oct 2023 07:44:49 +0200 Subject: [PATCH 3/4] scylla-macros: implement enforce_order flavor of SerializeCql Some users might not need the additional robustness of `SerializeCql` that comes from sorting the fields before serializing, as they are used to the current behavior of `Value` and properly set the order of the fields in their Rust struct. In order to give them some performance boost, add an additional mode to `SerializeCql` called "enforce_order" which expects that the order of the fields in the struct is kept in sync with the DB definition of the UDT. It's still safe to use because, as the struct fields are serialized, their names are compared with the fields in the UDT definition order and serialization fails if the field name on some position is mismatched. --- scylla-cql/src/macros.rs | 19 ++- scylla-cql/src/types/serialize/value.rs | 170 ++++++++++++++++++++++++ scylla-macros/src/serialize/cql.rs | 121 ++++++++++++++++- scylla-macros/src/serialize/mod.rs | 18 +++ 4 files changed, 323 insertions(+), 5 deletions(-) diff --git a/scylla-cql/src/macros.rs b/scylla-cql/src/macros.rs index 8f53e24fa9..2b7b0b4ae7 100644 --- a/scylla-cql/src/macros.rs +++ b/scylla-cql/src/macros.rs @@ -16,9 +16,7 @@ pub use scylla_macros::ValueList; /// Derive macro for the [`SerializeCql`](crate::types::serialize::value::SerializeCql) trait /// which serializes given Rust structure as a User Defined Type (UDT). /// -/// At the moment, only structs with named fields are supported. The generated -/// implementation of the trait will match the struct fields to UDT fields -/// by name automatically. +/// At the moment, only structs with named fields are supported. /// /// Serialization will fail if there are some fields in the UDT that don't match /// to any of the Rust struct fields, _or vice versa_. @@ -50,6 +48,21 @@ pub use scylla_macros::ValueList; /// /// # Attributes /// +/// `#[scylla(flavor = "flavor_name")]` +/// +/// Allows to choose one of the possible "flavors", i.e. the way how the +/// generated code will approach serialization. Possible flavors are: +/// +/// - `"match_by_name"` (default) - the generated implementation _does not +/// require_ the fields in the Rust struct to be in the same order as the +/// fields in the UDT. During serialization, the implementation will take +/// care to serialize the fields in the order which the database expects. +/// - `"enforce_order"` - the generated implementation _requires_ the fields +/// in the Rust struct to be in the same order as the fields in the UDT. +/// If the order is incorrect, type checking/serialization will fail. +/// This is a less robust flavor than `"match_by_name"`, but should be +/// slightly more performant as it doesn't need to perform lookups by name. +/// /// `#[scylla(crate = crate_name)]` /// /// By default, the code generated by the derive macro will refer to the items diff --git a/scylla-cql/src/types/serialize/value.rs b/scylla-cql/src/types/serialize/value.rs index 85033dac25..567b59cfab 100644 --- a/scylla-cql/src/types/serialize/value.rs +++ b/scylla-cql/src/types/serialize/value.rs @@ -1314,6 +1314,12 @@ pub enum UdtTypeCheckErrorKind { /// The Rust data contains a field that is not present in the UDT UnexpectedFieldInDestination { field_name: String }, + + /// A different field name was expected at given position. + FieldNameMismatch { + rust_field_name: String, + db_field_name: String, + }, } impl Display for UdtTypeCheckErrorKind { @@ -1337,6 +1343,10 @@ impl Display for UdtTypeCheckErrorKind { f, "the field {field_name} present in the Rust data is not present in the CQL type" ), + UdtTypeCheckErrorKind::FieldNameMismatch { rust_field_name, db_field_name } => write!( + f, + "expected field with name {db_field_name} at given position, but the Rust field name is {rust_field_name}" + ), } } } @@ -1668,4 +1678,164 @@ mod tests { check_with_type(ColumnType::Int, 123_i32, CqlValue::Int(123_i32)); check_with_type(ColumnType::Double, 123_f64, CqlValue::Double(123_f64)); } + + #[derive(SerializeCql, Debug, PartialEq, Eq, Default)] + #[scylla(crate = crate, flavor = "enforce_order")] + struct TestUdtWithEnforcedOrder { + a: String, + b: i32, + c: Vec, + } + + #[test] + fn test_udt_serialization_with_enforced_order_correct_order() { + let typ = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let reference = do_serialize( + CqlValue::UserDefinedType { + keyspace: "ks".to_string(), + type_name: "typ".to_string(), + fields: vec![ + ( + "a".to_string(), + Some(CqlValue::Text(String::from("Ala ma kota"))), + ), + ("b".to_string(), Some(CqlValue::Int(42))), + ( + "c".to_string(), + Some(CqlValue::List(vec![ + CqlValue::BigInt(1), + CqlValue::BigInt(2), + CqlValue::BigInt(3), + ])), + ), + ], + }, + &typ, + ); + let udt = do_serialize( + TestUdtWithEnforcedOrder { + a: "Ala ma kota".to_owned(), + b: 42, + c: vec![1, 2, 3], + }, + &typ, + ); + + assert_eq!(reference, udt); + } + + #[test] + fn test_udt_serialization_with_enforced_order_failing_type_check() { + let typ_not_udt = ColumnType::Ascii; + let udt = TestUdtWithEnforcedOrder::default(); + + let mut data = Vec::new(); + + let err = <_ as SerializeCql>::serialize(&udt, &typ_not_udt, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NotUdt) + )); + + let typ = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + // Two first columns are swapped + ("b".to_string(), ColumnType::Int), + ("a".to_string(), ColumnType::Text), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let err = + <_ as SerializeCql>::serialize(&udt, &typ, CellWriter::new(&mut data)).unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::FieldNameMismatch { .. }) + )); + + let typ_without_c = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + // Last field is missing + ], + }; + + let err = <_ as SerializeCql>::serialize(&udt, &typ_without_c, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::MissingField { .. }) + )); + + let typ_unexpected_field = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + // Unexpected field + ("d".to_string(), ColumnType::Counter), + ], + }; + + let err = + <_ as SerializeCql>::serialize(&udt, &typ_unexpected_field, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::UnexpectedFieldInDestination { .. } + ) + )); + + let typ_unexpected_field = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ("c".to_string(), ColumnType::TinyInt), // Wrong column type + ], + }; + + let err = + <_ as SerializeCql>::serialize(&udt, &typ_unexpected_field, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinSerializationErrorKind::UdtError( + UdtSerializationErrorKind::FieldSerializationFailed { .. } + ) + )); + } } diff --git a/scylla-macros/src/serialize/cql.rs b/scylla-macros/src/serialize/cql.rs index f19e47b27c..d3c5788401 100644 --- a/scylla-macros/src/serialize/cql.rs +++ b/scylla-macros/src/serialize/cql.rs @@ -3,11 +3,15 @@ use proc_macro::TokenStream; use proc_macro2::Span; use syn::parse_quote; +use super::Flavor; + #[derive(FromAttributes)] #[darling(attributes(scylla))] struct Attributes { #[darling(rename = "crate")] crate_path: Option, + + flavor: Option, } impl Attributes { @@ -36,7 +40,11 @@ pub fn derive_serialize_cql(tokens_input: TokenStream) -> Result = match ctx.attributes.flavor { + Some(Flavor::MatchByName) | None => Box::new(FieldSortingGenerator { ctx: &ctx }), + Some(Flavor::EnforceOrder) => Box::new(FieldOrderedGenerator { ctx: &ctx }), + }; let serialize_item = gen.generate_serialize(); @@ -93,13 +101,17 @@ impl Context { } } +trait Generator { + fn generate_serialize(&self) -> syn::TraitItemFn; +} + // Generates an implementation of the trait which sorts the fields according // to how it is defined in the database. struct FieldSortingGenerator<'a> { ctx: &'a Context, } -impl<'a> FieldSortingGenerator<'a> { +impl<'a> Generator for FieldSortingGenerator<'a> { fn generate_serialize(&self) -> syn::TraitItemFn { // Need to: // - Check that all required fields are there and no more @@ -222,3 +234,108 @@ impl<'a> FieldSortingGenerator<'a> { } } } + +// Generates an implementation of the trait which requires the fields +// to be placed in the same order as they are defined in the struct. +struct FieldOrderedGenerator<'a> { + ctx: &'a Context, +} + +impl<'a> Generator for FieldOrderedGenerator<'a> { + fn generate_serialize(&self) -> syn::TraitItemFn { + let mut statements: Vec = Vec::new(); + + let crate_path = self.ctx.attributes.crate_path(); + + // Declare a helper lambda for creating errors + statements.push(self.ctx.generate_mk_typck_err()); + statements.push(self.ctx.generate_mk_ser_err()); + + // Check that the type we want to serialize to is a UDT + statements.push( + self.ctx + .generate_udt_type_match(parse_quote!(#crate_path::UdtTypeCheckErrorKind::NotUdt)), + ); + + // Turn the cell writer into a value builder + statements.push(parse_quote! { + let mut builder = #crate_path::CellWriter::into_value_builder(writer); + }); + + // Create an iterator over fields + statements.push(parse_quote! { + let mut field_iter = field_types.iter(); + }); + + // Serialize each field + for field in self.ctx.fields.iter() { + let rust_field_ident = field.ident.as_ref().unwrap(); + let rust_field_name = rust_field_ident.to_string(); + let typ = &field.ty; + statements.push(parse_quote! { + match field_iter.next() { + Some((field_name, typ)) => { + if field_name == #rust_field_name { + let sub_builder = #crate_path::CellValueBuilder::make_sub_writer(&mut builder); + match <#typ as #crate_path::SerializeCql>::serialize(&self.#rust_field_ident, typ, sub_builder) { + Ok(_proof) => {}, + Err(err) => { + return ::std::result::Result::Err(mk_ser_err( + #crate_path::UdtSerializationErrorKind::FieldSerializationFailed { + field_name: <_ as ::std::clone::Clone>::clone(field_name), + err, + } + )); + } + } + } else { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::UdtTypeCheckErrorKind::FieldNameMismatch { + rust_field_name: <_ as ::std::string::ToString>::to_string(#rust_field_name), + db_field_name: <_ as ::std::clone::Clone>::clone(field_name), + } + )); + } + } + None => { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::UdtTypeCheckErrorKind::MissingField { + field_name: <_ as ::std::string::ToString>::to_string(#rust_field_name), + } + )); + } + } + }); + } + + // Check whether there are some fields remaining + statements.push(parse_quote! { + if let Some((field_name, typ)) = field_iter.next() { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::UdtTypeCheckErrorKind::UnexpectedFieldInDestination { + field_name: <_ as ::std::clone::Clone>::clone(field_name), + } + )); + } + }); + + parse_quote! { + fn serialize<'b>( + &self, + typ: &#crate_path::ColumnType, + writer: #crate_path::CellWriter<'b>, + ) -> ::std::result::Result<#crate_path::WrittenCellProof<'b>, #crate_path::SerializationError> { + #(#statements)* + let proof = #crate_path::CellValueBuilder::finish(builder) + .map_err(|_| #crate_path::SerializationError::new( + #crate_path::BuiltinTypeSerializationError { + rust_name: ::std::any::type_name::(), + got: <_ as ::std::clone::Clone>::clone(typ), + kind: #crate_path::BuiltinTypeSerializationErrorKind::SizeOverflow, + } + ) as #crate_path::SerializationError)?; + ::std::result::Result::Ok(proof) + } + } + } +} diff --git a/scylla-macros/src/serialize/mod.rs b/scylla-macros/src/serialize/mod.rs index 53abe0f296..183183fa91 100644 --- a/scylla-macros/src/serialize/mod.rs +++ b/scylla-macros/src/serialize/mod.rs @@ -1,2 +1,20 @@ +use darling::FromMeta; + pub(crate) mod cql; pub(crate) mod row; + +#[derive(Copy, Clone, PartialEq, Eq)] +enum Flavor { + MatchByName, + EnforceOrder, +} + +impl FromMeta for Flavor { + fn from_string(value: &str) -> darling::Result { + match value { + "match_by_name" => Ok(Self::MatchByName), + "enforce_order" => Ok(Self::EnforceOrder), + _ => Err(darling::Error::unknown_value(value)), + } + } +} From a255d81aa3e0c5e2661755153022c535ed0d15ea Mon Sep 17 00:00:00 2001 From: Piotr Dulikowski Date: Fri, 27 Oct 2023 09:16:38 +0200 Subject: [PATCH 4/4] scylla-macros: implement enforce_order flavor of SerializeRow Like in the case of `SerializeRow`, some people might be used to working with the old `ValueList` and already order their Rust struct fields with accordance to the queries they are used with and don't need the overhead associated with looking up columns by name. The `enforce_order` mode is added to `SerializeRow` which works analogously as in `SerializeCql` - expects the columns to be in the correct order and verifies that this is the case when serializing, but just fails instead of reordering if that expectation is broken. --- scylla-cql/src/macros.rs | 19 ++++- scylla-cql/src/types/serialize/row.rs | 110 +++++++++++++++++++++++++ scylla-macros/src/serialize/row.rs | 113 +++++++++++++++++++++++++- 3 files changed, 237 insertions(+), 5 deletions(-) diff --git a/scylla-cql/src/macros.rs b/scylla-cql/src/macros.rs index 2b7b0b4ae7..51cc79ce24 100644 --- a/scylla-cql/src/macros.rs +++ b/scylla-cql/src/macros.rs @@ -91,9 +91,7 @@ pub use scylla_macros::SerializeCql; /// Derive macro for the [`SerializeRow`](crate::types::serialize::row::SerializeRow) trait /// which serializes given Rust structure into bind markers for a CQL statement. /// -/// At the moment, only structs with named fields are supported. The generated -/// implementation of the trait will match the struct fields to bind markers/columns -/// by name automatically. +/// At the moment, only structs with named fields are supported. /// /// Serialization will fail if there are some bind markers/columns in the statement /// that don't match to any of the Rust struct fields, _or vice versa_. @@ -127,6 +125,21 @@ pub use scylla_macros::SerializeCql; /// /// # Attributes /// +/// `#[scylla(flavor = "flavor_name")]` +/// +/// Allows to choose one of the possible "flavors", i.e. the way how the +/// generated code will approach serialization. Possible flavors are: +/// +/// - `"match_by_name"` (default) - the generated implementation _does not +/// require_ the fields in the Rust struct to be in the same order as the +/// columns/bind markers. During serialization, the implementation will take +/// care to serialize the fields in the order which the database expects. +/// - `"enforce_order"` - the generated implementation _requires_ the fields +/// in the Rust struct to be in the same order as the columns/bind markers. +/// If the order is incorrect, type checking/serialization will fail. +/// This is a less robust flavor than `"match_by_name"`, but should be +/// slightly more performant as it doesn't need to perform lookups by name. +/// /// `#[scylla(crate = crate_name)]` /// /// By default, the code generated by the derive macro will refer to the items diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index d398a42281..213af49c0f 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -504,6 +504,12 @@ pub enum BuiltinTypeCheckErrorKind { /// A value required by the statement is not provided by the Rust type. ColumnMissingForValue { name: String }, + + /// A different column name was expected at given position. + ColumnNameMismatch { + rust_column_name: String, + db_column_name: String, + }, } impl Display for BuiltinTypeCheckErrorKind { @@ -524,6 +530,10 @@ impl Display for BuiltinTypeCheckErrorKind { "value for column {name} was provided, but there is no bind marker for this column in the query" ) } + BuiltinTypeCheckErrorKind::ColumnNameMismatch { rust_column_name, db_column_name } => write!( + f, + "expected column with name {db_column_name} at given position, but the Rust field name is {rust_column_name}" + ), } } } @@ -838,4 +848,104 @@ mod tests { check_with_type(ColumnType::Int, 123_i32); check_with_type(ColumnType::Double, 123_f64); } + + #[derive(SerializeRow, Debug, PartialEq, Eq, Default)] + #[scylla(crate = crate, flavor = "enforce_order")] + struct TestRowWithEnforcedOrder { + a: String, + b: i32, + c: Vec, + } + + #[test] + fn test_row_serialization_with_enforced_order_correct_order() { + let spec = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + col("c", ColumnType::List(Box::new(ColumnType::BigInt))), + ]; + + let reference = do_serialize(("Ala ma kota", 42i32, vec![1i64, 2i64, 3i64]), &spec); + let row = do_serialize( + TestRowWithEnforcedOrder { + a: "Ala ma kota".to_owned(), + b: 42, + c: vec![1, 2, 3], + }, + &spec, + ); + + assert_eq!(reference, row); + } + + #[test] + fn test_row_serialization_with_enforced_order_failing_type_check() { + let row = TestRowWithEnforcedOrder::default(); + let mut data = Vec::new(); + let mut writer = RowWriter::new(&mut data); + + // The order of two last columns is swapped + let spec = [ + col("a", ColumnType::Text), + col("c", ColumnType::List(Box::new(ColumnType::BigInt))), + col("b", ColumnType::Int), + ]; + let ctx = RowSerializationContext { columns: &spec }; + let err = <_ as SerializeRow>::serialize(&row, &ctx, &mut writer).unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::ColumnNameMismatch { .. } + )); + + let spec_without_c = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + // Missing column c + ]; + + let ctx = RowSerializationContext { + columns: &spec_without_c, + }; + let err = <_ as SerializeRow>::serialize(&row, &ctx, &mut writer).unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::ColumnMissingForValue { .. } + )); + + let spec_duplicate_column = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + col("c", ColumnType::List(Box::new(ColumnType::BigInt))), + // Unexpected last column + col("d", ColumnType::Counter), + ]; + + let ctx = RowSerializationContext { + columns: &spec_duplicate_column, + }; + let err = <_ as SerializeRow>::serialize(&row, &ctx, &mut writer).unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::MissingValueForColumn { .. } + )); + + let spec_wrong_type = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + col("c", ColumnType::TinyInt), // Wrong type + ]; + + let ctx = RowSerializationContext { + columns: &spec_wrong_type, + }; + let err = <_ as SerializeRow>::serialize(&row, &ctx, &mut writer).unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinSerializationErrorKind::ColumnSerializationFailed { .. } + )); + } } diff --git a/scylla-macros/src/serialize/row.rs b/scylla-macros/src/serialize/row.rs index 0dd2356041..ee0f702d27 100644 --- a/scylla-macros/src/serialize/row.rs +++ b/scylla-macros/src/serialize/row.rs @@ -3,11 +3,15 @@ use proc_macro::TokenStream; use proc_macro2::Span; use syn::parse_quote; +use super::Flavor; + #[derive(FromAttributes)] #[darling(attributes(scylla))] struct Attributes { #[darling(rename = "crate")] crate_path: Option, + + flavor: Option, } impl Attributes { @@ -36,7 +40,11 @@ pub fn derive_serialize_row(tokens_input: TokenStream) -> Result = match ctx.attributes.flavor { + Some(Flavor::MatchByName) | None => Box::new(ColumnSortingGenerator { ctx: &ctx }), + Some(Flavor::EnforceOrder) => Box::new(ColumnOrderedGenerator { ctx: &ctx }), + }; let serialize_item = gen.generate_serialize(); let is_empty_item = gen.generate_is_empty(); @@ -80,13 +88,18 @@ impl Context { } } +trait Generator { + fn generate_serialize(&self) -> syn::TraitItemFn; + fn generate_is_empty(&self) -> syn::TraitItemFn; +} + // Generates an implementation of the trait which sorts the columns according // to how they are defined in prepared statement metadata. struct ColumnSortingGenerator<'a> { ctx: &'a Context, } -impl<'a> ColumnSortingGenerator<'a> { +impl<'a> Generator for ColumnSortingGenerator<'a> { fn generate_serialize(&self) -> syn::TraitItemFn { // Need to: // - Check that all required columns are there and no more @@ -200,3 +213,99 @@ impl<'a> ColumnSortingGenerator<'a> { } } } + +// Generates an implementation of the trait which requires the columns +// to be placed in the same order as they are defined in the struct. +struct ColumnOrderedGenerator<'a> { + ctx: &'a Context, +} + +impl<'a> Generator for ColumnOrderedGenerator<'a> { + fn generate_serialize(&self) -> syn::TraitItemFn { + let mut statements: Vec = Vec::new(); + + let crate_path = self.ctx.attributes.crate_path(); + + // Declare a helper lambda for creating errors + statements.push(self.ctx.generate_mk_typck_err()); + statements.push(self.ctx.generate_mk_ser_err()); + + // Create an iterator over fields + statements.push(parse_quote! { + let mut column_iter = ctx.columns().iter(); + }); + + // Serialize each field + for field in self.ctx.fields.iter() { + let rust_field_ident = field.ident.as_ref().unwrap(); + let rust_field_name = rust_field_ident.to_string(); + let typ = &field.ty; + statements.push(parse_quote! { + match column_iter.next() { + Some(spec) => { + if spec.name == #rust_field_name { + let cell_writer = #crate_path::RowWriter::make_cell_writer(writer); + match <#typ as #crate_path::SerializeCql>::serialize(&self.#rust_field_ident, &spec.typ, cell_writer) { + Ok(_proof) => {}, + Err(err) => { + return ::std::result::Result::Err(mk_ser_err( + #crate_path::BuiltinRowSerializationErrorKind::ColumnSerializationFailed { + name: <_ as ::std::clone::Clone>::clone(&spec.name), + err, + } + )); + } + } + } else { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::BuiltinRowTypeCheckErrorKind::ColumnNameMismatch { + rust_column_name: <_ as ::std::string::ToString>::to_string(#rust_field_name), + db_column_name: <_ as ::std::clone::Clone>::clone(&spec.name), + } + )); + } + } + None => { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::BuiltinRowTypeCheckErrorKind::ColumnMissingForValue { + name: <_ as ::std::string::ToString>::to_string(#rust_field_name), + } + )); + } + } + }); + } + + // Check whether there are some columns remaining + statements.push(parse_quote! { + if let Some(spec) = column_iter.next() { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::BuiltinRowTypeCheckErrorKind::MissingValueForColumn { + name: <_ as ::std::clone::Clone>::clone(&spec.name), + } + )); + } + }); + + parse_quote! { + fn serialize<'b>( + &self, + ctx: &#crate_path::RowSerializationContext, + writer: &mut #crate_path::RowWriter<'b>, + ) -> ::std::result::Result<(), #crate_path::SerializationError> { + #(#statements)* + ::std::result::Result::Ok(()) + } + } + } + + fn generate_is_empty(&self) -> syn::TraitItemFn { + let is_empty = self.ctx.fields.is_empty(); + parse_quote! { + #[inline] + fn is_empty(&self) -> bool { + #is_empty + } + } + } +}