Skip to content

Commit

Permalink
ARROW-4464: [Rust] [DataFusion] Add support for LIMIT
Browse files Browse the repository at this point in the history
Author: Nicolas Trinquier <[email protected]>

Closes apache#3669 from ntrinquier/arrow-4464 and squashes the following commits:

facc5c2 <Nicolas Trinquier> Add Limit case to ProjectionPushDown
2ed488c <Nicolas Trinquier> Merge remote-tracking branch 'upstream/master' into arrow-4464
c78ae2c <Nicolas Trinquier> Use the previous batch's schema for Limit
e93df93 <Nicolas Trinquier> Remove redundant variable
dbc639f <Nicolas Trinquier> Make limit an usize and avoid evaluting the limit expression
eac5a24 <Nicolas Trinquier> Add support for Limit
  • Loading branch information
Nicolas Trinquier authored and andygrove committed Feb 17, 2019
1 parent 5e2445b commit 811c7dc
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 2 deletions.
36 changes: 35 additions & 1 deletion rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::collections::HashMap;
use std::rc::Rc;
use std::sync::Arc;

use arrow::datatypes::{Field, Schema};
use arrow::datatypes::*;

use super::super::dfparser::{DFASTNode, DFParser};
use super::super::logicalplan::*;
Expand All @@ -30,6 +30,7 @@ use super::datasource::DataSource;
use super::error::{ExecutionError, Result};
use super::expression::*;
use super::filter::FilterRelation;
use super::limit::LimitRelation;
use super::projection::ProjectRelation;
use super::relation::{DataSourceRelation, Relation};

Expand Down Expand Up @@ -160,6 +161,39 @@ impl ExecutionContext {

Ok(Rc::new(RefCell::new(rel)))
}
LogicalPlan::Limit {
ref expr,
ref input,
..
} => {
let input_rel = self.execute(input)?;

let input_schema = input_rel.as_ref().borrow().schema().clone();

match expr {
&Expr::Literal(ref scalar_value) => {
let limit: usize = match scalar_value {
ScalarValue::Int8(x) => Ok(*x as usize),
ScalarValue::Int16(x) => Ok(*x as usize),
ScalarValue::Int32(x) => Ok(*x as usize),
ScalarValue::Int64(x) => Ok(*x as usize),
ScalarValue::UInt8(x) => Ok(*x as usize),
ScalarValue::UInt16(x) => Ok(*x as usize),
ScalarValue::UInt32(x) => Ok(*x as usize),
ScalarValue::UInt64(x) => Ok(*x as usize),
_ => Err(ExecutionError::ExecutionError(
"Limit only support positive integer literals"
.to_string(),
)),
}?;
let rel = LimitRelation::new(input_rel, limit, input_schema);
Ok(Rc::new(RefCell::new(rel)))
}
_ => Err(ExecutionError::ExecutionError(
"Limit only support positive integer literals".to_string(),
)),
}
}

_ => unimplemented!(),
}
Expand Down
182 changes: 182 additions & 0 deletions rust/datafusion/src/execution/limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Execution of a limit (predicate)
use std::cell::RefCell;
use std::rc::Rc;
use std::sync::Arc;

use arrow::array::*;
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;

use super::error::{ExecutionError, Result};
use super::relation::Relation;

pub struct LimitRelation {
input: Rc<RefCell<Relation>>,
schema: Arc<Schema>,
limit: usize,
num_consumed_rows: usize,
}

impl LimitRelation {
pub fn new(input: Rc<RefCell<Relation>>, limit: usize, schema: Arc<Schema>) -> Self {
Self {
input,
schema,
limit,
num_consumed_rows: 0,
}
}
}

impl Relation for LimitRelation {
fn next(&mut self) -> Result<Option<RecordBatch>> {
match self.input.borrow_mut().next()? {
Some(batch) => {
let capacity = self.limit - self.num_consumed_rows;

if capacity <= 0 {
return Ok(None);
}

if batch.num_rows() >= capacity {
let limited_columns: Result<Vec<ArrayRef>> = (0..batch.num_columns())
.map(|i| limit(batch.column(i).as_ref(), capacity))
.collect();

let limited_batch: RecordBatch =
RecordBatch::new(self.schema.clone(), limited_columns?);
self.num_consumed_rows += capacity;

Ok(Some(limited_batch))
} else {
self.num_consumed_rows += batch.num_rows();
Ok(Some(batch))
}
}
None => Ok(None),
}
}

fn schema(&self) -> &Arc<Schema> {
&self.schema
}
}

