Skip to content

Commit

Permalink
Support filter in cross join elimination
Browse files Browse the repository at this point in the history
  • Loading branch information
Dandandan committed Oct 20, 2024
1 parent 80a9e7f commit 228d747
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions datafusion/optimizer/src/eliminate_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::{OptimizerConfig, OptimizerRule};

use crate::join_key_set::JoinKeySet;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{internal_err, Result};
use datafusion_common::Result;
use datafusion_expr::expr::{BinaryExpr, Expr};
use datafusion_expr::logical_plan::{
Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
Expand Down Expand Up @@ -269,13 +269,7 @@ fn flatten_join_inputs(
fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
// can only flatten inner / cross joins
match plan {
LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
// The filter of inner join will lost, skip this rule.
// issue: https://github.com/apache/datafusion/issues/4844
// if join.filter.is_some() {
// return false;
// }
}
LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {}
LogicalPlan::CrossJoin(_) => {}
_ => return false,
};
Expand Down Expand Up @@ -483,12 +477,6 @@ mod tests {
assert_eq!(&starting_schema, optimized_plan.schema())
}

fn assert_optimization_rule_fails(plan: LogicalPlan) {
let rule = EliminateCrossJoin::new();
let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
assert!(!transformed_plan.transformed)
}

#[test]
fn eliminate_cross_with_simple_and() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
Expand Down Expand Up @@ -658,7 +646,6 @@ mod tests {
}

#[test]
/// See https://github.com/apache/datafusion/issues/7530
fn eliminate_cross_not_possible_nested_inner_join_with_filter() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
Expand All @@ -676,7 +663,17 @@ mod tests {
.filter(col("t1.a").gt(lit(15u32)))?
.build()?;

assert_optimization_rule_fails(plan);
let expected = vec![
"Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"
];

assert_optimized_plan_eq(plan, expected);

Ok(())
}
Expand Down

0 comments on commit 228d747

Please sign in to comment.