Skip to content

Commit

Permalink
WIP: support pgwrie ExtendedQueryHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
holicc committed Jun 27, 2024
1 parent c2d1614 commit 2cfd730
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 73 deletions.
20 changes: 12 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,26 @@ edition = "2021"


[dependencies]
arrow = "51.0.0"
# sqlparser
sqlparser = { git = "https://github.com/holicc/sqlparser.git" }
# arrow
parquet = "52.0.0"
arrow = "52.0.0"
# file
url = "2.5.0"
parquet = "51.0.0"

# async
tokio = { version = "1.37.0", features = ["full"] }
async-trait = "0.1.80"
log = "0.4.21"
tokio-stream = "0.1.15"

# postgres
pgwire = { version = "0.23.0", optional = true }
tokio-postgres = { version = "0.7.10", optional = true }
chrono = "0.4.38"
futures = "0.3.30"
# log
log = "0.4.21"

[features]
postgresql = ["dep:pgwire", "dep:tokio-postgres"]
postgresql = ["dep:pgwire"]

[dev-dependencies]
arrow = { version = "51.0.0", features = ["prettyprint", "test_utils"] }
arrow = { version = "52.0.0", features = ["prettyprint", "test_utils"] }
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
check:
cargo check --all-features
test-pg:
docker compose up -d && cargo test --features postgres && docker compose down
22 changes: 19 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use arrow::error::ArrowError;
use std::fmt::Display;

use arrow::error::ArrowError;
use parquet::errors::ParquetError;

pub type Result<T, E = Error> = std::result::Result<T, E>;

#[derive(Debug)]
Expand All @@ -16,8 +18,22 @@ pub enum Error {
TableNotFound(String),
}

