From 811c7dc12268b58dbb35eda9e0404c7fb5a2e876 Mon Sep 17 00:00:00 2001 From: Nicolas Trinquier Date: Sun, 17 Feb 2019 16:14:51 -0700 Subject: [PATCH] ARROW-4464: [Rust] [DataFusion] Add support for LIMIT Author: Nicolas Trinquier Closes #3669 from ntrinquier/arrow-4464 and squashes the following commits: facc5c2 Add Limit case to ProjectionPushDown 2ed488c Merge remote-tracking branch 'upstream/master' into arrow-4464 c78ae2c Use the previous batch's schema for Limit e93df93 Remove redundant variable dbc639f Make limit an usize and avoid evaluting the limit expression eac5a24 Add support for Limit --- rust/datafusion/src/execution/context.rs | 36 +++- rust/datafusion/src/execution/limit.rs | 182 ++++++++++++++++++ rust/datafusion/src/execution/mod.rs | 1 + rust/datafusion/src/logicalplan.rs | 15 ++ .../src/optimizer/projection_push_down.rs | 9 + rust/datafusion/src/sqlplanner.rs | 19 +- rust/datafusion/tests/sql.rs | 53 +++++ 7 files changed, 313 insertions(+), 2 deletions(-) create mode 100644 rust/datafusion/src/execution/limit.rs diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 86d7c99c21900..59c65a8ee5972 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -20,7 +20,7 @@ use std::collections::HashMap; use std::rc::Rc; use std::sync::Arc; -use arrow::datatypes::{Field, Schema}; +use arrow::datatypes::*; use super::super::dfparser::{DFASTNode, DFParser}; use super::super::logicalplan::*; @@ -30,6 +30,7 @@ use super::datasource::DataSource; use super::error::{ExecutionError, Result}; use super::expression::*; use super::filter::FilterRelation; +use super::limit::LimitRelation; use super::projection::ProjectRelation; use super::relation::{DataSourceRelation, Relation}; @@ -160,6 +161,39 @@ impl ExecutionContext { Ok(Rc::new(RefCell::new(rel))) } + LogicalPlan::Limit { + ref expr, + ref input, + .. + } => { + let input_rel = self.execute(input)?; + + let input_schema = input_rel.as_ref().borrow().schema().clone(); + + match expr { + &Expr::Literal(ref scalar_value) => { + let limit: usize = match scalar_value { + ScalarValue::Int8(x) => Ok(*x as usize), + ScalarValue::Int16(x) => Ok(*x as usize), + ScalarValue::Int32(x) => Ok(*x as usize), + ScalarValue::Int64(x) => Ok(*x as usize), + ScalarValue::UInt8(x) => Ok(*x as usize), + ScalarValue::UInt16(x) => Ok(*x as usize), + ScalarValue::UInt32(x) => Ok(*x as usize), + ScalarValue::UInt64(x) => Ok(*x as usize), + _ => Err(ExecutionError::ExecutionError( + "Limit only support positive integer literals" + .to_string(), + )), + }?; + let rel = LimitRelation::new(input_rel, limit, input_schema); + Ok(Rc::new(RefCell::new(rel))) + } + _ => Err(ExecutionError::ExecutionError( + "Limit only support positive integer literals".to_string(), + )), + } + } _ => unimplemented!(), } diff --git a/rust/datafusion/src/execution/limit.rs b/rust/datafusion/src/execution/limit.rs new file mode 100644 index 0000000000000..d6258d63db99f --- /dev/null +++ b/rust/datafusion/src/execution/limit.rs @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution of a limit (predicate) + +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Arc; + +use arrow::array::*; +use arrow::datatypes::{DataType, Schema}; +use arrow::record_batch::RecordBatch; + +use super::error::{ExecutionError, Result}; +use super::relation::Relation; + +pub struct LimitRelation { + input: Rc>, + schema: Arc, + limit: usize, + num_consumed_rows: usize, +} + +impl LimitRelation { + pub fn new(input: Rc>, limit: usize, schema: Arc) -> Self { + Self { + input, + schema, + limit, + num_consumed_rows: 0, + } + } +} + +impl Relation for LimitRelation { + fn next(&mut self) -> Result> { + match self.input.borrow_mut().next()? { + Some(batch) => { + let capacity = self.limit - self.num_consumed_rows; + + if capacity <= 0 { + return Ok(None); + } + + if batch.num_rows() >= capacity { + let limited_columns: Result> = (0..batch.num_columns()) + .map(|i| limit(batch.column(i).as_ref(), capacity)) + .collect(); + + let limited_batch: RecordBatch = + RecordBatch::new(self.schema.clone(), limited_columns?); + self.num_consumed_rows += capacity; + + Ok(Some(limited_batch)) + } else { + self.num_consumed_rows += batch.num_rows(); + Ok(Some(batch)) + } + } + None => Ok(None), + } + } + + fn schema(&self) -> &Arc { + &self.schema + } +} + +//TODO: move into Arrow array_ops +fn limit(a: &Array, num_rows_to_read: usize) -> Result { + //TODO use macros + match a.data_type() { + DataType::UInt8 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = UInt8Array::builder(num_rows_to_read as usize); + for i in 0..num_rows_to_read { + builder.append_value(b.value(i as usize))?; + } + Ok(Arc::new(builder.finish())) + } + DataType::UInt16 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = UInt16Array::builder(num_rows_to_read as usize); + for i in 0..num_rows_to_read { + builder.append_value(b.value(i as usize))?; + } + Ok(Arc::new(builder.finish())) + } + DataType::UInt32 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = UInt32Array::builder(num_rows_to_read as usize); + for i in 0..num_rows_to_read { + builder.append_value(b.value(i as usize))?; + } + Ok(Arc::new(builder.finish())) + } + DataType::UInt64 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = UInt64Array::builder(num_rows_to_read as usize); + for i in 0..num_rows_to_read { + builder.append_value(b.value(i as usize))?; + } + Ok(Arc::new(builder.finish())) + } + DataType::Int8 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Int8Array::builder(num_rows_to_read as usize); + for i in 0..num_rows_to_read { + builder.append_value(b.value(i as usize))?; + } + Ok(Arc::new(builder.finish())) + } + DataType::Int16 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Int16Array::builder(num_rows_to_read as usize); + for i in 0..num_rows_to_read { + builder.append_value(b.value(i as usize))?; + } + Ok(Arc::new(builder.finish())) + } + DataType::Int32 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Int32Array::builder(num_rows_to_read as usize); + for i in 0..num_rows_to_read { + builder.append_value(b.value(i as usize))?; + } + Ok(Arc::new(builder.finish())) + } + DataType::Int64 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Int64Array::builder(num_rows_to_read as usize); + for i in 0..num_rows_to_read { + builder.append_value(b.value(i as usize))?; + } + Ok(Arc::new(builder.finish())) + } + DataType::Float32 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Float32Array::builder(num_rows_to_read as usize); + for i in 0..num_rows_to_read { + builder.append_value(b.value(i as usize))?; + } + Ok(Arc::new(builder.finish())) + } + DataType::Float64 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Float64Array::builder(num_rows_to_read as usize); + for i in 0..num_rows_to_read { + builder.append_value(b.value(i as usize))?; + } + Ok(Arc::new(builder.finish())) + } + DataType::Utf8 => { + //TODO: this is inefficient and we should improve the Arrow impl to help make this more concise + let b = a.as_any().downcast_ref::().unwrap(); + let mut values: Vec = Vec::with_capacity(num_rows_to_read as usize); + for i in 0..num_rows_to_read { + values.push(b.get_string(i as usize)); + } + let tmp: Vec<&str> = values.iter().map(|s| s.as_str()).collect(); + Ok(Arc::new(BinaryArray::from(tmp))) + } + other => Err(ExecutionError::ExecutionError(format!( + "filter not supported for {:?}", + other + ))), + } +} diff --git a/rust/datafusion/src/execution/mod.rs b/rust/datafusion/src/execution/mod.rs index 23144bb5173ca..9eb303f3d4fbc 100644 --- a/rust/datafusion/src/execution/mod.rs +++ b/rust/datafusion/src/execution/mod.rs @@ -21,6 +21,7 @@ pub mod datasource; pub mod error; pub mod expression; pub mod filter; +pub mod limit; pub mod physicalplan; pub mod projection; pub mod relation; diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index b3e6bda545996..7dd4602b30bd1 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -350,6 +350,12 @@ pub enum LogicalPlan { }, /// An empty relation with an empty schema EmptyRelation { schema: Arc }, + // Represents the maximum number of records to return + Limit { + expr: Expr, + input: Rc, + schema: Arc, + }, } impl LogicalPlan { @@ -362,6 +368,7 @@ impl LogicalPlan { LogicalPlan::Selection { input, .. } => input.schema(), LogicalPlan::Aggregate { schema, .. } => &schema, LogicalPlan::Sort { schema, .. } => &schema, + LogicalPlan::Limit { schema, .. } => &schema, } } } @@ -430,6 +437,14 @@ impl LogicalPlan { } input.fmt_with_indent(f, indent + 1) } + LogicalPlan::Limit { + ref input, + ref expr, + .. + } => { + write!(f, "Limit: {:?}", expr)?; + input.fmt_with_indent(f, indent + 1) + } } } } diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs index 8fd2e8cc1f096..b8d98fe7229fc 100644 --- a/rust/datafusion/src/optimizer/projection_push_down.rs +++ b/rust/datafusion/src/optimizer/projection_push_down.rs @@ -179,6 +179,15 @@ impl ProjectionPushDown { projection: Some(projection), })) } + LogicalPlan::Limit { + expr, + input, + schema, + } => Ok(Rc::new(LogicalPlan::Limit { + expr: expr.clone(), + input: input.clone(), + schema: schema.clone(), + })), } } diff --git a/rust/datafusion/src/sqlplanner.rs b/rust/datafusion/src/sqlplanner.rs index dcb69ebc33f0a..fc8048f856424 100644 --- a/rust/datafusion/src/sqlplanner.rs +++ b/rust/datafusion/src/sqlplanner.rs @@ -53,6 +53,7 @@ impl SqlToRel { ref relation, ref selection, ref order_by, + ref limit, ref group_by, ref having, .. @@ -167,7 +168,22 @@ impl SqlToRel { _ => projection, }; - Ok(Rc::new(order_by_plan)) + let limit_plan = match limit { + &Some(ref limit_expr) => { + let input_schema = order_by_plan.schema(); + let limit_rex = + self.sql_to_rex(&limit_expr, &input_schema.clone())?; + + LogicalPlan::Limit { + expr: limit_rex, + input: Rc::new(order_by_plan.clone()), + schema: input_schema.clone(), + } + } + _ => order_by_plan, + }; + + Ok(Rc::new(limit_plan)) } } @@ -491,6 +507,7 @@ pub fn push_down_projection( }), LogicalPlan::Projection { .. } => plan.clone(), LogicalPlan::Sort { .. } => plan.clone(), + LogicalPlan::Limit { .. } => plan.clone(), LogicalPlan::EmptyRelation { .. } => plan.clone(), } } diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index bd228087dcc6e..6f96980ca9607 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -72,6 +72,59 @@ fn csv_query_cast() { assert_eq!(expected, actual); } +#[test] +fn csv_query_limit() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx); + let sql = "SELECT 0 FROM aggregate_test_100 LIMIT 2"; + let actual = execute(&mut ctx, sql); + let expected = "0\n0\n".to_string(); + assert_eq!(expected, actual); +} + +#[test] +fn csv_query_limit_bigger_than_nbr_of_rows() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx); + let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 200"; + let actual = execute(&mut ctx, sql); + let expected = "2\n5\n1\n1\n5\n4\n3\n3\n1\n4\n1\n4\n3\n2\n1\n1\n2\n1\n3\n2\n4\n1\n5\n4\n2\n1\n4\n5\n2\n3\n4\n2\n1\n5\n3\n1\n2\n3\n3\n3\n2\n4\n1\n3\n2\n5\n2\n1\n4\n1\n4\n2\n5\n4\n2\n3\n4\n4\n4\n5\n4\n2\n1\n2\n4\n2\n3\n5\n1\n1\n4\n2\n1\n2\n1\n1\n5\n4\n5\n2\n3\n2\n4\n1\n3\n4\n3\n2\n5\n3\n3\n2\n5\n5\n4\n1\n3\n3\n4\n4\n".to_string(); + assert_eq!(expected, actual); +} + +#[test] +fn csv_query_limit_with_same_nbr_of_rows() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx); + let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 100"; + let actual = execute(&mut ctx, sql); + let expected = "2\n5\n1\n1\n5\n4\n3\n3\n1\n4\n1\n4\n3\n2\n1\n1\n2\n1\n3\n2\n4\n1\n5\n4\n2\n1\n4\n5\n2\n3\n4\n2\n1\n5\n3\n1\n2\n3\n3\n3\n2\n4\n1\n3\n2\n5\n2\n1\n4\n1\n4\n2\n5\n4\n2\n3\n4\n4\n4\n5\n4\n2\n1\n2\n4\n2\n3\n5\n1\n1\n4\n2\n1\n2\n1\n1\n5\n4\n5\n2\n3\n2\n4\n1\n3\n4\n3\n2\n5\n3\n3\n2\n5\n5\n4\n1\n3\n3\n4\n4\n".to_string(); + assert_eq!(expected, actual); +} + +#[test] +fn csv_query_limit_zero() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx); + let sql = "SELECT 0 FROM aggregate_test_100 LIMIT 0"; + let actual = execute(&mut ctx, sql); + let expected = "".to_string(); + assert_eq!(expected, actual); +} + +//TODO Uncomment the following test when ORDER BY is implemented to be able to test ORDER BY + LIMIT +/* +#[test] +fn csv_query_limit_with_order_by() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx); + let sql = "SELECT c7 FROM aggregate_test_100 ORDER BY c7 ASC LIMIT 2"; + let actual = execute(&mut ctx, sql); + let expected = "0\n2\n".to_string(); + assert_eq!(expected, actual); +} +*/ + fn aggr_test_schema() -> Arc { Arc::new(Schema::new(vec![ Field::new("c1", DataType::Utf8, false),