Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add postres geometry line segment #3690

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sqlx-postgres/src/type_checking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ impl_type_checking!(

sqlx::postgres::types::PgLine,

sqlx::postgres::types::PgLSeg,

#[cfg(feature = "uuid")]
sqlx::types::Uuid,

Expand Down
282 changes: 282 additions & 0 deletions sqlx-postgres/src/types/geometry/line_segment.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::types::Type;
use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
use sqlx_core::bytes::Buf;
use std::str::FromStr;

const ERROR: &str = "error decoding LSEG";

/// ## Postgres Geometric Line Segment type
///
/// Description: Finite line segment
/// Representation: `((start_x,start_y),(end_x,end_y))`
///
///
/// Line segments are represented by pairs of points that are the endpoints of the segment. Values of type lseg are specified using any of the following syntaxes:
/// ```text
/// [ ( start_x , start_y ) , ( end_x , end_y ) ]
/// ( ( start_x , start_y ) , ( end_x , end_y ) )
/// ( start_x , start_y ) , ( end_x , end_y )
/// start_x , start_y , end_x , end_y
/// ```
/// where `(start_x,start_y) and (end_x,end_y)` are the end points of the line segment.
///
/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-LSEG
#[derive(Debug, Clone, PartialEq)]
pub struct PgLSeg {
pub start_x: f64,
pub start_y: f64,
pub end_x: f64,
pub end_y: f64,
}

impl Type<Postgres> for PgLSeg {
fn type_info() -> PgTypeInfo {
PgTypeInfo::with_name("lseg")
}
}

impl PgHasArrayType for PgLSeg {
fn array_type_info() -> PgTypeInfo {
PgTypeInfo::with_name("_lseg")
}
}

impl<'r> Decode<'r, Postgres> for PgLSeg {
fn decode(value: PgValueRef<'r>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
match value.format() {
PgValueFormat::Text => Ok(PgLSeg::from_str(value.as_str()?)?),
PgValueFormat::Binary => Ok(PgLSeg::from_bytes(value.as_bytes()?)?),
}
}
}

impl<'q> Encode<'q, Postgres> for PgLSeg {
fn produces(&self) -> Option<PgTypeInfo> {
Some(PgTypeInfo::with_name("lseg"))
}

fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
self.serialize(buf)?;
Ok(IsNull::No)
}
}

impl FromStr for PgLSeg {
type Err = BoxDynError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let sanitised = s.replace(['(', ')', '[', ']', ' '], "");
let mut parts = sanitised.split(',');

let start_x = parts
.next()
.and_then(|s| s.parse::<f64>().ok())
.ok_or_else(|| format!("{}: could not get start_x from {}", ERROR, s))?;

let start_y = parts
.next()
.and_then(|s| s.parse::<f64>().ok())
.ok_or_else(|| format!("{}: could not get start_y from {}", ERROR, s))?;

let end_x = parts
.next()
.and_then(|s| s.parse::<f64>().ok())
.ok_or_else(|| format!("{}: could not get end_x from {}", ERROR, s))?;

let end_y = parts
.next()
.and_then(|s| s.parse::<f64>().ok())
.ok_or_else(|| format!("{}: could not get end_y from {}", ERROR, s))?;

if parts.next().is_some() {
return Err(format!("{}: too many numbers inputted in {}", ERROR, s).into());
}

Ok(PgLSeg {
start_x,
start_y,
end_x,
end_y,
})
}
}

impl PgLSeg {
fn from_bytes(mut bytes: &[u8]) -> Result<PgLSeg, BoxDynError> {
let start_x = bytes.get_f64();
let start_y = bytes.get_f64();
let end_x = bytes.get_f64();
let end_y = bytes.get_f64();

Ok(PgLSeg {
start_x,
start_y,
end_x,
end_y,
})
}

fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> {
buff.extend_from_slice(&self.start_x.to_be_bytes());
buff.extend_from_slice(&self.start_y.to_be_bytes());
buff.extend_from_slice(&self.end_x.to_be_bytes());
buff.extend_from_slice(&self.end_y.to_be_bytes());
Ok(())
}

#[cfg(test)]
fn serialize_to_vec(&self) -> Vec<u8> {
let mut buff = PgArgumentBuffer::default();
self.serialize(&mut buff).unwrap();
buff.to_vec()
}
}

