You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ko...@apache.org on 2021/11/10 02:06:55 UTC

[arrow] 10/12: ARROW-14630: [C++] Fix aggregation over scalar key columns

This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch maint-6.0.x
in repository https://gitbox.apache.org/repos/asf/arrow.git

commit 031c894a4b8de2530c5ea33cab52ab2290711ad2
Author: David Li <li...@gmail.com>
AuthorDate: Mon Nov 8 16:01:03 2021 -0500

    ARROW-14630: [C++] Fix aggregation over scalar key columns
    
    This fixes two issues:
    - GroupByNode would try to finish its future twice
    - GroupByNode wouldn't set the length of the key column batch, which broke down when using a scalar column as the key
    
    Closes #11640 from lidavidm/arrow-14630
    
    Authored-by: David Li <li...@gmail.com>
    Signed-off-by: David Li <li...@gmail.com>
---
 cpp/src/arrow/compute/exec/aggregate_node.cc |  7 ++-
 cpp/src/arrow/compute/exec/plan_test.cc      | 76 ++++++++++++++++++++++++++--
 2 files changed, 76 insertions(+), 7 deletions(-)

diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc
index 904fa4e..ddf6f79 100644
--- a/cpp/src/arrow/compute/exec/aggregate_node.cc
+++ b/cpp/src/arrow/compute/exec/aggregate_node.cc
@@ -372,7 +372,7 @@ class GroupByNode : public ExecNode {
     for (size_t i = 0; i < key_field_ids_.size(); ++i) {
       keys[i] = batch.values[key_field_ids_[i]];
     }
-    ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));
+    ExecBatch key_batch(std::move(keys), batch.length);
 
     // Create a batch with group ids
     ARROW_ASSIGN_OR_RAISE(Datum id_batch, state->grouper->Consume(key_batch));
@@ -527,9 +527,8 @@ class GroupByNode : public ExecNode {
   void StopProducing(ExecNode* output) override {
     DCHECK_EQ(output, outputs_[0]);
 
-    if (input_counter_.Cancel()) {
-      finished_.MarkFinished();
-    } else if (output_counter_.Cancel()) {
+    ARROW_UNUSED(input_counter_.Cancel());
+    if (output_counter_.Cancel()) {
       finished_.MarkFinished();
     }
     inputs_[0]->StopProducing(this);
diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc
index 437a93f..54d807e 100644
--- a/cpp/src/arrow/compute/exec/plan_test.cc
+++ b/cpp/src/arrow/compute/exec/plan_test.cc
@@ -546,7 +546,7 @@ TEST(ExecPlanExecution, StressSourceSink) {
     for (bool parallel : {false, true}) {
       SCOPED_TRACE(parallel ? "parallel" : "single threaded");
 
-      int num_batches = slow && !parallel ? 30 : 300;
+      int num_batches = (slow && !parallel) ? 30 : 300;
 
       ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
       AsyncGenerator<util::optional<ExecBatch>> sink_gen;
@@ -576,7 +576,7 @@ TEST(ExecPlanExecution, StressSourceOrderBy) {
     for (bool parallel : {false, true}) {
       SCOPED_TRACE(parallel ? "parallel" : "single threaded");
 
-      int num_batches = slow && !parallel ? 30 : 300;
+      int num_batches = (slow && !parallel) ? 30 : 300;
 
       ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
       AsyncGenerator<util::optional<ExecBatch>> sink_gen;
@@ -605,6 +605,42 @@ TEST(ExecPlanExecution, StressSourceOrderBy) {
   }
 }
 
+TEST(ExecPlanExecution, StressSourceGroupedSumStop) {
+  auto input_schema = schema({field("a", int32()), field("b", boolean())});
+  for (bool slow : {false, true}) {
+    SCOPED_TRACE(slow ? "slowed" : "unslowed");
+
+    for (bool parallel : {false, true}) {
+      SCOPED_TRACE(parallel ? "parallel" : "single threaded");
+
+      int num_batches = (slow && !parallel) ? 30 : 300;
+
+      ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+      AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+      auto random_data = MakeRandomBatches(input_schema, num_batches);
+
+      SortOptions options({SortKey("a", SortOrder::Ascending)});
+      ASSERT_OK(Declaration::Sequence(
+                    {
+                        {"source", SourceNodeOptions{random_data.schema,
+                                                     random_data.gen(parallel, slow)}},
+                        {"aggregate",
+                         AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
+                                              /*targets=*/{"a"}, /*names=*/{"sum(a)"},
+                                              /*keys=*/{"b"}}},
+                        {"sink", SinkNodeOptions{&sink_gen}},
+                    })
+                    .AddToPlan(plan.get()));
+
+      ASSERT_OK(plan->Validate());
+      ASSERT_OK(plan->StartProducing());
+      plan->StopProducing();
+      ASSERT_FINISHES_OK(plan->finished());
+    }
+  }
+}
+
 TEST(ExecPlanExecution, StressSourceSinkStopped) {
   for (bool slow : {false, true}) {
     SCOPED_TRACE(slow ? "slowed" : "unslowed");
@@ -612,7 +648,7 @@ TEST(ExecPlanExecution, StressSourceSinkStopped) {
     for (bool parallel : {false, true}) {
       SCOPED_TRACE(parallel ? "parallel" : "single threaded");
 
-      int num_batches = slow && !parallel ? 30 : 300;
+      int num_batches = (slow && !parallel) ? 30 : 300;
 
       ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
       AsyncGenerator<util::optional<ExecBatch>> sink_gen;
@@ -1005,6 +1041,40 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) {
       }))));
 }
 
+TEST(ExecPlanExecution, ScalarSourceGroupedSum) {
+  // ARROW-14630: ensure grouped aggregation with a scalar key/array input doesn't error
+  ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+  AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+  BatchesWithSchema scalar_data;
+  scalar_data.batches = {
+      ExecBatchFromJSON({int32(), ValueDescr::Scalar(boolean())},
+                        "[[5, false], [6, false], [7, false]]"),
+      ExecBatchFromJSON({int32(), ValueDescr::Scalar(boolean())},
+                        "[[1, true], [2, true], [3, true]]"),
+  };
+  scalar_data.schema = schema({field("a", int32()), field("b", boolean())});
+
+  SortOptions options({SortKey("b", SortOrder::Descending)});
+  ASSERT_OK(Declaration::Sequence(
+                {
+                    {"source", SourceNodeOptions{scalar_data.schema,
+                                                 scalar_data.gen(/*parallel=*/false,
+                                                                 /*slow=*/false)}},
+                    {"aggregate",
+                     AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
+                                          /*targets=*/{"a"}, /*names=*/{"hash_sum(a)"},
+                                          /*keys=*/{"b"}}},
+                    {"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}},
+                })
+                .AddToPlan(plan.get()));
+
+  ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+              Finishes(ResultWith(UnorderedElementsAreArray({
+                  ExecBatchFromJSON({int64(), boolean()}, R"([[6, true], [18, false]])"),
+              }))));
+}
+
 TEST(ExecPlanExecution, SelfInnerHashJoinSink) {
   for (bool parallel : {false, true}) {
     SCOPED_TRACE(parallel ? "parallel/merged" : "serial");