You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2023/04/26 16:52:45 UTC

[arrow] branch main updated: GH-35247: [C++] Add Arrow Substrait support for stddev/variance (#35249)

This is an automated email from the ASF dual-hosted git repository.

westonpace pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9feee48a7c GH-35247: [C++] Add Arrow Substrait support for stddev/variance (#35249)
9feee48a7c is described below

commit 9feee48a7ceffecd9b1fe34293cc005a7958638f
Author: rtpsw <rt...@hotmail.com>
AuthorDate: Wed Apr 26 19:52:35 2023 +0300

    GH-35247: [C++] Add Arrow Substrait support for stddev/variance (#35249)
    
    See #35247
    * Closes: #35247
    
    Authored-by: Yaron Gvili <rt...@hotmail.com>
    Signed-off-by: Weston Pace <we...@gmail.com>
---
 .../arrow/engine/substrait/expression_internal.cc  |   3 +
 cpp/src/arrow/engine/substrait/extension_set.cc    |  31 +++++-
 cpp/src/arrow/engine/substrait/function_test.cc    |  13 +++
 cpp/src/arrow/engine/substrait/serde_test.cc       | 112 +++++++++++++++++++++
 4 files changed, 157 insertions(+), 2 deletions(-)

diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc
index 722dec2a30..5e214bdda4 100644
--- a/cpp/src/arrow/engine/substrait/expression_internal.cc
+++ b/cpp/src/arrow/engine/substrait/expression_internal.cc
@@ -269,6 +269,9 @@ Result<SubstraitCall> FromProto(const substrait::AggregateFunction& func, bool i
     ARROW_RETURN_NOT_OK(DecodeArg(func.arguments(i), static_cast<uint32_t>(i), &call,
                                   ext_set, conversion_options));
   }
+  for (int i = 0; i < func.options_size(); i++) {
+    ARROW_RETURN_NOT_OK(DecodeOption(func.options(i), &call));
+  }
   return std::move(call);
 }
 
diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc
index 6db36e6a7a..d0959005f1 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.cc
+++ b/cpp/src/arrow/engine/substrait/extension_set.cc
@@ -944,6 +944,30 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate(
                                call.id().name, " to have at least one argument");
       }
       case 1: {
+        std::shared_ptr<compute::FunctionOptions> options = nullptr;
+        if (arrow_function_name == "stddev" || arrow_function_name == "variance") {
+          // See the following URL for the spec of stddev and variance:
+          // https://github.com/substrait-io/substrait/blob/
+          // 73228b4112d79eb1011af0ebb41753ce23ca180c/
+          // extensions/functions_arithmetic.yaml#L1240
+          auto maybe_dist = call.GetOption("distribution");
+          if (maybe_dist) {
+            auto& prefs = **maybe_dist;
+            if (prefs.size() != 1) {
+              return Status::Invalid("expected a single preference for ",
+                                     arrow_function_name, " but got ", prefs.size());
+            }
+            int ddof;
+            if (prefs[0] == "POPULATION") {
+              ddof = 1;
+            } else if (prefs[0] == "SAMPLE") {
+              ddof = 0;
+            } else {
+              return Status::Invalid("unknown distribution preference ", prefs[0]);
+            }
+            options = std::make_shared<compute::VarianceOptions>(ddof);
+          }
+        }
         fixed_arrow_func += arrow_function_name;
 
         ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0));
@@ -953,7 +977,8 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate(
                                  call.id().name, " to have a direct reference");
         }
 
-        return compute::Aggregate{std::move(fixed_arrow_func), *arg_ref, ""};
+        return compute::Aggregate{std::move(fixed_arrow_func),
+                                  options ? std::move(options) : nullptr, *arg_ref, ""};
       }
       default:
         break;
@@ -1069,12 +1094,14 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl {
         DecodeOptionlessBasicMapping("is_valid", /*max_args=*/1)));
 
     // --------------- Substrait -> Arrow Aggregates --------------
