You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2020/12/11 19:40:46 UTC

[GitHub] [arrow] alamb commented on a change in pull request #8836: ARROW-10808: [Rust][DataFusion] Support nested expressions in aggregations.

alamb commented on a change in pull request #8836:
URL: https://github.com/apache/arrow/pull/8836#discussion_r541192133



##########
File path: rust/datafusion/src/sql/planner.rs
##########
@@ -527,6 +559,29 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> {
         LogicalPlanBuilder::from(&plan).sort(order_by_rex?)?.build()
     }
 
+    /// Validate the schema provides all of the columns referenced in the expressions.
+    fn validate_schema_satisfies_exprs(

Review comment:
       this is a good check

##########
File path: rust/datafusion/src/sql/planner.rs
##########
@@ -407,80 +412,107 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> {
         };
         let plan = plan?;
 
-        let projection_expr: Vec<Expr> = select
-            .projection
-            .iter()
-            .map(|e| self.sql_select_to_rex(&e, &plan.schema()))
-            .collect::<Result<Vec<Expr>>>()?;
+        // The SELECT expressions, with wildcards expanded.
+        let select_exprs = self.prepare_select_exprs(&plan, &select.projection)?;
 
-        let aggr_expr: Vec<Expr> = projection_expr
-            .iter()
-            .filter(|e| is_aggregate_expr(e))
-            .map(|e| e.clone())
-            .collect();
+        // All of the aggregate expressions (deduplicated).
+        let aggr_exprs = find_aggregate_exprs(&select_exprs);
 
-        // apply projection or aggregate
-        let plan = if (select.group_by.len() > 0) | (aggr_expr.len() > 0) {
-            self.aggregate(&plan, projection_expr, &select.group_by, aggr_expr)?
-        } else {
-            self.project(&plan, projection_expr)?
-        };
-        Ok(plan)
+        let (plan, select_exprs_post_aggr) =
+            if select.group_by.len() > 0 || aggr_exprs.len() > 0 {
+                self.aggregate(&plan, &select_exprs, &select.group_by, &aggr_exprs)?
+            } else {
+                (plan, select_exprs)
+            };
+
+        self.project(&plan, select_exprs_post_aggr, false)
+    }
+
+    /// Returns the `Expr`'s corresponding to a SQL query's SELECT expressions.
+    ///
+    /// Wildcards are expanded into the concrete list of columns.
+    fn prepare_select_exprs(
+        &self,
+        plan: &LogicalPlan,
+        projection: &Vec<SelectItem>,
+    ) -> Result<Vec<Expr>> {
+        let input_schema = plan.schema();
+
+        Ok(projection
+            .iter()
+            .map(|expr| self.sql_select_to_rex(&expr, &input_schema))
+            .collect::<Result<Vec<Expr>>>()?
+            .iter()
+            .flat_map(|expr| expand_wildcard(&expr, &input_schema))
+            .collect::<Vec<Expr>>())
     }
 
     /// Wrap a plan in a projection
-    fn project(&self, input: &LogicalPlan, expr: Vec<Expr>) -> Result<LogicalPlan> {
-        LogicalPlanBuilder::from(input).project(expr)?.build()
+    fn project(

Review comment:
       documenting what `force` does here might be helpful. A follow on PR would be fine

##########
File path: rust/datafusion/src/physical_plan/udf.rs
##########
@@ -56,6 +56,12 @@ impl Debug for ScalarUDF {
     }
 }
 
+impl PartialEq for ScalarUDF {
+    fn eq(&self, other: &Self) -> bool {
+        self.name == other.name && self.signature == other.signature

Review comment:
       I think to handle it in the general case we would have to require that the user defined function / aggregate itself define equality (perhaps with a default implementation that compares function pointers as suggested by @drusso ).
   
   I think user defined functions in this kind of framework are also tagged with other properties (like if they have side effects, and thus can't be optimized away) 
   
   I personally suggest filing a ticket for this issue in the future -- it is kind of like the category of "more mature user defined function support".

##########
File path: rust/datafusion/src/sql/utils.rs
##########
@@ -0,0 +1,305 @@
+use crate::error::{DataFusionError, Result};
+use crate::logical_plan::{Expr, LogicalPlan};
+use arrow::datatypes::Schema;
+
+/// Resolves an `Expr::Wildcard` to a collection of `Expr::Column`'s.
+pub(crate) fn expand_wildcard(expr: &Expr, schema: &Schema) -> Vec<Expr> {
+    match expr {
+        Expr::Wildcard => schema
+            .fields()
+            .iter()
+            .map(|f| Expr::Column(f.name().to_string()))
+            .collect::<Vec<Expr>>(),
+        _ => vec![expr.clone()],
+    }
+}
+
+/// Collect all deeply nested `Expr::AggregateFunction` and
+/// `Expr::AggregateUDF`. They are returned in order of occurrence (depth
+/// first), with duplicates omitted.
+pub(crate) fn find_aggregate_exprs(exprs: &Vec<Expr>) -> Vec<Expr> {
+    find_exprs_in_exprs(exprs, &|nested_expr| match nested_expr {
+        Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } => true,
+        _ => false,
+    })
+}
+
+/// Collect all deeply nested `Expr::Column`'s. They are returned in order of
+/// appearance (depth first), with duplicates omitted.
+pub(crate) fn find_column_exprs(exprs: &Vec<Expr>) -> Vec<Expr> {
+    find_exprs_in_exprs(exprs, &|nested_expr| match nested_expr {
+        Expr::Column(_) => true,
+        _ => false,
+    })
+}
+
+/// Search the provided `Expr`'s, and all of their nested `Expr`, for any that
+/// pass the provided test. The returned `Expr`'s are deduplicated and returned
+/// in order of appearance (depth first).
+fn find_exprs_in_exprs<F>(exprs: &Vec<Expr>, test_fn: &F) -> Vec<Expr>
+where
+    F: Fn(&Expr) -> bool,
+{
+    exprs
+        .iter()
+        .flat_map(|expr| find_exprs_in_expr(expr, test_fn))
+        .fold(vec![], |mut acc, expr| {
+            if !acc.contains(&expr) {
+                acc.push(expr)
+            }
+            acc
+        })
+}
+
+/// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the
+/// provided test. The returned `Expr`'s are deduplicated and returned in order
+/// of appearance (depth first).
+fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>

Review comment:
       👍  -- this is a good basic visitor function. Nice

##########
File path: rust/datafusion/src/sql/planner.rs
##########
@@ -1059,6 +1163,21 @@ mod tests {
         quick_test(sql, expected);
     }
 
+    #[test]
+    fn select_wildcard_with_groupby() {
+        quick_test(
+            "SELECT * FROM person GROUP BY id, first_name, last_name, age, state, salary, birth_date",
+            "Aggregate: groupBy=[[#id, #first_name, #last_name, #age, #state, #salary, #birth_date]], aggr=[[]]\
+             \n  TableScan: person projection=None",
+        );
+        quick_test(
+            "SELECT * FROM (SELECT first_name, last_name FROM person) GROUP BY first_name, last_name",

Review comment:
       wow I didn't realize we handled subqueryes like this. 👍 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org