Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ipnet support #3710

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ _unstable-all-types = [
"json",
"time",
"chrono",
"ipnet",
"ipnetwork",
"mac_address",
"uuid",
Expand Down Expand Up @@ -116,6 +117,7 @@ json = ["sqlx-macros?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sq
bigdecimal = ["sqlx-core/bigdecimal", "sqlx-macros?/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"]
bit-vec = ["sqlx-core/bit-vec", "sqlx-macros?/bit-vec", "sqlx-postgres?/bit-vec"]
chrono = ["sqlx-core/chrono", "sqlx-macros?/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"]
ipnet = ["sqlx-core/ipnet", "sqlx-macros?/ipnet", "sqlx-postgres?/ipnet"]
ipnetwork = ["sqlx-core/ipnetwork", "sqlx-macros?/ipnetwork", "sqlx-postgres?/ipnetwork"]
mac_address = ["sqlx-core/mac_address", "sqlx-macros?/mac_address", "sqlx-postgres?/mac_address"]
rust_decimal = ["sqlx-core/rust_decimal", "sqlx-macros?/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"]
Expand All @@ -142,6 +144,7 @@ sqlx = { version = "=0.8.3", path = ".", default-features = false }
bigdecimal = "0.4.0"
bit-vec = "0.6.3"
chrono = { version = "0.4.34", default-features = false, features = ["std", "clock"] }
ipnet = "2.3.0"
ipnetwork = "0.20.0"
mac_address = "1.1.5"
rust_decimal = { version = "1.26.1", default-features = false, features = ["std"] }
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ be removed in the future.

- `rust_decimal`: Add support for `NUMERIC` using the `rust_decimal` crate.

- `ipnet`: Add support for `INET` and `CIDR` (in postgres) using the `ipnet` crate.

- `ipnetwork`: Add support for `INET` and `CIDR` (in postgres) using the `ipnetwork` crate.

- `json`: Add support for `JSON` and `JSONB` (in postgres) using the `serde_json` crate.
Expand Down
1 change: 1 addition & 0 deletions sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ bit-vec = { workspace = true, optional = true }
bigdecimal = { workspace = true, optional = true }
rust_decimal = { workspace = true, optional = true }
time = { workspace = true, optional = true }
ipnet = { workspace = true, optional = true }
ipnetwork = { workspace = true, optional = true }
mac_address = { workspace = true, optional = true }
uuid = { workspace = true, optional = true }
Expand Down
7 changes: 7 additions & 0 deletions sqlx-core/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ pub use bigdecimal::BigDecimal;
#[doc(no_inline)]
pub use rust_decimal::Decimal;

#[cfg(feature = "ipnet")]
#[cfg_attr(docsrs, doc(cfg(feature = "ipnet")))]
pub mod ipnet {
#[doc(no_inline)]
pub use ipnet::{IpNet, Ipv4Net, Ipv6Net};
}

#[cfg(feature = "ipnetwork")]
#[cfg_attr(docsrs, doc(cfg(feature = "ipnetwork")))]
pub mod ipnetwork {
Expand Down
1 change: 1 addition & 0 deletions sqlx-macros-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ json = ["sqlx-core/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlit
bigdecimal = ["sqlx-core/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"]
bit-vec = ["sqlx-core/bit-vec", "sqlx-postgres?/bit-vec"]
chrono = ["sqlx-core/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"]
ipnet = ["sqlx-core/ipnet", "sqlx-postgres?/ipnet"]
ipnetwork = ["sqlx-core/ipnetwork", "sqlx-postgres?/ipnetwork"]
mac_address = ["sqlx-core/mac_address", "sqlx-postgres?/mac_address"]
rust_decimal = ["sqlx-core/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"]
Expand Down
1 change: 1 addition & 0 deletions sqlx-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ sqlite-unbundled = ["sqlx-macros-core/sqlite-unbundled"]
bigdecimal = ["sqlx-macros-core/bigdecimal"]
bit-vec = ["sqlx-macros-core/bit-vec"]
chrono = ["sqlx-macros-core/chrono"]
ipnet = ["sqlx-macros-core/ipnet"]
ipnetwork = ["sqlx-macros-core/ipnetwork"]
mac_address = ["sqlx-macros-core/mac_address"]
rust_decimal = ["sqlx-macros-core/rust_decimal"]
Expand Down
2 changes: 2 additions & 0 deletions sqlx-postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ offline = ["sqlx-core/offline"]
bigdecimal = ["dep:bigdecimal", "dep:num-bigint", "sqlx-core/bigdecimal"]
bit-vec = ["dep:bit-vec", "sqlx-core/bit-vec"]
chrono = ["dep:chrono", "sqlx-core/chrono"]
ipnet = ["dep:ipnet", "sqlx-core/ipnet"]
ipnetwork = ["dep:ipnetwork", "sqlx-core/ipnetwork"]
mac_address = ["dep:mac_address", "sqlx-core/mac_address"]
rust_decimal = ["dep:rust_decimal", "rust_decimal/maths", "sqlx-core/rust_decimal"]
Expand All @@ -43,6 +44,7 @@ sha2 = { version = "0.10.0", default-features = false }
bigdecimal = { workspace = true, optional = true }
bit-vec = { workspace = true, optional = true }
chrono = { workspace = true, optional = true }
ipnet = { workspace = true, optional = true }
ipnetwork = { workspace = true, optional = true }
mac_address = { workspace = true, optional = true }
rust_decimal = { workspace = true, optional = true }
Expand Down
6 changes: 6 additions & 0 deletions sqlx-postgres/src/type_checking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ impl_type_checking!(
#[cfg(feature = "rust_decimal")]
sqlx::types::Decimal,

#[cfg(all(feature = "ipnet", not(feature = "ipnetwork")))]
sqlx::types::ipnet::IpNet,

#[cfg(feature = "ipnetwork")]
sqlx::types::ipnetwork::IpNetwork,

Expand Down Expand Up @@ -138,6 +141,9 @@ impl_type_checking!(
#[cfg(feature = "rust_decimal")]
Vec<sqlx::types::Decimal> | &[sqlx::types::Decimal],

#[cfg(all(feature = "ipnet", not(feature = "ipnetwork")))]
Vec<sqlx::types::ipnet::IpNet> | &[sqlx::types::ipnet::IpNet],

#[cfg(feature = "ipnetwork")]
Vec<sqlx::types::ipnetwork::IpNetwork> | &[sqlx::types::ipnetwork::IpNetwork],

Expand Down
62 changes: 62 additions & 0 deletions sqlx-postgres/src/types/ipnet/ipaddr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use std::net::IpAddr;

use ipnet::IpNet;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::types::Type;
use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres};

impl Type<Postgres> for IpAddr
where
IpNet: Type<Postgres>,
{
fn type_info() -> PgTypeInfo {
IpNet::type_info()
}

fn compatible(ty: &PgTypeInfo) -> bool {
IpNet::compatible(ty)
}
}

impl PgHasArrayType for IpAddr {
fn array_type_info() -> PgTypeInfo {
<IpNet as PgHasArrayType>::array_type_info()
}

fn array_compatible(ty: &PgTypeInfo) -> bool {
<IpNet as PgHasArrayType>::array_compatible(ty)
}
}

impl<'db> Encode<'db, Postgres> for IpAddr
where
IpNet: Encode<'db, Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
IpNet::from(*self).encode_by_ref(buf)
}

fn size_hint(&self) -> usize {
IpNet::from(*self).size_hint()
}
}

impl<'db> Decode<'db, Postgres> for IpAddr
where
IpNet: Decode<'db, Postgres>,
{
fn decode(value: PgValueRef<'db>) -> Result<Self, BoxDynError> {
let ipnet = IpNet::decode(value)?;

if matches!(ipnet, IpNet::V4(net) if net.prefix_len() != 32)
|| matches!(ipnet, IpNet::V6(net) if net.prefix_len() != 128)
{
Err("lossy decode from inet/cidr")?
}

Ok(ipnet.addr())
}
}
130 changes: 130 additions & 0 deletions sqlx-postgres/src/types/ipnet/ipnet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};

#[cfg(feature = "ipnet")]
use ipnet::{IpNet, Ipv4Net, Ipv6Net};

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::types::Type;
use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};

// https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/include/utils/inet.h#L39

// Technically this is a magic number here but it doesn't make sense to drag in the whole of `libc`
// just for one constant.
const PGSQL_AF_INET: u8 = 2; // AF_INET
const PGSQL_AF_INET6: u8 = PGSQL_AF_INET + 1;

impl Type<Postgres> for IpNet {
fn type_info() -> PgTypeInfo {
PgTypeInfo::INET
}

fn compatible(ty: &PgTypeInfo) -> bool {
*ty == PgTypeInfo::CIDR || *ty == PgTypeInfo::INET
}
}

impl PgHasArrayType for IpNet {
fn array_type_info() -> PgTypeInfo {
PgTypeInfo::INET_ARRAY
}

fn array_compatible(ty: &PgTypeInfo) -> bool {
*ty == PgTypeInfo::CIDR_ARRAY || *ty == PgTypeInfo::INET_ARRAY
}
}

impl Encode<'_, Postgres> for IpNet {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
// https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L293
// https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L271

match self {
IpNet::V4(net) => {
buf.push(PGSQL_AF_INET); // ip_family
buf.push(net.prefix_len()); // ip_bits
buf.push(0); // is_cidr
buf.push(4); // nb (number of bytes)
buf.extend_from_slice(&net.addr().octets()) // address
}

IpNet::V6(net) => {
buf.push(PGSQL_AF_INET6); // ip_family
buf.push(net.prefix_len()); // ip_bits
buf.push(0); // is_cidr
buf.push(16); // nb (number of bytes)
buf.extend_from_slice(&net.addr().octets()); // address
}
}

Ok(IsNull::No)
}

