You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2022/07/25 23:59:21 UTC

[arrow] branch master updated: ARROW-15591: [C++] Add support for aggregation to the Substrait consumer (#13130)

This is an automated email from the ASF dual-hosted git repository.

westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 87cefe80c7 ARROW-15591: [C++] Add support for aggregation to the Substrait consumer (#13130)
87cefe80c7 is described below

commit 87cefe80c7126a6bf91b809159f39c9aa2a8db61
Author: Vibhatha Lakmal Abeykoon <vi...@users.noreply.github.com>
AuthorDate: Tue Jul 26 05:29:15 2022 +0530

    ARROW-15591: [C++] Add support for aggregation to the Substrait consumer (#13130)
    
    This PR includes the Substrait-Arrow Aggregate integration where a Substrait plan can be consumed in ACERO.
    
    Lead-authored-by: Vibhatha Abeykoon <vi...@gmail.com>
    Co-authored-by: Vibhatha Lakmal Abeykoon <vi...@users.noreply.github.com>
    Signed-off-by: Weston Pace <we...@gmail.com>
---
 cpp/src/arrow/engine/substrait/extension_set.cc    |   1 +
 .../arrow/engine/substrait/relation_internal.cc    |  72 ++++-
 cpp/src/arrow/engine/substrait/serde_test.cc       | 317 +++++++++++++++++++++
 3 files changed, 389 insertions(+), 1 deletion(-)

diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc
index f60f6ac1cb..08eb6acc9c 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.cc
+++ b/cpp/src/arrow/engine/substrait/extension_set.cc
@@ -445,6 +445,7 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl {
              "add",
              "equal",
              "is_not_distinct_from",
+             "hash_count",
          }) {
       DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string()));
     }
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc
index 09ecb2f069..8f6cb0ce36 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.cc
+++ b/cpp/src/arrow/engine/substrait/relation_internal.cc
@@ -307,7 +307,7 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel,
             callptr->function_name);
       }
 
-      // TODO: ARROW-166241 Add Suffix support for Substrait
+      // TODO: ARROW-16624 Add Suffix support for Substrait
       const auto* left_keys = callptr->arguments[0].field_ref();
       const auto* right_keys = callptr->arguments[1].field_ref();
       if (!left_keys || !right_keys) {
@@ -323,6 +323,76 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel,
       join_dec.inputs.emplace_back(std::move(right.declaration));
       return DeclarationInfo{std::move(join_dec), num_columns};
     }
+    case substrait::Rel::RelTypeCase::kAggregate: {
+      const auto& aggregate = rel.aggregate();
+      RETURN_NOT_OK(CheckRelCommon(aggregate));
+
+      if (!aggregate.has_input()) {
+        return Status::Invalid("substrait::AggregateRel with no input relation");
+      }
+
+      ARROW_ASSIGN_OR_RAISE(auto input, FromProto(aggregate.input(), ext_set));
+
+      if (aggregate.groupings_size() > 1) {
+        return Status::NotImplemented(
+            "Grouping sets not supported.  AggregateRel::groupings may not have more "
+            "than one item");
+      }
+      std::vector<FieldRef> keys;
+      auto group = aggregate.groupings(0);
+      keys.reserve(group.grouping_expressions_size());
+      for (int exp_id = 0; exp_id < group.grouping_expressions_size(); exp_id++) {
+        ARROW_ASSIGN_OR_RAISE(auto expr,
+                              FromProto(group.grouping_expressions(exp_id), ext_set));
+        const auto* field_ref = expr.field_ref();
+        if (field_ref) {
+          keys.emplace_back(std::move(*field_ref));
+        } else {
+          return Status::Invalid(
+              "The grouping expression for an aggregate must be a direct reference.");
+        }
+      }
+
+      int measure_size = aggregate.measures_size();
+      std::vector<compute::Aggregate> aggregates;
+      aggregates.reserve(measure_size);
+      for (int measure_id = 0; measure_id < measure_size; measure_id++) {
+        const auto& agg_measure = aggregate.measures(measure_id);
+        if (agg_measure.has_measure()) {
+          if (agg_measure.has_filter()) {
+            return Status::NotImplemented("Aggregate filters are not supported.");
+          }
+          const auto& agg_func = agg_measure.measure();
+          if (agg_func.arguments_size() != 1) {
+            return Status::NotImplemented("Aggregate function must be a unary function.");
+          }
+          int func_reference = agg_func.function_reference();
+          ARROW_ASSIGN_OR_RAISE(auto func_record, ext_set.DecodeFunction(func_reference));
+          // aggreagte function name
+          auto func_name = std::string(func_record.id.name);
+          // aggregate target
+          auto subs_func_args = agg_func.arguments(0);
+          ARROW_ASSIGN_OR_RAISE(auto field_expr,
+                                FromProto(subs_func_args.value(), ext_set));
+          auto target = field_expr.field_ref();
+          if (!target) {
+            return Status::Invalid(
+                "The input expression to an aggregate function must be a direct "
+                "reference.");
+          }
+          aggregates.emplace_back(compute::Aggregate{std::move(func_name), NULLPTR,
+                                                     std::move(*target), std::move("")});
+        } else {
+          return Status::Invalid("substrait::AggregateFunction not provided");
+        }
+      }
+
+      return DeclarationInfo{
+          compute::Declaration::Sequence(
+              {std::move(input.declaration),
+               {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}}),
+          static_cast<int>(aggregates.size())};
+    }
 
     default:
       break;
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc
index e10082392d..8e5745d6df 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -1383,5 +1383,322 @@ TEST(Substrait, JoinPlanInvalidKeys) {
   }
 }
 
