You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2023/01/18 07:53:32 UTC

[arrow-datafusion] branch maint-16.x updated: Fix column indices in EnforceDistribution optimizer in Partial AggregateMode (#4878) (#4959)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch maint-16.x
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/maint-16.x by this push:
     new 35e34d4d7 Fix column indices in EnforceDistribution optimizer in Partial AggregateMode (#4878) (#4959)
35e34d4d7 is described below

commit 35e34d4d7581c150acb2016477a6115e0de9987c
Author: Jon Mease <jo...@gmail.com>
AuthorDate: Wed Jan 18 02:53:25 2023 -0500

    Fix column indices in EnforceDistribution optimizer in Partial AggregateMode (#4878) (#4959)
---
 .../src/physical_optimizer/dist_enforcement.rs     | 25 +++++++---
 datafusion/core/tests/sql/joins.rs                 | 58 ++++++++++++++++++++++
 2 files changed, 76 insertions(+), 7 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/dist_enforcement.rs b/datafusion/core/src/physical_optimizer/dist_enforcement.rs
index aa8b07569..cc94ad14a 100644
--- a/datafusion/core/src/physical_optimizer/dist_enforcement.rs
+++ b/datafusion/core/src/physical_optimizer/dist_enforcement.rs
@@ -431,11 +431,22 @@ fn reorder_aggregate_keys(
                     None
                 };
                 if let Some(partial_agg) = new_partial_agg {
-                    let mut new_group_exprs = vec![];
-                    for idx in positions.into_iter() {
-                        new_group_exprs.push(group_by.expr()[idx].clone());
-                    }
-                    let new_group_by = PhysicalGroupBy::new_single(new_group_exprs);
+                    // Build new group expressions that correspond to the output of partial_agg
+                    let new_final_group: Vec<Arc<dyn PhysicalExpr>> =
+                        partial_agg.output_group_expr();
+                    let new_group_by = PhysicalGroupBy::new_single(
+                        new_final_group
+                            .iter()
+                            .enumerate()
+                            .map(|(i, expr)| {
+                                (
+                                    expr.clone(),
+                                    partial_agg.group_expr().expr()[i].1.clone(),
+                                )
+                            })
+                            .collect(),
+                    );
+
                     let new_final_agg = Arc::new(AggregateExec::try_new(
                         AggregateMode::FinalPartitioned,
                         new_group_by,
@@ -1494,7 +1505,7 @@ mod tests {
         let expected = &[
             "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"b1\", index: 1 }, Column { name: \"b\", index: 0 }), (Column { name: \"a1\", index: 0 }, Column { name: \"a\", index: 1 })]",
             "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]",
-            "AggregateExec: mode=FinalPartitioned, gby=[b1@1 as b1, a1@0 as a1], aggr=[]",
+            "AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]",
             "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 0 }, Column { name: \"a1\", index: 1 }], 10)",
             "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]",
             "ParquetExec: limit=None, partitions={1 group: [[x]]}, projection=[a, b, c, d, e]",
@@ -2057,7 +2068,7 @@ mod tests {
             "SortExec: [b3@1 ASC,a3@0 ASC]",
             "ProjectionExec: expr=[a1@0 as a3, b1@1 as b3]",
             "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]",
-            "AggregateExec: mode=FinalPartitioned, gby=[b1@1 as b1, a1@0 as a1], aggr=[]",
+            "AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]",
             "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 0 }, Column { name: \"a1\", index: 1 }], 10)",
             "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]",
             "ParquetExec: limit=None, partitions={1 group: [[x]]}, projection=[a, b, c, d, e]",
diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index 1de20c29c..db5c706d3 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -2810,3 +2810,61 @@ async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> {
 
     Ok(())
 }
+
+#[tokio::test]
+async fn test_cross_join_to_groupby_with_different_key_ordering() -> Result<()> {
+    // Regression test for GH #4873
+    let col1 = Arc::new(StringArray::from(vec![
+        "A", "A", "A", "A", "A", "A", "A", "A", "BB", "BB", "BB", "BB",
+    ])) as ArrayRef;
+
+    let col2 =
+        Arc::new(UInt64Array::from(vec![1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6])) as ArrayRef;
+
+    let col3 =
+        Arc::new(UInt64Array::from(vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])) as ArrayRef;
+
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("col1", DataType::Utf8, true),
+        Field::new("col2", DataType::UInt64, true),
+        Field::new("col3", DataType::UInt64, true),
+    ])) as SchemaRef;
+
+    let batch = RecordBatch::try_new(schema.clone(), vec![col1, col2, col3]).unwrap();
+    let mem_table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
+
+    // Create context and register table
+    let ctx = SessionContext::new();
+    ctx.register_table("tbl", Arc::new(mem_table)).unwrap();
+
+    let sql = "select col1, col2, coalesce(sum_col3, 0) as sum_col3 \
+                     from (select distinct col2 from tbl) AS q1 \
+                     cross join (select distinct col1 from tbl) AS q2 \
+                     left outer join (SELECT col1, col2, sum(col3) as sum_col3 FROM tbl GROUP BY col1, col2) AS q3 \
+                     USING(col2, col1) \
+                     ORDER BY col1, col2";
+
+    let expected = vec![
+        "+------+------+----------+",
+        "| col1 | col2 | sum_col3 |",
+        "+------+------+----------+",
+        "| A    | 1    | 2        |",
+        "| A    | 2    | 2        |",
+        "| A    | 3    | 2        |",
+        "| A    | 4    | 2        |",
+        "| A    | 5    | 0        |",
+        "| A    | 6    | 0        |",
+        "| BB   | 1    | 0        |",
+        "| BB   | 2    | 0        |",
+        "| BB   | 3    | 0        |",
+        "| BB   | 4    | 0        |",
+        "| BB   | 5    | 2        |",
+        "| BB   | 6    | 2        |",
+        "+------+------+----------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}