Skip to content

Commit

Permalink
feat: HAVING clause
Browse files Browse the repository at this point in the history
  • Loading branch information
holicc committed Oct 22, 2024
1 parent 60eb0ee commit 30ad83b
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 103 deletions.
2 changes: 1 addition & 1 deletion qurious/src/execution/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ mod tests {
// session.sql("INSERT INTO test VALUES (1, 1), (2, 2), (3, 3), (3, 5), (NULL, NULL);")?;
// session.sql("select a, b, c, d from x join y on a = c")?;
println!("++++++++++++++");
let batch = session.sql("select v1, v2 from t order by v1 asc, v2 desc")?;
let batch = session.sql("select 42")?;

print_batches(&batch)?;

Expand Down
12 changes: 11 additions & 1 deletion qurious/src/logical/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;

use super::{
expr::{LogicalExpr, SortExpr},
plan::{Aggregate, CrossJoin, EmptyRelation, Join, Limit, LogicalPlan, Projection, Sort, TableScan},
plan::{Aggregate, CrossJoin, EmptyRelation, Filter, Join, Limit, LogicalPlan, Projection, Sort, TableScan},
};
use crate::{common::join_type::JoinType, provider::table::TableProvider};
use crate::{common::table_relation::TableRelation, error::Result};
Expand All @@ -27,6 +27,16 @@ impl LogicalPlanBuilder {
Projection::try_new(input, exprs.into_iter().map(|exp| exp.into()).collect()).map(LogicalPlan::Projection)
}

pub fn filter(input: LogicalPlan, predicate: LogicalExpr) -> Result<LogicalPlan> {
Filter::try_new(input, predicate).map(LogicalPlan::Filter)
}

pub fn having(self, predicate: LogicalExpr) -> Result<Self> {
Ok(LogicalPlanBuilder {
plan: LogicalPlan::Filter(Filter::try_new(self.plan, predicate)?.into()),
})
}

pub fn add_project(self, exprs: impl IntoIterator<Item = impl Into<LogicalExpr>>) -> Result<Self> {
Projection::try_new(self.plan, exprs.into_iter().map(|exp| exp.into()).collect())
.map(|s| LogicalPlanBuilder::from(LogicalPlan::Projection(s)))
Expand Down
12 changes: 11 additions & 1 deletion qurious/src/logical/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub use literal::*;
pub use sort::*;

use crate::common::table_relation::TableRelation;
use crate::common::transformed::{TransformNode, Transformed, TreeNodeRecursion};
use crate::common::transformed::{TransformNode, Transformed, TransformedResult, TreeNodeRecursion};
use crate::datatypes::scalar::ScalarValue;
use crate::error::{Error, Result};
use crate::internal_err;
Expand Down Expand Up @@ -97,6 +97,16 @@ impl Display for LogicalExpr {
}

impl LogicalExpr {
pub fn rebase_expr(self, base_exprs: &[&LogicalExpr]) -> Result<Self> {
self.transform(|nested_expr| {
if base_exprs.contains(&&nested_expr) {
return nested_expr.as_column().map(Transformed::yes);
}
Ok(Transformed::no(nested_expr))
})
.data()
}

pub fn using_columns(&self) -> HashSet<Column> {
let mut columns = HashSet::new();
let mut stack = vec![self];
Expand Down
156 changes: 97 additions & 59 deletions qurious/src/planner/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,48 +502,80 @@ impl<'a> SqlQueryPlanner<'a> {
let plan = self.table_scan_to_plan(select.from)?;
let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_));
// process the WHERE clause
let mut plan = self.filter_expr(plan, select.r#where)?;
let plan = self.filter_expr(plan, select.r#where)?;
// process the SELECT expressions
let column_exprs = self.column_exprs(&plan, empty_from, select.columns)?;
// sort exprs
let sort_exprs = self.order_by_exprs(select.order_by.unwrap_or_default())?;
// get aggregate expressions
let aggr_exprs = find_aggregate_exprs(&column_exprs);
// process the HAVING clause
let having = if let Some(having_expr) = select.having {
Some(self.sql_to_expr(having_expr)?)
} else {
None
};
let having = select.having.map(|expr| self.sql_to_expr(expr)).transpose()?;
// get aggregate expressions
let aggr_exprs = find_aggregate_exprs(column_exprs.iter().chain(having.iter()));
// process the GROUP BY clause or process aggregation in SELECT
if select.group_by.is_some() || !aggr_exprs.is_empty() {
let group_by_exprs = select
.group_by
.unwrap_or_default()
.into_iter()
.map(|expr| {
let col = self.sql_to_expr(expr)?;

col.transform(|expr| match expr {
LogicalExpr::Column(col) => {
if col.relation.is_none() {
if let Some(data) = self.get_column_alias(&col.name) {
return Ok(Transformed::yes(data));
let (mut plan, select_exprs_post_aggr, having_expr_post_aggr) =
if select.group_by.is_some() || !aggr_exprs.is_empty() {
let group_by_exprs = select
.group_by
.unwrap_or_default()
.into_iter()
.map(|expr| {
let col = self.sql_to_expr(expr)?;

col.transform(|expr| match expr {
LogicalExpr::Column(col) => {
if col.relation.is_none() {
if let Some(data) = self.get_column_alias(&col.name) {
return Ok(Transformed::yes(data));
}
}
Ok(Transformed::no(LogicalExpr::Column(col)))
}
Ok(Transformed::no(LogicalExpr::Column(col)))
}
_ => Ok(Transformed::no(expr)),
_ => Ok(Transformed::no(expr)),
})
.data()
})
.data()
})
.collect::<Result<_>>()?;
.collect::<Result<_>>()?;

let having = having
.map(|expr| {
expr.transform(|expr| match expr {
LogicalExpr::Column(col) => {
if col.relation.is_none() {
if let Some(data) = self.get_column_alias(&col.name) {
return Ok(Transformed::yes(data));
}
}
Ok(Transformed::no(LogicalExpr::Column(col)))
}
_ => Ok(Transformed::no(expr)),
})
.data()
})
.transpose()?;

plan = self.aggregate_plan(plan, column_exprs.clone(), aggr_exprs, group_by_exprs, having)?;
} else {
plan = LogicalPlanBuilder::project(plan, column_exprs)?;
}
self.aggregate_plan(plan, column_exprs.clone(), aggr_exprs, group_by_exprs, having)?
} else {
match having {
Some(having_expr) => {
// // allow scalar having
// having_expr.apply(|expr| {

// })
return internal_err!(
"HAVING clause [{having_expr}] requires a GROUP BY clause or be used in an aggregate function"
);
}
None => (plan, column_exprs, None),
}
};
// process the HAVE clause
if let Some(having_expr) = having_expr_post_aggr {
plan = LogicalPlanBuilder::from(plan)
.having(having_expr)
.map(|builder| builder.build())?;
}
// do the final projection
plan = LogicalPlanBuilder::project(plan, select_exprs_post_aggr)?;
// process the ORDER BY clause
let plan = if !sort_exprs.is_empty() {
LogicalPlanBuilder::from(plan)
Expand Down Expand Up @@ -692,29 +724,33 @@ impl<'a> SqlQueryPlanner<'a> {
select_exprs: Vec<LogicalExpr>,
aggr_exprs: Vec<LogicalExpr>,
group_exprs: Vec<LogicalExpr>,
_having_expr: Option<LogicalExpr>,
) -> Result<LogicalPlan> {
having: Option<LogicalExpr>,
) -> Result<(LogicalPlan, Vec<LogicalExpr>, Option<LogicalExpr>)> {
let agg_and_group_by_column_exprs = aggr_exprs.iter().chain(group_exprs.iter()).collect::<Vec<_>>();

let select_exprs_post_aggr = select_exprs
.into_iter()
.map(|expr| {
expr.transform(|nested_expr| {
// if expr is one of the group by columns or aggregate columns, we should convert to column
if agg_and_group_by_column_exprs.contains(&&nested_expr) {
return nested_expr.as_column().map(Transformed::yes);
}
Ok(Transformed::no(nested_expr))
})
.data()
})
.map(|expr| expr.rebase_expr(&agg_and_group_by_column_exprs))
.collect::<Result<Vec<_>>>()?;
let having_expr_post_aggr = having
.map(|expr| expr.rebase_expr(&agg_and_group_by_column_exprs))
.transpose()?;

let agg_and_group_columns = agg_and_group_by_column_exprs
.iter()
.map(|expr| expr.as_column())
.collect::<Result<Vec<_>>>()?;

for col_expr in select_exprs_post_aggr.iter().flat_map(find_columns_exprs) {
let mut check_columns = select_exprs_post_aggr
.iter()
.flat_map(find_columns_exprs)
.collect::<Vec<_>>();

if let Some(having_expr) = &having_expr_post_aggr {
check_columns.extend(find_columns_exprs(having_expr));
}

for col_expr in check_columns {
if !agg_and_group_columns.contains(&col_expr) {
return internal_err!("column [{}] must appear in the GROUP BY clause or be used in an aggregate function, validate columns: [{}]",
col_expr,
Expand All @@ -726,10 +762,11 @@ impl<'a> SqlQueryPlanner<'a> {
}
}

LogicalPlanBuilder::from(input)
.aggregate(group_exprs, aggr_exprs)?
.add_project(select_exprs_post_aggr)
.map(|plan| plan.build())
let plan = LogicalPlanBuilder::from(input)
.aggregate(group_exprs, aggr_exprs)
.map(|plan| plan.build())?;

Ok((plan, select_exprs_post_aggr, having_expr_post_aggr))
}

fn order_by_exprs(&self, order_by: Vec<(Expression, Order)>) -> Result<Vec<SortExpr>> {
Expand Down Expand Up @@ -1071,9 +1108,9 @@ fn find_columns_exprs(expr: &LogicalExpr) -> Vec<LogicalExpr> {
columns
}

fn find_aggregate_exprs(exprs: &Vec<LogicalExpr>) -> Vec<LogicalExpr> {
fn find_aggregate_exprs<'a>(exprs: impl IntoIterator<Item = &'a LogicalExpr>) -> Vec<LogicalExpr> {
exprs
.iter()
.into_iter()
.flat_map(|expr| {
let mut exprs = vec![];
expr.apply(|nested_expr| {
Expand Down Expand Up @@ -1393,17 +1430,18 @@ mod tests {

#[test]
fn test_group_by() {
let sql = "SELECT name,max(name) FROM person GROUP BY name";
let expected = "Projection: (person.name, MAX(person.name))\n Aggregate: group_expr=[person.name], aggregat_expr=[MAX(person.name)]\n TableScan: person\n";
quick_test(sql, expected);
quick_test("SELECT name FROM person HAVING count(name) > 1", "Internal Error: column [person.name] must appear in the GROUP BY clause or be used in an aggregate function, validate columns: [COUNT(person.name)]");

quick_test(
"SELECT name FROM person WHERE name = 'abc' GROUP BY name HAVING count(name) > 1",
"Projection: (person.name)\n Filter: COUNT(person.name) > Int64(1)\n Aggregate: group_expr=[person.name], aggregat_expr=[COUNT(person.name)]\n Filter: person.name = Utf8('abc')\n TableScan: person\n",
);

quick_test("SELECT name,max(name) FROM person GROUP BY name", "Projection: (person.name, MAX(person.name))\n Aggregate: group_expr=[person.name], aggregat_expr=[MAX(person.name)]\n TableScan: person\n");

let sql = "SELECT name, COUNT(*) FROM person GROUP BY name";
let expected = "Projection: (person.name, COUNT(*))\n Aggregate: group_expr=[person.name], aggregat_expr=[COUNT(*)]\n TableScan: person\n";
quick_test(sql, expected);
quick_test("SELECT name, COUNT(*) FROM person GROUP BY name", "Projection: (person.name, COUNT(*))\n Aggregate: group_expr=[person.name], aggregat_expr=[COUNT(*)]\n TableScan: person\n");

let sql = "SELECT * FROM person GROUP BY name";
let expected = "Internal Error: column [person.age] must appear in the GROUP BY clause or be used in an aggregate function, validate columns: [person.name]";
quick_test(sql, expected);
quick_test("SELECT * FROM person GROUP BY name", "Internal Error: column [person.age] must appear in the GROUP BY clause or be used in an aggregate function, validate columns: [person.name]");
}

#[test]
Expand Down
11 changes: 0 additions & 11 deletions qurious/tests/having.slt → qurious/tests/sql/having.slt
Original file line number Diff line number Diff line change
@@ -1,14 +1,3 @@
# scalar having
query I
select 42 having 42 > 18
----
42

# scalar having
query I
select 42 having 42 > 801
----

statement ok
CREATE TABLE test (x INT, y INT);

Expand Down
Loading

0 comments on commit 30ad83b

Please sign in to comment.