You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2019/02/17 23:15:02 UTC

[arrow] branch master updated: ARROW-4464: [Rust] [DataFusion] Add support for LIMIT

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 811c7dc  ARROW-4464: [Rust] [DataFusion] Add support for LIMIT
811c7dc is described below

commit 811c7dc12268b58dbb35eda9e0404c7fb5a2e876
Author: Nicolas Trinquier <ns...@protonmail.ch>
AuthorDate: Sun Feb 17 16:14:51 2019 -0700

    ARROW-4464: [Rust] [DataFusion] Add support for LIMIT
    
    Author: Nicolas Trinquier <ns...@protonmail.ch>
    
    Closes #3669 from ntrinquier/arrow-4464 and squashes the following commits:
    
    facc5c2 <Nicolas Trinquier> Add Limit case to ProjectionPushDown
    2ed488c <Nicolas Trinquier> Merge remote-tracking branch 'upstream/master' into arrow-4464
    c78ae2c <Nicolas Trinquier> Use the previous batch's schema for Limit
    e93df93 <Nicolas Trinquier> Remove redundant variable
    dbc639f <Nicolas Trinquier> Make limit an usize and avoid evaluting the limit expression
    eac5a24 <Nicolas Trinquier> Add support for Limit
---
 rust/datafusion/src/execution/context.rs           |  36 +++-
 rust/datafusion/src/execution/limit.rs             | 182 +++++++++++++++++++++
 rust/datafusion/src/execution/mod.rs               |   1 +
 rust/datafusion/src/logicalplan.rs                 |  15 ++
 .../src/optimizer/projection_push_down.rs          |   9 +
 rust/datafusion/src/sqlplanner.rs                  |  19 ++-
 rust/datafusion/tests/sql.rs                       |  53 ++++++
 7 files changed, 313 insertions(+), 2 deletions(-)

diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs
index 86d7c99..59c65a8 100644
--- a/rust/datafusion/src/execution/context.rs
+++ b/rust/datafusion/src/execution/context.rs
@@ -20,7 +20,7 @@ use std::collections::HashMap;
 use std::rc::Rc;
 use std::sync::Arc;
 
-use arrow::datatypes::{Field, Schema};
+use arrow::datatypes::*;
 
 use super::super::dfparser::{DFASTNode, DFParser};
 use super::super::logicalplan::*;
@@ -30,6 +30,7 @@ use super::datasource::DataSource;
 use super::error::{ExecutionError, Result};
 use super::expression::*;
 use super::filter::FilterRelation;
+use super::limit::LimitRelation;
 use super::projection::ProjectRelation;
 use super::relation::{DataSourceRelation, Relation};
 
@@ -160,6 +161,39 @@ impl ExecutionContext {
 
                 Ok(Rc::new(RefCell::new(rel)))
             }
+            LogicalPlan::Limit {
+                ref expr,
+                ref input,
+                ..
+            } => {
+                let input_rel = self.execute(input)?;
+
+                let input_schema = input_rel.as_ref().borrow().schema().clone();
+
+                match expr {
+                    &Expr::Literal(ref scalar_value) => {
+                        let limit: usize = match scalar_value {
+                            ScalarValue::Int8(x) => Ok(*x as usize),
+                            ScalarValue::Int16(x) => Ok(*x as usize),
+                            ScalarValue::Int32(x) => Ok(*x as usize),
+                            ScalarValue::Int64(x) => Ok(*x as usize),
+                            ScalarValue::UInt8(x) => Ok(*x as usize),
+                            ScalarValue::UInt16(x) => Ok(*x as usize),
+                            ScalarValue::UInt32(x) => Ok(*x as usize),
+                            ScalarValue::UInt64(x) => Ok(*x as usize),
+                            _ => Err(ExecutionError::ExecutionError(
+                                "Limit only support positive integer literals"
+                                    .to_string(),
+                            )),
+                        }?;
+                        let rel = LimitRelation::new(input_rel, limit, input_schema);
+                        Ok(Rc::new(RefCell::new(rel)))
+                    }
+                    _ => Err(ExecutionError::ExecutionError(
+                        "Limit only support positive integer literals".to_string(),
+                    )),
+                }
+            }
 
             _ => unimplemented!(),
         }
