You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2023/01/26 17:01:13 UTC
[arrow] branch master updated: GH-33566: [C++] Add support for nullary and n-ary aggregate functions (#15083)
This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new af400a81e6 GH-33566: [C++] Add support for nullary and n-ary aggregate functions (#15083)
af400a81e6 is described below
commit af400a81e698938b9d21ef4b1aeff1ade64562eb
Author: Felipe Oliveira Carvalho <fe...@gmail.com>
AuthorDate: Thu Jan 26 14:01:01 2023 -0300
GH-33566: [C++] Add support for nullary and n-ary aggregate functions (#15083)
- [x] Add ability to pass 0 or more than 1 target fields via the Aggregate API
- [x] Add support for nullary `count` -- `count(*)`
- [x] Add a n-ary aggregate function to test changes `*`
`*` I implemented a `"covariant(y, x)"` aggregation function and used it to test the Aggregate API changes, but it's not present in this PR now that I intend to focus on passing the CI tests and get a final review
* Closes: #33566
Lead-authored-by: Felipe Oliveira Carvalho <fe...@gmail.com>
Co-authored-by: Antoine Pitrou <pi...@free.fr>
Co-authored-by: Weston Pace <we...@gmail.com>
Signed-off-by: Weston Pace <we...@gmail.com>
---
cpp/src/arrow/compute/api_aggregate.h | 30 +++-
cpp/src/arrow/compute/exec.cc | 61 +++++++-
cpp/src/arrow/compute/exec.h | 6 +-
cpp/src/arrow/compute/exec/aggregate.cc | 156 ++++++++++++++-------
cpp/src/arrow/compute/exec/aggregate.h | 7 +-
cpp/src/arrow/compute/exec/aggregate_node.cc | 112 +++++++++------
cpp/src/arrow/compute/exec/plan_test.cc | 63 +++++++--
cpp/src/arrow/compute/exec/test_util.cc | 14 +-
cpp/src/arrow/compute/kernels/aggregate_basic.cc | 42 +++++-
cpp/src/arrow/compute/kernels/hash_aggregate.cc | 82 ++++++++++-
.../arrow/compute/kernels/hash_aggregate_test.cc | 71 ++++++++--
cpp/src/arrow/engine/substrait/extension_set.cc | 42 ++++--
cpp/src/arrow/engine/substrait/function_test.cc | 18 ++-
.../arrow/engine/substrait/relation_internal.cc | 30 ++--
cpp/src/arrow/engine/substrait/serde_test.cc | 2 +-
.../arrow/engine/substrait/test_plan_builder.cc | 20 +--
cpp/src/arrow/engine/substrait/test_plan_builder.h | 7 +-
docs/source/cpp/compute.rst | 152 ++++++++++----------
docs/source/python/api/compute.rst | 4 +
python/pyarrow/_compute.pyx | 10 +-
python/pyarrow/compute.py | 4 +
python/pyarrow/includes/libarrow.pxd | 2 +
python/pyarrow/table.pxi | 48 ++++---
python/pyarrow/tests/test_table.py | 35 +++++
24 files changed, 743 insertions(+), 275 deletions(-)
diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h
index c9a84d2018..97c654266e 100644
--- a/cpp/src/arrow/compute/api_aggregate.h
+++ b/cpp/src/arrow/compute/api_aggregate.h
@@ -20,6 +20,8 @@
#pragma once
+#include <vector>
+
#include "arrow/compute/function.h"
#include "arrow/datum.h"
#include "arrow/result.h"
@@ -186,16 +188,38 @@ class ARROW_EXPORT IndexOptions : public FunctionOptions {
/// \brief Configure a grouped aggregation
struct ARROW_EXPORT Aggregate {
+ Aggregate() = default;
+
+ Aggregate(std::string function, std::shared_ptr<FunctionOptions> options,
+ std::vector<FieldRef> target, std::string name = "")
+ : function(std::move(function)),
+ options(std::move(options)),
+ target(std::move(target)),
+ name(std::move(name)) {}
+
+ Aggregate(std::string function, std::shared_ptr<FunctionOptions> options,
+ FieldRef target, std::string name = "")
+ : Aggregate(std::move(function), std::move(options),
+ std::vector<FieldRef>{std::move(target)}, std::move(name)) {}
+
+ Aggregate(std::string function, FieldRef target, std::string name)
+ : Aggregate(std::move(function), /*options=*/NULLPTR,
+ std::vector<FieldRef>{std::move(target)}, std::move(name)) {}
+
+ Aggregate(std::string function, std::string name)
+ : Aggregate(std::move(function), /*options=*/NULLPTR,
+ /*target=*/std::vector<FieldRef>{}, std::move(name)) {}
+
/// the name of the aggregation function
std::string function;
/// options for the aggregation function
std::shared_ptr<FunctionOptions> options;
- // fields to which aggregations will be applied
- FieldRef target;
+ /// zero or more fields to which aggregations will be applied
+ std::vector<FieldRef> target;
- // output field name for aggregations
+ /// optional output field name for aggregations
std::string name;
};
diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc
index ee02b26845..5f73ca93c7 100644
--- a/cpp/src/arrow/compute/exec.cc
+++ b/cpp/src/arrow/compute/exec.cc
@@ -147,9 +147,21 @@ ExecBatch ExecBatch::Slice(int64_t offset, int64_t length) const {
return out;
}
-Result<ExecBatch> ExecBatch::Make(std::vector<Datum> values) {
+namespace {
+
+enum LengthInferenceError {
+ kEmptyInput = -1,
+ kInvalidValues = -2,
+};
+
+/// \brief Infer the ExecBatch length from values.
+///
+/// \return the inferred length of the batch. If there are no values in the
+/// batch then kEmptyInput (-1) is returned. If the values in the batch have
+/// different lengths then kInvalidValues (-2) is returned.
+int64_t DoInferLength(const std::vector<Datum>& values) {
if (values.empty()) {
- return Status::Invalid("Cannot infer ExecBatch length without at least one value");
+ return kEmptyInput;
}
int64_t length = -1;
@@ -164,13 +176,52 @@ Result<ExecBatch> ExecBatch::Make(std::vector<Datum> values) {
}
if (length != value.length()) {
+ // all the arrays should have the same length
+ return kInvalidValues;
+ }
+ }
+
+ return length == -1 ? 1 : length;
+}
+
+} // namespace
+
+Result<int64_t> ExecBatch::InferLength(const std::vector<Datum>& values) {
+ const int64_t length = DoInferLength(values);
+ switch (length) {
+ case kInvalidValues:
return Status::Invalid(
"Arrays used to construct an ExecBatch must have equal length");
- }
+ case kEmptyInput:
+ return Status::Invalid("Cannot infer ExecBatch length without at least one value");
+ default:
+ break;
}
+ return {length};
+}
- if (length == -1) {
- length = 1;
+Result<ExecBatch> ExecBatch::Make(std::vector<Datum> values, int64_t length) {
+ // Infer the length again and/or validate the given length.
+ const int64_t inferred_length = DoInferLength(values);
+ switch (inferred_length) {
+ case kEmptyInput:
+ if (length < 0) {
+ return Status::Invalid(
+ "Cannot infer ExecBatch length without at least one value");
+ }
+ break;
+
+ case kInvalidValues:
+ return Status::Invalid(
+ "Arrays used to construct an ExecBatch must have equal length");
+
+ default:
+ if (length < 0) {
+ length = inferred_length;
+ } else if (length != inferred_length) {
+ return Status::Invalid("Length used to construct an ExecBatch is invalid");
+ }
+ break;
}
return ExecBatch(std::move(values), length);
diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h
index 30f4113f6c..487b8d120a 100644
--- a/cpp/src/arrow/compute/exec.h
+++ b/cpp/src/arrow/compute/exec.h
@@ -24,6 +24,7 @@
#include <cstdint>
#include <limits>
#include <memory>
+#include <optional>
#include <string>
#include <utility>
#include <vector>
@@ -174,7 +175,10 @@ struct ARROW_EXPORT ExecBatch {
explicit ExecBatch(const RecordBatch& batch);
- static Result<ExecBatch> Make(std::vector<Datum> values);
+ /// \brief Infer the ExecBatch length from values.
+ static Result<int64_t> InferLength(const std::vector<Datum>& values);
+
+ static Result<ExecBatch> Make(std::vector<Datum> values, int64_t length = -1);
Result<std::shared_ptr<RecordBatch>> ToRecordBatch(
std::shared_ptr<Schema> schema, MemoryPool* pool = default_memory_pool()) const;
diff --git a/cpp/src/arrow/compute/exec/aggregate.cc b/cpp/src/arrow/compute/exec/aggregate.cc
index d5f347f34a..5e99bbba92 100644
--- a/cpp/src/arrow/compute/exec/aggregate.cc
+++ b/cpp/src/arrow/compute/exec/aggregate.cc
@@ -36,54 +36,74 @@ using internal::ToChars;
namespace compute {
namespace internal {
+namespace {
+
+std::vector<TypeHolder> ExtendWithGroupIdType(const std::vector<TypeHolder>& in_types) {
+ std::vector<TypeHolder> aggr_in_types;
+ aggr_in_types.reserve(in_types.size() + 1);
+ aggr_in_types = in_types;
+ aggr_in_types.emplace_back(uint32());
+ return aggr_in_types;
+}
+
+Result<const HashAggregateKernel*> GetKernel(ExecContext* ctx, const Aggregate& aggregate,
+ const std::vector<TypeHolder>& in_types) {
+ const auto aggr_in_types = ExtendWithGroupIdType(in_types);
+ ARROW_ASSIGN_OR_RAISE(auto function,
+ ctx->func_registry()->GetFunction(aggregate.function));
+ ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, function->DispatchExact(aggr_in_types));
+ return static_cast<const HashAggregateKernel*>(kernel);
+}
+
+Result<std::unique_ptr<KernelState>> InitKernel(const HashAggregateKernel* kernel,
+ ExecContext* ctx,
+ const Aggregate& aggregate,
+ const std::vector<TypeHolder>& in_types) {
+ const auto aggr_in_types = ExtendWithGroupIdType(in_types);
+
+ KernelContext kernel_ctx{ctx};
+ const auto* options =
+ arrow::internal::checked_cast<const FunctionOptions*>(aggregate.options.get());
+ if (options == nullptr) {
+ // use known default options for the named function if possible
+ auto maybe_function = ctx->func_registry()->GetFunction(aggregate.function);
+ if (maybe_function.ok()) {
+ options = maybe_function.ValueOrDie()->default_options();
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto state,
+ kernel->init(&kernel_ctx, KernelInitArgs{kernel, aggr_in_types, options}));
+ return std::move(state);
+}
+
+} // namespace
+
Result<std::vector<const HashAggregateKernel*>> GetKernels(
ExecContext* ctx, const std::vector<Aggregate>& aggregates,
- const std::vector<TypeHolder>& in_types) {
+ const std::vector<std::vector<TypeHolder>>& in_types) {
if (aggregates.size() != in_types.size()) {
return Status::Invalid(aggregates.size(), " aggregate functions were specified but ",
in_types.size(), " arguments were provided.");
}
std::vector<const HashAggregateKernel*> kernels(in_types.size());
-
for (size_t i = 0; i < aggregates.size(); ++i) {
- ARROW_ASSIGN_OR_RAISE(auto function,
- ctx->func_registry()->GetFunction(aggregates[i].function));
- ARROW_ASSIGN_OR_RAISE(const Kernel* kernel,
- function->DispatchExact({in_types[i], uint32()}));
- kernels[i] = static_cast<const HashAggregateKernel*>(kernel);
+ ARROW_ASSIGN_OR_RAISE(kernels[i], GetKernel(ctx, aggregates[i], in_types[i]));
}
return kernels;
}
Result<std::vector<std::unique_ptr<KernelState>>> InitKernels(
const std::vector<const HashAggregateKernel*>& kernels, ExecContext* ctx,
- const std::vector<Aggregate>& aggregates, const std::vector<TypeHolder>& in_types) {
+ const std::vector<Aggregate>& aggregates,
+ const std::vector<std::vector<TypeHolder>>& in_types) {
std::vector<std::unique_ptr<KernelState>> states(kernels.size());
-
for (size_t i = 0; i < aggregates.size(); ++i) {
- const FunctionOptions* options =
- arrow::internal::checked_cast<const FunctionOptions*>(
- aggregates[i].options.get());
-
- if (options == nullptr) {
- // use known default options for the named function if possible
- auto maybe_function = ctx->func_registry()->GetFunction(aggregates[i].function);
- if (maybe_function.ok()) {
- options = maybe_function.ValueOrDie()->default_options();
- }
- }
-
- KernelContext kernel_ctx{ctx};
ARROW_ASSIGN_OR_RAISE(states[i],
- kernels[i]->init(&kernel_ctx, KernelInitArgs{kernels[i],
- {
- in_types[i],
- uint32(),
- },
- options}));
+ InitKernel(kernels[i], ctx, aggregates[i], in_types[i]));
}
-
return std::move(states);
}
@@ -91,15 +111,16 @@ Result<FieldVector> ResolveKernels(
const std::vector<Aggregate>& aggregates,
const std::vector<const HashAggregateKernel*>& kernels,
const std::vector<std::unique_ptr<KernelState>>& states, ExecContext* ctx,
- const std::vector<TypeHolder>& types) {
+ const std::vector<std::vector<TypeHolder>>& types) {
FieldVector fields(types.size());
for (size_t i = 0; i < kernels.size(); ++i) {
KernelContext kernel_ctx{ctx};
kernel_ctx.SetState(states[i].get());
- ARROW_ASSIGN_OR_RAISE(auto type, kernels[i]->signature->out_type().Resolve(
- &kernel_ctx, {types[i], uint32()}));
+ const auto aggr_in_types = ExtendWithGroupIdType(types[i]);
+ ARROW_ASSIGN_OR_RAISE(
+ auto type, kernels[i]->signature->out_type().Resolve(&kernel_ctx, aggr_in_types));
fields[i] = field(aggregates[i].function, type.GetSharedPtr());
}
return fields;
@@ -121,27 +142,50 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
ExecSpanIterator argument_iterator;
ExecBatch args_batch;
- if (!arguments.empty()) {
- ARROW_ASSIGN_OR_RAISE(args_batch, ExecBatch::Make(arguments));
+ Result<int64_t> inferred_length = ExecBatch::InferLength(arguments);
+ if (!inferred_length.ok()) {
+ inferred_length = ExecBatch::InferLength(keys);
+ }
+ ARROW_ASSIGN_OR_RAISE(const int64_t length, std::move(inferred_length));
+ if (!aggregates.empty()) {
+ ARROW_ASSIGN_OR_RAISE(args_batch, ExecBatch::Make(arguments, length));
// Construct and initialize HashAggregateKernels
- auto argument_types = args_batch.GetTypes();
+ std::vector<std::vector<TypeHolder>> aggs_argument_types;
+ aggs_argument_types.reserve(aggregates.size());
+ size_t i = 0;
+ for (const auto& aggregate : aggregates) {
+ auto& agg_types = aggs_argument_types.emplace_back();
+ const size_t num_needed = aggregate.target.size();
+ for (size_t j = 0; j < num_needed && i < arguments.size(); j++, i++) {
+ agg_types.emplace_back(arguments[i].type());
+ }
+ if (agg_types.size() != num_needed) {
+ return Status::Invalid("Not enough arguments specified to aggregate functions.");
+ }
+ }
+ DCHECK_EQ(aggs_argument_types.size(), aggregates.size());
+ if (i != arguments.size()) {
+ return Status::Invalid("Aggregate functions expect exactly ", i, " arguments, but ",
+ arguments.size(), " were specified.");
+ }
- ARROW_ASSIGN_OR_RAISE(kernels, GetKernels(ctx, aggregates, argument_types));
+ ARROW_ASSIGN_OR_RAISE(kernels, GetKernels(ctx, aggregates, aggs_argument_types));
states.resize(task_group->parallelism());
for (auto& state : states) {
- ARROW_ASSIGN_OR_RAISE(state, InitKernels(kernels, ctx, aggregates, argument_types));
+ ARROW_ASSIGN_OR_RAISE(state,
+ InitKernels(kernels, ctx, aggregates, aggs_argument_types));
}
- ARROW_ASSIGN_OR_RAISE(
- out_fields, ResolveKernels(aggregates, kernels, states[0], ctx, argument_types));
+ ARROW_ASSIGN_OR_RAISE(out_fields, ResolveKernels(aggregates, kernels, states[0], ctx,
+ aggs_argument_types));
RETURN_NOT_OK(argument_iterator.Init(args_batch, ctx->exec_chunksize()));
}
// Construct Groupers
- ARROW_ASSIGN_OR_RAISE(ExecBatch keys_batch, ExecBatch::Make(keys));
+ ARROW_ASSIGN_OR_RAISE(ExecBatch keys_batch, ExecBatch::Make(keys, length));
auto key_types = keys_batch.GetTypes();
std::vector<std::unique_ptr<Grouper>> groupers(task_group->parallelism());
@@ -164,6 +208,10 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
ExecSpan key_batch, argument_batch;
while ((arguments.empty() || argument_iterator.Next(&argument_batch)) &&
key_iterator.Next(&key_batch)) {
+ if (arguments.empty()) {
+ // A value-less argument_batch should still have a valid length
+ argument_batch.length = key_batch.length;
+ }
if (key_batch.length == 0) continue;
task_group->Append([&, key_batch, argument_batch] {
@@ -181,13 +229,23 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
ARROW_ASSIGN_OR_RAISE(Datum id_batch, grouper->Consume(key_batch));
// consume group ids with HashAggregateKernels
- for (size_t i = 0; i < kernels.size(); ++i) {
+ for (size_t k = 0, arg_idx = 0; k < kernels.size(); ++k) {
+ const auto* kernel = kernels[k];
KernelContext batch_ctx{ctx};
- batch_ctx.SetState(states[thread_index][i].get());
- ExecSpan kernel_batch({argument_batch[i], *id_batch.array()},
- argument_batch.length);
- RETURN_NOT_OK(kernels[i]->resize(&batch_ctx, grouper->num_groups()));
- RETURN_NOT_OK(kernels[i]->consume(&batch_ctx, kernel_batch));
+ batch_ctx.SetState(states[thread_index][k].get());
+
+ const size_t kernel_num_args = kernel->signature->in_types().size();
+ DCHECK_GT(kernel_num_args, 0);
+
+ std::vector<ExecValue> kernel_args;
+ for (size_t i = 0; i + 1 < kernel_num_args; i++, arg_idx++) {
+ kernel_args.push_back(argument_batch[arg_idx]);
+ }
+ kernel_args.emplace_back(*id_batch.array());
+
+ ExecSpan kernel_batch(std::move(kernel_args), argument_batch.length);
+ RETURN_NOT_OK(kernel->resize(&batch_ctx, grouper->num_groups()));
+ RETURN_NOT_OK(kernel->consume(&batch_ctx, kernel_batch));
}
return Status::OK();
@@ -215,7 +273,7 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
}
// Finalize output
- ArrayDataVector out_data(arguments.size() + keys.size());
+ ArrayDataVector out_data(kernels.size() + keys.size());
auto it = out_data.begin();
for (size_t idx = 0; idx < kernels.size(); ++idx) {
@@ -231,8 +289,8 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
*it++ = key.array();
}
- int64_t length = out_data[0]->length;
- return ArrayData::Make(struct_(std::move(out_fields)), length,
+ const int64_t out_length = out_data[0]->length;
+ return ArrayData::Make(struct_(std::move(out_fields)), out_length,
{/*null_bitmap=*/nullptr}, std::move(out_data),
/*null_count=*/0);
}
diff --git a/cpp/src/arrow/compute/exec/aggregate.h b/cpp/src/arrow/compute/exec/aggregate.h
index 72990f3b6e..027449f02a 100644
--- a/cpp/src/arrow/compute/exec/aggregate.h
+++ b/cpp/src/arrow/compute/exec/aggregate.h
@@ -42,17 +42,18 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
Result<std::vector<const HashAggregateKernel*>> GetKernels(
ExecContext* ctx, const std::vector<Aggregate>& aggregates,
- const std::vector<TypeHolder>& in_types);
+ const std::vector<std::vector<TypeHolder>>& in_types);
Result<std::vector<std::unique_ptr<KernelState>>> InitKernels(
const std::vector<const HashAggregateKernel*>& kernels, ExecContext* ctx,
- const std::vector<Aggregate>& aggregates, const std::vector<TypeHolder>& in_types);
+ const std::vector<Aggregate>& aggregates,
+ const std::vector<std::vector<TypeHolder>>& in_types);
Result<FieldVector> ResolveKernels(
const std::vector<Aggregate>& aggregates,
const std::vector<const HashAggregateKernel*>& kernels,
const std::vector<std::unique_ptr<KernelState>>& states, ExecContext* ctx,
- const std::vector<TypeHolder>& in_types);
+ const std::vector<std::vector<TypeHolder>>& in_types);
} // namespace internal
} // namespace compute
diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc
index b4726b4fbc..ac0375f2ce 100644
--- a/cpp/src/arrow/compute/exec/aggregate_node.cc
+++ b/cpp/src/arrow/compute/exec/aggregate_node.cc
@@ -46,12 +46,21 @@ namespace {
void AggregatesToString(std::stringstream* ss, const Schema& input_schema,
const std::vector<Aggregate>& aggs,
- const std::vector<int>& target_field_ids, int indent = 0) {
+ const std::vector<std::vector<int>>& target_fieldsets,
+ int indent = 0) {
*ss << "aggregates=[" << std::endl;
for (size_t i = 0; i < aggs.size(); i++) {
for (int j = 0; j < indent; ++j) *ss << " ";
- *ss << '\t' << aggs[i].function << '('
- << input_schema.field(target_field_ids[i])->name();
+ *ss << '\t' << aggs[i].function << '(';
+ const auto& target = target_fieldsets[i];
+ if (target.size() == 0) {
+ *ss << "*";
+ } else {
+ *ss << input_schema.field(target[0])->name();
+ for (size_t k = 1; k < target.size(); k++) {
+ *ss << ", " << input_schema.field(target[k])->name();
+ }
+ }
if (aggs[i].options) {
*ss << ", " << aggs[i].options->ToString();
}
@@ -65,12 +74,13 @@ class ScalarAggregateNode : public ExecNode, public TracedNode<ScalarAggregateNo
public:
ScalarAggregateNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema,
- std::vector<int> target_field_ids, std::vector<Aggregate> aggs,
+ std::vector<std::vector<int>> target_fieldsets,
+ std::vector<Aggregate> aggs,
std::vector<const ScalarAggregateKernel*> kernels,
std::vector<std::vector<std::unique_ptr<KernelState>>> states)
: ExecNode(plan, std::move(inputs), {"target"},
/*output_schema=*/std::move(output_schema)),
- target_field_ids_(std::move(target_field_ids)),
+ target_fieldsets_(std::move(target_fieldsets)),
aggs_(std::move(aggs)),
kernels_(std::move(kernels)),
states_(std::move(states)) {}
@@ -88,13 +98,14 @@ class ScalarAggregateNode : public ExecNode, public TracedNode<ScalarAggregateNo
std::vector<const ScalarAggregateKernel*> kernels(aggregates.size());
std::vector<std::vector<std::unique_ptr<KernelState>>> states(kernels.size());
FieldVector fields(kernels.size());
- std::vector<int> target_field_ids(kernels.size());
+ std::vector<std::vector<int>> target_fieldsets(kernels.size());
for (size_t i = 0; i < kernels.size(); ++i) {
- ARROW_ASSIGN_OR_RAISE(
- auto match,
- FieldRef(aggregate_options.aggregates[i].target).FindOne(input_schema));
- target_field_ids[i] = match[0];
+ const auto& target_fieldset = aggregate_options.aggregates[i].target;
+ for (const auto& target : target_fieldset) {
+ ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(target).FindOne(input_schema));
+ target_fieldsets[i].push_back(match[0]);
+ }
ARROW_ASSIGN_OR_RAISE(
auto function, exec_ctx->func_registry()->GetFunction(aggregates[i].function));
@@ -104,34 +115,37 @@ class ScalarAggregateNode : public ExecNode, public TracedNode<ScalarAggregateNo
aggregates[i].function);
}
- TypeHolder in_type(input_schema.field(target_field_ids[i])->type().get());
- ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, function->DispatchExact({in_type}));
+ std::vector<TypeHolder> in_types;
+ for (const auto& target : target_fieldsets[i]) {
+ in_types.emplace_back(input_schema.field(target)->type().get());
+ }
+ ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, function->DispatchExact(in_types));
kernels[i] = static_cast<const ScalarAggregateKernel*>(kernel);
if (aggregates[i].options == nullptr) {
- aggregates[i].options = function->default_options()->Copy();
+ DCHECK(!function->doc().options_required);
+ const auto* default_options = function->default_options();
+ if (default_options) {
+ aggregates[i].options = default_options->Copy();
+ }
}
KernelContext kernel_ctx{exec_ctx};
states[i].resize(plan->query_context()->max_concurrency());
- RETURN_NOT_OK(Kernel::InitAll(&kernel_ctx,
- KernelInitArgs{kernels[i],
- {
- in_type,
- },
- aggregates[i].options.get()},
- &states[i]));
+ RETURN_NOT_OK(Kernel::InitAll(
+ &kernel_ctx, KernelInitArgs{kernels[i], in_types, aggregates[i].options.get()},
+ &states[i]));
// pick one to resolve the kernel signature
kernel_ctx.SetState(states[i][0].get());
ARROW_ASSIGN_OR_RAISE(auto out_type, kernels[i]->signature->out_type().Resolve(
- &kernel_ctx, {in_type}));
+ &kernel_ctx, in_types));
fields[i] = field(aggregate_options.aggregates[i].name, out_type.GetSharedPtr());
}
return plan->EmplaceNode<ScalarAggregateNode>(
- plan, std::move(inputs), schema(std::move(fields)), std::move(target_field_ids),
+ plan, std::move(inputs), schema(std::move(fields)), std::move(target_fieldsets),
std::move(aggregates), std::move(kernels), std::move(states));
}
@@ -148,8 +162,12 @@ class ScalarAggregateNode : public ExecNode, public TracedNode<ScalarAggregateNo
KernelContext batch_ctx{plan()->query_context()->exec_context()};
batch_ctx.SetState(states_[i][thread_index].get());
- ExecSpan single_column_batch{{batch.values[target_field_ids_[i]]}, batch.length};
- RETURN_NOT_OK(kernels_[i]->consume(&batch_ctx, single_column_batch));
+ std::vector<ExecValue> column_values;
+ for (const int field : target_fieldsets_[i]) {
+ column_values.push_back(batch.values[field]);
+ }
+ ExecSpan column_batch{std::move(column_values), batch.length};
+ RETURN_NOT_OK(kernels_[i]->consume(&batch_ctx, column_batch));
}
return Status::OK();
}
@@ -197,7 +215,7 @@ class ScalarAggregateNode : public ExecNode, public TracedNode<ScalarAggregateNo
std::string ToStringExtra(int indent = 0) const override {
std::stringstream ss;
const auto input_schema = inputs_[0]->output_schema();
- AggregatesToString(&ss, *input_schema, aggs_, target_field_ids_);
+ AggregatesToString(&ss, *input_schema, aggs_, target_fieldsets_);
return ss.str();
}
@@ -223,7 +241,7 @@ class ScalarAggregateNode : public ExecNode, public TracedNode<ScalarAggregateNo
return output_->InputReceived(this, std::move(batch));
}
- const std::vector<int> target_field_ids_;
+ const std::vector<std::vector<int>> target_fieldsets_;
const std::vector<Aggregate> aggs_;
const std::vector<const ScalarAggregateKernel*> kernels_;
@@ -235,12 +253,13 @@ class ScalarAggregateNode : public ExecNode, public TracedNode<ScalarAggregateNo
class GroupByNode : public ExecNode, public TracedNode<GroupByNode> {
public:
GroupByNode(ExecNode* input, std::shared_ptr<Schema> output_schema,
- std::vector<int> key_field_ids, std::vector<int> agg_src_field_ids,
+ std::vector<int> key_field_ids,
+ std::vector<std::vector<int>> agg_src_fieldsets,
std::vector<Aggregate> aggs,
std::vector<const HashAggregateKernel*> agg_kernels)
: ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema)),
key_field_ids_(std::move(key_field_ids)),
- agg_src_field_ids_(std::move(agg_src_field_ids)),
+ agg_src_fieldsets_(std::move(agg_src_fieldsets)),
aggs_(std::move(aggs)),
agg_kernels_(std::move(agg_kernels)) {}
@@ -272,17 +291,21 @@ class GroupByNode : public ExecNode, public TracedNode<GroupByNode> {
}
// Find input field indices for aggregates
- std::vector<int> agg_src_field_ids(aggs.size());
+ std::vector<std::vector<int>> agg_src_fieldsets(aggs.size());
for (size_t i = 0; i < aggs.size(); ++i) {
- ARROW_ASSIGN_OR_RAISE(auto match, aggs[i].target.FindOne(*input_schema));
- agg_src_field_ids[i] = match[0];
+ const auto& target_fieldset = aggs[i].target;
+ for (const auto& target : target_fieldset) {
+ ARROW_ASSIGN_OR_RAISE(auto match, target.FindOne(*input_schema));
+ agg_src_fieldsets[i].push_back(match[0]);
+ }
}
// Build vector of aggregate source field data types
- std::vector<TypeHolder> agg_src_types(aggs.size());
+ std::vector<std::vector<TypeHolder>> agg_src_types(aggs.size());
for (size_t i = 0; i < aggs.size(); ++i) {
- auto agg_src_field_id = agg_src_field_ids[i];
- agg_src_types[i] = input_schema->field(agg_src_field_id)->type().get();
+ for (const auto& agg_src_field_id : agg_src_fieldsets[i]) {
+ agg_src_types[i].push_back(input_schema->field(agg_src_field_id)->type().get());
+ }
}
auto ctx = plan->query_context()->exec_context();
@@ -314,7 +337,7 @@ class GroupByNode : public ExecNode, public TracedNode<GroupByNode> {
return input->plan()->EmplaceNode<GroupByNode>(
input, schema(std::move(output_fields)), std::move(key_field_ids),
- std::move(agg_src_field_ids), std::move(aggs), std::move(agg_kernels));
+ std::move(agg_src_fieldsets), std::move(aggs), std::move(agg_kernels));
}
const char* kind_name() const override { return "GroupByNode"; }
@@ -351,8 +374,12 @@ class GroupByNode : public ExecNode, public TracedNode<GroupByNode> {
KernelContext kernel_ctx{ctx};
kernel_ctx.SetState(state->agg_states[i].get());
- ExecSpan agg_batch({batch[agg_src_field_ids_[i]], ExecValue(*id_batch.array())},
- batch.length);
+ std::vector<ExecValue> column_values;
+ for (const int field : agg_src_fieldsets_[i]) {
+ column_values.push_back(batch[field]);
+ }
+ column_values.emplace_back(*id_batch.array());
+ ExecSpan agg_batch(std::move(column_values), batch.length);
RETURN_NOT_OK(agg_kernels_[i]->resize(&kernel_ctx, state->grouper->num_groups()));
RETURN_NOT_OK(agg_kernels_[i]->consume(&kernel_ctx, agg_batch));
}
@@ -506,7 +533,7 @@ class GroupByNode : public ExecNode, public TracedNode<GroupByNode> {
ss << '"' << input_schema->field(key_field_ids_[i])->name() << '"';
}
ss << "], ";
- AggregatesToString(&ss, *input_schema, aggs_, agg_src_field_ids_, indent);
+ AggregatesToString(&ss, *input_schema, aggs_, agg_src_fieldsets_, indent);
return ss.str();
}
@@ -539,10 +566,11 @@ class GroupByNode : public ExecNode, public TracedNode<GroupByNode> {
state->grouper, Grouper::Make(key_types, plan_->query_context()->exec_context()));
// Build vector of aggregate source field data types
- std::vector<TypeHolder> agg_src_types(agg_kernels_.size());
+ std::vector<std::vector<TypeHolder>> agg_src_types(agg_kernels_.size());
for (size_t i = 0; i < agg_kernels_.size(); ++i) {
- auto agg_src_field_id = agg_src_field_ids_[i];
- agg_src_types[i] = input_schema->field(agg_src_field_id)->type().get();
+ for (const auto& field_id : agg_src_fieldsets_[i]) {
+ agg_src_types[i].emplace_back(input_schema->field(field_id)->type().get());
+ }
}
ARROW_ASSIGN_OR_RAISE(
@@ -565,7 +593,7 @@ class GroupByNode : public ExecNode, public TracedNode<GroupByNode> {
int output_task_group_id_;
const std::vector<int> key_field_ids_;
- const std::vector<int> agg_src_field_ids_;
+ const std::vector<std::vector<int>> agg_src_fieldsets_;
const std::vector<Aggregate> aggs_;
const std::vector<const HashAggregateKernel*> agg_kernels_;
diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc
index b0c0237e93..497b719625 100644
--- a/cpp/src/arrow/compute/exec/plan_test.cc
+++ b/cpp/src/arrow/compute/exec/plan_test.cc
@@ -476,7 +476,9 @@ TEST(ExecPlan, ToString) {
/*aggregates=*/{
{"hash_sum", nullptr, "multiply(i32, 2)", "sum(multiply(i32, 2))"},
{"hash_count", options, "multiply(i32, 2)",
- "count(multiply(i32, 2))"}},
+ "count(multiply(i32, 2))"},
+ {"hash_count_all", "count(*)"},
+ },
/*keys=*/{"bool"}}},
{"filter", FilterNodeOptions{greater(field_ref("sum(multiply(i32, 2))"),
literal(10))}},
@@ -493,6 +495,7 @@ custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32,
:GroupByNode{keys=["bool"], aggregates=[
hash_sum(multiply(i32, 2)),
hash_count(multiply(i32, 2), {mode=NON_NULL}),
+ hash_count_all(*),
]}
:ProjectNode{projection=[bool, multiply(i32, 2)]}
:FilterNode{filter=(i32 >= 0)}
@@ -512,20 +515,23 @@ custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32,
rhs.label = "rhs";
union_node.inputs.emplace_back(lhs);
union_node.inputs.emplace_back(rhs);
- ASSERT_OK(
- Declaration::Sequence(
- {
- union_node,
- {"aggregate", AggregateNodeOptions{
- /*aggregates=*/{{"count", options, "i32", "count(i32)"}},
- /*keys=*/{}}},
- {"sink", SinkNodeOptions{&sink_gen}},
- })
- .AddToPlan(plan.get()));
+ ASSERT_OK(Declaration::Sequence(
+ {
+ union_node,
+ {"aggregate",
+ AggregateNodeOptions{/*aggregates=*/{
+ {"count", options, "i32", "count(i32)"},
+ {"count_all", "count(*)"},
+ },
+ /*keys=*/{}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
EXPECT_EQ(plan->ToString(), R"a(ExecPlan with 5 nodes:
:SinkNode{}
:ScalarAggregateNode{aggregates=[
count(i32, {mode=NON_NULL}),
+ count_all(*),
]}
:UnionNode{}
rhs:SourceNode{}
@@ -1249,6 +1255,7 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) {
/*aggregates=*/{{"all", nullptr, "b", "all(b)"},
{"any", nullptr, "b", "any(b)"},
{"count", nullptr, "a", "count(a)"},
+ {"count_all", "count(*)"},
{"mean", nullptr, "a", "mean(a)"},
{"product", nullptr, "a", "product(a)"},
{"stddev", nullptr, "a", "stddev(a)"},
@@ -1258,17 +1265,43 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) {
auto exp_batches = {
ExecBatchFromJSON(
- {boolean(), boolean(), int64(), float64(), int64(), float64(), int64(),
+ {boolean(), boolean(), int64(), int64(), float64(), int64(), float64(), int64(),
float64(), float64()},
{ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR,
- ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::ARRAY,
- ArgShape::SCALAR},
- R"([[false, true, 6, 5.5, 26250, 0.7637626158259734, 33, 5.0, 0.5833333333333334]])"),
+ ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR,
+ ArgShape::ARRAY, ArgShape::SCALAR},
+ R"([[false, true, 6, 6, 5.5, 26250, 0.7637626158259734, 33, 5.0, 0.5833333333333334]])"),
};
ASSERT_OK_AND_ASSIGN(auto result, DeclarationToExecBatches(std::move(plan)));
AssertExecBatchesEqualIgnoringOrder(result.schema, result.batches, exp_batches);
}
+TEST(ExecPlanExecution, ScalarSourceStandaloneNullaryScalarAggSink) {
+ BatchesWithSchema scalar_data;
+ scalar_data.batches = {
+ ExecBatchFromJSON({int32(), boolean()}, {ArgShape::SCALAR, ArgShape::SCALAR},
+ "[[5, null], [5, false], [5, false]]"),
+ ExecBatchFromJSON({int32(), boolean()}, "[[5, true], [null, false], [7, true]]")};
+ scalar_data.schema = schema({
+ field("a", int32()),
+ field("b", boolean()),
+ });
+
+ Declaration plan = Declaration::Sequence(
+ {{"source",
+ SourceNodeOptions{scalar_data.schema, scalar_data.gen(/*parallel=*/false,
+ /*slow=*/false)}},
+ {"aggregate", AggregateNodeOptions{/*aggregates=*/{
+ {"count_all", "count(*)"},
+ }}}});
+ ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema actual_batches,
+ DeclarationToExecBatches(std::move(plan)));
+
+ auto expected = ExecBatchFromJSON({int64()}, {ArgShape::SCALAR}, R"([[6]])");
+ AssertExecBatchesEqualIgnoringOrder(actual_batches.schema, actual_batches.batches,
+ {expected});
+}
+
TEST(ExecPlanExecution, ScalarSourceGroupedSum) {
// ARROW-14630: ensure grouped aggregation with a scalar key/array input doesn't
// error
diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc
index 4ac3fec55a..a1487c9b34 100644
--- a/cpp/src/arrow/compute/exec/test_util.cc
+++ b/cpp/src/arrow/compute/exec/test_util.cc
@@ -514,7 +514,19 @@ static inline void PrintToImpl(const std::string& factory_name,
*os << "function=" << agg.function << "<";
if (agg.options) PrintTo(*agg.options, os);
*os << ">,";
- *os << "target=" << agg.target.ToString() << ",";
+ *os << "target=";
+ if (agg.target.size() == 0) {
+ *os << "*";
+ } else if (agg.target.size() == 1) {
+ *os << agg.target[0].ToString();
+ } else {
+ *os << "(" << agg.target[0].ToString();
+ for (size_t i = 1; i < agg.target.size(); i++) {
+ *os << "," << agg.target[i].ToString();
+ }
+ *os << ")";
+ }
+ *os << ",";
*os << "name=" << agg.name;
}
*os << "},";
diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
index c2ea04d492..2f9c970759 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
@@ -67,7 +67,28 @@ void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
namespace {
// ----------------------------------------------------------------------
-// Count implementation
+// Count implementations
+
+struct CountAllImpl : public ScalarAggregator {
+ Status Consume(KernelContext*, const ExecSpan& batch) override {
+ this->count += batch.length;
+ return Status::OK();
+ }
+
+ Status MergeFrom(KernelContext*, KernelState&& src) override {
+ const auto& other_state = checked_cast<const CountAllImpl&>(src);
+ this->count += other_state.count;
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext* ctx, Datum* out) override {
+ const auto& state = checked_cast<const CountAllImpl&>(*ctx->state());
+ *out = Datum(state.count);
+ return Status::OK();
+ }
+
+ int64_t count = 0;
+};
struct CountImpl : public ScalarAggregator {
explicit CountImpl(CountOptions options) : options(std::move(options)) {}
@@ -118,6 +139,11 @@ struct CountImpl : public ScalarAggregator {
int64_t nulls = 0;
};
+Result<std::unique_ptr<KernelState>> CountAllInit(KernelContext*,
+ const KernelInitArgs& args) {
+ return std::make_unique<CountAllImpl>();
+}
+
Result<std::unique_ptr<KernelState>> CountInit(KernelContext*,
const KernelInitArgs& args) {
return std::make_unique<CountImpl>(static_cast<const CountOptions&>(*args.options));
@@ -825,6 +851,9 @@ void AddMinMaxKernels(KernelInit init,
namespace {
+const FunctionDoc count_all_doc{
+ "Count the number of rows", "This version of count takes no arguments.", {}};
+
const FunctionDoc count_doc{"Count the number of null / non-null values",
("By default, only non-null values are counted.\n"
"This can be changed through CountOptions."),
@@ -907,8 +936,15 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults();
static auto default_count_options = CountOptions::Defaults();
- auto func = std::make_shared<ScalarAggregateFunction>(
- "count", Arity::Unary(), count_doc, &default_count_options);
+ auto func = std::make_shared<ScalarAggregateFunction>("count_all", Arity::Nullary(),
+ count_all_doc, NULLPTR);
+
+ // Takes no input (counts all rows), outputs int64 scalar
+ AddAggKernel(KernelSignature::Make({}, int64()), CountAllInit, func.get());
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+
+ func = std::make_shared<ScalarAggregateFunction>("count", Arity::Unary(), count_doc,
+ &default_count_options);
// Takes any input, outputs int64 scalar
InputType any_input;
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
index fe2b4af205..cec6390f92 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -108,12 +108,11 @@ Result<TypeHolder> ResolveGroupOutputType(KernelContext* ctx,
return checked_cast<GroupedAggregator*>(ctx->state())->out_type();
}
-HashAggregateKernel MakeKernel(InputType argument_type, KernelInit init) {
+HashAggregateKernel MakeKernel(std::shared_ptr<KernelSignature> signature,
+ KernelInit init) {
HashAggregateKernel kernel;
kernel.init = std::move(init);
- kernel.signature =
- KernelSignature::Make({std::move(argument_type), InputType(Type::UINT32)},
- OutputType(ResolveGroupOutputType));
+ kernel.signature = std::move(signature);
kernel.resize = HashAggregateResize;
kernel.consume = HashAggregateConsume;
kernel.merge = HashAggregateMerge;
@@ -121,6 +120,19 @@ HashAggregateKernel MakeKernel(InputType argument_type, KernelInit init) {
return kernel;
}
+HashAggregateKernel MakeKernel(InputType argument_type, KernelInit init) {
+ return MakeKernel(
+ KernelSignature::Make({std::move(argument_type), InputType(Type::UINT32)},
+ OutputType(ResolveGroupOutputType)),
+ std::move(init));
+}
+
+HashAggregateKernel MakeUnaryKernel(KernelInit init) {
+ return MakeKernel(KernelSignature::Make({InputType(Type::UINT32)},
+ OutputType(ResolveGroupOutputType)),
+ std::move(init));
+}
+
Status AddHashAggKernels(
const std::vector<std::shared_ptr<DataType>>& types,
Result<HashAggregateKernel> make_kernel(const std::shared_ptr<DataType>&),
@@ -223,6 +235,53 @@ void VisitGroupedValuesNonNull(const ExecSpan& batch, ConsumeValue&& valid_func)
// ----------------------------------------------------------------------
// Count implementation
+// Nullary-count implementation -- COUNT(*).
+struct GroupedCountAllImpl : public GroupedAggregator {
+ Status Init(ExecContext* ctx, const KernelInitArgs& args) override {
+ counts_ = BufferBuilder(ctx->memory_pool());
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ auto added_groups = new_num_groups - num_groups_;
+ num_groups_ = new_num_groups;
+ return counts_.Append(added_groups * sizeof(int64_t), 0);
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast<GroupedCountAllImpl*>(&raw_other);
+
+ auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
+ auto other_counts = reinterpret_cast<const int64_t*>(other->counts_.data());
+
+ auto g = group_id_mapping.GetValues<uint32_t>(1);
+ for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
+ counts[*g] += other_counts[other_g];
+ }
+ return Status::OK();
+ }
+
+ Status Consume(const ExecSpan& batch) override {
+ auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
+ auto g_begin = batch[0].array.GetValues<uint32_t>(1);
+ for (auto g_itr = g_begin, end = g_itr + batch.length; g_itr != end; g_itr++) {
+ counts[*g_itr] += 1;
+ }
+ return Status::OK();
+ }
+
+ Result<Datum> Finalize() override {
+ ARROW_ASSIGN_OR_RAISE(auto counts, counts_.Finish());
+ return std::make_shared<Int64Array>(num_groups_, std::move(counts));
+ }
+
+ std::shared_ptr<DataType> out_type() const override { return int64(); }
+
+ int64_t num_groups_ = 0;
+ BufferBuilder counts_;
+};
+
struct GroupedCountImpl : public GroupedAggregator {
Status Init(ExecContext* ctx, const KernelInitArgs& args) override {
options_ = checked_cast<const CountOptions&>(*args.options);
@@ -2670,11 +2729,15 @@ struct GroupedListFactory {
namespace {
const FunctionDoc hash_count_doc{
"Count the number of null / non-null values in each group",
- ("By default, non-null values are counted.\n"
+ ("By default, only non-null values are counted.\n"
"This can be changed through ScalarAggregateOptions."),
{"array", "group_id_array"},
"CountOptions"};
+const FunctionDoc hash_count_all_doc{"Count the number of rows in each group",
+ ("Not caring about the values of any column."),
+ {"group_id_array"}};
+
const FunctionDoc hash_sum_doc{"Sum values in each group",
("Null values are ignored."),
{"array", "group_id_array"},
@@ -2791,6 +2854,15 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}
+ {
+ auto func = std::make_shared<HashAggregateFunction>("hash_count_all", Arity::Unary(),
+ hash_count_all_doc, NULLPTR);
+
+ DCHECK_OK(func->AddKernel(MakeUnaryKernel(HashAggregateInit<GroupedCountAllImpl>)));
+ auto status = registry->AddFunction(std::move(func));
+ DCHECK_OK(status);
+ }
+
{
auto func = std::make_shared<HashAggregateFunction>(
"hash_sum", Arity::Binary(), hash_sum_doc, &default_scalar_aggregate_options);
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
index 3bc2a03ddb..2fb3a28baf 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
@@ -3259,10 +3259,10 @@ TEST(GroupBy, CountAndSum) {
[null, 3]
])");
- std::shared_ptr<CountOptions> count_options;
- auto count_nulls = std::make_shared<CountOptions>(CountOptions::ONLY_NULL);
- auto count_all = std::make_shared<CountOptions>(CountOptions::ALL);
- auto min_count =
+ std::shared_ptr<CountOptions> count_opts;
+ auto count_nulls_opts = std::make_shared<CountOptions>(CountOptions::ONLY_NULL);
+ auto count_all_opts = std::make_shared<CountOptions>(CountOptions::ALL);
+ auto min_count_opts =
std::make_shared<ScalarAggregateOptions>(/*skip_nulls=*/true, /*min_count=*/3);
ASSERT_OK_AND_ASSIGN(
Datum aggregated_and_grouped,
@@ -3280,12 +3280,13 @@ TEST(GroupBy, CountAndSum) {
batch->GetColumnByName("key"),
},
{
- {"hash_count", count_options, "agg_0", "hash_count"},
- {"hash_count", count_nulls, "agg_1", "hash_count"},
- {"hash_count", count_all, "agg_2", "hash_count"},
- {"hash_sum", nullptr, "agg_3", "hash_sum"},
- {"hash_sum", min_count, "agg_4", "hash_sum"},
- {"hash_sum", nullptr, "agg_5", "hash_sum"},
+ {"hash_count", count_opts, "agg_0", "hash_count"},
+ {"hash_count", count_nulls_opts, "agg_1", "hash_count"},
+ {"hash_count", count_all_opts, "agg_2", "hash_count"},
+ {"hash_count_all", "hash_count_all"},
+ {"hash_sum", "agg_3", "hash_sum"},
+ {"hash_sum", min_count_opts, "agg_4", "hash_sum"},
+ {"hash_sum", "agg_5", "hash_sum"},
}));
AssertDatumsEqual(
@@ -3293,6 +3294,7 @@ TEST(GroupBy, CountAndSum) {
field("hash_count", int64()),
field("hash_count", int64()),
field("hash_count", int64()),
+ field("hash_count_all", int64()),
// NB: summing a float32 array results in float64 sums
field("hash_sum", float64()),
field("hash_sum", float64()),
@@ -3300,15 +3302,56 @@ TEST(GroupBy, CountAndSum) {
field("key_0", int64()),
}),
R"([
- [2, 1, 3, 4.25, null, 3, 1],
- [3, 0, 3, -0.125, -0.125, 6, 2],
- [0, 2, 2, null, null, 6, 3],
- [2, 0, 2, 4.75, null, null, null]
+ [2, 1, 3, 3, 4.25, null, 3, 1],
+ [3, 0, 3, 3, -0.125, -0.125, 6, 2],
+ [0, 2, 2, 2, null, null, 6, 3],
+ [2, 0, 2, 2, 4.75, null, null, null]
])"),
aggregated_and_grouped,
/*verbose=*/true);
}
+TEST(GroupBy, StandAloneNullaryCount) {
+ auto batch = RecordBatchFromJSON(
+ schema({field("argument", float64()), field("key", int64())}), R"([
+ [1.0, 1],
+ [null, 1],
+ [0.0, 2],
+ [null, 3],
+ [4.0, null],
+ [3.25, 1],
+ [0.125, 2],
+ [-0.25, 2],
+ [0.75, null],
+ [null, 3]
+ ])");
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ internal::GroupBy(
+ // zero arguments for aggregations because only the
+ // nullary hash_count_all aggregation is present
+ {},
+ {
+ batch->GetColumnByName("key"),
+ },
+ {
+ {"hash_count_all", "hash_count_all"},
+ }));
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_count_all", int64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [3, 1],
+ [3, 2],
+ [2, 3],
+ [2, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+}
+
TEST(GroupBy, Product) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", int64())}), R"([
diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc
index 8f437ab928..cea87071c4 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.cc
+++ b/cpp/src/arrow/engine/substrait/extension_set.cc
@@ -878,21 +878,37 @@ ExtensionIdRegistry::SubstraitCallToArrow DecodeConcatMapping() {
ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate(
const std::string& arrow_function_name) {
return [arrow_function_name](const SubstraitCall& call) -> Result<compute::Aggregate> {
- if (call.size() != 1) {
- return Status::NotImplemented(
- "Only unary aggregate functions are currently supported");
- }
- ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0));
- const FieldRef* arg_ref = arg.field_ref();
- if (!arg_ref) {
- return Status::Invalid("Expected an aggregate call ", call.id().uri, "#",
- call.id().name, " to have a direct reference");
- }
- std::string fixed_arrow_func = arrow_function_name;
+ std::string fixed_arrow_func;
if (call.is_hash()) {
- fixed_arrow_func = "hash_" + arrow_function_name;
+ fixed_arrow_func = "hash_";
}
- return compute::Aggregate{std::move(fixed_arrow_func), nullptr, *arg_ref, ""};
+
+ switch (call.size()) {
+ case 0: {
+ if (call.id().name == "count") {
+ fixed_arrow_func += "count_all";
+ return compute::Aggregate{std::move(fixed_arrow_func), ""};
+ }
+ return Status::Invalid("Expected aggregate call ", call.id().uri, "#",
+ call.id().name, " to have at least one argument");
+ }
+ case 1: {
+ fixed_arrow_func += arrow_function_name;
+
+ ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0));
+ const FieldRef* arg_ref = arg.field_ref();
+ if (!arg_ref) {
+ return Status::Invalid("Expected an aggregate call ", call.id().uri, "#",
+ call.id().name, " to have a direct reference");
+ }
+
+ return compute::Aggregate{std::move(fixed_arrow_func), *arg_ref, ""};
+ }
+ default:
+ break;
+ }
+ return Status::NotImplemented(
+ "Only nullary and unary aggregate functions are currently supported");
};
}
diff --git a/cpp/src/arrow/engine/substrait/function_test.cc b/cpp/src/arrow/engine/substrait/function_test.cc
index eca033fb0c..80f4af27b6 100644
--- a/cpp/src/arrow/engine/substrait/function_test.cc
+++ b/cpp/src/arrow/engine/substrait/function_test.cc
@@ -526,6 +526,8 @@ struct AggregateTestCase {
std::string group_outputs;
// The data type of the outputs
std::shared_ptr<DataType> output_type;
+ // The aggregation takes zero columns as input
+ bool nullary = false;
};
std::shared_ptr<Table> GetInputTableForAggregateCase(const AggregateTestCase& test_case) {
@@ -560,8 +562,10 @@ std::shared_ptr<compute::ExecPlan> PlanFromAggregateCase(
}
EXPECT_OK_AND_ASSIGN(
std::shared_ptr<Buffer> substrait,
- internal::CreateScanAggSubstrait(test_case.function_id, input_table, key_idxs,
- /*arg_idx=*/1, *test_case.output_type));
+ internal::CreateScanAggSubstrait(
+ test_case.function_id, input_table, key_idxs,
+ /*arg_idxs=*/test_case.nullary ? std::vector<int>{} : std::vector<int>{1},
+ *test_case.output_type));
std::shared_ptr<compute::SinkNodeConsumer> consumer =
std::make_shared<compute::TableSinkNodeConsumer>(output_table,
default_memory_pool());
@@ -664,7 +668,15 @@ TEST(FunctionMapping, AggregateCases) {
{int8()},
"[3]",
"[2, 1]",
- int64()}};
+ int64()},
+ {{kSubstraitAggregateGenericFunctionsUri, "count"},
+ {"[1, null, 30]"},
+ {int8()},
+ "[3]",
+ "[2, 1]",
+ int64(),
+ /*nullary=*/true},
+ };
CheckAggregateCases(test_cases);
}
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc
index 95b0ef16ba..7fed34da63 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.cc
+++ b/cpp/src/arrow/engine/substrait/relation_internal.cc
@@ -623,11 +623,11 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
}
}
- int measure_size = aggregate.measures_size();
+ const int measure_size = aggregate.measures_size();
std::vector<compute::Aggregate> aggregates;
aggregates.reserve(measure_size);
// store aggregate fields to be used when output schema is created
- std::vector<int> agg_src_field_ids(measure_size);
+ std::vector<std::vector<int>> agg_src_fieldsets(measure_size);
for (int measure_id = 0; measure_id < measure_size; measure_id++) {
const auto& agg_measure = aggregate.measures(measure_id);
if (agg_measure.has_measure()) {
@@ -635,9 +635,9 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
return Status::NotImplemented("Aggregate filters are not supported.");
}
const auto& agg_func = agg_measure.measure();
- ARROW_ASSIGN_OR_RAISE(
- SubstraitCall aggregate_call,
- FromProto(agg_func, !keys.empty(), ext_set, conversion_options));
+ ARROW_ASSIGN_OR_RAISE(SubstraitCall aggregate_call,
+ FromProto(agg_func, /*is_hash=*/!keys.empty(), ext_set,
+ conversion_options));
ExtensionIdRegistry::SubstraitAggregateToArrow converter;
if (aggregate_call.id().uri.empty() || aggregate_call.id().uri[0] == '/') {
ARROW_ASSIGN_OR_RAISE(
@@ -651,9 +651,11 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, converter(aggregate_call));
// find aggregate field ids from schema
- const auto field_ref = arrow_agg.target;
- ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
- agg_src_field_ids[measure_id] = match[0];
+ const auto& target = arrow_agg.target;
+ for (const auto& field_ref : target) {
+ ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
+ agg_src_fieldsets[measure_id].push_back(match[0]);
+ }
aggregates.push_back(std::move(arrow_agg));
} else {
@@ -661,14 +663,16 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
}
}
FieldVector output_fields;
- output_fields.reserve(key_field_ids.size() + agg_src_field_ids.size());
+ output_fields.reserve(key_field_ids.size() + measure_size);
// extract aggregate fields to output schema
- for (int id = 0; id < static_cast<int>(agg_src_field_ids.size()); id++) {
- output_fields.emplace_back(input_schema->field(agg_src_field_ids[id]));
+ for (const auto& agg_src_fieldset : agg_src_fieldsets) {
+ for (int field : agg_src_fieldset) {
+ output_fields.emplace_back(input_schema->field(field));
+ }
}
// extract key fields to output schema
- for (int id = 0; id < static_cast<int>(key_field_ids.size()); id++) {
- output_fields.emplace_back(input_schema->field(key_field_ids[id]));
+ for (int key_field_id : key_field_ids) {
+ output_fields.emplace_back(input_schema->field(key_field_id));
}
std::shared_ptr<Schema> aggregate_schema = schema(std::move(output_fields));
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc
index e17267d4f5..665713d5a0 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -1883,7 +1883,7 @@ TEST(Substrait, AggregateInvalidAggFuncArgs) {
})",
/*ignore_unknown_fields=*/false));
- ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; }));
+ ASSERT_RAISES(Invalid, DeserializePlans(*buf, [] { return kNullConsumer; }));
}
TEST(Substrait, AggregateWithFilter) {
diff --git a/cpp/src/arrow/engine/substrait/test_plan_builder.cc b/cpp/src/arrow/engine/substrait/test_plan_builder.cc
index 62f4361a61..f38f7ece9a 100644
--- a/cpp/src/arrow/engine/substrait/test_plan_builder.cc
+++ b/cpp/src/arrow/engine/substrait/test_plan_builder.cc
@@ -118,7 +118,7 @@ Result<std::unique_ptr<substrait::ProjectRel>> CreateProject(
Result<std::unique_ptr<substrait::AggregateRel>> CreateAgg(Id function_id,
const std::vector<int>& keys,
- int arg_idx,
+ std::vector<int> arg_idxs,
const DataType& output_type,
ExtensionSet* ext_set) {
auto agg = std::make_unique<substrait::AggregateRel>();
@@ -137,10 +137,12 @@ Result<std::unique_ptr<substrait::AggregateRel>> CreateAgg(Id function_id,
agg_func->set_function_reference(function_anchor);
- substrait::FunctionArgument* arg = agg_func->add_arguments();
- auto arg_expr = std::make_unique<substrait::Expression>();
- CreateDirectReference(arg_idx, arg_expr.get());
- arg->set_allocated_value(arg_expr.release());
+ for (int arg_idx : arg_idxs) {
+ substrait::FunctionArgument* arg = agg_func->add_arguments();
+ auto arg_expr = std::make_unique<substrait::Expression>();
+ CreateDirectReference(arg_idx, arg_expr.get());
+ arg->set_allocated_value(arg_expr.release());
+ }
agg_func->set_phase(substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT);
agg_func->set_invocation(
@@ -206,13 +208,15 @@ Result<std::shared_ptr<Buffer>> CreateScanProjectSubstrait(
Result<std::shared_ptr<Buffer>> CreateScanAggSubstrait(
Id function_id, const std::shared_ptr<Table>& input_table,
- const std::vector<int>& key_idxs, int arg_idx, const DataType& output_type) {
+ const std::vector<int>& key_idxs, const std::vector<int>& arg_idxs,
+ const DataType& output_type) {
ExtensionSet ext_set;
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::ReadRel> read,
CreateRead(*input_table, &ext_set));
- ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::AggregateRel> agg,
- CreateAgg(function_id, key_idxs, arg_idx, output_type, &ext_set));
+ ARROW_ASSIGN_OR_RAISE(
+ std::unique_ptr<substrait::AggregateRel> agg,
+ CreateAgg(function_id, key_idxs, arg_idxs, output_type, &ext_set));
auto read_rel = std::make_unique<substrait::Rel>();
read_rel->set_allocated_read(read.release());
diff --git a/cpp/src/arrow/engine/substrait/test_plan_builder.h b/cpp/src/arrow/engine/substrait/test_plan_builder.h
index 8af156ea61..94c03daaa7 100644
--- a/cpp/src/arrow/engine/substrait/test_plan_builder.h
+++ b/cpp/src/arrow/engine/substrait/test_plan_builder.h
@@ -64,11 +64,12 @@ ARROW_ENGINE_EXPORT Result<std::shared_ptr<Buffer>> CreateScanProjectSubstrait(
/// \brief Create a scan->aggregate->sink plan for tests
///
/// The plan will create an aggregate with one grouping set (defined by
-/// key_idxs) and one measure. The measure will be a unary function
-/// defined by `function_id` and a direct reference to `arg_idx`.
+/// key_idxs) and one measure. The measure will be a function
+/// defined by `function_id` and direct references to `arg_idxs`.
ARROW_ENGINE_EXPORT Result<std::shared_ptr<Buffer>> CreateScanAggSubstrait(
Id function_id, const std::shared_ptr<Table>& input_table,
- const std::vector<int>& key_idxs, int arg_idx, const DataType& output_type);
+ const std::vector<int>& key_idxs, const std::vector<int>& arg_idxs,
+ const DataType& output_type);
} // namespace internal
} // namespace engine
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 4205cce1c0..5bd3d659a3 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -197,43 +197,45 @@ Aggregations
Scalar aggregations operate on a (chunked) array or scalar value and reduce
the input to a single output value.
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| Function name | Arity | Input types | Output type | Options class | Notes |
-+====================+=======+==================+========================+==================================+=======+
-| all | Unary | Boolean | Scalar Boolean | :struct:`ScalarAggregateOptions` | \(1) |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| any | Unary | Boolean | Scalar Boolean | :struct:`ScalarAggregateOptions` | \(1) |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| approximate_median | Unary | Numeric | Scalar Float64 | :struct:`ScalarAggregateOptions` | |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| count | Unary | Any | Scalar Int64 | :struct:`CountOptions` | \(2) |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| count_distinct | Unary | Non-nested types | Scalar Int64 | :struct:`CountOptions` | \(2) |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| index | Unary | Any | Scalar Int64 | :struct:`IndexOptions` | \(3) |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| max | Unary | Non-nested types | Scalar Input type | :struct:`ScalarAggregateOptions` | |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| mean | Unary | Numeric | Scalar Decimal/Float64 | :struct:`ScalarAggregateOptions` | \(4) |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| min | Unary | Non-nested types | Scalar Input type | :struct:`ScalarAggregateOptions` | |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| min_max | Unary | Non-nested types | Scalar Struct | :struct:`ScalarAggregateOptions` | \(5) |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| mode | Unary | Numeric | Struct | :struct:`ModeOptions` | \(6) |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| product | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(7) |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| quantile | Unary | Numeric | Scalar Numeric | :struct:`QuantileOptions` | \(8) |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| stddev | Unary | Numeric | Scalar Float64 | :struct:`VarianceOptions` | \(9) |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| sum | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(7) |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| tdigest | Unary | Numeric | Float64 | :struct:`TDigestOptions` | \(10) |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
-| variance | Unary | Numeric | Scalar Float64 | :struct:`VarianceOptions` | \(9) |
-+--------------------+-------+------------------+------------------------+----------------------------------+-------+
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| Function name | Arity | Input types | Output type | Options class | Notes |
++====================+=========+==================+========================+==================================+=======+
+| all | Unary | Boolean | Scalar Boolean | :struct:`ScalarAggregateOptions` | \(1) |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| any | Unary | Boolean | Scalar Boolean | :struct:`ScalarAggregateOptions` | \(1) |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| approximate_median | Unary | Numeric | Scalar Float64 | :struct:`ScalarAggregateOptions` | |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| count | Unary | Any | Scalar Int64 | :struct:`CountOptions` | \(2) |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| count_all | Nullary | | Scalar Int64 | | |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| count_distinct | Unary | Non-nested types | Scalar Int64 | :struct:`CountOptions` | \(2) |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| index | Unary | Any | Scalar Int64 | :struct:`IndexOptions` | \(3) |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| max | Unary | Non-nested types | Scalar Input type | :struct:`ScalarAggregateOptions` | |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| mean | Unary | Numeric | Scalar Decimal/Float64 | :struct:`ScalarAggregateOptions` | \(4) |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| min | Unary | Non-nested types | Scalar Input type | :struct:`ScalarAggregateOptions` | |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| min_max | Unary | Non-nested types | Scalar Struct | :struct:`ScalarAggregateOptions` | \(5) |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| mode | Unary | Numeric | Struct | :struct:`ModeOptions` | \(6) |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| product | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(7) |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| quantile | Unary | Numeric | Scalar Numeric | :struct:`QuantileOptions` | \(8) |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| stddev | Unary | Numeric | Scalar Float64 | :struct:`VarianceOptions` | \(9) |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| sum | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(7) |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| tdigest | Unary | Numeric | Float64 | :struct:`TDigestOptions` | \(10) |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
+| variance | Unary | Numeric | Scalar Float64 | :struct:`VarianceOptions` | \(9) |
++--------------------+---------+------------------+------------------------+----------------------------------+-------+
* \(1) If null values are taken into account, by setting the
ScalarAggregateOptions parameter skip_nulls = false, then `Kleene logic`_
@@ -321,43 +323,45 @@ The supported aggregation functions are as follows. All function names are
prefixed with ``hash_``, which differentiates them from their scalar
equivalents above and reflects how they are implemented internally.
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| Function name | Arity | Input types | Output type | Options class | Notes |
-+=========================+=======+====================================+========================+==================================+===========+
-| hash_all | Unary | Boolean | Boolean | :struct:`ScalarAggregateOptions` | \(1) |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_any | Unary | Boolean | Boolean | :struct:`ScalarAggregateOptions` | \(1) |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_approximate_median | Unary | Numeric | Float64 | :struct:`ScalarAggregateOptions` | |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_count | Unary | Any | Int64 | :struct:`CountOptions` | \(2) |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_count_distinct | Unary | Any | Int64 | :struct:`CountOptions` | \(2) |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_distinct | Unary | Any | List of input type | :struct:`CountOptions` | \(2) \(3) |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_list | Unary | Any | List of input type | | \(3) |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_max | Unary | Non-nested, non-binary/string-like | Input type | :struct:`ScalarAggregateOptions` | |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_mean | Unary | Numeric | Decimal/Float64 | :struct:`ScalarAggregateOptions` | \(4) |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_min | Unary | Non-nested, non-binary/string-like | Input type | :struct:`ScalarAggregateOptions` | |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_min_max | Unary | Non-nested types | Struct | :struct:`ScalarAggregateOptions` | \(5) |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_one | Unary | Any | Input type | | \(6) |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_product | Unary | Numeric | Numeric | :struct:`ScalarAggregateOptions` | \(7) |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_stddev | Unary | Numeric | Float64 | :struct:`VarianceOptions` | \(8) |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_sum | Unary | Numeric | Numeric | :struct:`ScalarAggregateOptions` | \(7) |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_tdigest | Unary | Numeric | FixedSizeList[Float64] | :struct:`TDigestOptions` | \(9) |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
-| hash_variance | Unary | Numeric | Float64 | :struct:`VarianceOptions` | \(8) |
-+-------------------------+-------+------------------------------------+------------------------+----------------------------------+-----------+
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| Function name | Arity | Input types | Output type | Options class | Notes |
++=========================+=========+====================================+========================+==================================+===========+
+| hash_all | Unary | Boolean | Boolean | :struct:`ScalarAggregateOptions` | \(1) |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_any | Unary | Boolean | Boolean | :struct:`ScalarAggregateOptions` | \(1) |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_approximate_median | Unary | Numeric | Float64 | :struct:`ScalarAggregateOptions` | |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_count | Unary | Any | Int64 | :struct:`CountOptions` | \(2) |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_count_all | Nullary | | Int64 | | |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_count_distinct | Unary | Any | Int64 | :struct:`CountOptions` | \(2) |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_distinct | Unary | Any | List of input type | :struct:`CountOptions` | \(2) \(3) |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_list | Unary | Any | List of input type | | \(3) |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_max | Unary | Non-nested, non-binary/string-like | Input type | :struct:`ScalarAggregateOptions` | |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_mean | Unary | Numeric | Decimal/Float64 | :struct:`ScalarAggregateOptions` | \(4) |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_min | Unary | Non-nested, non-binary/string-like | Input type | :struct:`ScalarAggregateOptions` | |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_min_max | Unary | Non-nested types | Struct | :struct:`ScalarAggregateOptions` | \(5) |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_one | Unary | Any | Input type | | \(6) |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_product | Unary | Numeric | Numeric | :struct:`ScalarAggregateOptions` | \(7) |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_stddev | Unary | Numeric | Float64 | :struct:`VarianceOptions` | \(8) |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_sum | Unary | Numeric | Numeric | :struct:`ScalarAggregateOptions` | \(7) |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_tdigest | Unary | Numeric | FixedSizeList[Float64] | :struct:`TDigestOptions` | \(9) |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
+| hash_variance | Unary | Numeric | Float64 | :struct:`VarianceOptions` | \(8) |
++-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
* \(1) If null values are taken into account, by setting the
:member:`ScalarAggregateOptions::skip_nulls` to false, then `Kleene logic`_
diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst
index 1b5d7b1e05..2c0170d08b 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -45,6 +45,10 @@ Aggregations
tdigest
variance
+..
+ Nullary aggregate functions (count_all) aren't exposed in pyarrow.compute,
+ so they aren't listed here.
+
Cumulative Functions
--------------------
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index 283f532837..1337408aeb 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -2214,13 +2214,19 @@ def _group_by(args, keys, aggregations):
_pack_compute_args(args, &c_args)
_pack_compute_args(keys, &c_keys)
- for aggr_func_name, aggr_opts in aggregations:
+ # reference into the flattened list of arguments for the aggregations
+ field_ref = 0
+ for aggr_arg_names, aggr_func_name, aggr_opts in aggregations:
c_aggr.function = tobytes(aggr_func_name)
if aggr_opts is not None:
c_aggr.options = (<FunctionOptions?>aggr_opts).wrapped
else:
c_aggr.options = <shared_ptr[CFunctionOptions]>nullptr
- c_aggregations.push_back(c_aggr)
+ for _ in aggr_arg_names:
+ c_aggr.target.push_back(CFieldRef(<int> field_ref))
+ field_ref += 1
+
+ c_aggregations.push_back(move(c_aggr))
with nogil:
result = GetResultValue(
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index f455b81411..7e182ff9c4 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -319,6 +319,10 @@ def _make_global_functions():
# Hash aggregate functions are not callable,
# so let's not expose them at module level.
continue
+ if func.kind == "scalar_aggregate" and func.arity == 0:
+ # Nullary scalar aggregate functions are not callable
+ # directly so let's not expose them at module level.
+ continue
assert name not in g, name
g[cpp_name] = g[name] = _wrap_function(name, func)
diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd
index 82358beea1..80a087740f 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -2492,6 +2492,8 @@ cdef extern from "arrow/compute/exec/aggregate.h" namespace \
cdef cppclass CAggregate "arrow::compute::Aggregate":
c_string function
shared_ptr[CFunctionOptions] options
+ vector[CFieldRef] target
+ c_string name
CResult[CDatum] GroupBy(const vector[CDatum]& arguments,
const vector[CDatum]& keys,
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index 47f37a53ad..318c4d323b 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -5379,9 +5379,13 @@ class TableGroupBy:
----------
aggregations : list[tuple(str, str)] or \
list[tuple(str, str, FunctionOptions)]
- List of tuples made of aggregation column names followed
- by function names and optionally aggregation function options.
+ List of tuples, where each tuple is one aggregation specification
+ and consists of: aggregation column name followed
+ by function name and optionally aggregation function option.
Pass empty list to get a single row for each group.
+ The column name can be a string, an empty list or a list of
+ column names, for unary, nullary and n-ary aggregation functions
+ respectively.
Returns
-------
@@ -5402,36 +5406,46 @@ list[tuple(str, str, FunctionOptions)]
----
values_sum: [[3,7,5]]
keys: [["a","b","c"]]
+ >>> t.group_by("keys").aggregate([([], "count_all")])
+ pyarrow.Table
+ count_all: int64
+ keys: string
+ ----
+ count_all: [[2,2,1]]
+ keys: [["a","b","c"]]
>>> t.group_by("keys").aggregate([])
pyarrow.Table
keys: string
----
keys: [["a","b","c"]]
"""
- columns = [a[0] for a in aggregations]
- aggrfuncs = [
- (a[1], a[2]) if len(a) > 2 else (a[1], None)
- for a in aggregations
- ]
-
group_by_aggrs = []
- for aggr in aggrfuncs:
- if not aggr[0].startswith("hash_"):
- aggr = ("hash_" + aggr[0], aggr[1])
- group_by_aggrs.append(aggr)
+ for aggr in aggregations:
+ if len(aggr) == 2:
+ target, func = aggr
+ opt = None
+ else:
+ target, func, opt = aggr
+ if not isinstance(target, (list, tuple)):
+ target = [target]
+ if not func.startswith("hash_"):
+ func = "hash_" + func
+ group_by_aggrs.append((target, func, opt))
# Build unique names for aggregation result columns
# so that it's obvious what they refer to.
- column_names = [
- aggr_name.replace("hash", col_name)
- for col_name, (aggr_name, _) in zip(columns, group_by_aggrs)
+ out_column_names = [
+ aggr_name.replace("hash", "_".join(target))
+ if len(target) > 0 else aggr_name.replace("hash_", "")
+ for target, aggr_name, _ in group_by_aggrs
] + self.keys
+ flat_cols = [c for aggr in group_by_aggrs for c in aggr[0]]
result = _pc()._group_by(
- [self._table[c] for c in columns],
+ [self._table[c] for c in flat_cols],
[self._table[k] for k in self.keys],
group_by_aggrs
)
t = Table.from_batches([RecordBatch.from_struct_array(result)])
- return t.rename_columns(column_names)
+ return t.rename_columns(out_column_names)
diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py
index 04e2dacc48..d710b7aac6 100644
--- a/python/pyarrow/tests/test_table.py
+++ b/python/pyarrow/tests/test_table.py
@@ -1995,6 +1995,23 @@ def test_table_group_by():
"values_sum": [3, 3, 4, 5]
}
+ # Test many arguments
+ r = table.group_by("keys").aggregate([
+ ("values", "max"),
+ ("bigvalues", "sum"),
+ ("bigvalues", "max"),
+ ([], "count_all"),
+ ("values", "sum")
+ ])
+ assert sorted_by_keys(r.to_pydict()) == {
+ "keys": ["a", "b", "c"],
+ "values_max": [2, 4, 5],
+ "bigvalues_sum": [30, 70, 50],
+ "bigvalues_max": [20, 40, 50],
+ "count_all": [2, 2, 1],
+ "values_sum": [3, 7, 5]
+ }
+
table_with_nulls = pa.table([
pa.array(["a", "a", "a"]),
pa.array([1, None, None])
@@ -2024,6 +2041,24 @@ def test_table_group_by():
"values_count": [1]
}
+ r = table_with_nulls.group_by(["keys"]).aggregate([
+ ([], "count_all"), # nullary count that takes no parameters
+ ("values", "count", pc.CountOptions(mode="only_valid"))
+ ])
+ assert r.to_pydict() == {
+ "keys": ["a"],
+ "count_all": [3],
+ "values_count": [1]
+ }
+
+ r = table_with_nulls.group_by(["keys"]).aggregate([
+ ([], "count_all")
+ ])
+ assert r.to_pydict() == {
+ "keys": ["a"],
+ "count_all": [3]
+ }
+
def test_table_to_recordbatchreader():
table = pa.Table.from_pydict({'x': [1, 2, 3]})