Skip to content

Commit

Permalink
Fix nested count optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
Dandandan committed Dec 7, 2023
1 parent 33fc110 commit e7daf66
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
3 changes: 3 additions & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3134,6 +3134,9 @@ digraph {
.is_nullable());
}




#[test]
fn test_filter_is_scalar() {
// test empty placeholder
Expand Down
43 changes: 40 additions & 3 deletions datafusion/optimizer/src/optimize_projections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ fn optimize_projections(
let new_group_bys = aggregate.group_expr.clone();

// Only use absolutely necessary aggregate expressions required by parent.
let new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs);
let mut new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs);
let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter());
let necessary_indices =
indices_referred_by_exprs(&aggregate.input, all_exprs_iter)?;
Expand All @@ -213,6 +213,11 @@ fn optimize_projections(
let (aggregate_input, _is_added) =
add_projection_on_top_if_helpful(aggregate_input, necessary_exprs, true)?;

// Aggregate always needs at least one aggregate expr
if new_aggr_expr.len() == 0 && new_group_bys.len() == 0 && aggregate.aggr_expr.len() > 0 {
new_aggr_expr = vec![aggregate.aggr_expr[0].clone()];
}

// Create new aggregate plan with updated input, and absolutely necessary fields.
return Aggregate::try_new(
Arc::new(aggregate_input),
Expand Down Expand Up @@ -857,10 +862,11 @@ fn rewrite_projection_given_requirements(
#[cfg(test)]
mod tests {
use crate::optimize_projections::OptimizeProjections;
use datafusion_common::Result;
use arrow::datatypes::{Schema, Field, DataType};
use datafusion_common::{Result, TableReference};
use datafusion_expr::{
binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan,
Operator,
Operator, count, Expr, table_scan,
};
use std::sync::Arc;

Expand Down Expand Up @@ -909,4 +915,35 @@ mod tests {
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_nested_count() -> Result<()> {
let schema = Schema::new(vec![
Field::new("foo", DataType::Int32, false),
]);

let groups : Vec<Expr> = vec![];

let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.aggregate(
groups.clone(),
vec![count(lit(1))],
)
.unwrap()
.aggregate(
groups,
vec![count(lit(1))],
)
.unwrap()
.build()
.unwrap();

// let output_schema = plan.schema();

let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\
\n Projection: \
\n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\
\n TableScan: ?table? projection=[]";
assert_optimized_plan_equal(&plan, expected)
}
}

0 comments on commit e7daf66

Please sign in to comment.