Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make #[diesel(embed)] fields be somewhat-checked by #[check_for_backend] #4484

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ struct UserCorrect {
name: String,
}

#[derive(Selectable, Queryable)]
#[diesel(check_for_backend(diesel::pg::Pg))]
struct SelectableWithEmbed {
#[diesel(embed)]
embed_user: User,
}

fn main() {
let mut conn = PgConnection::establish("...").unwrap();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,38 @@ error[E0277]: cannot deserialize a value of the database type `diesel::sql_types
= note: required for `std::string::String` to implement `FromSqlRow<diesel::sql_types::Integer, Pg>`
= help: see issue #48214

error[E0277]: the trait bound `(std::string::String, i32): FromStaticSqlRow<(diesel::sql_types::Integer, diesel::sql_types::Text), Pg>` is not satisfied
--> tests/fail/selectable_with_typemisamatch.rs:32:17
|
32 | embed_user: User,
| ^^^^ the trait `FromStaticSqlRow<(diesel::sql_types::Integer, diesel::sql_types::Text), Pg>` is not implemented for `(std::string::String, i32)`
|
= help: the following other types implement trait `FromStaticSqlRow<ST, DB>`:
`(T0,)` implements `FromStaticSqlRow<(ST0,), __DB>`
`(T1, T0)` implements `FromStaticSqlRow<(ST1, ST0), __DB>`
`(T1, T2, T0)` implements `FromStaticSqlRow<(ST1, ST2, ST0), __DB>`
`(T1, T2, T3, T0)` implements `FromStaticSqlRow<(ST1, ST2, ST3, ST0), __DB>`
`(T1, T2, T3, T4, T0)` implements `FromStaticSqlRow<(ST1, ST2, ST3, ST4, ST0), __DB>`
`(T1, T2, T3, T4, T5, T0)` implements `FromStaticSqlRow<(ST1, ST2, ST3, ST4, ST5, ST0), __DB>`
`(T1, T2, T3, T4, T5, T6, T0)` implements `FromStaticSqlRow<(ST1, ST2, ST3, ST4, ST5, ST6, ST0), __DB>`
`(T1, T2, T3, T4, T5, T6, T7, T0)` implements `FromStaticSqlRow<(ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST0), __DB>`
and $N others
note: required for `User` to implement `diesel::Queryable<(diesel::sql_types::Integer, diesel::sql_types::Text), Pg>`
--> tests/fail/selectable_with_typemisamatch.rs:12:22
|
12 | #[derive(Selectable, Queryable)]
| ^^^^^^^^^ unsatisfied trait bound introduced in this `derive` macro
...
15 | struct User {
| ^^^^
= note: required for `User` to implement `FromSqlRow<(diesel::sql_types::Integer, diesel::sql_types::Text), Pg>`
= help: see issue #48214
= note: this error originates in the derive macro `Queryable` (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0277]: the trait bound `diesel::expression::select_by::SelectBy<User, _>: load_dsl::private::CompatibleType<_, _>` is not satisfied
--> tests/fail/selectable_with_typemisamatch.rs:33:15
--> tests/fail/selectable_with_typemisamatch.rs:40:15
|
33 | .load(&mut conn)
40 | .load(&mut conn)
| ---- ^^^^^^^^^ the trait `load_dsl::private::CompatibleType<_, _>` is not implemented for `diesel::expression::select_by::SelectBy<User, _>`
| |
| required by a bound introduced by this call
Expand Down
69 changes: 47 additions & 22 deletions diesel_derives/src/selectable.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use proc_macro2::TokenStream;
use quote::quote;
use std::borrow::Cow;
use syn::spanned::Spanned;
use syn::DeriveInput;
use syn::{parse_quote, Result};
use syn::{parse_quote, DeriveInput, Result};

use crate::field::Field;
use crate::model::Model;
Expand All @@ -11,7 +11,8 @@ use crate::util::wrap_in_dummy_mod;
pub fn derive(item: DeriveInput) -> Result<TokenStream> {
let model = Model::from_item(&item, false, false)?;

let (_, ty_generics, original_where_clause) = item.generics.split_for_impl();
let (original_impl_generics, ty_generics, original_where_clause) =
item.generics.split_for_impl();

let mut generics = item.generics.clone();
generics
Expand All @@ -21,8 +22,7 @@ pub fn derive(item: DeriveInput) -> Result<TokenStream> {
for embed_field in model.fields().iter().filter(|f| f.embed()) {
let embed_ty = &embed_field.ty;
generics
.where_clause
.get_or_insert_with(|| parse_quote!(where))
.make_where_clause()
.predicates
.push(parse_quote!(#embed_ty: Selectable<__DB>));
}
Expand All @@ -32,12 +32,16 @@ pub fn derive(item: DeriveInput) -> Result<TokenStream> {
let struct_name = &item.ident;

let mut compile_errors: Vec<syn::Error> = Vec::new();
let field_columns_ty = model
let field_select_expression_type_builders = model
.fields()
.iter()
.map(|f| field_column_ty(f, &model, &mut compile_errors))
.map(|f| field_select_expression_ty_builder(f, &model, &mut compile_errors))
.collect::<Result<Vec<_>>>()?;
let field_columns_inst = model
let field_select_expression_types = field_select_expression_type_builders
.iter()
.map(|f| f.type_with_backend(&parse_quote!(__DB)))
.collect::<Vec<_>>();
let field_select_expressions = model
.fields()
.iter()
.map(|f| field_column_inst(f, &model))
Expand All @@ -47,12 +51,12 @@ pub fn derive(item: DeriveInput) -> Result<TokenStream> {
let field_check_bound = model
.fields()
.iter()
.zip(&field_columns_ty)
.filter(|(f, _)| !f.embed())
.flat_map(|(f, ty)| {
.zip(&field_select_expression_type_builders)
.flat_map(|(f, ty_builder)| {
backends.iter().map(move |b| {
let span = f.ty.span();
let field_ty = to_field_ty_bound(f.ty_for_deserialize())?;
let ty = ty_builder.type_with_backend(b);
Ok(syn::parse_quote_spanned! {span =>
#field_ty: diesel::deserialize::FromSqlRow<diesel::dsl::SqlTypeOf<#ty>, #b>
})
Expand All @@ -65,7 +69,7 @@ pub fn derive(item: DeriveInput) -> Result<TokenStream> {
where_clause.predicates.push(field_check);
}
Some(quote::quote! {
fn _check_field_compatibility #impl_generics()
fn _check_field_compatibility #original_impl_generics()
#where_clause
{}
})
Expand All @@ -85,10 +89,10 @@ pub fn derive(item: DeriveInput) -> Result<TokenStream> {
for #struct_name #ty_generics
#where_clause
{
type SelectExpression = (#(#field_columns_ty,)*);
type SelectExpression = (#(#field_select_expression_types,)*);

fn construct_selection() -> Self::SelectExpression {
(#(#field_columns_inst,)*)
(#(#field_select_expressions,)*)
}
}

Expand Down Expand Up @@ -124,11 +128,11 @@ fn to_field_ty_bound(field_ty: &syn::Type) -> Result<TokenStream> {
}
}

fn field_column_ty(
field: &Field,
fn field_select_expression_ty_builder<'a>(
field: &'a Field,
model: &Model,
compile_errors: &mut Vec<syn::Error>,
) -> Result<TokenStream> {
) -> Result<FieldSelectExpressionTyBuilder<'a>> {
if let Some(ref select_expression) = field.select_expression {
use dsl_auto_type::auto_type::expression_type_inference as type_inference;
let expr = &select_expression.item;
Expand All @@ -142,17 +146,38 @@ fn field_column_ty(
.build(),
);
compile_errors.extend(errors);
Ok(quote::quote!(#inferred_type))
Ok(FieldSelectExpressionTyBuilder::Always(
quote::quote!(#inferred_type),
))
} else if let Some(ref select_expression_type) = field.select_expression_type {
let ty = &select_expression_type.item;
Ok(quote!(#ty))
Ok(FieldSelectExpressionTyBuilder::Always(quote!(#ty)))
} else if field.embed() {
let embed_ty = &field.ty;
Ok(quote!(<#embed_ty as Selectable<__DB>>::SelectExpression))
Ok(FieldSelectExpressionTyBuilder::EmbedSelectable {
embed_ty: &field.ty,
})
} else {
let table_name = &model.table_names()[0];
let column_name = field.column_name()?.to_ident()?;
Ok(quote!(#table_name::#column_name))
Ok(FieldSelectExpressionTyBuilder::Always(
quote!(#table_name::#column_name),
))
}
}

enum FieldSelectExpressionTyBuilder<'a> {
Always(TokenStream),
EmbedSelectable { embed_ty: &'a syn::Type },
}

impl FieldSelectExpressionTyBuilder<'_> {
fn type_with_backend(&self, backend: &syn::TypePath) -> Cow<'_, TokenStream> {
match self {
FieldSelectExpressionTyBuilder::Always(ty) => Cow::Borrowed(ty),
FieldSelectExpressionTyBuilder::EmbedSelectable { embed_ty } => {
Cow::Owned(quote!(<#embed_ty as Selectable<#backend>>::SelectExpression))
}
}
}
}

Expand Down
Loading