Skip to content

Commit

Permalink
chore: find out ref columns in subquery
Browse files Browse the repository at this point in the history
  • Loading branch information
holicc committed Dec 9, 2024
1 parent dbd612a commit 45621e3
Show file tree
Hide file tree
Showing 12 changed files with 271 additions and 81 deletions.
9 changes: 8 additions & 1 deletion qurious/src/common/table_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ pub struct TableSchema {
}

impl TableSchema {
pub fn new(field_qualifiers: Vec<Option<TableRelation>>, schema: SchemaRef) -> Self {
Self {
field_qualifiers,
schema,
}
}

pub fn try_from_qualified_schema(relation: impl Into<TableRelation>, schema: SchemaRef) -> Result<Self> {
Ok(Self {
field_qualifiers: vec![Some(relation.into()); schema.fields().len()],
Expand Down Expand Up @@ -48,7 +55,7 @@ impl TableSchema {
.fields()
.iter()
.zip(self.field_qualifiers.iter())
.map(|(f, q)| Column::new(f.name(), q.clone()))
.map(|(f, q)| Column::new(f.name(), q.clone(), false))
.collect()
}
}
Expand Down
6 changes: 6 additions & 0 deletions qurious/src/common/transformed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,9 @@ fn apply_impl<'n, N: TransformNode, F: FnMut(&'n N) -> Result<TreeNodeRecursion>
) -> Result<TreeNodeRecursion> {
f(node)?.visit_children(|| node.apply_children(|c| apply_impl(c, f)))
}

pub trait TreeNodeContainer<'a, T: 'a> {
fn apply<F>(&'a self, f: F) -> Result<TreeNodeRecursion>
where
F: FnMut(&'a T) -> Result<TreeNodeRecursion>;
}
1 change: 1 addition & 0 deletions qurious/src/logical/expr/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ impl AggregateExpr {
LogicalExpr::Column(Column {
name: format!("{}({})", self.op, inner_col),
relation: None,
is_outer_ref: false,
})
})
}
Expand Down
6 changes: 5 additions & 1 deletion qurious/src/logical/expr/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ use super::LogicalExpr;
pub struct Column {
pub name: String,
pub relation: Option<TableRelation>,
pub is_outer_ref: bool,
}

impl Column {
pub fn new(name: impl Into<String>, relation: Option<impl Into<TableRelation>>) -> Self {
pub fn new(name: impl Into<String>, relation: Option<impl Into<TableRelation>>, is_outer_ref: bool) -> Self {
Self {
name: name.into(),
relation: relation.map(|r| r.into()),
is_outer_ref,
}
}

Expand Down Expand Up @@ -58,6 +60,7 @@ impl FromStr for Column {
Ok(Self {
name: s.to_string(),
relation: None,
is_outer_ref: false,
})
}
}
Expand All @@ -66,5 +69,6 @@ pub fn column(name: &str) -> LogicalExpr {
LogicalExpr::Column(Column {
name: name.to_string(),
relation: None,
is_outer_ref: false,
})
}
27 changes: 20 additions & 7 deletions qurious/src/logical/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub enum LogicalExpr {
IsNotNull(Box<LogicalExpr>),
Like(Like),
Negative(Box<LogicalExpr>),
SubQuery(Box<LogicalPlan>),
SubQuery(SubQuery),
}

macro_rules! impl_logical_expr_methods {
Expand Down Expand Up @@ -94,7 +94,7 @@ impl Display for LogicalExpr {
LogicalExpr::Function(function) => write!(f, "{function}",),
LogicalExpr::IsNull(logical_expr) => write!(f, "{} IS NULL", logical_expr),
LogicalExpr::IsNotNull(logical_expr) => write!(f, "{} IS NOT NULLni", logical_expr),
LogicalExpr::SubQuery(logical_plan) => write!(f, "(\n{})\n", utils::format(logical_plan, 5)),
LogicalExpr::SubQuery(subquery) => write!(f, "(\n{})\n", utils::format(&subquery.subquery, 5)),
LogicalExpr::Like(like) => {
if like.negated {
write!(f, "{} NOT LIKE {}", like.expr, like.pattern)
Expand All @@ -107,6 +107,14 @@ impl Display for LogicalExpr {
}

impl LogicalExpr {
pub fn qualified_name(&self) -> Option<TableRelation> {
match self {
LogicalExpr::Column(column) => column.relation.clone(),
LogicalExpr::Alias(alias) => Some(alias.name.clone().into()),
_ => None,
}
}

pub fn rebase_expr(self, base_exprs: &[&LogicalExpr]) -> Result<Self> {
self.transform(|nested_expr| {
if base_exprs.contains(&&nested_expr) {
Expand Down Expand Up @@ -162,7 +170,7 @@ impl LogicalExpr {
LogicalExpr::Column(_) => Ok(self.clone()),
LogicalExpr::AggregateExpr(agg) => agg.as_column(),
LogicalExpr::Literal(_) | LogicalExpr::Wildcard | LogicalExpr::BinaryExpr(_) => Ok(LogicalExpr::Column(
Column::new(format!("{}", self), None::<TableRelation>),
Column::new(format!("{}", self), None::<TableRelation>, false),
)),
_ => Err(Error::InternalError(format!("Expect column, got {:?}", self))),
}
Expand Down Expand Up @@ -196,7 +204,7 @@ impl LogicalExpr {
LogicalExpr::AggregateExpr(AggregateExpr { op, expr }) => op.infer_type(&expr.data_type(schema)?),
LogicalExpr::SortExpr(SortExpr { expr, .. }) | LogicalExpr::Negative(expr) => expr.data_type(schema),
LogicalExpr::Like(_) | LogicalExpr::IsNull(_) | LogicalExpr::IsNotNull(_) => Ok(DataType::Boolean),
LogicalExpr::SubQuery(plan) => Ok(plan.schema().fields[0].data_type().clone()),
LogicalExpr::SubQuery(subquery) => Ok(subquery.subquery.schema().fields[0].data_type().clone()),
_ => internal_err!("[{}] has no data type", self),
}
}
Expand Down Expand Up @@ -252,7 +260,12 @@ impl TransformNode for LogicalExpr {
LogicalExpr::IsNull(expr) => f(*expr)?.update(|expr| LogicalExpr::IsNull(Box::new(expr))),
LogicalExpr::IsNotNull(expr) => f(*expr)?.update(|expr| LogicalExpr::IsNotNull(Box::new(expr))),
LogicalExpr::Negative(expr) => f(*expr)?.update(|expr| LogicalExpr::Negative(Box::new(expr))),
LogicalExpr::SubQuery(plan) => plan.map_exprs(f)?.update(|plan| LogicalExpr::SubQuery(Box::new(plan))),
LogicalExpr::SubQuery(subquery) => subquery.subquery.map_exprs(f)?.update(|plan| {
LogicalExpr::SubQuery(SubQuery {
subquery: Box::new(plan),
outer_ref_columns: subquery.outer_ref_columns,
})
}),

LogicalExpr::Wildcard | LogicalExpr::Column(_) | LogicalExpr::Literal(_) => Transformed::no(self),
LogicalExpr::Like(like) => f(*like.expr)?.update(|expr| {
Expand All @@ -267,7 +280,7 @@ impl TransformNode for LogicalExpr {

fn apply_children<'n, F>(&'n self, mut f: F) -> Result<TreeNodeRecursion>
where
F: FnMut(&'n Self) -> Result<TreeNodeRecursion>,
F: FnMut(&'n LogicalExpr) -> Result<TreeNodeRecursion>,
{
let children = match self {
LogicalExpr::BinaryExpr(BinaryExpr { left, right, .. }) => vec![left.as_ref(), right.as_ref()],
Expand Down Expand Up @@ -312,6 +325,6 @@ pub struct Like {

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SubQuery {
pub subquery: Arc<LogicalPlan>,
pub subquery: Box<LogicalPlan>,
pub outer_ref_columns: Vec<LogicalExpr>,
}
74 changes: 70 additions & 4 deletions qurious/src/logical/plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ pub use sub_query::SubqueryAlias;

use arrow::datatypes::SchemaRef;

use super::expr::LogicalExpr;
use super::expr::{Column, LogicalExpr};
use crate::common::table_relation::TableRelation;
use crate::common::table_schema::TableSchemaRef;
use crate::common::transformed::{TransformNode, Transformed, TransformedResult};
use crate::common::transformed::{TransformNode, Transformed, TransformedResult, TreeNodeContainer, TreeNodeRecursion};
use crate::error::Result;

#[macro_export]
Expand Down Expand Up @@ -70,6 +70,36 @@ pub enum LogicalPlan {
}

impl LogicalPlan {
pub fn outer_ref_columns(&self) -> Result<Vec<LogicalExpr>> {
let mut outer_ref_columns = vec![];

let mut stack = vec![self];

while let Some(plan) = stack.pop() {
match plan.apply_exprs(|expr| {
expr.apply_children(|expr| {
match expr {
LogicalExpr::Column(Column { is_outer_ref: true, .. }) => {
outer_ref_columns.push(expr.clone());
}
_ => {}
}

Ok(TreeNodeRecursion::Continue)
})
})? {
TreeNodeRecursion::Continue => {
if let Some(children) = plan.children() {
stack.extend(children);
}
}
TreeNodeRecursion::Stop => return Ok(outer_ref_columns),
}
}

Ok(outer_ref_columns)
}

pub fn relation(&self) -> Option<TableRelation> {
match self {
LogicalPlan::TableScan(s) => Some(s.table_name.clone()),
Expand Down Expand Up @@ -103,6 +133,7 @@ impl LogicalPlan {
LogicalPlan::CrossJoin(s) => s.schema.clone(),
LogicalPlan::SubqueryAlias(s) => s.schema.clone(),
LogicalPlan::Filter(f) => f.input.table_schema(),
LogicalPlan::Projection(p) => p.schema.clone(),
_ => todo!("[{}] not implement table_schema", self),
}
}
Expand All @@ -124,6 +155,25 @@ impl LogicalPlan {
}
}

pub fn apply_exprs<F>(&self, mut f: F) -> Result<TreeNodeRecursion>
where
F: FnMut(&LogicalExpr) -> Result<TreeNodeRecursion>,
{
match self {
LogicalPlan::Projection(Projection { exprs, .. }) => exprs.apply(f),
LogicalPlan::Aggregate(Aggregate {
group_expr, aggr_expr, ..
}) => {
group_expr.apply(&mut f)?;
aggr_expr.apply(&mut f)?;

Ok(TreeNodeRecursion::Continue)
}
LogicalPlan::Filter(Filter { expr, .. }) => f(expr),
_ => Ok(TreeNodeRecursion::Continue),
}
}

pub fn map_exprs<F>(self, mut f: F) -> Result<Transformed<Self>>
where
F: FnMut(LogicalExpr) -> Result<Transformed<LogicalExpr>>,
Expand Down Expand Up @@ -211,14 +261,30 @@ impl TransformNode for LogicalPlan {
})
}

fn apply_children<'n, F>(&'n self, _f: F) -> Result<crate::common::transformed::TreeNodeRecursion>
fn apply_children<'n, F>(&'n self, _f: F) -> Result<TreeNodeRecursion>
where
F: FnMut(&'n Self) -> Result<crate::common::transformed::TreeNodeRecursion>,
F: FnMut(&'n LogicalPlan) -> Result<TreeNodeRecursion>,
{
todo!()
}
}

impl<'a, T: TransformNode + 'a> TreeNodeContainer<'a, T> for Vec<T> {
fn apply<F>(&'a self, mut f: F) -> Result<TreeNodeRecursion>
where
F: FnMut(&'a T) -> Result<TreeNodeRecursion>,
{
for child in self {
match child.apply(&mut f)? {
TreeNodeRecursion::Continue => {}
TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
}
}

Ok(TreeNodeRecursion::Continue)
}
}

pub fn base_plan(plan: &LogicalPlan) -> &LogicalPlan {
match plan {
LogicalPlan::Aggregate(Aggregate { input, .. }) => base_plan(&input),
Expand Down
25 changes: 16 additions & 9 deletions qurious/src/logical/plan/projection.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,38 @@
use arrow::datatypes::{FieldRef, Schema, SchemaRef};
use arrow::datatypes::{Schema, SchemaRef};

use crate::common::table_schema::{TableSchema, TableSchemaRef};
use crate::error::Result;
use crate::{logical::expr::LogicalExpr, logical::plan::LogicalPlan};
use std::fmt::Display;
use std::sync::Arc;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Projection {
pub schema: SchemaRef,
pub schema: TableSchemaRef,
pub input: Box<LogicalPlan>,
pub exprs: Vec<LogicalExpr>,
}

impl Projection {
pub fn try_new(input: LogicalPlan, exprs: Vec<LogicalExpr>) -> Result<Self> {
let mut field_qualifiers = vec![];
let mut fields = vec![];

for expr in &exprs {
field_qualifiers.push(expr.qualified_name());
fields.push(expr.field(&input)?);
}

let schema = TableSchema::new(field_qualifiers, Arc::new(Schema::new(fields)));

Ok(Self {
schema: exprs
.iter()
.map(|f| f.field(&input))
.collect::<Result<Vec<FieldRef>>>()
.map(|fields| Arc::new(Schema::new(fields)))?,
schema: Arc::new(schema),
input: Box::new(input),
exprs,
})
}

pub fn try_new_with_schema(input: LogicalPlan, exprs: Vec<LogicalExpr>, schema: SchemaRef) -> Result<Self> {
pub fn try_new_with_schema(input: LogicalPlan, exprs: Vec<LogicalExpr>, schema: TableSchemaRef) -> Result<Self> {
Ok(Self {
schema,
input: Box::new(input),
Expand All @@ -34,7 +41,7 @@ impl Projection {
}

pub fn schema(&self) -> SchemaRef {
self.schema.clone()
self.schema.arrow_schema()
}

pub fn children(&self) -> Option<Vec<&LogicalPlan>> {
Expand Down
2 changes: 1 addition & 1 deletion qurious/src/optimizer/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod count_wildcard_rule;
mod pushdown_filter_inner_join;
mod scalar_subquery_to_join;
// mod scalar_subquery_to_join;
mod type_coercion;

use crate::{error::Result, logical::plan::LogicalPlan};
Expand Down
11 changes: 8 additions & 3 deletions qurious/src/optimizer/pushdown_filter_inner_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::common::join_type::JoinType;
use crate::common::transformed::{TransformNode, Transformed, TransformedResult};
use crate::datatypes::operator::Operator;
use crate::error::{Error, Result};
use crate::logical::expr::{BinaryExpr, Column, LogicalExpr};
use crate::logical::expr::{BinaryExpr, Column, LogicalExpr, SubQuery};
use crate::logical::plan::{CrossJoin, Filter, LogicalPlan};
use crate::logical::LogicalPlanBuilder;

Expand Down Expand Up @@ -45,8 +45,13 @@ impl OptimizerRule for PushdownFilterInnerJoin {
.map_exprs(|expr| {
expr.transform(|expr| match expr {
LogicalExpr::SubQuery(query) => self
.optimize(*query)
.map(|rewritten_query| LogicalExpr::SubQuery(Box::new(rewritten_query)))
.optimize(*query.subquery)
.map(|rewritten_query| {
LogicalExpr::SubQuery(SubQuery {
subquery: Box::new(rewritten_query),
outer_ref_columns: query.outer_ref_columns,
})
})
.map(Transformed::yes),
_ => Ok(Transformed::no(expr)),
})
Expand Down
Loading

0 comments on commit 45621e3

Please sign in to comment.