//TODO: move into Arrow array_ops
fn limit(a: &Array, num_rows_to_read: usize) -> Result<ArrayRef> {
//TODO use macros
match a.data_type() {
DataType::UInt8 => {
let b = a.as_any().downcast_ref::<UInt8Array>().unwrap();
let mut builder = UInt8Array::builder(num_rows_to_read as usize);
for i in 0..num_rows_to_read {
builder.append_value(b.value(i as usize))?;
}
Ok(Arc::new(builder.finish()))
}
DataType::UInt16 => {
let b = a.as_any().downcast_ref::<UInt16Array>().unwrap();
let mut builder = UInt16Array::builder(num_rows_to_read as usize);
for i in 0..num_rows_to_read {
builder.append_value(b.value(i as usize))?;
}
Ok(Arc::new(builder.finish()))
}
DataType::UInt32 => {
let b = a.as_any().downcast_ref::<UInt32Array>().unwrap();
let mut builder = UInt32Array::builder(num_rows_to_read as usize);
for i in 0..num_rows_to_read {
builder.append_value(b.value(i as usize))?;
}
Ok(Arc::new(builder.finish()))
}
DataType::UInt64 => {
let b = a.as_any().downcast_ref::<UInt64Array>().unwrap();
let mut builder = UInt64Array::builder(num_rows_to_read as usize);
for i in 0..num_rows_to_read {
builder.append_value(b.value(i as usize))?;
}
Ok(Arc::new(builder.finish()))
}
DataType::Int8 => {
let b = a.as_any().downcast_ref::<Int8Array>().unwrap();
let mut builder = Int8Array::builder(num_rows_to_read as usize);
for i in 0..num_rows_to_read {
builder.append_value(b.value(i as usize))?;
}
Ok(Arc::new(builder.finish()))
}
DataType::Int16 => {
let b = a.as_any().downcast_ref::<Int16Array>().unwrap();
let mut builder = Int16Array::builder(num_rows_to_read as usize);
for i in 0..num_rows_to_read {
builder.append_value(b.value(i as usize))?;
}
Ok(Arc::new(builder.finish()))
}
DataType::Int32 => {
let b = a.as_any().downcast_ref::<Int32Array>().unwrap();
let mut builder = Int32Array::builder(num_rows_to_read as usize);
for i in 0..num_rows_to_read {
builder.append_value(b.value(i as usize))?;
}
Ok(Arc::new(builder.finish()))
}
DataType::Int64 => {
let b = a.as_any().downcast_ref::<Int64Array>().unwrap();
let mut builder = Int64Array::builder(num_rows_to_read as usize);
for i in 0..num_rows_to_read {
builder.append_value(b.value(i as usize))?;
}
Ok(Arc::new(builder.finish()))
}
DataType::Float32 => {
let b = a.as_any().downcast_ref::<Float32Array>().unwrap();
let mut builder = Float32Array::builder(num_rows_to_read as usize);
for i in 0..num_rows_to_read {
builder.append_value(b.value(i as usize))?;
}
Ok(Arc::new(builder.finish()))
}
DataType::Float64 => {
let b = a.as_any().downcast_ref::<Float64Array>().unwrap();
let mut builder = Float64Array::builder(num_rows_to_read as usize);
for i in 0..num_rows_to_read {
builder.append_value(b.value(i as usize))?;
}
Ok(Arc::new(builder.finish()))
}
DataType::Utf8 => {
//TODO: this is inefficient and we should improve the Arrow impl to help make this more concise
let b = a.as_any().downcast_ref::<BinaryArray>().unwrap();
let mut values: Vec<String> = Vec::with_capacity(num_rows_to_read as usize);
for i in 0..num_rows_to_read {
values.push(b.get_string(i as usize));
}
let tmp: Vec<&str> = values.iter().map(|s| s.as_str()).collect();
Ok(Arc::new(BinaryArray::from(tmp)))
}
other => Err(ExecutionError::ExecutionError(format!(
"filter not supported for {:?}",
other
))),
}
}
1 change: 1 addition & 0 deletions rust/datafusion/src/execution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub mod datasource;
pub mod error;
pub mod expression;
pub mod filter;
pub mod limit;
pub mod physicalplan;
pub mod projection;
pub mod relation;
15 changes: 15 additions & 0 deletions rust/datafusion/src/logicalplan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,12 @@ pub enum LogicalPlan {
},
/// An empty relation with an empty schema
EmptyRelation { schema: Arc<Schema> },
// Represents the maximum number of records to return
Limit {
expr: Expr,
input: Rc<LogicalPlan>,
schema: Arc<Schema>,
},
}

