You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ko...@apache.org on 2021/11/10 02:06:55 UTC
[arrow] 10/12: ARROW-14630: [C++] Fix aggregation over scalar key
columns
This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch maint-6.0.x
in repository https://gitbox.apache.org/repos/asf/arrow.git
commit 031c894a4b8de2530c5ea33cab52ab2290711ad2
Author: David Li <li...@gmail.com>
AuthorDate: Mon Nov 8 16:01:03 2021 -0500
ARROW-14630: [C++] Fix aggregation over scalar key columns
This fixes two issues:
- GroupByNode would try to finish its future twice
- GroupByNode wouldn't set the length of the key column batch, which broke down when using a scalar column as the key
Closes #11640 from lidavidm/arrow-14630
Authored-by: David Li <li...@gmail.com>
Signed-off-by: David Li <li...@gmail.com>
---
cpp/src/arrow/compute/exec/aggregate_node.cc | 7 ++-
cpp/src/arrow/compute/exec/plan_test.cc | 76 ++++++++++++++++++++++++++--
2 files changed, 76 insertions(+), 7 deletions(-)
diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc
index 904fa4e..ddf6f79 100644
--- a/cpp/src/arrow/compute/exec/aggregate_node.cc
+++ b/cpp/src/arrow/compute/exec/aggregate_node.cc
@@ -372,7 +372,7 @@ class GroupByNode : public ExecNode {
for (size_t i = 0; i < key_field_ids_.size(); ++i) {
keys[i] = batch.values[key_field_ids_[i]];
}
- ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));
+ ExecBatch key_batch(std::move(keys), batch.length);
// Create a batch with group ids
ARROW_ASSIGN_OR_RAISE(Datum id_batch, state->grouper->Consume(key_batch));
@@ -527,9 +527,8 @@ class GroupByNode : public ExecNode {
void StopProducing(ExecNode* output) override {
DCHECK_EQ(output, outputs_[0]);
- if (input_counter_.Cancel()) {
- finished_.MarkFinished();
- } else if (output_counter_.Cancel()) {
+ ARROW_UNUSED(input_counter_.Cancel());
+ if (output_counter_.Cancel()) {
finished_.MarkFinished();
}
inputs_[0]->StopProducing(this);
diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc
index 437a93f..54d807e 100644
--- a/cpp/src/arrow/compute/exec/plan_test.cc
+++ b/cpp/src/arrow/compute/exec/plan_test.cc
@@ -546,7 +546,7 @@ TEST(ExecPlanExecution, StressSourceSink) {
for (bool parallel : {false, true}) {
SCOPED_TRACE(parallel ? "parallel" : "single threaded");
- int num_batches = slow && !parallel ? 30 : 300;
+ int num_batches = (slow && !parallel) ? 30 : 300;
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
AsyncGenerator<util::optional<ExecBatch>> sink_gen;
@@ -576,7 +576,7 @@ TEST(ExecPlanExecution, StressSourceOrderBy) {
for (bool parallel : {false, true}) {
SCOPED_TRACE(parallel ? "parallel" : "single threaded");
- int num_batches = slow && !parallel ? 30 : 300;
+ int num_batches = (slow && !parallel) ? 30 : 300;
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
AsyncGenerator<util::optional<ExecBatch>> sink_gen;
@@ -605,6 +605,42 @@ TEST(ExecPlanExecution, StressSourceOrderBy) {
}
}
+TEST(ExecPlanExecution, StressSourceGroupedSumStop) {
+ auto input_schema = schema({field("a", int32()), field("b", boolean())});
+ for (bool slow : {false, true}) {
+ SCOPED_TRACE(slow ? "slowed" : "unslowed");
+
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel" : "single threaded");
+
+ int num_batches = (slow && !parallel) ? 30 : 300;
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ auto random_data = MakeRandomBatches(input_schema, num_batches);
+
+ SortOptions options({SortKey("a", SortOrder::Ascending)});
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{random_data.schema,
+ random_data.gen(parallel, slow)}},
+ {"aggregate",
+ AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
+ /*targets=*/{"a"}, /*names=*/{"sum(a)"},
+ /*keys=*/{"b"}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_OK(plan->Validate());
+ ASSERT_OK(plan->StartProducing());
+ plan->StopProducing();
+ ASSERT_FINISHES_OK(plan->finished());
+ }
+ }
+}
+
TEST(ExecPlanExecution, StressSourceSinkStopped) {
for (bool slow : {false, true}) {
SCOPED_TRACE(slow ? "slowed" : "unslowed");
@@ -612,7 +648,7 @@ TEST(ExecPlanExecution, StressSourceSinkStopped) {
for (bool parallel : {false, true}) {
SCOPED_TRACE(parallel ? "parallel" : "single threaded");
- int num_batches = slow && !parallel ? 30 : 300;
+ int num_batches = (slow && !parallel) ? 30 : 300;
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
AsyncGenerator<util::optional<ExecBatch>> sink_gen;
@@ -1005,6 +1041,40 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) {
}))));
}
+TEST(ExecPlanExecution, ScalarSourceGroupedSum) {
+ // ARROW-14630: ensure grouped aggregation with a scalar key/array input doesn't error
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ BatchesWithSchema scalar_data;
+ scalar_data.batches = {
+ ExecBatchFromJSON({int32(), ValueDescr::Scalar(boolean())},
+ "[[5, false], [6, false], [7, false]]"),
+ ExecBatchFromJSON({int32(), ValueDescr::Scalar(boolean())},
+ "[[1, true], [2, true], [3, true]]"),
+ };
+ scalar_data.schema = schema({field("a", int32()), field("b", boolean())});
+
+ SortOptions options({SortKey("b", SortOrder::Descending)});
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{scalar_data.schema,
+ scalar_data.gen(/*parallel=*/false,
+ /*slow=*/false)}},
+ {"aggregate",
+ AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
+ /*targets=*/{"a"}, /*names=*/{"hash_sum(a)"},
+ /*keys=*/{"b"}}},
+ {"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(UnorderedElementsAreArray({
+ ExecBatchFromJSON({int64(), boolean()}, R"([[6, true], [18, false]])"),
+ }))));
+}
+
TEST(ExecPlanExecution, SelfInnerHashJoinSink) {
for (bool parallel : {false, true}) {
SCOPED_TRACE(parallel ? "parallel/merged" : "serial");