Skip to content

Commit

Permalink
fix: rewrite fetch, skip of the Limit node in correct order (#14496)
Browse files Browse the repository at this point in the history
* fix: rewrite fetch, skip of the Limit node in correct order

* style: fix clippy
  • Loading branch information
evenyag authored Feb 6, 2025
1 parent d1308f0 commit 7fd04a3
Showing 1 changed file with 46 additions and 18 deletions.
64 changes: 46 additions & 18 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -959,8 +959,9 @@ impl LogicalPlan {
expr.len()
);
}
let new_skip = skip.as_ref().and_then(|_| expr.pop());
// `LogicalPlan::expressions()` returns in [skip, fetch] order, so we can pop from the end.
let new_fetch = fetch.as_ref().and_then(|_| expr.pop());
let new_skip = skip.as_ref().and_then(|_| expr.pop());
let input = self.only_input(inputs)?;
Ok(LogicalPlan::Limit(Limit {
skip: new_skip.map(Box::new),
Expand Down Expand Up @@ -4293,23 +4294,50 @@ digraph {

#[test]
fn test_limit_with_new_children() {
let limit = LogicalPlan::Limit(Limit {
skip: None,
fetch: Some(Box::new(Expr::Literal(
ScalarValue::new_ten(&DataType::UInt32).unwrap(),
))),
input: Arc::new(LogicalPlan::Values(Values {
schema: Arc::new(DFSchema::empty()),
values: vec![vec![]],
})),
});
let new_limit = limit
.with_new_exprs(
limit.expressions(),
limit.inputs().into_iter().cloned().collect(),
)
.unwrap();
assert_eq!(limit, new_limit);
let input = Arc::new(LogicalPlan::Values(Values {
schema: Arc::new(DFSchema::empty()),
values: vec![vec![]],
}));
let cases = [
LogicalPlan::Limit(Limit {
skip: None,
fetch: None,
input: Arc::clone(&input),
}),
LogicalPlan::Limit(Limit {
skip: None,
fetch: Some(Box::new(Expr::Literal(
ScalarValue::new_ten(&DataType::UInt32).unwrap(),
))),
input: Arc::clone(&input),
}),
LogicalPlan::Limit(Limit {
skip: Some(Box::new(Expr::Literal(
ScalarValue::new_ten(&DataType::UInt32).unwrap(),
))),
fetch: None,
input: Arc::clone(&input),
}),
LogicalPlan::Limit(Limit {
skip: Some(Box::new(Expr::Literal(
ScalarValue::new_one(&DataType::UInt32).unwrap(),
))),
fetch: Some(Box::new(Expr::Literal(
ScalarValue::new_ten(&DataType::UInt32).unwrap(),
))),
input,
}),
];

for limit in cases {
let new_limit = limit
.with_new_exprs(
limit.expressions(),
limit.inputs().into_iter().cloned().collect(),
)
.unwrap();
assert_eq!(limit, new_limit);
}
}

#[test]
Expand Down

0 comments on commit 7fd04a3

Please sign in to comment.