-    for (const auto& fn_name : {"sum", "min", "max"}) {
+    for (const auto& fn_name : {"sum", "min", "max", "variance"}) {
       DCHECK_OK(AddSubstraitAggregateToArrow({kSubstraitArithmeticFunctionsUri, fn_name},
                                              DecodeBasicAggregate(fn_name)));
     }
     DCHECK_OK(AddSubstraitAggregateToArrow({kSubstraitArithmeticFunctionsUri, "avg"},
                                            DecodeBasicAggregate("mean")));
+    DCHECK_OK(AddSubstraitAggregateToArrow({kSubstraitArithmeticFunctionsUri, "std_dev"},
+                                           DecodeBasicAggregate("stddev")));
     DCHECK_OK(
         AddSubstraitAggregateToArrow({kSubstraitAggregateGenericFunctionsUri, "count"},
                                      DecodeBasicAggregate("count")));
diff --git a/cpp/src/arrow/engine/substrait/function_test.cc b/cpp/src/arrow/engine/substrait/function_test.cc
index 9164bf0a4b..f48e4c090d 100644
--- a/cpp/src/arrow/engine/substrait/function_test.cc
+++ b/cpp/src/arrow/engine/substrait/function_test.cc
@@ -650,6 +650,7 @@ void CheckWholeAggregateCase(const AggregateTestCase& test_case) {
 }
 
 void CheckGroupedAggregateCase(const AggregateTestCase& test_case) {
+  ARROW_SCOPED_TRACE("function = ", test_case.function_id.ToString());
   std::shared_ptr<Table> output_table;
   std::shared_ptr<acero::ExecPlan> plan =
       PlanFromAggregateCase(test_case, &output_table, /*with_keys=*/true);
@@ -720,6 +721,18 @@ TEST(FunctionMapping, AggregateCases) {
        "[2, 1]",
        int64(),
        /*nullary=*/true},
+      {{kSubstraitArithmeticFunctionsUri, "variance"},
+       "[1, 2, 3]",
+       float64(),
+       "[0.6666666666666666]",
+       "[0.25, 0]",
+       float64()},
+      {{kSubstraitArithmeticFunctionsUri, "std_dev"},
+       "[1, 2, 3]",
+       float64(),
+       "[0.816496580927726]",
+       "[0.5, 0]",
+       float64()},
   };
   CheckAggregateCases(test_cases);
 }
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc
index 3a96618880..d89924b807 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -3466,6 +3466,118 @@ TEST(SubstraitRoundTrip, AggregateRel) {
                        /*include_columns=*/{}, conversion_options);
 }
 
+TEST(SubstraitRoundTrip, AggregateRelOptions) {
+  auto dummy_schema =
+      schema({field("A", int32()), field("B", int32()), field("C", int32())});
+
+  // creating a dummy dataset using a dummy table
+  auto input_table = TableFromJSON(dummy_schema, {R"([
+      [10, 1, 10],
+      [20, 2, 10],
+      [30, 3, 20],
+      [30, 1, 20],
+      [20, 2, 30],
+      [10, 3, 30]
+  ])"});
+
+  std::string substrait_json = R"({
+    "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
+    "relations": [{
+      "rel": {
+        "aggregate": {
+          "input": {
+            "read": {
+              "base_schema": {
+                "names": ["A", "B", "C"],
+                "struct": {
+                  "types": [{
+                    "i32": {}
+                  }, {
+                    "i32": {}
+                  }, {
+                    "i32": {}
+                  }]
+                }
+              },
+              "namedTable" : {
+                "names": ["A"]
+              }
+            }
+          },
+          "groupings": [{
+            "groupingExpressions": [{
+              "selection": {
+                "directReference": {
+                  "structField": {
+                    "field": 0
+                  }
+                }
+              }
+            }]
+          }],
+          "measures": [{
+            "measure": {
+              "functionReference": 0,
+              "arguments": [{
+                "value": {
+                  "selection": {
+                    "directReference": {
+                      "structField": {
+                        "field": 2
+                      }
+                    }
+                  }
+                }
+            }],
+              "options": [{
+                "name": "distribution",
+                "preference": [
+                  "POPULATION"
+                ]
+              }],
+              "sorts": [],
+              "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+              "invocation": "AGGREGATION_INVOCATION_ALL",
+              "outputType": {
+                "i64": {}
+              }
+            }
+          }]
+        }
+      }
+    }],
+    "extensionUris": [{
+      "extension_uri_anchor": 0,
+      "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
+    }],
+    "extensions": [{
+      "extension_function": {
+        "extension_uri_reference": 0,
+        "function_anchor": 0,
+        "name": "variance"
+      }
+    }],
+  })";
+
+  ASSERT_OK_AND_ASSIGN(auto buf,
+                       internal::SubstraitFromJSON("Plan", substrait_json,
+                                                   /*ignore_unknown_fields=*/false));
+  auto output_schema = schema({field("keys", int32()), field("aggregates", float64())});
+  auto expected_table = TableFromJSON(output_schema, {R"([
+      [10, 200],
+      [20, 200],
+      [30, 0]
+  ])"});
+
+  NamedTableProvider table_provider = AlwaysProvideSameTable(std::move(input_table));
+
+  ConversionOptions conversion_options;
+  conversion_options.named_table_provider = std::move(table_provider);
+
+  CheckRoundTripResult(std::move(expected_table), buf,
+                       /*include_columns=*/{}, conversion_options);
+}
+
 TEST(SubstraitRoundTrip, AggregateRelEmit) {
   auto dummy_schema =
       schema({field("A", int32()), field("B", int32()), field("C", int32())});