+TEST(Substrait, AggregateBasic) {
+  ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+    "relations": [{
+      "rel": {
+        "aggregate": {
+          "input": {
+            "read": {
+              "base_schema": {
+                "names": ["A", "B", "C"],
+                "struct": {
+                  "types": [{
+                    "i32": {}
+                  }, {
+                    "i32": {}
+                  }, {
+                    "i32": {}
+                  }]
+                }
+              },
+              "local_files": { 
+                "items": [
+                  {
+                    "uri_file": "file:///tmp/dat.parquet",
+                    "parquet": {}
+                  }
+                ]
+              }
+            }
+          },
+          "groupings": [{
+            "groupingExpressions": [{
+              "selection": {
+                "directReference": {
+                  "structField": {
+                    "field": 0
+                  }
+                }
+              }
+            }]
+          }],
+          "measures": [{
+            "measure": {
+              "functionReference": 0,
+              "arguments": [{
+                "value": {
+                  "selection": {
+                    "directReference": {
+                      "structField": {
+                        "field": 1
+                      }
+                    }
+                  }
+                }
+            }],
+              "sorts": [],
+              "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+              "outputType": {
+                "i64": {}
+              }
+            }
+          }]
+        }
+      }
+    }],
+    "extensionUris": [{
+      "extension_uri_anchor": 0,
+      "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+    }],
+    "extensions": [{
+      "extension_function": {
+        "extension_uri_reference": 0,
+        "function_anchor": 0,
+        "name": "hash_count"
+      }
+    }],
+  })"));
+
+  auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+  ASSERT_OK_AND_ASSIGN(auto sink_decls,
+                       DeserializePlans(*buf, [] { return kNullConsumer; }));
+  auto agg_decl = sink_decls[0].inputs[0];
+
+  const auto& agg_rel = agg_decl.get<compute::Declaration>();
+
+  const auto& agg_options =
+      checked_cast<const compute::AggregateNodeOptions&>(*agg_rel->options);
+
+  EXPECT_EQ(agg_rel->factory_name, "aggregate");
+  EXPECT_EQ(agg_options.aggregates[0].name, "");
+  EXPECT_EQ(agg_options.aggregates[0].function, "hash_count");
+}
+
+TEST(Substrait, AggregateInvalidRel) {
+  ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+    "relations": [{
+      "rel": {
+        "aggregate": {
+        }
+      }
+    }],
+    "extensionUris": [{
+      "extension_uri_anchor": 0,
+      "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+    }],
+    "extensions": [{
+      "extension_function": {
+        "extension_uri_reference": 0,
+        "function_anchor": 0,
+        "name": "hash_count"
+      }
+    }],
+  })"));
+
+  ASSERT_RAISES(Invalid, DeserializePlans(*buf, [] { return kNullConsumer; }));
+}
+
+TEST(Substrait, AggregateInvalidFunction) {
+  ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+    "relations": [{
+      "rel": {
+        "aggregate": {
+          "input": {
+            "read": {
+              "base_schema": {
+                "names": ["A", "B", "C"],
+                "struct": {
+                  "types": [{
+                    "i32": {}
+                  }, {
+                    "i32": {}
+                  }, {
+                    "i32": {}
+                  }]
+                }
+              },
+              "local_files": { 
+                "items": [
+                  {
+                    "uri_file": "file:///tmp/dat.parquet",
+                    "parquet": {}
+                  }
+                ]
+              }
+            }
+          },
+          "groupings": [{
+            "groupingExpressions": [{
+              "selection": {
+                "directReference": {
+                  "structField": {
+                    "field": 0
+                  }
+                }
+              }
+            }]
+          }],
+          "measures": [{
+          }]
+        }
+      }
+    }],
+    "extensionUris": [{
+      "extension_uri_anchor": 0,
+      "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+    }],
+    "extensions": [{
+      "extension_function": {
+        "extension_uri_reference": 0,
+        "function_anchor": 0,
+        "name": "hash_count"
+      }
+    }],
+  })"));
+
+  ASSERT_RAISES(Invalid, DeserializePlans(*buf, [] { return kNullConsumer; }));
+}
+
+TEST(Substrait, AggregateInvalidAggFuncArgs) {
+  ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+    "relations": [{
+      "rel": {
+        "aggregate": {
+          "input": {
+            "read": {
+              "base_schema": {
+                "names": ["A", "B", "C"],
+                "struct": {
+                  "types": [{
+                    "i32": {}
+                  }, {
+                    "i32": {}
+                  }, {
+                    "i32": {}
+                  }]
+                }
+              },
+              "local_files": { 
+                "items": [
+                  {
+                    "uri_file": "file:///tmp/dat.parquet",
+                    "parquet": {}
+                  }
+                ]
+              }
+            }
+          },
+          "groupings": [{
+            "groupingExpressions": [{
+              "selection": {
+                "directReference": {
+                  "structField": {
+                    "field": 0
+                  }
+                }
+              }
+            }]
+          }],
+          "measures": [{
+            "measure": {
+              "functionReference": 0,
+              "args": [],
+              "sorts": [],
+              "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+              "outputType": {
+                "i64": {}
+              }
+            }
+          }]
+        }
+      }
+    }],
+    "extensionUris": [{
+      "extension_uri_anchor": 0,
+      "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+    }],
+    "extensions": [{
+      "extension_function": {
+        "extension_uri_reference": 0,
+        "function_anchor": 0,
+        "name": "hash_count"
+      }
+    }],
+  })"));
+
+  ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; }));
+}
+
+TEST(Substrait, AggregateWithFilter) {
+  ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+    "relations": [{
+      "rel": {
+        "aggregate": {
+          "input": {
+            "read": {
+              "base_schema": {
+                "names": ["A", "B", "C"],
+                "struct": {
+                  "types": [{
+                    "i32": {}
+                  }, {
+                    "i32": {}
+                  }, {
+                    "i32": {}
+                  }]
+                }
+              },
+              "local_files": { 
+                "items": [
+                  {
+                    "uri_file": "file:///tmp/dat.parquet",
+                    "parquet": {}
+                  }
+                ]
+              }
+            }
+          },
+          "groupings": [{
+            "groupingExpressions": [{
+              "selection": {
+                "directReference": {
+                  "structField": {
+                    "field": 0
+                  }
+                }
+              }
+            }]
+          }],
+          "measures": [{
+            "measure": {
+              "functionReference": 0,
+              "args": [],
+              "sorts": [],
+              "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+              "outputType": {
+                "i64": {}
+              }
+            }
+          }]
+        }
+      }
+    }],
+    "extensionUris": [{
+      "extension_uri_anchor": 0,
+      "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+    }],
+    "extensions": [{
+      "extension_function": {
+        "extension_uri_reference": 0,
+        "function_anchor": 0,
+        "name": "equal"
+      }
+    }],
+  })"));
+
+  ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; }));
+}
+
 }  // namespace engine
 }  // namespace arrow