diff --git a/rust/datafusion/src/execution/limit.rs b/rust/datafusion/src/execution/limit.rs
new file mode 100644
index 0000000..d6258d6
--- /dev/null
+++ b/rust/datafusion/src/execution/limit.rs
@@ -0,0 +1,182 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Execution of a limit (predicate)
+
+use std::cell::RefCell;
+use std::rc::Rc;
+use std::sync::Arc;
+
+use arrow::array::*;
+use arrow::datatypes::{DataType, Schema};
+use arrow::record_batch::RecordBatch;
+
+use super::error::{ExecutionError, Result};
+use super::relation::Relation;
+
+pub struct LimitRelation {
+    input: Rc<RefCell<Relation>>,
+    schema: Arc<Schema>,
+    limit: usize,
+    num_consumed_rows: usize,
+}
+
+impl LimitRelation {
+    pub fn new(input: Rc<RefCell<Relation>>, limit: usize, schema: Arc<Schema>) -> Self {
+        Self {
+            input,
+            schema,
+            limit,
+            num_consumed_rows: 0,
+        }
+    }
+}
+
+impl Relation for LimitRelation {
+    fn next(&mut self) -> Result<Option<RecordBatch>> {
+        match self.input.borrow_mut().next()? {
+            Some(batch) => {
+                let capacity = self.limit - self.num_consumed_rows;
+
+                if capacity <= 0 {
+                    return Ok(None);
+                }
+
+                if batch.num_rows() >= capacity {
+                    let limited_columns: Result<Vec<ArrayRef>> = (0..batch.num_columns())
+                        .map(|i| limit(batch.column(i).as_ref(), capacity))
+                        .collect();
+
+                    let limited_batch: RecordBatch =
+                        RecordBatch::new(self.schema.clone(), limited_columns?);
+                    self.num_consumed_rows += capacity;
+
+                    Ok(Some(limited_batch))
+                } else {
+                    self.num_consumed_rows += batch.num_rows();
+                    Ok(Some(batch))
+                }
+            }
+            None => Ok(None),
+        }
+    }
+
+    fn schema(&self) -> &Arc<Schema> {
+        &self.schema
+    }
+}
+
+//TODO: move into Arrow array_ops
+fn limit(a: &Array, num_rows_to_read: usize) -> Result<ArrayRef> {
+    //TODO use macros
+    match a.data_type() {
+        DataType::UInt8 => {
+            let b = a.as_any().downcast_ref::<UInt8Array>().unwrap();
+            let mut builder = UInt8Array::builder(num_rows_to_read as usize);
+            for i in 0..num_rows_to_read {
+                builder.append_value(b.value(i as usize))?;
+            }
+            Ok(Arc::new(builder.finish()))
+        }
+        DataType::UInt16 => {
+            let b = a.as_any().downcast_ref::<UInt16Array>().unwrap();
+            let mut builder = UInt16Array::builder(num_rows_to_read as usize);
+            for i in 0..num_rows_to_read {
+                builder.append_value(b.value(i as usize))?;
+            }
+            Ok(Arc::new(builder.finish()))
+        }
+        DataType::UInt32 => {
+            let b = a.as_any().downcast_ref::<UInt32Array>().unwrap();
+            let mut builder = UInt32Array::builder(num_rows_to_read as usize);
+            for i in 0..num_rows_to_read {
+                builder.append_value(b.value(i as usize))?;
+            }
+            Ok(Arc::new(builder.finish()))
+        }
+        DataType::UInt64 => {
+            let b = a.as_any().downcast_ref::<UInt64Array>().unwrap();
+            let mut builder = UInt64Array::builder(num_rows_to_read as usize);
+            for i in 0..num_rows_to_read {
+                builder.append_value(b.value(i as usize))?;
+            }
+            Ok(Arc::new(builder.finish()))
+        }
+        DataType::Int8 => {
+            let b = a.as_any().downcast_ref::<Int8Array>().unwrap();
+            let mut builder = Int8Array::builder(num_rows_to_read as usize);
+            for i in 0..num_rows_to_read {
+                builder.append_value(b.value(i as usize))?;
+            }
+            Ok(Arc::new(builder.finish()))
+        }
+        DataType::Int16 => {
+            let b = a.as_any().downcast_ref::<Int16Array>().unwrap();
+            let mut builder = Int16Array::builder(num_rows_to_read as usize);
+            for i in 0..num_rows_to_read {
+                builder.append_value(b.value(i as usize))?;
+            }
+            Ok(Arc::new(builder.finish()))
+        }
+        DataType::Int32 => {
+            let b = a.as_any().downcast_ref::<Int32Array>().unwrap();
+            let mut builder = Int32Array::builder(num_rows_to_read as usize);
+            for i in 0..num_rows_to_read {
+                builder.append_value(b.value(i as usize))?;
+            }
+            Ok(Arc::new(builder.finish()))
+        }
+        DataType::Int64 => {
+            let b = a.as_any().downcast_ref::<Int64Array>().unwrap();
+            let mut builder = Int64Array::builder(num_rows_to_read as usize);
+            for i in 0..num_rows_to_read {
+                builder.append_value(b.value(i as usize))?;
+            }
+            Ok(Arc::new(builder.finish()))
+        }
+        DataType::Float32 => {
+            let b = a.as_any().downcast_ref::<Float32Array>().unwrap();
+            let mut builder = Float32Array::builder(num_rows_to_read as usize);
+            for i in 0..num_rows_to_read {
+                builder.append_value(b.value(i as usize))?;
+            }
+            Ok(Arc::new(builder.finish()))
+        }
+        DataType::Float64 => {
+            let b = a.as_any().downcast_ref::<Float64Array>().unwrap();
+            let mut builder = Float64Array::builder(num_rows_to_read as usize);
+            for i in 0..num_rows_to_read {
+                builder.append_value(b.value(i as usize))?;
+            }
+            Ok(Arc::new(builder.finish()))
+        }
+        DataType::Utf8 => {
+            //TODO: this is inefficient and we should improve the Arrow impl to help make this more concise
+            let b = a.as_any().downcast_ref::<BinaryArray>().unwrap();
+            let mut values: Vec<String> = Vec::with_capacity(num_rows_to_read as usize);
+            for i in 0..num_rows_to_read {
+                values.push(b.get_string(i as usize));
+            }
+            let tmp: Vec<&str> = values.iter().map(|s| s.as_str()).collect();
+            Ok(Arc::new(BinaryArray::from(tmp)))
+        }
+        other => Err(ExecutionError::ExecutionError(format!(
+            "filter not supported for {:?}",
+            other
+        ))),
+    }
+}
diff --git a/rust/datafusion/src/execution/mod.rs b/rust/datafusion/src/execution/mod.rs
index 23144bb..9eb303f 100644
--- a/rust/datafusion/src/execution/mod.rs
+++ b/rust/datafusion/src/execution/mod.rs
@@ -21,6 +21,7 @@ pub mod datasource;
 pub mod error;
 pub mod expression;
 pub mod filter;