fn size_hint(&self) -> usize {
match self {
IpNet::V4(_) => 8,
IpNet::V6(_) => 20,
}
}
}

impl Decode<'_, Postgres> for IpNet {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
let bytes = match value.format() {
PgValueFormat::Binary => value.as_bytes()?,
PgValueFormat::Text => {
let s = value.as_str()?;
println!("{s}");
if s.contains('/') {
return Ok(s.parse()?);
}
// IpNet::from_str doesn't handle conversion from IpAddr to IpNet
let addr: IpAddr = s.parse()?;
return Ok(addr.into());
}
};

if bytes.len() >= 8 {
let family = bytes[0];
let prefix = bytes[1];
let _is_cidr = bytes[2] != 0;
let len = bytes[3];

match family {
PGSQL_AF_INET => {
if bytes.len() == 8 && len == 4 {
let inet = Ipv4Net::new(
Ipv4Addr::new(bytes[4], bytes[5], bytes[6], bytes[7]),
prefix,
)?;

return Ok(IpNet::V4(inet));
}
}

PGSQL_AF_INET6 => {
if bytes.len() == 20 && len == 16 {
let inet = Ipv6Net::new(
Ipv6Addr::from([
bytes[4], bytes[5], bytes[6], bytes[7], bytes[8], bytes[9],
bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
bytes[16], bytes[17], bytes[18], bytes[19],
]),
prefix,
)?;

return Ok(IpNet::V6(inet));
}
}

_ => {
return Err(format!("unknown ip family {family}").into());
}
}
}

