diff --git a/Cargo.lock b/Cargo.lock index c112899415..a0eabbeddc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1987,6 +1987,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "ipnet" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" + [[package]] name = "ipnetwork" version = "0.20.0" @@ -3580,6 +3586,7 @@ dependencies = [ "hashbrown 0.15.2", "hashlink", "indexmap 2.7.0", + "ipnet", "ipnetwork", "log", "mac_address", @@ -3836,6 +3843,7 @@ dependencies = [ "hkdf", "hmac", "home", + "ipnet", "ipnetwork", "itoa", "log", diff --git a/Cargo.toml b/Cargo.toml index f93ed3dded..0ec0039db1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,6 +68,7 @@ _unstable-all-types = [ "json", "time", "chrono", + "ipnet", "ipnetwork", "mac_address", "uuid", @@ -116,6 +117,7 @@ json = ["sqlx-macros?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sq bigdecimal = ["sqlx-core/bigdecimal", "sqlx-macros?/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] bit-vec = ["sqlx-core/bit-vec", "sqlx-macros?/bit-vec", "sqlx-postgres?/bit-vec"] chrono = ["sqlx-core/chrono", "sqlx-macros?/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] +ipnet = ["sqlx-core/ipnet", "sqlx-macros?/ipnet", "sqlx-postgres?/ipnet"] ipnetwork = ["sqlx-core/ipnetwork", "sqlx-macros?/ipnetwork", "sqlx-postgres?/ipnetwork"] mac_address = ["sqlx-core/mac_address", "sqlx-macros?/mac_address", "sqlx-postgres?/mac_address"] rust_decimal = ["sqlx-core/rust_decimal", "sqlx-macros?/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] @@ -142,6 +144,7 @@ sqlx = { version = "=0.8.3", path = ".", default-features = false } bigdecimal = "0.4.0" bit-vec = "0.6.3" chrono = { version = "0.4.34", default-features = false, features = ["std", "clock"] } +ipnet = "2.3.0" ipnetwork = "0.20.0" mac_address = "1.1.5" rust_decimal = { version = "1.26.1", default-features = false, features = ["std"] } diff --git a/README.md b/README.md index 15d68bbb42..fc8d3b7427 100644 --- a/README.md +++ b/README.md @@ -220,6 +220,8 @@ be removed in the future. - `rust_decimal`: Add support for `NUMERIC` using the `rust_decimal` crate. +- `ipnet`: Add support for `INET` and `CIDR` (in postgres) using the `ipnet` crate. + - `ipnetwork`: Add support for `INET` and `CIDR` (in postgres) using the `ipnetwork` crate. - `json`: Add support for `JSON` and `JSONB` (in postgres) using the `serde_json` crate. diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index f767507bb4..b8b7d0eda3 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -49,6 +49,7 @@ bit-vec = { workspace = true, optional = true } bigdecimal = { workspace = true, optional = true } rust_decimal = { workspace = true, optional = true } time = { workspace = true, optional = true } +ipnet = { workspace = true, optional = true } ipnetwork = { workspace = true, optional = true } mac_address = { workspace = true, optional = true } uuid = { workspace = true, optional = true } diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index 25837b1e77..d36bec49ed 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -67,6 +67,13 @@ pub use bigdecimal::BigDecimal; #[doc(no_inline)] pub use rust_decimal::Decimal; +#[cfg(feature = "ipnet")] +#[cfg_attr(docsrs, doc(cfg(feature = "ipnet")))] +pub mod ipnet { + #[doc(no_inline)] + pub use ipnet::{IpNet, Ipv4Net, Ipv6Net}; +} + #[cfg(feature = "ipnetwork")] #[cfg_attr(docsrs, doc(cfg(feature = "ipnetwork")))] pub mod ipnetwork { diff --git a/sqlx-macros-core/Cargo.toml b/sqlx-macros-core/Cargo.toml index 46786b7d8d..85efa80912 100644 --- a/sqlx-macros-core/Cargo.toml +++ b/sqlx-macros-core/Cargo.toml @@ -38,6 +38,7 @@ json = ["sqlx-core/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlit bigdecimal = ["sqlx-core/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] bit-vec = ["sqlx-core/bit-vec", "sqlx-postgres?/bit-vec"] chrono = ["sqlx-core/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] +ipnet = ["sqlx-core/ipnet", "sqlx-postgres?/ipnet"] ipnetwork = ["sqlx-core/ipnetwork", "sqlx-postgres?/ipnetwork"] mac_address = ["sqlx-core/mac_address", "sqlx-postgres?/mac_address"] rust_decimal = ["sqlx-core/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 5617d3f251..b513c3e808 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -37,6 +37,7 @@ sqlite-unbundled = ["sqlx-macros-core/sqlite-unbundled"] bigdecimal = ["sqlx-macros-core/bigdecimal"] bit-vec = ["sqlx-macros-core/bit-vec"] chrono = ["sqlx-macros-core/chrono"] +ipnet = ["sqlx-macros-core/ipnet"] ipnetwork = ["sqlx-macros-core/ipnetwork"] mac_address = ["sqlx-macros-core/mac_address"] rust_decimal = ["sqlx-macros-core/rust_decimal"] diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index 174a73b3fa..818aadbab7 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -19,6 +19,7 @@ offline = ["sqlx-core/offline"] bigdecimal = ["dep:bigdecimal", "dep:num-bigint", "sqlx-core/bigdecimal"] bit-vec = ["dep:bit-vec", "sqlx-core/bit-vec"] chrono = ["dep:chrono", "sqlx-core/chrono"] +ipnet = ["dep:ipnet", "sqlx-core/ipnet"] ipnetwork = ["dep:ipnetwork", "sqlx-core/ipnetwork"] mac_address = ["dep:mac_address", "sqlx-core/mac_address"] rust_decimal = ["dep:rust_decimal", "rust_decimal/maths", "sqlx-core/rust_decimal"] @@ -43,6 +44,7 @@ sha2 = { version = "0.10.0", default-features = false } bigdecimal = { workspace = true, optional = true } bit-vec = { workspace = true, optional = true } chrono = { workspace = true, optional = true } +ipnet = { workspace = true, optional = true } ipnetwork = { workspace = true, optional = true } mac_address = { workspace = true, optional = true } rust_decimal = { workspace = true, optional = true } diff --git a/sqlx-postgres/src/type_checking.rs b/sqlx-postgres/src/type_checking.rs index 314fda0b17..cf7d29f6ce 100644 --- a/sqlx-postgres/src/type_checking.rs +++ b/sqlx-postgres/src/type_checking.rs @@ -77,6 +77,9 @@ impl_type_checking!( #[cfg(feature = "rust_decimal")] sqlx::types::Decimal, + #[cfg(all(feature = "ipnet", not(feature = "ipnetwork")))] + sqlx::types::ipnet::IpNet, + #[cfg(feature = "ipnetwork")] sqlx::types::ipnetwork::IpNetwork, @@ -138,6 +141,9 @@ impl_type_checking!( #[cfg(feature = "rust_decimal")] Vec | &[sqlx::types::Decimal], + #[cfg(all(feature = "ipnet", not(feature = "ipnetwork")))] + Vec | &[sqlx::types::ipnet::IpNet], + #[cfg(feature = "ipnetwork")] Vec | &[sqlx::types::ipnetwork::IpNetwork], diff --git a/sqlx-postgres/src/types/ipnet/ipaddr.rs b/sqlx-postgres/src/types/ipnet/ipaddr.rs new file mode 100644 index 0000000000..b157eff3c6 --- /dev/null +++ b/sqlx-postgres/src/types/ipnet/ipaddr.rs @@ -0,0 +1,62 @@ +use std::net::IpAddr; + +use ipnet::IpNet; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres}; + +impl Type for IpAddr +where + IpNet: Type, +{ + fn type_info() -> PgTypeInfo { + IpNet::type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + IpNet::compatible(ty) + } +} + +impl PgHasArrayType for IpAddr { + fn array_type_info() -> PgTypeInfo { + ::array_type_info() + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + ::array_compatible(ty) + } +} + +impl<'db> Encode<'db, Postgres> for IpAddr +where + IpNet: Encode<'db, Postgres>, +{ + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + IpNet::from(*self).encode_by_ref(buf) + } + + fn size_hint(&self) -> usize { + IpNet::from(*self).size_hint() + } +} + +impl<'db> Decode<'db, Postgres> for IpAddr +where + IpNet: Decode<'db, Postgres>, +{ + fn decode(value: PgValueRef<'db>) -> Result { + let ipnet = IpNet::decode(value)?; + + if matches!(ipnet, IpNet::V4(net) if net.prefix_len() != 32) + || matches!(ipnet, IpNet::V6(net) if net.prefix_len() != 128) + { + Err("lossy decode from inet/cidr")? + } + + Ok(ipnet.addr()) + } +} diff --git a/sqlx-postgres/src/types/ipnet/ipnet.rs b/sqlx-postgres/src/types/ipnet/ipnet.rs new file mode 100644 index 0000000000..1f986174b8 --- /dev/null +++ b/sqlx-postgres/src/types/ipnet/ipnet.rs @@ -0,0 +1,130 @@ +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + +#[cfg(feature = "ipnet")] +use ipnet::{IpNet, Ipv4Net, Ipv6Net}; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +// https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/include/utils/inet.h#L39 + +// Technically this is a magic number here but it doesn't make sense to drag in the whole of `libc` +// just for one constant. +const PGSQL_AF_INET: u8 = 2; // AF_INET +const PGSQL_AF_INET6: u8 = PGSQL_AF_INET + 1; + +impl Type for IpNet { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INET + } + + fn compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::CIDR || *ty == PgTypeInfo::INET + } +} + +impl PgHasArrayType for IpNet { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INET_ARRAY + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::CIDR_ARRAY || *ty == PgTypeInfo::INET_ARRAY + } +} + +impl Encode<'_, Postgres> for IpNet { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + // https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L293 + // https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L271 + + match self { + IpNet::V4(net) => { + buf.push(PGSQL_AF_INET); // ip_family + buf.push(net.prefix_len()); // ip_bits + buf.push(0); // is_cidr + buf.push(4); // nb (number of bytes) + buf.extend_from_slice(&net.addr().octets()) // address + } + + IpNet::V6(net) => { + buf.push(PGSQL_AF_INET6); // ip_family + buf.push(net.prefix_len()); // ip_bits + buf.push(0); // is_cidr + buf.push(16); // nb (number of bytes) + buf.extend_from_slice(&net.addr().octets()); // address + } + } + + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + match self { + IpNet::V4(_) => 8, + IpNet::V6(_) => 20, + } + } +} + +impl Decode<'_, Postgres> for IpNet { + fn decode(value: PgValueRef<'_>) -> Result { + let bytes = match value.format() { + PgValueFormat::Binary => value.as_bytes()?, + PgValueFormat::Text => { + let s = value.as_str()?; + println!("{s}"); + if s.contains('/') { + return Ok(s.parse()?); + } + // IpNet::from_str doesn't handle conversion from IpAddr to IpNet + let addr: IpAddr = s.parse()?; + return Ok(addr.into()); + } + }; + + if bytes.len() >= 8 { + let family = bytes[0]; + let prefix = bytes[1]; + let _is_cidr = bytes[2] != 0; + let len = bytes[3]; + + match family { + PGSQL_AF_INET => { + if bytes.len() == 8 && len == 4 { + let inet = Ipv4Net::new( + Ipv4Addr::new(bytes[4], bytes[5], bytes[6], bytes[7]), + prefix, + )?; + + return Ok(IpNet::V4(inet)); + } + } + + PGSQL_AF_INET6 => { + if bytes.len() == 20 && len == 16 { + let inet = Ipv6Net::new( + Ipv6Addr::from([ + bytes[4], bytes[5], bytes[6], bytes[7], bytes[8], bytes[9], + bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15], + bytes[16], bytes[17], bytes[18], bytes[19], + ]), + prefix, + )?; + + return Ok(IpNet::V6(inet)); + } + } + + _ => { + return Err(format!("unknown ip family {family}").into()); + } + } + } + + Err("invalid data received when expecting an INET".into()) + } +} diff --git a/sqlx-postgres/src/types/ipnet/mod.rs b/sqlx-postgres/src/types/ipnet/mod.rs new file mode 100644 index 0000000000..cd40cf30da --- /dev/null +++ b/sqlx-postgres/src/types/ipnet/mod.rs @@ -0,0 +1,7 @@ +// Prefer `ipnetwork` over `ipnet` because it was implemented first (want to avoid breaking change). +#[cfg(not(feature = "ipnetwork"))] +mod ipaddr; + +// Parent module is named after the `ipnet` crate, this is named after the `IpNet` type. +#[allow(clippy::module_inception)] +mod ipnet; diff --git a/sqlx-postgres/src/types/ipaddr.rs b/sqlx-postgres/src/types/ipnetwork/ipaddr.rs similarity index 100% rename from sqlx-postgres/src/types/ipaddr.rs rename to sqlx-postgres/src/types/ipnetwork/ipaddr.rs diff --git a/sqlx-postgres/src/types/ipnetwork.rs b/sqlx-postgres/src/types/ipnetwork/ipnetwork.rs similarity index 100% rename from sqlx-postgres/src/types/ipnetwork.rs rename to sqlx-postgres/src/types/ipnetwork/ipnetwork.rs diff --git a/sqlx-postgres/src/types/ipnetwork/mod.rs b/sqlx-postgres/src/types/ipnetwork/mod.rs new file mode 100644 index 0000000000..de40244c65 --- /dev/null +++ b/sqlx-postgres/src/types/ipnetwork/mod.rs @@ -0,0 +1,5 @@ +mod ipaddr; + +// Parent module is named after the `ipnetwork` crate, this is named after the `IpNetwork` type. +#[allow(clippy::module_inception)] +mod ipnetwork; diff --git a/sqlx-postgres/src/types/mod.rs b/sqlx-postgres/src/types/mod.rs index b5b3266cbc..4fef56bb23 100644 --- a/sqlx-postgres/src/types/mod.rs +++ b/sqlx-postgres/src/types/mod.rs @@ -83,7 +83,7 @@ //! //! ### [`ipnetwork`](https://crates.io/crates/ipnetwork) //! -//! Requires the `ipnetwork` Cargo feature flag. +//! Requires the `ipnetwork` Cargo feature flag (takes precedence over `ipnet` if both are used). //! //! | Rust type | Postgres type(s) | //! |---------------------------------------|------------------------------------------------------| @@ -96,6 +96,17 @@ //! //! `IpNetwork` does not have this limitation. //! +//! ### [`ipnet`](https://crates.io/crates/ipnet) +//! +//! Requires the `ipnet` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `ipnet::IpNet` | INET, CIDR | +//! | `std::net::IpAddr` | INET, CIDR | +//! +//! The same `IpAddr` limitation for smaller network prefixes applies as with `ipnet`. +//! //! ### [`mac_address`](https://crates.io/crates/mac_address) //! //! Requires the `mac_address` Cargo feature flag. @@ -244,11 +255,11 @@ mod time; #[cfg(feature = "uuid")] mod uuid; -#[cfg(feature = "ipnetwork")] -mod ipnetwork; +#[cfg(feature = "ipnet")] +mod ipnet; #[cfg(feature = "ipnetwork")] -mod ipaddr; +mod ipnetwork; #[cfg(feature = "mac_address")] mod mac_address; diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index c1cf87983c..34c326ffee 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -2,6 +2,7 @@ extern crate time_ as time; use std::net::SocketAddr; use std::ops::Bound; +use std::str::FromStr; use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange}; use sqlx::postgres::Postgres; @@ -9,7 +10,6 @@ use sqlx_test::{new, test_decode_type, test_prepared_type, test_type}; use sqlx_core::executor::Executor; use sqlx_core::types::Text; -use std::str::FromStr; test_type!(null>(Postgres, "NULL::int2" == None:: @@ -171,6 +171,38 @@ test_type!(uuid_vec>(Postgres, ] )); +#[cfg(feature = "ipnet")] +test_type!(ipnet(Postgres, + "'127.0.0.1'::inet" + == "127.0.0.1/32" + .parse::() + .unwrap(), + "'8.8.8.8/24'::inet" + == "8.8.8.8/24" + .parse::() + .unwrap(), + "'10.1.1/24'::inet" + == "10.1.1.0/24" + .parse::() + .unwrap(), + "'::ffff:1.2.3.0'::inet" + == "::ffff:1.2.3.0/128" + .parse::() + .unwrap(), + "'2001:4f8:3:ba::/64'::inet" + == "2001:4f8:3:ba::/64" + .parse::() + .unwrap(), + "'192.168'::cidr" + == "192.168.0.0/24" + .parse::() + .unwrap(), + "'::ffff:1.2.3.0/120'::cidr" + == "::ffff:1.2.3.0/120" + .parse::() + .unwrap(), +)); + #[cfg(feature = "ipnetwork")] test_type!(ipnetwork(Postgres, "'127.0.0.1'::inet" @@ -232,6 +264,15 @@ test_type!(bitvec( }, )); +#[cfg(feature = "ipnet")] +test_type!(ipnet_vec>(Postgres, + "'{127.0.0.1,8.8.8.8/24}'::inet[]" + == vec![ + "127.0.0.1/32".parse::().unwrap(), + "8.8.8.8/24".parse::().unwrap() + ] +)); + #[cfg(feature = "ipnetwork")] test_type!(ipnetwork_vec>(Postgres, "'{127.0.0.1,8.8.8.8/24}'::inet[]" diff --git a/tests/ui-tests.rs b/tests/ui-tests.rs index f74694b870..4a5ca240e1 100644 --- a/tests/ui-tests.rs +++ b/tests/ui-tests.rs @@ -17,7 +17,7 @@ fn ui_tests() { t.compile_fail("tests/ui/postgres/gated/uuid.rs"); } - if cfg!(not(feature = "ipnetwork")) { + if cfg!(not(feature = "ipnet")) && cfg!(not(feature = "ipnetwork")) { t.compile_fail("tests/ui/postgres/gated/ipnetwork.rs"); } }