diff --git a/sqlx-postgres/src/type_checking.rs b/sqlx-postgres/src/type_checking.rs index eb18c5a999..314fda0b17 100644 --- a/sqlx-postgres/src/type_checking.rs +++ b/sqlx-postgres/src/type_checking.rs @@ -36,6 +36,8 @@ impl_type_checking!( sqlx::postgres::types::PgLine, + sqlx::postgres::types::PgLSeg, + #[cfg(feature = "uuid")] sqlx::types::Uuid, diff --git a/sqlx-postgres/src/types/geometry/line_segment.rs b/sqlx-postgres/src/types/geometry/line_segment.rs new file mode 100644 index 0000000000..ebe32d97d0 --- /dev/null +++ b/sqlx-postgres/src/types/geometry/line_segment.rs @@ -0,0 +1,282 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use std::str::FromStr; + +const ERROR: &str = "error decoding LSEG"; + +/// ## Postgres Geometric Line Segment type +/// +/// Description: Finite line segment +/// Representation: `((start_x,start_y),(end_x,end_y))` +/// +/// +/// Line segments are represented by pairs of points that are the endpoints of the segment. Values of type lseg are specified using any of the following syntaxes: +/// ```text +/// [ ( start_x , start_y ) , ( end_x , end_y ) ] +/// ( ( start_x , start_y ) , ( end_x , end_y ) ) +/// ( start_x , start_y ) , ( end_x , end_y ) +/// start_x , start_y , end_x , end_y +/// ``` +/// where `(start_x,start_y) and (end_x,end_y)` are the end points of the line segment. +/// +/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-LSEG +#[derive(Debug, Clone, PartialEq)] +pub struct PgLSeg { + pub start_x: f64, + pub start_y: f64, + pub end_x: f64, + pub end_y: f64, +} + +impl Type for PgLSeg { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("lseg") + } +} + +impl PgHasArrayType for PgLSeg { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_lseg") + } +} + +impl<'r> Decode<'r, Postgres> for PgLSeg { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgLSeg::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgLSeg::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgLSeg { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("lseg")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgLSeg { + type Err = BoxDynError; + + fn from_str(s: &str) -> Result { + let sanitised = s.replace(['(', ')', '[', ']', ' '], ""); + let mut parts = sanitised.split(','); + + let start_x = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get start_x from {}", ERROR, s))?; + + let start_y = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get start_y from {}", ERROR, s))?; + + let end_x = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get end_x from {}", ERROR, s))?; + + let end_y = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get end_y from {}", ERROR, s))?; + + if parts.next().is_some() { + return Err(format!("{}: too many numbers inputted in {}", ERROR, s).into()); + } + + Ok(PgLSeg { + start_x, + start_y, + end_x, + end_y, + }) + } +} + +impl PgLSeg { + fn from_bytes(mut bytes: &[u8]) -> Result { + let start_x = bytes.get_f64(); + let start_y = bytes.get_f64(); + let end_x = bytes.get_f64(); + let end_y = bytes.get_f64(); + + Ok(PgLSeg { + start_x, + start_y, + end_x, + end_y, + }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> { + buff.extend_from_slice(&self.start_x.to_be_bytes()); + buff.extend_from_slice(&self.start_y.to_be_bytes()); + buff.extend_from_slice(&self.end_x.to_be_bytes()); + buff.extend_from_slice(&self.end_y.to_be_bytes()); + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +#[cfg(test)] +mod lseg_tests { + + use std::str::FromStr; + + use super::PgLSeg; + + const LINE_SEGMENT_BYTES: &[u8] = &[ + 63, 241, 153, 153, 153, 153, 153, 154, 64, 1, 153, 153, 153, 153, 153, 154, 64, 10, 102, + 102, 102, 102, 102, 102, 64, 17, 153, 153, 153, 153, 153, 154, + ]; + + #[test] + fn can_deserialise_lseg_type_bytes() { + let lseg = PgLSeg::from_bytes(LINE_SEGMENT_BYTES).unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1.1, + start_y: 2.2, + end_x: 3.3, + end_y: 4.4 + } + ) + } + + #[test] + fn can_deserialise_lseg_type_str_first_syntax() { + let lseg = PgLSeg::from_str("[( 1, 2), (3, 4 )]").unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1., + start_y: 2., + end_x: 3., + end_y: 4. + } + ); + } + #[test] + fn can_deserialise_lseg_type_str_second_syntax() { + let lseg = PgLSeg::from_str("(( 1, 2), (3, 4 ))").unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1., + start_y: 2., + end_x: 3., + end_y: 4. + } + ); + } + + #[test] + fn can_deserialise_lseg_type_str_third_syntax() { + let lseg = PgLSeg::from_str("(1, 2), (3, 4 )").unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1., + start_y: 2., + end_x: 3., + end_y: 4. + } + ); + } + + #[test] + fn can_deserialise_lseg_type_str_fourth_syntax() { + let lseg = PgLSeg::from_str("1, 2, 3, 4").unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1., + start_y: 2., + end_x: 3., + end_y: 4. + } + ); + } + + #[test] + fn can_deserialise_too_many_numbers() { + let input_str = "1, 2, 3, 4, 5"; + let lseg = PgLSeg::from_str(input_str); + assert!(lseg.is_err()); + if let Err(err) = lseg { + assert_eq!( + err.to_string(), + format!("error decoding LSEG: too many numbers inputted in {input_str}") + ) + } + } + + #[test] + fn can_deserialise_too_few_numbers() { + let input_str = "1, 2, 3"; + let lseg = PgLSeg::from_str(input_str); + assert!(lseg.is_err()); + if let Err(err) = lseg { + assert_eq!( + err.to_string(), + format!("error decoding LSEG: could not get end_y from {input_str}") + ) + } + } + + #[test] + fn can_deserialise_invalid_numbers() { + let input_str = "1, 2, 3, FOUR"; + let lseg = PgLSeg::from_str(input_str); + assert!(lseg.is_err()); + if let Err(err) = lseg { + assert_eq!( + err.to_string(), + format!("error decoding LSEG: could not get end_y from {input_str}") + ) + } + } + + #[test] + fn can_deserialise_lseg_type_str_float() { + let lseg = PgLSeg::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1.1, + start_y: 2.2, + end_x: 3.3, + end_y: 4.4 + } + ); + } + + #[test] + fn can_serialise_lseg_type() { + let lseg = PgLSeg { + start_x: 1.1, + start_y: 2.2, + end_x: 3.3, + end_y: 4.4, + }; + assert_eq!(lseg.serialize_to_vec(), LINE_SEGMENT_BYTES,) + } +} diff --git a/sqlx-postgres/src/types/geometry/mod.rs b/sqlx-postgres/src/types/geometry/mod.rs index daf9f1deb9..0da73fef08 100644 --- a/sqlx-postgres/src/types/geometry/mod.rs +++ b/sqlx-postgres/src/types/geometry/mod.rs @@ -1,2 +1,3 @@ pub mod line; +pub mod line_segment; pub mod point; diff --git a/sqlx-postgres/src/types/mod.rs b/sqlx-postgres/src/types/mod.rs index 747345518a..b5b3266cbc 100644 --- a/sqlx-postgres/src/types/mod.rs +++ b/sqlx-postgres/src/types/mod.rs @@ -23,6 +23,7 @@ //! | [`PgCube`] | CUBE | //! | [`PgPoint] | POINT | //! | [`PgLine] | LINE | +//! | [`PgLSeg] | LSEG | //! | [`PgHstore`] | HSTORE | //! //! 1 SQLx generally considers `CITEXT` to be compatible with `String`, `&str`, etc., @@ -259,6 +260,7 @@ pub use array::PgHasArrayType; pub use citext::PgCiText; pub use cube::PgCube; pub use geometry::line::PgLine; +pub use geometry::line_segment::PgLSeg; pub use geometry::point::PgPoint; pub use hstore::PgHstore; pub use interval::PgInterval; diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index 3f6c362043..c1cf87983c 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -509,6 +509,11 @@ test_type!(line(Postgres, "line('((0.0, 0.0), (1.0,1.0))')" == sqlx::postgres::types::PgLine { a: 1., b: -1., c: 0. }, )); +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(lseg(Postgres, + "lseg('((1.0, 2.0), (3.0,4.0))')" == sqlx::postgres::types::PgLSeg { start_x: 1., start_y: 2., end_x: 3. , end_y: 4.}, +)); + #[cfg(feature = "rust_decimal")] test_type!(decimal(Postgres, "0::numeric" == sqlx::types::Decimal::from_str("0").unwrap(),