impl LogicalPlan {
Expand All @@ -362,6 +368,7 @@ impl LogicalPlan {
LogicalPlan::Selection { input, .. } => input.schema(),
LogicalPlan::Aggregate { schema, .. } => &schema,
LogicalPlan::Sort { schema, .. } => &schema,
LogicalPlan::Limit { schema, .. } => &schema,
}
}
}
Expand Down Expand Up @@ -430,6 +437,14 @@ impl LogicalPlan {
}
input.fmt_with_indent(f, indent + 1)
}
LogicalPlan::Limit {
ref input,
ref expr,
..
} => {
write!(f, "Limit: {:?}", expr)?;
input.fmt_with_indent(f, indent + 1)
}
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions rust/datafusion/src/optimizer/projection_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ impl ProjectionPushDown {
projection: Some(projection),
}))
}
LogicalPlan::Limit {
expr,
input,
schema,
} => Ok(Rc::new(LogicalPlan::Limit {
expr: expr.clone(),
input: input.clone(),
schema: schema.clone(),
})),
}
}

Expand Down
19 changes: 18 additions & 1 deletion rust/datafusion/src/sqlplanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ impl SqlToRel {
ref relation,
ref selection,
ref order_by,
ref limit,
ref group_by,
ref having,
..
Expand Down Expand Up @@ -167,7 +168,22 @@ impl SqlToRel {
_ => projection,
};

Ok(Rc::new(order_by_plan))
let limit_plan = match limit {
&Some(ref limit_expr) => {
let input_schema = order_by_plan.schema();
let limit_rex =
self.sql_to_rex(&limit_expr, &input_schema.clone())?;

LogicalPlan::Limit {
expr: limit_rex,
input: Rc::new(order_by_plan.clone()),
schema: input_schema.clone(),
}
}
_ => order_by_plan,
};

Ok(Rc::new(limit_plan))
}
}

Expand Down Expand Up @@ -491,6 +507,7 @@ pub fn push_down_projection(
}),
LogicalPlan::Projection { .. } => plan.clone(),
LogicalPlan::Sort { .. } => plan.clone(),
LogicalPlan::Limit { .. } => plan.clone(),
LogicalPlan::EmptyRelation { .. } => plan.clone(),
}
}
Expand Down
53 changes: 53 additions & 0 deletions rust/datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,59 @@ fn csv_query_cast() {
assert_eq!(expected, actual);
}

#[test]
fn csv_query_limit() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
let sql = "SELECT 0 FROM aggregate_test_100 LIMIT 2";
let actual = execute(&mut ctx, sql);
let expected = "0\n0\n".to_string();
assert_eq!(expected, actual);
}

#[test]
fn csv_query_limit_bigger_than_nbr_of_rows() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 200";
let actual = execute(&mut ctx, sql);
let expected = "2\n5\n1\n1\n5\n4\n3\n3\n1\n4\n1\n4\n3\n2\n1\n1\n2\n1\n3\n2\n4\n1\n5\n4\n2\n1\n4\n5\n2\n3\n4\n2\n1\n5\n3\n1\n2\n3\n3\n3\n2\n4\n1\n3\n2\n5\n2\n1\n4\n1\n4\n2\n5\n4\n2\n3\n4\n4\n4\n5\n4\n2\n1\n2\n4\n2\n3\n5\n1\n1\n4\n2\n1\n2\n1\n1\n5\n4\n5\n2\n3\n2\n4\n1\n3\n4\n3\n2\n5\n3\n3\n2\n5\n5\n4\n1\n3\n3\n4\n4\n".to_string();
assert_eq!(expected, actual);
}

#[test]
fn csv_query_limit_with_same_nbr_of_rows() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 100";
let actual = execute(&mut ctx, sql);
let expected = "2\n5\n1\n1\n5\n4\n3\n3\n1\n4\n1\n4\n3\n2\n1\n1\n2\n1\n3\n2\n4\n1\n5\n4\n2\n1\n4\n5\n2\n3\n4\n2\n1\n5\n3\n1\n2\n3\n3\n3\n2\n4\n1\n3\n2\n5\n2\n1\n4\n1\n4\n2\n5\n4\n2\n3\n4\n4\n4\n5\n4\n2\n1\n2\n4\n2\n3\n5\n1\n1\n4\n2\n1\n2\n1\n1\n5\n4\n5\n2\n3\n2\n4\n1\n3\n4\n3\n2\n5\n3\n3\n2\n5\n5\n4\n1\n3\n3\n4\n4\n".to_string();
assert_eq!(expected, actual);
}

#[test]
fn csv_query_limit_zero() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
let sql = "SELECT 0 FROM aggregate_test_100 LIMIT 0";
let actual = execute(&mut ctx, sql);
let expected = "".to_string();
assert_eq!(expected, actual);
}

//TODO Uncomment the following test when ORDER BY is implemented to be able to test ORDER BY + LIMIT
/*
#[test]
fn csv_query_limit_with_order_by() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
let sql = "SELECT c7 FROM aggregate_test_100 ORDER BY c7 ASC LIMIT 2";
let actual = execute(&mut ctx, sql);
let expected = "0\n2\n".to_string();
assert_eq!(expected, actual);
}
*/

fn aggr_test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("c1", DataType::Utf8, false),
Expand Down

0 comments on commit 811c7dc

Please sign in to comment.