You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2019/02/17 23:15:02 UTC
[arrow] branch master updated: ARROW-4464: [Rust] [DataFusion] Add
support for LIMIT
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 811c7dc ARROW-4464: [Rust] [DataFusion] Add support for LIMIT
811c7dc is described below
commit 811c7dc12268b58dbb35eda9e0404c7fb5a2e876
Author: Nicolas Trinquier <ns...@protonmail.ch>
AuthorDate: Sun Feb 17 16:14:51 2019 -0700
ARROW-4464: [Rust] [DataFusion] Add support for LIMIT
Author: Nicolas Trinquier <ns...@protonmail.ch>
Closes #3669 from ntrinquier/arrow-4464 and squashes the following commits:
facc5c2 <Nicolas Trinquier> Add Limit case to ProjectionPushDown
2ed488c <Nicolas Trinquier> Merge remote-tracking branch 'upstream/master' into arrow-4464
c78ae2c <Nicolas Trinquier> Use the previous batch's schema for Limit
e93df93 <Nicolas Trinquier> Remove redundant variable
dbc639f <Nicolas Trinquier> Make limit an usize and avoid evaluting the limit expression
eac5a24 <Nicolas Trinquier> Add support for Limit
---
rust/datafusion/src/execution/context.rs | 36 +++-
rust/datafusion/src/execution/limit.rs | 182 +++++++++++++++++++++
rust/datafusion/src/execution/mod.rs | 1 +
rust/datafusion/src/logicalplan.rs | 15 ++
.../src/optimizer/projection_push_down.rs | 9 +
rust/datafusion/src/sqlplanner.rs | 19 ++-
rust/datafusion/tests/sql.rs | 53 ++++++
7 files changed, 313 insertions(+), 2 deletions(-)
diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs
index 86d7c99..59c65a8 100644
--- a/rust/datafusion/src/execution/context.rs
+++ b/rust/datafusion/src/execution/context.rs
@@ -20,7 +20,7 @@ use std::collections::HashMap;
use std::rc::Rc;
use std::sync::Arc;
-use arrow::datatypes::{Field, Schema};
+use arrow::datatypes::*;
use super::super::dfparser::{DFASTNode, DFParser};
use super::super::logicalplan::*;
@@ -30,6 +30,7 @@ use super::datasource::DataSource;
use super::error::{ExecutionError, Result};
use super::expression::*;
use super::filter::FilterRelation;
+use super::limit::LimitRelation;
use super::projection::ProjectRelation;
use super::relation::{DataSourceRelation, Relation};
@@ -160,6 +161,39 @@ impl ExecutionContext {
Ok(Rc::new(RefCell::new(rel)))
}
+ LogicalPlan::Limit {
+ ref expr,
+ ref input,
+ ..
+ } => {
+ let input_rel = self.execute(input)?;
+
+ let input_schema = input_rel.as_ref().borrow().schema().clone();
+
+ match expr {
+ &Expr::Literal(ref scalar_value) => {
+ let limit: usize = match scalar_value {
+ ScalarValue::Int8(x) => Ok(*x as usize),
+ ScalarValue::Int16(x) => Ok(*x as usize),
+ ScalarValue::Int32(x) => Ok(*x as usize),
+ ScalarValue::Int64(x) => Ok(*x as usize),
+ ScalarValue::UInt8(x) => Ok(*x as usize),
+ ScalarValue::UInt16(x) => Ok(*x as usize),
+ ScalarValue::UInt32(x) => Ok(*x as usize),
+ ScalarValue::UInt64(x) => Ok(*x as usize),
+ _ => Err(ExecutionError::ExecutionError(
+ "Limit only support positive integer literals"
+ .to_string(),
+ )),
+ }?;
+ let rel = LimitRelation::new(input_rel, limit, input_schema);
+ Ok(Rc::new(RefCell::new(rel)))
+ }
+ _ => Err(ExecutionError::ExecutionError(
+ "Limit only support positive integer literals".to_string(),
+ )),
+ }
+ }
_ => unimplemented!(),
}
diff --git a/rust/datafusion/src/execution/limit.rs b/rust/datafusion/src/execution/limit.rs
new file mode 100644
index 0000000..d6258d6
--- /dev/null
+++ b/rust/datafusion/src/execution/limit.rs
@@ -0,0 +1,182 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Execution of a limit (predicate)
+
+use std::cell::RefCell;
+use std::rc::Rc;
+use std::sync::Arc;
+
+use arrow::array::*;
+use arrow::datatypes::{DataType, Schema};
+use arrow::record_batch::RecordBatch;
+
+use super::error::{ExecutionError, Result};
+use super::relation::Relation;
+
+pub struct LimitRelation {
+ input: Rc<RefCell<Relation>>,
+ schema: Arc<Schema>,
+ limit: usize,
+ num_consumed_rows: usize,
+}
+
+impl LimitRelation {
+ pub fn new(input: Rc<RefCell<Relation>>, limit: usize, schema: Arc<Schema>) -> Self {
+ Self {
+ input,
+ schema,
+ limit,
+ num_consumed_rows: 0,
+ }
+ }
+}
+
+impl Relation for LimitRelation {
+ fn next(&mut self) -> Result<Option<RecordBatch>> {
+ match self.input.borrow_mut().next()? {
+ Some(batch) => {
+ let capacity = self.limit - self.num_consumed_rows;
+
+ if capacity <= 0 {
+ return Ok(None);
+ }
+
+ if batch.num_rows() >= capacity {
+ let limited_columns: Result<Vec<ArrayRef>> = (0..batch.num_columns())
+ .map(|i| limit(batch.column(i).as_ref(), capacity))
+ .collect();
+
+ let limited_batch: RecordBatch =
+ RecordBatch::new(self.schema.clone(), limited_columns?);
+ self.num_consumed_rows += capacity;
+
+ Ok(Some(limited_batch))
+ } else {
+ self.num_consumed_rows += batch.num_rows();
+ Ok(Some(batch))
+ }
+ }
+ None => Ok(None),
+ }
+ }
+
+ fn schema(&self) -> &Arc<Schema> {
+ &self.schema
+ }
+}
+
+//TODO: move into Arrow array_ops
+fn limit(a: &Array, num_rows_to_read: usize) -> Result<ArrayRef> {
+ //TODO use macros
+ match a.data_type() {
+ DataType::UInt8 => {
+ let b = a.as_any().downcast_ref::<UInt8Array>().unwrap();
+ let mut builder = UInt8Array::builder(num_rows_to_read as usize);
+ for i in 0..num_rows_to_read {
+ builder.append_value(b.value(i as usize))?;
+ }
+ Ok(Arc::new(builder.finish()))
+ }
+ DataType::UInt16 => {
+ let b = a.as_any().downcast_ref::<UInt16Array>().unwrap();
+ let mut builder = UInt16Array::builder(num_rows_to_read as usize);
+ for i in 0..num_rows_to_read {
+ builder.append_value(b.value(i as usize))?;
+ }
+ Ok(Arc::new(builder.finish()))
+ }
+ DataType::UInt32 => {
+ let b = a.as_any().downcast_ref::<UInt32Array>().unwrap();
+ let mut builder = UInt32Array::builder(num_rows_to_read as usize);
+ for i in 0..num_rows_to_read {
+ builder.append_value(b.value(i as usize))?;
+ }
+ Ok(Arc::new(builder.finish()))
+ }
+ DataType::UInt64 => {
+ let b = a.as_any().downcast_ref::<UInt64Array>().unwrap();
+ let mut builder = UInt64Array::builder(num_rows_to_read as usize);
+ for i in 0..num_rows_to_read {
+ builder.append_value(b.value(i as usize))?;
+ }
+ Ok(Arc::new(builder.finish()))
+ }
+ DataType::Int8 => {
+ let b = a.as_any().downcast_ref::<Int8Array>().unwrap();
+ let mut builder = Int8Array::builder(num_rows_to_read as usize);
+ for i in 0..num_rows_to_read {
+ builder.append_value(b.value(i as usize))?;
+ }
+ Ok(Arc::new(builder.finish()))
+ }
+ DataType::Int16 => {
+ let b = a.as_any().downcast_ref::<Int16Array>().unwrap();
+ let mut builder = Int16Array::builder(num_rows_to_read as usize);
+ for i in 0..num_rows_to_read {
+ builder.append_value(b.value(i as usize))?;
+ }
+ Ok(Arc::new(builder.finish()))
+ }
+ DataType::Int32 => {
+ let b = a.as_any().downcast_ref::<Int32Array>().unwrap();
+ let mut builder = Int32Array::builder(num_rows_to_read as usize);
+ for i in 0..num_rows_to_read {
+ builder.append_value(b.value(i as usize))?;
+ }
+ Ok(Arc::new(builder.finish()))
+ }
+ DataType::Int64 => {
+ let b = a.as_any().downcast_ref::<Int64Array>().unwrap();
+ let mut builder = Int64Array::builder(num_rows_to_read as usize);
+ for i in 0..num_rows_to_read {
+ builder.append_value(b.value(i as usize))?;
+ }
+ Ok(Arc::new(builder.finish()))
+ }
+ DataType::Float32 => {
+ let b = a.as_any().downcast_ref::<Float32Array>().unwrap();
+ let mut builder = Float32Array::builder(num_rows_to_read as usize);
+ for i in 0..num_rows_to_read {
+ builder.append_value(b.value(i as usize))?;
+ }
+ Ok(Arc::new(builder.finish()))
+ }
+ DataType::Float64 => {
+ let b = a.as_any().downcast_ref::<Float64Array>().unwrap();
+ let mut builder = Float64Array::builder(num_rows_to_read as usize);
+ for i in 0..num_rows_to_read {
+ builder.append_value(b.value(i as usize))?;
+ }
+ Ok(Arc::new(builder.finish()))
+ }
+ DataType::Utf8 => {
+ //TODO: this is inefficient and we should improve the Arrow impl to help make this more concise
+ let b = a.as_any().downcast_ref::<BinaryArray>().unwrap();
+ let mut values: Vec<String> = Vec::with_capacity(num_rows_to_read as usize);
+ for i in 0..num_rows_to_read {
+ values.push(b.get_string(i as usize));
+ }
+ let tmp: Vec<&str> = values.iter().map(|s| s.as_str()).collect();
+ Ok(Arc::new(BinaryArray::from(tmp)))
+ }
+ other => Err(ExecutionError::ExecutionError(format!(
+ "filter not supported for {:?}",
+ other
+ ))),
+ }
+}
diff --git a/rust/datafusion/src/execution/mod.rs b/rust/datafusion/src/execution/mod.rs
index 23144bb..9eb303f 100644
--- a/rust/datafusion/src/execution/mod.rs
+++ b/rust/datafusion/src/execution/mod.rs
@@ -21,6 +21,7 @@ pub mod datasource;
pub mod error;
pub mod expression;
pub mod filter;
+pub mod limit;
pub mod physicalplan;
pub mod projection;
pub mod relation;
diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs
index b3e6bda..7dd4602 100644
--- a/rust/datafusion/src/logicalplan.rs
+++ b/rust/datafusion/src/logicalplan.rs
@@ -350,6 +350,12 @@ pub enum LogicalPlan {
},
/// An empty relation with an empty schema
EmptyRelation { schema: Arc<Schema> },
+ // Represents the maximum number of records to return
+ Limit {
+ expr: Expr,
+ input: Rc<LogicalPlan>,
+ schema: Arc<Schema>,
+ },
}
impl LogicalPlan {
@@ -362,6 +368,7 @@ impl LogicalPlan {
LogicalPlan::Selection { input, .. } => input.schema(),
LogicalPlan::Aggregate { schema, .. } => &schema,
LogicalPlan::Sort { schema, .. } => &schema,
+ LogicalPlan::Limit { schema, .. } => &schema,
}
}
}
@@ -430,6 +437,14 @@ impl LogicalPlan {
}
input.fmt_with_indent(f, indent + 1)
}
+ LogicalPlan::Limit {
+ ref input,
+ ref expr,
+ ..
+ } => {
+ write!(f, "Limit: {:?}", expr)?;
+ input.fmt_with_indent(f, indent + 1)
+ }
}
}
}
diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs
index 8fd2e8c..b8d98fe 100644
--- a/rust/datafusion/src/optimizer/projection_push_down.rs
+++ b/rust/datafusion/src/optimizer/projection_push_down.rs
@@ -179,6 +179,15 @@ impl ProjectionPushDown {
projection: Some(projection),
}))
}
+ LogicalPlan::Limit {
+ expr,
+ input,
+ schema,
+ } => Ok(Rc::new(LogicalPlan::Limit {
+ expr: expr.clone(),
+ input: input.clone(),
+ schema: schema.clone(),
+ })),
}
}
diff --git a/rust/datafusion/src/sqlplanner.rs b/rust/datafusion/src/sqlplanner.rs
index dcb69eb..fc8048f 100644
--- a/rust/datafusion/src/sqlplanner.rs
+++ b/rust/datafusion/src/sqlplanner.rs
@@ -53,6 +53,7 @@ impl SqlToRel {
ref relation,
ref selection,
ref order_by,
+ ref limit,
ref group_by,
ref having,
..
@@ -167,7 +168,22 @@ impl SqlToRel {
_ => projection,
};
- Ok(Rc::new(order_by_plan))
+ let limit_plan = match limit {
+ &Some(ref limit_expr) => {
+ let input_schema = order_by_plan.schema();
+ let limit_rex =
+ self.sql_to_rex(&limit_expr, &input_schema.clone())?;
+
+ LogicalPlan::Limit {
+ expr: limit_rex,
+ input: Rc::new(order_by_plan.clone()),
+ schema: input_schema.clone(),
+ }
+ }
+ _ => order_by_plan,
+ };
+
+ Ok(Rc::new(limit_plan))
}
}
@@ -491,6 +507,7 @@ pub fn push_down_projection(
}),
LogicalPlan::Projection { .. } => plan.clone(),
LogicalPlan::Sort { .. } => plan.clone(),
+ LogicalPlan::Limit { .. } => plan.clone(),
LogicalPlan::EmptyRelation { .. } => plan.clone(),
}
}
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index bd22808..6f96980 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -72,6 +72,59 @@ fn csv_query_cast() {
assert_eq!(expected, actual);
}
+#[test]
+fn csv_query_limit() {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx);
+ let sql = "SELECT 0 FROM aggregate_test_100 LIMIT 2";
+ let actual = execute(&mut ctx, sql);
+ let expected = "0\n0\n".to_string();
+ assert_eq!(expected, actual);
+}
+
+#[test]
+fn csv_query_limit_bigger_than_nbr_of_rows() {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx);
+ let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 200";
+ let actual = execute(&mut ctx, sql);
+ let expected = "2\n5\n1\n1\n5\n4\n3\n3\n1\n4\n1\n4\n3\n2\n1\n1\n2\n1\n3\n2\n4\n1\n5\n4\n2\n1\n4\n5\n2\n3\n4\n2\n1\n5\n3\n1\n2\n3\n3\n3\n2\n4\n1\n3\n2\n5\n2\n1\n4\n1\n4\n2\n5\n4\n2\n3\n4\n4\n4\n5\n4\n2\n1\n2\n4\n2\n3\n5\n1\n1\n4\n2\n1\n2\n1\n1\n5\n4\n5\n2\n3\n2\n4\n1\n3\n4\n3\n2\n5\n3\n3\n2\n5\n5\n4\n1\n3\n3\n4\n4\n".to_string();
+ assert_eq!(expected, actual);
+}
+
+#[test]
+fn csv_query_limit_with_same_nbr_of_rows() {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx);
+ let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 100";
+ let actual = execute(&mut ctx, sql);
+ let expected = "2\n5\n1\n1\n5\n4\n3\n3\n1\n4\n1\n4\n3\n2\n1\n1\n2\n1\n3\n2\n4\n1\n5\n4\n2\n1\n4\n5\n2\n3\n4\n2\n1\n5\n3\n1\n2\n3\n3\n3\n2\n4\n1\n3\n2\n5\n2\n1\n4\n1\n4\n2\n5\n4\n2\n3\n4\n4\n4\n5\n4\n2\n1\n2\n4\n2\n3\n5\n1\n1\n4\n2\n1\n2\n1\n1\n5\n4\n5\n2\n3\n2\n4\n1\n3\n4\n3\n2\n5\n3\n3\n2\n5\n5\n4\n1\n3\n3\n4\n4\n".to_string();
+ assert_eq!(expected, actual);
+}
+
+#[test]
+fn csv_query_limit_zero() {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx);
+ let sql = "SELECT 0 FROM aggregate_test_100 LIMIT 0";
+ let actual = execute(&mut ctx, sql);
+ let expected = "".to_string();
+ assert_eq!(expected, actual);
+}
+
+//TODO Uncomment the following test when ORDER BY is implemented to be able to test ORDER BY + LIMIT
+/*
+#[test]
+fn csv_query_limit_with_order_by() {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx);
+ let sql = "SELECT c7 FROM aggregate_test_100 ORDER BY c7 ASC LIMIT 2";
+ let actual = execute(&mut ctx, sql);
+ let expected = "0\n2\n".to_string();
+ assert_eq!(expected, actual);
+}
+*/
+
fn aggr_test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("c1", DataType::Utf8, false),