impl<T: std::error::Error> From<T> for Error {
fn from(value: T) -> Self {
impl std::error::Error for Error {}

impl From<ArrowError> for Error {
fn from(e: ArrowError) -> Self {
Error::ArrowError(e)
}
}

impl From<std::io::Error> for Error {
fn from(value: std::io::Error) -> Self {
Error::InternalError(value.to_string())
}
}

impl From<ParquetError> for Error {
fn from(value: ParquetError) -> Self {
Error::InternalError(value.to_string())
}
}
Expand Down
11 changes: 9 additions & 2 deletions src/execution/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use arrow::array::RecordBatch;

use super::registry::TableRegistry;
use crate::datasource::DataSource;
use crate::error::Error;
use crate::execution::registry::HashMapTableRegistry;
use crate::planner::sql::SqlQueryPlanner;
use crate::planner::QueryPlanner;
Expand Down Expand Up @@ -31,8 +32,14 @@ impl ExecuteSession {
}

pub fn register_table(&mut self, name: &str, table: Arc<dyn DataSource>) -> Result<()> {
let mut write = self.tables.write()?;
write.register_table(name, table)
self.tables
.write()
.map_err(|e| Error::InternalError(e.to_string()))?
.register_table(name, table)
}

pub(crate) fn get_tables(&self) -> Arc<RwLock<dyn TableRegistry>> {
self.tables.clone()
}
}

Expand Down
24 changes: 17 additions & 7 deletions src/planner/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ use self::alias::Alias;

use super::{normalize_col_with_schemas_and_ambiguity_check, TableSchemaInfo};

pub struct SqlQueryPlanner<'a> {
pub struct SqlQueryPlanner {
table_registry: Arc<RwLock<dyn TableRegistry>>,
ctes: HashMap<String, Arc<LogicalPlan>>,
new_tables: HashMap<String, Arc<dyn DataSource>>,
relations: HashMap<TableRelation<'a>, TableSchemaInfo>,
relations: HashMap<OwnedTableRelation, TableSchemaInfo>,
}

impl<'a> SqlQueryPlanner<'a> {
impl SqlQueryPlanner {
pub fn create_logical_plan(table_registry: Arc<RwLock<dyn TableRegistry>>, sql: &str) -> Result<LogicalPlan> {
let mut planner = SqlQueryPlanner {
table_registry,
Expand Down Expand Up @@ -340,9 +340,7 @@ impl<'a> SqlQueryPlanner<'a> {
})
.collect()
}
}

impl<'a> SqlQueryPlanner<'a> {
fn cte_tables(&mut self, ctes: Vec<Cte>) -> Result<()> {
for cte in ctes {
let plan = self
Expand Down Expand Up @@ -522,7 +520,16 @@ impl<'a> SqlQueryPlanner<'a> {
}
}

impl<'a> SqlQueryPlanner<'a> {
impl SqlQueryPlanner {
pub fn new(table_registry: Arc<RwLock<dyn TableRegistry>>) -> Self {
SqlQueryPlanner {
table_registry,
ctes: HashMap::default(),
new_tables: HashMap::default(),
relations: HashMap::default(),
}
}

fn add_cte_table(&mut self, name: &str, plan: Arc<LogicalPlan>) {
let cte_table_name = name.to_owned();
self.relations.insert(
Expand Down Expand Up @@ -553,7 +560,10 @@ impl<'a> SqlQueryPlanner<'a> {
}

fn get_table_source(&self, table_name: &str) -> Result<Arc<dyn DataSource>> {
self.table_registry.read()?.get_table_source(table_name)
self.table_registry
.read()
.map_err(|e| Error::InternalError(e.to_string()))?
.get_table_source(table_name)
}

fn get_cte_table(&self, name: &str) -> Option<LogicalPlan> {
Expand Down
3 changes: 1 addition & 2 deletions src/server/postgresql/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::sync::Arc;

use arrow::{array::*, datatypes::*};
use chrono::{NaiveDate, NaiveDateTime};
use futures::{stream, Stream, StreamExt};
use pgwire::{
api::{
results::{DataRowEncoder, FieldInfo, QueryResponse, Response},
Expand All @@ -11,7 +10,7 @@ use pgwire::{
error::{ErrorInfo, PgWireError, PgWireResult},
messages::data::DataRow,
};
use tokio_postgres::row;
use tokio_stream::{self as stream, StreamExt};

fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> bool {
arr.as_any().downcast_ref::<BooleanArray>().unwrap().value(idx)
Expand Down
64 changes: 40 additions & 24 deletions src/server/postgresql/handler.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
use std::{sync::Arc, vec};
use std::{
sync::{Arc, Mutex, RwLock},
vec,
};

use crate::{
error::{Error, Result},
server::server::send_and_receive,
execution::{registry::TableRegistry, session::ExecuteSession},
logical::plan::LogicalPlan,
planner::sql::SqlQueryPlanner,
};
use arrow::array::RecordBatch;
use async_trait::async_trait;
use pgwire::{
api::{
auth::noop::NoopStartupHandler,
copy::NoopCopyHandler,
portal::{Format, Portal},
query::{ExtendedQueryHandler, SimpleQueryHandler},
query::{ExtendedQueryHandler, PlaceholderExtendedQueryHandler, SimpleQueryHandler},
results::{DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response},
stmt::{NoopQueryParser, StoredStatement},
ClientInfo, PgWireHandlerFactory,
stmt::{NoopQueryParser, QueryParser, StoredStatement},
ClientInfo, PgWireHandlerFactory, Type,
},
error::{ErrorInfo, PgWireError, PgWireResult},
};
Expand All @@ -29,15 +33,15 @@ pub struct HandlerFactory(pub Arc<PostgresqlHandler>);
impl PgWireHandlerFactory for HandlerFactory {
type StartupHandler = NoopStartupHandler;
type SimpleQueryHandler = PostgresqlHandler;
type ExtendedQueryHandler = PostgresqlHandler;
type ExtendedQueryHandler = PlaceholderExtendedQueryHandler;
type CopyHandler = NoopCopyHandler;

fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
self.0.clone()
}

fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
self.0.clone()
Arc::new(PlaceholderExtendedQueryHandler)
}

fn startup_handler(&self) -> Arc<Self::StartupHandler> {
Expand All @@ -50,7 +54,7 @@ impl PgWireHandlerFactory for HandlerFactory {
}

pub struct PostgresqlHandler {
pub tx: Sender<Message>,
pub(crate) session: Arc<ExecuteSession>,
}

#[async_trait]
Expand All @@ -59,8 +63,8 @@ impl SimpleQueryHandler for PostgresqlHandler {
where
C: ClientInfo + Unpin + Send + Sync,
{
send_and_receive(self.tx.clone(), sql.to_owned())
.await
self.session
.sql(sql)
.map_err(|e| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_owned(),
Expand All @@ -75,12 +79,12 @@ impl SimpleQueryHandler for PostgresqlHandler {

#[async_trait]
impl ExtendedQueryHandler for PostgresqlHandler {
type Statement = String;
type Statement = LogicalPlan;

type QueryParser = NoopQueryParser;
type QueryParser = Parser;

fn query_parser(&self) -> Arc<Self::QueryParser> {
Arc::new(NoopQueryParser)
Arc::new(Parser(self.session.get_tables()))
}

async fn do_describe_statement<C>(
Expand Down Expand Up @@ -110,15 +114,27 @@ impl ExtendedQueryHandler for PostgresqlHandler {
where
C: ClientInfo + Unpin + Send + Sync,
{
send_and_receive(self.tx.clone(), portal.statement.statement.clone())
.await
.map_err(|e| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_owned(),
"28P01".to_owned(),
e.to_string(),
)))
})
.and_then(|batch| into_pg_reponse(batch))
// send_and_receive(self.tx.clone(), portal.statement.statement.clone())
// .await
// .map_err(|e| {
// PgWireError::UserError(Box::new(ErrorInfo::new(
// "FATAL".to_owned(),
// "28P01".to_owned(),
// e.to_string(),
// )))
// })
// .and_then(|batch| into_pg_reponse(batch))
todo!()
}
}

pub struct Parser(Arc<RwLock<dyn TableRegistry>>);

#[async_trait]
impl QueryParser for Parser {
type Statement = LogicalPlan;

async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
SqlQueryPlanner::create_logical_plan(self.0.clone(), sql).map_err(|e| PgWireError::ApiError(Box::new(e)))
}
}
45 changes: 24 additions & 21 deletions src/server/postgresql/server.rs
Original file line number Diff line number Diff line change
@@ -1,58 +1,61 @@
use crate::error::{Error, Result};
use crate::execution::session::ExecuteSession;
use crate::server::server::Message;
use log::error;
use pgwire::api::auth::noop::NoopStartupHandler;
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::mpsc::{self, Sender};
use tokio_postgres::{Client, NoTls};

use super::handler::HandlerFactory;
use super::PostgresqlHandler;

pub struct PostgresqlServer {
session: Arc<ExecuteSession>,
tx: Sender<Message>,
addr: SocketAddr,
}

impl PostgresqlServer {
pub fn try_new(tx: Sender<Message>, svr_addr: SocketAddr) -> Result<Self> {
Ok(PostgresqlServer { tx, addr: svr_addr })
pub fn try_new(session: Arc<ExecuteSession>, tx: Sender<Message>, svr_addr: SocketAddr) -> Result<Self> {
Ok(PostgresqlServer {
session,
tx,
addr: svr_addr,
})
}

async fn connect_pg_backend(url: &str) -> Result<Client> {
let (cli, connection) = tokio_postgres::connect(url, NoTls)
.await
.map_err(|e| Error::InternalError(e.to_string()))?;
// async fn connect_pg_backend(url: &str) -> Result<Client> {
// let (cli, connection) = tokio_postgres::connect(url, NoTls)
// .await
// .map_err(|e| Error::InternalError(e.to_string()))?;

// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
// // The connection object performs the actual communication with the database,
// // so spawn it off to run on its own.
// tokio::spawn(async move {
// if let Err(e) = connection.await {
// eprintln!("connection error: {}", e);
// }
// });

Ok(cli)
}
// Ok(cli)
// }
}

impl PostgresqlServer {
pub async fn start(&self) -> Result<()> {
// tokio::spawn();
Self::listen(self.tx.clone(), self.addr).await;
Self::listen(self.session.clone(), self.tx.clone(), self.addr).await;
Ok(())
}

pub fn shutdown(&self) -> Result<()> {
todo!("Implement PostgresqlServer::shutdown()")
}

async fn listen(tx: mpsc::Sender<Message>, addr: SocketAddr) {
async fn listen(session: Arc<ExecuteSession>, tx: mpsc::Sender<Message>, addr: SocketAddr) {
let listener = tokio::net::TcpListener::bind(addr)
.await
.unwrap_or_else(|e| panic!("PostgreSQL Server bind fail. err: {}", e));

let processor = Arc::new(HandlerFactory(Arc::new(PostgresqlHandler { tx })));
let processor = Arc::new(HandlerFactory(Arc::new(PostgresqlHandler { session })));

loop {
tokio::select! {
Expand Down
6 changes: 0 additions & 6 deletions src/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,3 @@ impl Server {
todo!()
}
}

pub(crate) async fn send_and_receive(tx: mpsc::Sender<Message>, sql: String) -> Result<Vec<RecordBatch>> {
let (otx, orx) = oneshot::channel();
tx.send(Message::Query { sql, resp: otx }).await?;
orx.await?
}

0 comments on commit 2cfd730

Please sign in to comment.