You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ji...@apache.org on 2022/02/27 14:02:04 UTC

[arrow-datafusion] branch master updated: Fix incorrect aggregation in case that GROUP BY contains duplicate column names (#1855)

This is an automated email from the ASF dual-hosted git repository.

jiayuliu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 24a8624  Fix incorrect aggregation in case that GROUP BY contains duplicate column names (#1855)
24a8624 is described below

commit 24a86247fd4a4bc171e7380a87ce4cb41f887ecd
Author: Alexander Spies <88...@users.noreply.github.com>
AuthorDate: Sun Feb 27 15:01:56 2022 +0100

    Fix incorrect aggregation in case that GROUP BY contains duplicate column names (#1855)
    
    * Add test for SUM with groups with repeating column names
    
    Aggregation queries produce wrong results if the GROUP BY clause
    contains columns with identical names; this can happen if two tables are
    joined which happen to have identical column names.
    
    Add a test case for this bug.
    
    * Fix wrong aggregation in case of duplicate column names
    
    When planning an Aggregate in DefaultPhysicalPlanner, two
    HashAggregateExecs are performed, a partial and a final one. The column
    indices for the second HashAggregateExec's input are computed wrongly in
    case column names are repeated. (This can happen if two tables were
    joined that happen to have a column name in common).
    
    Add a method output_group_expr to HashAggregateExec that provides its
    output columns with the correct indices; use it to obtain the group_expr
    for the final HashAggregateExec's input.
---
 datafusion/src/physical_plan/hash_aggregate.rs | 13 +++++++++
 datafusion/src/physical_plan/planner.rs        |  5 +---
 datafusion/tests/sql/aggregates.rs             | 40 ++++++++++++++++++++++++++
 3 files changed, 54 insertions(+), 4 deletions(-)

diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs
index 89877b3..33d3bcc 100644
--- a/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/datafusion/src/physical_plan/hash_aggregate.rs
@@ -163,6 +163,19 @@ impl HashAggregateExec {
         &self.group_expr
     }
 
+    /// Grouping expressions as they occur in the output schema
+    pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+        // Update column indices. Since the group by columns come first in the output schema, their
+        // indices are simply 0..self.group_expr(len).
+        self.group_expr
+            .iter()
+            .enumerate()
+            .map(|(index, (_col, name))| {
+                Arc::new(Column::new(name, index)) as Arc<dyn PhysicalExpr>
+            })
+            .collect()
+    }
+
     /// Aggregate expressions
     pub fn aggr_expr(&self) -> &[Arc<dyn AggregateExpr>] {
         &self.aggr_expr
diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs
index ce3351b..88ff9fa 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -61,7 +61,6 @@ use arrow::compute::SortOptions;
 use arrow::datatypes::{Schema, SchemaRef};
 use arrow::{compute::can_cast_types, datatypes::DataType};
 use async_trait::async_trait;
-use expressions::col;
 use futures::future::BoxFuture;
 use futures::{FutureExt, StreamExt, TryStreamExt};
 use log::{debug, trace};
@@ -524,9 +523,7 @@ impl DefaultPhysicalPlanner {
                     )?);
 
                     // update group column indices based on partial aggregate plan evaluation
-                    let final_group: Vec<Arc<dyn PhysicalExpr>> = (0..groups.len())
-                        .map(|i| col(&groups[i].1, &initial_aggr.schema()))
-                        .collect::<Result<_>>()?;
+                    let final_group: Vec<Arc<dyn PhysicalExpr>> = initial_aggr.output_group_expr();
 
                     // TODO: dictionary type not yet supported in Hash Repartition
                     let contains_dict = groups
diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs
index 528386d..187778c 100644
--- a/datafusion/tests/sql/aggregates.rs
+++ b/datafusion/tests/sql/aggregates.rs
@@ -477,6 +477,46 @@ async fn csv_query_approx_percentile_cont() -> Result<()> {
 }
 
 #[tokio::test]
+async fn csv_query_sum_crossjoin() {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv_by_sql(&mut ctx).await;
+    let sql = "SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY a.c1, b.c1 ORDER BY a.c1, b.c1";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    let expected = vec![
+        "+----+----+-----------+",
+        "| c1 | c1 | SUM(a.c2) |",
+        "+----+----+-----------+",
+        "| a  | a  | 1260      |",
+        "| a  | b  | 1140      |",
+        "| a  | c  | 1260      |",
+        "| a  | d  | 1080      |",
+        "| a  | e  | 1260      |",
+        "| b  | a  | 1302      |",
+        "| b  | b  | 1178      |",
+        "| b  | c  | 1302      |",
+        "| b  | d  | 1116      |",
+        "| b  | e  | 1302      |",
+        "| c  | a  | 1176      |",
+        "| c  | b  | 1064      |",
+        "| c  | c  | 1176      |",
+        "| c  | d  | 1008      |",
+        "| c  | e  | 1176      |",
+        "| d  | a  | 924       |",
+        "| d  | b  | 836       |",
+        "| d  | c  | 924       |",
+        "| d  | d  | 792       |",
+        "| d  | e  | 924       |",
+        "| e  | a  | 1323      |",
+        "| e  | b  | 1197      |",
+        "| e  | c  | 1323      |",
+        "| e  | d  | 1134      |",
+        "| e  | e  | 1323      |",
+        "+----+----+-----------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+}
+
+#[tokio::test]
 async fn query_count_without_from() -> Result<()> {
     let mut ctx = ExecutionContext::new();
     let sql = "SELECT count(1 + 1)";