+pub mod limit;
 pub mod physicalplan;
 pub mod projection;
 pub mod relation;
diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs
index b3e6bda..7dd4602 100644
--- a/rust/datafusion/src/logicalplan.rs
+++ b/rust/datafusion/src/logicalplan.rs
@@ -350,6 +350,12 @@ pub enum LogicalPlan {
     },
     /// An empty relation with an empty schema
     EmptyRelation { schema: Arc<Schema> },
+    // Represents the maximum number of records to return
+    Limit {
+        expr: Expr,
+        input: Rc<LogicalPlan>,
+        schema: Arc<Schema>,
+    },
 }
 
 impl LogicalPlan {
@@ -362,6 +368,7 @@ impl LogicalPlan {
             LogicalPlan::Selection { input, .. } => input.schema(),
             LogicalPlan::Aggregate { schema, .. } => &schema,
             LogicalPlan::Sort { schema, .. } => &schema,
+            LogicalPlan::Limit { schema, .. } => &schema,
         }
     }
 }
@@ -430,6 +437,14 @@ impl LogicalPlan {
                 }
                 input.fmt_with_indent(f, indent + 1)
             }
+            LogicalPlan::Limit {
+                ref input,
+                ref expr,
+                ..
+            } => {
+                write!(f, "Limit: {:?}", expr)?;
+                input.fmt_with_indent(f, indent + 1)
+            }
         }
     }
 }
diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs
index 8fd2e8c..b8d98fe 100644
--- a/rust/datafusion/src/optimizer/projection_push_down.rs
+++ b/rust/datafusion/src/optimizer/projection_push_down.rs
@@ -179,6 +179,15 @@ impl ProjectionPushDown {
                     projection: Some(projection),
                 }))
             }
+            LogicalPlan::Limit {
+                expr,
+                input,
+                schema,
+            } => Ok(Rc::new(LogicalPlan::Limit {
+                expr: expr.clone(),
+                input: input.clone(),
+                schema: schema.clone(),
+            })),
         }
     }
 