Err("invalid data received when expecting an INET".into())
}
}
7 changes: 7 additions & 0 deletions sqlx-postgres/src/types/ipnet/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Prefer `ipnetwork` over `ipnet` because it was implemented first (want to avoid breaking change).
#[cfg(not(feature = "ipnetwork"))]
mod ipaddr;

// Parent module is named after the `ipnet` crate, this is named after the `IpNet` type.
#[allow(clippy::module_inception)]
mod ipnet;
5 changes: 5 additions & 0 deletions sqlx-postgres/src/types/ipnetwork/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod ipaddr;

// Parent module is named after the `ipnetwork` crate, this is named after the `IpNetwork` type.
#[allow(clippy::module_inception)]
mod ipnetwork;
19 changes: 15 additions & 4 deletions sqlx-postgres/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
//!
//! ### [`ipnetwork`](https://crates.io/crates/ipnetwork)
//!
//! Requires the `ipnetwork` Cargo feature flag.
//! Requires the `ipnetwork` Cargo feature flag (takes precedence over `ipnet` if both are used).
//!
//! | Rust type | Postgres type(s) |
//! |---------------------------------------|------------------------------------------------------|
Expand All @@ -96,6 +96,17 @@
//!
//! `IpNetwork` does not have this limitation.
//!
//! ### [`ipnet`](https://crates.io/crates/ipnet)
//!
//! Requires the `ipnet` Cargo feature flag.
//!
//! | Rust type | Postgres type(s) |
//! |---------------------------------------|------------------------------------------------------|
//! | `ipnet::IpNet` | INET, CIDR |
//! | `std::net::IpAddr` | INET, CIDR |
//!
//! The same `IpAddr` limitation for smaller network prefixes applies as with `ipnet`.
//!
//! ### [`mac_address`](https://crates.io/crates/mac_address)
//!
//! Requires the `mac_address` Cargo feature flag.
Expand Down Expand Up @@ -244,11 +255,11 @@ mod time;
#[cfg(feature = "uuid")]
mod uuid;

#[cfg(feature = "ipnetwork")]
mod ipnetwork;
#[cfg(feature = "ipnet")]
mod ipnet;

#[cfg(feature = "ipnetwork")]
mod ipaddr;
mod ipnetwork;

#[cfg(feature = "mac_address")]
mod mac_address;
Expand Down
Loading