From 90dd39a75527294921a1c09e6338ba17ed8fc830 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 15 Aug 2024 02:17:00 -0700 Subject: [PATCH 01/40] fix(ci): enable unit-tests for all relevant packages --- .github/workflows/sqlx.yml | 46 +++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 04b0ae41e6..a7d93f5e6c 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -61,12 +61,8 @@ jobs: - run: cargo build --all-features test: - name: Unit Test + name: Unit Tests runs-on: ubuntu-22.04 - strategy: - matrix: - runtime: [async-std, tokio] - tls: [native-tls, rustls-aws-lc-rs, rustls-ring, none] steps: - uses: actions/checkout@v4 @@ -74,10 +70,44 @@ jobs: with: key: ${{ runner.os }}-test - - run: > + - name: Install Rust + run: rustup update + + - name: Test sqlx-core + run: > + cargo test + -p sqlx-core + --all-features + + - name: Test sqlx-mysql + run: > + cargo test + -p sqlx-mysql + --all-features + + - name: Test sqlx-postgres + run: > + cargo test + -p sqlx-postgres + --all-features + + - name: Test sqlx-sqlite + run: > + cargo test + -p sqlx-sqlite + --all-features + + - name: Test sqlx-macros-core + run: > + cargo test + -p sqlx-macros-core + --all-features + + - name: Test sqlx + run: > cargo test - --manifest-path sqlx-core/Cargo.toml - --features json,_rt-${{ matrix.runtime }},_tls-${{ matrix.tls }} + -p sqlx + --all-features sqlite: name: SQLite From 9bef3973b2474bf753d6983e9410f0b33e1bc375 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 15 Aug 2024 03:30:23 -0700 Subject: [PATCH 02/40] fix: tests in `sqlx-postgres` --- sqlx-core/src/lib.rs | 2 + sqlx-postgres/Cargo.toml | 17 +++++--- sqlx-postgres/src/advisory_lock.rs | 1 - sqlx-postgres/src/connection/describe.rs | 54 +++++++++++++----------- sqlx-postgres/src/lib.rs | 2 + sqlx-postgres/src/listener.rs | 14 +++--- sqlx-postgres/src/migrate.rs | 1 + sqlx-postgres/src/options/mod.rs | 3 +- sqlx-postgres/src/type_info.rs | 2 +- 9 files changed, 55 insertions(+), 41 deletions(-) diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index cc0122c907..8636760401 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -17,6 +17,8 @@ #![allow(clippy::needless_doctest_main, clippy::type_complexity)] // See `clippy.toml` at the workspace root #![deny(clippy::disallowed_methods)] +#![deny(clippy::cast_possible_truncation)] +#![deny(clippy::cast_possible_wrap)] // The only unsafe code in SQLx is that necessary to interact with native APIs like with SQLite, // and that can live in its own separate driver crate. #![forbid(unsafe_code)] diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index 13dac25868..6534592d27 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -15,9 +15,15 @@ json = ["sqlx-core/json"] migrate = ["sqlx-core/migrate"] offline = ["sqlx-core/offline"] -# Type integration features which require additional dependencies -rust_decimal = ["dep:rust_decimal", "rust_decimal/maths"] -bigdecimal = ["dep:bigdecimal", "dep:num-bigint"] +# Type Integration features +bigdecimal = ["dep:bigdecimal", "dep:num-bigint", "sqlx-core/bigdecimal"] +bit-vec = ["dep:bit-vec", "sqlx-core/bit-vec"] +chrono = ["dep:chrono", "sqlx-core/chrono"] +ipnetwork = ["dep:ipnetwork", "sqlx-core/ipnetwork"] +mac_address = ["dep:mac_address", "sqlx-core/mac_address"] +rust_decimal = ["dep:rust_decimal", "rust_decimal/maths", "sqlx-core/rust_decimal"] +time = ["dep:time", "sqlx-core/time"] +uuid = ["dep:uuid", "sqlx-core/uuid"] [dependencies] # Futures crates @@ -71,8 +77,9 @@ workspace = true # We use JSON in the driver implementation itself so there's no reason not to enable it here. features = ["json"] -[dev-dependencies] -sqlx.workspace = true +[dev-dependencies.sqlx] +workspace = true +features = ["postgres", "derive"] [target.'cfg(target_os = "windows")'.dependencies] etcetera = "0.8.0" diff --git a/sqlx-postgres/src/advisory_lock.rs b/sqlx-postgres/src/advisory_lock.rs index 82191726f2..982744137f 100644 --- a/sqlx-postgres/src/advisory_lock.rs +++ b/sqlx-postgres/src/advisory_lock.rs @@ -98,7 +98,6 @@ impl PgAdvisoryLock { /// [hkdf]: https://datatracker.ietf.org/doc/html/rfc5869 /// ### Example /// ```rust - /// # extern crate sqlx_core as sqlx; /// use sqlx::postgres::{PgAdvisoryLock, PgAdvisoryLockKey}; /// /// let lock = PgAdvisoryLock::new("my first Postgres advisory lock!"); diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index 82bba18f60..a579f9217a 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -2,17 +2,17 @@ use crate::error::Error; use crate::ext::ustr::UStr; use crate::message::{ParameterDescription, RowDescription}; use crate::query_as::query_as; -use crate::query_scalar::{query_scalar, query_scalar_with}; +use crate::query_scalar::{query_scalar}; use crate::statement::PgStatementMetadata; use crate::type_info::{PgArrayOf, PgCustomType, PgType, PgTypeKind}; use crate::types::Json; use crate::types::Oid; use crate::HashMap; -use crate::{PgArguments, PgColumn, PgConnection, PgTypeInfo}; +use crate::{PgColumn, PgConnection, PgTypeInfo}; use futures_core::future::BoxFuture; use smallvec::SmallVec; -use std::fmt::Write; use std::sync::Arc; +use sqlx_core::query_builder::QueryBuilder; /// Describes the type of the `pg_type.typtype` column /// @@ -423,29 +423,34 @@ WHERE rngtypid = $1 return Ok(vec![]); } - let mut nullable_query = String::from("SELECT NOT pg_attribute.attnotnull FROM (VALUES "); - let mut args = PgArguments::default(); - - for (i, (column, bind)) in meta.columns.iter().zip((1..).step_by(3)).enumerate() { - if !args.buffer.is_empty() { - nullable_query += ", "; - } - - let _ = write!( - nullable_query, - "(${}::int4, ${}::int4, ${}::int2)", - bind, - bind + 1, - bind + 2 + if meta.columns.len() * 3 > 65535 { + tracing::debug!( + ?stmt_id, + num_columns=meta.columns.len(), + "number of columns in query is too large to pull nullability for" ); - - args.add(i as i32).map_err(Error::Encode)?; - args.add(column.relation_id).map_err(Error::Encode)?; - args.add(column.relation_attribute_no) - .map_err(Error::Encode)?; } - nullable_query.push_str( + // Query for NOT NULL constraints for each column in the query. + // + // This will include columns that don't have a `relation_id` (are not from a table); + // assuming those are a minority of columns, it's less code to _not_ work around it + // and just let Postgres return `NULL`. + let mut nullable_query = QueryBuilder::new( + "SELECT NOT pg_attribute.attnotnull FROM ( " + ); + + nullable_query.push_values( + meta.columns.iter().zip(0i32..), + |mut tuple, (column, i)| { + // ({i}::int4, {column.relation_id}::int4, {column.relation_attribute_no}::int2) + tuple.push_bind(i).push_unseparated("::int4"); + tuple.push_bind(column.relation_id).push_unseparated("::int4"); + tuple.push_bind(column.relation_attribute_no).push_bind_unseparated("::int2"); + }, + ); + + nullable_query.push( ") as col(idx, table_id, col_idx) \ LEFT JOIN pg_catalog.pg_attribute \ ON table_id IS NOT NULL \ @@ -454,7 +459,8 @@ WHERE rngtypid = $1 ORDER BY col.idx", ); - let mut nullables = query_scalar_with::<_, Option, _>(&nullable_query, args) + let mut nullables: Vec> = nullable_query + .build_query_scalar() .fetch_all(&mut *self) .await?; diff --git a/sqlx-postgres/src/lib.rs b/sqlx-postgres/src/lib.rs index c50f53067e..2423acb8f5 100644 --- a/sqlx-postgres/src/lib.rs +++ b/sqlx-postgres/src/lib.rs @@ -1,4 +1,6 @@ //! **PostgreSQL** database driver. +#![deny(clippy::cast_possible_truncation)] +#![deny(clippy::cast_possible_wrap)] #[macro_use] extern crate sqlx_core; diff --git a/sqlx-postgres/src/listener.rs b/sqlx-postgres/src/listener.rs index f23b81498a..ca4f78a275 100644 --- a/sqlx-postgres/src/listener.rs +++ b/sqlx-postgres/src/listener.rs @@ -188,19 +188,17 @@ impl PgListener { /// # Example /// /// ```rust,no_run - /// # use sqlx_core::postgres::PgListener; - /// # use sqlx_core::error::Error; + /// # use sqlx::postgres::PgListener; /// # - /// # #[cfg(feature = "_rt")] /// # sqlx::__rt::test_block_on(async move { - /// # let mut listener = PgListener::connect("postgres:// ...").await?; + /// let mut listener = PgListener::connect("postgres:// ...").await?; /// loop { /// // ask for next notification, re-connecting (transparently) if needed /// let notification = listener.recv().await?; /// /// // handle notification, do something interesting /// } - /// # Result::<(), Error>::Ok(()) + /// # Result::<(), sqlx::Error>::Ok(()) /// # }).unwrap(); /// ``` pub async fn recv(&mut self) -> Result { @@ -219,10 +217,8 @@ impl PgListener { /// # Example /// /// ```rust,no_run - /// # use sqlx_core::postgres::PgListener; - /// # use sqlx_core::error::Error; + /// # use sqlx::postgres::PgListener; /// # - /// # #[cfg(feature = "_rt")] /// # sqlx::__rt::test_block_on(async move { /// # let mut listener = PgListener::connect("postgres:// ...").await?; /// loop { @@ -233,7 +229,7 @@ impl PgListener { /// /// // connection lost, do something interesting /// } - /// # Result::<(), Error>::Ok(()) + /// # Result::<(), sqlx::Error>::Ok(()) /// # }).unwrap(); /// ``` pub async fn try_recv(&mut self) -> Result, Error> { diff --git a/sqlx-postgres/src/migrate.rs b/sqlx-postgres/src/migrate.rs index 5e62a6287b..da3080581e 100644 --- a/sqlx-postgres/src/migrate.rs +++ b/sqlx-postgres/src/migrate.rs @@ -230,6 +230,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( let elapsed = start.elapsed(); // language=SQL + #[allow(clippy::cast_possible_truncation)] let _ = query( r#" UPDATE _sqlx_migrations diff --git a/sqlx-postgres/src/options/mod.rs b/sqlx-postgres/src/options/mod.rs index b99edf67c9..a0b222606a 100644 --- a/sqlx-postgres/src/options/mod.rs +++ b/sqlx-postgres/src/options/mod.rs @@ -82,7 +82,8 @@ mod ssl_mode; /// // Information about SQL queries is logged at `DEBUG` level by default. /// opts = opts.log_statements(log::LevelFilter::Trace); /// -/// let pool = PgPool::connect_with(&opts).await?; +/// let pool = PgPool::connect_with(opts).await?; +/// # Ok(()) /// # } /// ``` #[derive(Debug, Clone)] diff --git a/sqlx-postgres/src/type_info.rs b/sqlx-postgres/src/type_info.rs index f50ea7fb10..3d948f73d4 100644 --- a/sqlx-postgres/src/type_info.rs +++ b/sqlx-postgres/src/type_info.rs @@ -294,7 +294,7 @@ impl PgTypeInfo { /// in quotes, e.g.: /// ``` /// use sqlx::postgres::PgTypeInfo; - /// use sqlx::Type; + /// use sqlx::{Type, TypeInfo}; /// /// /// `CREATE TYPE "_foo" AS ENUM ('Bar', 'Baz');` /// #[derive(sqlx::Type)] From 5c595cb46ad3a47cd2e8485d8e07722445dee7eb Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 15 Aug 2024 02:50:04 -0700 Subject: [PATCH 03/40] fix(postgres): prevent integer overflow when binding arguments --- sqlx-postgres/src/arguments.rs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/sqlx-postgres/src/arguments.rs b/sqlx-postgres/src/arguments.rs index 7911a0660d..23859fbb79 100644 --- a/sqlx-postgres/src/arguments.rs +++ b/sqlx-postgres/src/arguments.rs @@ -79,7 +79,8 @@ impl PgArguments { // encode the value into our buffer if let Err(error) = self.buffer.encode(value) { - // reset the value buffer to its previous value if encoding failed so we don't leave a half-encoded value behind + // reset the value buffer to its previous value if encoding failed, + // so we don't leave a half-encoded value behind self.buffer.reset_to_snapshot(buffer_snapshot); return Err(error); }; @@ -154,13 +155,18 @@ impl PgArgumentBuffer { where T: Encode<'q, Postgres>, { + // Won't catch everything but is a good sanity check + value_size_int4_checked(value.size_hint())?; + // reserve space to write the prefixed length of the value let offset = self.len(); + self.extend(&[0; 4]); // encode the value into our buffer let len = if let IsNull::No = value.encode(self)? { - (self.len() - offset - 4) as i32 + // Ensure that the value size does not overflow i32 + value_size_int4_checked(self.len() - offset - 4)? } else { // Write a -1 to indicate NULL // NOTE: It is illegal for [encode] to write any data @@ -169,6 +175,7 @@ impl PgArgumentBuffer { }; // write the len to the beginning of the value + // (offset + 4) cannot overflow because it would have failed at `self.extend()`. self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes()); Ok(()) @@ -265,3 +272,8 @@ impl DerefMut for PgArgumentBuffer { &mut self.buffer } } + +pub(crate) fn value_size_int4_checked(size: usize) -> Result { + i32::try_from(size) + .map_err(|_| format!("value size would overflow in the binary protocol encoding: {size} > {}", i32::MAX)) +} From 2ed868d4d3b26bc604948a3643ecabdd97d65d1a Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 15 Aug 2024 03:29:32 -0700 Subject: [PATCH 04/40] fix: audit `sqlx_postgres::types::rust_decimal` for overflowing casts --- sqlx-postgres/src/types/numeric.rs | 8 ++ sqlx-postgres/src/types/rust_decimal.rs | 146 +++++++++++++++++------- 2 files changed, 113 insertions(+), 41 deletions(-) diff --git a/sqlx-postgres/src/types/numeric.rs b/sqlx-postgres/src/types/numeric.rs index b281de46fb..6416872913 100644 --- a/sqlx-postgres/src/types/numeric.rs +++ b/sqlx-postgres/src/types/numeric.rs @@ -75,6 +75,14 @@ impl PgNumericSign { } impl PgNumeric { + /// Equivalent value of `0::numeric`. + pub const ZERO: Self = PgNumeric::Number { + sign: PgNumericSign::Positive, + digits: vec![], + weight: 0, + scale: 0, + }; + pub(crate) fn decode(mut buf: &[u8]) -> Result { // https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L874 let num_digits = buf.get_u16(); diff --git a/sqlx-postgres/src/types/rust_decimal.rs b/sqlx-postgres/src/types/rust_decimal.rs index fa66eb393b..d94dfe34cd 100644 --- a/sqlx-postgres/src/types/rust_decimal.rs +++ b/sqlx-postgres/src/types/rust_decimal.rs @@ -1,4 +1,4 @@ -use rust_decimal::{prelude::Zero, Decimal}; +use rust_decimal::Decimal; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; @@ -25,9 +25,17 @@ impl TryFrom for Decimal { type Error = BoxDynError; fn try_from(numeric: PgNumeric) -> Result { - let (digits, sign, mut weight, scale) = match numeric { + Decimal::try_from(&numeric) + } +} + +impl TryFrom<&'_ PgNumeric> for Decimal { + type Error = BoxDynError; + + fn try_from(numeric: &'_ PgNumeric) -> Result { + let (digits, sign, mut weight, scale) = match *numeric { PgNumeric::Number { - digits, + ref digits, sign, weight, scale, @@ -40,13 +48,13 @@ impl TryFrom for Decimal { if digits.is_empty() { // Postgres returns an empty digit array for 0 - return Ok(0u64.into()); + return Ok(Decimal::ZERO); } let mut value = Decimal::ZERO; // Sum over `digits`, multiply each by its weight and add it to `value`. - for digit in digits { + for &digit in digits { let mul = Decimal::from(10_000i16) .checked_powi(weight as i64) .ok_or("value not representable as rust_decimal::Decimal")?; @@ -71,40 +79,40 @@ impl TryFrom for Decimal { } } +impl From for PgNumeric { + fn from(value: Decimal) -> Self { + PgNumeric::from(&value) + } +} + // This impl is effectively infallible because `NUMERIC` has a greater range than `Decimal`. impl From<&'_ Decimal> for PgNumeric { + // Impl has been manually validated. + #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)] fn from(decimal: &Decimal) -> Self { - // `Decimal` added `is_zero()` as an inherent method in a more recent version - if Zero::is_zero(decimal) { - PgNumeric::Number { - sign: PgNumericSign::Positive, - scale: 0, - weight: 0, - digits: vec![], - }; + if Decimal::is_zero(decimal) { + return PgNumeric::ZERO; } - let scale = decimal.scale() as u16; + assert!( + (0u32..=28).contains(&decimal.scale()), + "decimal scale out of range {:?}", + decimal.unpack(), + ); - // A serialized version of the decimal number. The resulting byte array - // will have the following representation: - // - // Bytes 1-4: flags - // Bytes 5-8: lo portion of m - // Bytes 9-12: mid portion of m - // Bytes 13-16: high portion of m - let mut mantissa = u128::from_le_bytes(decimal.serialize()); + // Cannot overflow: always in the range [0, 28] + let scale = decimal.scale() as u16; - // chop off the flags - mantissa >>= 32; + let mut mantissa = decimal.mantissa().unsigned_abs(); - // If our scale is not a multiple of 4, we need to go to the next - // multiple. + // If our scale is not a multiple of 4, we need to go to the next multiple. let groups_diff = scale % 4; if groups_diff > 0 { let remainder = 4 - groups_diff as u32; let power = 10u32.pow(remainder) as u128; + // Impossible to overflow; 0 <= mantissa <= 2^96, + // and we're multiplying by at most 1,000 (giving us a result < 2^106) mantissa *= power; } @@ -113,16 +121,32 @@ impl From<&'_ Decimal> for PgNumeric { // Convert to base-10000. while mantissa != 0 { + // Cannot overflow or wrap because of the modulus digits.push((mantissa % 10_000) as i16); mantissa /= 10_000; } - // Change the endianness. + // We started with the low digits first, but they should actually be at the end. digits.reverse(); - // Weight is number of digits on the left side of the decimal. - let digits_after_decimal = (scale + 3) / 4; - let weight = digits.len() as i16 - digits_after_decimal as i16 - 1; + // Cannot overflow: strictly smaller than `scale`. + let digits_after_decimal = scale.div_ceil(4) as i16; + + // `mantissa` contains at most 29 decimal digits (log10(2^96)), + // split into at most 8 4-digit segments. + assert!( + digits.len() <= 8, + "digits.len() out of range: {}; unpacked: {:?}", + digits.len(), + decimal.unpack() + ); + + // Cannot overflow; at most 8 + let num_digits = digits.len() as i16; + + // Find how many 4-digit segments should go before the decimal point. + // `weight = 0` puts just `digit[0]` before the decimal point, and the rest after. + let weight = num_digits - digits_after_decimal - 1; // Remove non-significant zeroes. while let Some(&0) = digits.last() { @@ -134,6 +158,7 @@ impl From<&'_ Decimal> for PgNumeric { false => PgNumericSign::Positive, true => PgNumericSign::Negative, }, + // Cannot overflow; between 0 and 28 scale: scale as i16, weight, digits, @@ -160,7 +185,7 @@ impl Decode<'_, Postgres> for Decimal { } #[cfg(test)] -mod decimal_to_pgnumeric { +mod tests { use super::{Decimal, PgNumeric, PgNumericSign}; use std::convert::TryFrom; @@ -169,13 +194,13 @@ mod decimal_to_pgnumeric { let zero: Decimal = "0".parse().unwrap(); assert_eq!( - PgNumeric::try_from(&zero).unwrap(), - PgNumeric::Number { - sign: PgNumericSign::Positive, - scale: 0, - weight: 0, - digits: vec![] - } + PgNumeric::from(&zero), + PgNumeric::ZERO, + ); + + assert_eq!( + Decimal::try_from(&PgNumeric::ZERO).unwrap(), + Decimal::ZERO ); } @@ -343,6 +368,48 @@ mod decimal_to_pgnumeric { assert_eq!(actual_decimal.scale(), 8); } + #[test] + fn max_value() { + let expected_numeric = PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 7, + digits: vec![7, 9228, 1625, 1426, 4337, 5935, 4395, 0335], + }; + assert_eq!( + PgNumeric::try_from(&Decimal::MAX).unwrap(), + expected_numeric + ); + + let actual_decimal = Decimal::try_from(expected_numeric).unwrap(); + assert_eq!(actual_decimal, Decimal::MAX); + // Value split by 10,000's to match the expected digits[] + assert_eq!(actual_decimal.mantissa(), 7_9228_1625_1426_4337_5935_4395_0335); + assert_eq!(actual_decimal.scale(), 0); + } + + #[test] + fn max_value_max_scale() { + let mut max_value_max_scale = Decimal::MAX; + max_value_max_scale.set_scale(28).unwrap(); + + let expected_numeric = PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 28, + weight: 0, + digits: vec![7, 9228, 1625, 1426, 4337, 5935, 4395, 0335], + }; + assert_eq!( + PgNumeric::try_from(&max_value_max_scale).unwrap(), + expected_numeric + ); + + let actual_decimal = Decimal::try_from(expected_numeric).unwrap(); + assert_eq!(actual_decimal, max_value_max_scale); + assert_eq!(actual_decimal.mantissa(), 79_228_162_514_264_337_593_543_950_335); + assert_eq!(actual_decimal.scale(), 28); + } + #[test] fn issue_423_four_digit() { // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 @@ -420,7 +487,4 @@ mod decimal_to_pgnumeric { assert_eq!(actual_decimal.mantissa(), 10000); assert_eq!(actual_decimal.scale(), 2); } - - #[test] - fn issue_666_trailing_zeroes_at_max_precision() {} } From e1f04cb04d49c3fa1290ce9bf7a9790b977a4581 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 15 Aug 2024 03:29:45 -0700 Subject: [PATCH 05/40] fix: audit `sqlx_postgres::types::bit_vec` for overflowing casts --- sqlx-postgres/src/types/bit_vec.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sqlx-postgres/src/types/bit_vec.rs b/sqlx-postgres/src/types/bit_vec.rs index 6b48c722d1..299aaf828a 100644 --- a/sqlx-postgres/src/types/bit_vec.rs +++ b/sqlx-postgres/src/types/bit_vec.rs @@ -8,6 +8,7 @@ use crate::{ use bit_vec::BitVec; use sqlx_core::bytes::Buf; use std::{io, mem}; +use crate::arguments::value_size_int4_checked; impl Type for BitVec { fn type_info() -> PgTypeInfo { @@ -31,7 +32,9 @@ impl PgHasArrayType for BitVec { impl Encode<'_, Postgres> for BitVec { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { - buf.extend(&(self.len() as i32).to_be_bytes()); + let len = value_size_int4_checked(self.len())?; + + buf.extend(len.to_be_bytes()); buf.extend(self.to_bytes()); Ok(IsNull::No) From 74e720e4177ac9909709587d92ac55c64565decf Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 15 Aug 2024 03:32:14 -0700 Subject: [PATCH 06/40] fix: audit `sqlx_postgres::types::time` for overflowing casts --- sqlx-postgres/src/types/time/date.rs | 7 +++++-- sqlx-postgres/src/types/time/datetime.rs | 4 +++- sqlx-postgres/src/types/time/time.rs | 7 +++++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/sqlx-postgres/src/types/time/date.rs b/sqlx-postgres/src/types/time/date.rs index d1a2d34df8..619537c5af 100644 --- a/sqlx-postgres/src/types/time/date.rs +++ b/sqlx-postgres/src/types/time/date.rs @@ -22,8 +22,11 @@ impl PgHasArrayType for Date { impl Encode<'_, Postgres> for Date { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { - // DATE is encoded as the days since epoch - let days = (*self - PG_EPOCH).whole_days() as i32; + // DATE is encoded as number of days since epoch (2000-01-01) + let days: i32 = (*self - PG_EPOCH) + .whole_days() + .try_into() + .map_err(|_| format!("value {self:?} would overflow binary encoding for Postgres DATE"))?; Encode::::encode(days, buf) } diff --git a/sqlx-postgres/src/types/time/datetime.rs b/sqlx-postgres/src/types/time/datetime.rs index 3dc9e849f6..d51002286d 100644 --- a/sqlx-postgres/src/types/time/datetime.rs +++ b/sqlx-postgres/src/types/time/datetime.rs @@ -37,7 +37,9 @@ impl PgHasArrayType for OffsetDateTime { impl Encode<'_, Postgres> for PrimitiveDateTime { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { // TIMESTAMP is encoded as the microseconds since the epoch - let micros = (*self - PG_EPOCH.midnight()).whole_microseconds() as i64; + let micros: i64 = (*self - PG_EPOCH.midnight()).whole_microseconds() + .try_into() + .map_err(|_| format!("value {self:?} would overflow binary encoding for Postgres TIME"))?; Encode::::encode(micros, buf) } diff --git a/sqlx-postgres/src/types/time/time.rs b/sqlx-postgres/src/types/time/time.rs index 61be6f19f5..635170d14b 100644 --- a/sqlx-postgres/src/types/time/time.rs +++ b/sqlx-postgres/src/types/time/time.rs @@ -21,8 +21,11 @@ impl PgHasArrayType for Time { impl Encode<'_, Postgres> for Time { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { - // TIME is encoded as the microseconds since midnight - let micros = (*self - Time::MIDNIGHT).whole_microseconds() as i64; + // TIME is encoded as the microseconds since midnight. + // + // A truncating cast is fine because `self - Time::MIDNIGHT` cannot exceed a span of 24 hours. + #[allow(clippy::cast_possible_truncation)] + let micros: i64 = (*self - Time::MIDNIGHT).whole_microseconds() as i64; Encode::::encode(micros, buf) } From 13561cdafc426c4cfccffbe631e79bd9a8db65e6 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 15 Aug 2024 15:09:32 -0700 Subject: [PATCH 07/40] fix: audit `sqlx_postgres::types::chrono` for overflowing casts --- sqlx-postgres/src/types/chrono/date.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sqlx-postgres/src/types/chrono/date.rs b/sqlx-postgres/src/types/chrono/date.rs index 475f41f400..5fc0a8f08d 100644 --- a/sqlx-postgres/src/types/chrono/date.rs +++ b/sqlx-postgres/src/types/chrono/date.rs @@ -23,7 +23,11 @@ impl PgHasArrayType for NaiveDate { impl Encode<'_, Postgres> for NaiveDate { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { // DATE is encoded as the days since epoch - let days = (*self - postgres_epoch_date()).num_days() as i32; + let days: i32 = (*self - postgres_epoch_date()) + .num_days() + .try_into() + .map_err(|_| format!("value {self:?} would overflow binary encoding for Postgres DATE"))?; + Encode::::encode(days, buf) } From 6599f1cda61c289f67031b7e4a935e38c4b9f112 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 15 Aug 2024 22:15:24 -0700 Subject: [PATCH 08/40] fix: audit `sqlx_postgres::types::cube` for overflowing casts --- sqlx-postgres/src/lib.rs | 2 + sqlx-postgres/src/types/cube.rs | 325 +++++++++++++++++++++----------- 2 files changed, 219 insertions(+), 108 deletions(-) diff --git a/sqlx-postgres/src/lib.rs b/sqlx-postgres/src/lib.rs index 2423acb8f5..2bfc30d88e 100644 --- a/sqlx-postgres/src/lib.rs +++ b/sqlx-postgres/src/lib.rs @@ -1,6 +1,8 @@ //! **PostgreSQL** database driver. +// https://github.com/launchbadge/sqlx/issues/3440 #![deny(clippy::cast_possible_truncation)] #![deny(clippy::cast_possible_wrap)] +#![deny(clippy::cast_sign_loss)] #[macro_use] extern crate sqlx_core; diff --git a/sqlx-postgres/src/types/cube.rs b/sqlx-postgres/src/types/cube.rs index bf778b8b91..a489e31d80 100644 --- a/sqlx-postgres/src/types/cube.rs +++ b/sqlx-postgres/src/types/cube.rs @@ -3,24 +3,52 @@ use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; use sqlx_core::Error; use std::str::FromStr; const BYTE_WIDTH: usize = 8; -const CUBE_TYPE_ZERO_VOLUME: usize = 128; -const CUBE_TYPE_DEFAULT: usize = 0; -const CUBE_DIMENSION_ONE: usize = 1; -const DIMENSIONALITY_POSITION: usize = 3; -const START_INDEX: usize = 4; +/// +const MAX_DIMENSIONS: usize = 100; + +const IS_POINT_FLAG: u32 = 1 << 31; + +// FIXME(breaking): these variants are confusingly named and structured +// consider changing them or making this an opaque wrapper around `Vec` #[derive(Debug, Clone, PartialEq)] pub enum PgCube { + /// A one-dimensional point. + // FIXME: `Point1D(f64) Point(f64), + /// An N-dimensional point ("represented internally as a zero-volume cube"). + // FIXME: `PointND(f64)` ZeroVolume(Vec), + + /// A one-dimensional interval with starting and ending points. + // FIXME: `Interval1D { start: f64, end: f64 }` OneDimensionInterval(f64, f64), + + // FIXME: add `Cube3D { lower_left: [f64; 3], upper_right: [f64; 3] }`? + /// An N-dimensional cube with points representing lower-left and upper-right corners, respectively. + // FIXME: CubeND { lower_left: Vec, upper_right: Vec }` MultiDimension(Vec>), } +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +struct Header { + dimensions: usize, + is_point: bool, +} + +#[derive(Debug, thiserror::Error)] +#[error("error decoding CUBE (is_point: {is_point}, dimensions: {dimensions})")] +struct DecodeError { + is_point: bool, + dimensions: usize, + message: String, +} + impl Type for PgCube { fn type_info() -> PgTypeInfo { PgTypeInfo::with_name("cube") @@ -37,7 +65,7 @@ impl<'r> Decode<'r, Postgres> for PgCube { fn decode(value: PgValueRef<'r>) -> Result> { match value.format() { PgValueFormat::Text => Ok(PgCube::from_str(value.as_str()?)?), - PgValueFormat::Binary => Ok(pg_cube_from_bytes(value.as_bytes()?)?), + PgValueFormat::Binary => Ok(PgCube::from_bytes(value.as_bytes()?)?), } } } @@ -51,6 +79,10 @@ impl<'q> Encode<'q, Postgres> for PgCube { self.serialize(buf)?; Ok(IsNull::No) } + + fn size_hint(&self) -> usize { + self.header().encoded_size() + } } impl FromStr for PgCube { @@ -81,86 +113,84 @@ impl FromStr for PgCube { } } -fn pg_cube_from_bytes(bytes: &[u8]) -> Result { - let cube_type = bytes - .first() - .map(|&byte| byte as usize) - .ok_or(Error::Decode( - format!("Could not decode cube bytes: {:?}", bytes).into(), - ))?; - - let dimensionality = bytes - .get(DIMENSIONALITY_POSITION) - .map(|&byte| byte as usize) - .ok_or(Error::Decode( - format!("Could not decode cube bytes: {:?}", bytes).into(), - ))?; - - match (cube_type, dimensionality) { - (CUBE_TYPE_ZERO_VOLUME, CUBE_DIMENSION_ONE) => { - let point = get_f64_from_bytes(bytes, 4)?; - Ok(PgCube::Point(point)) +impl PgCube { + fn header(&self) -> Header { + match self { + PgCube::Point(..) => Header { + is_point: true, + dimensions: 1, + }, + PgCube::ZeroVolume(values) => Header { + is_point: true, + dimensions: values.len(), + }, + PgCube::OneDimensionInterval(..) => Header { + is_point: false, + dimensions: 1, + }, + PgCube::MultiDimension(multi_values) => Header { + is_point: false, + dimensions: multi_values.first().map(|arr| arr.len()).unwrap_or(0), + }, } - (CUBE_TYPE_ZERO_VOLUME, _) => { - Ok(PgCube::ZeroVolume(deserialize_vector(bytes, START_INDEX)?)) + } + + fn from_bytes(mut bytes: &[u8]) -> Result { + let header = Header::try_read(&mut bytes)?; + + if bytes.len() != header.data_size() { + return Err(DecodeError::new( + &header, + format!( + "expected {} bytes after header, got {}", + header.data_size(), + bytes.len() + ), + ) + .into()); } - (CUBE_TYPE_DEFAULT, CUBE_DIMENSION_ONE) => { - let x_start = 4; - let y_start = x_start + BYTE_WIDTH; - let x = get_f64_from_bytes(bytes, x_start)?; - let y = get_f64_from_bytes(bytes, y_start)?; - Ok(PgCube::OneDimensionInterval(x, y)) + + match (header.is_point, header.dimensions) { + (true, 1) => Ok(PgCube::Point(bytes.get_f64())), + (true, _) => Ok(PgCube::ZeroVolume( + read_vec(&mut bytes).map_err(|e| DecodeError::new(&header, e))?, + )), + (false, 1) => Ok(PgCube::OneDimensionInterval( + bytes.get_f64(), + bytes.get_f64(), + )), + (false, _) => Ok(PgCube::MultiDimension(read_cube(&header, bytes)?)), } - (CUBE_TYPE_DEFAULT, dim) => Ok(PgCube::MultiDimension(deserialize_matrix( - bytes, - START_INDEX, - dim, - )?)), - (flag, dimension) => Err(Error::Decode( - format!( - "Could not deserialise cube with flag {} and dimension {}: {:?}", - flag, dimension, bytes - ) - .into(), - )), } -} -impl PgCube { - fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), Error> { + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { + let header = self.header(); + + buff.reserve(header.data_size()); + + header.try_write(buff)?; + match self { PgCube::Point(value) => { - buff.extend(&[CUBE_TYPE_ZERO_VOLUME as u8, 0, 0, CUBE_DIMENSION_ONE as u8]); buff.extend_from_slice(&value.to_be_bytes()); } PgCube::ZeroVolume(values) => { - let dimension = values.len() as u8; - buff.extend_from_slice(&[CUBE_TYPE_ZERO_VOLUME as u8, 0, 0]); - buff.extend_from_slice(&dimension.to_be_bytes()); - let bytes = values - .iter() - .flat_map(|v| v.to_be_bytes()) - .collect::>(); - buff.extend_from_slice(&bytes); + buff.extend(values.iter().flat_map(|v| v.to_be_bytes())); } PgCube::OneDimensionInterval(x, y) => { - buff.extend_from_slice(&[0, 0, 0, CUBE_DIMENSION_ONE as u8]); buff.extend_from_slice(&x.to_be_bytes()); buff.extend_from_slice(&y.to_be_bytes()); } PgCube::MultiDimension(multi_values) => { - let dimension = multi_values - .first() - .map(|arr| arr.len() as u8) - .unwrap_or(1_u8); - buff.extend_from_slice(&[0, 0, 0]); - buff.extend_from_slice(&dimension.to_be_bytes()); - let bytes = multi_values - .iter() - .flatten() - .flat_map(|v| v.to_be_bytes()) - .collect::>(); - buff.extend_from_slice(&bytes); + if multi_values.len() != 2 { + return Err(format!("invalid CUBE value: {self:?}")); + } + + buff.extend( + multi_values + .iter() + .flat_map(|point| point.iter().flat_map(|scalar| scalar.to_be_bytes())), + ); } }; Ok(()) @@ -174,41 +204,46 @@ impl PgCube { } } -fn get_f64_from_bytes(bytes: &[u8], start: usize) -> Result { - bytes - .get(start..start + BYTE_WIDTH) - .ok_or(Error::Decode( - format!("Could not decode cube bytes: {:?}", bytes).into(), - ))? - .try_into() - .map(f64::from_be_bytes) - .map_err(|err| Error::Decode(format!("Invalid bytes slice: {:?}", err).into())) -} +fn read_vec(bytes: &mut &[u8]) -> Result, String> { + if bytes.len() % BYTE_WIDTH != 0 { + return Err(format!( + "data length not divisible by {BYTE_WIDTH}: {}", + bytes.len() + )); + } + + let mut out = Vec::with_capacity(bytes.len() / BYTE_WIDTH); -fn deserialize_vector(bytes: &[u8], start_index: usize) -> Result, Error> { - let steps = (bytes.len() - start_index) / BYTE_WIDTH; - (0..steps) - .map(|i| get_f64_from_bytes(bytes, start_index + i * BYTE_WIDTH)) - .collect() + while bytes.has_remaining() { + out.push(bytes.get_f64()); + } + + Ok(out) } -fn deserialize_matrix( - bytes: &[u8], - start_index: usize, - dim: usize, -) -> Result>, Error> { - let step = BYTE_WIDTH * dim; - let steps = (bytes.len() - start_index) / step; - - (0..steps) - .map(|step_idx| { - (0..dim) - .map(|dim_idx| { - get_f64_from_bytes(bytes, start_index + step_idx * step + dim_idx * BYTE_WIDTH) - }) - .collect() - }) - .collect() +fn read_cube(header: &Header, mut bytes: &[u8]) -> Result>, String> { + if bytes.len() != header.data_size() { + return Err(format!( + "expected {} bytes, got {}", + header.data_size(), + bytes.len() + )); + } + + let mut out = Vec::with_capacity(2); + + // Expecting exactly 2 N-dimensional points + for _ in 0..2 { + let mut point = Vec::new(); + + for _ in 0..header.dimensions { + point.push(bytes.get_f64()); + } + + out.push(point); + } + + Ok(out) } fn parse_float_from_str(s: &str, error_msg: &str) -> Result { @@ -268,12 +303,86 @@ fn remove_parentheses(s: &str) -> String { s.trim_matches(|c| c == '(' || c == ')').to_string() } +impl Header { + const PACKED_WIDTH: usize = size_of::(); + + fn encoded_size(&self) -> usize { + Self::PACKED_WIDTH + self.data_size() + } + + fn data_size(&self) -> usize { + if self.is_point { + self.dimensions * BYTE_WIDTH + } else { + self.dimensions * BYTE_WIDTH * 2 + } + } + + fn try_write(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { + if self.dimensions > MAX_DIMENSIONS { + return Err(format!( + "CUBE dimensionality exceeds allowed maximum ({} > {MAX_DIMENSIONS})", + self.dimensions + )); + } + + // Cannot overflow thanks to the above check. + #[allow(clippy::cast_possible_truncation)] + let mut packed = self.dimensions as u32; + + // https://github.com/postgres/postgres/blob/e3ec9dc1bf4983fcedb6f43c71ea12ee26aefc7a/contrib/cube/cubedata.h#L18-L24 + if self.is_point { + packed |= IS_POINT_FLAG; + } + + buff.extend(packed.to_be_bytes()); + + Ok(()) + } + + fn try_read(buf: &mut &[u8]) -> Result { + if buf.len() < Self::PACKED_WIDTH { + return Err(format!( + "expected CUBE data to contain at least {} bytes, got {}", + Self::PACKED_WIDTH, + buf.len() + )); + } + + let packed = buf.get_u32(); + + let is_point = packed & IS_POINT_FLAG != 0; + let dimensions = packed & !IS_POINT_FLAG; + + // can only overflow on 16-bit platforms + let dimensions = usize::try_from(dimensions) + .ok() + .filter(|&it| it <= MAX_DIMENSIONS) + .ok_or_else(|| format!("received CUBE data with higher than expected dimensionality: {dimensions} (is_point: {is_point})"))?; + + Ok(Self { + is_point, + dimensions, + }) + } +} + +impl DecodeError { + fn new(header: &Header, message: String) -> Self { + DecodeError { + is_point: header.is_point, + dimensions: header.dimensions, + message, + } + } +} + #[cfg(test)] mod cube_tests { use std::str::FromStr; - use crate::types::{cube::pg_cube_from_bytes, PgCube}; + use super::PgCube; const POINT_BYTES: &[u8] = &[128, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, 0]; const ZERO_VOLUME_BYTES: &[u8] = &[ @@ -293,7 +402,7 @@ mod cube_tests { #[test] fn can_deserialise_point_type_byes() { - let cube = pg_cube_from_bytes(POINT_BYTES).unwrap(); + let cube = PgCube::from_bytes(POINT_BYTES).unwrap(); assert_eq!(cube, PgCube::Point(2.)) } @@ -311,7 +420,7 @@ mod cube_tests { } #[test] fn can_deserialise_zero_volume_bytes() { - let cube = pg_cube_from_bytes(ZERO_VOLUME_BYTES).unwrap(); + let cube = PgCube::from_bytes(ZERO_VOLUME_BYTES).unwrap(); assert_eq!(cube, PgCube::ZeroVolume(vec![2., 3.])); } @@ -333,7 +442,7 @@ mod cube_tests { #[test] fn can_deserialise_one_dimension_interval_bytes() { - let cube = pg_cube_from_bytes(ONE_DIMENSIONAL_INTERVAL_BYTES).unwrap(); + let cube = PgCube::from_bytes(ONE_DIMENSIONAL_INTERVAL_BYTES).unwrap(); assert_eq!(cube, PgCube::OneDimensionInterval(7., 8.)) } @@ -355,7 +464,7 @@ mod cube_tests { #[test] fn can_deserialise_multi_dimension_2_dimension_byte() { - let cube = pg_cube_from_bytes(MULTI_DIMENSION_2_DIM_BYTES).unwrap(); + let cube = PgCube::from_bytes(MULTI_DIMENSION_2_DIM_BYTES).unwrap(); assert_eq!( cube, PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]]) @@ -396,7 +505,7 @@ mod cube_tests { #[test] fn can_deserialise_multi_dimension_3_dimension_bytes() { - let cube = pg_cube_from_bytes(MULTI_DIMENSION_3_DIM_BYTES).unwrap(); + let cube = PgCube::from_bytes(MULTI_DIMENSION_3_DIM_BYTES).unwrap(); assert_eq!( cube, PgCube::MultiDimension(vec![vec![2., 3., 4.], vec![5., 6., 7.]]) From 63349ded19a77db317053597336aee81235e0369 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 15 Aug 2024 22:15:38 -0700 Subject: [PATCH 09/40] chore: run `cargo fmt` --- sqlx-postgres/src/arguments.rs | 8 ++++-- sqlx-postgres/src/connection/describe.rs | 31 ++++++++++++------------ sqlx-postgres/src/types/bit_vec.rs | 2 +- sqlx-postgres/src/types/chrono/date.rs | 4 ++- sqlx-postgres/src/types/cube.rs | 2 +- sqlx-postgres/src/types/rust_decimal.rs | 20 +++++++-------- sqlx-postgres/src/types/time/date.rs | 7 +++--- sqlx-postgres/src/types/time/datetime.rs | 7 ++++-- 8 files changed, 44 insertions(+), 37 deletions(-) diff --git a/sqlx-postgres/src/arguments.rs b/sqlx-postgres/src/arguments.rs index 23859fbb79..2e7d5fd9d4 100644 --- a/sqlx-postgres/src/arguments.rs +++ b/sqlx-postgres/src/arguments.rs @@ -274,6 +274,10 @@ impl DerefMut for PgArgumentBuffer { } pub(crate) fn value_size_int4_checked(size: usize) -> Result { - i32::try_from(size) - .map_err(|_| format!("value size would overflow in the binary protocol encoding: {size} > {}", i32::MAX)) + i32::try_from(size).map_err(|_| { + format!( + "value size would overflow in the binary protocol encoding: {size} > {}", + i32::MAX + ) + }) } diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index a579f9217a..e53a054a70 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -2,7 +2,7 @@ use crate::error::Error; use crate::ext::ustr::UStr; use crate::message::{ParameterDescription, RowDescription}; use crate::query_as::query_as; -use crate::query_scalar::{query_scalar}; +use crate::query_scalar::query_scalar; use crate::statement::PgStatementMetadata; use crate::type_info::{PgArrayOf, PgCustomType, PgType, PgTypeKind}; use crate::types::Json; @@ -11,8 +11,8 @@ use crate::HashMap; use crate::{PgColumn, PgConnection, PgTypeInfo}; use futures_core::future::BoxFuture; use smallvec::SmallVec; -use std::sync::Arc; use sqlx_core::query_builder::QueryBuilder; +use std::sync::Arc; /// Describes the type of the `pg_type.typtype` column /// @@ -426,7 +426,7 @@ WHERE rngtypid = $1 if meta.columns.len() * 3 > 65535 { tracing::debug!( ?stmt_id, - num_columns=meta.columns.len(), + num_columns = meta.columns.len(), "number of columns in query is too large to pull nullability for" ); } @@ -436,19 +436,18 @@ WHERE rngtypid = $1 // This will include columns that don't have a `relation_id` (are not from a table); // assuming those are a minority of columns, it's less code to _not_ work around it // and just let Postgres return `NULL`. - let mut nullable_query = QueryBuilder::new( - "SELECT NOT pg_attribute.attnotnull FROM ( " - ); - - nullable_query.push_values( - meta.columns.iter().zip(0i32..), - |mut tuple, (column, i)| { - // ({i}::int4, {column.relation_id}::int4, {column.relation_attribute_no}::int2) - tuple.push_bind(i).push_unseparated("::int4"); - tuple.push_bind(column.relation_id).push_unseparated("::int4"); - tuple.push_bind(column.relation_attribute_no).push_bind_unseparated("::int2"); - }, - ); + let mut nullable_query = QueryBuilder::new("SELECT NOT pg_attribute.attnotnull FROM ( "); + + nullable_query.push_values(meta.columns.iter().zip(0i32..), |mut tuple, (column, i)| { + // ({i}::int4, {column.relation_id}::int4, {column.relation_attribute_no}::int2) + tuple.push_bind(i).push_unseparated("::int4"); + tuple + .push_bind(column.relation_id) + .push_unseparated("::int4"); + tuple + .push_bind(column.relation_attribute_no) + .push_bind_unseparated("::int2"); + }); nullable_query.push( ") as col(idx, table_id, col_idx) \ diff --git a/sqlx-postgres/src/types/bit_vec.rs b/sqlx-postgres/src/types/bit_vec.rs index 299aaf828a..dfc3b16922 100644 --- a/sqlx-postgres/src/types/bit_vec.rs +++ b/sqlx-postgres/src/types/bit_vec.rs @@ -1,3 +1,4 @@ +use crate::arguments::value_size_int4_checked; use crate::{ decode::Decode, encode::{Encode, IsNull}, @@ -8,7 +9,6 @@ use crate::{ use bit_vec::BitVec; use sqlx_core::bytes::Buf; use std::{io, mem}; -use crate::arguments::value_size_int4_checked; impl Type for BitVec { fn type_info() -> PgTypeInfo { diff --git a/sqlx-postgres/src/types/chrono/date.rs b/sqlx-postgres/src/types/chrono/date.rs index 5fc0a8f08d..0327d5c45d 100644 --- a/sqlx-postgres/src/types/chrono/date.rs +++ b/sqlx-postgres/src/types/chrono/date.rs @@ -26,7 +26,9 @@ impl Encode<'_, Postgres> for NaiveDate { let days: i32 = (*self - postgres_epoch_date()) .num_days() .try_into() - .map_err(|_| format!("value {self:?} would overflow binary encoding for Postgres DATE"))?; + .map_err(|_| { + format!("value {self:?} would overflow binary encoding for Postgres DATE") + })?; Encode::::encode(days, buf) } diff --git a/sqlx-postgres/src/types/cube.rs b/sqlx-postgres/src/types/cube.rs index a489e31d80..4247798490 100644 --- a/sqlx-postgres/src/types/cube.rs +++ b/sqlx-postgres/src/types/cube.rs @@ -147,7 +147,7 @@ impl PgCube { bytes.len() ), ) - .into()); + .into()); } match (header.is_point, header.dimensions) { diff --git a/sqlx-postgres/src/types/rust_decimal.rs b/sqlx-postgres/src/types/rust_decimal.rs index d94dfe34cd..83a2d0e08a 100644 --- a/sqlx-postgres/src/types/rust_decimal.rs +++ b/sqlx-postgres/src/types/rust_decimal.rs @@ -193,15 +193,9 @@ mod tests { fn zero() { let zero: Decimal = "0".parse().unwrap(); - assert_eq!( - PgNumeric::from(&zero), - PgNumeric::ZERO, - ); + assert_eq!(PgNumeric::from(&zero), PgNumeric::ZERO,); - assert_eq!( - Decimal::try_from(&PgNumeric::ZERO).unwrap(), - Decimal::ZERO - ); + assert_eq!(Decimal::try_from(&PgNumeric::ZERO).unwrap(), Decimal::ZERO); } #[test] @@ -384,7 +378,10 @@ mod tests { let actual_decimal = Decimal::try_from(expected_numeric).unwrap(); assert_eq!(actual_decimal, Decimal::MAX); // Value split by 10,000's to match the expected digits[] - assert_eq!(actual_decimal.mantissa(), 7_9228_1625_1426_4337_5935_4395_0335); + assert_eq!( + actual_decimal.mantissa(), + 7_9228_1625_1426_4337_5935_4395_0335 + ); assert_eq!(actual_decimal.scale(), 0); } @@ -406,7 +403,10 @@ mod tests { let actual_decimal = Decimal::try_from(expected_numeric).unwrap(); assert_eq!(actual_decimal, max_value_max_scale); - assert_eq!(actual_decimal.mantissa(), 79_228_162_514_264_337_593_543_950_335); + assert_eq!( + actual_decimal.mantissa(), + 79_228_162_514_264_337_593_543_950_335 + ); assert_eq!(actual_decimal.scale(), 28); } diff --git a/sqlx-postgres/src/types/time/date.rs b/sqlx-postgres/src/types/time/date.rs index 619537c5af..2afa57ee0d 100644 --- a/sqlx-postgres/src/types/time/date.rs +++ b/sqlx-postgres/src/types/time/date.rs @@ -23,10 +23,9 @@ impl PgHasArrayType for Date { impl Encode<'_, Postgres> for Date { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { // DATE is encoded as number of days since epoch (2000-01-01) - let days: i32 = (*self - PG_EPOCH) - .whole_days() - .try_into() - .map_err(|_| format!("value {self:?} would overflow binary encoding for Postgres DATE"))?; + let days: i32 = (*self - PG_EPOCH).whole_days().try_into().map_err(|_| { + format!("value {self:?} would overflow binary encoding for Postgres DATE") + })?; Encode::::encode(days, buf) } diff --git a/sqlx-postgres/src/types/time/datetime.rs b/sqlx-postgres/src/types/time/datetime.rs index d51002286d..3484116bd1 100644 --- a/sqlx-postgres/src/types/time/datetime.rs +++ b/sqlx-postgres/src/types/time/datetime.rs @@ -37,9 +37,12 @@ impl PgHasArrayType for OffsetDateTime { impl Encode<'_, Postgres> for PrimitiveDateTime { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { // TIMESTAMP is encoded as the microseconds since the epoch - let micros: i64 = (*self - PG_EPOCH.midnight()).whole_microseconds() + let micros: i64 = (*self - PG_EPOCH.midnight()) + .whole_microseconds() .try_into() - .map_err(|_| format!("value {self:?} would overflow binary encoding for Postgres TIME"))?; + .map_err(|_| { + format!("value {self:?} would overflow binary encoding for Postgres TIME") + })?; Encode::::encode(micros, buf) } From dd92def007b3775b51f228ce6d5a2beeb22afc71 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 16 Aug 2024 13:19:13 -0700 Subject: [PATCH 10/40] fix: audit `PgValueRef::get()` and usage sites for bad casts --- sqlx-postgres/src/types/array.rs | 13 +++++++------ sqlx-postgres/src/types/range.rs | 4 ++-- sqlx-postgres/src/types/record.rs | 2 +- sqlx-postgres/src/value.rs | 26 ++++++++++++++++---------- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/sqlx-postgres/src/types/array.rs b/sqlx-postgres/src/types/array.rs index d624222741..f936996788 100644 --- a/sqlx-postgres/src/types/array.rs +++ b/sqlx-postgres/src/types/array.rs @@ -242,6 +242,9 @@ where // length of the array axis let len = buf.get_i32(); + let len = usize::try_from(len) + .map_err(|_| format!("overflow converting array len ({len}) to usize"))?; + // the lower bound, we only support arrays starting from "1" let lower = buf.get_i32(); @@ -249,14 +252,12 @@ where return Err(format!("encountered an array with a lower bound of {lower} in the first dimension; only arrays starting at one are supported").into()); } - let mut elements = Vec::with_capacity(len as usize); + let mut elements = Vec::with_capacity(len); for _ in 0..len { - elements.push(T::decode(PgValueRef::get( - &mut buf, - format, - element_type_info.clone(), - ))?) + let value_ref = PgValueRef::get(&mut buf, format, element_type_info.clone())?; + + elements.push(T::decode(value_ref)?); } Ok(elements) diff --git a/sqlx-postgres/src/types/range.rs b/sqlx-postgres/src/types/range.rs index 82134b4726..5e1346d86c 100644 --- a/sqlx-postgres/src/types/range.rs +++ b/sqlx-postgres/src/types/range.rs @@ -350,7 +350,7 @@ where if !flags.contains(RangeFlags::LB_INF) { let value = - T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone()))?; + T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone())?)?; start = if flags.contains(RangeFlags::LB_INC) { Bound::Included(value) @@ -361,7 +361,7 @@ where if !flags.contains(RangeFlags::UB_INF) { let value = - T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone()))?; + T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone())?)?; end = if flags.contains(RangeFlags::UB_INC) { Bound::Included(value) diff --git a/sqlx-postgres/src/types/record.rs b/sqlx-postgres/src/types/record.rs index a119f29dce..c4eb639368 100644 --- a/sqlx-postgres/src/types/record.rs +++ b/sqlx-postgres/src/types/record.rs @@ -137,7 +137,7 @@ impl<'r> PgRecordDecoder<'r> { self.ind += 1; - T::decode(PgValueRef::get(&mut self.buf, self.fmt, element_type)) + T::decode(PgValueRef::get(&mut self.buf, self.fmt, element_type)?) } PgValueFormat::Text => { diff --git a/sqlx-postgres/src/value.rs b/sqlx-postgres/src/value.rs index ee15412adc..90c015bb0e 100644 --- a/sqlx-postgres/src/value.rs +++ b/sqlx-postgres/src/value.rs @@ -1,11 +1,10 @@ use crate::error::{BoxDynError, UnexpectedNullError}; use crate::{PgTypeInfo, Postgres}; use sqlx_core::bytes::{Buf, Bytes}; +pub(crate) use sqlx_core::value::{Value, ValueRef}; use std::borrow::Cow; use std::str::from_utf8; -pub(crate) use sqlx_core::value::{Value, ValueRef}; - #[derive(Debug, Clone, Copy, Eq, PartialEq)] #[repr(u8)] pub enum PgValueFormat { @@ -31,24 +30,31 @@ pub struct PgValue { } impl<'r> PgValueRef<'r> { - pub(crate) fn get(buf: &mut &'r [u8], format: PgValueFormat, ty: PgTypeInfo) -> Self { - let mut element_len = buf.get_i32(); + pub(crate) fn get( + buf: &mut &'r [u8], + format: PgValueFormat, + ty: PgTypeInfo, + ) -> Result { + let element_len = buf.get_i32(); let element_val = if element_len == -1 { - element_len = 0; None } else { - Some(&buf[..(element_len as usize)]) - }; + let element_len: usize = element_len + .try_into() + .map_err(|_| format!("overflow converting element_len ({element_len}) to usize"))?; - buf.advance(element_len as usize); + let val = &buf[..element_len]; + buf.advance(element_len); + Some(val) + }; - PgValueRef { + Ok(PgValueRef { value: element_val, row: None, type_info: ty, format, - } + }) } pub fn format(&self) -> PgValueFormat { From 31fc4ed1c6ce22c04603e0525f94ba31efae7be9 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 16 Aug 2024 13:21:23 -0700 Subject: [PATCH 11/40] fix: audit `sqlx_postgres::types::bit_vec` for casts involving sign loss --- sqlx-postgres/src/types/bit_vec.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/sqlx-postgres/src/types/bit_vec.rs b/sqlx-postgres/src/types/bit_vec.rs index dfc3b16922..b519a5f24c 100644 --- a/sqlx-postgres/src/types/bit_vec.rs +++ b/sqlx-postgres/src/types/bit_vec.rs @@ -52,15 +52,10 @@ impl Decode<'_, Postgres> for BitVec { let mut bytes = value.as_bytes()?; let len = bytes.get_i32(); - if len < 0 { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "Negative VARBIT length.", - ))? - } + let len = usize::try_from(len).map_err(|_| format!("invalid VARBIT len: {len}"))?; // The smallest amount of data we can read is one byte - let bytes_len = (len as usize + 7) / 8; + let bytes_len = (len + 7) / 8; if bytes.remaining() != bytes_len { Err(io::Error::new( @@ -74,7 +69,7 @@ impl Decode<'_, Postgres> for BitVec { // Chop off zeroes from the back. We get bits in bytes, so if // our bitvec is not in full bytes, extra zeroes are added to // the end. - while bitvec.len() > len as usize { + while bitvec.len() > len { bitvec.pop(); } From de957f9e88db5f8246c12ff13e63cd06d64b72a6 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 16 Aug 2024 13:22:48 -0700 Subject: [PATCH 12/40] fix: audit `sqlx_postgres::types::rust_decimal` for casts involving sign loss --- sqlx-postgres/src/types/rust_decimal.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sqlx-postgres/src/types/rust_decimal.rs b/sqlx-postgres/src/types/rust_decimal.rs index 83a2d0e08a..281bc7e468 100644 --- a/sqlx-postgres/src/types/rust_decimal.rs +++ b/sqlx-postgres/src/types/rust_decimal.rs @@ -50,6 +50,9 @@ impl TryFrom<&'_ PgNumeric> for Decimal { // Postgres returns an empty digit array for 0 return Ok(Decimal::ZERO); } + + let scale = u32::try_from(scale) + .map_err(|_| format!("invalid scale value for Pg NUMERIC: {scale}"))?; let mut value = Decimal::ZERO; @@ -73,7 +76,7 @@ impl TryFrom<&'_ PgNumeric> for Decimal { PgNumericSign::Negative => value.set_sign_negative(true), } - value.rescale(scale as u32); + value.rescale(scale); Ok(value) } From 072139a9033c5506c8b0e68694937069555f0554 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 16 Aug 2024 14:06:38 -0700 Subject: [PATCH 13/40] fix: audit `PgNumeric` and usages for casts involving sign loss --- sqlx-postgres/src/types/bigdecimal.rs | 71 ++++++++++++++++++++----- sqlx-postgres/src/types/numeric.rs | 44 ++++++++++++--- sqlx-postgres/src/types/rust_decimal.rs | 4 +- 3 files changed, 95 insertions(+), 24 deletions(-) diff --git a/sqlx-postgres/src/types/bigdecimal.rs b/sqlx-postgres/src/types/bigdecimal.rs index 5a6e500d32..869f850797 100644 --- a/sqlx-postgres/src/types/bigdecimal.rs +++ b/sqlx-postgres/src/types/bigdecimal.rs @@ -1,7 +1,6 @@ -use std::cmp; - use bigdecimal::BigDecimal; use num_bigint::{BigInt, Sign}; +use std::cmp; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; @@ -26,9 +25,17 @@ impl TryFrom for BigDecimal { type Error = BoxDynError; fn try_from(numeric: PgNumeric) -> Result { - let (digits, sign, weight) = match numeric { + Self::try_from(&numeric) + } +} + +impl TryFrom<&'_ PgNumeric> for BigDecimal { + type Error = BoxDynError; + + fn try_from(numeric: &'_ PgNumeric) -> Result { + let (digits, sign, weight) = match *numeric { PgNumeric::Number { - digits, + ref digits, sign, weight, .. @@ -50,11 +57,27 @@ impl TryFrom for BigDecimal { }; // weight is 0 if the decimal point falls after the first base-10000 digit + // + // `Vec` capacity cannot exceed `isize::MAX` bytes, so this cast can't wrap in practice. + #[allow(clippy::cast_possible_wrap)] let scale = (digits.len() as i64 - weight as i64 - 1) * 4; // no optimized algorithm for base-10 so use base-100 for faster processing let mut cents = Vec::with_capacity(digits.len() * 2); - for digit in &digits { + + #[allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss + )] + for (i, &digit) in digits.iter().enumerate() { + if !PgNumeric::is_valid_digit(digit) { + return Err(format!( + "PgNumeric to BigDecimal: {i}th digit is out of range {digit}" + ) + .into()); + } + cents.push((digit / 100) as u8); cents.push((digit % 100) as u8); } @@ -79,9 +102,16 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric { // FIXME: is there a way to iterate over the digits to avoid the Vec allocation let (sign, base_10) = integer.to_radix_be(10); + let base_10_len = i64::try_from(base_10.len()).map_err(|_| { + format!( + "BigDecimal base-10 length out of range for PgNumeric: {}", + base_10.len() + ) + })?; + // weight is positive power of 10000 // exp is the negative power of 10 - let weight_10 = base_10.len() as i64 - exp; + let weight_10 = base_10_len - exp; // scale is only nonzero when we have fractional digits // since `exp` is the _negative_ decimal exponent, it tells us @@ -103,19 +133,34 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric { base_10.len() / 4 }; - let offset = weight_10.rem_euclid(4) as usize; + // For efficiency, we want to process the base-10 digits in chunks of 4, + // but that means we need to deal with the non-divisible remainder first. + let offset = weight_10.rem_euclid(4); + + // Do a checked conversion to the smallest integer, + // so we can widen arbitrarily without triggering lints. + let offset = u8::try_from(offset).unwrap_or_else(|_| { + panic!("BUG: `offset` should be in the range [0, 4) but is {offset}") + }); let mut digits = Vec::with_capacity(digits_len); - if let Some(first) = base_10.get(..offset) { + if let Some(first) = base_10.get(..offset as usize) { if !first.is_empty() { digits.push(base_10_to_10000(first)); } } else if offset != 0 { - digits.push(base_10_to_10000(&base_10) * 10i16.pow((offset - base_10.len()) as u32)); + // If we didn't hit the `if let Some` branch, + // then `base_10.len()` must strictly be smaller + #[allow(clippy::cast_possible_truncation)] + let power = (offset as usize - base_10.len()) as u32; + + digits.push(base_10_to_10000(&base_10) * 10i16.pow(power)); } - if let Some(rest) = base_10.get(offset..) { + if let Some(rest) = base_10.get(offset as usize..) { + // `chunk.len()` is always between 1 and 4 + #[allow(clippy::cast_possible_truncation)] digits.extend( rest.chunks(4) .map(|chunk| base_10_to_10000(chunk) * 10i16.pow(4 - chunk.len() as u32)), @@ -138,15 +183,13 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric { #[doc=include_str!("bigdecimal-range.md")] impl Encode<'_, Postgres> for BigDecimal { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { - PgNumeric::try_from(self)?.encode(buf); + PgNumeric::try_from(self)?.encode(buf)?; Ok(IsNull::No) } fn size_hint(&self) -> usize { - // BigDecimal::digits() gives us base-10 digits, so we divide by 4 to get base-10000 digits - // and since this is just a hint we just always round up - 8 + (self.digits() / 4 + 1) as usize * 2 + PgNumeric::size_hint(self.digits()) } } diff --git a/sqlx-postgres/src/types/numeric.rs b/sqlx-postgres/src/types/numeric.rs index 6416872913..3a01f2e621 100644 --- a/sqlx-postgres/src/types/numeric.rs +++ b/sqlx-postgres/src/types/numeric.rs @@ -1,4 +1,5 @@ use sqlx_core::bytes::Buf; +use std::num::Saturating; use crate::error::BoxDynError; use crate::PgArgumentBuffer; @@ -83,6 +84,27 @@ impl PgNumeric { scale: 0, }; + pub(crate) fn is_valid_digit(digit: i16) -> bool { + (0..10_000).contains(&digit) + } + + pub(crate) fn size_hint(decimal_digits: u64) -> usize { + let mut size_hint = Saturating(decimal_digits); + + // BigDecimal::digits() gives us base-10 digits, so we divide by 4 to get base-10000 digits + // and since this is just a hint we just always round up + size_hint /= 4; + size_hint += 1; + + // Times two bytes for each base-10000 digit + size_hint *= 2; + + // Plus `weight` and `scale` + size_hint += 8; + + usize::try_from(size_hint.0).unwrap_or(usize::MAX) + } + pub(crate) fn decode(mut buf: &[u8]) -> Result { // https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L874 let num_digits = buf.get_u16(); @@ -104,11 +126,11 @@ impl PgNumeric { } } - /// ### Panics + /// ### Errors /// /// * If `digits.len()` overflows `i16` /// * If any element in `digits` is greater than or equal to 10000 - pub(crate) fn encode(&self, buf: &mut PgArgumentBuffer) { + pub(crate) fn encode(&self, buf: &mut PgArgumentBuffer) -> Result<(), String> { match *self { PgNumeric::Number { ref digits, @@ -116,18 +138,22 @@ impl PgNumeric { scale, weight, } => { - let digits_len: i16 = digits - .len() - .try_into() - .expect("PgNumeric.digits.len() should not overflow i16"); + let digits_len = i16::try_from(digits.len()).map_err(|_| { + format!( + "PgNumeric digits.len() ({}) should not overflow i16", + digits.len() + ) + })?; buf.extend(&digits_len.to_be_bytes()); buf.extend(&weight.to_be_bytes()); buf.extend(&(sign as i16).to_be_bytes()); buf.extend(&scale.to_be_bytes()); - for digit in digits { - debug_assert!(*digit < 10000, "PgNumeric digits must be in base-10000"); + for (i, &digit) in digits.iter().enumerate() { + if Self::is_valid_digit(digit) { + return Err(format!("{i}th PgNumeric digit out of range {digit}")); + } buf.extend(&digit.to_be_bytes()); } @@ -140,5 +166,7 @@ impl PgNumeric { buf.extend(&0_i16.to_be_bytes()); } } + + Ok(()) } } diff --git a/sqlx-postgres/src/types/rust_decimal.rs b/sqlx-postgres/src/types/rust_decimal.rs index 281bc7e468..8321e82811 100644 --- a/sqlx-postgres/src/types/rust_decimal.rs +++ b/sqlx-postgres/src/types/rust_decimal.rs @@ -50,7 +50,7 @@ impl TryFrom<&'_ PgNumeric> for Decimal { // Postgres returns an empty digit array for 0 return Ok(Decimal::ZERO); } - + let scale = u32::try_from(scale) .map_err(|_| format!("invalid scale value for Pg NUMERIC: {scale}"))?; @@ -171,7 +171,7 @@ impl From<&'_ Decimal> for PgNumeric { impl Encode<'_, Postgres> for Decimal { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { - PgNumeric::from(self).encode(buf); + PgNumeric::from(self).encode(buf)?; Ok(IsNull::No) } From 627c289f85f969f1860339a7ff87667c390d0cc7 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 16 Aug 2024 14:08:13 -0700 Subject: [PATCH 14/40] fix: audit `sqlx_postgres::type::int` for bad casts --- sqlx-postgres/src/types/int.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sqlx-postgres/src/types/int.rs b/sqlx-postgres/src/types/int.rs index 465d040f0e..b8255f1b08 100644 --- a/sqlx-postgres/src/types/int.rs +++ b/sqlx-postgres/src/types/int.rs @@ -71,6 +71,8 @@ impl Decode<'_, Postgres> for i8 { return Ok(i8::from_str_radix(text.trim_start_matches('\\'), 8)?); } + // Wrapping is the whole idea. + #[allow(clippy::cast_possible_wrap)] Ok(text.as_bytes()[0] as i8) } } From 8a17e465efc395a6c632d5412ab4e33ab249dc7a Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 16 Aug 2024 14:30:37 -0700 Subject: [PATCH 15/40] fix: audit `sqlx_postgres::types::hstore` for bad casts --- sqlx-postgres/src/types/hstore.rs | 93 ++++++++++++++++++++----------- 1 file changed, 61 insertions(+), 32 deletions(-) diff --git a/sqlx-postgres/src/types/hstore.rs b/sqlx-postgres/src/types/hstore.rs index 0dcc71943c..bb61cc5479 100644 --- a/sqlx-postgres/src/types/hstore.rs +++ b/sqlx-postgres/src/types/hstore.rs @@ -2,11 +2,9 @@ use std::{ collections::{btree_map, BTreeMap}, mem::size_of, ops::{Deref, DerefMut}, - str::from_utf8, + str, }; -use serde::{Deserialize, Serialize}; - use crate::{ decode::Decode, encode::{Encode, IsNull}, @@ -14,6 +12,8 @@ use crate::{ types::Type, PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres, }; +use serde::{Deserialize, Serialize}; +use sqlx_core::bytes::Buf; /// Key-value support (`hstore`) for Postgres. /// @@ -143,41 +143,64 @@ impl<'r> Decode<'r, Postgres> for PgHstore { let mut buf = <&[u8] as Decode>::decode(value)?; let len = read_length(&mut buf)?; - if len < 0 { - Err(format!("hstore, invalid entry count: {len}"))?; - } + let len = + usize::try_from(len).map_err(|_| format!("PgHstore: length out of range: {len}"))?; let mut result = Self::default(); - while !buf.is_empty() { - let key_len = read_length(&mut buf)?; - let key = read_value(&mut buf, key_len)?.ok_or("hstore, key not found")?; + for i in 0..len { + let key = read_string(&mut buf) + .map_err(|e| format!("PgHstore: error reading {i}th key: {e}"))? + .ok_or_else(|| format!("PgHstore: expected {i}th key, got nothing"))?; - let value_len = read_length(&mut buf)?; - let value = read_value(&mut buf, value_len)?; + let value = read_string(&mut buf) + .map_err(|e| format!("PgHstore: error reading value for key {key:?}: {e}"))?; result.insert(key, value); } + if !buf.is_empty() { + tracing::warn!("{} unread bytes at the end of HSTORE value", buf.len()); + } + Ok(result) } } impl Encode<'_, Postgres> for PgHstore { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { - buf.extend_from_slice(&i32::to_be_bytes(self.0.len() as i32)); - - for (key, val) in &self.0 { + buf.extend_from_slice(&i32::to_be_bytes( + self.0 + .len() + .try_into() + .map_err(|_| format!("PgHstore length out of range: {}", self.0.len()))?, + )); + + for (i, (key, val)) in self.0.iter().enumerate() { let key_bytes = key.as_bytes(); - buf.extend_from_slice(&i32::to_be_bytes(key_bytes.len() as i32)); + let key_len = i32::try_from(key_bytes.len()).map_err(|_| { + // Doesn't make sense to print the key itself: it's more than 2 GiB long! + format!( + "PgHstore: length of {i}th key out of range: {} bytes", + key_bytes.len() + ) + })?; + + buf.extend_from_slice(&i32::to_be_bytes(key_len)); buf.extend_from_slice(key_bytes); match val { Some(val) => { let val_bytes = val.as_bytes(); - buf.extend_from_slice(&i32::to_be_bytes(val_bytes.len() as i32)); + let val_len = i32::try_from(val_bytes.len()).map_err(|_| { + format!( + "PgHstore: value length for key {key:?} out of range: {} bytes", + val_bytes.len() + ) + })?; + buf.extend_from_slice(&i32::to_be_bytes(val_len)); buf.extend_from_slice(val_bytes); } None => { @@ -190,30 +213,36 @@ impl Encode<'_, Postgres> for PgHstore { } } -fn read_length(buf: &mut &[u8]) -> Result { - let (bytes, rest) = buf.split_at(size_of::()); - - *buf = rest; +fn read_length(buf: &mut &[u8]) -> Result { + if buf.len() < size_of::() { + return Err(format!( + "expected {} bytes, got {}", + size_of::(), + buf.len() + )); + } - Ok(i32::from_be_bytes( - bytes - .try_into() - .map_err(|err| format!("hstore, reading length: {err}"))?, - )) + Ok(buf.get_i32()) } -fn read_value(buf: &mut &[u8], len: i32) -> Result, BoxDynError> { +fn read_string(buf: &mut &[u8]) -> Result, String> { + let len = read_length(buf)?; + match len { - len if len <= 0 => Ok(None), + -1 => Ok(None), len => { - let (val, rest) = buf.split_at(len as usize); + let len = + usize::try_from(len).map_err(|_| format!("string length out of range: {len}"))?; + + if buf.len() < len { + return Err(format!("expected {len} bytes, got {}", buf.len())); + } + let (val, rest) = buf.split_at(len); *buf = rest; Ok(Some( - from_utf8(val) - .map_err(|err| format!("hstore, reading value: {err}"))? - .to_string(), + str::from_utf8(val).map_err(|e| e.to_string())?.to_string(), )) } } @@ -258,7 +287,7 @@ mod test { } #[test] - #[should_panic(expected = "hstore, invalid entry count: -5")] + #[should_panic(expected = "PgHstore: length out of range: -5")] fn hstore_deserialize_buffer_length_error() { let buf = PgValueRef { value: Some(&[255, 255, 255, 251]), From 2cb621715a3f4dcdc91266209d7a3c3f140a0abc Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 16 Aug 2024 16:37:22 -0700 Subject: [PATCH 16/40] fix: audit `sqlx_postgres::types::array` for bad casts --- sqlx-postgres/src/types/array.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sqlx-postgres/src/types/array.rs b/sqlx-postgres/src/types/array.rs index f936996788..9b8be63412 100644 --- a/sqlx-postgres/src/types/array.rs +++ b/sqlx-postgres/src/types/array.rs @@ -174,7 +174,14 @@ where } } - buf.extend(&(self.len() as i32).to_be_bytes()); // len + let array_len = i32::try_from(self.len()).map_err(|_| { + format!( + "encoded array length is too large for Postgres: {}", + self.len() + ) + })?; + + buf.extend(array_len.to_be_bytes()); // len buf.extend(&1_i32.to_be_bytes()); // lower bound for element in self.iter() { From ec5326e5c9be7afbbb0c811df4859a2b8eb12641 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Sat, 17 Aug 2024 04:54:15 -0700 Subject: [PATCH 17/40] refactor: rename `sqlx_core::io::{Encode, Decode}` for clarity --- sqlx-core/src/io/decode.rs | 8 ++++---- sqlx-core/src/io/encode.rs | 15 ++++++++------- sqlx-core/src/io/mod.rs | 4 ++-- sqlx-core/src/net/socket/buffered.rs | 20 ++++++++++++-------- 4 files changed, 26 insertions(+), 21 deletions(-) diff --git a/sqlx-core/src/io/decode.rs b/sqlx-core/src/io/decode.rs index 2f397127b8..798fc8befb 100644 --- a/sqlx-core/src/io/decode.rs +++ b/sqlx-core/src/io/decode.rs @@ -2,13 +2,13 @@ use bytes::Bytes; use crate::error::Error; -pub trait Decode<'de, Context = ()> +pub trait ProtocolDecode<'de, Context = ()> where Self: Sized, { fn decode(buf: Bytes) -> Result where - Self: Decode<'de, ()>, + Self: ProtocolDecode<'de, ()>, { Self::decode_with(buf, ()) } @@ -16,13 +16,13 @@ where fn decode_with(buf: Bytes, context: Context) -> Result; } -impl Decode<'_> for Bytes { +impl ProtocolDecode<'_> for Bytes { fn decode_with(buf: Bytes, _: ()) -> Result { Ok(buf) } } -impl Decode<'_> for () { +impl ProtocolDecode<'_> for () { fn decode_with(_: Bytes, _: ()) -> Result<(), Error> { Ok(()) } diff --git a/sqlx-core/src/io/encode.rs b/sqlx-core/src/io/encode.rs index a417ef9eb9..a603ea9325 100644 --- a/sqlx-core/src/io/encode.rs +++ b/sqlx-core/src/io/encode.rs @@ -1,16 +1,17 @@ -pub trait Encode<'en, Context = ()> { - fn encode(&self, buf: &mut Vec) +pub trait ProtocolEncode<'en, Context = ()> { + fn encode(&self, buf: &mut Vec) -> Result<(), crate::Error> where - Self: Encode<'en, ()>, + Self: ProtocolEncode<'en, ()>, { - self.encode_with(buf, ()); + self.encode_with(buf, ()) } - fn encode_with(&self, buf: &mut Vec, context: Context); + fn encode_with(&self, buf: &mut Vec, context: Context) -> Result<(), crate::Error>; } -impl<'en, C> Encode<'en, C> for &'_ [u8] { - fn encode_with(&self, buf: &mut Vec, _: C) { +impl<'en, C> ProtocolEncode<'en, C> for &'_ [u8] { + fn encode_with(&self, buf: &mut Vec, _context: C) -> Result<(), crate::Error> { buf.extend_from_slice(self); + Ok(()) } } diff --git a/sqlx-core/src/io/mod.rs b/sqlx-core/src/io/mod.rs index 84a09d7c3b..2765abe02f 100644 --- a/sqlx-core/src/io/mod.rs +++ b/sqlx-core/src/io/mod.rs @@ -9,8 +9,8 @@ mod read_buf; pub use buf::BufExt; pub use buf_mut::BufMutExt; //pub use buf_stream::BufStream; -pub use decode::Decode; -pub use encode::Encode; +pub use decode::ProtocolDecode; +pub use encode::ProtocolEncode; pub use read_buf::ReadBuf; #[cfg(not(feature = "_rt-tokio"))] diff --git a/sqlx-core/src/net/socket/buffered.rs b/sqlx-core/src/net/socket/buffered.rs index 5e032e2b1d..25e1276432 100644 --- a/sqlx-core/src/net/socket/buffered.rs +++ b/sqlx-core/src/net/socket/buffered.rs @@ -4,7 +4,7 @@ use std::{cmp, io}; use crate::error::Error; -use crate::io::{AsyncRead, AsyncReadExt, Decode, Encode}; +use crate::io::{AsyncRead, AsyncReadExt, ProtocolDecode, ProtocolEncode}; // Tokio, async-std, and std all use this as the default capacity for their buffered I/O. const DEFAULT_BUF_SIZE: usize = 8192; @@ -59,32 +59,36 @@ impl BufferedSocket { pub async fn read<'de, T>(&mut self, byte_len: usize) -> Result where - T: Decode<'de, ()>, + T: ProtocolDecode<'de, ()>, { self.read_with(byte_len, ()).await } pub async fn read_with<'de, T, C>(&mut self, byte_len: usize, context: C) -> Result where - T: Decode<'de, C>, + T: ProtocolDecode<'de, C>, { T::decode_with(self.read_buffered(byte_len).await?.freeze(), context) } - pub fn write<'en, T>(&mut self, value: T) + #[inline(always)] + pub fn write<'en, T>(&mut self, value: T) -> Result<(), Error> where - T: Encode<'en, ()>, + T: ProtocolEncode<'en, ()>, { self.write_with(value, ()) } - pub fn write_with<'en, T, C>(&mut self, value: T, context: C) + #[inline(always)] + pub fn write_with<'en, T, C>(&mut self, value: T, context: C) -> Result<(), Error> where - T: Encode<'en, C>, + T: ProtocolEncode<'en, C>, { - value.encode_with(self.write_buf.buf_mut(), context); + value.encode_with(self.write_buf.buf_mut(), context)?; self.write_buf.bytes_written = self.write_buf.buf.len(); self.write_buf.sanity_check(); + + Ok(()) } pub async fn flush(&mut self) -> io::Result<()> { From c2f7339004207c4d57fa53d88675409f03e8d085 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Sat, 17 Aug 2024 04:54:40 -0700 Subject: [PATCH 18/40] refactor(postgres): make better use of traits to improve protocol handling --- sqlx-postgres/src/advisory_lock.rs | 3 +- sqlx-postgres/src/arguments.rs | 1 + sqlx-postgres/src/connection/describe.rs | 26 +-- sqlx-postgres/src/connection/establish.rs | 30 ++- sqlx-postgres/src/connection/executor.rs | 92 +++++---- sqlx-postgres/src/connection/mod.rs | 30 +-- sqlx-postgres/src/connection/sasl.rs | 12 +- sqlx-postgres/src/connection/stream.rs | 50 ++--- sqlx-postgres/src/copy.rs | 48 ++--- sqlx-postgres/src/io/buf_mut.rs | 62 +++--- sqlx-postgres/src/io/mod.rs | 125 ++++++++++++ sqlx-postgres/src/listener.rs | 6 +- sqlx-postgres/src/message/authentication.rs | 14 +- sqlx-postgres/src/message/backend_key_data.rs | 12 +- sqlx-postgres/src/message/bind.rs | 76 ++++--- sqlx-postgres/src/message/close.rs | 37 ++-- sqlx-postgres/src/message/command_complete.rs | 21 +- sqlx-postgres/src/message/copy.rs | 125 ++++++++---- sqlx-postgres/src/message/data_row.rs | 74 +++++-- sqlx-postgres/src/message/describe.rs | 162 +++++++-------- sqlx-postgres/src/message/execute.rs | 80 +++++--- sqlx-postgres/src/message/flush.rs | 30 ++- sqlx-postgres/src/message/mod.rs | 162 ++++++++++++--- sqlx-postgres/src/message/notification.rs | 12 +- .../src/message/parameter_description.rs | 14 +- sqlx-postgres/src/message/parameter_status.rs | 17 +- sqlx-postgres/src/message/parse.rs | 61 ++++-- sqlx-postgres/src/message/parse_complete.rs | 13 ++ sqlx-postgres/src/message/password.rs | 191 ++++++++++-------- sqlx-postgres/src/message/query.rs | 28 ++- sqlx-postgres/src/message/ready_for_query.rs | 10 +- sqlx-postgres/src/message/response.rs | 66 +++--- sqlx-postgres/src/message/row_description.rs | 24 ++- sqlx-postgres/src/message/sasl.rs | 78 +++++-- sqlx-postgres/src/message/ssl_request.rs | 31 ++- sqlx-postgres/src/message/startup.rs | 13 +- sqlx-postgres/src/message/sync.rs | 19 +- sqlx-postgres/src/message/terminate.rs | 19 +- sqlx-postgres/src/transaction.rs | 5 +- sqlx-postgres/src/types/oid.rs | 6 - 40 files changed, 1222 insertions(+), 663 deletions(-) create mode 100644 sqlx-postgres/src/message/parse_complete.rs diff --git a/sqlx-postgres/src/advisory_lock.rs b/sqlx-postgres/src/advisory_lock.rs index 982744137f..d1aef176fb 100644 --- a/sqlx-postgres/src/advisory_lock.rs +++ b/sqlx-postgres/src/advisory_lock.rs @@ -414,7 +414,8 @@ impl<'lock, C: AsMut> Drop for PgAdvisoryLockGuard<'lock, C> { // The `async fn` versions can safely use the prepared statement protocol, // but this is the safest way to queue a query to execute on the next opportunity. conn.as_mut() - .queue_simple_query(self.lock.get_release_query()); + .queue_simple_query(self.lock.get_release_query()) + .expect("BUG: PgAdvisoryLock::get_release_query() somehow too long for protocol"); } } } diff --git a/sqlx-postgres/src/arguments.rs b/sqlx-postgres/src/arguments.rs index 2e7d5fd9d4..bc7e861c52 100644 --- a/sqlx-postgres/src/arguments.rs +++ b/sqlx-postgres/src/arguments.rs @@ -145,6 +145,7 @@ impl<'q> Arguments<'q> for PgArguments { write!(writer, "${}", self.buffer.count) } + #[inline(always)] fn len(&self) -> usize { self.buffer.count } diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index e53a054a70..d9c55201a0 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -1,5 +1,6 @@ use crate::error::Error; use crate::ext::ustr::UStr; +use crate::io::StatementId; use crate::message::{ParameterDescription, RowDescription}; use crate::query_as::query_as; use crate::query_scalar::query_scalar; @@ -27,10 +28,12 @@ enum TypType { Range, } -impl TryFrom for TypType { +impl TryFrom for TypType { type Error = (); - fn try_from(t: u8) -> Result { + fn try_from(t: i8) -> Result { + let t = u8::try_from(t).or(Err(()))?; + let t = match t { b'b' => Self::Base, b'c' => Self::Composite, @@ -66,10 +69,12 @@ enum TypCategory { Unknown, } -impl TryFrom for TypCategory { +impl TryFrom for TypCategory { type Error = (); - fn try_from(c: u8) -> Result { + fn try_from(c: i8) -> Result { + let c = u8::try_from(c).or(Err(()))?; + let c = match c { b'A' => Self::Array, b'B' => Self::Boolean, @@ -209,8 +214,8 @@ impl PgConnection { .fetch_one(&mut *self) .await?; - let typ_type = TypType::try_from(typ_type as u8); - let category = TypCategory::try_from(category as u8); + let typ_type = TypType::try_from(typ_type); + let category = TypCategory::try_from(category); match (typ_type, category) { (Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await, @@ -416,7 +421,7 @@ WHERE rngtypid = $1 pub(crate) async fn get_nullable_for_columns( &mut self, - stmt_id: Oid, + stmt_id: StatementId, meta: &PgStatementMetadata, ) -> Result>, Error> { if meta.columns.is_empty() { @@ -486,13 +491,10 @@ WHERE rngtypid = $1 /// and returns `None` for all others. async fn nullables_from_explain( &mut self, - stmt_id: Oid, + stmt_id: StatementId, params_len: usize, ) -> Result>, Error> { - let mut explain = format!( - "EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE sqlx_s_{}", - stmt_id.0 - ); + let mut explain = format!("EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE {stmt_id}"); let mut comma = false; if params_len > 0 { diff --git a/sqlx-postgres/src/connection/establish.rs b/sqlx-postgres/src/connection/establish.rs index 83b9843a12..a730f5c161 100644 --- a/sqlx-postgres/src/connection/establish.rs +++ b/sqlx-postgres/src/connection/establish.rs @@ -3,11 +3,10 @@ use crate::HashMap; use crate::common::StatementCache; use crate::connection::{sasl, stream::PgStream}; use crate::error::Error; -use crate::io::Decode; +use crate::io::StatementId; use crate::message::{ - Authentication, BackendKeyData, MessageFormat, Password, ReadyForQuery, Startup, + Authentication, BackendKeyData, BackendMessageFormat, Password, ReadyForQuery, Startup, }; -use crate::types::Oid; use crate::{PgConnectOptions, PgConnection}; // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3 @@ -44,13 +43,13 @@ impl PgConnection { params.push(("options", options)); } - stream - .send(Startup { - username: Some(&options.username), - database: options.database.as_deref(), - params: ¶ms, - }) - .await?; + stream.write(Startup { + username: Some(&options.username), + database: options.database.as_deref(), + params: ¶ms, + })?; + + stream.flush().await?; // The server then uses this information and the contents of // its configuration files (such as pg_hba.conf) to determine whether the connection is @@ -64,7 +63,7 @@ impl PgConnection { loop { let message = stream.recv().await?; match message.format { - MessageFormat::Authentication => match message.decode()? { + BackendMessageFormat::Authentication => match message.decode()? { Authentication::Ok => { // the authentication exchange is successfully completed // do nothing; no more information is required to continue @@ -108,7 +107,7 @@ impl PgConnection { } }, - MessageFormat::BackendKeyData => { + BackendMessageFormat::BackendKeyData => { // provides secret-key data that the frontend must save if it wants to be // able to issue cancel requests later @@ -118,10 +117,9 @@ impl PgConnection { secret_key = data.secret_key; } - MessageFormat::ReadyForQuery => { + BackendMessageFormat::ReadyForQuery => { // start-up is completed. The frontend can now issue commands - transaction_status = - ReadyForQuery::decode(message.contents)?.transaction_status; + transaction_status = message.decode::()?.transaction_status; break; } @@ -142,7 +140,7 @@ impl PgConnection { transaction_status, transaction_depth: 0, pending_ready_for_query_count: 0, - next_statement_id: Oid(1), + next_statement_id: StatementId::NAMED_START, cache_statement: StatementCache::new(options.statement_cache_capacity), cache_type_oid: HashMap::new(), cache_type_info: HashMap::new(), diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index bb73db1e38..d2f6bcddf1 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -1,13 +1,13 @@ use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; +use crate::io::{PortalId, StatementId}; use crate::logger::QueryLogger; use crate::message::{ - self, Bind, Close, CommandComplete, DataRow, MessageFormat, ParameterDescription, Parse, Query, - RowDescription, + self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse, + ParseComplete, Query, RowDescription, }; use crate::statement::PgStatementMetadata; -use crate::types::Oid; use crate::{ statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo, PgValueFormat, Postgres, @@ -16,6 +16,7 @@ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::{pin_mut, TryStreamExt}; +use sqlx_core::arguments::Arguments; use sqlx_core::Either; use std::{borrow::Cow, sync::Arc}; @@ -24,9 +25,9 @@ async fn prepare( sql: &str, parameters: &[PgTypeInfo], metadata: Option>, -) -> Result<(Oid, Arc), Error> { +) -> Result<(StatementId, Arc), Error> { let id = conn.next_statement_id; - conn.next_statement_id.incr_one(); + conn.next_statement_id = id.next(); // build a list of type OIDs to send to the database in the PARSE command // we have not yet started the query sequence, so we are *safe* to cleanly make @@ -42,15 +43,15 @@ async fn prepare( conn.wait_until_ready().await?; // next we send the PARSE command to the server - conn.stream.write(Parse { + conn.stream.write_msg(Parse { param_types: ¶m_types, query: sql, statement: id, - }); + })?; if metadata.is_none() { // get the statement columns and parameters - conn.stream.write(message::Describe::Statement(id)); + conn.stream.write_msg(message::Describe::Statement(id))?; } // we ask for the server to immediately send us the result of the PARSE command @@ -58,9 +59,7 @@ async fn prepare( conn.stream.flush().await?; // indicates that the SQL query string is now successfully parsed and has semantic validity - conn.stream - .recv_expect(MessageFormat::ParseComplete) - .await?; + conn.stream.recv_expect::().await?; let metadata = if let Some(metadata) = metadata { // each SYNC produces one READY FOR QUERY @@ -95,18 +94,18 @@ async fn prepare( } async fn recv_desc_params(conn: &mut PgConnection) -> Result { - conn.stream - .recv_expect(MessageFormat::ParameterDescription) - .await + conn.stream.recv_expect().await } async fn recv_desc_rows(conn: &mut PgConnection) -> Result, Error> { let rows: Option = match conn.stream.recv().await? { // describes the rows that will be returned when the statement is eventually executed - message if message.format == MessageFormat::RowDescription => Some(message.decode()?), + message if message.format == BackendMessageFormat::RowDescription => { + Some(message.decode()?) + } // no data would be returned if this statement was executed - message if message.format == MessageFormat::NoData => None, + message if message.format == BackendMessageFormat::NoData => None, message => { return Err(err_protocol!( @@ -125,12 +124,12 @@ impl PgConnection { // we need to wait for the [CloseComplete] to be returned from the server while count > 0 { match self.stream.recv().await? { - message if message.format == MessageFormat::PortalSuspended => { + message if message.format == BackendMessageFormat::PortalSuspended => { // there was an open portal // this can happen if the last time a statement was used it was not fully executed } - message if message.format == MessageFormat::CloseComplete => { + message if message.format == BackendMessageFormat::CloseComplete => { // successfully closed the statement (and freed up the server resources) count -= 1; } @@ -147,8 +146,11 @@ impl PgConnection { Ok(()) } + #[inline(always)] pub(crate) fn write_sync(&mut self) { - self.stream.write(message::Sync); + self.stream + .write_msg(message::Sync) + .expect("BUG: Sync should not be too big for protocol"); // all SYNC messages will return a ReadyForQuery self.pending_ready_for_query_count += 1; @@ -163,7 +165,7 @@ impl PgConnection { // optional metadata that was provided by the user, this means they are reusing // a statement object metadata: Option>, - ) -> Result<(Oid, Arc), Error> { + ) -> Result<(StatementId, Arc), Error> { if let Some(statement) = self.cache_statement.get_mut(sql) { return Ok((*statement).clone()); } @@ -172,7 +174,7 @@ impl PgConnection { if store_to_cache && self.cache_statement.is_enabled() { if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) { - self.stream.write(Close::Statement(id)); + self.stream.write_msg(Close::Statement(id))?; self.write_sync(); self.stream.flush().await?; @@ -201,6 +203,14 @@ impl PgConnection { let mut metadata: Arc; let format = if let Some(mut arguments) = arguments { + // Check this before we write anything to the stream. + let num_params = i16::try_from(arguments.len()).map_err(|_| { + err_protocol!( + "PgConnection::run(): too many arguments for query: {}", + arguments.len() + ) + })?; + // prepare the statement if this our first time executing it // always return the statement ID here let (statement, metadata_) = self @@ -216,21 +226,21 @@ impl PgConnection { self.wait_until_ready().await?; // bind to attach the arguments to the statement and create a portal - self.stream.write(Bind { - portal: None, + self.stream.write_msg(Bind { + portal: PortalId::UNNAMED, statement, formats: &[PgValueFormat::Binary], - num_params: arguments.types.len() as i16, + num_params, params: &arguments.buffer, result_formats: &[PgValueFormat::Binary], - }); + })?; // executes the portal up to the passed limit // the protocol-level limit acts nearly identically to the `LIMIT` in SQL - self.stream.write(message::Execute { - portal: None, + self.stream.write_msg(message::Execute { + portal: PortalId::UNNAMED, limit: limit.into(), - }); + })?; // From https://www.postgresql.org/docs/current/protocol-flow.html: // // "An unnamed portal is destroyed at the end of the transaction, or as @@ -240,7 +250,7 @@ impl PgConnection { // we ask the database server to close the unnamed portal and free the associated resources // earlier - after the execution of the current query. - self.stream.write(message::Close::Portal(None)); + self.stream.write_msg(Close::Portal(PortalId::UNNAMED))?; // finally, [Sync] asks postgres to process the messages that we sent and respond with // a [ReadyForQuery] message when it's completely done. Theoretically, we could send @@ -253,7 +263,7 @@ impl PgConnection { PgValueFormat::Binary } else { // Query will trigger a ReadyForQuery - self.stream.write(Query(query)); + self.stream.write_msg(Query(query))?; self.pending_ready_for_query_count += 1; // metadata starts out as "nothing" @@ -270,12 +280,12 @@ impl PgConnection { let message = self.stream.recv().await?; match message.format { - MessageFormat::BindComplete - | MessageFormat::ParseComplete - | MessageFormat::ParameterDescription - | MessageFormat::NoData + BackendMessageFormat::BindComplete + | BackendMessageFormat::ParseComplete + | BackendMessageFormat::ParameterDescription + | BackendMessageFormat::NoData // unnamed portal has been closed - | MessageFormat::CloseComplete + | BackendMessageFormat::CloseComplete => { // harmless messages to ignore } @@ -284,7 +294,7 @@ impl PgConnection { // exactly one of these messages: CommandComplete, // EmptyQueryResponse (if the portal was created from an // empty query string), ErrorResponse, or PortalSuspended" - MessageFormat::CommandComplete => { + BackendMessageFormat::CommandComplete => { // a SQL command completed normally let cc: CommandComplete = message.decode()?; @@ -295,16 +305,16 @@ impl PgConnection { })); } - MessageFormat::EmptyQueryResponse => { + BackendMessageFormat::EmptyQueryResponse => { // empty query string passed to an unprepared execute } // Message::ErrorResponse is handled in self.stream.recv() // incomplete query execution has finished - MessageFormat::PortalSuspended => {} + BackendMessageFormat::PortalSuspended => {} - MessageFormat::RowDescription => { + BackendMessageFormat::RowDescription => { // indicates that a *new* set of rows are about to be returned let (columns, column_names) = self .handle_row_description(Some(message.decode()?), false) @@ -317,7 +327,7 @@ impl PgConnection { }); } - MessageFormat::DataRow => { + BackendMessageFormat::DataRow => { logger.increment_rows_returned(); // one of the set of rows returned by a SELECT, FETCH, etc query @@ -331,7 +341,7 @@ impl PgConnection { r#yield!(Either::Right(row)); } - MessageFormat::ReadyForQuery => { + BackendMessageFormat::ReadyForQuery => { // processing of the query string is complete self.handle_ready_for_query(message)?; break; diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 1c7a468240..9003dcb338 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -8,9 +8,10 @@ use futures_util::FutureExt; use crate::common::StatementCache; use crate::error::Error; use crate::ext::ustr::UStr; -use crate::io::Decode; +use crate::io::StatementId; use crate::message::{ - Close, Message, MessageFormat, Query, ReadyForQuery, Terminate, TransactionStatus, + BackendMessageFormat, Close, Query, ReadyForQuery, ReceivedMessage, Terminate, + TransactionStatus, }; use crate::statement::PgStatementMetadata; use crate::transaction::Transaction; @@ -47,10 +48,10 @@ pub struct PgConnection { // sequence of statement IDs for use in preparing statements // in PostgreSQL, the statement is prepared to a user-supplied identifier - next_statement_id: Oid, + next_statement_id: StatementId, // cache statement by query string to the id and columns - cache_statement: StatementCache<(Oid, Arc)>, + cache_statement: StatementCache<(StatementId, Arc)>, // cache user-defined types by id <-> info cache_type_info: HashMap, @@ -82,7 +83,7 @@ impl PgConnection { while self.pending_ready_for_query_count > 0 { let message = self.stream.recv().await?; - if let MessageFormat::ReadyForQuery = message.format { + if let BackendMessageFormat::ReadyForQuery = message.format { self.handle_ready_for_query(message)?; } } @@ -91,10 +92,7 @@ impl PgConnection { } async fn recv_ready_for_query(&mut self) -> Result<(), Error> { - let r: ReadyForQuery = self - .stream - .recv_expect(MessageFormat::ReadyForQuery) - .await?; + let r: ReadyForQuery = self.stream.recv_expect().await?; self.pending_ready_for_query_count -= 1; self.transaction_status = r.transaction_status; @@ -102,9 +100,10 @@ impl PgConnection { Ok(()) } - fn handle_ready_for_query(&mut self, message: Message) -> Result<(), Error> { + #[inline(always)] + fn handle_ready_for_query(&mut self, message: ReceivedMessage) -> Result<(), Error> { self.pending_ready_for_query_count -= 1; - self.transaction_status = ReadyForQuery::decode(message.contents)?.transaction_status; + self.transaction_status = message.decode::()?.transaction_status; Ok(()) } @@ -112,9 +111,12 @@ impl PgConnection { /// Queue a simple query (not prepared) to execute the next time this connection is used. /// /// Used for rolling back transactions and releasing advisory locks. - pub(crate) fn queue_simple_query(&mut self, query: &str) { + #[inline(always)] + pub(crate) fn queue_simple_query(&mut self, query: &str) -> Result<(), Error> { + self.stream.write_msg(Query(query))?; self.pending_ready_for_query_count += 1; - self.stream.write(Query(query)); + + Ok(()) } } @@ -184,7 +186,7 @@ impl Connection for PgConnection { self.wait_until_ready().await?; while let Some((id, _)) = self.cache_statement.remove_lru() { - self.stream.write(Close::Statement(id)); + self.stream.write_msg(Close::Statement(id))?; cleared += 1; } diff --git a/sqlx-postgres/src/connection/sasl.rs b/sqlx-postgres/src/connection/sasl.rs index 11f36eec56..729cc1fcc5 100644 --- a/sqlx-postgres/src/connection/sasl.rs +++ b/sqlx-postgres/src/connection/sasl.rs @@ -1,8 +1,6 @@ use crate::connection::stream::PgStream; use crate::error::Error; -use crate::message::{ - Authentication, AuthenticationSasl, MessageFormat, SaslInitialResponse, SaslResponse, -}; +use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse}; use crate::PgConnectOptions; use hmac::{Hmac, Mac}; use rand::Rng; @@ -76,7 +74,7 @@ pub(crate) async fn authenticate( }) .await?; - let cont = match stream.recv_expect(MessageFormat::Authentication).await? { + let cont = match stream.recv_expect().await? { Authentication::SaslContinue(data) => data, auth => { @@ -147,7 +145,7 @@ pub(crate) async fn authenticate( stream.send(SaslResponse(&client_final_message)).await?; - let data = match stream.recv_expect(MessageFormat::Authentication).await? { + let data = match stream.recv_expect().await? { Authentication::SaslFinal(data) => data, auth => { @@ -172,10 +170,10 @@ fn gen_nonce() -> String { // ;; a valid "value". let nonce: String = std::iter::repeat(()) .map(|()| { - let mut c = rng.gen_range(0x21..0x7F) as u8; + let mut c = rng.gen_range(0x21u8..0x7F); while c == 0x2C { - c = rng.gen_range(0x21..0x7F) as u8; + c = rng.gen_range(0x21u8..0x7F); } c diff --git a/sqlx-postgres/src/connection/stream.rs b/sqlx-postgres/src/connection/stream.rs index 0cbf405d25..a7c7d1aea0 100644 --- a/sqlx-postgres/src/connection/stream.rs +++ b/sqlx-postgres/src/connection/stream.rs @@ -9,8 +9,10 @@ use sqlx_core::bytes::{Buf, Bytes}; use crate::connection::tls::MaybeUpgradeTls; use crate::error::Error; -use crate::io::{Decode, Encode}; -use crate::message::{Message, MessageFormat, Notice, Notification, ParameterStatus}; +use crate::message::{ + BackendMessage, BackendMessageFormat, EncodeMessage, FrontendMessage, Notice, Notification, + ParameterStatus, ReceivedMessage, +}; use crate::net::{self, BufferedSocket, Socket}; use crate::{PgConnectOptions, PgDatabaseError, PgSeverity}; @@ -55,59 +57,51 @@ impl PgStream { }) } - pub(crate) async fn send<'en, T>(&mut self, message: T) -> Result<(), Error> + #[inline(always)] + pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> { + self.write(EncodeMessage(message)) + } + + pub(crate) async fn send(&mut self, message: T) -> Result<(), Error> where - T: Encode<'en>, + T: FrontendMessage, { - self.write(message); + self.write_msg(message)?; self.flush().await?; Ok(()) } // Expect a specific type and format - pub(crate) async fn recv_expect<'de, T: Decode<'de>>( - &mut self, - format: MessageFormat, - ) -> Result { - let message = self.recv().await?; - - if message.format != format { - return Err(err_protocol!( - "expecting {:?} but received {:?}", - format, - message.format - )); - } - - message.decode() + pub(crate) async fn recv_expect(&mut self) -> Result { + self.recv().await?.decode() } - pub(crate) async fn recv_unchecked(&mut self) -> Result { + pub(crate) async fn recv_unchecked(&mut self) -> Result { // all packets in postgres start with a 5-byte header // this header contains the message type and the total length of the message let mut header: Bytes = self.inner.read(5).await?; - let format = MessageFormat::try_from_u8(header.get_u8())?; + let format = BackendMessageFormat::try_from_u8(header.get_u8())?; let size = (header.get_u32() - 4) as usize; let contents = self.inner.read(size).await?; - Ok(Message { format, contents }) + Ok(ReceivedMessage { format, contents }) } // Get the next message from the server // May wait for more data from the server - pub(crate) async fn recv(&mut self) -> Result { + pub(crate) async fn recv(&mut self) -> Result { loop { let message = self.recv_unchecked().await?; match message.format { - MessageFormat::ErrorResponse => { + BackendMessageFormat::ErrorResponse => { // An error returned from the database server. return Err(PgDatabaseError(message.decode()?).into()); } - MessageFormat::NotificationResponse => { + BackendMessageFormat::NotificationResponse => { if let Some(buffer) = &mut self.notifications { let notification: Notification = message.decode()?; let _ = buffer.send(notification).await; @@ -116,7 +110,7 @@ impl PgStream { } } - MessageFormat::ParameterStatus => { + BackendMessageFormat::ParameterStatus => { // informs the frontend about the current (initial) // setting of backend parameters @@ -135,7 +129,7 @@ impl PgStream { continue; } - MessageFormat::NoticeResponse => { + BackendMessageFormat::NoticeResponse => { // do we need this to be more configurable? // if you are reading this comment and think so, open an issue diff --git a/sqlx-postgres/src/copy.rs b/sqlx-postgres/src/copy.rs index 98efbba051..347877c36b 100644 --- a/sqlx-postgres/src/copy.rs +++ b/sqlx-postgres/src/copy.rs @@ -11,7 +11,8 @@ use crate::error::{Error, Result}; use crate::ext::async_stream::TryAsyncStream; use crate::io::AsyncRead; use crate::message::{ - CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query, + BackendMessageFormat, CommandComplete, CopyData, CopyDone, CopyFail, CopyInResponse, + CopyOutResponse, CopyResponseData, Query, ReadyForQuery, }; use crate::pool::{Pool, PoolConnection}; use crate::Postgres; @@ -138,7 +139,7 @@ impl PgPoolCopyExt for Pool { #[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"] pub struct PgCopyIn> { conn: Option, - response: CopyResponse, + response: CopyResponseData, } impl> PgCopyIn { @@ -146,8 +147,8 @@ impl> PgCopyIn { conn.wait_until_ready().await?; conn.stream.send(Query(statement)).await?; - let response = match conn.stream.recv_expect(MessageFormat::CopyInResponse).await { - Ok(res) => res, + let response = match conn.stream.recv_expect::().await { + Ok(res) => res.0, Err(e) => { conn.stream.recv().await?; return Err(e); @@ -168,7 +169,7 @@ impl> PgCopyIn { /// Returns the number of columns expected in the input. pub fn num_columns(&self) -> usize { assert_eq!( - self.response.num_columns as usize, + self.response.num_columns.unsigned_abs() as usize, self.response.format_codes.len(), "num_columns does not match format_codes.len()" ); @@ -261,9 +262,7 @@ impl> PgCopyIn { match e.code() { Some(Cow::Borrowed("57014")) => { // postgres abort received error code - conn.stream - .recv_expect(MessageFormat::ReadyForQuery) - .await?; + conn.stream.recv_expect::().await?; Ok(()) } _ => Err(Error::Database(e)), @@ -283,11 +282,7 @@ impl> PgCopyIn { .expect("CopyWriter::finish: conn taken illegally"); conn.stream.send(CopyDone).await?; - let cc: CommandComplete = match conn - .stream - .recv_expect(MessageFormat::CommandComplete) - .await - { + let cc: CommandComplete = match conn.stream.recv_expect().await { Ok(cc) => cc, Err(e) => { conn.stream.recv().await?; @@ -295,9 +290,7 @@ impl> PgCopyIn { } }; - conn.stream - .recv_expect(MessageFormat::ReadyForQuery) - .await?; + conn.stream.recv_expect::().await?; Ok(cc.rows_affected()) } @@ -306,9 +299,11 @@ impl> PgCopyIn { impl> Drop for PgCopyIn { fn drop(&mut self) { if let Some(mut conn) = self.conn.take() { - conn.stream.write(CopyFail::new( - "PgCopyIn dropped without calling finish() or fail()", - )); + conn.stream + .write_msg(CopyFail::new( + "PgCopyIn dropped without calling finish() or fail()", + )) + .expect("BUG: PgCopyIn abort message should not be too large"); } } } @@ -320,24 +315,21 @@ async fn pg_begin_copy_out<'c, C: DerefMut + Send + 'c>( conn.wait_until_ready().await?; conn.stream.send(Query(statement)).await?; - let _: CopyResponse = conn - .stream - .recv_expect(MessageFormat::CopyOutResponse) - .await?; + let _: CopyOutResponse = conn.stream.recv_expect().await?; let stream: TryAsyncStream<'c, Bytes> = try_stream! { loop { match conn.stream.recv().await { Err(e) => { - conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?; + conn.stream.recv_expect::().await?; return Err(e); }, Ok(msg) => match msg.format { - MessageFormat::CopyData => r#yield!(msg.decode::>()?.0), - MessageFormat::CopyDone => { + BackendMessageFormat::CopyData => r#yield!(msg.decode::>()?.0), + BackendMessageFormat::CopyDone => { let _ = msg.decode::()?; - conn.stream.recv_expect(MessageFormat::CommandComplete).await?; - conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?; + conn.stream.recv_expect::().await?; + conn.stream.recv_expect::().await?; return Ok(()) }, _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format)) diff --git a/sqlx-postgres/src/io/buf_mut.rs b/sqlx-postgres/src/io/buf_mut.rs index b5688f3bf3..ff6fe03df3 100644 --- a/sqlx-postgres/src/io/buf_mut.rs +++ b/sqlx-postgres/src/io/buf_mut.rs @@ -1,54 +1,64 @@ -use crate::types::Oid; +use crate::io::{PortalId, StatementId}; pub trait PgBufMutExt { - fn put_length_prefixed(&mut self, f: F) + fn put_length_prefixed(&mut self, f: F) -> Result<(), crate::Error> where - F: FnOnce(&mut Vec); + F: FnOnce(&mut Vec) -> Result<(), crate::Error>; - fn put_statement_name(&mut self, id: Oid); + fn put_statement_name(&mut self, id: StatementId); - fn put_portal_name(&mut self, id: Option); + fn put_portal_name(&mut self, id: PortalId); } impl PgBufMutExt for Vec { // writes a length-prefixed message, this is used when encoding nearly all messages as postgres // wants us to send the length of the often-variable-sized messages up front - fn put_length_prefixed(&mut self, f: F) + fn put_length_prefixed(&mut self, write_contents: F) -> Result<(), crate::Error> where - F: FnOnce(&mut Vec), + F: FnOnce(&mut Vec) -> Result<(), crate::Error>, { // reserve space to write the prefixed length let offset = self.len(); self.extend(&[0; 4]); // write the main body of the message - f(self); + let write_result = write_contents(self); - // now calculate the size of what we wrote and set the length value - let size = (self.len() - offset) as i32; - self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes()); + let size_result = write_result.and_then(|_| { + let size = self.len() - offset; + i32::try_from(size) + .map_err(|_| err_protocol!("message size out of range for Pg protocol: {size")) + }); + + match size_result { + Ok(size) => { + // now calculate the size of what we wrote and set the length value + self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes()); + Ok(()) + } + Err(e) => { + // Put the buffer back to where it was. + self.truncate(offset); + Err(e) + } + } } // writes a statement name by ID #[inline] - fn put_statement_name(&mut self, id: Oid) { - // N.B. if you change this don't forget to update it in ../describe.rs - self.extend(b"sqlx_s_"); - - self.extend(itoa::Buffer::new().format(id.0).as_bytes()); - - self.push(0); + fn put_statement_name(&mut self, id: StatementId) { + let _: Result<(), ()> = id.write_name(|s| { + self.extend_from_slice(s.as_bytes()); + Ok(()) + }); } // writes a portal name by ID #[inline] - fn put_portal_name(&mut self, id: Option) { - if let Some(id) = id { - self.extend(b"sqlx_p_"); - - self.extend(itoa::Buffer::new().format(id.0).as_bytes()); - } - - self.push(0); + fn put_portal_name(&mut self, id: PortalId) { + let _: Result<(), ()> = id.write_name(|s| { + self.extend_from_slice(s.as_bytes()); + Ok(()) + }); } } diff --git a/sqlx-postgres/src/io/mod.rs b/sqlx-postgres/src/io/mod.rs index 1a6d070257..f90db85d93 100644 --- a/sqlx-postgres/src/io/mod.rs +++ b/sqlx-postgres/src/io/mod.rs @@ -1,5 +1,130 @@ mod buf_mut; pub use buf_mut::PgBufMutExt; +use std::fmt; +use std::fmt::{Display, Formatter}; +use std::num::{NonZeroU32, Saturating}; pub(crate) use sqlx_core::io::*; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub(crate) struct StatementId(IdInner); + +#[allow(dead_code)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub(crate) struct PortalId(IdInner); + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +struct IdInner(Option); + +impl StatementId { + pub const UNNAMED: Self = Self(IdInner::UNNAMED); + + pub const NAMED_START: Self = Self(IdInner::NAMED_START); + + #[cfg(test)] + pub const TEST_VAL: Self = Self(IdInner::TEST_VAL); + + const NAME_PREFIX: &'static str = "sqlx_s_"; + + pub fn next(&self) -> Self { + Self(self.0.next()) + } + + pub fn name_len(&self) -> Saturating { + self.0.name_len(Self::NAME_PREFIX) + } + + // There's no common trait implemented by `Formatter` and `Vec` for this purpose; + // we're deliberately avoiding the formatting machinery because it's known to be slow. + pub fn write_name(&self, write: impl FnMut(&str) -> Result<(), E>) -> Result<(), E> { + self.0.write_name(Self::NAME_PREFIX, write) + } +} + +impl Display for StatementId { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + self.write_name(|s| f.write_str(s)) + } +} + +#[allow(dead_code)] +impl PortalId { + // None selects the unnamed portal + pub const UNNAMED: Self = PortalId(IdInner::UNNAMED); + + pub const NAMED_START: Self = PortalId(IdInner::NAMED_START); + + #[cfg(test)] + pub const TEST_VAL: Self = Self(IdInner::TEST_VAL); + + const NAME_PREFIX: &'static str = "sqlx_p_"; + + /// If ID represents a named portal, return the next ID, wrapping on overflow. + /// + /// If this ID represents the unnamed portal, return the same. + pub fn next(&self) -> Self { + Self(self.0.next()) + } + + /// Calculate the number of bytes that will be written by [`Self::write_name()`]. + pub fn name_len(&self) -> Saturating { + self.0.name_len(Self::NAME_PREFIX) + } + + pub fn write_name(&self, write: impl FnMut(&str) -> Result<(), E>) -> Result<(), E> { + self.0.write_name(Self::NAME_PREFIX, write) + } +} + +impl IdInner { + const UNNAMED: Self = Self(None); + + const NAMED_START: Self = Self(Some(NonZeroU32::MIN)); + + #[cfg(test)] + pub const TEST_VAL: Self = Self(NonZeroU32::new(1234567890)); + + #[inline(always)] + fn next(&self) -> Self { + Self( + self.0 + .map(|id| id.checked_add(1).unwrap_or(NonZeroU32::MIN)), + ) + } + + #[inline(always)] + fn name_len(&self, name_prefix: &str) -> Saturating { + let mut len = Saturating(0); + + if let Some(id) = self.0 { + len += name_prefix.len(); + // estimate the length of the ID in decimal + // `.ilog10()` can't panic since the value is never zero + len += id.get().ilog10() as usize; + // add one to compensate for `ilog10()` rounding down. + len += 1; + } + + // count the NUL terminator + len += 1; + + len + } + + #[inline(always)] + fn write_name( + &self, + name_prefix: &str, + mut write: impl FnMut(&str) -> Result<(), E>, + ) -> Result<(), E> { + if let Some(id) = self.0 { + write(name_prefix)?; + write(itoa::Buffer::new().format(id.get()))?; + } + + write("\0")?; + + Ok(()) + } +} diff --git a/sqlx-postgres/src/listener.rs b/sqlx-postgres/src/listener.rs index ca4f78a275..43bd3c8ff5 100644 --- a/sqlx-postgres/src/listener.rs +++ b/sqlx-postgres/src/listener.rs @@ -11,7 +11,7 @@ use sqlx_core::Either; use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; -use crate::message::{MessageFormat, Notification}; +use crate::message::{BackendMessageFormat, Notification}; use crate::pool::PoolOptions; use crate::pool::{Pool, PoolConnection}; use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres}; @@ -277,12 +277,12 @@ impl PgListener { match message.format { // We've received an async notification, return it. - MessageFormat::NotificationResponse => { + BackendMessageFormat::NotificationResponse => { return Ok(Some(PgNotification(message.decode()?))); } // Mark the connection as ready for another query - MessageFormat::ReadyForQuery => { + BackendMessageFormat::ReadyForQuery => { self.connection().await?.pending_ready_for_query_count -= 1; } diff --git a/sqlx-postgres/src/message/authentication.rs b/sqlx-postgres/src/message/authentication.rs index 2e55c11f93..3a3cf7ff6e 100644 --- a/sqlx-postgres/src/message/authentication.rs +++ b/sqlx-postgres/src/message/authentication.rs @@ -4,10 +4,10 @@ use memchr::memchr; use sqlx_core::bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::Decode; +use crate::io::ProtocolDecode; +use crate::message::{BackendMessage, BackendMessageFormat}; use base64::prelude::{Engine as _, BASE64_STANDARD}; - // On startup, the server sends an appropriate authentication request message, // to which the frontend must reply with an appropriate authentication // response message (such as a password). @@ -60,8 +60,10 @@ pub enum Authentication { SaslFinal(AuthenticationSaslFinal), } -impl Decode<'_> for Authentication { - fn decode_with(mut buf: Bytes, _: ()) -> Result { +impl BackendMessage for Authentication { + const FORMAT: BackendMessageFormat = BackendMessageFormat::Authentication; + + fn decode_body(mut buf: Bytes) -> Result { Ok(match buf.get_u32() { 0 => Authentication::Ok, @@ -129,7 +131,7 @@ pub struct AuthenticationSaslContinue { pub message: String, } -impl Decode<'_> for AuthenticationSaslContinue { +impl ProtocolDecode<'_> for AuthenticationSaslContinue { fn decode_with(buf: Bytes, _: ()) -> Result { let mut iterations: u32 = 4096; let mut salt = Vec::new(); @@ -173,7 +175,7 @@ pub struct AuthenticationSaslFinal { pub verifier: Vec, } -impl Decode<'_> for AuthenticationSaslFinal { +impl ProtocolDecode<'_> for AuthenticationSaslFinal { fn decode_with(buf: Bytes, _: ()) -> Result { let mut verifier = Vec::new(); diff --git a/sqlx-postgres/src/message/backend_key_data.rs b/sqlx-postgres/src/message/backend_key_data.rs index d89df65fb0..f2dc2f232f 100644 --- a/sqlx-postgres/src/message/backend_key_data.rs +++ b/sqlx-postgres/src/message/backend_key_data.rs @@ -2,7 +2,7 @@ use byteorder::{BigEndian, ByteOrder}; use sqlx_core::bytes::Bytes; use crate::error::Error; -use crate::io::Decode; +use crate::message::{BackendMessage, BackendMessageFormat}; /// Contains cancellation key data. The frontend must save these values if it /// wishes to be able to issue `CancelRequest` messages later. @@ -15,8 +15,10 @@ pub struct BackendKeyData { pub secret_key: u32, } -impl Decode<'_> for BackendKeyData { - fn decode_with(buf: Bytes, _: ()) -> Result { +impl BackendMessage for BackendKeyData { + const FORMAT: BackendMessageFormat = BackendMessageFormat::BackendKeyData; + + fn decode_body(buf: Bytes) -> Result { let process_id = BigEndian::read_u32(&buf); let secret_key = BigEndian::read_u32(&buf[4..]); @@ -31,7 +33,7 @@ impl Decode<'_> for BackendKeyData { fn test_decode_backend_key_data() { const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; - let m = BackendKeyData::decode(DATA.into()).unwrap(); + let m = BackendKeyData::decode_body(DATA.into()).unwrap(); assert_eq!(m.process_id, 10182); assert_eq!(m.secret_key, 2303903019); @@ -43,6 +45,6 @@ fn bench_decode_backend_key_data(b: &mut test::Bencher) { const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; b.iter(|| { - BackendKeyData::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); + BackendKeyData::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap(); }); } diff --git a/sqlx-postgres/src/message/bind.rs b/sqlx-postgres/src/message/bind.rs index b8db9679bb..83631fea5c 100644 --- a/sqlx-postgres/src/message/bind.rs +++ b/sqlx-postgres/src/message/bind.rs @@ -1,15 +1,15 @@ -use crate::io::Encode; -use crate::io::PgBufMutExt; -use crate::types::Oid; +use crate::io::{PgBufMutExt, PortalId, StatementId}; +use crate::message::{FrontendMessage, FrontendMessageFormat}; use crate::PgValueFormat; +use std::num::Saturating; #[derive(Debug)] pub struct Bind<'a> { - /// The ID of the destination portal (`None` selects the unnamed portal). - pub portal: Option, + /// The ID of the destination portal (`PortalId::UNNAMED` selects the unnamed portal). + pub portal: PortalId, /// The id of the source prepared statement. - pub statement: Oid, + pub statement: StatementId, /// The parameter format codes. Each must presently be zero (text) or one (binary). /// @@ -19,6 +19,8 @@ pub struct Bind<'a> { pub formats: &'a [PgValueFormat], /// The number of parameters. + /// + /// May be different from `formats.len()` pub num_params: i16, /// The value of each parameter, in the indicated format. @@ -33,31 +35,59 @@ pub struct Bind<'a> { pub result_formats: &'a [PgValueFormat], } -impl Encode<'_> for Bind<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'B'); +impl FrontendMessage for Bind<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Bind; + + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + size += self.portal.name_len(); + size += self.statement.name_len(); + + // Parameter formats and length prefix + size += 2; + size += self.formats.len(); + + // `num_params` + size += 2; + + size += self.params.len(); + + // Result formats and length prefix + size += 2; + size += self.result_formats.len(); + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), crate::Error> { + buf.put_portal_name(self.portal); + + buf.put_statement_name(self.statement); + + let formats_len = i16::try_from(self.formats.len()).map_err(|_| { + err_protocol!("too many parameter format codes ({})", self.formats.len()) + })?; - buf.put_length_prefixed(|buf| { - buf.put_portal_name(self.portal); + buf.extend(formats_len.to_be_bytes()); - buf.put_statement_name(self.statement); + for &format in self.formats { + buf.extend((format as i16).to_be_bytes()); + } - buf.extend(&(self.formats.len() as i16).to_be_bytes()); + buf.extend(self.num_params.to_be_bytes()); - for &format in self.formats { - buf.extend(&(format as i16).to_be_bytes()); - } + buf.extend(self.params); - buf.extend(&self.num_params.to_be_bytes()); + let result_formats_len = i16::try_from(self.formats.len()) + .map_err(|_| err_protocol!("too many result format codes ({})", self.formats.len()))?; - buf.extend(self.params); + buf.extend(result_formats_len.to_be_bytes()); - buf.extend(&(self.result_formats.len() as i16).to_be_bytes()); + for &format in self.result_formats { + buf.extend((format as i16).to_be_bytes()); + } - for &format in self.result_formats { - buf.extend(&(format as i16).to_be_bytes()); - } - }); + Ok(()) } } diff --git a/sqlx-postgres/src/message/close.rs b/sqlx-postgres/src/message/close.rs index 0ffa638c0b..172f244c17 100644 --- a/sqlx-postgres/src/message/close.rs +++ b/sqlx-postgres/src/message/close.rs @@ -1,6 +1,6 @@ -use crate::io::Encode; -use crate::io::PgBufMutExt; -use crate::types::Oid; +use crate::io::{PgBufMutExt, PortalId, StatementId}; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use std::num::Saturating; const CLOSE_PORTAL: u8 = b'P'; const CLOSE_STATEMENT: u8 = b'S'; @@ -8,18 +8,27 @@ const CLOSE_STATEMENT: u8 = b'S'; #[derive(Debug)] #[allow(dead_code)] pub enum Close { - Statement(Oid), - // None selects the unnamed portal - Portal(Option), + Statement(StatementId), + Portal(PortalId), } -impl Encode<'_> for Close { - fn encode_with(&self, buf: &mut Vec, _: ()) { - // 15 bytes for 1-digit statement/portal IDs - buf.reserve(20); - buf.push(b'C'); +impl FrontendMessage for Close { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Close; - buf.put_length_prefixed(|buf| match self { + fn body_size_hint(&self) -> Saturating { + // Either `CLOSE_PORTAL` or `CLOSE_STATEMENT` + let mut size = Saturating(1); + + match self { + Close::Statement(id) => size += id.name_len(), + Close::Portal(id) => size += id.name_len(), + } + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), crate::Error> { + match self { Close::Statement(id) => { buf.push(CLOSE_STATEMENT); buf.put_statement_name(*id); @@ -29,6 +38,8 @@ impl Encode<'_> for Close { buf.push(CLOSE_PORTAL); buf.put_portal_name(*id); } - }) + } + + Ok(()) } } diff --git a/sqlx-postgres/src/message/command_complete.rs b/sqlx-postgres/src/message/command_complete.rs index c2c8e1580e..eb33c512d9 100644 --- a/sqlx-postgres/src/message/command_complete.rs +++ b/sqlx-postgres/src/message/command_complete.rs @@ -3,7 +3,7 @@ use memchr::memrchr; use sqlx_core::bytes::Bytes; use crate::error::Error; -use crate::io::Decode; +use crate::message::{BackendMessage, BackendMessageFormat}; #[derive(Debug)] pub struct CommandComplete { @@ -12,10 +12,11 @@ pub struct CommandComplete { tag: Bytes, } -impl Decode<'_> for CommandComplete { - #[inline] - fn decode_with(buf: Bytes, _: ()) -> Result { - Ok(CommandComplete { tag: buf }) +impl BackendMessage for CommandComplete { + const FORMAT: BackendMessageFormat = BackendMessageFormat::CommandComplete; + + fn decode_body(bytes: Bytes) -> Result { + Ok(CommandComplete { tag: bytes }) } } @@ -35,7 +36,7 @@ impl CommandComplete { fn test_decode_command_complete_for_insert() { const DATA: &[u8] = b"INSERT 0 1214\0"; - let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); + let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap(); assert_eq!(cc.rows_affected(), 1214); } @@ -44,7 +45,7 @@ fn test_decode_command_complete_for_insert() { fn test_decode_command_complete_for_begin() { const DATA: &[u8] = b"BEGIN\0"; - let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); + let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap(); assert_eq!(cc.rows_affected(), 0); } @@ -53,7 +54,7 @@ fn test_decode_command_complete_for_begin() { fn test_decode_command_complete_for_update() { const DATA: &[u8] = b"UPDATE 5\0"; - let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); + let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap(); assert_eq!(cc.rows_affected(), 5); } @@ -64,7 +65,7 @@ fn bench_decode_command_complete(b: &mut test::Bencher) { const DATA: &[u8] = b"INSERT 0 1214\0"; b.iter(|| { - let _ = CommandComplete::decode(test::black_box(Bytes::from_static(DATA))); + let _ = CommandComplete::decode_body(test::black_box(Bytes::from_static(DATA))); }); } @@ -73,7 +74,7 @@ fn bench_decode_command_complete(b: &mut test::Bencher) { fn bench_decode_command_complete_rows_affected(b: &mut test::Bencher) { const DATA: &[u8] = b"INSERT 0 1214\0"; - let data = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); + let data = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap(); b.iter(|| { let _rows = test::black_box(&data).rows_affected(); diff --git a/sqlx-postgres/src/message/copy.rs b/sqlx-postgres/src/message/copy.rs index db0e7398cf..837d849a06 100644 --- a/sqlx-postgres/src/message/copy.rs +++ b/sqlx-postgres/src/message/copy.rs @@ -1,15 +1,25 @@ use crate::error::Result; -use crate::io::{BufExt, BufMutExt, Decode, Encode}; -use sqlx_core::bytes::{Buf, BufMut, Bytes}; +use crate::io::BufMutExt; +use crate::message::{ + BackendMessage, BackendMessageFormat, FrontendMessage, FrontendMessageFormat, +}; +use sqlx_core::bytes::{Buf, Bytes}; +use sqlx_core::Error; +use std::num::Saturating; use std::ops::Deref; /// The same structure is sent for both `CopyInResponse` and `CopyOutResponse` -pub struct CopyResponse { +pub struct CopyResponseData { pub format: i8, pub num_columns: i16, pub format_codes: Vec, } +pub struct CopyInResponse(pub CopyResponseData); + +#[allow(dead_code)] +pub struct CopyOutResponse(pub CopyResponseData); + pub struct CopyData(pub B); pub struct CopyFail { @@ -18,14 +28,15 @@ pub struct CopyFail { pub struct CopyDone; -impl Decode<'_> for CopyResponse { - fn decode_with(mut buf: Bytes, _: ()) -> Result { +impl CopyResponseData { + #[inline] + fn decode(mut buf: Bytes) -> Result { let format = buf.get_i8(); let num_columns = buf.get_i16(); let format_codes = (0..num_columns).map(|_| buf.get_i16()).collect(); - Ok(CopyResponse { + Ok(CopyResponseData { format, num_columns, format_codes, @@ -33,40 +44,65 @@ impl Decode<'_> for CopyResponse { } } -impl Decode<'_> for CopyData { - fn decode_with(buf: Bytes, _: ()) -> Result { - // well.. that was easy - Ok(CopyData(buf)) +impl BackendMessage for CopyInResponse { + const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyInResponse; + + #[inline(always)] + fn decode_body(buf: Bytes) -> std::result::Result { + Ok(Self(CopyResponseData::decode(buf)?)) } } -impl> Encode<'_> for CopyData { - fn encode_with(&self, buf: &mut Vec, _context: ()) { - buf.push(b'd'); - buf.put_u32(self.0.len() as u32 + 4); - buf.extend_from_slice(&self.0); +impl BackendMessage for CopyOutResponse { + const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyOutResponse; + + #[inline(always)] + fn decode_body(buf: Bytes) -> std::result::Result { + Ok(Self(CopyResponseData::decode(buf)?)) } } -impl Decode<'_> for CopyFail { - fn decode_with(mut buf: Bytes, _: ()) -> Result { - Ok(CopyFail { - message: buf.get_str_nul()?, - }) +impl BackendMessage for CopyData { + const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyData; + + #[inline(always)] + fn decode_body(buf: Bytes) -> std::result::Result { + Ok(Self(buf)) + } +} + +impl> FrontendMessage for CopyData { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyData; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(self.0.len()) + } + + #[inline(always)] + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + buf.extend_from_slice(&self.0); + Ok(()) } } -impl Encode<'_> for CopyFail { - fn encode_with(&self, buf: &mut Vec, _: ()) { - let len = 4 + self.message.len() + 1; +impl FrontendMessage for CopyFail { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyFail; - buf.push(b'f'); // to pay respects - buf.put_u32(len as u32); + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(self.message.len()) + } + + #[inline(always)] + fn encode_body(&self, buf: &mut Vec) -> std::result::Result<(), Error> { buf.put_str_nul(&self.message); + Ok(()) } } impl CopyFail { + #[inline(always)] pub fn new(msg: impl Into) -> CopyFail { CopyFail { message: msg.into(), @@ -74,23 +110,32 @@ impl CopyFail { } } -impl Decode<'_> for CopyDone { - fn decode_with(buf: Bytes, _: ()) -> Result { - if buf.is_empty() { - Ok(CopyDone) - } else { - Err(err_protocol!( - "expected no data for CopyDone, got: {:?}", - buf - )) - } +impl FrontendMessage for CopyDone { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyDone; + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(0) + } + + #[inline(always)] + fn encode_body(&self, _buf: &mut Vec) -> std::result::Result<(), Error> { + Ok(()) } } -impl Encode<'_> for CopyDone { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.reserve(4); - buf.push(b'c'); - buf.put_u32(4); +impl BackendMessage for CopyDone { + const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyDone; + + #[inline(always)] + fn decode_body(bytes: Bytes) -> std::result::Result { + if !bytes.is_empty() { + // Not fatal but may indicate a protocol change + tracing::debug!( + "Postgres backend returned non-empty message for CopyDone: \"{}\"", + bytes.escape_ascii() + ) + } + + Ok(CopyDone) } } diff --git a/sqlx-postgres/src/message/data_row.rs b/sqlx-postgres/src/message/data_row.rs index 3e08d22f20..ae9d0d9b26 100644 --- a/sqlx-postgres/src/message/data_row.rs +++ b/sqlx-postgres/src/message/data_row.rs @@ -1,10 +1,9 @@ -use std::ops::Range; - use byteorder::{BigEndian, ByteOrder}; use sqlx_core::bytes::Bytes; +use std::ops::Range; use crate::error::Error; -use crate::io::Decode; +use crate::message::{BackendMessage, BackendMessageFormat}; /// A row of data from the database. #[derive(Debug)] @@ -26,25 +25,55 @@ impl DataRow { } } -impl Decode<'_> for DataRow { - fn decode_with(buf: Bytes, _: ()) -> Result { +impl BackendMessage for DataRow { + const FORMAT: BackendMessageFormat = BackendMessageFormat::DataRow; + + fn decode_body(buf: Bytes) -> Result { + if buf.len() < 2 { + return Err(err_protocol!( + "expected at least 2 bytes, got {}", + buf.len() + )); + } + let cnt = BigEndian::read_u16(&buf) as usize; let mut values = Vec::with_capacity(cnt); - let mut offset = 2; + let mut offset: u32 = 2; for _ in 0..cnt { + let value_start = offset + .checked_add(4) + .ok_or_else(|| err_protocol!("next value start out of range (offset: {offset})"))?; + + // widen both to a larger type for a safe comparison + if (buf.len() as u64) < (value_start as u64) { + return Err(err_protocol!( + "expected 4 bytes at offset {offset}, got {}", + (value_start as u64) - (buf.len() as u64) + )); + } + // Length of the column value, in bytes (this count does not include itself). // Can be zero. As a special case, -1 indicates a NULL column value. // No value bytes follow in the NULL case. + // + // we know `offset` is within range of `buf.len()` from the above check + #[allow(clippy::cast_possible_truncation)] let length = BigEndian::read_i32(&buf[(offset as usize)..]); - offset += 4; - if length < 0 { - values.push(None); + if let Ok(length) = u32::try_from(length) { + let value_end = value_start.checked_add(length).ok_or_else(|| { + err_protocol!("value_start + length out of range ({offset} + {length})") + })?; + + values.push(Some(value_start..value_end)); + offset = value_end; } else { - values.push(Some(offset..(offset + length as u32))); - offset += length as u32; + // Negative values signify NULL + values.push(None); + // `value_start` is actually the next value now. + offset = value_start; } } @@ -57,9 +86,22 @@ impl Decode<'_> for DataRow { #[test] fn test_decode_data_row() { - const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P"; - - let row = DataRow::decode(DATA.into()).unwrap(); + const DATA: &[u8] = b"\ + \x00\x08\ + \xff\xff\xff\xff\ + \x00\x00\x00\x04\ + \x00\x00\x00\n\ + \xff\xff\xff\xff\ + \x00\x00\x00\x04\ + \x00\x00\x00\x14\ + \xff\xff\xff\xff\ + \x00\x00\x00\x04\ + \x00\x00\x00(\ + \xff\xff\xff\xff\ + \x00\x00\x00\x04\ + \x00\x00\x00P"; + + let row = DataRow::decode_body(DATA.into()).unwrap(); assert_eq!(row.values.len(), 8); @@ -78,7 +120,7 @@ fn test_decode_data_row() { fn bench_data_row_get(b: &mut test::Bencher) { const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P"; - let row = DataRow::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); + let row = DataRow::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap(); b.iter(|| { let _value = test::black_box(&row).get(3); @@ -91,6 +133,6 @@ fn bench_decode_data_row(b: &mut test::Bencher) { const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P"; b.iter(|| { - let _ = DataRow::decode(test::black_box(Bytes::from_static(DATA))); + let _ = DataRow::decode_body(test::black_box(Bytes::from_static(DATA))); }); } diff --git a/sqlx-postgres/src/message/describe.rs b/sqlx-postgres/src/message/describe.rs index 382f6e70f5..d6ea7e89cc 100644 --- a/sqlx-postgres/src/message/describe.rs +++ b/sqlx-postgres/src/message/describe.rs @@ -1,127 +1,103 @@ -use crate::io::Encode; -use crate::io::PgBufMutExt; -use crate::types::Oid; +use crate::io::{PgBufMutExt, PortalId, StatementId}; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; const DESCRIBE_PORTAL: u8 = b'P'; const DESCRIBE_STATEMENT: u8 = b'S'; -// [Describe] will emit both a [RowDescription] and a [ParameterDescription] message - +/// Note: will emit both a RowDescription and a ParameterDescription message #[derive(Debug)] #[allow(dead_code)] pub enum Describe { - UnnamedStatement, - Statement(Oid), - - UnnamedPortal, - Portal(Oid), + Statement(StatementId), + Portal(PortalId), } -impl Encode<'_> for Describe { - fn encode_with(&self, buf: &mut Vec, _: ()) { - // 15 bytes for 1-digit statement/portal IDs - buf.reserve(20); - buf.push(b'D'); - - buf.put_length_prefixed(|buf| { - match self { - // #[likely] - Describe::Statement(id) => { - buf.push(DESCRIBE_STATEMENT); - buf.put_statement_name(*id); - } - - Describe::UnnamedPortal => { - buf.push(DESCRIBE_PORTAL); - buf.push(0); - } - - Describe::UnnamedStatement => { - buf.push(DESCRIBE_STATEMENT); - buf.push(0); - } - - Describe::Portal(id) => { - buf.push(DESCRIBE_PORTAL); - buf.put_portal_name(Some(*id)); - } - } - }); - } -} +impl FrontendMessage for Describe { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Describe; -#[test] -fn test_encode_describe_portal() { - const EXPECTED: &[u8] = b"D\0\0\0\x0EPsqlx_p_5\0"; + fn body_size_hint(&self) -> Saturating { + // Either `DESCRIBE_PORTAL` or `DESCRIBE_STATEMENT` + let mut size = Saturating(1); - let mut buf = Vec::new(); - let m = Describe::Portal(Oid(5)); + match self { + Describe::Statement(id) => size += id.name_len(), + Describe::Portal(id) => size += id.name_len(), + } - m.encode(&mut buf); + size + } - assert_eq!(buf, EXPECTED); -} + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + match self { + // #[likely] + Describe::Statement(id) => { + buf.push(DESCRIBE_STATEMENT); + buf.put_statement_name(*id); + } -#[test] -fn test_encode_describe_unnamed_portal() { - const EXPECTED: &[u8] = b"D\0\0\0\x06P\0"; + Describe::Portal(id) => { + buf.push(DESCRIBE_PORTAL); + buf.put_portal_name(*id); + } + } - let mut buf = Vec::new(); - let m = Describe::UnnamedPortal; + Ok(()) + } +} - m.encode(&mut buf); +#[cfg(test)] +mod tests { + use crate::message::FrontendMessage; - assert_eq!(buf, EXPECTED); -} + use super::{Describe, PortalId, StatementId}; -#[test] -fn test_encode_describe_statement() { - const EXPECTED: &[u8] = b"D\0\0\0\x0ESsqlx_s_5\0"; + #[test] + fn test_encode_describe_portal() { + const EXPECTED: &[u8] = b"D\0\0\0\x17Psqlx_p_1234567890\0"; - let mut buf = Vec::new(); - let m = Describe::Statement(Oid(5)); + let mut buf = Vec::new(); + let m = Describe::Portal(PortalId::TEST_VAL); - m.encode(&mut buf); + m.encode_msg(&mut buf).unwrap(); - assert_eq!(buf, EXPECTED); -} + assert_eq!(buf, EXPECTED); + } -#[test] -fn test_encode_describe_unnamed_statement() { - const EXPECTED: &[u8] = b"D\0\0\0\x06S\0"; + #[test] + fn test_encode_describe_unnamed_portal() { + const EXPECTED: &[u8] = b"D\0\0\0\x06P\0"; - let mut buf = Vec::new(); - let m = Describe::UnnamedStatement; + let mut buf = Vec::new(); + let m = Describe::Portal(PortalId::UNNAMED); - m.encode(&mut buf); + m.encode_msg(&mut buf).unwrap(); - assert_eq!(buf, EXPECTED); -} + assert_eq!(buf, EXPECTED); + } -#[cfg(all(test, not(debug_assertions)))] -#[bench] -fn bench_encode_describe_portal(b: &mut test::Bencher) { - use test::black_box; + #[test] + fn test_encode_describe_statement() { + const EXPECTED: &[u8] = b"D\0\0\0\x17Ssqlx_s_1234567890\0"; - let mut buf = Vec::with_capacity(128); + let mut buf = Vec::new(); + let m = Describe::Statement(StatementId::TEST_VAL); - b.iter(|| { - buf.clear(); + m.encode_msg(&mut buf).unwrap(); - black_box(Describe::Portal(5)).encode(&mut buf); - }); -} + assert_eq!(buf, EXPECTED); + } -#[cfg(all(test, not(debug_assertions)))] -#[bench] -fn bench_encode_describe_unnamed_statement(b: &mut test::Bencher) { - use test::black_box; + #[test] + fn test_encode_describe_unnamed_statement() { + const EXPECTED: &[u8] = b"D\0\0\0\x06S\0"; - let mut buf = Vec::with_capacity(128); + let mut buf = Vec::new(); + let m = Describe::Statement(StatementId::UNNAMED); - b.iter(|| { - buf.clear(); + m.encode_msg(&mut buf).unwrap(); - black_box(Describe::UnnamedStatement).encode(&mut buf); - }); + assert_eq!(buf, EXPECTED); + } } diff --git a/sqlx-postgres/src/message/execute.rs b/sqlx-postgres/src/message/execute.rs index 3550ae7824..f82b7884bc 100644 --- a/sqlx-postgres/src/message/execute.rs +++ b/sqlx-postgres/src/message/execute.rs @@ -1,39 +1,73 @@ -use crate::io::Encode; -use crate::io::PgBufMutExt; -use crate::types::Oid; +use std::num::Saturating; + +use sqlx_core::Error; + +use crate::io::{PgBufMutExt, PortalId}; +use crate::message::{FrontendMessage, FrontendMessageFormat}; pub struct Execute { - /// The id of the portal to execute (`None` selects the unnamed portal). - pub portal: Option, + /// The id of the portal to execute. + pub portal: PortalId, /// Maximum number of rows to return, if portal contains a query /// that returns rows (ignored otherwise). Zero denotes “no limit”. pub limit: u32, } -impl Encode<'_> for Execute { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.reserve(20); - buf.push(b'E'); +impl FrontendMessage for Execute { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Execute; + + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + + size += self.portal.name_len(); + size += 2; // limit - buf.put_length_prefixed(|buf| { - buf.put_portal_name(self.portal); - buf.extend(&self.limit.to_be_bytes()); - }); + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + buf.put_portal_name(self.portal); + buf.extend(&self.limit.to_be_bytes()); + + Ok(()) } } -#[test] -fn test_encode_execute() { - const EXPECTED: &[u8] = b"E\0\0\0\x11sqlx_p_5\0\0\0\0\x02"; +#[cfg(test)] +mod tests { + use crate::io::PortalId; + use crate::message::FrontendMessage; + + use super::Execute; - let mut buf = Vec::new(); - let m = Execute { - portal: Some(Oid(5)), - limit: 2, - }; + #[test] + fn test_encode_execute_named_portal() { + const EXPECTED: &[u8] = b"E\0\0\0\x1Asqlx_p_1234567890\0\0\0\0\x02"; - m.encode(&mut buf); + let mut buf = Vec::new(); + let m = Execute { + portal: PortalId::TEST_VAL, + limit: 2, + }; - assert_eq!(buf, EXPECTED); + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } + + #[test] + fn test_encode_execute_unnamed_portal() { + const EXPECTED: &[u8] = b"E\0\0\0\x09\0\x49\x96\x02\xD2"; + + let mut buf = Vec::new(); + let m = Execute { + portal: PortalId::UNNAMED, + limit: 1234567890, + }; + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } } diff --git a/sqlx-postgres/src/message/flush.rs b/sqlx-postgres/src/message/flush.rs index fc21d3f10f..d1dfabbfaf 100644 --- a/sqlx-postgres/src/message/flush.rs +++ b/sqlx-postgres/src/message/flush.rs @@ -1,17 +1,25 @@ -use crate::io::Encode; - -// The Flush message does not cause any specific output to be generated, -// but forces the backend to deliver any data pending in its output buffers. - -// A Flush must be sent after any extended-query command except Sync, if the -// frontend wishes to examine the results of that command before issuing more commands. +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; +/// The Flush message does not cause any specific output to be generated, +/// but forces the backend to deliver any data pending in its output buffers. +/// +/// A Flush must be sent after any extended-query command except Sync, if the +/// frontend wishes to examine the results of that command before issuing more commands. #[derive(Debug)] pub struct Flush; -impl Encode<'_> for Flush { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'H'); - buf.extend(&4_i32.to_be_bytes()); +impl FrontendMessage for Flush { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Flush; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(0) + } + + #[inline(always)] + fn encode_body(&self, _buf: &mut Vec) -> Result<(), Error> { + Ok(()) } } diff --git a/sqlx-postgres/src/message/mod.rs b/sqlx-postgres/src/message/mod.rs index ef1dbfabf0..e62f9bebb3 100644 --- a/sqlx-postgres/src/message/mod.rs +++ b/sqlx-postgres/src/message/mod.rs @@ -1,7 +1,8 @@ use sqlx_core::bytes::Bytes; +use std::num::Saturating; use crate::error::Error; -use crate::io::Decode; +use crate::io::PgBufMutExt; mod authentication; mod backend_key_data; @@ -17,6 +18,7 @@ mod notification; mod parameter_description; mod parameter_status; mod parse; +mod parse_complete; mod password; mod query; mod ready_for_query; @@ -33,7 +35,7 @@ pub use backend_key_data::BackendKeyData; pub use bind::Bind; pub use close::Close; pub use command_complete::CommandComplete; -pub use copy::{CopyData, CopyDone, CopyFail, CopyResponse}; +pub use copy::{CopyData, CopyDone, CopyFail, CopyInResponse, CopyOutResponse, CopyResponseData}; pub use data_row::DataRow; pub use describe::Describe; pub use execute::Execute; @@ -43,20 +45,51 @@ pub use notification::Notification; pub use parameter_description::ParameterDescription; pub use parameter_status::ParameterStatus; pub use parse::Parse; +pub use parse_complete::ParseComplete; pub use password::Password; pub use query::Query; pub use ready_for_query::{ReadyForQuery, TransactionStatus}; pub use response::{Notice, PgSeverity}; pub use row_description::RowDescription; pub use sasl::{SaslInitialResponse, SaslResponse}; +use sqlx_core::io::ProtocolEncode; pub use ssl_request::SslRequest; pub use startup::Startup; pub use sync::Sync; pub use terminate::Terminate; +// Note: we can't use the same enum for both frontend and backend message formats +// because there are duplicated format codes between them. +// +// For example, `Close` (frontend) and `CommandComplete` (backend) both use format code `C`. +// #[derive(Debug, PartialOrd, PartialEq)] #[repr(u8)] -pub enum MessageFormat { +pub enum FrontendMessageFormat { + Bind = b'B', + Close = b'C', + CopyData = b'd', + CopyDone = b'c', + CopyFail = b'f', + Describe = b'D', + Execute = b'E', + Flush = b'H', + Parse = b'P', + /// This message format is polymorphic. It's used for: + /// + /// * Plain password responses + /// * MD5 password responses + /// * SASL responses + /// * GSSAPI/SSPI responses + PasswordPolymorphic = b'p', + Query = b'Q', + Sync = b'S', + Terminate = b'X', +} + +#[derive(Debug, PartialOrd, PartialEq)] +#[repr(u8)] +pub enum BackendMessageFormat { Authentication, BackendKeyData, BindComplete, @@ -81,49 +114,116 @@ pub enum MessageFormat { } #[derive(Debug)] -pub struct Message { - pub format: MessageFormat, +pub struct ReceivedMessage { + pub format: BackendMessageFormat, pub contents: Bytes, } -impl Message { +impl ReceivedMessage { #[inline] - pub fn decode<'de, T>(self) -> Result + pub fn decode(self) -> Result where - T: Decode<'de>, + T: BackendMessage, { - T::decode(self.contents) + if T::FORMAT != self.format { + return Err(err_protocol!( + "Postgres protocol error: expected {:?}, got {:?}", + T::FORMAT, + self.format + )); + } + + T::decode_body(self.contents).map_err(|e| match e { + Error::Protocol(s) => { + err_protocol!("Postgres protocol error (reading {:?}): {s}", self.format) + } + other => other, + }) } } -impl MessageFormat { +impl BackendMessageFormat { pub fn try_from_u8(v: u8) -> Result { // https://www.postgresql.org/docs/current/protocol-message-formats.html Ok(match v { - b'1' => MessageFormat::ParseComplete, - b'2' => MessageFormat::BindComplete, - b'3' => MessageFormat::CloseComplete, - b'C' => MessageFormat::CommandComplete, - b'd' => MessageFormat::CopyData, - b'c' => MessageFormat::CopyDone, - b'G' => MessageFormat::CopyInResponse, - b'H' => MessageFormat::CopyOutResponse, - b'D' => MessageFormat::DataRow, - b'E' => MessageFormat::ErrorResponse, - b'I' => MessageFormat::EmptyQueryResponse, - b'A' => MessageFormat::NotificationResponse, - b'K' => MessageFormat::BackendKeyData, - b'N' => MessageFormat::NoticeResponse, - b'R' => MessageFormat::Authentication, - b'S' => MessageFormat::ParameterStatus, - b'T' => MessageFormat::RowDescription, - b'Z' => MessageFormat::ReadyForQuery, - b'n' => MessageFormat::NoData, - b's' => MessageFormat::PortalSuspended, - b't' => MessageFormat::ParameterDescription, + b'1' => BackendMessageFormat::ParseComplete, + b'2' => BackendMessageFormat::BindComplete, + b'3' => BackendMessageFormat::CloseComplete, + b'C' => BackendMessageFormat::CommandComplete, + b'd' => BackendMessageFormat::CopyData, + b'c' => BackendMessageFormat::CopyDone, + b'G' => BackendMessageFormat::CopyInResponse, + b'H' => BackendMessageFormat::CopyOutResponse, + b'D' => BackendMessageFormat::DataRow, + b'E' => BackendMessageFormat::ErrorResponse, + b'I' => BackendMessageFormat::EmptyQueryResponse, + b'A' => BackendMessageFormat::NotificationResponse, + b'K' => BackendMessageFormat::BackendKeyData, + b'N' => BackendMessageFormat::NoticeResponse, + b'R' => BackendMessageFormat::Authentication, + b'S' => BackendMessageFormat::ParameterStatus, + b'T' => BackendMessageFormat::RowDescription, + b'Z' => BackendMessageFormat::ReadyForQuery, + b'n' => BackendMessageFormat::NoData, + b's' => BackendMessageFormat::PortalSuspended, + b't' => BackendMessageFormat::ParameterDescription, _ => return Err(err_protocol!("unknown message type: {:?}", v as char)), }) } } + +pub(crate) trait FrontendMessage: Sized { + /// The format prefix of this message. + const FORMAT: FrontendMessageFormat; + + /// Return the amount of space, in bytes, to reserve in the buffer passed to [`Self::encode_body()`]. + fn body_size_hint(&self) -> Saturating; + + /// Encode this type as a Frontend message in the Postgres protocol. + /// + /// The implementation should *not* include `Self::FORMAT` or the length prefix. + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error>; + + #[inline(always)] + #[cfg_attr(not(test), allow(dead_code))] + fn encode_msg(self, buf: &mut Vec) -> Result<(), Error> { + EncodeMessage(self).encode(buf) + } +} + +pub(crate) trait BackendMessage: Sized { + /// The expected message format. + /// + /// + const FORMAT: BackendMessageFormat; + + /// Decode this type from a Backend message in the Postgres protocol. + /// + /// The format code and length prefix have already been read and are not at the start of `bytes`. + fn decode_body(buf: Bytes) -> Result; +} + +pub struct EncodeMessage(pub F); + +impl ProtocolEncode<'_, ()> for EncodeMessage { + fn encode_with(&self, buf: &mut Vec, _context: ()) -> Result<(), Error> { + let mut size_hint = self.0.body_size_hint(); + // plus format code and length prefix + size_hint += 5; + + // don't panic if `size_hint` is ridiculous + buf.try_reserve(size_hint.0).map_err(|e| { + err_protocol!( + "Postgres protocol: error allocating {} bytes for encoding message {:?}: {e}", + size_hint.0, + F::FORMAT, + ) + })?; + + buf.push(F::FORMAT as u8); + + buf.put_length_prefixed(|buf| self.0.encode_body(buf)) + } +} diff --git a/sqlx-postgres/src/message/notification.rs b/sqlx-postgres/src/message/notification.rs index 34303908ac..7bf029839c 100644 --- a/sqlx-postgres/src/message/notification.rs +++ b/sqlx-postgres/src/message/notification.rs @@ -1,7 +1,8 @@ use sqlx_core::bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::{BufExt, Decode}; +use crate::io::BufExt; +use crate::message::{BackendMessage, BackendMessageFormat}; #[derive(Debug)] pub struct Notification { @@ -10,9 +11,10 @@ pub struct Notification { pub(crate) payload: Bytes, } -impl Decode<'_> for Notification { - #[inline] - fn decode_with(mut buf: Bytes, _: ()) -> Result { +impl BackendMessage for Notification { + const FORMAT: BackendMessageFormat = BackendMessageFormat::NotificationResponse; + + fn decode_body(mut buf: Bytes) -> Result { let process_id = buf.get_u32(); let channel = buf.get_bytes_nul()?; let payload = buf.get_bytes_nul()?; @@ -29,7 +31,7 @@ impl Decode<'_> for Notification { fn test_decode_notification_response() { const NOTIFICATION_RESPONSE: &[u8] = b"\x34\x20\x10\x02TEST-CHANNEL\0THIS IS A TEST\0"; - let message = Notification::decode(Bytes::from(NOTIFICATION_RESPONSE)).unwrap(); + let message = Notification::decode_body(Bytes::from(NOTIFICATION_RESPONSE)).unwrap(); assert_eq!(message.process_id, 0x34201002); assert_eq!(&*message.channel, &b"TEST-CHANNEL"[..]); diff --git a/sqlx-postgres/src/message/parameter_description.rs b/sqlx-postgres/src/message/parameter_description.rs index 8d525d05c5..8aa361a8eb 100644 --- a/sqlx-postgres/src/message/parameter_description.rs +++ b/sqlx-postgres/src/message/parameter_description.rs @@ -2,7 +2,7 @@ use smallvec::SmallVec; use sqlx_core::bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::Decode; +use crate::message::{BackendMessage, BackendMessageFormat}; use crate::types::Oid; #[derive(Debug)] @@ -10,8 +10,10 @@ pub struct ParameterDescription { pub types: SmallVec<[Oid; 6]>, } -impl Decode<'_> for ParameterDescription { - fn decode_with(mut buf: Bytes, _: ()) -> Result { +impl BackendMessage for ParameterDescription { + const FORMAT: BackendMessageFormat = BackendMessageFormat::ParameterDescription; + + fn decode_body(mut buf: Bytes) -> Result { let cnt = buf.get_u16(); let mut types = SmallVec::with_capacity(cnt as usize); @@ -27,7 +29,7 @@ impl Decode<'_> for ParameterDescription { fn test_decode_parameter_description() { const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; - let m = ParameterDescription::decode(DATA.into()).unwrap(); + let m = ParameterDescription::decode_body(DATA.into()).unwrap(); assert_eq!(m.types.len(), 2); assert_eq!(m.types[0], Oid(0x0000_0000)); @@ -38,7 +40,7 @@ fn test_decode_parameter_description() { fn test_decode_empty_parameter_description() { const DATA: &[u8] = b"\x00\x00"; - let m = ParameterDescription::decode(DATA.into()).unwrap(); + let m = ParameterDescription::decode_body(DATA.into()).unwrap(); assert!(m.types.is_empty()); } @@ -49,6 +51,6 @@ fn bench_decode_parameter_description(b: &mut test::Bencher) { const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; b.iter(|| { - ParameterDescription::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); + ParameterDescription::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap(); }); } diff --git a/sqlx-postgres/src/message/parameter_status.rs b/sqlx-postgres/src/message/parameter_status.rs index 37abe4e38e..d979d1895e 100644 --- a/sqlx-postgres/src/message/parameter_status.rs +++ b/sqlx-postgres/src/message/parameter_status.rs @@ -1,7 +1,8 @@ use sqlx_core::bytes::Bytes; use crate::error::Error; -use crate::io::{BufExt, Decode}; +use crate::io::BufExt; +use crate::message::{BackendMessage, BackendMessageFormat}; #[derive(Debug)] pub struct ParameterStatus { @@ -9,8 +10,10 @@ pub struct ParameterStatus { pub value: String, } -impl Decode<'_> for ParameterStatus { - fn decode_with(mut buf: Bytes, _: ()) -> Result { +impl BackendMessage for ParameterStatus { + const FORMAT: BackendMessageFormat = BackendMessageFormat::ParameterStatus; + + fn decode_body(mut buf: Bytes) -> Result { let name = buf.get_str_nul()?; let value = buf.get_str_nul()?; @@ -22,7 +25,7 @@ impl Decode<'_> for ParameterStatus { fn test_decode_parameter_status() { const DATA: &[u8] = b"client_encoding\x00UTF8\x00"; - let m = ParameterStatus::decode(DATA.into()).unwrap(); + let m = ParameterStatus::decode_body(DATA.into()).unwrap(); assert_eq!(&m.name, "client_encoding"); assert_eq!(&m.value, "UTF8") @@ -32,7 +35,7 @@ fn test_decode_parameter_status() { fn test_decode_empty_parameter_status() { const DATA: &[u8] = b"\x00\x00"; - let m = ParameterStatus::decode(DATA.into()).unwrap(); + let m = ParameterStatus::decode_body(DATA.into()).unwrap(); assert!(m.name.is_empty()); assert!(m.value.is_empty()); @@ -44,7 +47,7 @@ fn bench_decode_parameter_status(b: &mut test::Bencher) { const DATA: &[u8] = b"client_encoding\x00UTF8\x00"; b.iter(|| { - ParameterStatus::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); + ParameterStatus::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap(); }); } @@ -52,7 +55,7 @@ fn bench_decode_parameter_status(b: &mut test::Bencher) { fn test_decode_parameter_status_response() { const PARAMETER_STATUS_RESPONSE: &[u8] = b"crdb_version\0CockroachDB CCL v21.1.0 (x86_64-unknown-linux-gnu, built 2021/05/17 13:49:40, go1.15.11)\0"; - let message = ParameterStatus::decode(Bytes::from(PARAMETER_STATUS_RESPONSE)).unwrap(); + let message = ParameterStatus::decode_body(Bytes::from(PARAMETER_STATUS_RESPONSE)).unwrap(); assert_eq!(message.name, "crdb_version"); assert_eq!( diff --git a/sqlx-postgres/src/message/parse.rs b/sqlx-postgres/src/message/parse.rs index 6bcbdb6bb0..3e77c3024c 100644 --- a/sqlx-postgres/src/message/parse.rs +++ b/sqlx-postgres/src/message/parse.rs @@ -1,11 +1,14 @@ -use crate::io::PgBufMutExt; -use crate::io::{BufMutExt, Encode}; +use crate::io::BufMutExt; +use crate::io::{PgBufMutExt, StatementId}; +use crate::message::{FrontendMessage, FrontendMessageFormat}; use crate::types::Oid; +use sqlx_core::Error; +use std::num::Saturating; #[derive(Debug)] pub struct Parse<'a> { /// The ID of the destination prepared statement. - pub statement: Oid, + pub statement: StatementId, /// The query string to be parsed. pub query: &'a str, @@ -16,39 +19,59 @@ pub struct Parse<'a> { pub param_types: &'a [Oid], } -impl Encode<'_> for Parse<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'P'); +impl FrontendMessage for Parse<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Parse; - buf.put_length_prefixed(|buf| { - buf.put_statement_name(self.statement); + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); - buf.put_str_nul(self.query); + size += self.statement.name_len(); - // TODO: Return an error here instead - assert!(self.param_types.len() <= (u16::MAX as usize)); + size += self.query.len(); + size += 1; // NUL terminator - buf.extend(&(self.param_types.len() as i16).to_be_bytes()); + size += 2; // param_types_len - for &oid in self.param_types { - buf.extend(&oid.0.to_be_bytes()); - } - }) + // `param_types` + size += self.param_types.len().saturating_mul(4); + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + buf.put_statement_name(self.statement); + + buf.put_str_nul(self.query); + + let param_types_len = i16::try_from(self.param_types.len()).map_err(|_| { + err_protocol!( + "param_types.len() too large for binary protocol: {}", + self.param_types.len() + ) + })?; + + buf.extend(param_types_len.to_be_bytes()); + + for &oid in self.param_types { + buf.extend(oid.0.to_be_bytes()); + } + + Ok(()) } } #[test] fn test_encode_parse() { - const EXPECTED: &[u8] = b"P\0\0\0\x1dsqlx_s_1\0SELECT $1\0\0\x01\0\0\0\x19"; + const EXPECTED: &[u8] = b"P\0\0\0\x26sqlx_s_1234567890\0SELECT $1\0\0\x01\0\0\0\x19"; let mut buf = Vec::new(); let m = Parse { - statement: Oid(1), + statement: StatementId::TEST_VAL, query: "SELECT $1", param_types: &[Oid(25)], }; - m.encode(&mut buf); + m.encode_msg(&mut buf).unwrap(); assert_eq!(buf, EXPECTED); } diff --git a/sqlx-postgres/src/message/parse_complete.rs b/sqlx-postgres/src/message/parse_complete.rs new file mode 100644 index 0000000000..3051f5ff97 --- /dev/null +++ b/sqlx-postgres/src/message/parse_complete.rs @@ -0,0 +1,13 @@ +use crate::message::{BackendMessage, BackendMessageFormat}; +use sqlx_core::bytes::Bytes; +use sqlx_core::Error; + +pub struct ParseComplete; + +impl BackendMessage for ParseComplete { + const FORMAT: BackendMessageFormat = BackendMessageFormat::ParseComplete; + + fn decode_body(_bytes: Bytes) -> Result { + Ok(ParseComplete) + } +} diff --git a/sqlx-postgres/src/message/password.rs b/sqlx-postgres/src/message/password.rs index ba8b5ac68e..4eaaeb15af 100644 --- a/sqlx-postgres/src/message/password.rs +++ b/sqlx-postgres/src/message/password.rs @@ -1,9 +1,9 @@ -use std::fmt::Write; - +use crate::io::BufMutExt; +use crate::message::{FrontendMessage, FrontendMessageFormat}; use md5::{Digest, Md5}; - -use crate::io::PgBufMutExt; -use crate::io::{BufMutExt, Encode}; +use sqlx_core::Error; +use std::fmt::Write; +use std::num::Saturating; #[derive(Debug)] pub enum Password<'a> { @@ -16,117 +16,138 @@ pub enum Password<'a> { }, } -impl Password<'_> { - #[inline] - fn len(&self) -> usize { +impl FrontendMessage for Password<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + match self { - Password::Cleartext(s) => s.len() + 5, - Password::Md5 { .. } => 35 + 5, + Password::Cleartext(password) => { + // To avoid reporting the exact password length anywhere, + // we deliberately give a bad estimate. + // + // This shouldn't affect performance in the long run. + size += password + .len() + .saturating_add(1) // NUL terminator + .checked_next_power_of_two() + .unwrap_or(usize::MAX); + } + Password::Md5 { .. } => { + // "md5<32 hex chars>\0" + size += 36; + } } - } -} -impl Encode<'_> for Password<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.reserve(1 + 4 + self.len()); - buf.push(b'p'); + size + } - buf.put_length_prefixed(|buf| { - match self { - Password::Cleartext(password) => { - buf.put_str_nul(password); - } + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + match self { + Password::Cleartext(password) => { + buf.put_str_nul(password); + } - Password::Md5 { - username, - password, - salt, - } => { - // The actual `PasswordMessage` can be computed in SQL as - // `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`. + Password::Md5 { + username, + password, + salt, + } => { + // The actual `PasswordMessage` can be computed in SQL as + // `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`. - // Keep in mind the md5() function returns its result as a hex string. + // Keep in mind the md5() function returns its result as a hex string. - let mut hasher = Md5::new(); + let mut hasher = Md5::new(); - hasher.update(password); - hasher.update(username); + hasher.update(password); + hasher.update(username); - let mut output = String::with_capacity(35); + let mut output = String::with_capacity(35); - let _ = write!(output, "{:x}", hasher.finalize_reset()); + let _ = write!(output, "{:x}", hasher.finalize_reset()); - hasher.update(&output); - hasher.update(salt); + hasher.update(&output); + hasher.update(salt); - output.clear(); + output.clear(); - let _ = write!(output, "md5{:x}", hasher.finalize()); + let _ = write!(output, "md5{:x}", hasher.finalize()); - buf.put_str_nul(&output); - } + buf.put_str_nul(&output); } - }); + } + + Ok(()) } } -#[test] -fn test_encode_clear_password() { - const EXPECTED: &[u8] = b"p\0\0\0\rpassword\0"; +#[cfg(test)] +mod tests { + use crate::message::FrontendMessage; - let mut buf = Vec::new(); - let m = Password::Cleartext("password"); + use super::Password; - m.encode(&mut buf); + #[test] + fn test_encode_clear_password() { + const EXPECTED: &[u8] = b"p\0\0\0\rpassword\0"; - assert_eq!(buf, EXPECTED); -} + let mut buf = Vec::new(); + let m = Password::Cleartext("password"); -#[test] -fn test_encode_md5_password() { - const EXPECTED: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0"; + m.encode_msg(&mut buf).unwrap(); - let mut buf = Vec::new(); - let m = Password::Md5 { - password: "password", - username: "root", - salt: [147, 24, 57, 152], - }; + assert_eq!(buf, EXPECTED); + } - m.encode(&mut buf); + #[test] + fn test_encode_md5_password() { + const EXPECTED: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0"; - assert_eq!(buf, EXPECTED); -} + let mut buf = Vec::new(); + let m = Password::Md5 { + password: "password", + username: "root", + salt: [147, 24, 57, 152], + }; + + m.encode_msg(&mut buf).unwrap(); -#[cfg(all(test, not(debug_assertions)))] -#[bench] -fn bench_encode_clear_password(b: &mut test::Bencher) { - use test::black_box; + assert_eq!(buf, EXPECTED); + } - let mut buf = Vec::with_capacity(128); + #[cfg(all(test, not(debug_assertions)))] + #[bench] + fn bench_encode_clear_password(b: &mut test::Bencher) { + use test::black_box; - b.iter(|| { - buf.clear(); + let mut buf = Vec::with_capacity(128); - black_box(Password::Cleartext("password")).encode(&mut buf); - }); -} + b.iter(|| { + buf.clear(); -#[cfg(all(test, not(debug_assertions)))] -#[bench] -fn bench_encode_md5_password(b: &mut test::Bencher) { - use test::black_box; + black_box(Password::Cleartext("password")).encode_msg(&mut buf); + }); + } - let mut buf = Vec::with_capacity(128); + #[cfg(all(test, not(debug_assertions)))] + #[bench] + fn bench_encode_md5_password(b: &mut test::Bencher) { + use test::black_box; - b.iter(|| { - buf.clear(); + let mut buf = Vec::with_capacity(128); - black_box(Password::Md5 { - password: "password", - username: "root", - salt: [147, 24, 57, 152], - }) - .encode(&mut buf); - }); + b.iter(|| { + buf.clear(); + + black_box(Password::Md5 { + password: "password", + username: "root", + salt: [147, 24, 57, 152], + }) + .encode_msg(&mut buf); + }); + } } diff --git a/sqlx-postgres/src/message/query.rs b/sqlx-postgres/src/message/query.rs index 8f49aabc30..788d7808fc 100644 --- a/sqlx-postgres/src/message/query.rs +++ b/sqlx-postgres/src/message/query.rs @@ -1,27 +1,37 @@ -use crate::io::{BufMutExt, Encode}; +use crate::io::BufMutExt; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; #[derive(Debug)] pub struct Query<'a>(pub &'a str); -impl Encode<'_> for Query<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { - let len = 4 + self.0.len() + 1; +impl FrontendMessage for Query<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Query; - buf.reserve(len + 1); - buf.push(b'Q'); - buf.extend(&(len as i32).to_be_bytes()); + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + + size += self.0.len(); + size += 1; // NUL terminator + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { buf.put_str_nul(self.0); + Ok(()) } } #[test] fn test_encode_query() { - const EXPECTED: &[u8] = b"Q\0\0\0\rSELECT 1\0"; + const EXPECTED: &[u8] = b"Q\0\0\0\x0DSELECT 1\0"; let mut buf = Vec::new(); let m = Query("SELECT 1"); - m.encode(&mut buf); + m.encode_msg(&mut buf).unwrap(); assert_eq!(buf, EXPECTED); } diff --git a/sqlx-postgres/src/message/ready_for_query.rs b/sqlx-postgres/src/message/ready_for_query.rs index 21e6540d01..a1f6761b89 100644 --- a/sqlx-postgres/src/message/ready_for_query.rs +++ b/sqlx-postgres/src/message/ready_for_query.rs @@ -1,7 +1,7 @@ use sqlx_core::bytes::Bytes; use crate::error::Error; -use crate::io::Decode; +use crate::message::{BackendMessage, BackendMessageFormat}; #[derive(Debug)] #[repr(u8)] @@ -21,8 +21,10 @@ pub struct ReadyForQuery { pub transaction_status: TransactionStatus, } -impl Decode<'_> for ReadyForQuery { - fn decode_with(buf: Bytes, _: ()) -> Result { +impl BackendMessage for ReadyForQuery { + const FORMAT: BackendMessageFormat = BackendMessageFormat::ReadyForQuery; + + fn decode_body(buf: Bytes) -> Result { let status = match buf[0] { b'I' => TransactionStatus::Idle, b'T' => TransactionStatus::Transaction, @@ -46,7 +48,7 @@ impl Decode<'_> for ReadyForQuery { fn test_decode_ready_for_query() -> Result<(), Error> { const DATA: &[u8] = b"E"; - let m = ReadyForQuery::decode(Bytes::from_static(DATA))?; + let m = ReadyForQuery::decode_body(Bytes::from_static(DATA))?; assert!(matches!(m.transaction_status, TransactionStatus::Error)); diff --git a/sqlx-postgres/src/message/response.rs b/sqlx-postgres/src/message/response.rs index ec3c880886..d6e43e0871 100644 --- a/sqlx-postgres/src/message/response.rs +++ b/sqlx-postgres/src/message/response.rs @@ -1,10 +1,13 @@ +use std::ops::Range; use std::str::from_utf8; use memchr::memchr; + use sqlx_core::bytes::Bytes; use crate::error::Error; -use crate::io::Decode; +use crate::io::ProtocolDecode; +use crate::message::{BackendMessage, BackendMessageFormat}; #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[repr(u8)] @@ -53,8 +56,8 @@ impl TryFrom<&str> for PgSeverity { pub struct Notice { storage: Bytes, severity: PgSeverity, - message: (u16, u16), - code: (u16, u16), + message: Range, + code: Range, } impl Notice { @@ -65,12 +68,12 @@ impl Notice { #[inline] pub fn code(&self) -> &str { - self.get_cached_str(self.code) + self.get_cached_str(self.code.clone()) } #[inline] pub fn message(&self) -> &str { - self.get_cached_str(self.message) + self.get_cached_str(self.message.clone()) } // Field descriptions available here: @@ -84,7 +87,7 @@ impl Notice { pub fn get_raw(&self, ty: u8) -> Option<&[u8]> { self.fields() .filter(|(field, _)| *field == ty) - .map(|(_, (start, end))| &self.storage[start as usize..end as usize]) + .map(|(_, range)| &self.storage[range]) .next() } } @@ -99,13 +102,13 @@ impl Notice { } #[inline] - fn get_cached_str(&self, cache: (u16, u16)) -> &str { + fn get_cached_str(&self, cache: Range) -> &str { // unwrap: this cannot fail at this stage - from_utf8(&self.storage[cache.0 as usize..cache.1 as usize]).unwrap() + from_utf8(&self.storage[cache]).unwrap() } } -impl Decode<'_> for Notice { +impl ProtocolDecode<'_> for Notice { fn decode_with(buf: Bytes, _: ()) -> Result { // In order to support PostgreSQL 9.5 and older we need to parse the localized S field. // Newer versions additionally come with the V field that is guaranteed to be in English. @@ -113,8 +116,8 @@ impl Decode<'_> for Notice { const DEFAULT_SEVERITY: PgSeverity = PgSeverity::Log; let mut severity_v = None; let mut severity_s = None; - let mut message = (0, 0); - let mut code = (0, 0); + let mut message = 0..0; + let mut code = 0..0; // we cache the three always present fields // this enables to keep the access time down for the fields most likely accessed @@ -125,7 +128,7 @@ impl Decode<'_> for Notice { }; for (field, v) in fields { - if message.0 != 0 && code.0 != 0 { + if !(message.is_empty() || code.is_empty()) { // stop iterating when we have the 3 fields we were looking for // we assume V (severity) was the first field as it should be break; @@ -133,7 +136,7 @@ impl Decode<'_> for Notice { match field { b'S' => { - severity_s = from_utf8(&buf[v.0 as usize..v.1 as usize]) + severity_s = from_utf8(&buf[v.clone()]) // If the error string is not UTF-8, we have no hope of interpreting it, // localized or not. The `V` field would likely fail to parse as well. .map_err(|_| notice_protocol_err())? @@ -146,21 +149,19 @@ impl Decode<'_> for Notice { // Propagate errors here, because V is not localized and // thus we are missing a possible variant. severity_v = Some( - from_utf8(&buf[v.0 as usize..v.1 as usize]) + from_utf8(&buf[v.clone()]) .map_err(|_| notice_protocol_err())? .try_into()?, ); } b'M' => { - _ = from_utf8(&buf[v.0 as usize..v.1 as usize]) - .map_err(|_| notice_protocol_err())?; + _ = from_utf8(&buf[v.clone()]).map_err(|_| notice_protocol_err())?; message = v; } b'C' => { - _ = from_utf8(&buf[v.0 as usize..v.1 as usize]) - .map_err(|_| notice_protocol_err())?; + _ = from_utf8(&buf[v.clone()]).map_err(|_| notice_protocol_err())?; code = v; } @@ -179,31 +180,46 @@ impl Decode<'_> for Notice { } } +impl BackendMessage for Notice { + const FORMAT: BackendMessageFormat = BackendMessageFormat::NoticeResponse; + + fn decode_body(buf: Bytes) -> Result { + // Keeping both impls for now + Self::decode_with(buf, ()) + } +} + /// An iterator over each field in the Error (or Notice) response. struct Fields<'a> { storage: &'a [u8], - offset: u16, + offset: usize, } impl<'a> Iterator for Fields<'a> { - type Item = (u8, (u16, u16)); + type Item = (u8, Range); fn next(&mut self) -> Option { // The fields in the response body are sequentially stored as [tag][string], // ending in a final, additional [nul] - let ty = self.storage[self.offset as usize]; + let ty = *self.storage.get(self.offset)?; if ty == 0 { return None; } - let nul = memchr(b'\0', &self.storage[(self.offset + 1) as usize..])? as u16; - let offset = self.offset; + // Consume the type byte + self.offset = self.offset.checked_add(1)?; + + let start = self.offset; + + let len = memchr(b'\0', self.storage.get(start..)?)?; - self.offset += nul + 2; + // Neither can overflow as they will always be `<= self.storage.len()`. + let end = self.offset + len; + self.offset = end + 1; - Some((ty, (offset + 1, offset + nul + 1))) + Some((ty, start..end)) } } diff --git a/sqlx-postgres/src/message/row_description.rs b/sqlx-postgres/src/message/row_description.rs index 32121386ae..3f3155ed5b 100644 --- a/sqlx-postgres/src/message/row_description.rs +++ b/sqlx-postgres/src/message/row_description.rs @@ -1,7 +1,8 @@ use sqlx_core::bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::{BufExt, Decode}; +use crate::io::BufExt; +use crate::message::{BackendMessage, BackendMessageFormat}; use crate::types::Oid; #[derive(Debug)] @@ -40,13 +41,30 @@ pub struct Field { pub format: i16, } -impl Decode<'_> for RowDescription { - fn decode_with(mut buf: Bytes, _: ()) -> Result { +impl BackendMessage for RowDescription { + const FORMAT: BackendMessageFormat = BackendMessageFormat::RowDescription; + + fn decode_body(mut buf: Bytes) -> Result { + if buf.len() < 2 { + return Err(err_protocol!( + "expected at least 2 bytes, got {}", + buf.len() + )); + } + let cnt = buf.get_u16(); let mut fields = Vec::with_capacity(cnt as usize); for _ in 0..cnt { let name = buf.get_str_nul()?.to_owned(); + + if buf.len() < 18 { + return Err(err_protocol!( + "expected at least 18 bytes after field name {name:?}, got {}", + buf.len() + )); + } + let relation_id = buf.get_i32(); let relation_attribute_no = buf.get_i16(); let data_type_id = Oid(buf.get_u32()); diff --git a/sqlx-postgres/src/message/sasl.rs b/sqlx-postgres/src/message/sasl.rs index 77d0bf8dfe..9d393189bf 100644 --- a/sqlx-postgres/src/message/sasl.rs +++ b/sqlx-postgres/src/message/sasl.rs @@ -1,35 +1,69 @@ -use crate::io::PgBufMutExt; -use crate::io::{BufMutExt, Encode}; +use crate::io::BufMutExt; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; pub struct SaslInitialResponse<'a> { pub response: &'a str, pub plus: bool, } -impl Encode<'_> for SaslInitialResponse<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'p'); - buf.put_length_prefixed(|buf| { - // name of the SASL authentication mechanism that the client selected - buf.put_str_nul(if self.plus { - "SCRAM-SHA-256-PLUS" - } else { - "SCRAM-SHA-256" - }); - - buf.extend(&(self.response.as_bytes().len() as i32).to_be_bytes()); - buf.extend(self.response.as_bytes()); - }); +impl SaslInitialResponse<'_> { + #[inline(always)] + fn selected_mechanism(&self) -> &'static str { + if self.plus { + "SCRAM-SHA-256-PLUS" + } else { + "SCRAM-SHA-256" + } + } +} + +impl FrontendMessage for SaslInitialResponse<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + + size += self.selected_mechanism().len(); + size += 1; // NUL terminator + + size += 4; // response_len + size += self.response.len(); + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + // name of the SASL authentication mechanism that the client selected + buf.put_str_nul(self.selected_mechanism()); + + let response_len = i32::try_from(self.response.len()).map_err(|_| { + err_protocol!( + "SASL Initial Response length too long for protcol: {}", + self.response.len() + ) + })?; + + buf.extend_from_slice(&response_len.to_be_bytes()); + buf.extend_from_slice(self.response.as_bytes()); + + Ok(()) } } pub struct SaslResponse<'a>(pub &'a str); -impl Encode<'_> for SaslResponse<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'p'); - buf.put_length_prefixed(|buf| { - buf.extend(self.0.as_bytes()); - }); +impl FrontendMessage for SaslResponse<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic; + + fn body_size_hint(&self) -> Saturating { + Saturating(self.0.len()) + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + buf.extend(self.0.as_bytes()); + Ok(()) } } diff --git a/sqlx-postgres/src/message/ssl_request.rs b/sqlx-postgres/src/message/ssl_request.rs index fa57faf064..09c886221a 100644 --- a/sqlx-postgres/src/message/ssl_request.rs +++ b/sqlx-postgres/src/message/ssl_request.rs @@ -1,23 +1,38 @@ -use crate::io::Encode; +use crate::io::ProtocolEncode; pub struct SslRequest; impl SslRequest { - pub const BYTES: &'static [u8] = b"\x00\x00\x00\x08\x04\xd2\x16/"; + // https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-SSLREQUEST + pub const BYTES: &'static [u8] = b"\x00\x00\x00\x08\x04\xd2\x16\x2f"; } -impl Encode<'_> for SslRequest { - #[inline] - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.extend(&8_u32.to_be_bytes()); - buf.extend(&(((1234 << 16) | 5679) as u32).to_be_bytes()); +// Cannot impl FrontendMessage because it does not have a format code +impl ProtocolEncode<'_> for SslRequest { + #[inline(always)] + fn encode_with(&self, buf: &mut Vec, _context: ()) -> Result<(), crate::Error> { + buf.extend_from_slice(Self::BYTES); + Ok(()) } } #[test] fn test_encode_ssl_request() { let mut buf = Vec::new(); - SslRequest.encode(&mut buf); + + // Int32(8) + // Length of message contents in bytes, including self. + buf.extend_from_slice(&8_u32.to_be_bytes()); + + // Int32(80877103) + // The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits, + // and 5679 in the least significant 16 bits. + // (To avoid confusion, this code must not be the same as any protocol version number.) + buf.extend_from_slice(&(((1234 << 16) | 5679) as u32).to_be_bytes()); + + let mut encoded = Vec::new(); + SslRequest.encode(&mut encoded).unwrap(); assert_eq!(buf, SslRequest::BYTES); + assert_eq!(buf, encoded); } diff --git a/sqlx-postgres/src/message/startup.rs b/sqlx-postgres/src/message/startup.rs index 838695843f..1c6d735ab7 100644 --- a/sqlx-postgres/src/message/startup.rs +++ b/sqlx-postgres/src/message/startup.rs @@ -1,5 +1,5 @@ use crate::io::PgBufMutExt; -use crate::io::{BufMutExt, Encode}; +use crate::io::{BufMutExt, ProtocolEncode}; // To begin a session, a frontend opens a connection to the server and sends a startup message. // This message includes the names of the user and of the database the user wants to connect to; @@ -19,8 +19,9 @@ pub struct Startup<'a> { pub params: &'a [(&'a str, &'a str)], } -impl Encode<'_> for Startup<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { +// Startup cannot impl FrontendMessage because it doesn't have a format code. +impl ProtocolEncode<'_> for Startup<'_> { + fn encode_with(&self, buf: &mut Vec, _context: ()) -> Result<(), crate::Error> { buf.reserve(120); buf.put_length_prefixed(|buf| { @@ -47,7 +48,9 @@ impl Encode<'_> for Startup<'_> { // A zero byte is required as a terminator // after the last name/value pair. buf.push(0); - }); + + Ok(()) + }) } } @@ -68,7 +71,7 @@ fn test_encode_startup() { params: &[], }; - m.encode(&mut buf); + m.encode(&mut buf).unwrap(); assert_eq!(buf, EXPECTED); } diff --git a/sqlx-postgres/src/message/sync.rs b/sqlx-postgres/src/message/sync.rs index bc30114ef3..56f4498746 100644 --- a/sqlx-postgres/src/message/sync.rs +++ b/sqlx-postgres/src/message/sync.rs @@ -1,11 +1,20 @@ -use crate::io::Encode; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; #[derive(Debug)] pub struct Sync; -impl Encode<'_> for Sync { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'S'); - buf.extend(&4_i32.to_be_bytes()); +impl FrontendMessage for Sync { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Sync; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(0) + } + + #[inline(always)] + fn encode_body(&self, _buf: &mut Vec) -> Result<(), Error> { + Ok(()) } } diff --git a/sqlx-postgres/src/message/terminate.rs b/sqlx-postgres/src/message/terminate.rs index 98e41fdbaa..39f8ff6e6c 100644 --- a/sqlx-postgres/src/message/terminate.rs +++ b/sqlx-postgres/src/message/terminate.rs @@ -1,10 +1,19 @@ -use crate::io::Encode; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; pub struct Terminate; -impl Encode<'_> for Terminate { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'X'); - buf.extend(&4_u32.to_be_bytes()); +impl FrontendMessage for Terminate { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Terminate; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(0) + } + + #[inline(always)] + fn encode_body(&self, _buf: &mut Vec) -> Result<(), Error> { + Ok(()) } } diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index 02028624e1..b9330d5292 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -17,7 +17,7 @@ impl TransactionManager for PgTransactionManager { Box::pin(async move { let rollback = Rollback::new(conn); let query = begin_ansi_transaction_sql(rollback.conn.transaction_depth); - rollback.conn.queue_simple_query(&query); + rollback.conn.queue_simple_query(&query)?; rollback.conn.transaction_depth += 1; rollback.conn.wait_until_ready().await?; rollback.defuse(); @@ -54,7 +54,8 @@ impl TransactionManager for PgTransactionManager { fn start_rollback(conn: &mut PgConnection) { if conn.transaction_depth > 0 { - conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.transaction_depth)); + conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.transaction_depth)) + .expect("BUG: Rollback query somehow too large for protocol"); conn.transaction_depth -= 1; } diff --git a/sqlx-postgres/src/types/oid.rs b/sqlx-postgres/src/types/oid.rs index caa90dfcc7..04c5ef837a 100644 --- a/sqlx-postgres/src/types/oid.rs +++ b/sqlx-postgres/src/types/oid.rs @@ -17,12 +17,6 @@ pub struct Oid( pub u32, ); -impl Oid { - pub(crate) fn incr_one(&mut self) { - self.0 = self.0.wrapping_add(1); - } -} - impl Type for Oid { fn type_info() -> PgTypeInfo { PgTypeInfo::OID From b783dbdc2fb1dbf295c0404830809a7b27a189eb Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Sat, 17 Aug 2024 04:54:59 -0700 Subject: [PATCH 19/40] fix(mysql): fallout from ec5326e5 --- sqlx-mysql/Cargo.toml | 7 ++++++ sqlx-mysql/src/connection/stream.rs | 8 +++---- .../src/protocol/connect/auth_switch.rs | 11 +++++---- sqlx-mysql/src/protocol/connect/handshake.rs | 4 ++-- .../protocol/connect/handshake_response.rs | 24 ++++++++++++------- .../src/protocol/connect/ssl_request.rs | 14 ++++++----- sqlx-mysql/src/protocol/packet.rs | 16 +++++++------ sqlx-mysql/src/protocol/response/eof.rs | 4 ++-- sqlx-mysql/src/protocol/response/err.rs | 4 ++-- sqlx-mysql/src/protocol/response/ok.rs | 4 ++-- sqlx-mysql/src/protocol/statement/execute.rs | 8 ++++--- sqlx-mysql/src/protocol/statement/prepare.rs | 7 +++--- .../src/protocol/statement/prepare_ok.rs | 4 ++-- sqlx-mysql/src/protocol/statement/row.rs | 4 ++-- .../src/protocol/statement/stmt_close.rs | 7 +++--- sqlx-mysql/src/protocol/text/column.rs | 4 ++-- sqlx-mysql/src/protocol/text/ping.rs | 7 +++--- sqlx-mysql/src/protocol/text/query.rs | 9 +++---- sqlx-mysql/src/protocol/text/quit.rs | 7 +++--- sqlx-mysql/src/protocol/text/row.rs | 4 ++-- 20 files changed, 91 insertions(+), 66 deletions(-) diff --git a/sqlx-mysql/Cargo.toml b/sqlx-mysql/Cargo.toml index cc0100a913..174b529141 100644 --- a/sqlx-mysql/Cargo.toml +++ b/sqlx-mysql/Cargo.toml @@ -15,6 +15,13 @@ any = ["sqlx-core/any"] offline = ["sqlx-core/offline", "serde/derive"] migrate = ["sqlx-core/migrate"] +# Type Integration features +bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal"] +chrono = ["dep:chrono", "sqlx-core/chrono"] +rust_decimal = ["dep:rust_decimal", "rust_decimal/maths", "sqlx-core/rust_decimal"] +time = ["dep:time", "sqlx-core/time"] +uuid = ["dep:uuid", "sqlx-core/uuid"] + [dependencies] sqlx-core = { workspace = true } diff --git a/sqlx-mysql/src/connection/stream.rs b/sqlx-mysql/src/connection/stream.rs index c225256da2..465acc17cb 100644 --- a/sqlx-mysql/src/connection/stream.rs +++ b/sqlx-mysql/src/connection/stream.rs @@ -6,7 +6,7 @@ use bytes::{Buf, Bytes, BytesMut}; use crate::collation::{CharSet, Collation}; use crate::error::Error; use crate::io::MySqlBufExt; -use crate::io::{Decode, Encode}; +use crate::io::{ProtocolDecode, ProtocolEncode}; use crate::net::{BufferedSocket, Socket}; use crate::protocol::response::{EofPacket, ErrPacket, OkPacket, Status}; use crate::protocol::{Capabilities, Packet}; @@ -110,7 +110,7 @@ impl MySqlStream { pub(crate) async fn send_packet<'en, T>(&mut self, payload: T) -> Result<(), Error> where - T: Encode<'en, Capabilities>, + T: ProtocolEncode<'en, Capabilities>, { self.sequence_id = 0; self.write_packet(payload); @@ -120,7 +120,7 @@ impl MySqlStream { pub(crate) fn write_packet<'en, T>(&mut self, payload: T) where - T: Encode<'en, Capabilities>, + T: ProtocolEncode<'en, Capabilities>, { self.socket .write_with(Packet(payload), (self.capabilities, &mut self.sequence_id)); @@ -184,7 +184,7 @@ impl MySqlStream { pub(crate) async fn recv<'de, T>(&mut self) -> Result where - T: Decode<'de, Capabilities>, + T: ProtocolDecode<'de, Capabilities>, { self.recv_packet().await?.decode_with(self.capabilities) } diff --git a/sqlx-mysql/src/protocol/connect/auth_switch.rs b/sqlx-mysql/src/protocol/connect/auth_switch.rs index 58b7fbb2ef..7ecb1b2b17 100644 --- a/sqlx-mysql/src/protocol/connect/auth_switch.rs +++ b/sqlx-mysql/src/protocol/connect/auth_switch.rs @@ -1,8 +1,8 @@ use bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::Encode; -use crate::io::{BufExt, Decode}; +use crate::io::ProtocolEncode; +use crate::io::{BufExt, ProtocolDecode}; use crate::protocol::auth::AuthPlugin; use crate::protocol::Capabilities; @@ -14,7 +14,7 @@ pub struct AuthSwitchRequest { pub data: Bytes, } -impl Decode<'_, bool> for AuthSwitchRequest { +impl ProtocolDecode<'_, bool> for AuthSwitchRequest { fn decode_with(mut buf: Bytes, enable_cleartext_plugin: bool) -> Result { let header = buf.get_u8(); if header != 0xfe { @@ -58,9 +58,10 @@ impl Decode<'_, bool> for AuthSwitchRequest { #[derive(Debug)] pub struct AuthSwitchResponse(pub Vec); -impl Encode<'_, Capabilities> for AuthSwitchResponse { - fn encode_with(&self, buf: &mut Vec, _: Capabilities) { +impl ProtocolEncode<'_, Capabilities> for AuthSwitchResponse { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), Error> { buf.extend_from_slice(&self.0); + Ok(()) } } diff --git a/sqlx-mysql/src/protocol/connect/handshake.rs b/sqlx-mysql/src/protocol/connect/handshake.rs index 166fbbf06d..faf9634bc8 100644 --- a/sqlx-mysql/src/protocol/connect/handshake.rs +++ b/sqlx-mysql/src/protocol/connect/handshake.rs @@ -3,7 +3,7 @@ use bytes::{Buf, Bytes}; use std::cmp; use crate::error::Error; -use crate::io::{BufExt, Decode}; +use crate::io::{BufExt, ProtocolDecode}; use crate::protocol::auth::AuthPlugin; use crate::protocol::response::Status; use crate::protocol::Capabilities; @@ -27,7 +27,7 @@ pub(crate) struct Handshake { pub(crate) auth_plugin_data: Chain, } -impl Decode<'_> for Handshake { +impl ProtocolDecode<'_> for Handshake { fn decode_with(mut buf: Bytes, _: ()) -> Result { let protocol_version = buf.get_u8(); // int<1> let server_version = buf.get_str_nul()?; // string diff --git a/sqlx-mysql/src/protocol/connect/handshake_response.rs b/sqlx-mysql/src/protocol/connect/handshake_response.rs index c5d982e441..aace8a9d55 100644 --- a/sqlx-mysql/src/protocol/connect/handshake_response.rs +++ b/sqlx-mysql/src/protocol/connect/handshake_response.rs @@ -1,5 +1,5 @@ use crate::io::MySqlBufMutExt; -use crate::io::{BufMutExt, Encode}; +use crate::io::{BufMutExt, ProtocolEncode}; use crate::protocol::auth::AuthPlugin; use crate::protocol::connect::ssl_request::SslRequest; use crate::protocol::Capabilities; @@ -27,11 +27,15 @@ pub struct HandshakeResponse<'a> { pub auth_response: Option<&'a [u8]>, } -impl Encode<'_, Capabilities> for HandshakeResponse<'_> { - fn encode_with(&self, buf: &mut Vec, mut capabilities: Capabilities) { +impl ProtocolEncode<'_, Capabilities> for HandshakeResponse<'_> { + fn encode_with( + &self, + buf: &mut Vec, + mut context: Capabilities, + ) -> Result<(), crate::Error> { if self.auth_plugin.is_none() { // ensure PLUGIN_AUTH is set *only* if we have a defined plugin - capabilities.remove(Capabilities::PLUGIN_AUTH); + context.remove(Capabilities::PLUGIN_AUTH); } // NOTE: Half of this packet is identical to the SSL Request packet @@ -39,13 +43,13 @@ impl Encode<'_, Capabilities> for HandshakeResponse<'_> { max_packet_size: self.max_packet_size, collation: self.collation, } - .encode_with(buf, capabilities); + .encode_with(buf, context)?; buf.put_str_nul(self.username); - if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) { + if context.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) { buf.put_bytes_lenenc(self.auth_response.unwrap_or_default()); - } else if capabilities.contains(Capabilities::SECURE_CONNECTION) { + } else if context.contains(Capabilities::SECURE_CONNECTION) { let response = self.auth_response.unwrap_or_default(); buf.push(response.len() as u8); @@ -54,7 +58,7 @@ impl Encode<'_, Capabilities> for HandshakeResponse<'_> { buf.push(0); } - if capabilities.contains(Capabilities::CONNECT_WITH_DB) { + if context.contains(Capabilities::CONNECT_WITH_DB) { if let Some(database) = &self.database { buf.put_str_nul(database); } else { @@ -62,12 +66,14 @@ impl Encode<'_, Capabilities> for HandshakeResponse<'_> { } } - if capabilities.contains(Capabilities::PLUGIN_AUTH) { + if context.contains(Capabilities::PLUGIN_AUTH) { if let Some(plugin) = &self.auth_plugin { buf.put_str_nul(plugin.name()); } else { buf.push(0); } } + + Ok(()) } } diff --git a/sqlx-mysql/src/protocol/connect/ssl_request.rs b/sqlx-mysql/src/protocol/connect/ssl_request.rs index c7d501db24..4b85c7b275 100644 --- a/sqlx-mysql/src/protocol/connect/ssl_request.rs +++ b/sqlx-mysql/src/protocol/connect/ssl_request.rs @@ -1,4 +1,4 @@ -use crate::io::Encode; +use crate::io::ProtocolEncode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html @@ -10,21 +10,23 @@ pub struct SslRequest { pub collation: u8, } -impl Encode<'_, Capabilities> for SslRequest { - fn encode_with(&self, buf: &mut Vec, capabilities: Capabilities) { - buf.extend(&(capabilities.bits() as u32).to_le_bytes()); +impl ProtocolEncode<'_, Capabilities> for SslRequest { + fn encode_with(&self, buf: &mut Vec, context: Capabilities) -> Result<(), crate::Error> { + buf.extend(&(context.bits() as u32).to_le_bytes()); buf.extend(&self.max_packet_size.to_le_bytes()); buf.push(self.collation); // reserved: string<19> buf.extend(&[0_u8; 19]); - if capabilities.contains(Capabilities::MYSQL) { + if context.contains(Capabilities::MYSQL) { // reserved: string<4> buf.extend(&[0_u8; 4]); } else { // extended client capabilities (MariaDB-specified): int<4> - buf.extend(&((capabilities.bits() >> 32) as u32).to_le_bytes()); + buf.extend(&((context.bits() >> 32) as u32).to_le_bytes()); } + + Ok(()) } } diff --git a/sqlx-mysql/src/protocol/packet.rs b/sqlx-mysql/src/protocol/packet.rs index 9d0d46c35a..8805bd2e38 100644 --- a/sqlx-mysql/src/protocol/packet.rs +++ b/sqlx-mysql/src/protocol/packet.rs @@ -4,22 +4,22 @@ use std::ops::{Deref, DerefMut}; use bytes::Bytes; use crate::error::Error; -use crate::io::{Decode, Encode}; +use crate::io::{ProtocolDecode, ProtocolEncode}; use crate::protocol::response::{EofPacket, OkPacket}; use crate::protocol::Capabilities; #[derive(Debug)] pub struct Packet(pub(crate) T); -impl<'en, 'stream, T> Encode<'stream, (Capabilities, &'stream mut u8)> for Packet +impl<'en, 'stream, T> ProtocolEncode<'stream, (Capabilities, &'stream mut u8)> for Packet where - T: Encode<'en, Capabilities>, + T: ProtocolEncode<'en, Capabilities>, { fn encode_with( &self, buf: &mut Vec, (capabilities, sequence_id): (Capabilities, &'stream mut u8), - ) { + ) -> Result<(), Error> { let mut next_header = |len: u32| { let mut buf = len.to_le_bytes(); buf[3] = *sequence_id; @@ -33,7 +33,7 @@ where buf.extend(&[0_u8; 4]); // encode the payload - self.0.encode_with(buf, capabilities); + self.0.encode_with(buf, capabilities)?; // determine the length of the encoded payload // and write to our reserved space @@ -59,20 +59,22 @@ where buf.extend(&next_header(remainder.len() as u32)); buf.extend(remainder); } + + Ok(()) } } impl Packet { pub(crate) fn decode<'de, T>(self) -> Result where - T: Decode<'de, ()>, + T: ProtocolDecode<'de, ()>, { self.decode_with(()) } pub(crate) fn decode_with<'de, T, C>(self, context: C) -> Result where - T: Decode<'de, C>, + T: ProtocolDecode<'de, C>, { T::decode_with(self.0, context) } diff --git a/sqlx-mysql/src/protocol/response/eof.rs b/sqlx-mysql/src/protocol/response/eof.rs index 7bf4cecc28..89de9a32ce 100644 --- a/sqlx-mysql/src/protocol/response/eof.rs +++ b/sqlx-mysql/src/protocol/response/eof.rs @@ -1,7 +1,7 @@ use bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::Decode; +use crate::io::ProtocolDecode; use crate::protocol::response::Status; use crate::protocol::Capabilities; @@ -18,7 +18,7 @@ pub struct EofPacket { pub status: Status, } -impl Decode<'_, Capabilities> for EofPacket { +impl ProtocolDecode<'_, Capabilities> for EofPacket { fn decode_with(mut buf: Bytes, _: Capabilities) -> Result { let header = buf.get_u8(); if header != 0xfe { diff --git a/sqlx-mysql/src/protocol/response/err.rs b/sqlx-mysql/src/protocol/response/err.rs index cc48c7c58e..085d24f434 100644 --- a/sqlx-mysql/src/protocol/response/err.rs +++ b/sqlx-mysql/src/protocol/response/err.rs @@ -1,7 +1,7 @@ use bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::{BufExt, Decode}; +use crate::io::{BufExt, ProtocolDecode}; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html @@ -15,7 +15,7 @@ pub struct ErrPacket { pub error_message: String, } -impl Decode<'_, Capabilities> for ErrPacket { +impl ProtocolDecode<'_, Capabilities> for ErrPacket { fn decode_with(mut buf: Bytes, capabilities: Capabilities) -> Result { let header = buf.get_u8(); if header != 0xff { diff --git a/sqlx-mysql/src/protocol/response/ok.rs b/sqlx-mysql/src/protocol/response/ok.rs index 0eada2e8ff..d16127d5f2 100644 --- a/sqlx-mysql/src/protocol/response/ok.rs +++ b/sqlx-mysql/src/protocol/response/ok.rs @@ -1,8 +1,8 @@ use bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::Decode; use crate::io::MySqlBufExt; +use crate::io::ProtocolDecode; use crate::protocol::response::Status; /// Indicates successful completion of a previous command sent by the client. @@ -14,7 +14,7 @@ pub struct OkPacket { pub warnings: u16, } -impl Decode<'_> for OkPacket { +impl ProtocolDecode<'_> for OkPacket { fn decode_with(mut buf: Bytes, _: ()) -> Result { let header = buf.get_u8(); if header != 0 && header != 0xfe { diff --git a/sqlx-mysql/src/protocol/statement/execute.rs b/sqlx-mysql/src/protocol/statement/execute.rs index e1bf998b43..6e51e7b564 100644 --- a/sqlx-mysql/src/protocol/statement/execute.rs +++ b/sqlx-mysql/src/protocol/statement/execute.rs @@ -1,4 +1,4 @@ -use crate::io::Encode; +use crate::io::ProtocolEncode; use crate::protocol::text::ColumnFlags; use crate::protocol::Capabilities; use crate::MySqlArguments; @@ -11,8 +11,8 @@ pub struct Execute<'q> { pub arguments: &'q MySqlArguments, } -impl<'q> Encode<'_, Capabilities> for Execute<'q> { - fn encode_with(&self, buf: &mut Vec, _: Capabilities) { +impl<'q> ProtocolEncode<'_, Capabilities> for Execute<'q> { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), crate::Error> { buf.push(0x17); // COM_STMT_EXECUTE buf.extend(&self.statement.to_le_bytes()); buf.push(0); // NO_CURSOR @@ -34,5 +34,7 @@ impl<'q> Encode<'_, Capabilities> for Execute<'q> { buf.extend(&*self.arguments.values); } + + Ok(()) } } diff --git a/sqlx-mysql/src/protocol/statement/prepare.rs b/sqlx-mysql/src/protocol/statement/prepare.rs index 106e177475..6012b11939 100644 --- a/sqlx-mysql/src/protocol/statement/prepare.rs +++ b/sqlx-mysql/src/protocol/statement/prepare.rs @@ -1,4 +1,4 @@ -use crate::io::Encode; +use crate::io::ProtocolEncode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html#packet-COM_STMT_PREPARE @@ -7,9 +7,10 @@ pub struct Prepare<'a> { pub query: &'a str, } -impl Encode<'_, Capabilities> for Prepare<'_> { - fn encode_with(&self, buf: &mut Vec, _: Capabilities) { +impl ProtocolEncode<'_, Capabilities> for Prepare<'_> { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), crate::Error> { buf.push(0x16); // COM_STMT_PREPARE buf.extend(self.query.as_bytes()); + Ok(()) } } diff --git a/sqlx-mysql/src/protocol/statement/prepare_ok.rs b/sqlx-mysql/src/protocol/statement/prepare_ok.rs index 842fe25bba..da25047a1d 100644 --- a/sqlx-mysql/src/protocol/statement/prepare_ok.rs +++ b/sqlx-mysql/src/protocol/statement/prepare_ok.rs @@ -1,7 +1,7 @@ use bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::Decode; +use crate::io::ProtocolDecode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK @@ -15,7 +15,7 @@ pub(crate) struct PrepareOk { pub(crate) warnings: u16, } -impl Decode<'_, Capabilities> for PrepareOk { +impl ProtocolDecode<'_, Capabilities> for PrepareOk { fn decode_with(buf: Bytes, _: Capabilities) -> Result { const SIZE: usize = 12; diff --git a/sqlx-mysql/src/protocol/statement/row.rs b/sqlx-mysql/src/protocol/statement/row.rs index fcb2f4d4e5..d5791ce749 100644 --- a/sqlx-mysql/src/protocol/statement/row.rs +++ b/sqlx-mysql/src/protocol/statement/row.rs @@ -2,7 +2,7 @@ use bytes::{Buf, Bytes}; use crate::error::Error; use crate::io::MySqlBufExt; -use crate::io::{BufExt, Decode}; +use crate::io::{BufExt, ProtocolDecode}; use crate::protocol::text::ColumnType; use crate::protocol::Row; use crate::MySqlColumn; @@ -13,7 +13,7 @@ use crate::MySqlColumn; #[derive(Debug)] pub(crate) struct BinaryRow(pub(crate) Row); -impl<'de> Decode<'de, &'de [MySqlColumn]> for BinaryRow { +impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for BinaryRow { fn decode_with(mut buf: Bytes, columns: &'de [MySqlColumn]) -> Result { let header = buf.get_u8(); if header != 0 { diff --git a/sqlx-mysql/src/protocol/statement/stmt_close.rs b/sqlx-mysql/src/protocol/statement/stmt_close.rs index 150a57e2c6..a92f03108e 100644 --- a/sqlx-mysql/src/protocol/statement/stmt_close.rs +++ b/sqlx-mysql/src/protocol/statement/stmt_close.rs @@ -1,4 +1,4 @@ -use crate::io::Encode; +use crate::io::ProtocolEncode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/internals/en/com-stmt-close.html @@ -8,9 +8,10 @@ pub struct StmtClose { pub statement: u32, } -impl Encode<'_, Capabilities> for StmtClose { - fn encode_with(&self, buf: &mut Vec, _: Capabilities) { +impl ProtocolEncode<'_, Capabilities> for StmtClose { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), crate::Error> { buf.push(0x19); // COM_STMT_CLOSE buf.extend(&self.statement.to_le_bytes()); + Ok(()) } } diff --git a/sqlx-mysql/src/protocol/text/column.rs b/sqlx-mysql/src/protocol/text/column.rs index 36778fec2a..15c990d5be 100644 --- a/sqlx-mysql/src/protocol/text/column.rs +++ b/sqlx-mysql/src/protocol/text/column.rs @@ -4,8 +4,8 @@ use bitflags::bitflags; use bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::Decode; use crate::io::MySqlBufExt; +use crate::io::ProtocolDecode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__column__definition__flags.html @@ -134,7 +134,7 @@ impl ColumnDefinition { } } -impl Decode<'_, Capabilities> for ColumnDefinition { +impl ProtocolDecode<'_, Capabilities> for ColumnDefinition { fn decode_with(mut buf: Bytes, _: Capabilities) -> Result { let catalog = buf.get_bytes_lenenc(); let schema = buf.get_bytes_lenenc(); diff --git a/sqlx-mysql/src/protocol/text/ping.rs b/sqlx-mysql/src/protocol/text/ping.rs index 217a826477..4eb8ab2eb4 100644 --- a/sqlx-mysql/src/protocol/text/ping.rs +++ b/sqlx-mysql/src/protocol/text/ping.rs @@ -1,4 +1,4 @@ -use crate::io::Encode; +use crate::io::ProtocolEncode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/internals/en/com-ping.html @@ -6,8 +6,9 @@ use crate::protocol::Capabilities; #[derive(Debug)] pub(crate) struct Ping; -impl Encode<'_, Capabilities> for Ping { - fn encode_with(&self, buf: &mut Vec, _: Capabilities) { +impl ProtocolEncode<'_, Capabilities> for Ping { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), crate::Error> { buf.push(0x0e); // COM_PING + Ok(()) } } diff --git a/sqlx-mysql/src/protocol/text/query.rs b/sqlx-mysql/src/protocol/text/query.rs index caca9e46bf..b3533adb42 100644 --- a/sqlx-mysql/src/protocol/text/query.rs +++ b/sqlx-mysql/src/protocol/text/query.rs @@ -1,4 +1,4 @@ -use crate::io::Encode; +use crate::io::ProtocolEncode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/internals/en/com-query.html @@ -6,9 +6,10 @@ use crate::protocol::Capabilities; #[derive(Debug)] pub(crate) struct Query<'q>(pub(crate) &'q str); -impl Encode<'_, Capabilities> for Query<'_> { - fn encode_with(&self, buf: &mut Vec, _: Capabilities) { +impl ProtocolEncode<'_, Capabilities> for Query<'_> { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), crate::Error> { buf.push(0x03); // COM_QUERY - buf.extend(self.0.as_bytes()) + buf.extend(self.0.as_bytes()); + Ok(()) } } diff --git a/sqlx-mysql/src/protocol/text/quit.rs b/sqlx-mysql/src/protocol/text/quit.rs index e4f6525f20..c0d8729e9d 100644 --- a/sqlx-mysql/src/protocol/text/quit.rs +++ b/sqlx-mysql/src/protocol/text/quit.rs @@ -1,4 +1,4 @@ -use crate::io::Encode; +use crate::io::ProtocolEncode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/internals/en/com-quit.html @@ -6,8 +6,9 @@ use crate::protocol::Capabilities; #[derive(Debug)] pub(crate) struct Quit; -impl Encode<'_, Capabilities> for Quit { - fn encode_with(&self, buf: &mut Vec, _: Capabilities) { +impl ProtocolEncode<'_, Capabilities> for Quit { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), crate::Error> { buf.push(0x01); // COM_QUIT + Ok(()) } } diff --git a/sqlx-mysql/src/protocol/text/row.rs b/sqlx-mysql/src/protocol/text/row.rs index 852c47604e..e32f6e9b82 100644 --- a/sqlx-mysql/src/protocol/text/row.rs +++ b/sqlx-mysql/src/protocol/text/row.rs @@ -2,14 +2,14 @@ use bytes::{Buf, Bytes}; use crate::column::MySqlColumn; use crate::error::Error; -use crate::io::Decode; use crate::io::MySqlBufExt; +use crate::io::ProtocolDecode; use crate::protocol::Row; #[derive(Debug)] pub(crate) struct TextRow(pub(crate) Row); -impl<'de> Decode<'de, &'de [MySqlColumn]> for TextRow { +impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for TextRow { fn decode_with(mut buf: Bytes, columns: &'de [MySqlColumn]) -> Result { let storage = buf.clone(); let offset = buf.len(); From 15fc55c19bf252eef9abf4a60a2660e0b389552d Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Sun, 18 Aug 2024 00:20:35 -0700 Subject: [PATCH 20/40] chore(mysql): deny bad-cast lints --- sqlx-mysql/src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sqlx-mysql/src/lib.rs b/sqlx-mysql/src/lib.rs index d77d0daf3b..c35be3c6fa 100644 --- a/sqlx-mysql/src/lib.rs +++ b/sqlx-mysql/src/lib.rs @@ -1,4 +1,7 @@ //! **MySQL** database driver. +#![deny(clippy::cast_possible_truncation)] +#![deny(clippy::cast_possible_wrap)] +#![deny(clippy::cast_sign_loss)] #[macro_use] extern crate sqlx_core; From a9510c8318e9293a4c8c665513f0cd48964d9e33 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Tue, 20 Aug 2024 01:15:35 -0700 Subject: [PATCH 21/40] fix(mysql): audit for bad casts --- sqlx-mysql/src/any.rs | 2 + sqlx-mysql/src/connection/auth.rs | 4 +- sqlx-mysql/src/connection/establish.rs | 4 +- sqlx-mysql/src/connection/executor.rs | 4 +- sqlx-mysql/src/connection/stream.rs | 8 +- sqlx-mysql/src/connection/tls.rs | 2 +- sqlx-mysql/src/io/buf.rs | 14 ++- sqlx-mysql/src/io/buf_mut.rs | 27 +++-- sqlx-mysql/src/migrate.rs | 4 +- sqlx-mysql/src/protocol/capabilities.rs | 52 ++++++--- sqlx-mysql/src/protocol/connect/handshake.rs | 4 +- .../protocol/connect/handshake_response.rs | 5 +- .../src/protocol/connect/ssl_request.rs | 2 + sqlx-mysql/src/protocol/packet.rs | 8 ++ sqlx-mysql/src/protocol/statement/row.rs | 13 ++- sqlx-mysql/src/protocol/text/column.rs | 12 +-- sqlx-mysql/src/protocol/text/row.rs | 5 +- sqlx-mysql/src/transaction.rs | 3 +- sqlx-mysql/src/types/chrono.rs | 100 +++++++++++------- sqlx-mysql/src/types/float.rs | 1 + sqlx-mysql/src/types/mysql_time.rs | 2 + sqlx-mysql/src/types/time.rs | 72 ++++++++----- 22 files changed, 228 insertions(+), 120 deletions(-) diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index fa8d34f8db..0466bfc0a4 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -213,6 +213,8 @@ impl<'a> TryFrom<&'a AnyConnectOptions> for MySqlConnectOptions { fn map_result(result: MySqlQueryResult) -> AnyQueryResult { AnyQueryResult { rows_affected: result.rows_affected, + // Don't expect this to be a problem + #[allow(clippy::cast_possible_wrap)] last_insert_id: Some(result.last_insert_id as i64), } } diff --git a/sqlx-mysql/src/connection/auth.rs b/sqlx-mysql/src/connection/auth.rs index 75115f5bec..613f8e702f 100644 --- a/sqlx-mysql/src/connection/auth.rs +++ b/sqlx-mysql/src/connection/auth.rs @@ -53,7 +53,7 @@ impl AuthPlugin { 0x04 => { let payload = encrypt_rsa(stream, 0x02, password, nonce).await?; - stream.write_packet(&*payload); + stream.write_packet(&*payload)?; stream.flush().await?; Ok(false) @@ -143,7 +143,7 @@ async fn encrypt_rsa<'s>( } // client sends a public key request - stream.write_packet(&[public_key_request_id][..]); + stream.write_packet(&[public_key_request_id][..])?; stream.flush().await?; // server sends a public key response diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index 72590324f7..468478e550 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -131,7 +131,7 @@ impl<'a> DoHandshake<'a> { database: options.database.as_deref(), auth_plugin: plugin, auth_response: auth_response.as_deref(), - }); + })?; stream.flush().await?; @@ -160,7 +160,7 @@ impl<'a> DoHandshake<'a> { ) .await?; - stream.write_packet(AuthSwitchResponse(response)); + stream.write_packet(AuthSwitchResponse(response))?; stream.flush().await?; } diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index 9742cdf224..07c7979b08 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -187,7 +187,9 @@ impl MySqlConnection { // otherwise, this first packet is the start of the result-set metadata, *self.inner.stream.waiting.front_mut().unwrap() = Waiting::Row; - let num_columns = packet.get_uint_lenenc() as usize; // column count + let num_columns = packet.get_uint_lenenc(); // column count + let num_columns = usize::try_from(num_columns) + .map_err(|_| err_protocol!("column count overflows usize: {num_columns}"))?; if needs_metadata { column_names = Arc::new(recv_result_metadata(&mut self.inner.stream, num_columns, Arc::make_mut(&mut columns)).await?); diff --git a/sqlx-mysql/src/connection/stream.rs b/sqlx-mysql/src/connection/stream.rs index 465acc17cb..1f93ed11a5 100644 --- a/sqlx-mysql/src/connection/stream.rs +++ b/sqlx-mysql/src/connection/stream.rs @@ -113,17 +113,17 @@ impl MySqlStream { T: ProtocolEncode<'en, Capabilities>, { self.sequence_id = 0; - self.write_packet(payload); + self.write_packet(payload)?; self.flush().await?; Ok(()) } - pub(crate) fn write_packet<'en, T>(&mut self, payload: T) + pub(crate) fn write_packet<'en, T>(&mut self, payload: T) -> Result<(), Error> where T: ProtocolEncode<'en, Capabilities>, { self.socket - .write_with(Packet(payload), (self.capabilities, &mut self.sequence_id)); + .write_with(Packet(payload), (self.capabilities, &mut self.sequence_id)) } async fn recv_packet_part(&mut self) -> Result { @@ -132,6 +132,8 @@ impl MySqlStream { let mut header: Bytes = self.socket.read(4).await?; + // cannot overflow + #[allow(clippy::cast_possible_truncation)] let packet_size = header.get_uint_le(3) as usize; let sequence_id = header.get_u8(); diff --git a/sqlx-mysql/src/connection/tls.rs b/sqlx-mysql/src/connection/tls.rs index f98c5c532f..22b3c487b5 100644 --- a/sqlx-mysql/src/connection/tls.rs +++ b/sqlx-mysql/src/connection/tls.rs @@ -72,7 +72,7 @@ pub(super) async fn maybe_upgrade( stream.write_packet(SslRequest { max_packet_size: super::MAX_PACKET_SIZE, collation: stream.collation as u8, - }); + })?; stream.flush().await?; diff --git a/sqlx-mysql/src/io/buf.rs b/sqlx-mysql/src/io/buf.rs index 98bb3407d9..685d5bfda7 100644 --- a/sqlx-mysql/src/io/buf.rs +++ b/sqlx-mysql/src/io/buf.rs @@ -15,7 +15,7 @@ pub trait MySqlBufExt: Buf { fn get_str_lenenc(&mut self) -> Result; // Read a length-encoded byte sequence. - fn get_bytes_lenenc(&mut self) -> Bytes; + fn get_bytes_lenenc(&mut self) -> Result; } impl MySqlBufExt for Bytes { @@ -31,11 +31,17 @@ impl MySqlBufExt for Bytes { fn get_str_lenenc(&mut self) -> Result { let size = self.get_uint_lenenc(); - self.get_str(size as usize) + let size = usize::try_from(size) + .map_err(|_| err_protocol!("string length overflows usize: {size}"))?; + + self.get_str(size) } - fn get_bytes_lenenc(&mut self) -> Bytes { + fn get_bytes_lenenc(&mut self) -> Result { let size = self.get_uint_lenenc(); - self.split_to(size as usize) + let size = usize::try_from(size) + .map_err(|_| err_protocol!("string length overflows usize: {size}"))?; + + Ok(self.split_to(size)) } } diff --git a/sqlx-mysql/src/io/buf_mut.rs b/sqlx-mysql/src/io/buf_mut.rs index 95bf04f319..0cf58794de 100644 --- a/sqlx-mysql/src/io/buf_mut.rs +++ b/sqlx-mysql/src/io/buf_mut.rs @@ -13,17 +13,22 @@ impl MySqlBufMutExt for Vec { // https://dev.mysql.com/doc/internals/en/integer.html // https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers - if v < 251 { - self.push(v as u8); - } else if v < 0x1_00_00 { - self.push(0xfc); - self.extend(&(v as u16).to_le_bytes()); - } else if v < 0x1_00_00_00 { - self.push(0xfd); - self.extend(&(v as u32).to_le_bytes()[..3]); - } else { - self.push(0xfe); - self.extend(&v.to_le_bytes()); + let encoded_le = v.to_le_bytes(); + + match v { + ..251 => self.push(encoded_le[0]), + 251..0x1_00_00 => { + self.push(0xfc); + self.extend_from_slice(&encoded_le[..2]); + } + 0x1_00_00..0x1_00_00_00 => { + self.push(0xfd); + self.extend_from_slice(&encoded_le[..3]); + } + _ => { + self.push(0xfe); + self.extend_from_slice(&encoded_le); + } } } diff --git a/sqlx-mysql/src/migrate.rs b/sqlx-mysql/src/migrate.rs index 7c514d8631..c595ca2cb4 100644 --- a/sqlx-mysql/src/migrate.rs +++ b/sqlx-mysql/src/migrate.rs @@ -231,7 +231,9 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( WHERE version = ? "#, ) - .bind(elapsed.as_nanos() as i64) + // Unlikely unless the execution time exceeds ~292 years, + // then we're probably okay losing that information. + .bind(i64::try_from(elapsed.as_nanos()).ok()) .bind(migration.version) .execute(self) .await?; diff --git a/sqlx-mysql/src/protocol/capabilities.rs b/sqlx-mysql/src/protocol/capabilities.rs index a1c6824985..a9c5cc581b 100644 --- a/sqlx-mysql/src/protocol/capabilities.rs +++ b/sqlx-mysql/src/protocol/capabilities.rs @@ -1,5 +1,9 @@ // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__capabilities__flags.html // https://mariadb.com/kb/en/library/connection/#capabilities +// +// MySQL defines the capabilities flags as fitting in an `int<4>` but MariaDB +// extends this with more bits sent in a separate field. +// For simplicity, we've chosen to combine these into one type. bitflags::bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Capabilities: u64 { @@ -43,45 +47,65 @@ bitflags::bitflags! { const TRANSACTIONS = 8192; // 4.1+ authentication - const SECURE_CONNECTION = (1 << 15); + const SECURE_CONNECTION = 1 << 15; // Enable/disable multi-statement support for COM_QUERY *and* COM_STMT_PREPARE - const MULTI_STATEMENTS = (1 << 16); + const MULTI_STATEMENTS = 1 << 16; // Enable/disable multi-results for COM_QUERY - const MULTI_RESULTS = (1 << 17); + const MULTI_RESULTS = 1 << 17; // Enable/disable multi-results for COM_STMT_PREPARE - const PS_MULTI_RESULTS = (1 << 18); + const PS_MULTI_RESULTS = 1 << 18; // Client supports plugin authentication - const PLUGIN_AUTH = (1 << 19); + const PLUGIN_AUTH = 1 << 19; // Client supports connection attributes - const CONNECT_ATTRS = (1 << 20); + const CONNECT_ATTRS = 1 << 20; // Enable authentication response packet to be larger than 255 bytes. - const PLUGIN_AUTH_LENENC_DATA = (1 << 21); + const PLUGIN_AUTH_LENENC_DATA = 1 << 21; // Don't close the connection for a user account with expired password. - const CAN_HANDLE_EXPIRED_PASSWORDS = (1 << 22); + const CAN_HANDLE_EXPIRED_PASSWORDS = 1 << 22; // Capable of handling server state change information. - const SESSION_TRACK = (1 << 23); + const SESSION_TRACK = 1 << 23; // Client no longer needs EOF_Packet and will use OK_Packet instead. - const DEPRECATE_EOF = (1 << 24); + const DEPRECATE_EOF = 1 << 24; // Support ZSTD protocol compression - const ZSTD_COMPRESSION_ALGORITHM = (1 << 26); + const ZSTD_COMPRESSION_ALGORITHM = 1 << 26; // Verify server certificate - const SSL_VERIFY_SERVER_CERT = (1 << 30); + const SSL_VERIFY_SERVER_CERT = 1 << 30; // The client can handle optional metadata information in the resultset - const OPTIONAL_RESULTSET_METADATA = (1 << 25); + const OPTIONAL_RESULTSET_METADATA = 1 << 25; // Don't reset the options after an unsuccessful connect - const REMEMBER_OPTIONS = (1 << 31); + const REMEMBER_OPTIONS = 1 << 31; + + // Extended capabilities (MariaDB only, as of writing) + // Client support progress indicator (since 10.2) + const MARIADB_CLIENT_PROGRESS = 1 << 32; + + // Permit COM_MULTI protocol + const MARIADB_CLIENT_MULTI = 1 << 33; + + // Permit bulk insert + const MARIADB_CLIENT_STMT_BULK_OPERATIONS = 1 << 34; + + // Add extended metadata information + const MARIADB_CLIENT_EXTENDED_TYPE_INFO = 1 << 35; + + // Permit skipping metadata + const MARIADB_CLIENT_CACHE_METADATA = 1 << 36; + + // when enabled, indicate that Bulk command can use STMT_BULK_FLAG_SEND_UNIT_RESULTS flag + // that permit to return a result-set of all affected rows and auto-increment values + const MARIADB_CLIENT_BULK_UNIT_RESULTS = 1 << 37; } } diff --git a/sqlx-mysql/src/protocol/connect/handshake.rs b/sqlx-mysql/src/protocol/connect/handshake.rs index faf9634bc8..84afe74ea6 100644 --- a/sqlx-mysql/src/protocol/connect/handshake.rs +++ b/sqlx-mysql/src/protocol/connect/handshake.rs @@ -62,8 +62,8 @@ impl ProtocolDecode<'_> for Handshake { } let auth_plugin_data_2 = if capabilities.contains(Capabilities::SECURE_CONNECTION) { - let len = cmp::max((auth_plugin_data_len as isize) - 9, 12) as usize; - let v = buf.get_bytes(len); + let len = cmp::max(auth_plugin_data_len.saturating_sub(9), 12); + let v = buf.get_bytes(len as usize); buf.advance(1); // NUL-terminator v diff --git a/sqlx-mysql/src/protocol/connect/handshake_response.rs b/sqlx-mysql/src/protocol/connect/handshake_response.rs index aace8a9d55..2e6fec1c0d 100644 --- a/sqlx-mysql/src/protocol/connect/handshake_response.rs +++ b/sqlx-mysql/src/protocol/connect/handshake_response.rs @@ -52,7 +52,10 @@ impl ProtocolEncode<'_, Capabilities> for HandshakeResponse<'_> { } else if context.contains(Capabilities::SECURE_CONNECTION) { let response = self.auth_response.unwrap_or_default(); - buf.push(response.len() as u8); + let response_len = u8::try_from(response.len()) + .map_err(|_| err_protocol!("auth_response.len() too long: {}", response.len()))?; + + buf.push(response_len); buf.extend(response); } else { buf.push(0); diff --git a/sqlx-mysql/src/protocol/connect/ssl_request.rs b/sqlx-mysql/src/protocol/connect/ssl_request.rs index 4b85c7b275..cdfc9e5178 100644 --- a/sqlx-mysql/src/protocol/connect/ssl_request.rs +++ b/sqlx-mysql/src/protocol/connect/ssl_request.rs @@ -12,6 +12,8 @@ pub struct SslRequest { impl ProtocolEncode<'_, Capabilities> for SslRequest { fn encode_with(&self, buf: &mut Vec, context: Capabilities) -> Result<(), crate::Error> { + // truncation is intended + #[allow(clippy::cast_possible_truncation)] buf.extend(&(context.bits() as u32).to_le_bytes()); buf.extend(&self.max_packet_size.to_le_bytes()); buf.push(self.collation); diff --git a/sqlx-mysql/src/protocol/packet.rs b/sqlx-mysql/src/protocol/packet.rs index 8805bd2e38..d43338dcd4 100644 --- a/sqlx-mysql/src/protocol/packet.rs +++ b/sqlx-mysql/src/protocol/packet.rs @@ -40,6 +40,8 @@ where let len = buf.len() - offset - 4; let header = &mut buf[offset..]; + // // `min(.., 0xFF_FF_FF)` cannot overflow + #[allow(clippy::cast_possible_truncation)] header[..4].copy_from_slice(&next_header(min(len, 0xFF_FF_FF) as u32)); // add more packets if we need to split the data @@ -49,6 +51,9 @@ where for chunk in chunks.by_ref() { buf.reserve(chunk.len() + 4); + + // `chunk.len() == 0xFF_FF_FF` + #[allow(clippy::cast_possible_truncation)] buf.extend(&next_header(chunk.len() as u32)); buf.extend(chunk); } @@ -56,6 +61,9 @@ where // this will also handle adding a zero sized packet if the data size is a multiple of 0xFF_FF_FF let remainder = chunks.remainder(); buf.reserve(remainder.len() + 4); + + // `remainder.len() < 0xFF_FF_FF` + #[allow(clippy::cast_possible_truncation)] buf.extend(&next_header(remainder.len() as u32)); buf.extend(remainder); } diff --git a/sqlx-mysql/src/protocol/statement/row.rs b/sqlx-mysql/src/protocol/statement/row.rs index d5791ce749..3007884c72 100644 --- a/sqlx-mysql/src/protocol/statement/row.rs +++ b/sqlx-mysql/src/protocol/statement/row.rs @@ -34,8 +34,11 @@ impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for BinaryRow { for (column_idx, column) in columns.iter().enumerate() { // NOTE: the column index starts at the 3rd bit let column_null_idx = column_idx + 2; - let is_null = - null_bitmap[column_null_idx / 8] & (1 << (column_null_idx % 8) as u8) != 0; + + let byte_idx = column_null_idx / 8; + let bit_idx = column_null_idx % 8; + + let is_null = null_bitmap[byte_idx] & (1u8 << bit_idx) != 0; if is_null { values.push(None); @@ -72,7 +75,11 @@ impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for BinaryRow { | ColumnType::Bit | ColumnType::Decimal | ColumnType::Json - | ColumnType::NewDecimal => buf.get_uint_lenenc() as usize, + | ColumnType::NewDecimal => { + let size = buf.get_uint_lenenc(); + usize::try_from(size) + .map_err(|_| err_protocol!("BLOB length out of range: {size}"))? + } // Like strings and blobs, these values are variable-length. // Unlike strings and blobs, however, they exclusively use one byte for length. diff --git a/sqlx-mysql/src/protocol/text/column.rs b/sqlx-mysql/src/protocol/text/column.rs index 15c990d5be..425a5cdc47 100644 --- a/sqlx-mysql/src/protocol/text/column.rs +++ b/sqlx-mysql/src/protocol/text/column.rs @@ -136,12 +136,12 @@ impl ColumnDefinition { impl ProtocolDecode<'_, Capabilities> for ColumnDefinition { fn decode_with(mut buf: Bytes, _: Capabilities) -> Result { - let catalog = buf.get_bytes_lenenc(); - let schema = buf.get_bytes_lenenc(); - let table_alias = buf.get_bytes_lenenc(); - let table = buf.get_bytes_lenenc(); - let alias = buf.get_bytes_lenenc(); - let name = buf.get_bytes_lenenc(); + let catalog = buf.get_bytes_lenenc()?; + let schema = buf.get_bytes_lenenc()?; + let table_alias = buf.get_bytes_lenenc()?; + let table = buf.get_bytes_lenenc()?; + let alias = buf.get_bytes_lenenc()?; + let name = buf.get_bytes_lenenc()?; let _next_len = buf.get_uint_lenenc(); // always 0x0c let collation = buf.get_u16_le(); let max_size = buf.get_u32_le(); diff --git a/sqlx-mysql/src/protocol/text/row.rs b/sqlx-mysql/src/protocol/text/row.rs index e32f6e9b82..0b81cd7f4f 100644 --- a/sqlx-mysql/src/protocol/text/row.rs +++ b/sqlx-mysql/src/protocol/text/row.rs @@ -22,7 +22,10 @@ impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for TextRow { values.push(None); buf.advance(1); } else { - let size = buf.get_uint_lenenc() as usize; + let size = buf.get_uint_lenenc(); + let size = usize::try_from(size) + .map_err(|_| err_protocol!("TextRow length out of range: {size}"))?; + let offset = offset - buf.len(); values.push(Some(offset..(offset + size))); diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index 99d6526392..d8538cc2b3 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -59,7 +59,8 @@ impl TransactionManager for MySqlTransactionManager { conn.inner.stream.sequence_id = 0; conn.inner .stream - .write_packet(Query(&rollback_ansi_transaction_sql(depth))); + .write_packet(Query(&rollback_ansi_transaction_sql(depth))) + .expect("BUG: unexpected error queueing ROLLBACK"); conn.inner.transaction_depth = depth - 1; } diff --git a/sqlx-mysql/src/types/chrono.rs b/sqlx-mysql/src/types/chrono.rs index ed39844b09..39e215bec5 100644 --- a/sqlx-mysql/src/types/chrono.rs +++ b/sqlx-mysql/src/types/chrono.rs @@ -70,8 +70,8 @@ impl Type for NaiveTime { impl Encode<'_, MySql> for NaiveTime { fn encode_by_ref(&self, buf: &mut Vec) -> Result { - let len = Encode::::size_hint(self) - 1; - buf.push(len as u8); + let len = naive_time_encoded_len(self); + buf.push(len); // NaiveTime is not negative buf.push(0); @@ -80,19 +80,13 @@ impl Encode<'_, MySql> for NaiveTime { // https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding buf.extend_from_slice(&[0_u8; 4]); - encode_time(self, len > 9, buf); + encode_time(self, len > 8, buf); Ok(IsNull::No) } fn size_hint(&self) -> usize { - if self.nanosecond() == 0 { - // if micro_seconds is 0, length is 8 and micro_seconds is not sent - 9 - } else { - // otherwise length is 12 - 13 - } + naive_time_encoded_len(self) as usize + 1 // plus length byte } } @@ -217,38 +211,20 @@ impl Type for NaiveDateTime { impl Encode<'_, MySql> for NaiveDateTime { fn encode_by_ref(&self, buf: &mut Vec) -> Result { - let len = Encode::::size_hint(self) - 1; - buf.push(len as u8); + let len = naive_dt_encoded_len(self); + buf.push(len); encode_date(&self.date(), buf)?; if len > 4 { - encode_time(&self.time(), len > 8, buf); + encode_time(&self.time(), len > 7, buf); } Ok(IsNull::No) } fn size_hint(&self) -> usize { - // to save space the packet can be compressed: - match ( - self.hour(), - self.minute(), - self.second(), - #[allow(deprecated)] - self.timestamp_subsec_nanos(), - ) { - // if hour, minutes, seconds and micro_seconds are all 0, - // length is 4 and no other field is sent - (0, 0, 0, 0) => 5, - - // if micro_seconds is 0, length is 7 - // and micro_seconds is not sent - (_, _, _, 0) => 8, - - // otherwise length is 11 - (_, _, _, _) => 12, - } + naive_dt_encoded_len(self) as usize + 1 // plus length byte } } @@ -284,13 +260,18 @@ impl<'r> Decode<'r, MySql> for NaiveDateTime { } fn encode_date(date: &NaiveDate, buf: &mut Vec) -> Result<(), BoxDynError> { - // MySQL supports years from 1000 - 9999 + // MySQL supports years 1000 - 9999 let year = u16::try_from(date.year()) .map_err(|_| format!("NaiveDateTime out of range for Mysql: {date}"))?; buf.extend_from_slice(&year.to_le_bytes()); - buf.push(date.month() as u8); - buf.push(date.day() as u8); + + // `NaiveDate` guarantees the ranges of these values + #[allow(clippy::cast_possible_truncation)] + { + buf.push(date.month() as u8); + buf.push(date.day() as u8); + } Ok(()) } @@ -314,9 +295,13 @@ fn decode_date(mut buf: &[u8]) -> Result, BoxDynError> { } fn encode_time(time: &NaiveTime, include_micros: bool, buf: &mut Vec) { - buf.push(time.hour() as u8); - buf.push(time.minute() as u8); - buf.push(time.second() as u8); + // `NaiveTime` API guarantees the ranges of these values + #[allow(clippy::cast_possible_truncation)] + { + buf.push(time.hour() as u8); + buf.push(time.minute() as u8); + buf.push(time.second() as u8); + } if include_micros { buf.extend((time.nanosecond() / 1000).to_le_bytes()); @@ -335,6 +320,43 @@ fn decode_time(len: u8, mut buf: &[u8]) -> Result { 0 }; - NaiveTime::from_hms_micro_opt(hour as u32, minute as u32, seconds as u32, micros as u32) + let micros = u32::try_from(micros) + .map_err(|_| format!("server returned microseconds out of range: {micros}"))?; + + NaiveTime::from_hms_micro_opt(hour as u32, minute as u32, seconds as u32, micros) .ok_or_else(|| format!("server returned invalid time: {hour:02}:{minute:02}:{seconds:02}; micros: {micros}").into()) } + +#[inline(always)] +fn naive_dt_encoded_len(time: &NaiveDateTime) -> u8 { + // to save space the packet can be compressed: + match ( + time.hour(), + time.minute(), + time.second(), + #[allow(deprecated)] + time.timestamp_subsec_nanos(), + ) { + // if hour, minutes, seconds and micro_seconds are all 0, + // length is 4 and no other field is sent + (0, 0, 0, 0) => 4, + + // if micro_seconds is 0, length is 7 + // and micro_seconds is not sent + (_, _, _, 0) => 7, + + // otherwise length is 11 + (_, _, _, _) => 11, + } +} + +#[inline(always)] +fn naive_time_encoded_len(time: &NaiveTime) -> u8 { + if time.nanosecond() == 0 { + // if micro_seconds is 0, length is 8 and micro_seconds is not sent + 8 + } else { + // otherwise length is 12 + 12 + } +} diff --git a/sqlx-mysql/src/types/float.rs b/sqlx-mysql/src/types/float.rs index 13809f39fe..44acb31b3a 100644 --- a/sqlx-mysql/src/types/float.rs +++ b/sqlx-mysql/src/types/float.rs @@ -59,6 +59,7 @@ impl Decode<'_, MySql> for f32 { 4 => LittleEndian::read_f32(buf), // MySQL can return 8-byte DOUBLE values for a FLOAT // We take and truncate to f32 as that's the same behavior as *in* MySQL, + #[allow(clippy::cast_possible_truncation)] 8 => LittleEndian::read_f64(buf) as f32, other => { // Users may try to decode a DECIMAL as floating point; diff --git a/sqlx-mysql/src/types/mysql_time.rs b/sqlx-mysql/src/types/mysql_time.rs index 1e91b82e08..b549af5765 100644 --- a/sqlx-mysql/src/types/mysql_time.rs +++ b/sqlx-mysql/src/types/mysql_time.rs @@ -617,6 +617,8 @@ fn parse_microseconds(micros: &str) -> Result { len @ ..=EXPECTED_DIGITS => { // Fewer than 6 digits, multiply to the correct magnitude let micros: u32 = micros.parse()?; + // cast cannot overflow + #[allow(clippy::cast_possible_truncation)] Ok(micros * 10u32.pow((EXPECTED_DIGITS - len) as u32)) } // More digits than expected, truncate diff --git a/sqlx-mysql/src/types/time.rs b/sqlx-mysql/src/types/time.rs index 4fc46a33e4..e04f8928c9 100644 --- a/sqlx-mysql/src/types/time.rs +++ b/sqlx-mysql/src/types/time.rs @@ -47,29 +47,23 @@ impl Type for Time { impl Encode<'_, MySql> for Time { fn encode_by_ref(&self, buf: &mut Vec) -> Result { - let len = Encode::::size_hint(self) - 1; - buf.push(len as u8); + let len = time_encoded_len(self); + buf.push(len); - // Time is not negative + // sign byte: Time is never negative buf.push(0); // Number of days in the interval; always 0 for time-of-day values. // https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding buf.extend_from_slice(&[0_u8; 4]); - encode_time(self, len > 9, buf); + encode_time(self, len > 8, buf); Ok(IsNull::No) } fn size_hint(&self) -> usize { - if self.nanosecond() == 0 { - // if micro_seconds is 0, length is 8 and micro_seconds is not sent - 9 - } else { - // otherwise length is 12 - 13 - } + time_encoded_len(self) as usize + 1 // plus length byte } } @@ -99,6 +93,7 @@ impl TryFrom for Time { return Err(format!("MySqlTime value out of range for `time::Time`: {time}").into()); } + #[allow(clippy::cast_possible_truncation)] Ok(Time::from_hms_micro( // `is_valid_time_of_day()` ensures this won't overflow time.hours() as u8, @@ -111,6 +106,8 @@ impl TryFrom for Time { impl From for time::Duration { fn from(time: MySqlTime) -> Self { + // `subsec_nanos()` is guaranteed to be between 0 and 10^9 + #[allow(clippy::cast_possible_wrap)] time::Duration::new(time.whole_seconds_signed(), time.subsec_nanos() as i32) } } @@ -191,32 +188,20 @@ impl Type for PrimitiveDateTime { impl Encode<'_, MySql> for PrimitiveDateTime { fn encode_by_ref(&self, buf: &mut Vec) -> Result { - let len = Encode::::size_hint(self) - 1; - buf.push(len as u8); + let len = primitive_dt_encoded_len(self); + buf.push(len); encode_date(&self.date(), buf)?; if len > 4 { - encode_time(&self.time(), len > 8, buf); + encode_time(&self.time(), len > 7, buf); } Ok(IsNull::No) } fn size_hint(&self) -> usize { - // to save space the packet can be compressed: - match (self.hour(), self.minute(), self.second(), self.nanosecond()) { - // if hour, minutes, seconds and micro_seconds are all 0, - // length is 4 and no other field is sent - (0, 0, 0, 0) => 5, - - // if micro_seconds is 0, length is 7 - // and micro_seconds is not sent - (_, _, _, 0) => 8, - - // otherwise length is 11 - (_, _, _, _) => 12, - } + primitive_dt_encoded_len(self) as usize + 1 // plus length byte } } @@ -316,6 +301,37 @@ fn decode_time(mut buf: &[u8]) -> Result { 0 }; - Time::from_hms_micro(hour, minute, seconds, micros as u32) + let micros = u32::try_from(micros) + .map_err(|_| format!("MySQL returned microseconds out of range: {micros}"))?; + + Time::from_hms_micro(hour, minute, seconds, micros) .map_err(|e| format!("Time out of range for MySQL: {e}").into()) } + +#[inline(always)] +fn primitive_dt_encoded_len(time: &PrimitiveDateTime) -> u8 { + // to save space the packet can be compressed: + match (time.hour(), time.minute(), time.second(), time.nanosecond()) { + // if hour, minutes, seconds and micro_seconds are all 0, + // length is 4 and no other field is sent + (0, 0, 0, 0) => 4, + + // if micro_seconds is 0, length is 7 + // and micro_seconds is not sent + (_, _, _, 0) => 7, + + // otherwise length is 11 + (_, _, _, _) => 11, + } +} + +#[inline(always)] +fn time_encoded_len(time: &Time) -> u8 { + if time.nanosecond() == 0 { + // if micro_seconds is 0, length is 8 and micro_seconds is not sent + 8 + } else { + // otherwise length is 12 + 12 + } +} From 0d9f2c81a698a8b5ad74c4552bb90382ddf27e7c Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Tue, 20 Aug 2024 01:26:53 -0700 Subject: [PATCH 22/40] chore: configure clippy cast lints at workspace level --- Cargo.toml | 9 +++++++++ sqlx-bench/Cargo.toml | 3 +++ sqlx-cli/Cargo.toml | 3 +++ sqlx-core/Cargo.toml | 3 +++ sqlx-core/src/lib.rs | 4 ---- sqlx-macros/Cargo.toml | 3 +++ sqlx-mysql/Cargo.toml | 3 +++ sqlx-postgres/Cargo.toml | 3 +++ sqlx-postgres/src/lib.rs | 4 ---- sqlx-sqlite/Cargo.toml | 3 +++ sqlx-test/Cargo.toml | 3 +++ 11 files changed, 33 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b508f4dde4..fad4f9b1f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -187,6 +187,15 @@ criterion = { version = "0.5.1", features = ["async_tokio"] } # Enable testing with SQLCipher if specifically requested. libsqlite3-sys = { version = "0.30.1", features = ["bundled-sqlcipher"] } +# Common lint settings for the workspace +[workspace.lints.clippy] +# https://github.com/launchbadge/sqlx/issues/3440 +cast_possible_truncation = 'deny' +cast_possible_wrap = 'deny' +cast_sign_loss = 'deny' +# See `clippy.toml` +disallowed_methods = 'deny' + # # Any # diff --git a/sqlx-bench/Cargo.toml b/sqlx-bench/Cargo.toml index a2028e40f4..0aa9532034 100644 --- a/sqlx-bench/Cargo.toml +++ b/sqlx-bench/Cargo.toml @@ -42,3 +42,6 @@ required-features = ["postgres"] name = "sqlite_fetch_all" harness = false required-features = ["sqlite"] + +[lints] +workspace = true diff --git a/sqlx-cli/Cargo.toml b/sqlx-cli/Cargo.toml index 1ae27836fe..6ddc71de7c 100644 --- a/sqlx-cli/Cargo.toml +++ b/sqlx-cli/Cargo.toml @@ -68,3 +68,6 @@ completions = ["dep:clap_complete"] [dev-dependencies] assert_cmd = "2.0.11" tempfile = "3.10.1" + +[lints] +workspace = true diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 5d1198bc9b..60f9573aae 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -94,3 +94,6 @@ hashbrown = "0.14.5" [dev-dependencies] sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] } tokio = { version = "1", features = ["rt"] } + +[lints] +workspace = true diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index 8636760401..df4b2cc27d 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -15,10 +15,6 @@ #![recursion_limit = "512"] #![warn(future_incompatible, rust_2018_idioms)] #![allow(clippy::needless_doctest_main, clippy::type_complexity)] -// See `clippy.toml` at the workspace root -#![deny(clippy::disallowed_methods)] -#![deny(clippy::cast_possible_truncation)] -#![deny(clippy::cast_possible_wrap)] // The only unsafe code in SQLx is that necessary to interact with native APIs like with SQLite, // and that can live in its own separate driver crate. #![forbid(unsafe_code)] diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 813a00b46d..b34c812309 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -49,3 +49,6 @@ sqlx-macros-core = { workspace = true } proc-macro2 = { version = "1.0.36", default-features = false } syn = { version = "2.0.52", default-features = false, features = ["parsing", "proc-macro"] } quote = { version = "1.0.26", default-features = false } + +[lints] +workspace = true diff --git a/sqlx-mysql/Cargo.toml b/sqlx-mysql/Cargo.toml index 174b529141..493562c750 100644 --- a/sqlx-mysql/Cargo.toml +++ b/sqlx-mysql/Cargo.toml @@ -71,3 +71,6 @@ tracing = { version = "0.1.37", features = ["log"] } whoami = { version = "1.2.1", default-features = false } serde = { version = "1.0.144", optional = true } + +[lints] +workspace = true diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index 6534592d27..55a94eceb1 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -83,3 +83,6 @@ features = ["postgres", "derive"] [target.'cfg(target_os = "windows")'.dependencies] etcetera = "0.8.0" + +[lints] +workspace = true diff --git a/sqlx-postgres/src/lib.rs b/sqlx-postgres/src/lib.rs index 2bfc30d88e..c50f53067e 100644 --- a/sqlx-postgres/src/lib.rs +++ b/sqlx-postgres/src/lib.rs @@ -1,8 +1,4 @@ //! **PostgreSQL** database driver. -// https://github.com/launchbadge/sqlx/issues/3440 -#![deny(clippy::cast_possible_truncation)] -#![deny(clippy::cast_possible_wrap)] -#![deny(clippy::cast_sign_loss)] #[macro_use] extern crate sqlx_core; diff --git a/sqlx-sqlite/Cargo.toml b/sqlx-sqlite/Cargo.toml index c530d578fe..6b4923534d 100644 --- a/sqlx-sqlite/Cargo.toml +++ b/sqlx-sqlite/Cargo.toml @@ -60,3 +60,6 @@ workspace = true [dev-dependencies] sqlx = { workspace = true, default-features = false, features = ["macros", "runtime-tokio", "tls-none"] } + +[lints] +workspace = true diff --git a/sqlx-test/Cargo.toml b/sqlx-test/Cargo.toml index 32708596e9..af76d5562f 100644 --- a/sqlx-test/Cargo.toml +++ b/sqlx-test/Cargo.toml @@ -9,3 +9,6 @@ sqlx = { default-features = false, path = ".." } env_logger = "0.11" dotenvy = "0.15.0" anyhow = "1.0.26" + +[lints] +workspace = true From 8d63ec7bc421eea54747e8053057ef147b1a649e Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Tue, 20 Aug 2024 02:22:27 -0700 Subject: [PATCH 23/40] fix(sqlite): forward optional features correctly --- sqlx-sqlite/Cargo.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sqlx-sqlite/Cargo.toml b/sqlx-sqlite/Cargo.toml index 6b4923534d..80a03ca0ba 100644 --- a/sqlx-sqlite/Cargo.toml +++ b/sqlx-sqlite/Cargo.toml @@ -16,7 +16,11 @@ json = ["sqlx-core/json", "serde"] offline = ["sqlx-core/offline", "serde"] migrate = ["sqlx-core/migrate"] -chrono = ["dep:chrono"] +# Type integrations +chrono = ["dep:chrono", "sqlx-core/chrono"] +time = ["dep:time", "sqlx-core/time"] +uuid = ["dep:uuid", "sqlx-core/uuid"] + regexp = ["dep:regex"] [dependencies] From d121704b8d9733fcbb351780355080520e53e2f1 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Tue, 20 Aug 2024 02:24:36 -0700 Subject: [PATCH 24/40] fix: use same fix for the same cast in `Migrate::apply()` everywhere --- sqlx-mysql/src/migrate.rs | 5 ++--- sqlx-sqlite/src/migrate.rs | 1 + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sqlx-mysql/src/migrate.rs b/sqlx-mysql/src/migrate.rs index c595ca2cb4..79b55ace3c 100644 --- a/sqlx-mysql/src/migrate.rs +++ b/sqlx-mysql/src/migrate.rs @@ -224,6 +224,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( let elapsed = start.elapsed(); + #[allow(clippy::cast_possible_truncation)] let _ = query( r#" UPDATE _sqlx_migrations @@ -231,9 +232,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( WHERE version = ? "#, ) - // Unlikely unless the execution time exceeds ~292 years, - // then we're probably okay losing that information. - .bind(i64::try_from(elapsed.as_nanos()).ok()) + .bind(elapsed.as_nanos() as i64) .bind(migration.version) .execute(self) .await?; diff --git a/sqlx-sqlite/src/migrate.rs b/sqlx-sqlite/src/migrate.rs index ac434996b1..b9ce22dccd 100644 --- a/sqlx-sqlite/src/migrate.rs +++ b/sqlx-sqlite/src/migrate.rs @@ -168,6 +168,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( let elapsed = start.elapsed(); // language=SQL + #[allow(clippy::cast_possible_truncation)] let _ = query( r#" UPDATE _sqlx_migrations From 445b7895ad113ff2e2528079023809b53dd0e134 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Tue, 20 Aug 2024 03:04:49 -0700 Subject: [PATCH 25/40] fix(sqlite): audit for bad casts --- sqlx-sqlite/src/connection/collation.rs | 16 ++++- sqlx-sqlite/src/connection/explain.rs | 7 +++ sqlx-sqlite/src/connection/intmap.rs | 7 +++ sqlx-sqlite/src/logger.rs | 8 +++ sqlx-sqlite/src/statement/handle.rs | 72 ++++++++++++++-------- sqlx-sqlite/src/statement/unlock_notify.rs | 3 +- sqlx-sqlite/src/statement/virtual.rs | 8 ++- sqlx-sqlite/src/types/chrono.rs | 16 ++++- sqlx-sqlite/src/types/float.rs | 2 + sqlx-sqlite/src/value.rs | 8 ++- 10 files changed, 114 insertions(+), 33 deletions(-) diff --git a/sqlx-sqlite/src/connection/collation.rs b/sqlx-sqlite/src/connection/collation.rs index 8cffda84c5..573a9af892 100644 --- a/sqlx-sqlite/src/connection/collation.rs +++ b/sqlx-sqlite/src/connection/collation.rs @@ -127,13 +127,23 @@ where C: Fn(&str, &str) -> Ordering, { let boxed_f: *mut C = data as *mut C; - debug_assert!(!boxed_f.is_null()); + + // Note: unwinding is now caught at the FFI boundary: + // https://doc.rust-lang.org/nomicon/ffi.html#ffi-and-unwinding + assert!(!boxed_f.is_null()); + + let left_len = + usize::try_from(left_len).unwrap_or_else(|_| panic!("left_len out of range: {left_len}")); + + let right_len = usize::try_from(right_len) + .unwrap_or_else(|_| panic!("right_len out of range: {right_len}")); + let s1 = { - let c_slice = slice::from_raw_parts(left_ptr as *const u8, left_len as usize); + let c_slice = slice::from_raw_parts(left_ptr as *const u8, left_len); from_utf8_unchecked(c_slice) }; let s2 = { - let c_slice = slice::from_raw_parts(right_ptr as *const u8, right_len as usize); + let c_slice = slice::from_raw_parts(right_ptr as *const u8, right_len); from_utf8_unchecked(c_slice) }; let t = (*boxed_f)(s1, s2); diff --git a/sqlx-sqlite/src/connection/explain.rs b/sqlx-sqlite/src/connection/explain.rs index a18cd58a53..89762d171f 100644 --- a/sqlx-sqlite/src/connection/explain.rs +++ b/sqlx-sqlite/src/connection/explain.rs @@ -1,3 +1,10 @@ +// Bad casts in this module SHOULD NOT result in a SQL injection +// https://github.com/launchbadge/sqlx/issues/3440 +#![allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss +)] use crate::connection::intmap::IntMap; use crate::connection::{execute, ConnectionState}; use crate::error::Error; diff --git a/sqlx-sqlite/src/connection/intmap.rs b/sqlx-sqlite/src/connection/intmap.rs index 05a27ba9d8..dc09162f64 100644 --- a/sqlx-sqlite/src/connection/intmap.rs +++ b/sqlx-sqlite/src/connection/intmap.rs @@ -1,3 +1,10 @@ +// Bad casts in this module SHOULD NOT result in a SQL injection +// https://github.com/launchbadge/sqlx/issues/3440 +#![allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss +)] use std::cmp::Ordering; use std::{fmt::Debug, hash::Hash}; diff --git a/sqlx-sqlite/src/logger.rs b/sqlx-sqlite/src/logger.rs index a3de1374e3..40fabd48ed 100644 --- a/sqlx-sqlite/src/logger.rs +++ b/sqlx-sqlite/src/logger.rs @@ -1,3 +1,11 @@ +// Bad casts in this module SHOULD NOT result in a SQL injection +// https://github.com/launchbadge/sqlx/issues/3440 +#![allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss +)] + use crate::connection::intmap::IntMap; use std::collections::HashSet; use std::fmt::Debug; diff --git a/sqlx-sqlite/src/statement/handle.rs b/sqlx-sqlite/src/statement/handle.rs index e1a7ab3de1..2925d1a199 100644 --- a/sqlx-sqlite/src/statement/handle.rs +++ b/sqlx-sqlite/src/statement/handle.rs @@ -34,6 +34,23 @@ pub(crate) struct StatementHandle(NonNull); unsafe impl Send for StatementHandle {} +macro_rules! expect_ret_valid { + ($fn_name:ident($($args:tt)*)) => {{ + let val = $fn_name($($args)*); + + TryFrom::try_from(val) + // This likely means UB in SQLite itself or our usage of it; + // signed integer overflow is UB in the C standard. + .unwrap_or_else(|_| panic!("{}() returned invalid value: {val:?}", stringify!($fn_name))) + }} +} + +macro_rules! check_col_idx { + ($idx:ident) => { + c_int::try_from($idx).unwrap_or_else(|_| panic!("invalid column index: {}", $idx)) + }; +} + // might use some of this later #[allow(dead_code)] impl StatementHandle { @@ -71,7 +88,7 @@ impl StatementHandle { #[inline] pub(crate) fn column_count(&self) -> usize { // https://sqlite.org/c3ref/column_count.html - unsafe { sqlite3_column_count(self.0.as_ptr()) as usize } + unsafe { expect_ret_valid!(sqlite3_column_count(self.0.as_ptr())) } } #[inline] @@ -79,14 +96,14 @@ impl StatementHandle { // returns the number of changes of the *last* statement; not // necessarily this statement. // https://sqlite.org/c3ref/changes.html - unsafe { sqlite3_changes(self.db_handle()) as u64 } + unsafe { expect_ret_valid!(sqlite3_changes(self.db_handle())) } } #[inline] pub(crate) fn column_name(&self, index: usize) -> &str { // https://sqlite.org/c3ref/column_name.html unsafe { - let name = sqlite3_column_name(self.0.as_ptr(), index as c_int); + let name = sqlite3_column_name(self.0.as_ptr(), check_col_idx!(index)); debug_assert!(!name.is_null()); from_utf8_unchecked(CStr::from_ptr(name).to_bytes()) @@ -107,7 +124,7 @@ impl StatementHandle { #[inline] pub(crate) fn column_decltype(&self, index: usize) -> Option { unsafe { - let decl = sqlite3_column_decltype(self.0.as_ptr(), index as c_int); + let decl = sqlite3_column_decltype(self.0.as_ptr(), check_col_idx!(index)); if decl.is_null() { // If the Nth column of the result set is an expression or subquery, // then a NULL pointer is returned. @@ -123,6 +140,8 @@ impl StatementHandle { pub(crate) fn column_nullable(&self, index: usize) -> Result, Error> { unsafe { + let index = check_col_idx!(index); + // https://sqlite.org/c3ref/column_database_name.html // // ### Note @@ -130,9 +149,9 @@ impl StatementHandle { // sqlite3_finalize() or until the statement is automatically reprepared by the // first call to sqlite3_step() for a particular run or until the same information // is requested again in a different encoding. - let db_name = sqlite3_column_database_name(self.0.as_ptr(), index as c_int); - let table_name = sqlite3_column_table_name(self.0.as_ptr(), index as c_int); - let origin_name = sqlite3_column_origin_name(self.0.as_ptr(), index as c_int); + let db_name = sqlite3_column_database_name(self.0.as_ptr(), index); + let table_name = sqlite3_column_table_name(self.0.as_ptr(), index); + let origin_name = sqlite3_column_origin_name(self.0.as_ptr(), index); if db_name.is_null() || table_name.is_null() || origin_name.is_null() { return Ok(None); @@ -174,7 +193,7 @@ impl StatementHandle { #[inline] pub(crate) fn bind_parameter_count(&self) -> usize { // https://www.sqlite.org/c3ref/bind_parameter_count.html - unsafe { sqlite3_bind_parameter_count(self.0.as_ptr()) as usize } + unsafe { expect_ret_valid!(sqlite3_bind_parameter_count(self.0.as_ptr())) } } // Name Of A Host Parameter @@ -183,7 +202,7 @@ impl StatementHandle { pub(crate) fn bind_parameter_name(&self, index: usize) -> Option<&str> { unsafe { // https://www.sqlite.org/c3ref/bind_parameter_name.html - let name = sqlite3_bind_parameter_name(self.0.as_ptr(), index as c_int); + let name = sqlite3_bind_parameter_name(self.0.as_ptr(), check_col_idx!(index)); if name.is_null() { return None; } @@ -200,7 +219,7 @@ impl StatementHandle { unsafe { sqlite3_bind_blob64( self.0.as_ptr(), - index as c_int, + check_col_idx!(index), v.as_ptr() as *const c_void, v.len() as u64, SQLITE_TRANSIENT(), @@ -210,36 +229,39 @@ impl StatementHandle { #[inline] pub(crate) fn bind_text(&self, index: usize, v: &str) -> c_int { + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let encoding = SQLITE_UTF8 as u8; + unsafe { sqlite3_bind_text64( self.0.as_ptr(), - index as c_int, + check_col_idx!(index), v.as_ptr() as *const c_char, v.len() as u64, SQLITE_TRANSIENT(), - SQLITE_UTF8 as u8, + encoding, ) } } #[inline] pub(crate) fn bind_int(&self, index: usize, v: i32) -> c_int { - unsafe { sqlite3_bind_int(self.0.as_ptr(), index as c_int, v as c_int) } + unsafe { sqlite3_bind_int(self.0.as_ptr(), check_col_idx!(index), v as c_int) } } #[inline] pub(crate) fn bind_int64(&self, index: usize, v: i64) -> c_int { - unsafe { sqlite3_bind_int64(self.0.as_ptr(), index as c_int, v) } + unsafe { sqlite3_bind_int64(self.0.as_ptr(), check_col_idx!(index), v) } } #[inline] pub(crate) fn bind_double(&self, index: usize, v: f64) -> c_int { - unsafe { sqlite3_bind_double(self.0.as_ptr(), index as c_int, v) } + unsafe { sqlite3_bind_double(self.0.as_ptr(), check_col_idx!(index), v) } } #[inline] pub(crate) fn bind_null(&self, index: usize) -> c_int { - unsafe { sqlite3_bind_null(self.0.as_ptr(), index as c_int) } + unsafe { sqlite3_bind_null(self.0.as_ptr(), check_col_idx!(index)) } } // result values from the query @@ -247,39 +269,41 @@ impl StatementHandle { #[inline] pub(crate) fn column_type(&self, index: usize) -> c_int { - unsafe { sqlite3_column_type(self.0.as_ptr(), index as c_int) } + unsafe { sqlite3_column_type(self.0.as_ptr(), check_col_idx!(index)) } } #[inline] pub(crate) fn column_int(&self, index: usize) -> i32 { - unsafe { sqlite3_column_int(self.0.as_ptr(), index as c_int) as i32 } + unsafe { sqlite3_column_int(self.0.as_ptr(), check_col_idx!(index)) as i32 } } #[inline] pub(crate) fn column_int64(&self, index: usize) -> i64 { - unsafe { sqlite3_column_int64(self.0.as_ptr(), index as c_int) as i64 } + unsafe { sqlite3_column_int64(self.0.as_ptr(), check_col_idx!(index)) as i64 } } #[inline] pub(crate) fn column_double(&self, index: usize) -> f64 { - unsafe { sqlite3_column_double(self.0.as_ptr(), index as c_int) } + unsafe { sqlite3_column_double(self.0.as_ptr(), check_col_idx!(index)) } } #[inline] pub(crate) fn column_value(&self, index: usize) -> *mut sqlite3_value { - unsafe { sqlite3_column_value(self.0.as_ptr(), index as c_int) } + unsafe { sqlite3_column_value(self.0.as_ptr(), check_col_idx!(index)) } } pub(crate) fn column_blob(&self, index: usize) -> &[u8] { - let index = index as c_int; - let len = unsafe { sqlite3_column_bytes(self.0.as_ptr(), index) } as usize; + let len = unsafe { + expect_ret_valid!(sqlite3_column_bytes(self.0.as_ptr(), check_col_idx!(index))) + }; if len == 0 { // empty blobs are NULL so just return an empty slice return &[]; } - let ptr = unsafe { sqlite3_column_blob(self.0.as_ptr(), index) } as *const u8; + let ptr = + unsafe { sqlite3_column_blob(self.0.as_ptr(), check_col_idx!(index)) } as *const u8; debug_assert!(!ptr.is_null()); unsafe { from_raw_parts(ptr, len) } diff --git a/sqlx-sqlite/src/statement/unlock_notify.rs b/sqlx-sqlite/src/statement/unlock_notify.rs index b7e723a3f3..5821c23ae3 100644 --- a/sqlx-sqlite/src/statement/unlock_notify.rs +++ b/sqlx-sqlite/src/statement/unlock_notify.rs @@ -27,7 +27,8 @@ pub unsafe fn wait(conn: *mut sqlite3) -> Result<(), SqliteError> { unsafe extern "C" fn unlock_notify_cb(ptr: *mut *mut c_void, len: c_int) { let ptr = ptr as *mut &Notify; - let slice = slice::from_raw_parts(ptr, len as usize); + // We don't have a choice; we can't panic and unwind into FFI here. + let slice = slice::from_raw_parts(ptr, usize::try_from(len).unwrap_or(0)); for notify in slice { notify.fire(); diff --git a/sqlx-sqlite/src/statement/virtual.rs b/sqlx-sqlite/src/statement/virtual.rs index 3c17428912..6be980c36a 100644 --- a/sqlx-sqlite/src/statement/virtual.rs +++ b/sqlx-sqlite/src/statement/virtual.rs @@ -163,7 +163,13 @@ fn prepare( let mut tail: *const c_char = null(); let query_ptr = query.as_ptr() as *const c_char; - let query_len = query.len() as i32; + let query_len = i32::try_from(query.len()).map_err(|_| { + err_protocol!( + "query string too large for SQLite3 API ({} bytes); \ + try breaking it into smaller chunks (< 2 GiB), executed separately", + query.len() + ) + })?; // let status = unsafe { diff --git a/sqlx-sqlite/src/types/chrono.rs b/sqlx-sqlite/src/types/chrono.rs index c491a9aa66..7424720444 100644 --- a/sqlx-sqlite/src/types/chrono.rs +++ b/sqlx-sqlite/src/types/chrono.rs @@ -167,10 +167,20 @@ fn decode_datetime_from_float(value: f64) -> Option> { let epoch_in_julian_days = 2_440_587.5; let seconds_in_day = 86400.0; let timestamp = (value - epoch_in_julian_days) * seconds_in_day; - let seconds = timestamp as i64; - let nanos = (timestamp.fract() * 1E9) as u32; - Utc.fix().timestamp_opt(seconds, nanos).single() + if !timestamp.is_finite() { + return None; + } + + // We don't really have a choice but to do lossy casts for this conversion + // We checked above if the value is infinite or NaN which could otherwise cause problems + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + { + let seconds = timestamp.trunc() as i64; + let nanos = (timestamp.fract() * 1E9).abs() as u32; + + Utc.fix().timestamp_opt(seconds, nanos).single() + } } impl<'r> Decode<'r, Sqlite> for NaiveDateTime { diff --git a/sqlx-sqlite/src/types/float.rs b/sqlx-sqlite/src/types/float.rs index 499a694242..79224f5451 100644 --- a/sqlx-sqlite/src/types/float.rs +++ b/sqlx-sqlite/src/types/float.rs @@ -24,6 +24,8 @@ impl<'q> Encode<'q, Sqlite> for f32 { impl<'r> Decode<'r, Sqlite> for f32 { fn decode(value: SqliteValueRef<'r>) -> Result { + // Truncation is intentional + #[allow(clippy::cast_possible_truncation)] Ok(value.double() as f32) } } diff --git a/sqlx-sqlite/src/value.rs b/sqlx-sqlite/src/value.rs index 1a4d8898a4..967b3f7476 100644 --- a/sqlx-sqlite/src/value.rs +++ b/sqlx-sqlite/src/value.rs @@ -120,7 +120,13 @@ impl SqliteValue { } fn blob(&self) -> &[u8] { - let len = unsafe { sqlite3_value_bytes(self.handle.0.as_ptr()) } as usize; + let len = unsafe { sqlite3_value_bytes(self.handle.0.as_ptr()) }; + + // This likely means UB in SQLite itself or our usage of it; + // signed integer overflow is UB in the C standard. + let len = usize::try_from(len).unwrap_or_else(|_| { + panic!("sqlite3_value_bytes() returned value out of range for usize: {len}") + }); if len == 0 { // empty blobs are NULL so just return an empty slice From 2e9ba07cb60f8691313de908a46319764cab1555 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 21 Aug 2024 13:00:10 -0700 Subject: [PATCH 26/40] fix(postgres): dead code `StatementId::UNNAMED` --- sqlx-postgres/src/io/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlx-postgres/src/io/mod.rs b/sqlx-postgres/src/io/mod.rs index f90db85d93..df064b1e2c 100644 --- a/sqlx-postgres/src/io/mod.rs +++ b/sqlx-postgres/src/io/mod.rs @@ -10,7 +10,6 @@ pub(crate) use sqlx_core::io::*; #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub(crate) struct StatementId(IdInner); -#[allow(dead_code)] #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub(crate) struct PortalId(IdInner); @@ -18,6 +17,7 @@ pub(crate) struct PortalId(IdInner); struct IdInner(Option); impl StatementId { + #[allow(dead_code)] pub const UNNAMED: Self = Self(IdInner::UNNAMED); pub const NAMED_START: Self = Self(IdInner::NAMED_START); From d3e75e7852000afcc278e56c86d4d5491fccd80c Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 21 Aug 2024 13:17:17 -0700 Subject: [PATCH 27/40] fix(postgres): decode `PgDatabaseError` for `ErrorResponse` --- sqlx-postgres/src/connection/stream.rs | 2 +- sqlx-postgres/src/error.rs | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sqlx-postgres/src/connection/stream.rs b/sqlx-postgres/src/connection/stream.rs index a7c7d1aea0..7817399925 100644 --- a/sqlx-postgres/src/connection/stream.rs +++ b/sqlx-postgres/src/connection/stream.rs @@ -98,7 +98,7 @@ impl PgStream { match message.format { BackendMessageFormat::ErrorResponse => { // An error returned from the database server. - return Err(PgDatabaseError(message.decode()?).into()); + return Err(message.decode::()?.into()); } BackendMessageFormat::NotificationResponse => { diff --git a/sqlx-postgres/src/error.rs b/sqlx-postgres/src/error.rs index b9df865736..db8bcc8a10 100644 --- a/sqlx-postgres/src/error.rs +++ b/sqlx-postgres/src/error.rs @@ -3,10 +3,10 @@ use std::fmt::{self, Debug, Display, Formatter}; use atoi::atoi; use smallvec::alloc::borrow::Cow; - +use sqlx_core::bytes::Bytes; pub(crate) use sqlx_core::error::*; -use crate::message::{Notice, PgSeverity}; +use crate::message::{BackendMessage, BackendMessageFormat, Notice, PgSeverity}; /// An error returned from the PostgreSQL database. pub struct PgDatabaseError(pub(crate) Notice); @@ -219,6 +219,16 @@ impl DatabaseError for PgDatabaseError { } } +// ErrorResponse is the same structure as NoticeResponse but a different format code. +impl BackendMessage for PgDatabaseError { + const FORMAT: BackendMessageFormat = BackendMessageFormat::ErrorResponse; + + #[inline(always)] + fn decode_body(buf: Bytes) -> std::result::Result { + Ok(Self(Notice::decode_body(buf)?)) + } +} + /// For reference: pub(crate) mod error_codes { /// Caused when a unique or primary key is violated. From 8db2055ed88c2ccfe461143d8cc143832a7bfde9 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 21 Aug 2024 13:50:48 -0700 Subject: [PATCH 28/40] chore(postgres): include nullables query in error --- sqlx-core/src/error.rs | 17 +++++++++++------ sqlx-postgres/src/connection/describe.rs | 8 +++++++- sqlx-postgres/src/io/buf_mut.rs | 2 +- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 042342ef9f..17774addd2 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -312,11 +312,16 @@ impl From for Error { /// Format an error message as a `Protocol` error #[macro_export] macro_rules! err_protocol { - ($expr:expr) => { - $crate::error::Error::Protocol($expr.into()) - }; - - ($fmt:expr, $($arg:tt)*) => { - $crate::error::Error::Protocol(format!($fmt, $($arg)*)) + ($($fmt_args:tt)*) => { + $crate::error::Error::Protocol( + format!( + "{} ({}:{})", + // Note: the format string needs to be unmodified (e.g. by `concat!()`) + // for implicit formatting arguments to work + format_args!($($fmt_args)*), + module_path!(), + line!(), + ) + ) }; } diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index d9c55201a0..9d532a5178 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -466,7 +466,13 @@ WHERE rngtypid = $1 let mut nullables: Vec> = nullable_query .build_query_scalar() .fetch_all(&mut *self) - .await?; + .await + .map_err(|e| { + err_protocol!( + "error from nullables query: {e}; query: {:?}", + nullable_query.sql() + ) + })?; // If the server is CockroachDB or Materialize, skip this step (#1248). if !self.stream.parameter_statuses.contains_key("crdb_version") diff --git a/sqlx-postgres/src/io/buf_mut.rs b/sqlx-postgres/src/io/buf_mut.rs index ff6fe03df3..eea9d34acd 100644 --- a/sqlx-postgres/src/io/buf_mut.rs +++ b/sqlx-postgres/src/io/buf_mut.rs @@ -27,7 +27,7 @@ impl PgBufMutExt for Vec { let size_result = write_result.and_then(|_| { let size = self.len() - offset; i32::try_from(size) - .map_err(|_| err_protocol!("message size out of range for Pg protocol: {size")) + .map_err(|_| err_protocol!("message size out of range for protocol: {size}")) }); match size_result { From e731cfd3d48397579d40a8b5d8b359fa185866f0 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 21 Aug 2024 14:04:21 -0700 Subject: [PATCH 29/40] fix(postgres): syntax error in nullables query --- sqlx-postgres/src/connection/describe.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index 9d532a5178..145616fec1 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -451,7 +451,7 @@ WHERE rngtypid = $1 .push_unseparated("::int4"); tuple .push_bind(column.relation_attribute_no) - .push_bind_unseparated("::int2"); + .push_unseparated("::int2"); }); nullable_query.push( From 37f53cc7e92421fe5277399ec6219b898047a556 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 21 Aug 2024 14:49:21 -0700 Subject: [PATCH 30/40] fix(postgres): syntax error in EXPLAIN query --- sqlx-postgres/src/connection/describe.rs | 6 ++- sqlx-postgres/src/io/buf_mut.rs | 10 +--- sqlx-postgres/src/io/mod.rs | 64 +++++++++++++++++------- 3 files changed, 52 insertions(+), 28 deletions(-) diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index 145616fec1..9a46a202d5 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -500,7 +500,11 @@ WHERE rngtypid = $1 stmt_id: StatementId, params_len: usize, ) -> Result>, Error> { - let mut explain = format!("EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE {stmt_id}"); + let stmt_id_display = stmt_id + .display() + .ok_or_else(|| err_protocol!("cannot EXPLAIN unnamed statement: {stmt_id:?}"))?; + + let mut explain = format!("EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE {stmt_id_display}"); let mut comma = false; if params_len > 0 { diff --git a/sqlx-postgres/src/io/buf_mut.rs b/sqlx-postgres/src/io/buf_mut.rs index eea9d34acd..0fe3809b57 100644 --- a/sqlx-postgres/src/io/buf_mut.rs +++ b/sqlx-postgres/src/io/buf_mut.rs @@ -47,18 +47,12 @@ impl PgBufMutExt for Vec { // writes a statement name by ID #[inline] fn put_statement_name(&mut self, id: StatementId) { - let _: Result<(), ()> = id.write_name(|s| { - self.extend_from_slice(s.as_bytes()); - Ok(()) - }); + id.put_name_with_nul(self); } // writes a portal name by ID #[inline] fn put_portal_name(&mut self, id: PortalId) { - let _: Result<(), ()> = id.write_name(|s| { - self.extend_from_slice(s.as_bytes()); - Ok(()) - }); + id.put_name_with_nul(self); } } diff --git a/sqlx-postgres/src/io/mod.rs b/sqlx-postgres/src/io/mod.rs index df064b1e2c..72f2a978c8 100644 --- a/sqlx-postgres/src/io/mod.rs +++ b/sqlx-postgres/src/io/mod.rs @@ -16,6 +16,11 @@ pub(crate) struct PortalId(IdInner); #[derive(Debug, Copy, Clone, PartialEq, Eq)] struct IdInner(Option); +pub(crate) struct DisplayId { + prefix: &'static str, + id: NonZeroU32, +} + impl StatementId { #[allow(dead_code)] pub const UNNAMED: Self = Self(IdInner::UNNAMED); @@ -35,16 +40,22 @@ impl StatementId { self.0.name_len(Self::NAME_PREFIX) } - // There's no common trait implemented by `Formatter` and `Vec` for this purpose; - // we're deliberately avoiding the formatting machinery because it's known to be slow. - pub fn write_name(&self, write: impl FnMut(&str) -> Result<(), E>) -> Result<(), E> { - self.0.write_name(Self::NAME_PREFIX, write) + /// Get a type to format this statement ID with [`Display`]. + /// + /// Returns `None` if this is the unnamed statement. + #[inline(always)] + pub fn display(&self) -> Option { + self.0.display(Self::NAME_PREFIX) + } + + pub fn put_name_with_nul(&self, buf: &mut Vec) { + self.0.put_name_with_nul(Self::NAME_PREFIX, buf) } } -impl Display for StatementId { +impl Display for DisplayId { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.write_name(|s| f.write_str(s)) + write!(f, "{}{}", self.prefix, self.id) } } @@ -67,13 +78,13 @@ impl PortalId { Self(self.0.next()) } - /// Calculate the number of bytes that will be written by [`Self::write_name()`]. + /// Calculate the number of bytes that will be written by [`Self::put_name_with_nul()`]. pub fn name_len(&self) -> Saturating { self.0.name_len(Self::NAME_PREFIX) } - pub fn write_name(&self, write: impl FnMut(&str) -> Result<(), E>) -> Result<(), E> { - self.0.write_name(Self::NAME_PREFIX, write) + pub fn put_name_with_nul(&self, buf: &mut Vec) { + self.0.put_name_with_nul(Self::NAME_PREFIX, buf) } } @@ -93,6 +104,11 @@ impl IdInner { ) } + #[inline(always)] + fn display(&self, prefix: &'static str) -> Option { + self.0.map(|id| DisplayId { prefix, id }) + } + #[inline(always)] fn name_len(&self, name_prefix: &str) -> Saturating { let mut len = Saturating(0); @@ -113,18 +129,28 @@ impl IdInner { } #[inline(always)] - fn write_name( - &self, - name_prefix: &str, - mut write: impl FnMut(&str) -> Result<(), E>, - ) -> Result<(), E> { + fn put_name_with_nul(&self, name_prefix: &str, buf: &mut Vec) { if let Some(id) = self.0 { - write(name_prefix)?; - write(itoa::Buffer::new().format(id.get()))?; + buf.extend_from_slice(name_prefix.as_bytes()); + buf.extend_from_slice(itoa::Buffer::new().format(id.get()).as_bytes()); } - write("\0")?; - - Ok(()) + buf.push(0); } } + +#[test] +fn statement_id_display_matches_encoding() { + const EXPECTED_STR: &str = "sqlx_s_1234567890"; + const EXPECTED_BYTES: &[u8] = b"sqlx_s_1234567890\0"; + + let mut bytes = Vec::new(); + + StatementId::TEST_VAL.put_name_with_nul(&mut bytes); + + assert_eq!(bytes, EXPECTED_BYTES); + + let str = StatementId::TEST_VAL.display().unwrap().to_string(); + + assert_eq!(str, EXPECTED_STR); +} From 59f5cd0bd74518971b0fe6c356418b4b7bf82638 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 21 Aug 2024 15:12:26 -0700 Subject: [PATCH 31/40] fix(mysql): correct `Capabilities` assertions in unit tests --- .../src/protocol/connect/auth_switch.rs | 11 ++++++++--- sqlx-mysql/src/protocol/connect/handshake.rs | 19 ++++++++++--------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/sqlx-mysql/src/protocol/connect/auth_switch.rs b/sqlx-mysql/src/protocol/connect/auth_switch.rs index 7ecb1b2b17..e61d26d793 100644 --- a/sqlx-mysql/src/protocol/connect/auth_switch.rs +++ b/sqlx-mysql/src/protocol/connect/auth_switch.rs @@ -81,9 +81,14 @@ fn test_decode_auth_switch_cleartext_disabled() { let e = AuthSwitchRequest::decode_with(AUTH_SWITCH_CLEARTEXT.into(), false).unwrap_err(); - assert_eq!( - e.to_string(), - "encountered unexpected or invalid data: mysql_cleartext_plugin disabled" + let e_str = e.to_string(); + + let expected = "encountered unexpected or invalid data: mysql_cleartext_plugin disabled"; + + assert!( + // Don't want to assert the full string since it contains the module path now. + e_str.starts_with(expected), + "expected error string to start with {expected:?}, got {e_str:?}" ); } diff --git a/sqlx-mysql/src/protocol/connect/handshake.rs b/sqlx-mysql/src/protocol/connect/handshake.rs index 84afe74ea6..3fef521652 100644 --- a/sqlx-mysql/src/protocol/connect/handshake.rs +++ b/sqlx-mysql/src/protocol/connect/handshake.rs @@ -94,11 +94,12 @@ impl ProtocolDecode<'_> for Handshake { fn test_decode_handshake_mysql_8_0_18() { const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00"; - let mut p = Handshake::decode(HANDSHAKE_MYSQL_8_0_18.into()).unwrap(); + let p = Handshake::decode(HANDSHAKE_MYSQL_8_0_18.into()).unwrap(); assert_eq!(p.protocol_version, 10); - p.server_capabilities.toggle( + assert_eq!( + p.server_capabilities, Capabilities::MYSQL | Capabilities::FOUND_ROWS | Capabilities::LONG_FLAG @@ -128,8 +129,6 @@ fn test_decode_handshake_mysql_8_0_18() { | Capabilities::REMEMBER_OPTIONS, ); - assert!(p.server_capabilities.is_empty()); - assert_eq!(p.server_default_collation, 255); assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT)); @@ -148,7 +147,7 @@ fn test_decode_handshake_mysql_8_0_18() { fn test_decode_handshake_mariadb_10_4_7() { const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\" Date: Thu, 22 Aug 2024 15:01:20 -0700 Subject: [PATCH 32/40] fix(mysql): add `sqlx` as a dev-dependency for doctests --- .github/workflows/sqlx.yml | 2 ++ Cargo.lock | 1 + sqlx-mysql/Cargo.toml | 3 +++ 3 files changed, 6 insertions(+) diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index a7d93f5e6c..599b90c33a 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -103,10 +103,12 @@ jobs: -p sqlx-macros-core --all-features + # Note: use `--lib` to not run integration tests that require a DB - name: Test sqlx run: > cargo test -p sqlx + --lib --all-features sqlite: diff --git a/Cargo.lock b/Cargo.lock index 3f711abedd..1a437b0519 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3639,6 +3639,7 @@ dependencies = [ "sha1", "sha2", "smallvec", + "sqlx", "sqlx-core", "stringprep", "thiserror", diff --git a/sqlx-mysql/Cargo.toml b/sqlx-mysql/Cargo.toml index 493562c750..a904bc0eef 100644 --- a/sqlx-mysql/Cargo.toml +++ b/sqlx-mysql/Cargo.toml @@ -72,5 +72,8 @@ whoami = { version = "1.2.1", default-features = false } serde = { version = "1.0.144", optional = true } +[dev-dependencies] +sqlx = { workspace = true, features = ["mysql"] } + [lints] workspace = true From 98fe7ca2844e801e89ca968ed6e9fe8dc6289e58 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 22 Aug 2024 16:23:35 -0700 Subject: [PATCH 33/40] fix(mysql): fix doctests --- sqlx-mysql/src/options/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlx-mysql/src/options/mod.rs b/sqlx-mysql/src/options/mod.rs index 7881fe8ca2..db2b20c19d 100644 --- a/sqlx-mysql/src/options/mod.rs +++ b/sqlx-mysql/src/options/mod.rs @@ -53,9 +53,9 @@ pub use ssl_mode::MySqlSslMode; /// /// // Change the log verbosity level for queries. /// // Information about SQL queries is logged at `DEBUG` level by default. -/// opts.log_statements(log::LevelFilter::Trace); +/// opts = opts.log_statements(log::LevelFilter::Trace); /// -/// let pool = MySqlPool::connect_with(&opts).await?; +/// let pool = MySqlPool::connect_with(opts).await?; /// # Ok(()) /// # } /// ``` From 0c1ff60336523ce854ef8d2f77e436d649e4c893 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 22 Aug 2024 17:20:22 -0700 Subject: [PATCH 34/40] fix(sqlite): fix unit and doctests --- sqlx-sqlite/Cargo.toml | 2 +- sqlx-sqlite/src/connection/explain.rs | 102 +++++++++++++------------- sqlx-sqlite/src/options/parse.rs | 36 ++++++--- sqlx-sqlite/src/regexp.rs | 2 +- 4 files changed, 80 insertions(+), 62 deletions(-) diff --git a/sqlx-sqlite/Cargo.toml b/sqlx-sqlite/Cargo.toml index 80a03ca0ba..1ad87de102 100644 --- a/sqlx-sqlite/Cargo.toml +++ b/sqlx-sqlite/Cargo.toml @@ -63,7 +63,7 @@ features = [ workspace = true [dev-dependencies] -sqlx = { workspace = true, default-features = false, features = ["macros", "runtime-tokio", "tls-none"] } +sqlx = { workspace = true, default-features = false, features = ["macros", "runtime-tokio", "tls-none", "sqlite"] } [lints] workspace = true diff --git a/sqlx-sqlite/src/connection/explain.rs b/sqlx-sqlite/src/connection/explain.rs index 89762d171f..bfa66aa12f 100644 --- a/sqlx-sqlite/src/connection/explain.rs +++ b/sqlx-sqlite/src/connection/explain.rs @@ -1633,147 +1633,147 @@ fn test_root_block_columns_has_types() { { let table_db_block = table_block_nums["t"]; assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Integer, nullable: Some(true) //sqlite primary key columns are nullable unless declared not null - }, - root_block_cols[&table_db_block][&0] + }), + root_block_cols[&table_db_block].get(&0) ); assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Text, nullable: Some(true) - }, - root_block_cols[&table_db_block][&1] + }), + root_block_cols[&table_db_block].get(&1) ); assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Text, nullable: Some(false) - }, - root_block_cols[&table_db_block][&2] + }), + root_block_cols[&table_db_block].get(&2) ); } { let table_db_block = table_block_nums["i1"]; assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Integer, nullable: Some(true) //sqlite primary key columns are nullable unless declared not null - }, - root_block_cols[&table_db_block][&0] + }), + root_block_cols[&table_db_block].get(&0) ); assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Text, nullable: Some(true) - }, - root_block_cols[&table_db_block][&1] + }), + root_block_cols[&table_db_block].get(&1) ); } { let table_db_block = table_block_nums["i2"]; assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Integer, nullable: Some(true) //sqlite primary key columns are nullable unless declared not null - }, - root_block_cols[&table_db_block][&0] + }), + root_block_cols[&table_db_block].get(&0) ); assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Text, nullable: Some(true) - }, - root_block_cols[&table_db_block][&1] + }), + root_block_cols[&table_db_block].get(&1) ); } { let table_db_block = table_block_nums["t2"]; assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Integer, nullable: Some(false) - }, - root_block_cols[&table_db_block][&0] + }), + root_block_cols[&table_db_block].get(&0) ); assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Null, nullable: Some(true) - }, - root_block_cols[&table_db_block][&1] + }), + root_block_cols[&table_db_block].get(&1) ); assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Null, nullable: Some(false) - }, - root_block_cols[&table_db_block][&2] + }), + root_block_cols[&table_db_block].get(&2) ); } { let table_db_block = table_block_nums["t2i1"]; assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Integer, nullable: Some(false) - }, - root_block_cols[&table_db_block][&0] + }), + root_block_cols[&table_db_block].get(&0) ); assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Null, nullable: Some(true) - }, - root_block_cols[&table_db_block][&1] + }), + root_block_cols[&table_db_block].get(&1) ); } { let table_db_block = table_block_nums["t2i2"]; assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Integer, nullable: Some(false) - }, - root_block_cols[&table_db_block][&0] + }), + root_block_cols[&table_db_block].get(&0) ); assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Null, nullable: Some(false) - }, - root_block_cols[&table_db_block][&1] + }), + root_block_cols[&table_db_block].get(&1) ); } { let table_db_block = table_block_nums["t3"]; assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Text, nullable: Some(true) - }, - root_block_cols[&table_db_block][&0] + }), + root_block_cols[&table_db_block].get(&0) ); assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Float, nullable: Some(false) - }, - root_block_cols[&table_db_block][&1] + }), + root_block_cols[&table_db_block].get(&1) ); assert_eq!( - ColumnType::Single { + Some(&ColumnType::Single { datatype: DataType::Float, nullable: Some(true) - }, - root_block_cols[&table_db_block][&2] + }), + root_block_cols[&table_db_block].get(&2) ); } } diff --git a/sqlx-sqlite/src/options/parse.rs b/sqlx-sqlite/src/options/parse.rs index f06cf0c65c..0530f4204c 100644 --- a/sqlx-sqlite/src/options/parse.rs +++ b/sqlx-sqlite/src/options/parse.rs @@ -1,12 +1,14 @@ -use crate::error::Error; -use crate::SqliteConnectOptions; -use percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}; use std::borrow::Cow; use std::path::{Path, PathBuf}; use std::str::FromStr; use std::sync::atomic::{AtomicUsize, Ordering}; + +use percent_encoding::{percent_decode_str, percent_encode, AsciiSet}; use url::Url; +use crate::error::Error; +use crate::SqliteConnectOptions; + // https://www.sqlite.org/uri.html static IN_MEMORY_DB_SEQ: AtomicUsize = AtomicUsize::new(0); @@ -114,10 +116,25 @@ impl SqliteConnectOptions { } pub(crate) fn build_url(&self) -> Url { - let filename = - utf8_percent_encode(&self.filename.to_string_lossy(), NON_ALPHANUMERIC).to_string(); - let mut url = - Url::parse(&format!("sqlite://{}", filename)).expect("BUG: generated un-parseable URL"); + // https://url.spec.whatwg.org/#path-percent-encode-set + static PATH_ENCODE_SET: AsciiSet = percent_encoding::CONTROLS + .add(b' ') + .add(b'"') + .add(b'#') + .add(b'<') + .add(b'>') + .add(b'?') + .add(b'`') + .add(b'{') + .add(b'}'); + + let filename_encoded = percent_encode( + self.filename.as_os_str().as_encoded_bytes(), + &PATH_ENCODE_SET, + ); + + let mut url = Url::parse(&format!("sqlite://{filename_encoded}")) + .expect("BUG: generated un-parseable URL"); let mode = match (self.in_memory, self.create_if_missing, self.read_only) { (true, _, _) => "memory", @@ -133,8 +150,9 @@ impl SqliteConnectOptions { }; url.query_pairs_mut().append_pair("cache", cache); - url.query_pairs_mut() - .append_pair("immutable", &self.immutable.to_string()); + if self.immutable { + url.query_pairs_mut().append_pair("immutable", "true"); + } if let Some(vfs) = &self.vfs { url.query_pairs_mut().append_pair("vfs", vfs); diff --git a/sqlx-sqlite/src/regexp.rs b/sqlx-sqlite/src/regexp.rs index ee19482ee1..eb14fffc77 100644 --- a/sqlx-sqlite/src/regexp.rs +++ b/sqlx-sqlite/src/regexp.rs @@ -170,7 +170,7 @@ unsafe extern "C" fn cleanup_arc_regex_pointer(ptr: *mut std::ffi::c_void) { #[cfg(test)] mod tests { - use sqlx::{ConnectOptions, Connection, Row}; + use sqlx::{ConnectOptions, Row}; use std::str::FromStr; async fn test_db() -> crate::SqliteConnection { From 5889f76436e35f45406b18d810b0355b8f398f6a Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 22 Aug 2024 17:48:53 -0700 Subject: [PATCH 35/40] fix(postgres): fix missing inversion on `PgNumeric::is_valid_digit()` --- sqlx-postgres/src/types/numeric.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlx-postgres/src/types/numeric.rs b/sqlx-postgres/src/types/numeric.rs index 3a01f2e621..67713d7694 100644 --- a/sqlx-postgres/src/types/numeric.rs +++ b/sqlx-postgres/src/types/numeric.rs @@ -151,8 +151,8 @@ impl PgNumeric { buf.extend(&scale.to_be_bytes()); for (i, &digit) in digits.iter().enumerate() { - if Self::is_valid_digit(digit) { - return Err(format!("{i}th PgNumeric digit out of range {digit}")); + if !Self::is_valid_digit(digit) { + return Err(format!("{i}th PgNumeric digit out of range: {digit}")); } buf.extend(&digit.to_be_bytes()); From 6aaec279d54ec05e7f4b5c08706c6944dc17e1fc Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 23 Aug 2024 21:58:16 -0700 Subject: [PATCH 36/40] fix(mysql): don't use an arbitrary `cfg` for one test --- sqlx-mysql/src/error.rs | 20 ++++++++++++++++++++ tests/mysql/error.rs | 1 - 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sqlx-mysql/src/error.rs b/sqlx-mysql/src/error.rs index 89c6df47ca..e7363399d3 100644 --- a/sqlx-mysql/src/error.rs +++ b/sqlx-mysql/src/error.rs @@ -100,6 +100,16 @@ impl DatabaseError for MySqlDatabaseError { error_codes::ER_CHECK_CONSTRAINT_VIOLATED => ErrorKind::CheckViolation, + // https://mariadb.com/kb/en/e4025/ + error_codes::mariadb::ER_CONSTRAINT_FAILED + // MySQL uses this code for a completely different error, + // but we can differentiate by SQLSTATE: + // + { + ErrorKind::CheckViolation + } + _ => ErrorKind::Other, } } @@ -154,4 +164,14 @@ pub(crate) mod error_codes { /// /// Only available after 8.0.16. pub const ER_CHECK_CONSTRAINT_VIOLATED: u16 = 3819; + + pub(crate) mod mariadb { + /// Error code emitted by MariaDB for constraint errors: + /// + /// MySQL emits this code for a completely different error: + /// + /// + /// You also check that SQLSTATE is `23000`. + pub const ER_CONSTRAINT_FAILED: u16 = 4025; + } } diff --git a/tests/mysql/error.rs b/tests/mysql/error.rs index bb174e5e8d..7c84266c32 100644 --- a/tests/mysql/error.rs +++ b/tests/mysql/error.rs @@ -57,7 +57,6 @@ async fn it_fails_with_not_null_violation() -> anyhow::Result<()> { Ok(()) } -#[cfg(mysql_8)] #[sqlx_macros::test] async fn it_fails_with_check_violation() -> anyhow::Result<()> { let mut conn = new::().await?; From 985392eefb8e6edad0a71a914c2f725fb45e0ee6 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 23 Aug 2024 22:03:19 -0700 Subject: [PATCH 37/40] fix(postgres): use checked decrement on `pending_ready_for_query_count` --- sqlx-postgres/src/connection/mod.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 9003dcb338..5a6a597ead 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -102,7 +102,11 @@ impl PgConnection { #[inline(always)] fn handle_ready_for_query(&mut self, message: ReceivedMessage) -> Result<(), Error> { - self.pending_ready_for_query_count -= 1; + self.pending_ready_for_query_count = self + .pending_ready_for_query_count + .checked_sub(1) + .ok_or_else(|| err_protocol!("received more ReadyForQuery messages than expected"))?; + self.transaction_status = message.decode::()?.transaction_status; Ok(()) From f9e51769e905cbc1be54122b20f12b025ee7eaf0 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 23 Aug 2024 01:25:25 -0700 Subject: [PATCH 38/40] chore(postgres): create regression test for RUSTSEC-2024-0363 --- .github/workflows/sqlx.yml | 4 +- Cargo.toml | 5 + tests/postgres/fixtures/rustsec/2024_0363.sql | 4 + tests/postgres/rustsec.rs | 146 ++++++++++++++++++ 4 files changed, 157 insertions(+), 2 deletions(-) create mode 100644 tests/postgres/fixtures/rustsec/2024_0363.sql create mode 100644 tests/postgres/rustsec.rs diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 599b90c33a..508f036eba 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -204,7 +204,7 @@ jobs: - run: > cargo test --no-default-features - --features any,postgres,macros,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,postgres,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: postgres://postgres:password@localhost:5432/sqlx SQLX_OFFLINE_DIR: .sqlx @@ -216,7 +216,7 @@ jobs: run: > cargo test --no-default-features - --features any,postgres,macros,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,postgres,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: postgres://postgres:password@localhost:5432/sqlx?sslmode=verify-ca&sslrootcert=.%2Ftests%2Fcerts%2Fca.crt SQLX_OFFLINE_DIR: .sqlx diff --git a/Cargo.toml b/Cargo.toml index fad4f9b1f9..3345aac5f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -372,3 +372,8 @@ required-features = ["postgres", "macros", "migrate"] name = "postgres-query-builder" path = "tests/postgres/query_builder.rs" required-features = ["postgres"] + +[[test]] +name = "postgres-rustsec" +path = "tests/postgres/rustsec.rs" +required-features = ["postgres", "macros", "migrate"] diff --git a/tests/postgres/fixtures/rustsec/2024_0363.sql b/tests/postgres/fixtures/rustsec/2024_0363.sql new file mode 100644 index 0000000000..c3bb5b0920 --- /dev/null +++ b/tests/postgres/fixtures/rustsec/2024_0363.sql @@ -0,0 +1,4 @@ +-- https://rustsec.org/advisories/RUSTSEC-2024-0363.html +-- https://github.com/launchbadge/sqlx/issues/3440 +CREATE TABLE injection_target(id BIGSERIAL PRIMARY KEY, message TEXT); +INSERT INTO injection_target(message) VALUES ('existing value'); diff --git a/tests/postgres/rustsec.rs b/tests/postgres/rustsec.rs new file mode 100644 index 0000000000..45fd76b9db --- /dev/null +++ b/tests/postgres/rustsec.rs @@ -0,0 +1,146 @@ +use sqlx::{Error, PgPool}; + +use std::{cmp, str}; + +// https://rustsec.org/advisories/RUSTSEC-2024-0363.html +#[sqlx::test(migrations = false, fixtures("./fixtures/rustsec/2024_0363.sql"))] +async fn rustsec_2024_0363(pool: PgPool) -> anyhow::Result<()> { + let overflow_len = 4 * 1024 * 1024 * 1024; // 4 GiB + + // These three strings concatenated together will be the first query the Postgres backend "sees" + // + // Rather contrived because this already represents an injection vulnerability, + // but it's easier to demonstrate the bug with a simple `Query` message + // than the `Prepare` -> `Bind` -> `Execute` flow. + let real_query_prefix = "INSERT INTO injection_target(message) VALUES ('"; + let fake_message = "fake_msg') RETURNING id;\0"; + let real_query_suffix = "') RETURNING id"; + + // Our payload is another simple `Query` message + let real_payload = + "Q\0\0\0\x4DUPDATE injection_target SET message = 'you''ve been pwned!' WHERE id = 1\0"; + + // This is the value we want the length prefix to overflow to (including the length of the prefix itself) + // This will leave the backend's buffer pointing at our real payload. + let fake_payload_len = real_query_prefix.len() + fake_message.len() + 4; + + // Pretty easy to see that this should overflow to `fake_payload_len` + let target_payload_len = overflow_len + fake_payload_len; + + // This is the length we expect `injected_value` to be + let expected_inject_len = target_payload_len + - 4 // Length prefix + - real_query_prefix.len() + - (real_query_suffix.len() + 1 /* NUL terminator */); + + let pad_to_len = expected_inject_len - 5; // Header for FLUSH message that eats `real_query_suffix` (see below) + + let expected_payload_len = 4 // length prefix + + real_query_prefix.len() + + expected_inject_len + + real_query_suffix.len() + + 1; // NUL terminator + + let expected_wrapped_len = expected_payload_len % overflow_len; + assert_eq!(expected_wrapped_len, fake_payload_len); + + // This will be the string we inject into the query. + let mut injected_value = String::with_capacity(expected_inject_len); + + injected_value.push_str(fake_message); + injected_value.push_str(real_payload); + + // The Postgres backend reads the `FLUSH` message but ignores its contents. + // This gives us a variable-length NOP that lets us pad to the length we want, + // as well as a way to eat `real_query_suffix` without breaking the connection. + let flush_fill = "\0".repeat(9996); + + let flush_fmt_code = 'H'; // note: 'F' is `FunctionCall`. + + 'outer: while injected_value.len() < pad_to_len { + let remaining_len = pad_to_len - injected_value.len(); + + // The max length of a FLUSH message is 10,000, including the length prefix. + let flush_len = cmp::min( + remaining_len - 1, // minus format code + 10000, + ); + + // We need `flush_len` to be valid UTF-8 when encoded in big-endian + // in order to push it to the string. + // + // Not every value is going to be valid though, so we search for one that is. + 'inner: for flush_len in (4..=flush_len).rev() { + let flush_len_be = (flush_len as i32).to_be_bytes(); + + let Ok(flush_len_str) = str::from_utf8(&flush_len_be) else { + continue 'inner; + }; + + let fill_len = flush_len - 4; + + injected_value.push(flush_fmt_code); + injected_value.push_str(flush_len_str); + injected_value.push_str(&flush_fill[..fill_len]); + + continue 'outer; + } + + panic!("unable to find a valid encoding/split for {flush_len}"); + } + + assert_eq!(injected_value.len(), pad_to_len); + + // The amount of data the last FLUSH message has to eat + let eat_len = real_query_suffix.len() + 1; // plus NUL terminator + + // Push the FLUSH message that will eat `real_query_suffix` + injected_value.push(flush_fmt_code); + injected_value.push_str(str::from_utf8(&(eat_len as i32).to_be_bytes()).unwrap()); + // The value will be in the buffer already. + + assert_eq!(expected_inject_len, injected_value.len()); + + let query = format!("{real_query_prefix}{injected_value}{real_query_suffix}"); + + // The length of the `Query` message we've created + let final_payload_len = 4 // length prefix + + query.len() + + 1; // NUL terminator + + assert_eq!(expected_payload_len, final_payload_len); + + let wrapped_len = final_payload_len % overflow_len; + + assert_eq!(wrapped_len, fake_payload_len); + + let res = sqlx::raw_sql(&query) + // Note: the connection may hang afterward + // because `pending_ready_for_query_count` will underflow. + .execute(&pool) + .await; + + if let Err(e) = res { + // Connection rejected the query; we're happy. + if matches!(e, Error::Protocol(_)) { + return Ok(()); + } + + panic!("unexpected error: {e:?}"); + } + + let messages: Vec = + sqlx::query_scalar("SELECT message FROM injection_target ORDER BY id") + .fetch_all(&pool) + .await?; + + // If the injection succeeds, `messages` will look like: + // ["you've been pwned!'.to_string(), "fake_msg".to_string()] + assert_eq!( + messages, + ["existing message".to_string(), "fake_msg".to_string()] + ); + + // Injection didn't affect our database; we're happy. + Ok(()) +} From 127d6174f9a96b2fddb2f1809e800643d8bf2ae7 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 23 Aug 2024 18:27:36 -0700 Subject: [PATCH 39/40] chore(sqlite): create regression test for RUSTSEC-2024-0363 --- Cargo.toml | 5 +++ tests/sqlite/rustsec.rs | 78 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 tests/sqlite/rustsec.rs diff --git a/Cargo.toml b/Cargo.toml index 3345aac5f7..58c66388d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -273,6 +273,11 @@ name = "sqlite-migrate" path = "tests/sqlite/migrate.rs" required-features = ["sqlite", "macros", "migrate"] +[[test]] +name = "sqlite-rustsec" +path = "tests/sqlite/rustsec.rs" +required-features = ["sqlite"] + [[bench]] name = "sqlite-describe" path = "benches/sqlite/describe.rs" diff --git a/tests/sqlite/rustsec.rs b/tests/sqlite/rustsec.rs new file mode 100644 index 0000000000..3ff9c524fa --- /dev/null +++ b/tests/sqlite/rustsec.rs @@ -0,0 +1,78 @@ +use sqlx::{Connection, Error, SqliteConnection}; + +// https://rustsec.org/advisories/RUSTSEC-2024-0363.html +// +// Similar theory to the Postgres exploit in `tests/postgres/rustsec.rs` but much simpler +// since we just want to overflow the query length itself. +#[sqlx::test] +async fn rustsec_2024_0363() -> anyhow::Result<()> { + let overflow_len = 4 * 1024 * 1024 * 1024; // 4 GiB + + // `real_query_prefix` plus `fake_message` will be the first query that SQLite "sees" + // + // Rather contrived because this already represents a regular SQL injection, + // but this is the easiest way to demonstrate the exploit for SQLite. + let real_query_prefix = "INSERT INTO injection_target(message) VALUES ('"; + let fake_message = "fake_msg') RETURNING id;"; + let real_query_suffix = "') RETURNING id"; + + // Our actual payload is another query + let real_payload = + "\nUPDATE injection_target SET message = 'you''ve been pwned!' WHERE id = 1;\n--"; + + // This will parse the query up to `real_payload`. + let fake_payload_len = real_query_prefix.len() + fake_message.len(); + + // Pretty easy to see that this will overflow to `fake_payload_len` + let target_len = overflow_len + fake_payload_len; + + let inject_len = target_len - real_query_prefix.len() - real_query_suffix.len(); + + let pad_len = inject_len - fake_message.len() - real_payload.len(); + + let mut injected_value = String::with_capacity(inject_len); + injected_value.push_str(fake_message); + injected_value.push_str(real_payload); + + let padding = " ".repeat(pad_len); + injected_value.push_str(&padding); + + let query = format!("{real_query_prefix}{injected_value}{real_query_suffix}"); + + assert_eq!(query.len(), target_len); + + let mut conn = SqliteConnection::connect("sqlite://:memory:").await?; + + sqlx::raw_sql( + "CREATE TABLE injection_target(id INTEGER PRIMARY KEY, message TEXT);\n\ + INSERT INTO injection_target(message) VALUES ('existing message');", + ) + .execute(&mut conn) + .await?; + + let res = sqlx::raw_sql(&query).execute(&mut conn).await; + + if let Err(e) = res { + // Connection rejected the query; we're happy. + if matches!(e, Error::Protocol(_)) { + return Ok(()); + } + + panic!("unexpected error: {e:?}"); + } + + let messages: Vec = + sqlx::query_scalar("SELECT message FROM injection_target ORDER BY id") + .fetch_all(&mut conn) + .await?; + + // If the injection succeeds, `messages` will look like: + // ["you've been pwned!'.to_string(), "fake_msg".to_string()] + assert_eq!( + messages, + ["existing message".to_string(), "fake_msg".to_string()] + ); + + // Injection didn't affect our database; we're happy. + Ok(()) +} From 7e48dbe3d3746a276d7cef9916241d482fde2d4a Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 23 Aug 2024 20:07:55 -0700 Subject: [PATCH 40/40] chore(mysql): create regression test for RUSTSEC-2024-0363 --- Cargo.toml | 5 ++++ tests/mysql/rustsec.rs | 67 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 tests/mysql/rustsec.rs diff --git a/Cargo.toml b/Cargo.toml index 58c66388d9..e50eb48b86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -328,6 +328,11 @@ name = "mysql-migrate" path = "tests/mysql/migrate.rs" required-features = ["mysql", "macros", "migrate"] +[[test]] +name = "mysql-rustsec" +path = "tests/mysql/rustsec.rs" +required-features = ["mysql"] + # # PostgreSQL # diff --git a/tests/mysql/rustsec.rs b/tests/mysql/rustsec.rs new file mode 100644 index 0000000000..8d8db0c250 --- /dev/null +++ b/tests/mysql/rustsec.rs @@ -0,0 +1,67 @@ +use sqlx::{Error, MySql}; +use std::io; + +use sqlx_test::new; + +// https://rustsec.org/advisories/RUSTSEC-2024-0363.html +// +// During the audit the MySQL driver was found to be *unlikely* to be vulnerable to the exploit, +// so this just serves as a sanity check. +#[sqlx::test] +async fn rustsec_2024_0363() -> anyhow::Result<()> { + let overflow_len = 4 * 1024 * 1024 * 1024; // 4 GiB + + let padding = " ".repeat(overflow_len); + + let payload = "UPDATE injection_target SET message = 'you''ve been pwned!' WHERE id = 1"; + + let mut injected_value = String::with_capacity(overflow_len + payload.len()); + + injected_value.push_str(&padding); + injected_value.push_str(payload); + + // Since this is so large, keeping it around until the end *can* lead to getting OOM-killed. + drop(padding); + + let mut conn = new::().await?; + + sqlx::raw_sql( + "CREATE TEMPORARY TABLE injection_target(id INTEGER PRIMARY KEY AUTO_INCREMENT, message TEXT);\n\ + INSERT INTO injection_target(message) VALUES ('existing message');", + ) + .execute(&mut conn) + .await?; + + // We can't concatenate a query string together like the other tests + // because it would just demonstrate a regular old SQL injection. + let res = sqlx::query("INSERT INTO injection_target(message) VALUES (?)") + .bind(&injected_value) + .execute(&mut conn) + .await; + + if let Err(e) = res { + // Connection rejected the query; we're happy. + // + // Current observed behavior is that `mysqld` closes the connection before we're even done + // sending the message, giving us a "Broken pipe" error. + // + // As it turns out, MySQL has a tight limit on packet sizes (even after splitting) + // by default: https://dev.mysql.com/doc/refman/8.4/en/packet-too-large.html + if matches!(e, Error::Io(ref ioe) if ioe.kind() == io::ErrorKind::BrokenPipe) { + return Ok(()); + } + + panic!("unexpected error: {e:?}"); + } + + let messages: Vec = + sqlx::query_scalar("SELECT message FROM injection_target ORDER BY id") + .fetch_all(&mut conn) + .await?; + + assert_eq!(messages[0], "existing_message"); + assert_eq!(messages[1].len(), injected_value.len()); + + // Injection didn't affect our database; we're happy. + Ok(()) +}