You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2022/06/02 20:33:05 UTC
[arrow] branch master updated: ARROW-16657: [C++] Support nesting of extension-id-registries (#13232)
This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new dc36e9e2d5 ARROW-16657: [C++] Support nesting of extension-id-registries (#13232)
dc36e9e2d5 is described below
commit dc36e9e2d5c6af519a9dc279c3beded3b315c4d1
Author: rtpsw <rt...@hotmail.com>
AuthorDate: Thu Jun 2 23:32:58 2022 +0300
ARROW-16657: [C++] Support nesting of extension-id-registries (#13232)
Replacing https://github.com/apache/arrow/pull/13214
Lead-authored-by: Yaron Gvili <rt...@hotmail.com>
Co-authored-by: rtpsw <rt...@hotmail.com>
Signed-off-by: Weston Pace <we...@gmail.com>
---
cpp/src/arrow/engine/CMakeLists.txt | 1 +
cpp/src/arrow/engine/substrait/ext_test.cc | 271 +++++++++++++++++++
cpp/src/arrow/engine/substrait/extension_set.cc | 335 ++++++++++++++++--------
cpp/src/arrow/engine/substrait/extension_set.h | 22 +-
cpp/src/arrow/engine/substrait/plan_internal.cc | 2 +-
cpp/src/arrow/engine/substrait/plan_internal.h | 2 +-
cpp/src/arrow/engine/substrait/util.cc | 4 +
cpp/src/arrow/engine/substrait/util.h | 4 +
8 files changed, 522 insertions(+), 119 deletions(-)
diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt
index ea9797ea1d..8edd22900e 100644
--- a/cpp/src/arrow/engine/CMakeLists.txt
+++ b/cpp/src/arrow/engine/CMakeLists.txt
@@ -66,6 +66,7 @@ endif()
add_arrow_test(substrait_test
SOURCES
+ substrait/ext_test.cc
substrait/serde_test.cc
EXTRA_LINK_LIBS
${ARROW_SUBSTRAIT_TEST_LINK_LIBS}
diff --git a/cpp/src/arrow/engine/substrait/ext_test.cc b/cpp/src/arrow/engine/substrait/ext_test.cc
new file mode 100644
index 0000000000..8e41cb7c98
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/ext_test.cc
@@ -0,0 +1,271 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/engine/substrait/extension_set.h"
+#include "arrow/engine/substrait/util.h"
+
+#include <google/protobuf/descriptor.h>
+#include <google/protobuf/util/json_util.h>
+#include <google/protobuf/util/type_resolver_util.h>
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+
+using testing::ElementsAre;
+using testing::Eq;
+using testing::HasSubstr;
+using testing::UnorderedElementsAre;
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace engine {
+
+// an extension-id-registry provider to be used as a test parameter
+//
+// we cannot pass a pointer to a nested registry as a test parameter because the
+// shared_ptr in which it is made would not be held and get destructed too early,
+// nor can we pass a shared_ptr to the default nested registry as a test parameter
+// because it is global and must never be cleaned up, so we pass a shared_ptr to a
+// provider that either owns or does not own the registry it provides, depending
+// on the case.
+struct ExtensionIdRegistryProvider {
+ virtual ExtensionIdRegistry* get() const = 0;
+};
+
+struct DefaultExtensionIdRegistryProvider : public ExtensionIdRegistryProvider {
+ virtual ~DefaultExtensionIdRegistryProvider() {}
+ ExtensionIdRegistry* get() const override { return default_extension_id_registry(); }
+};
+
+struct NestedExtensionIdRegistryProvider : public ExtensionIdRegistryProvider {
+ virtual ~NestedExtensionIdRegistryProvider() {}
+ std::shared_ptr<ExtensionIdRegistry> registry_ = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* get() const override { return &*registry_; }
+};
+
+using Id = ExtensionIdRegistry::Id;
+
+bool operator==(const Id& id1, const Id& id2) {
+ return id1.uri == id2.uri && id1.name == id2.name;
+}
+
+bool operator!=(const Id& id1, const Id& id2) { return !(id1 == id2); }
+
+struct TypeName {
+ std::shared_ptr<DataType> type;
+ util::string_view name;
+};
+
+static const std::vector<TypeName> kTypeNames = {
+ TypeName{uint8(), "u8"},
+ TypeName{uint16(), "u16"},
+ TypeName{uint32(), "u32"},
+ TypeName{uint64(), "u64"},
+ TypeName{float16(), "fp16"},
+ TypeName{null(), "null"},
+ TypeName{month_interval(), "interval_month"},
+ TypeName{day_time_interval(), "interval_day_milli"},
+ TypeName{month_day_nano_interval(), "interval_month_day_nano"},
+};
+
+static const std::vector<util::string_view> kFunctionNames = {
+ "add",
+};
+
+static const std::vector<util::string_view> kTempFunctionNames = {
+ "temp_func_1",
+ "temp_func_2",
+};
+
+static const std::vector<TypeName> kTempTypeNames = {
+ TypeName{timestamp(TimeUnit::SECOND, "temp_tz_1"), "temp_type_1"},
+ TypeName{timestamp(TimeUnit::SECOND, "temp_tz_2"), "temp_type_2"},
+};
+
+static Id kNonExistentId{kArrowExtTypesUri, "non_existent"};
+static TypeName kNonExistentTypeName{timestamp(TimeUnit::SECOND, "non_existent_tz_1"),
+ "non_existent_type_1"};
+
+using ExtensionIdRegistryParams =
+ std::tuple<std::shared_ptr<ExtensionIdRegistryProvider>, std::string>;
+
+struct ExtensionIdRegistryTest
+ : public testing::TestWithParam<ExtensionIdRegistryParams> {};
+
+TEST_P(ExtensionIdRegistryTest, GetTypes) {
+ auto provider = std::get<0>(GetParam());
+ auto registry = provider->get();
+
+ for (TypeName e : kTypeNames) {
+ auto id = Id{kArrowExtTypesUri, e.name};
+ for (auto typerec_opt : {registry->GetType(id), registry->GetType(*e.type)}) {
+ ASSERT_TRUE(typerec_opt);
+ auto typerec = typerec_opt.value();
+ ASSERT_EQ(id, typerec.id);
+ ASSERT_EQ(*e.type, *typerec.type);
+ }
+ }
+ ASSERT_FALSE(registry->GetType(kNonExistentId));
+ ASSERT_FALSE(registry->GetType(*kNonExistentTypeName.type));
+}
+
+TEST_P(ExtensionIdRegistryTest, ReregisterTypes) {
+ auto provider = std::get<0>(GetParam());
+ auto registry = provider->get();
+
+ for (TypeName e : kTypeNames) {
+ auto id = Id{kArrowExtTypesUri, e.name};
+ ASSERT_RAISES(Invalid, registry->CanRegisterType(id, e.type));
+ ASSERT_RAISES(Invalid, registry->RegisterType(id, e.type));
+ }
+}
+
+TEST_P(ExtensionIdRegistryTest, GetFunctions) {
+ auto provider = std::get<0>(GetParam());
+ auto registry = provider->get();
+
+ for (util::string_view name : kFunctionNames) {
+ auto id = Id{kArrowExtTypesUri, name};
+ for (auto funcrec_opt : {registry->GetFunction(id), registry->GetFunction(name)}) {
+ ASSERT_TRUE(funcrec_opt);
+ auto funcrec = funcrec_opt.value();
+ ASSERT_EQ(id, funcrec.id);
+ ASSERT_EQ(name, funcrec.function_name);
+ }
+ }
+ ASSERT_FALSE(registry->GetType(kNonExistentId));
+ ASSERT_FALSE(registry->GetType(*kNonExistentTypeName.type));
+}
+
+TEST_P(ExtensionIdRegistryTest, ReregisterFunctions) {
+ auto provider = std::get<0>(GetParam());
+ auto registry = provider->get();
+
+ for (util::string_view name : kFunctionNames) {
+ auto id = Id{kArrowExtTypesUri, name};
+ ASSERT_RAISES(Invalid, registry->CanRegisterFunction(id, name.to_string()));
+ ASSERT_RAISES(Invalid, registry->RegisterFunction(id, name.to_string()));
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ Substrait, ExtensionIdRegistryTest,
+ testing::Values(
+ std::make_tuple(std::make_shared<DefaultExtensionIdRegistryProvider>(),
+ "default"),
+ std::make_tuple(std::make_shared<NestedExtensionIdRegistryProvider>(),
+ "nested")));
+
+TEST(ExtensionIdRegistryTest, RegisterTempTypes) {
+ auto default_registry = default_extension_id_registry();
+ constexpr int rounds = 3;
+ for (int i = 0; i < rounds; i++) {
+ auto registry = substrait::MakeExtensionIdRegistry();
+
+ for (TypeName e : kTempTypeNames) {
+ auto id = Id{kArrowExtTypesUri, e.name};
+ ASSERT_OK(registry->CanRegisterType(id, e.type));
+ ASSERT_OK(registry->RegisterType(id, e.type));
+ ASSERT_RAISES(Invalid, registry->CanRegisterType(id, e.type));
+ ASSERT_RAISES(Invalid, registry->RegisterType(id, e.type));
+ ASSERT_OK(default_registry->CanRegisterType(id, e.type));
+ }
+ }
+}
+
+TEST(ExtensionIdRegistryTest, RegisterTempFunctions) {
+ auto default_registry = default_extension_id_registry();
+ constexpr int rounds = 3;
+ for (int i = 0; i < rounds; i++) {
+ auto registry = substrait::MakeExtensionIdRegistry();
+
+ for (util::string_view name : kTempFunctionNames) {
+ auto id = Id{kArrowExtTypesUri, name};
+ ASSERT_OK(registry->CanRegisterFunction(id, name.to_string()));
+ ASSERT_OK(registry->RegisterFunction(id, name.to_string()));
+ ASSERT_RAISES(Invalid, registry->CanRegisterFunction(id, name.to_string()));
+ ASSERT_RAISES(Invalid, registry->RegisterFunction(id, name.to_string()));
+ ASSERT_OK(default_registry->CanRegisterFunction(id, name.to_string()));
+ }
+ }
+}
+
+TEST(ExtensionIdRegistryTest, RegisterNestedTypes) {
+ std::shared_ptr<DataType> type1 = kTempTypeNames[0].type;
+ std::shared_ptr<DataType> type2 = kTempTypeNames[1].type;
+ auto id1 = Id{kArrowExtTypesUri, kTempTypeNames[0].name};
+ auto id2 = Id{kArrowExtTypesUri, kTempTypeNames[1].name};
+
+ auto default_registry = default_extension_id_registry();
+ constexpr int rounds = 3;
+ for (int i = 0; i < rounds; i++) {
+ auto registry1 = nested_extension_id_registry(default_registry);
+
+ ASSERT_OK(registry1->CanRegisterType(id1, type1));
+ ASSERT_OK(registry1->RegisterType(id1, type1));
+
+ for (int j = 0; j < rounds; j++) {
+ auto registry2 = nested_extension_id_registry(&*registry1);
+
+ ASSERT_OK(registry2->CanRegisterType(id2, type2));
+ ASSERT_OK(registry2->RegisterType(id2, type2));
+ ASSERT_RAISES(Invalid, registry2->CanRegisterType(id2, type2));
+ ASSERT_RAISES(Invalid, registry2->RegisterType(id2, type2));
+ ASSERT_OK(default_registry->CanRegisterType(id2, type2));
+ }
+
+ ASSERT_RAISES(Invalid, registry1->CanRegisterType(id1, type1));
+ ASSERT_RAISES(Invalid, registry1->RegisterType(id1, type1));
+ ASSERT_OK(default_registry->CanRegisterType(id1, type1));
+ }
+}
+
+TEST(ExtensionIdRegistryTest, RegisterNestedFunctions) {
+ util::string_view name1 = kTempFunctionNames[0];
+ util::string_view name2 = kTempFunctionNames[1];
+ auto id1 = Id{kArrowExtTypesUri, name1};
+ auto id2 = Id{kArrowExtTypesUri, name2};
+
+ auto default_registry = default_extension_id_registry();
+ constexpr int rounds = 3;
+ for (int i = 0; i < rounds; i++) {
+ auto registry1 = substrait::MakeExtensionIdRegistry();
+
+ ASSERT_OK(registry1->CanRegisterFunction(id1, name1.to_string()));
+ ASSERT_OK(registry1->RegisterFunction(id1, name1.to_string()));
+
+ for (int j = 0; j < rounds; j++) {
+ auto registry2 = substrait::MakeExtensionIdRegistry();
+
+ ASSERT_OK(registry2->CanRegisterFunction(id2, name2.to_string()));
+ ASSERT_OK(registry2->RegisterFunction(id2, name2.to_string()));
+ ASSERT_RAISES(Invalid, registry2->CanRegisterFunction(id2, name2.to_string()));
+ ASSERT_RAISES(Invalid, registry2->RegisterFunction(id2, name2.to_string()));
+ ASSERT_OK(default_registry->CanRegisterFunction(id2, name2.to_string()));
+ }
+
+ ASSERT_RAISES(Invalid, registry1->CanRegisterFunction(id1, name1.to_string()));
+ ASSERT_RAISES(Invalid, registry1->RegisterFunction(id1, name1.to_string()));
+ ASSERT_OK(default_registry->CanRegisterFunction(id1, name1.to_string()));
+ }
+}
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc
index cd85678a72..a30c740b18 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.cc
+++ b/cpp/src/arrow/engine/substrait/extension_set.cc
@@ -57,7 +57,7 @@ bool ExtensionIdRegistry::IdHashEq::operator()(ExtensionIdRegistry::Id l,
// A builder used when creating a Substrait plan from an Arrow execution plan. In
// that situation we do not have a set of anchor values already defined so we keep
// a map of what Ids we have seen.
-ExtensionSet::ExtensionSet(ExtensionIdRegistry* registry) : registry_(registry) {}
+ExtensionSet::ExtensionSet(const ExtensionIdRegistry* registry) : registry_(registry) {}
Status ExtensionSet::CheckHasUri(util::string_view uri) {
auto it =
@@ -96,7 +96,7 @@ Status ExtensionSet::AddUri(Id id) {
Result<ExtensionSet> ExtensionSet::Make(
std::unordered_map<uint32_t, util::string_view> uris,
std::unordered_map<uint32_t, Id> type_ids,
- std::unordered_map<uint32_t, Id> function_ids, ExtensionIdRegistry* registry) {
+ std::unordered_map<uint32_t, Id> function_ids, const ExtensionIdRegistry* registry) {
ExtensionSet set;
set.registry_ = registry;
@@ -204,152 +204,259 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) {
return &it->second;
}
-ExtensionIdRegistry* default_extension_id_registry() {
- static struct Impl : ExtensionIdRegistry {
- Impl() {
- struct TypeName {
- std::shared_ptr<DataType> type;
- util::string_view name;
- };
-
- // The type (variation) mappings listed below need to be kept in sync
- // with the YAML at substrait/format/extension_types.yaml manually;
- // see ARROW-15535.
- for (TypeName e : {
- TypeName{uint8(), "u8"},
- TypeName{uint16(), "u16"},
- TypeName{uint32(), "u32"},
- TypeName{uint64(), "u64"},
- TypeName{float16(), "fp16"},
- }) {
- DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type)));
- }
-
- for (TypeName e : {
- TypeName{null(), "null"},
- TypeName{month_interval(), "interval_month"},
- TypeName{day_time_interval(), "interval_day_milli"},
- TypeName{month_day_nano_interval(), "interval_month_day_nano"},
- }) {
- DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type)));
- }
-
- // TODO: this is just a placeholder right now. We'll need a YAML file for
- // all functions (and prototypes) that Arrow provides that are relevant
- // for Substrait, and include mappings for all of them here. See
- // ARROW-15535.
- for (util::string_view name : {
- "add",
- "equal",
- "is_not_distinct_from",
- }) {
- DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string()));
- }
+namespace {
+
+struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
+ virtual ~ExtensionIdRegistryImpl() {}
+
+ std::vector<util::string_view> Uris() const override {
+ return {uris_.begin(), uris_.end()};
+ }
+
+ util::optional<TypeRecord> GetType(const DataType& type) const override {
+ if (auto index = GetIndex(type_to_index_, &type)) {
+ return TypeRecord{type_ids_[*index], types_[*index]};
+ }
+ return {};
+ }
+
+ util::optional<TypeRecord> GetType(Id id) const override {
+ if (auto index = GetIndex(id_to_index_, id)) {
+ return TypeRecord{type_ids_[*index], types_[*index]};
+ }
+ return {};
+ }
+
+ Status CanRegisterType(Id id, const std::shared_ptr<DataType>& type) const override {
+ if (id_to_index_.find(id) != id_to_index_.end()) {
+ return Status::Invalid("Type id was already registered");
+ }
+ if (type_to_index_.find(&*type) != type_to_index_.end()) {
+ return Status::Invalid("Type was already registered");
+ }
+ return Status::OK();
+ }
+
+ Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
+ DCHECK_EQ(type_ids_.size(), types_.size());
+
+ Id copied_id{*uris_.emplace(id.uri.to_string()).first,
+ *names_.emplace(id.name.to_string()).first};
+
+ auto index = static_cast<int>(type_ids_.size());
+
+ auto it_success = id_to_index_.emplace(copied_id, index);
+
+ if (!it_success.second) {
+ return Status::Invalid("Type id was already registered");
+ }
+
+ if (!type_to_index_.emplace(type.get(), index).second) {
+ id_to_index_.erase(it_success.first);
+ return Status::Invalid("Type was already registered");
}
- std::vector<util::string_view> Uris() const override {
- return {uris_.begin(), uris_.end()};
+ type_ids_.push_back(copied_id);
+ types_.push_back(std::move(type));
+ return Status::OK();
+ }
+
+ util::optional<FunctionRecord> GetFunction(
+ util::string_view arrow_function_name) const override {
+ if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) {
+ return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
}
+ return {};
+ }
- util::optional<TypeRecord> GetType(const DataType& type) const override {
- if (auto index = GetIndex(type_to_index_, &type)) {
- return TypeRecord{type_ids_[*index], types_[*index]};
- }
- return {};
+ util::optional<FunctionRecord> GetFunction(Id id) const override {
+ if (auto index = GetIndex(function_id_to_index_, id)) {
+ return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
}
+ return {};
+ }
- util::optional<TypeRecord> GetType(Id id) const override {
- if (auto index = GetIndex(id_to_index_, id)) {
- return TypeRecord{type_ids_[*index], types_[*index]};
- }
- return {};
+ Status CanRegisterFunction(Id id,
+ const std::string& arrow_function_name) const override {
+ if (function_id_to_index_.find(id) != function_id_to_index_.end()) {
+ return Status::Invalid("Function id was already registered");
+ }
+ if (function_name_to_index_.find(arrow_function_name) !=
+ function_name_to_index_.end()) {
+ return Status::Invalid("Function name was already registered");
}
+ return Status::OK();
+ }
- Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
- DCHECK_EQ(type_ids_.size(), types_.size());
+ Status RegisterFunction(Id id, std::string arrow_function_name) override {
+ DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size());
- Id copied_id{*uris_.emplace(id.uri.to_string()).first,
- *names_.emplace(id.name.to_string()).first};
+ Id copied_id{*uris_.emplace(id.uri.to_string()).first,
+ *names_.emplace(id.name.to_string()).first};
- auto index = static_cast<int>(type_ids_.size());
+ const std::string& copied_function_name{
+ *function_names_.emplace(std::move(arrow_function_name)).first};
- auto it_success = id_to_index_.emplace(copied_id, index);
+ auto index = static_cast<int>(function_ids_.size());
- if (!it_success.second) {
- return Status::Invalid("Type id was already registered");
- }
+ auto it_success = function_id_to_index_.emplace(copied_id, index);
- if (!type_to_index_.emplace(type.get(), index).second) {
- id_to_index_.erase(it_success.first);
- return Status::Invalid("Type was already registered");
- }
+ if (!it_success.second) {
+ return Status::Invalid("Function id was already registered");
+ }
- type_ids_.push_back(copied_id);
- types_.push_back(std::move(type));
- return Status::OK();
+ if (!function_name_to_index_.emplace(copied_function_name, index).second) {
+ function_id_to_index_.erase(it_success.first);
+ return Status::Invalid("Function name was already registered");
}
- util::optional<FunctionRecord> GetFunction(
- util::string_view arrow_function_name) const override {
- if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) {
- return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
- }
- return {};
+ function_name_ptrs_.push_back(&copied_function_name);
+ function_ids_.push_back(copied_id);
+ return Status::OK();
+ }
+
+ // owning storage of uris, names, (arrow::)function_names, types
+ // note that storing strings like this is safe since references into an
+ // unordered_set are not invalidated on insertion
+ std::unordered_set<std::string> uris_, names_, function_names_;
+ DataTypeVector types_;
+
+ // non-owning lookup helpers
+ std::vector<Id> type_ids_, function_ids_;
+ std::unordered_map<Id, int, IdHashEq, IdHashEq> id_to_index_;
+ std::unordered_map<const DataType*, int, TypePtrHashEq, TypePtrHashEq> type_to_index_;
+
+ std::vector<const std::string*> function_name_ptrs_;
+ std::unordered_map<Id, int, IdHashEq, IdHashEq> function_id_to_index_;
+ std::unordered_map<util::string_view, int, ::arrow::internal::StringViewHash>
+ function_name_to_index_;
+};
+
+struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl {
+ explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent)
+ : parent_(parent) {}
+
+ virtual ~NestedExtensionIdRegistryImpl() {}
+
+ std::vector<util::string_view> Uris() const override {
+ std::vector<util::string_view> uris = parent_->Uris();
+ std::unordered_set<util::string_view> uri_set;
+ uri_set.insert(uris.begin(), uris.end());
+ uri_set.insert(uris_.begin(), uris_.end());
+ return std::vector<util::string_view>(uris);
+ }
+
+ util::optional<TypeRecord> GetType(const DataType& type) const override {
+ auto type_opt = ExtensionIdRegistryImpl::GetType(type);
+ if (type_opt) {
+ return type_opt;
}
+ return parent_->GetType(type);
+ }
- util::optional<FunctionRecord> GetFunction(Id id) const override {
- if (auto index = GetIndex(function_id_to_index_, id)) {
- return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
- }
- return {};
+ util::optional<TypeRecord> GetType(Id id) const override {
+ auto type_opt = ExtensionIdRegistryImpl::GetType(id);
+ if (type_opt) {
+ return type_opt;
}
+ return parent_->GetType(id);
+ }
- Status RegisterFunction(Id id, std::string arrow_function_name) override {
- DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size());
+ Status CanRegisterType(Id id, const std::shared_ptr<DataType>& type) const override {
+ return parent_->CanRegisterType(id, type) &
+ ExtensionIdRegistryImpl::CanRegisterType(id, type);
+ }
- Id copied_id{*uris_.emplace(id.uri.to_string()).first,
- *names_.emplace(id.name.to_string()).first};
+ Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
+ return parent_->CanRegisterType(id, type) &
+ ExtensionIdRegistryImpl::RegisterType(id, type);
+ }
- const std::string& copied_function_name{
- *function_names_.emplace(std::move(arrow_function_name)).first};
+ util::optional<FunctionRecord> GetFunction(
+ util::string_view arrow_function_name) const override {
+ auto func_opt = ExtensionIdRegistryImpl::GetFunction(arrow_function_name);
+ if (func_opt) {
+ return func_opt;
+ }
+ return parent_->GetFunction(arrow_function_name);
+ }
- auto index = static_cast<int>(function_ids_.size());
+ util::optional<FunctionRecord> GetFunction(Id id) const override {
+ auto func_opt = ExtensionIdRegistryImpl::GetFunction(id);
+ if (func_opt) {
+ return func_opt;
+ }
+ return parent_->GetFunction(id);
+ }
- auto it_success = function_id_to_index_.emplace(copied_id, index);
+ Status CanRegisterFunction(Id id,
+ const std::string& arrow_function_name) const override {
+ return parent_->CanRegisterFunction(id, arrow_function_name) &
+ ExtensionIdRegistryImpl::CanRegisterFunction(id, arrow_function_name);
+ }
- if (!it_success.second) {
- return Status::Invalid("Function id was already registered");
- }
+ Status RegisterFunction(Id id, std::string arrow_function_name) override {
+ return parent_->CanRegisterFunction(id, arrow_function_name) &
+ ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name);
+ }
- if (!function_name_to_index_.emplace(copied_function_name, index).second) {
- function_id_to_index_.erase(it_success.first);
- return Status::Invalid("Function name was already registered");
- }
+ const ExtensionIdRegistry* parent_;
+};
- function_name_ptrs_.push_back(&copied_function_name);
- function_ids_.push_back(copied_id);
- return Status::OK();
+struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl {
+ DefaultExtensionIdRegistry() {
+ struct TypeName {
+ std::shared_ptr<DataType> type;
+ util::string_view name;
+ };
+
+ // The type (variation) mappings listed below need to be kept in sync
+ // with the YAML at substrait/format/extension_types.yaml manually;
+ // see ARROW-15535.
+ for (TypeName e : {
+ TypeName{uint8(), "u8"},
+ TypeName{uint16(), "u16"},
+ TypeName{uint32(), "u32"},
+ TypeName{uint64(), "u64"},
+ TypeName{float16(), "fp16"},
+ }) {
+ DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type)));
}
- // owning storage of uris, names, (arrow::)function_names, types
- // note that storing strings like this is safe since references into an
- // unordered_set are not invalidated on insertion
- std::unordered_set<std::string> uris_, names_, function_names_;
- DataTypeVector types_;
+ for (TypeName e : {
+ TypeName{null(), "null"},
+ TypeName{month_interval(), "interval_month"},
+ TypeName{day_time_interval(), "interval_day_milli"},
+ TypeName{month_day_nano_interval(), "interval_month_day_nano"},
+ }) {
+ DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type)));
+ }
- // non-owning lookup helpers
- std::vector<Id> type_ids_, function_ids_;
- std::unordered_map<Id, int, IdHashEq, IdHashEq> id_to_index_;
- std::unordered_map<const DataType*, int, TypePtrHashEq, TypePtrHashEq> type_to_index_;
+ // TODO: this is just a placeholder right now. We'll need a YAML file for
+ // all functions (and prototypes) that Arrow provides that are relevant
+ // for Substrait, and include mappings for all of them here. See
+ // ARROW-15535.
+ for (util::string_view name : {
+ "add",
+ "equal",
+ "is_not_distinct_from",
+ }) {
+ DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string()));
+ }
+ }
+};
- std::vector<const std::string*> function_name_ptrs_;
- std::unordered_map<Id, int, IdHashEq, IdHashEq> function_id_to_index_;
- std::unordered_map<util::string_view, int, ::arrow::internal::StringViewHash>
- function_name_to_index_;
- } impl_;
+} // namespace
+ExtensionIdRegistry* default_extension_id_registry() {
+ static DefaultExtensionIdRegistry impl_;
return &impl_;
}
+std::shared_ptr<ExtensionIdRegistry> nested_extension_id_registry(
+ const ExtensionIdRegistry* parent) {
+ return std::make_shared<NestedExtensionIdRegistryImpl>(parent);
+}
+
} // namespace engine
} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h
index 55ea4d0232..de013015a7 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.h
+++ b/cpp/src/arrow/engine/substrait/extension_set.h
@@ -70,6 +70,7 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry {
};
virtual util::optional<TypeRecord> GetType(const DataType&) const = 0;
virtual util::optional<TypeRecord> GetType(Id) const = 0;
+ virtual Status CanRegisterType(Id, const std::shared_ptr<DataType>& type) const = 0;
virtual Status RegisterType(Id, std::shared_ptr<DataType>) = 0;
/// \brief A mapping between a Substrait ID and an Arrow function
@@ -91,6 +92,8 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry {
virtual util::optional<FunctionRecord> GetFunction(Id) const = 0;
virtual util::optional<FunctionRecord> GetFunction(
util::string_view arrow_function_name) const = 0;
+ virtual Status CanRegisterFunction(Id,
+ const std::string& arrow_function_name) const = 0;
virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0;
};
@@ -103,6 +106,19 @@ constexpr util::string_view kArrowExtTypesUri =
/// Note: Function support is currently very minimal, see ARROW-15538
ARROW_ENGINE_EXPORT ExtensionIdRegistry* default_extension_id_registry();
+/// \brief Make a nested registry with a given parent.
+///
+/// A nested registry supports registering types and functions other and on top of those
+/// already registered in its parent registry. No conflicts in IDs and names used for
+/// lookup are allowed. Normally, the given parent is the default registry.
+///
+/// One use case for a nested registry is for dynamic registration of functions defined
+/// within a Substrait plan while keeping these registrations specific to the plan. When
+/// the Substrait plan is disposed of, normally after its execution, the nested registry
+/// can be disposed of as well.
+ARROW_ENGINE_EXPORT std::shared_ptr<ExtensionIdRegistry> nested_extension_id_registry(
+ const ExtensionIdRegistry* parent);
+
/// \brief A set of extensions used within a plan
///
/// Each time an extension is used within a Substrait plan the extension
@@ -147,7 +163,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet {
};
/// Construct an empty ExtensionSet to be populated during serialization.
- explicit ExtensionSet(ExtensionIdRegistry* = default_extension_id_registry());
+ explicit ExtensionSet(const ExtensionIdRegistry* = default_extension_id_registry());
ARROW_DEFAULT_MOVE_AND_ASSIGN(ExtensionSet);
/// Construct an ExtensionSet with explicit extension ids for efficient referencing
@@ -168,7 +184,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet {
std::unordered_map<uint32_t, util::string_view> uris,
std::unordered_map<uint32_t, Id> type_ids,
std::unordered_map<uint32_t, Id> function_ids,
- ExtensionIdRegistry* = default_extension_id_registry());
+ const ExtensionIdRegistry* = default_extension_id_registry());
const std::unordered_map<uint32_t, util::string_view>& uris() const { return uris_; }
@@ -229,7 +245,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet {
std::size_t num_functions() const { return functions_.size(); }
private:
- ExtensionIdRegistry* registry_;
+ const ExtensionIdRegistry* registry_;
// Map from anchor values to URI values referenced by this extension set
std::unordered_map<uint32_t, util::string_view> uris_;
diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc
index cc20de6da6..fcee0b2188 100644
--- a/cpp/src/arrow/engine/substrait/plan_internal.cc
+++ b/cpp/src/arrow/engine/substrait/plan_internal.cc
@@ -91,7 +91,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan)
}
Result<ExtensionSet> GetExtensionSetFromPlan(const substrait::Plan& plan,
- ExtensionIdRegistry* registry) {
+ const ExtensionIdRegistry* registry) {
std::unordered_map<uint32_t, util::string_view> uris;
uris.reserve(plan.extension_uris_size());
for (const auto& uri : plan.extension_uris()) {
diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h
index 281cab0c0f..dce23cdceb 100644
--- a/cpp/src/arrow/engine/substrait/plan_internal.h
+++ b/cpp/src/arrow/engine/substrait/plan_internal.h
@@ -49,7 +49,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan)
ARROW_ENGINE_EXPORT
Result<ExtensionSet> GetExtensionSetFromPlan(
const substrait::Plan& plan,
- ExtensionIdRegistry* registry = default_extension_id_registry());
+ const ExtensionIdRegistry* registry = default_extension_id_registry());
} // namespace engine
} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc
index bc2aa36856..2ae3771f3f 100644
--- a/cpp/src/arrow/engine/substrait/util.cc
+++ b/cpp/src/arrow/engine/substrait/util.cc
@@ -123,6 +123,10 @@ Result<std::shared_ptr<Buffer>> SerializeJsonPlan(const std::string& substrait_j
return engine::internal::SubstraitFromJSON("Plan", substrait_json);
}
+std::shared_ptr<ExtensionIdRegistry> MakeExtensionIdRegistry() {
+ return nested_extension_id_registry(default_extension_id_registry());
+}
+
} // namespace substrait
} // namespace engine
diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h
index 860a459da2..f3e7d0fe73 100644
--- a/cpp/src/arrow/engine/substrait/util.h
+++ b/cpp/src/arrow/engine/substrait/util.h
@@ -37,6 +37,10 @@ ARROW_ENGINE_EXPORT Result<std::shared_ptr<RecordBatchReader>> ExecuteSerialized
ARROW_ENGINE_EXPORT Result<std::shared_ptr<Buffer>> SerializeJsonPlan(
const std::string& substrait_json);
+/// \brief Make a nested registry with the default registry as parent.
+/// See arrow::engine::nested_extension_id_registry for details.
+ARROW_ENGINE_EXPORT std::shared_ptr<ExtensionIdRegistry> MakeExtensionIdRegistry();
+
} // namespace substrait
} // namespace engine