You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2023/04/26 16:52:45 UTC
[arrow] branch main updated: GH-35247: [C++] Add Arrow Substrait support for stddev/variance (#35249)
This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 9feee48a7c GH-35247: [C++] Add Arrow Substrait support for stddev/variance (#35249)
9feee48a7c is described below
commit 9feee48a7ceffecd9b1fe34293cc005a7958638f
Author: rtpsw <rt...@hotmail.com>
AuthorDate: Wed Apr 26 19:52:35 2023 +0300
GH-35247: [C++] Add Arrow Substrait support for stddev/variance (#35249)
See #35247
* Closes: #35247
Authored-by: Yaron Gvili <rt...@hotmail.com>
Signed-off-by: Weston Pace <we...@gmail.com>
---
.../arrow/engine/substrait/expression_internal.cc | 3 +
cpp/src/arrow/engine/substrait/extension_set.cc | 31 +++++-
cpp/src/arrow/engine/substrait/function_test.cc | 13 +++
cpp/src/arrow/engine/substrait/serde_test.cc | 112 +++++++++++++++++++++
4 files changed, 157 insertions(+), 2 deletions(-)
diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc
index 722dec2a30..5e214bdda4 100644
--- a/cpp/src/arrow/engine/substrait/expression_internal.cc
+++ b/cpp/src/arrow/engine/substrait/expression_internal.cc
@@ -269,6 +269,9 @@ Result<SubstraitCall> FromProto(const substrait::AggregateFunction& func, bool i
ARROW_RETURN_NOT_OK(DecodeArg(func.arguments(i), static_cast<uint32_t>(i), &call,
ext_set, conversion_options));
}
+ for (int i = 0; i < func.options_size(); i++) {
+ ARROW_RETURN_NOT_OK(DecodeOption(func.options(i), &call));
+ }
return std::move(call);
}
diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc
index 6db36e6a7a..d0959005f1 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.cc
+++ b/cpp/src/arrow/engine/substrait/extension_set.cc
@@ -944,6 +944,30 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate(
call.id().name, " to have at least one argument");
}
case 1: {
+ std::shared_ptr<compute::FunctionOptions> options = nullptr;
+ if (arrow_function_name == "stddev" || arrow_function_name == "variance") {
+ // See the following URL for the spec of stddev and variance:
+ // https://github.com/substrait-io/substrait/blob/
+ // 73228b4112d79eb1011af0ebb41753ce23ca180c/
+ // extensions/functions_arithmetic.yaml#L1240
+ auto maybe_dist = call.GetOption("distribution");
+ if (maybe_dist) {
+ auto& prefs = **maybe_dist;
+ if (prefs.size() != 1) {
+ return Status::Invalid("expected a single preference for ",
+ arrow_function_name, " but got ", prefs.size());
+ }
+ int ddof;
+ if (prefs[0] == "POPULATION") {
+ ddof = 1;
+ } else if (prefs[0] == "SAMPLE") {
+ ddof = 0;
+ } else {
+ return Status::Invalid("unknown distribution preference ", prefs[0]);
+ }
+ options = std::make_shared<compute::VarianceOptions>(ddof);
+ }
+ }
fixed_arrow_func += arrow_function_name;
ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0));
@@ -953,7 +977,8 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate(
call.id().name, " to have a direct reference");
}
- return compute::Aggregate{std::move(fixed_arrow_func), *arg_ref, ""};
+ return compute::Aggregate{std::move(fixed_arrow_func),
+ options ? std::move(options) : nullptr, *arg_ref, ""};
}
default:
break;
@@ -1069,12 +1094,14 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl {
DecodeOptionlessBasicMapping("is_valid", /*max_args=*/1)));
// --------------- Substrait -> Arrow Aggregates --------------
- for (const auto& fn_name : {"sum", "min", "max"}) {
+ for (const auto& fn_name : {"sum", "min", "max", "variance"}) {
DCHECK_OK(AddSubstraitAggregateToArrow({kSubstraitArithmeticFunctionsUri, fn_name},
DecodeBasicAggregate(fn_name)));
}
DCHECK_OK(AddSubstraitAggregateToArrow({kSubstraitArithmeticFunctionsUri, "avg"},
DecodeBasicAggregate("mean")));
+ DCHECK_OK(AddSubstraitAggregateToArrow({kSubstraitArithmeticFunctionsUri, "std_dev"},
+ DecodeBasicAggregate("stddev")));
DCHECK_OK(
AddSubstraitAggregateToArrow({kSubstraitAggregateGenericFunctionsUri, "count"},
DecodeBasicAggregate("count")));
diff --git a/cpp/src/arrow/engine/substrait/function_test.cc b/cpp/src/arrow/engine/substrait/function_test.cc
index 9164bf0a4b..f48e4c090d 100644
--- a/cpp/src/arrow/engine/substrait/function_test.cc
+++ b/cpp/src/arrow/engine/substrait/function_test.cc
@@ -650,6 +650,7 @@ void CheckWholeAggregateCase(const AggregateTestCase& test_case) {
}
void CheckGroupedAggregateCase(const AggregateTestCase& test_case) {
+ ARROW_SCOPED_TRACE("function = ", test_case.function_id.ToString());
std::shared_ptr<Table> output_table;
std::shared_ptr<acero::ExecPlan> plan =
PlanFromAggregateCase(test_case, &output_table, /*with_keys=*/true);
@@ -720,6 +721,18 @@ TEST(FunctionMapping, AggregateCases) {
"[2, 1]",
int64(),
/*nullary=*/true},
+ {{kSubstraitArithmeticFunctionsUri, "variance"},
+ "[1, 2, 3]",
+ float64(),
+ "[0.6666666666666666]",
+ "[0.25, 0]",
+ float64()},
+ {{kSubstraitArithmeticFunctionsUri, "std_dev"},
+ "[1, 2, 3]",
+ float64(),
+ "[0.816496580927726]",
+ "[0.5, 0]",
+ float64()},
};
CheckAggregateCases(test_cases);
}
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc
index 3a96618880..d89924b807 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -3466,6 +3466,118 @@ TEST(SubstraitRoundTrip, AggregateRel) {
/*include_columns=*/{}, conversion_options);
}
+TEST(SubstraitRoundTrip, AggregateRelOptions) {
+ auto dummy_schema =
+ schema({field("A", int32()), field("B", int32()), field("C", int32())});
+
+ // creating a dummy dataset using a dummy table
+ auto input_table = TableFromJSON(dummy_schema, {R"([
+ [10, 1, 10],
+ [20, 2, 10],
+ [30, 3, 20],
+ [30, 1, 20],
+ [20, 2, 30],
+ [10, 3, 30]
+ ])"});
+
+ std::string substrait_json = R"({
+ "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "namedTable" : {
+ "names": ["A"]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ "measure": {
+ "functionReference": 0,
+ "arguments": [{
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 2
+ }
+ }
+ }
+ }
+ }],
+ "options": [{
+ "name": "distribution",
+ "preference": [
+ "POPULATION"
+ ]
+ }],
+ "sorts": [],
+ "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "invocation": "AGGREGATION_INVOCATION_ALL",
+ "outputType": {
+ "i64": {}
+ }
+ }
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "variance"
+ }
+ }],
+ })";
+
+ ASSERT_OK_AND_ASSIGN(auto buf,
+ internal::SubstraitFromJSON("Plan", substrait_json,
+ /*ignore_unknown_fields=*/false));
+ auto output_schema = schema({field("keys", int32()), field("aggregates", float64())});
+ auto expected_table = TableFromJSON(output_schema, {R"([
+ [10, 200],
+ [20, 200],
+ [30, 0]
+ ])"});
+
+ NamedTableProvider table_provider = AlwaysProvideSameTable(std::move(input_table));
+
+ ConversionOptions conversion_options;
+ conversion_options.named_table_provider = std::move(table_provider);
+
+ CheckRoundTripResult(std::move(expected_table), buf,
+ /*include_columns=*/{}, conversion_options);
+}
+
TEST(SubstraitRoundTrip, AggregateRelEmit) {
auto dummy_schema =
schema({field("A", int32()), field("B", int32()), field("C", int32())});