You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ks...@apache.org on 2022/07/26 09:57:22 UTC
[arrow] 03/03: ARROW-15591: [C++] Add support for aggregation to the Substrait consumer (#13130)
This is an automated email from the ASF dual-hosted git repository.
kszucs pushed a commit to branch maint-9.0.0
in repository https://gitbox.apache.org/repos/asf/arrow.git
commit 74a4a0244ed7ea5acc3512e0a9a93036844abc0e
Author: Vibhatha Lakmal Abeykoon <vi...@users.noreply.github.com>
AuthorDate: Tue Jul 26 05:29:15 2022 +0530
ARROW-15591: [C++] Add support for aggregation to the Substrait consumer (#13130)
This PR includes the Substrait-Arrow Aggregate integration where a Substrait plan can be consumed in ACERO.
Lead-authored-by: Vibhatha Abeykoon <vi...@gmail.com>
Co-authored-by: Vibhatha Lakmal Abeykoon <vi...@users.noreply.github.com>
Signed-off-by: Weston Pace <we...@gmail.com>
---
cpp/src/arrow/engine/substrait/extension_set.cc | 1 +
.../arrow/engine/substrait/relation_internal.cc | 72 ++++-
cpp/src/arrow/engine/substrait/serde_test.cc | 317 +++++++++++++++++++++
3 files changed, 389 insertions(+), 1 deletion(-)
diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc
index f60f6ac1cb..08eb6acc9c 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.cc
+++ b/cpp/src/arrow/engine/substrait/extension_set.cc
@@ -445,6 +445,7 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl {
"add",
"equal",
"is_not_distinct_from",
+ "hash_count",
}) {
DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string()));
}
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc
index 09ecb2f069..8f6cb0ce36 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.cc
+++ b/cpp/src/arrow/engine/substrait/relation_internal.cc
@@ -307,7 +307,7 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel,
callptr->function_name);
}
- // TODO: ARROW-166241 Add Suffix support for Substrait
+ // TODO: ARROW-16624 Add Suffix support for Substrait
const auto* left_keys = callptr->arguments[0].field_ref();
const auto* right_keys = callptr->arguments[1].field_ref();
if (!left_keys || !right_keys) {
@@ -323,6 +323,76 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel,
join_dec.inputs.emplace_back(std::move(right.declaration));
return DeclarationInfo{std::move(join_dec), num_columns};
}
+ case substrait::Rel::RelTypeCase::kAggregate: {
+ const auto& aggregate = rel.aggregate();
+ RETURN_NOT_OK(CheckRelCommon(aggregate));
+
+ if (!aggregate.has_input()) {
+ return Status::Invalid("substrait::AggregateRel with no input relation");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto input, FromProto(aggregate.input(), ext_set));
+
+ if (aggregate.groupings_size() > 1) {
+ return Status::NotImplemented(
+ "Grouping sets not supported. AggregateRel::groupings may not have more "
+ "than one item");
+ }
+ std::vector<FieldRef> keys;
+ auto group = aggregate.groupings(0);
+ keys.reserve(group.grouping_expressions_size());
+ for (int exp_id = 0; exp_id < group.grouping_expressions_size(); exp_id++) {
+ ARROW_ASSIGN_OR_RAISE(auto expr,
+ FromProto(group.grouping_expressions(exp_id), ext_set));
+ const auto* field_ref = expr.field_ref();
+ if (field_ref) {
+ keys.emplace_back(std::move(*field_ref));
+ } else {
+ return Status::Invalid(
+ "The grouping expression for an aggregate must be a direct reference.");
+ }
+ }
+
+ int measure_size = aggregate.measures_size();
+ std::vector<compute::Aggregate> aggregates;
+ aggregates.reserve(measure_size);
+ for (int measure_id = 0; measure_id < measure_size; measure_id++) {
+ const auto& agg_measure = aggregate.measures(measure_id);
+ if (agg_measure.has_measure()) {
+ if (agg_measure.has_filter()) {
+ return Status::NotImplemented("Aggregate filters are not supported.");
+ }
+ const auto& agg_func = agg_measure.measure();
+ if (agg_func.arguments_size() != 1) {
+ return Status::NotImplemented("Aggregate function must be a unary function.");
+ }
+ int func_reference = agg_func.function_reference();
+ ARROW_ASSIGN_OR_RAISE(auto func_record, ext_set.DecodeFunction(func_reference));
+ // aggreagte function name
+ auto func_name = std::string(func_record.id.name);
+ // aggregate target
+ auto subs_func_args = agg_func.arguments(0);
+ ARROW_ASSIGN_OR_RAISE(auto field_expr,
+ FromProto(subs_func_args.value(), ext_set));
+ auto target = field_expr.field_ref();
+ if (!target) {
+ return Status::Invalid(
+ "The input expression to an aggregate function must be a direct "
+ "reference.");
+ }
+ aggregates.emplace_back(compute::Aggregate{std::move(func_name), NULLPTR,
+ std::move(*target), std::move("")});
+ } else {
+ return Status::Invalid("substrait::AggregateFunction not provided");
+ }
+ }
+
+ return DeclarationInfo{
+ compute::Declaration::Sequence(
+ {std::move(input.declaration),
+ {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}}),
+ static_cast<int>(aggregates.size())};
+ }
default:
break;
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc
index e10082392d..8e5745d6df 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -1383,5 +1383,322 @@ TEST(Substrait, JoinPlanInvalidKeys) {
}
}
+TEST(Substrait, AggregateBasic) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ "measure": {
+ "functionReference": 0,
+ "arguments": [{
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 1
+ }
+ }
+ }
+ }
+ }],
+ "sorts": [],
+ "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "outputType": {
+ "i64": {}
+ }
+ }
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ASSERT_OK_AND_ASSIGN(auto sink_decls,
+ DeserializePlans(*buf, [] { return kNullConsumer; }));
+ auto agg_decl = sink_decls[0].inputs[0];
+
+ const auto& agg_rel = agg_decl.get<compute::Declaration>();
+
+ const auto& agg_options =
+ checked_cast<const compute::AggregateNodeOptions&>(*agg_rel->options);
+
+ EXPECT_EQ(agg_rel->factory_name, "aggregate");
+ EXPECT_EQ(agg_options.aggregates[0].name, "");
+ EXPECT_EQ(agg_options.aggregates[0].function, "hash_count");
+}
+
+TEST(Substrait, AggregateInvalidRel) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ ASSERT_RAISES(Invalid, DeserializePlans(*buf, [] { return kNullConsumer; }));
+}
+
+TEST(Substrait, AggregateInvalidFunction) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ ASSERT_RAISES(Invalid, DeserializePlans(*buf, [] { return kNullConsumer; }));
+}
+
+TEST(Substrait, AggregateInvalidAggFuncArgs) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ "measure": {
+ "functionReference": 0,
+ "args": [],
+ "sorts": [],
+ "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "outputType": {
+ "i64": {}
+ }
+ }
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; }));
+}
+
+TEST(Substrait, AggregateWithFilter) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ "measure": {
+ "functionReference": 0,
+ "args": [],
+ "sorts": [],
+ "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "outputType": {
+ "i64": {}
+ }
+ }
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "equal"
+ }
+ }],
+ })"));
+
+ ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; }));
+}
+
} // namespace engine
} // namespace arrow