#[cfg(test)]
mod lseg_tests {

use std::str::FromStr;

use super::PgLSeg;

const LINE_SEGMENT_BYTES: &[u8] = &[
63, 241, 153, 153, 153, 153, 153, 154, 64, 1, 153, 153, 153, 153, 153, 154, 64, 10, 102,
102, 102, 102, 102, 102, 64, 17, 153, 153, 153, 153, 153, 154,
];

#[test]
fn can_deserialise_lseg_type_bytes() {
let lseg = PgLSeg::from_bytes(LINE_SEGMENT_BYTES).unwrap();
assert_eq!(
lseg,
PgLSeg {
start_x: 1.1,
start_y: 2.2,
end_x: 3.3,
end_y: 4.4
}
)
}

#[test]
fn can_deserialise_lseg_type_str_first_syntax() {
let lseg = PgLSeg::from_str("[( 1, 2), (3, 4 )]").unwrap();
assert_eq!(
lseg,
PgLSeg {
start_x: 1.,
start_y: 2.,
end_x: 3.,
end_y: 4.
}
);
}
#[test]
fn can_deserialise_lseg_type_str_second_syntax() {
let lseg = PgLSeg::from_str("(( 1, 2), (3, 4 ))").unwrap();
assert_eq!(
lseg,
PgLSeg {
start_x: 1.,
start_y: 2.,
end_x: 3.,
end_y: 4.
}
);
}

#[test]
fn can_deserialise_lseg_type_str_third_syntax() {
let lseg = PgLSeg::from_str("(1, 2), (3, 4 )").unwrap();
assert_eq!(
lseg,
PgLSeg {
start_x: 1.,
start_y: 2.,
end_x: 3.,
end_y: 4.
}
);
}

#[test]
fn can_deserialise_lseg_type_str_fourth_syntax() {
let lseg = PgLSeg::from_str("1, 2, 3, 4").unwrap();
assert_eq!(
lseg,
PgLSeg {
start_x: 1.,
start_y: 2.,
end_x: 3.,
end_y: 4.
}
);
}

#[test]
fn can_deserialise_too_many_numbers() {
let input_str = "1, 2, 3, 4, 5";
let lseg = PgLSeg::from_str(input_str);
assert!(lseg.is_err());
if let Err(err) = lseg {
assert_eq!(
err.to_string(),
format!("error decoding LSEG: too many numbers inputted in {input_str}")
)
}
}

#[test]
fn can_deserialise_too_few_numbers() {
let input_str = "1, 2, 3";
let lseg = PgLSeg::from_str(input_str);
assert!(lseg.is_err());
if let Err(err) = lseg {
assert_eq!(
err.to_string(),
format!("error decoding LSEG: could not get end_y from {input_str}")
)
}
}

#[test]
fn can_deserialise_invalid_numbers() {
let input_str = "1, 2, 3, FOUR";
let lseg = PgLSeg::from_str(input_str);
assert!(lseg.is_err());
if let Err(err) = lseg {
assert_eq!(
err.to_string(),
format!("error decoding LSEG: could not get end_y from {input_str}")
)
}
}

#[test]
fn can_deserialise_lseg_type_str_float() {
let lseg = PgLSeg::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap();
assert_eq!(
lseg,
PgLSeg {
start_x: 1.1,
start_y: 2.2,
end_x: 3.3,
end_y: 4.4
}
);
}

#[test]
fn can_serialise_lseg_type() {
let lseg = PgLSeg {
start_x: 1.1,
start_y: 2.2,
end_x: 3.3,
end_y: 4.4,
};
assert_eq!(lseg.serialize_to_vec(), LINE_SEGMENT_BYTES,)
}
}
1 change: 1 addition & 0 deletions sqlx-postgres/src/types/geometry/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod line;
pub mod line_segment;
pub mod point;
2 changes: 2 additions & 0 deletions sqlx-postgres/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
//! | [`PgCube`] | CUBE |
//! | [`PgPoint] | POINT |
//! | [`PgLine] | LINE |
//! | [`PgLSeg] | LSEG |
//! | [`PgHstore`] | HSTORE |
//!
//! <sup>1</sup> SQLx generally considers `CITEXT` to be compatible with `String`, `&str`, etc.,
Expand Down Expand Up @@ -259,6 +260,7 @@ pub use array::PgHasArrayType;
pub use citext::PgCiText;
pub use cube::PgCube;
pub use geometry::line::PgLine;
pub use geometry::line_segment::PgLSeg;
pub use geometry::point::PgPoint;
pub use hstore::PgHstore;
pub use interval::PgInterval;
Expand Down
5 changes: 5 additions & 0 deletions tests/postgres/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,11 @@ test_type!(line<sqlx::postgres::types::PgLine>(Postgres,
"line('((0.0, 0.0), (1.0,1.0))')" == sqlx::postgres::types::PgLine { a: 1., b: -1., c: 0. },
));

#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))]
test_type!(lseg<sqlx::postgres::types::PgLSeg>(Postgres,
"lseg('((1.0, 2.0), (3.0,4.0))')" == sqlx::postgres::types::PgLSeg { start_x: 1., start_y: 2., end_x: 3. , end_y: 4.},
));

#[cfg(feature = "rust_decimal")]
test_type!(decimal<sqlx::types::Decimal>(Postgres,
"0::numeric" == sqlx::types::Decimal::from_str("0").unwrap(),
Expand Down
Loading