You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2020/03/26 19:03:40 UTC

[arrow] branch master updated: ARROW-4815: [Rust] [DataFusion] Add support for SQL wilcard operator

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 76c6424  ARROW-4815: [Rust] [DataFusion] Add support for SQL wilcard operator
76c6424 is described below

commit 76c642442aa34c2a09cdb473bc68dff74b5870de
Author: Andy Grove <an...@gmail.com>
AuthorDate: Thu Mar 26 13:03:13 2020 -0600

    ARROW-4815: [Rust] [DataFusion] Add support for SQL wilcard operator
    
    Closes #6716 from andygrove/ARROW-4815
    
    Authored-by: Andy Grove <an...@gmail.com>
    Signed-off-by: Andy Grove <an...@gmail.com>
---
 rust/datafusion/src/logicalplan.rs                 | 61 +++++++++++++++-------
 .../src/optimizer/projection_push_down.rs          | 13 +++--
 rust/datafusion/src/optimizer/type_coercion.rs     |  4 +-
 rust/datafusion/src/optimizer/utils.rs             | 39 +++++++++-----
 rust/datafusion/src/sql/planner.rs                 | 33 +++++++-----
 5 files changed, 98 insertions(+), 52 deletions(-)

diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs
index 48fbae9..fd23ff1 100644
--- a/rust/datafusion/src/logicalplan.rs
+++ b/rust/datafusion/src/logicalplan.rs
@@ -232,37 +232,42 @@ pub enum Expr {
         /// The `DataType` the expression will yield
         return_type: DataType,
     },
