From d2f108e9d2d9fd8c3077bcdf8b2b43f7ba053b1c Mon Sep 17 00:00:00 2001 From: joe Date: Wed, 5 Jun 2024 23:29:07 +0800 Subject: [PATCH] WIP: Support Postgresql Protocol --- Cargo.toml | 7 +-- src/datasource/adbc.rs | 83 -------------------------------- src/datasource/mod.rs | 1 - src/error.rs | 7 --- src/lib.rs | 1 + src/server/mod.rs | 2 + src/server/postgresql/handler.rs | 17 +++++++ src/server/postgresql/mod.rs | 5 ++ src/server/postgresql/server.rs | 77 +++++++++++++++++++++++++++++ src/server/server.rs | 5 ++ 10 files changed, 111 insertions(+), 94 deletions(-) delete mode 100644 src/datasource/adbc.rs create mode 100644 src/server/mod.rs create mode 100644 src/server/postgresql/handler.rs create mode 100644 src/server/postgresql/mod.rs create mode 100644 src/server/postgresql/server.rs create mode 100644 src/server/server.rs diff --git a/Cargo.toml b/Cargo.toml index 50335cd..2bb149f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,10 +8,11 @@ edition = "2021" arrow = "51.0.0" sqlparser = { git = "https://github.com/holicc/sqlparser.git" } url = "2.5.0" -adbc_core = { git = "https://github.com/alexandreyc/adbc-rs", branch = "main", features = [ - "driver_manager", -] } parquet = "51.0.0" +pgwire = "0.22.0" +tokio = { version = "1.37.0", features = ["macros","time"] } +async-trait = "0.1.80" +log = "0.4.21" [dev-dependencies] arrow = { version = "51.0.0", features = ["prettyprint", "test_utils"] } diff --git a/src/datasource/adbc.rs b/src/datasource/adbc.rs deleted file mode 100644 index e809e89..0000000 --- a/src/datasource/adbc.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::sync::Arc; - -use crate::error::Result; -use adbc_core::{ - driver_manager::ManagedDriver, - options::{AdbcVersion, OptionDatabase}, - Connection, Database, Driver, Statement, -}; - -use super::{memory::MemoryDataSource, DataSource}; - -pub enum AdbcDriver { - Postgres, -} - -impl AdbcDriver { - pub fn get_dirver(&self) -> Result { - ManagedDriver::load_dynamic_from_name("adbc_driver_postgresql", None, AdbcVersion::V110) - } -} - -pub struct PostgresOption { - driver: AdbcDriver, - url: String, - user: Option, - password: Option, -} - -impl PostgresOption { - pub fn url(url: &str) -> PostgresOption { - PostgresOption { - driver: AdbcDriver::Postgres, - url: url.into(), - user: None, - password: None, - } - } -} - -pub fn read_postgres(sql: &str, ops: PostgresOption) -> Result> { - let mut driver = ops.driver.get_dirver()?; - let mut opts = vec![]; - // add url - opts.push((OptionDatabase::Uri, ops.url.into())); - // add user - if let Some(user) = ops.user { - opts.push((OptionDatabase::Username, user.into())); - } - // add password - if let Some(password) = ops.password { - opts.push((OptionDatabase::Password, password.into())); - } - let mut db = driver.new_database_with_opts(opts)?; - let mut conn = db.new_connection()?; - let mut stmt = conn.new_statement()?; - - stmt.set_sql_query(sql)?; - - let schema = stmt.execute_schema()?; - stmt.execute().map(|reader| { - let batch = reader.collect::, arrow::error::ArrowError>>()?; - Ok(Arc::new(MemoryDataSource::new(Arc::new(schema), batch)) as Arc) - })? -} - -#[cfg(test)] -mod tests { - use crate::datasource::adbc::{read_postgres, PostgresOption}; - - #[test] - fn test_read_postgres() { - let sql = "SELECT * FROM public.schools"; - let source = read_postgres(sql, PostgresOption::url("postgres://root:root@localhost:5433/qurious")).unwrap(); - - let schema = source.schema(); - - let batch = source.scan(None, &[]).unwrap(); - - println!("{}", arrow::util::pretty::pretty_format_batches(&batch).unwrap()); - - assert_eq!(schema.fields().len(), 2); - } -} diff --git a/src/datasource/mod.rs b/src/datasource/mod.rs index ab7b0f3..db46594 100644 --- a/src/datasource/mod.rs +++ b/src/datasource/mod.rs @@ -1,4 +1,3 @@ -pub mod adbc; pub mod file; pub mod memory; diff --git a/src/error.rs b/src/error.rs index 33c65b3..8862d19 100644 --- a/src/error.rs +++ b/src/error.rs @@ -35,13 +35,6 @@ impl From for Error { } } -impl From for Error{ - fn from(e: adbc_core::error::Error) -> Self { - Error::InternalError(e.to_string()) - } - -} - impl Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/src/lib.rs b/src/lib.rs index 3654d0f..d5c55b6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub mod optimizer; pub mod physical; pub mod planner; pub mod utils; +pub mod server; #[cfg(test)] pub mod test_utils; diff --git a/src/server/mod.rs b/src/server/mod.rs new file mode 100644 index 0000000..6e84a07 --- /dev/null +++ b/src/server/mod.rs @@ -0,0 +1,2 @@ +pub mod postgresql; +pub mod server; diff --git a/src/server/postgresql/handler.rs b/src/server/postgresql/handler.rs new file mode 100644 index 0000000..3377159 --- /dev/null +++ b/src/server/postgresql/handler.rs @@ -0,0 +1,17 @@ +use async_trait::async_trait; +use pgwire::{ + api::{query::SimpleQueryHandler, results::Response, ClientInfo}, + error::PgWireResult, +}; + +pub struct PostgresqlHandler; + +#[async_trait] +impl SimpleQueryHandler for PostgresqlHandler { + async fn do_query<'a, C>(&self, _client: &mut C, sql: &'a str) -> PgWireResult>> + where + C: ClientInfo + Unpin + Send + Sync, + { + todo!("Implement PostgresqlHandler::do_query()") + } +} diff --git a/src/server/postgresql/mod.rs b/src/server/postgresql/mod.rs new file mode 100644 index 0000000..1b2d0be --- /dev/null +++ b/src/server/postgresql/mod.rs @@ -0,0 +1,5 @@ +mod handler; +mod server; + +pub use server::PostgresqlServer; +pub use handler::PostgresqlHandler; diff --git a/src/server/postgresql/server.rs b/src/server/postgresql/server.rs new file mode 100644 index 0000000..5fd6edd --- /dev/null +++ b/src/server/postgresql/server.rs @@ -0,0 +1,77 @@ +use crate::error::Result; +use log::error; +use pgwire::api::MakeHandler; +use pgwire::api::{auth::noop::NoopStartupHandler, query::PlaceholderExtendedQueryHandler, StatelessMakeHandler}; +use std::{net::SocketAddr, sync::Arc}; + +use super::PostgresqlHandler; + +pub struct PostgresqlServer { + addr: SocketAddr, +} + +impl PostgresqlServer { + pub fn new(addr: SocketAddr) -> Self { + PostgresqlServer { addr } + } +} + +impl PostgresqlServer { + pub async fn start(&self) -> Result<()> { + tokio::spawn(Self::listen(self.addr)); + + Ok(()) + } + + pub fn shutdown(&self) -> Result<()> { + todo!("Implement PostgresqlServer::shutdown()") + } + + async fn listen(addr: SocketAddr) { + let listener = tokio::net::TcpListener::bind(addr) + .await + .unwrap_or_else(|e| panic!("PostgreSQL Server bind fail. err: {}", e)); + + let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler))); + let processor = Arc::new(StatelessMakeHandler::new(Arc::new(PostgresqlHandler))); + let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new(PlaceholderExtendedQueryHandler))); + + loop { + tokio::select! { + peer = listener.accept() => { + match peer { + Ok((socket, _)) => { + tokio::spawn(pgwire::tokio::process_socket( + socket, + None, + authenticator.make(), + processor.make(), + placeholder.make(), + )); + } + Err(e) => { + error!("PostgreSQL Server accept new connection fail. err: {}", e); + } + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::SocketAddr; + use std::str::FromStr; + + #[tokio::test] + async fn test_postgresql_server() { + let addr = SocketAddr::from_str("127.0.0.1:5434").unwrap(); + let server = PostgresqlServer::new(addr); + server.start().await.unwrap(); + + // wait + tokio::time::sleep(tokio::time::Duration::from_secs(10000000)).await; + } +} diff --git a/src/server/server.rs b/src/server/server.rs new file mode 100644 index 0000000..036e2ba --- /dev/null +++ b/src/server/server.rs @@ -0,0 +1,5 @@ +use crate::server::postgresql; + +pub struct Server { + postgres: postgresql::PostgresqlServer, +}