You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/05/06 16:48:40 UTC

[arrow-datafusion] branch master updated: Fix bugs in SQL planner with GROUP BY scalar function and alias (#2457)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 675eb82f9 Fix bugs in SQL planner with GROUP BY scalar function and alias (#2457)
675eb82f9 is described below

commit 675eb82f95a4fcbdd695861f628dd31c1701c6c4
Author: Andy Grove <ag...@apache.org>
AuthorDate: Fri May 6 10:48:35 2022 -0600

    Fix bugs in SQL planner with GROUP BY scalar function and alias (#2457)
---
 datafusion/core/src/sql/planner.rs    | 31 +++++++++++++++++++++++--------
 datafusion/core/src/sql/utils.rs      | 16 ++++++++++++++++
 datafusion/core/tests/sql/group_by.rs | 26 ++++++++++++++++++++++++++
 3 files changed, 65 insertions(+), 8 deletions(-)

diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs
index d4002997b..fe737d6e8 100644
--- a/datafusion/core/src/sql/planner.rs
+++ b/datafusion/core/src/sql/planner.rs
@@ -37,7 +37,7 @@ use crate::logical_plan::{
 use crate::optimizer::utils::exprlist_to_columns;
 use crate::prelude::JoinType;
 use crate::scalar::ScalarValue;
-use crate::sql::utils::{make_decimal_type, normalize_ident};
+use crate::sql::utils::{make_decimal_type, normalize_ident, resolve_columns};
 use crate::{
     error::{DataFusionError, Result},
     physical_plan::aggregates,
@@ -1144,30 +1144,45 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         group_by_exprs: Vec<Expr>,
         aggr_exprs: Vec<Expr>,
     ) -> Result<(LogicalPlan, Vec<Expr>, Option<Expr>)> {
+        // create the aggregate plan
+        let plan = LogicalPlanBuilder::from(input.clone())
+            .aggregate(group_by_exprs.clone(), aggr_exprs.clone())?
+            .build()?;
+
+        // in this next section of code we are re-writing the projection to refer to columns
+        // output by the aggregate plan. For example, if the projection contains the expression
+        // `SUM(a)` then we replace that with a reference to a column `#SUM(a)` produced by
+        // the aggregate plan.
+
+        // combine the original grouping and aggregate expressions into one list (note that
+        // we do not add the "having" expression since that is not part of the projection)
         let aggr_projection_exprs = group_by_exprs
             .iter()
             .chain(aggr_exprs.iter())
             .cloned()
             .collect::<Vec<Expr>>();
 
-        let plan = LogicalPlanBuilder::from(input.clone())
-            .aggregate(group_by_exprs, aggr_exprs)?
-            .build()?;
+        // now attempt to resolve columns and replace with fully-qualified columns
+        let aggr_projection_exprs = aggr_projection_exprs
+            .iter()
+            .map(|expr| resolve_columns(expr, &input))
+            .collect::<Result<Vec<Expr>>>()?;
 
-        // After aggregation, these are all of the columns that will be
-        // available to next phases of planning.
+        // next we replace any expressions that are not a column with a column referencing
+        // an output column from the aggregate schema
         let column_exprs_post_aggr = aggr_projection_exprs
             .iter()
             .map(|expr| expr_as_column_expr(expr, &input))
             .collect::<Result<Vec<Expr>>>()?;
 
-        // Rewrite the SELECT expression to use the columns produced by the
-        // aggregation.
+        // next we re-write the projection
         let select_exprs_post_aggr = select_exprs
             .iter()
             .map(|expr| rebase_expr(expr, &aggr_projection_exprs, &input))
             .collect::<Result<Vec<Expr>>>()?;
 
+        // finally, we have some validation that the re-written projection can be resolved
+        // from the aggregate output columns
         check_columns_satisfy_exprs(
             &column_exprs_post_aggr,
             &select_exprs_post_aggr,
diff --git a/datafusion/core/src/sql/utils.rs b/datafusion/core/src/sql/utils.rs
index cd1fb316b..4acaa21ef 100644
--- a/datafusion/core/src/sql/utils.rs
+++ b/datafusion/core/src/sql/utils.rs
@@ -155,6 +155,22 @@ pub(crate) fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Exp
     }
 }
 
+/// Make a best-effort attempt at resolving all columns in the expression tree
+pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
+    clone_with_replacement(expr, &|nested_expr| {
+        match nested_expr {
+            Expr::Column(col) => {
+                let field = plan.schema().field_from_column(col)?;
+                Ok(Some(Expr::Column(field.qualified_column())))
+            }
+            _ => {
+                // keep recursing
+                Ok(None)
+            }
+        }
+    })
+}
+
 /// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s.
 ///
 /// For example, the expression `a + b < 1` would require, as input, the 2
diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs
index 41f2471f6..e3da1b021 100644
--- a/datafusion/core/tests/sql/group_by.rs
+++ b/datafusion/core/tests/sql/group_by.rs
@@ -211,6 +211,32 @@ async fn csv_query_having_without_group_by() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn csv_query_group_by_substr() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_aggregate_csv(&ctx).await?;
+    // there is an input column "c1" as well a projection expression aliased as "c1"
+    let sql = "SELECT substr(c1, 1, 1) c1 \
+        FROM aggregate_test_100 \
+        GROUP BY substr(c1, 1, 1) \
+        ";
+    let actual = execute_to_batches(&ctx, sql).await;
+    #[rustfmt::skip]
+    let expected = vec![
+        "+----+",
+        "| c1 |",
+        "+----+",
+        "| a  |",
+        "| b  |",
+        "| c  |",
+        "| d  |",
+        "| e  |",
+        "+----+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
 #[tokio::test]
 async fn csv_query_group_by_avg() -> Result<()> {
     let ctx = SessionContext::new();