From d8d9d04b456cf6f3e3035093e3bfd97f2c671d18 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 31 May 2024 13:25:08 -0700 Subject: [PATCH 1/6] fix(postgres): get correctly qualified type name in describe --- sqlx-postgres/src/connection/describe.rs | 11 +++++++++- tests/postgres/derives.rs | 26 ++++++++++++++++++++++++ tests/postgres/setup.sql | 4 ++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index 71f9b9b316..1c5816b4a1 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -187,7 +187,16 @@ impl PgConnection { fn fetch_type_by_oid(&mut self, oid: Oid) -> BoxFuture<'_, Result> { Box::pin(async move { let (name, typ_type, category, relation_id, element, base_type): (String, i8, i8, Oid, Oid, Oid) = query_as( - "SELECT typname, typtype, typcategory, typrelid, typelem, typbasetype FROM pg_catalog.pg_type WHERE oid = $1", + // Converting the OID to `regtype` and then `text` will give us the name that + // the type will need to be found at by search_path. + "SELECT oid::regtype::text, \ + typtype, \ + typcategory, \ + typrelid, \ + typelem, \ + typbasetype \ + FROM pg_catalog.pg_type \ + WHERE oid = $1", ) .bind(oid) .fetch_one(&mut *self) diff --git a/tests/postgres/derives.rs b/tests/postgres/derives.rs index dd05e92907..d840589f78 100644 --- a/tests/postgres/derives.rs +++ b/tests/postgres/derives.rs @@ -724,3 +724,29 @@ async fn test_skip() -> anyhow::Result<()> { Ok(()) } + +#[cfg(feature = "macros")] +#[sqlx_macros::test] +async fn test_enum_with_schema() -> anyhow::Result<()> { + #[derive(Debug, PartialEq, Eq, sqlx::Type)] + #[sqlx(type_name = "foo.\"Foo\"")] + enum Foo { + Bar, + Baz, + } + + let mut conn = new::().await?; + + let foo: Foo = sqlx::query_scalar("SELECT $1::foo.\"Foo\"") + .bind(Foo::Bar) + .fetch_one(&mut conn).await?; + + assert_eq!(foo, Foo::Bar); + + let foo: Foo = sqlx::query_scalar("SELECT $1::foo.\"Foo\"") + .bind(Foo::Baz) + .fetch_one(&mut conn) + .await?; + + assert_eq!(foo, Foo::Baz); +} \ No newline at end of file diff --git a/tests/postgres/setup.sql b/tests/postgres/setup.sql index 5a415324d8..425bd4c534 100644 --- a/tests/postgres/setup.sql +++ b/tests/postgres/setup.sql @@ -51,3 +51,7 @@ CREATE OR REPLACE PROCEDURE forty_two(INOUT forty_two INT = NULL) CREATE TABLE test_citext ( foo CITEXT NOT NULL ); + +CREATE SCHEMA IF NOT EXISTS foo; + +CREATE ENUM foo."Foo" ('Bar', 'Baz'); \ No newline at end of file From 43dd68186a1ef39f1c75375382943adcb64ef2e7 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 31 May 2024 13:33:43 -0700 Subject: [PATCH 2/6] fix(postgres): derive `PgHasArrayType` for enums --- sqlx-macros-core/src/derives/type.rs | 22 +++++++++++++++++++++- tests/postgres/derives.rs | 19 +++++++++++++++---- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/sqlx-macros-core/src/derives/type.rs b/sqlx-macros-core/src/derives/type.rs index a4962ee960..140802ebda 100644 --- a/sqlx-macros-core/src/derives/type.rs +++ b/sqlx-macros-core/src/derives/type.rs @@ -130,7 +130,7 @@ fn expand_derive_has_sql_type_weak_enum( let attr = check_weak_enum_attributes(input, variants)?; let repr = attr.repr.unwrap(); let ident = &input.ident; - let ts = quote!( + let mut ts = quote!( #[automatically_derived] impl ::sqlx::Type for #ident where @@ -146,6 +146,16 @@ fn expand_derive_has_sql_type_weak_enum( } ); + if cfg!(feature = "postgres") && !attributes.no_pg_array { + ts.extend(quote!( + impl ::sqlx::postgres::PgHasArrayType for #ident { + fn array_type_info() -> ::sqlx::postgres::PgTypeInfo { + <#ty as ::sqlx::postgres::PgHasArrayType>::array_type_info() + } + } + )); + } + Ok(ts) } @@ -184,6 +194,16 @@ fn expand_derive_has_sql_type_strong_enum( } } )); + + if !attributes.no_pg_array { + tts.extend(quote!( + impl ::sqlx::postgres::PgHasArrayType for #ident { + fn array_type_info() -> ::sqlx::postgres::PgTypeInfo { + <#ty as ::sqlx::postgres::PgHasArrayType>::array_type_info() + } + } + )); + } } if cfg!(feature = "sqlite") { diff --git a/tests/postgres/derives.rs b/tests/postgres/derives.rs index d840589f78..086fcc6db9 100644 --- a/tests/postgres/derives.rs +++ b/tests/postgres/derives.rs @@ -154,13 +154,15 @@ test_type!(transparent_array(Postgres, test_type!(weak_enum(Postgres, "0::int4" == Weak::One, "2::int4" == Weak::Two, - "4::int4" == Weak::Three + "4::int4" == Weak::Three, + "'{0, 2, 4}'::int4[]" == vec![Weak::One, Weak::Two, Weak::Three], )); test_type!(strong_enum(Postgres, "'one'::text" == Strong::One, "'two'::text" == Strong::Two, - "'four'::text" == Strong::Three + "'four'::text" == Strong::Three, + "'{'one', 'two', 'four'}'::text[]" == vec![Strong::One, Strong::Two, Strong::Three], )); test_type!(floatrange(Postgres, @@ -739,7 +741,8 @@ async fn test_enum_with_schema() -> anyhow::Result<()> { let foo: Foo = sqlx::query_scalar("SELECT $1::foo.\"Foo\"") .bind(Foo::Bar) - .fetch_one(&mut conn).await?; + .fetch_one(&mut conn) + .await?; assert_eq!(foo, Foo::Bar); @@ -749,4 +752,12 @@ async fn test_enum_with_schema() -> anyhow::Result<()> { .await?; assert_eq!(foo, Foo::Baz); -} \ No newline at end of file + + let foos: Vec = sqlx::query_scalar!("SELECT ARRAY($1::foo.\"Foo\", $2::foo.\"Foo\")") + .bind(Foo::Bar) + .bind(Foo::Baz) + .fetch_one(&mut conn) + .await?; + + assert_eq!(foos, [Foo::Bar, Foo::Baz]); +} From 95351bbf732d29205197919aec694535221cd31e Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 31 May 2024 13:37:43 -0700 Subject: [PATCH 3/6] fix: run `rustfmt` --- sqlx-postgres/src/connection/describe.rs | 9 ++++++++- sqlx-postgres/src/copy.rs | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index 1c5816b4a1..1eedd44656 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -186,7 +186,14 @@ impl PgConnection { fn fetch_type_by_oid(&mut self, oid: Oid) -> BoxFuture<'_, Result> { Box::pin(async move { - let (name, typ_type, category, relation_id, element, base_type): (String, i8, i8, Oid, Oid, Oid) = query_as( + let (name, typ_type, category, relation_id, element, base_type): ( + String, + i8, + i8, + Oid, + Oid, + Oid, + ) = query_as( // Converting the OID to `regtype` and then `text` will give us the name that // the type will need to be found at by search_path. "SELECT oid::regtype::text, \ diff --git a/sqlx-postgres/src/copy.rs b/sqlx-postgres/src/copy.rs index d2fe7215d6..26e02e3b67 100644 --- a/sqlx-postgres/src/copy.rs +++ b/sqlx-postgres/src/copy.rs @@ -216,7 +216,7 @@ impl> PgCopyIn { let buf = conn.stream.write_buffer_mut(); // Write the CopyData format code and reserve space for the length. - // This may end up sending an empty `CopyData` packet if, after this point, + // This may end up sending an empty `CopyData` packet if, after this point, // we get canceled or read 0 bytes, but that should be fine. buf.put_slice(b"d\0\0\0\x04"); From 501d8ad2e677b8b4c4e1194e2b5e9251bf8692d3 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 31 May 2024 13:45:21 -0700 Subject: [PATCH 4/6] fix compilation errors --- sqlx-macros-core/src/derives/attributes.rs | 6 ------ sqlx-macros-core/src/derives/type.rs | 10 +++++----- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/sqlx-macros-core/src/derives/attributes.rs b/sqlx-macros-core/src/derives/attributes.rs index 9d7be0e85f..90b1f0b1c4 100644 --- a/sqlx-macros-core/src/derives/attributes.rs +++ b/sqlx-macros-core/src/derives/attributes.rs @@ -218,12 +218,6 @@ pub fn check_enum_attributes(input: &DeriveInput) -> syn::Result, ) -> syn::Result { - let attr = check_weak_enum_attributes(input, variants)?; - let repr = attr.repr.unwrap(); + let attrs = check_weak_enum_attributes(input, variants)?; + let repr = attrs.repr.unwrap(); let ident = &input.ident; let mut ts = quote!( #[automatically_derived] @@ -146,11 +146,11 @@ fn expand_derive_has_sql_type_weak_enum( } ); - if cfg!(feature = "postgres") && !attributes.no_pg_array { + if cfg!(feature = "postgres") && !attrs.no_pg_array { ts.extend(quote!( impl ::sqlx::postgres::PgHasArrayType for #ident { fn array_type_info() -> ::sqlx::postgres::PgTypeInfo { - <#ty as ::sqlx::postgres::PgHasArrayType>::array_type_info() + <#ident as ::sqlx::postgres::PgHasArrayType>::array_type_info() } } )); @@ -199,7 +199,7 @@ fn expand_derive_has_sql_type_strong_enum( tts.extend(quote!( impl ::sqlx::postgres::PgHasArrayType for #ident { fn array_type_info() -> ::sqlx::postgres::PgTypeInfo { - <#ty as ::sqlx::postgres::PgHasArrayType>::array_type_info() + <#ident as ::sqlx::postgres::PgHasArrayType>::array_type_info() } } )); From 82211ae186ee0a5069dc491107abc5b65a05548f Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 31 May 2024 13:54:12 -0700 Subject: [PATCH 5/6] fix trailing line break --- tests/postgres/setup.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/postgres/setup.sql b/tests/postgres/setup.sql index 425bd4c534..ba38d7c182 100644 --- a/tests/postgres/setup.sql +++ b/tests/postgres/setup.sql @@ -54,4 +54,4 @@ CREATE TABLE test_citext ( CREATE SCHEMA IF NOT EXISTS foo; -CREATE ENUM foo."Foo" ('Bar', 'Baz'); \ No newline at end of file +CREATE ENUM foo."Foo" ('Bar', 'Baz'); From 6b427abea6b855e3bea31ef56b2b1751cd385699 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 31 May 2024 14:35:53 -0700 Subject: [PATCH 6/6] fix(postgres): case-aware type name equality --- sqlx-postgres/src/type_info.rs | 98 +++++++++++++++++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/sqlx-postgres/src/type_info.rs b/sqlx-postgres/src/type_info.rs index 5952291e6e..d9d42a24c5 100644 --- a/sqlx-postgres/src/type_info.rs +++ b/sqlx-postgres/src/type_info.rs @@ -1154,7 +1154,103 @@ impl PartialEq for PgType { true } else { // Otherwise, perform a match on the name - self.name().eq_ignore_ascii_case(other.name()) + name_eq(self.name(), other.name()) } } } + +/// Check type names for equality, respecting Postgres' case sensitivity rules for identifiers. +/// +/// https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +fn name_eq(name1: &str, name2: &str) -> bool { + // Cop-out of processing Unicode escapes by just using string equality. + if name1.starts_with("U&") { + // If `name2` doesn't start with `U&` this will automatically be `false`. + return name1 == name2; + } + + let mut chars1 = identifier_chars(name1); + let mut chars2 = identifier_chars(name2); + + while let (Some(a), Some(b)) = (chars1.next(), chars2.next()) { + if !a.eq(&b) { + return false; + } + } + + chars1.next().is_none() && chars2.next().is_none() +} + +struct IdentifierChar { + ch: char, + case_sensitive: bool, +} + +impl IdentifierChar { + fn eq(&self, other: &Self) -> bool { + if self.case_sensitive || other.case_sensitive { + self.ch == other.ch + } else { + self.ch.eq_ignore_ascii_case(&other.ch) + } + } +} + +/// Return an iterator over all significant characters of an identifier. +/// +/// Ignores non-escaped quotation marks. +fn identifier_chars(ident: &str) -> impl Iterator + '_ { + let mut case_sensitive = false; + let mut last_char_quote = false; + + ident.chars().filter_map(move |ch| { + if ch == '"' { + if last_char_quote { + last_char_quote = false; + } else { + last_char_quote = true; + return None; + } + } else if last_char_quote { + last_char_quote = false; + case_sensitive = !case_sensitive; + } + + Some(IdentifierChar { ch, case_sensitive }) + }) +} + +#[test] +fn test_name_eq() { + let test_values = [ + ("foo", "foo", true), + ("foo", "Foo", true), + ("foo", "FOO", true), + ("foo", r#""foo""#, true), + ("foo", r#""Foo""#, false), + ("foo", "foo.foo", false), + ("foo.foo", "foo.foo", true), + ("foo.foo", "foo.Foo", true), + ("foo.foo", "foo.FOO", true), + ("foo.foo", "Foo.foo", true), + ("foo.foo", "Foo.Foo", true), + ("foo.foo", "FOO.FOO", true), + ("foo.foo", "foo", false), + ("foo.foo", r#"foo."foo""#, true), + ("foo.foo", r#"foo."Foo""#, false), + ("foo.foo", r#"foo."FOO""#, false), + ]; + + for (left, right, eq) in test_values { + assert_eq!( + name_eq(left, right), + eq, + "failed check for name_eq({left:?}, {right:?})" + ); + assert_eq!( + name_eq(right, left), + eq, + "failed check for name_eq({right:?}, {left:?})" + ); + } +}