+    /// Wildcard
+    Wildcard,
 }
 
 impl Expr {
     /// Find the `DataType` for the expression
-    pub fn get_type(&self, schema: &Schema) -> DataType {
+    pub fn get_type(&self, schema: &Schema) -> Result<DataType> {
         match self {
             Expr::Alias(expr, _) => expr.get_type(schema),
-            Expr::Column(n) => schema.field(*n).data_type().clone(),
-            Expr::Literal(l) => l.get_datatype(),
-            Expr::Cast { data_type, .. } => data_type.clone(),
-            Expr::ScalarFunction { return_type, .. } => return_type.clone(),
-            Expr::AggregateFunction { return_type, .. } => return_type.clone(),
-            Expr::Not(_) => DataType::Boolean,
-            Expr::IsNull(_) => DataType::Boolean,
-            Expr::IsNotNull(_) => DataType::Boolean,
+            Expr::Column(n) => Ok(schema.field(*n).data_type().clone()),
+            Expr::Literal(l) => Ok(l.get_datatype()),
+            Expr::Cast { data_type, .. } => Ok(data_type.clone()),
+            Expr::ScalarFunction { return_type, .. } => Ok(return_type.clone()),
+            Expr::AggregateFunction { return_type, .. } => Ok(return_type.clone()),
+            Expr::Not(_) => Ok(DataType::Boolean),
+            Expr::IsNull(_) => Ok(DataType::Boolean),
+            Expr::IsNotNull(_) => Ok(DataType::Boolean),
             Expr::BinaryExpr {
                 ref left,
                 ref right,
                 ref op,
             } => match op {
-                Operator::Eq | Operator::NotEq => DataType::Boolean,
-                Operator::Lt | Operator::LtEq => DataType::Boolean,
-                Operator::Gt | Operator::GtEq => DataType::Boolean,
-                Operator::And | Operator::Or => DataType::Boolean,
+                Operator::Eq | Operator::NotEq => Ok(DataType::Boolean),
+                Operator::Lt | Operator::LtEq => Ok(DataType::Boolean),
+                Operator::Gt | Operator::GtEq => Ok(DataType::Boolean),
+                Operator::And | Operator::Or => Ok(DataType::Boolean),
                 _ => {
-                    let left_type = left.get_type(schema);
-                    let right_type = right.get_type(schema);
-                    utils::get_supertype(&left_type, &right_type).unwrap()
+                    let left_type = left.get_type(schema)?;
+                    let right_type = right.get_type(schema)?;
+                    utils::get_supertype(&left_type, &right_type)
                 }
             },
             Expr::Sort { ref expr, .. } => expr.get_type(schema),
+            Expr::Wildcard => Err(ExecutionError::General(
+                "Wildcard expressions are not valid in a logical query plan".to_owned(),
+            )),
         }
     }
 
@@ -270,7 +275,7 @@ impl Expr {
     ///
     /// Will `Err` if the type cast cannot be performed.
     pub fn cast_to(&self, cast_to_type: &DataType, schema: &Schema) -> Result<Expr> {
-        let this_type = self.get_type(schema);
+        let this_type = self.get_type(schema)?;
         if this_type == *cast_to_type {
             Ok(self.clone())
         } else if can_coerce_from(cast_to_type, &this_type) {
@@ -414,6 +419,7 @@ impl fmt::Debug for Expr {
 
                 write!(f, ")")
             }
+            Expr::Wildcard => write!(f, "*"),
         }
     }
 }
@@ -698,12 +704,27 @@ impl LogicalPlanBuilder {
     /// Apply a projection
     pub fn project(&self, expr: &Vec<Expr>) -> Result<Self> {
         let input_schema = self.plan.schema();
+        let projected_expr = if expr.contains(&Expr::Wildcard) {
+            let mut expr_vec = vec![];
+            (0..expr.len()).for_each(|i| match &expr[i] {
+                Expr::Wildcard => {
+                    (0..input_schema.fields().len())
+                        .for_each(|i| expr_vec.push(col(i).clone()));
+                }
+                _ => expr_vec.push(expr[i].clone()),
+            });
+            expr_vec
+        } else {
+            expr.clone()
+        };
 
-        let schema =
-            Schema::new(utils::exprlist_to_fields(&expr, input_schema.as_ref())?);
+        let schema = Schema::new(utils::exprlist_to_fields(
+            &projected_expr,
+            input_schema.as_ref(),
+        )?);
 
         Ok(Self::from(&LogicalPlan::Projection {
-            expr: expr.clone(),
+            expr: projected_expr,
             input: Arc::new(self.plan.clone()),
             schema: Arc::new(schema),
         }))
diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs
index 5a541c9..a05f865 100644
--- a/rust/datafusion/src/optimizer/projection_push_down.rs
+++ b/rust/datafusion/src/optimizer/projection_push_down.rs
@@ -58,7 +58,7 @@ impl ProjectionPushDown {
                 schema,
             } => {
                 // collect all columns referenced by projection expressions
-                utils::exprlist_to_column_indices(&expr, accum);
+                utils::exprlist_to_column_indices(&expr, accum)?;
 
                 // push projection down
                 let input = self.optimize_plan(&input, accum, mapping)?;
@@ -74,7 +74,7 @@ impl ProjectionPushDown {
             }
             LogicalPlan::Selection { expr, input } => {
                 // collect all columns referenced by filter expression
-                utils::expr_to_column_indices(expr, accum);
+                utils::expr_to_column_indices(expr, accum)?;
 
                 // push projection down
                 let input = self.optimize_plan(&input, accum, mapping)?;
@@ -94,8 +94,8 @@ impl ProjectionPushDown {
                 schema,
             } => {
                 // collect all columns referenced by grouping and aggregate expressions
-                utils::exprlist_to_column_indices(&group_expr, accum);
-                utils::exprlist_to_column_indices(&aggr_expr, accum);
+                utils::exprlist_to_column_indices(&group_expr, accum)?;
+                utils::exprlist_to_column_indices(&aggr_expr, accum)?;
 
                 // push projection down
                 let input = self.optimize_plan(&input, accum, mapping)?;
@@ -117,7 +117,7 @@ impl ProjectionPushDown {
                 schema,
             } => {
                 // collect all columns referenced by sort expressions
-                utils::exprlist_to_column_indices(&expr, accum);
+                utils::exprlist_to_column_indices(&expr, accum)?;
 
                 // push projection down
                 let input = self.optimize_plan(&input, accum, mapping)?;
@@ -271,6 +271,9 @@ impl ProjectionPushDown {
                 args: self.rewrite_exprs(args, mapping)?,
                 return_type: return_type.clone(),
             }),
+            Expr::Wildcard => Err(ExecutionError::General(
+                "Wildcard expressions are not valid in a logical query plan".to_owned(),
+            )),
         }
     }
 
diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs
index bfd63b4..e93d01e 100644
--- a/rust/datafusion/src/optimizer/type_coercion.rs
+++ b/rust/datafusion/src/optimizer/type_coercion.rs
@@ -96,8 +96,8 @@ fn rewrite_expr(expr: &Expr, schema: &Schema) -> Result<Expr> {
         Expr::BinaryExpr { left, op, right } => {
             let left = rewrite_expr(left, schema)?;
             let right = rewrite_expr(right, schema)?;
-            let left_type = left.get_type(schema);
-            let right_type = right.get_type(schema);
+            let left_type = left.get_type(schema)?;
+            let right_type = right.get_type(schema)?;
             if left_type == right_type {
                 Ok(Expr::BinaryExpr {
                     left: Arc::new(left),
diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs
index b4d9889..1755ac6 100644
--- a/rust/datafusion/src/optimizer/utils.rs
+++ b/rust/datafusion/src/optimizer/utils.rs
@@ -26,30 +26,44 @@ use crate::logicalplan::Expr;
 
 /// Recursively walk a list of expression trees, collecting the unique set of column
 /// indexes referenced in the expression
-pub fn exprlist_to_column_indices(expr: &Vec<Expr>, accum: &mut HashSet<usize>) {
-    expr.iter().for_each(|e| expr_to_column_indices(e, accum));
+pub fn exprlist_to_column_indices(
+    expr: &Vec<Expr>,
+    accum: &mut HashSet<usize>,
+) -> Result<()> {
+    for e in expr {
+        expr_to_column_indices(e, accum)?;
+    }
+    Ok(())
 }
 
 /// Recursively walk an expression tree, collecting the unique set of column indexes
 /// referenced in the expression
-pub fn expr_to_column_indices(expr: &Expr, accum: &mut HashSet<usize>) {
+pub fn expr_to_column_indices(expr: &Expr, accum: &mut HashSet<usize>) -> Result<()> {
     match expr {
         Expr::Alias(expr, _) => expr_to_column_indices(expr, accum),
         Expr::Column(i) => {
             accum.insert(*i);
+            Ok(())
+        }
+        Expr::Literal(_) => {
+            // not needed
+            Ok(())
         }
-        Expr::Literal(_) => { /* not needed */ }
         Expr::Not(e) => expr_to_column_indices(e, accum),
         Expr::IsNull(e) => expr_to_column_indices(e, accum),
         Expr::IsNotNull(e) => expr_to_column_indices(e, accum),
         Expr::BinaryExpr { left, right, .. } => {
-            expr_to_column_indices(left, accum);
-            expr_to_column_indices(right, accum);
+            expr_to_column_indices(left, accum)?;
+            expr_to_column_indices(right, accum)?;
+            Ok(())
         }
         Expr::Cast { expr, .. } => expr_to_column_indices(expr, accum),
         Expr::Sort { expr, .. } => expr_to_column_indices(expr, accum),
         Expr::AggregateFunction { args, .. } => exprlist_to_column_indices(args, accum),
         Expr::ScalarFunction { args, .. } => exprlist_to_column_indices(args, accum),
+        Expr::Wildcard => Err(ExecutionError::General(
+            "Wildcard expressions are not valid in a logical query plan".to_owned(),
+        )),
     }
 }
 
@@ -57,7 +71,7 @@ pub fn expr_to_column_indices(expr: &Expr, accum: &mut HashSet<usize>) {
 pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result<Field> {
     match e {
         Expr::Alias(expr, name) => {
-            Ok(Field::new(name, expr.get_type(input_schema), true))
+            Ok(Field::new(name, expr.get_type(input_schema)?, true))
         }
         Expr::Column(i) => {
             let input_schema_field_count = input_schema.fields().len();
@@ -89,8 +103,8 @@ pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result<Field> {
             ref right,
             ..
         } => {
-            let left_type = left.get_type(input_schema);
-            let right_type = right.get_type(input_schema);
+            let left_type = left.get_type(input_schema)?;
+            let right_type = right.get_type(input_schema)?;
             Ok(Field::new(
                 "binary_expr",
                 get_supertype(&left_type, &right_type).unwrap(),
@@ -235,7 +249,7 @@ mod tests {
     use std::sync::Arc;
 
     #[test]
-    fn test_collect_expr() {
+    fn test_collect_expr() -> Result<()> {
         let mut accum: HashSet<usize> = HashSet::new();
         expr_to_column_indices(
             &Expr::Cast {
@@ -243,15 +257,16 @@ mod tests {
                 data_type: DataType::Float64,
             },
             &mut accum,
-        );
+        )?;
         expr_to_column_indices(
             &Expr::Cast {
                 expr: Arc::new(Expr::Column(3)),
                 data_type: DataType::Float64,
             },
             &mut accum,
-        );
+        )?;
         assert_eq!(1, accum.len());
         assert!(accum.contains(&3));
+        Ok(())
     }
 }
diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs
index 000dd49..54a5895 100644
--- a/rust/datafusion/src/sql/planner.rs
+++ b/rust/datafusion/src/sql/planner.rs
@@ -287,9 +287,7 @@ impl<S: SchemaProvider> SqlToRel<S> {
                 }
             }
 
-            ASTNode::SQLWildcard => {
-                Err(ExecutionError::NotImplemented("SQL wildcard operator is not supported in projection - please use explicit column names".to_string()))
-            }
+            ASTNode::SQLWildcard => Ok(Expr::Wildcard),
 
             ASTNode::SQLCast {
                 ref expr,
@@ -307,17 +305,17 @@ impl<S: SchemaProvider> SqlToRel<S> {
                 Ok(Expr::IsNotNull(Arc::new(self.sql_to_rex(expr, schema)?)))
             }
 
-            ASTNode::SQLUnary{
+            ASTNode::SQLUnary {
                 ref operator,
                 ref expr,
-            } => {
-                match *operator {
-                    SQLOperator::Not => Ok(Expr::Not(Arc::new(self.sql_to_rex(expr, schema)?))),
-                    _ => Err(ExecutionError::InternalError(format!(
-                        "SQL binary operator cannot be interpreted as a unary operator"
-                    ))),
+            } => match *operator {
+                SQLOperator::Not => {
+                    Ok(Expr::Not(Arc::new(self.sql_to_rex(expr, schema)?)))
                 }
-            }
+                _ => Err(ExecutionError::InternalError(format!(
+                    "SQL binary operator cannot be interpreted as a unary operator"
+                ))),
+            },
 
             ASTNode::SQLBinaryExpr {
                 ref left,
@@ -370,7 +368,7 @@ impl<S: SchemaProvider> SqlToRel<S> {
 
                         // return type is same as the argument type for these aggregate
                         // functions
-                        let return_type = rex_args[0].get_type(schema).clone();
+                        let return_type = rex_args[0].get_type(schema)?.clone();
 
                         Ok(Expr::AggregateFunction {
                             name: id.clone(),
@@ -387,7 +385,7 @@ impl<S: SchemaProvider> SqlToRel<S> {
                                 }
                                 ASTNode::SQLWildcard => {
                                     Ok(Expr::Literal(ScalarValue::UInt8(1)))
-                                },
+                                }
                                 _ => self.sql_to_rex(a, schema),
                             })
                             .collect::<Result<Vec<Expr>>>()?;
@@ -576,6 +574,15 @@ mod tests {
     }
 
     #[test]
+    fn test_wildcard() {
+        quick_test(
+            "SELECT * from person",
+            "Projection: #0, #1, #2, #3, #4, #5, #6\
+            \n  TableScan: person projection=None",
+        );
+    }
+
+    #[test]
     fn select_count_one() {
         let sql = "SELECT COUNT(1) FROM person";
         let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\