diff --git a/rust/datafusion/src/sqlplanner.rs b/rust/datafusion/src/sqlplanner.rs
index dcb69eb..fc8048f 100644
--- a/rust/datafusion/src/sqlplanner.rs
+++ b/rust/datafusion/src/sqlplanner.rs
@@ -53,6 +53,7 @@ impl SqlToRel {
                 ref relation,
                 ref selection,
                 ref order_by,
+                ref limit,
                 ref group_by,
                 ref having,
                 ..
@@ -167,7 +168,22 @@ impl SqlToRel {
                         _ => projection,
                     };
 
-                    Ok(Rc::new(order_by_plan))
+                    let limit_plan = match limit {
+                        &Some(ref limit_expr) => {
+                            let input_schema = order_by_plan.schema();
+                            let limit_rex =
+                                self.sql_to_rex(&limit_expr, &input_schema.clone())?;
+
+                            LogicalPlan::Limit {
+                                expr: limit_rex,
+                                input: Rc::new(order_by_plan.clone()),
+                                schema: input_schema.clone(),
+                            }
+                        }
+                        _ => order_by_plan,
+                    };
+
+                    Ok(Rc::new(limit_plan))
                 }
             }
 
@@ -491,6 +507,7 @@ pub fn push_down_projection(
         }),
         LogicalPlan::Projection { .. } => plan.clone(),
         LogicalPlan::Sort { .. } => plan.clone(),
+        LogicalPlan::Limit { .. } => plan.clone(),
         LogicalPlan::EmptyRelation { .. } => plan.clone(),
     }
 }
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index bd22808..6f96980 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -72,6 +72,59 @@ fn csv_query_cast() {
     assert_eq!(expected, actual);
 }
 
+#[test]
+fn csv_query_limit() {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx);
+    let sql = "SELECT 0 FROM aggregate_test_100 LIMIT 2";
+    let actual = execute(&mut ctx, sql);
+    let expected = "0\n0\n".to_string();
+    assert_eq!(expected, actual);
+}
+
+#[test]
+fn csv_query_limit_bigger_than_nbr_of_rows() {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx);
+    let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 200";
+    let actual = execute(&mut ctx, sql);
+    let expected = "2\n5\n1\n1\n5\n4\n3\n3\n1\n4\n1\n4\n3\n2\n1\n1\n2\n1\n3\n2\n4\n1\n5\n4\n2\n1\n4\n5\n2\n3\n4\n2\n1\n5\n3\n1\n2\n3\n3\n3\n2\n4\n1\n3\n2\n5\n2\n1\n4\n1\n4\n2\n5\n4\n2\n3\n4\n4\n4\n5\n4\n2\n1\n2\n4\n2\n3\n5\n1\n1\n4\n2\n1\n2\n1\n1\n5\n4\n5\n2\n3\n2\n4\n1\n3\n4\n3\n2\n5\n3\n3\n2\n5\n5\n4\n1\n3\n3\n4\n4\n".to_string();
+    assert_eq!(expected, actual);
+}
+
+#[test]
+fn csv_query_limit_with_same_nbr_of_rows() {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx);
+    let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 100";
+    let actual = execute(&mut ctx, sql);
+    let expected = "2\n5\n1\n1\n5\n4\n3\n3\n1\n4\n1\n4\n3\n2\n1\n1\n2\n1\n3\n2\n4\n1\n5\n4\n2\n1\n4\n5\n2\n3\n4\n2\n1\n5\n3\n1\n2\n3\n3\n3\n2\n4\n1\n3\n2\n5\n2\n1\n4\n1\n4\n2\n5\n4\n2\n3\n4\n4\n4\n5\n4\n2\n1\n2\n4\n2\n3\n5\n1\n1\n4\n2\n1\n2\n1\n1\n5\n4\n5\n2\n3\n2\n4\n1\n3\n4\n3\n2\n5\n3\n3\n2\n5\n5\n4\n1\n3\n3\n4\n4\n".to_string();
+    assert_eq!(expected, actual);
+}
+
+#[test]
+fn csv_query_limit_zero() {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx);
+    let sql = "SELECT 0 FROM aggregate_test_100 LIMIT 0";
+    let actual = execute(&mut ctx, sql);
+    let expected = "".to_string();
+    assert_eq!(expected, actual);
+}
+
+//TODO Uncomment the following test when ORDER BY is implemented to be able to test ORDER BY + LIMIT
+/*
+#[test]
+fn csv_query_limit_with_order_by() {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx);
+    let sql = "SELECT c7 FROM aggregate_test_100 ORDER BY c7 ASC LIMIT 2";
+    let actual = execute(&mut ctx, sql);
+    let expected = "0\n2\n".to_string();
+    assert_eq!(expected, actual);
+}
+*/
+
 fn aggr_test_schema() -> Arc<Schema> {
     Arc::new(Schema::new(vec![
         Field::new("c1", DataType::Utf8, false),