You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/06/02 15:23:58 UTC

[GitHub] [arrow] westonpace commented on a diff in pull request #13232: ARROW-16657: [C++] Support nesting of extension-id-registries

westonpace commented on code in PR #13232:
URL: https://github.com/apache/arrow/pull/13232#discussion_r888059559


##########
cpp/src/arrow/engine/substrait/extension_set.h:
##########
@@ -103,6 +106,19 @@ constexpr util::string_view kArrowExtTypesUri =
 /// Note: Function support is currently very minimal, see ARROW-15538
 ARROW_ENGINE_EXPORT ExtensionIdRegistry* default_extension_id_registry();
 
+/// \brief Makes a nested registry with a given parent.

Review Comment:
   ```suggestion
   /// \brief Make a nested registry with a given parent.
   ```



##########
cpp/src/arrow/engine/substrait/util.h:
##########
@@ -37,6 +37,10 @@ ARROW_ENGINE_EXPORT Result<std::shared_ptr<RecordBatchReader>> ExecuteSerialized
 ARROW_ENGINE_EXPORT Result<std::shared_ptr<Buffer>> SerializeJsonPlan(
     const std::string& substrait_json);
 
+/// \brief Makes a nested registry with the default registry as parent.

Review Comment:
   ```suggestion
   /// \brief Make a nested registry with the default registry as parent.
   ```



##########
cpp/src/arrow/engine/substrait/ext_test.cc:
##########
@@ -0,0 +1,263 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/engine/substrait/extension_set.h"
+#include "arrow/engine/substrait/util.h"
+
+#include <google/protobuf/descriptor.h>
+#include <google/protobuf/util/json_util.h>
+#include <google/protobuf/util/type_resolver_util.h>
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+
+using testing::ElementsAre;
+using testing::Eq;
+using testing::HasSubstr;
+using testing::UnorderedElementsAre;
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace engine {
+
+// an extension-id-registry provider to be used as a test parameter
+//
+// we cannot pass a pointer to a nested registry as a test parameter because the
+// shared_ptr in which it is made would not be held and get destructed too early,
+// nor can we pass a shared_ptr to the default nested registry as a test parameter
+// because it is global and must never be cleaned up, so we pass a shared_ptr to a
+// provider that either owns or does not own the registry it provides, depending
+// on the case.
+struct ExtensionIdRegistryProvider {
+  virtual ExtensionIdRegistry* get() const = 0;
+};
+
+struct DefaultExtensionIdRegistryProvider : public ExtensionIdRegistryProvider {
+  virtual ~DefaultExtensionIdRegistryProvider() {}
+  ExtensionIdRegistry* get() const override { return default_extension_id_registry(); }
+};
+
+struct NestedExtensionIdRegistryProvider : public ExtensionIdRegistryProvider {
+  virtual ~NestedExtensionIdRegistryProvider() {}
+  std::shared_ptr<ExtensionIdRegistry> registry_ = substrait::MakeExtensionIdRegistry();
+  ExtensionIdRegistry* get() const override { return &*registry_; }
+};
+
+using Id = ExtensionIdRegistry::Id;
+
+bool operator==(const Id& id1, const Id& id2) {
+  return id1.uri == id2.uri && id1.name == id2.name;
+}
+
+bool operator!=(const Id& id1, const Id& id2) { return !(id1 == id2); }
+
+struct TypeName {
+  std::shared_ptr<DataType> type;
+  util::string_view name;
+};
+
+static const std::vector<TypeName> kTypeNames = {
+    TypeName{uint8(), "u8"},
+    TypeName{uint16(), "u16"},
+    TypeName{uint32(), "u32"},
+    TypeName{uint64(), "u64"},
+    TypeName{float16(), "fp16"},
+    TypeName{null(), "null"},
+    TypeName{month_interval(), "interval_month"},
+    TypeName{day_time_interval(), "interval_day_milli"},
+    TypeName{month_day_nano_interval(), "interval_month_day_nano"},
+};
+
+static const std::vector<util::string_view> kFunctionNames = {
+    "add",
+};
+
+static const std::vector<util::string_view> kTempFunctionNames = {
+    "temp_func_1",
+    "temp_func_2",
+};
+
+static const std::vector<TypeName> kTempTypeNames = {
+    TypeName{timestamp(TimeUnit::SECOND, "temp_tz_1"), "temp_type_1"},
+    TypeName{timestamp(TimeUnit::SECOND, "temp_tz_2"), "temp_type_2"},
+};
+
+using ExtensionIdRegistryParams =
+    std::tuple<std::shared_ptr<ExtensionIdRegistryProvider>, std::string>;
+
+struct ExtensionIdRegistryTest
+    : public testing::TestWithParam<ExtensionIdRegistryParams> {};
+
+TEST_P(ExtensionIdRegistryTest, GetTypes) {
+  auto provider = std::get<0>(GetParam());
+  auto registry = provider->get();
+
+  for (TypeName e : kTypeNames) {
+    auto id = Id{kArrowExtTypesUri, e.name};
+    for (auto typerec_opt : {registry->GetType(id), registry->GetType(*e.type)}) {
+      ASSERT_TRUE(typerec_opt);
+      auto typerec = typerec_opt.value();
+      ASSERT_EQ(id, typerec.id);
+      ASSERT_EQ(*e.type, *typerec.type);
+    }
+  }
+}
+
+TEST_P(ExtensionIdRegistryTest, ReregisterTypes) {
+  auto provider = std::get<0>(GetParam());
+  auto registry = provider->get();
+
+  for (TypeName e : kTypeNames) {
+    auto id = Id{kArrowExtTypesUri, e.name};
+    ASSERT_RAISES(Invalid, registry->CanRegisterType(id, e.type));
+    ASSERT_RAISES(Invalid, registry->RegisterType(id, e.type));
+  }
+}
+
+TEST_P(ExtensionIdRegistryTest, GetFunctions) {
+  auto provider = std::get<0>(GetParam());
+  auto registry = provider->get();
+
+  for (util::string_view name : kFunctionNames) {
+    auto id = Id{kArrowExtTypesUri, name};
+    for (auto funcrec_opt : {registry->GetFunction(id), registry->GetFunction(name)}) {
+      ASSERT_TRUE(funcrec_opt);
+      auto funcrec = funcrec_opt.value();
+      ASSERT_EQ(id, funcrec.id);
+      ASSERT_EQ(name, funcrec.function_name);
+    }
+  }

Review Comment:
   Can you add failed lookup checks here too?



##########
cpp/src/arrow/engine/substrait/ext_test.cc:
##########
@@ -0,0 +1,263 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/engine/substrait/extension_set.h"
+#include "arrow/engine/substrait/util.h"
+
+#include <google/protobuf/descriptor.h>
+#include <google/protobuf/util/json_util.h>
+#include <google/protobuf/util/type_resolver_util.h>
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+
+using testing::ElementsAre;
+using testing::Eq;
+using testing::HasSubstr;
+using testing::UnorderedElementsAre;
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace engine {
+
+// an extension-id-registry provider to be used as a test parameter
+//
+// we cannot pass a pointer to a nested registry as a test parameter because the
+// shared_ptr in which it is made would not be held and get destructed too early,
+// nor can we pass a shared_ptr to the default nested registry as a test parameter
+// because it is global and must never be cleaned up, so we pass a shared_ptr to a
+// provider that either owns or does not own the registry it provides, depending
+// on the case.
+struct ExtensionIdRegistryProvider {
+  virtual ExtensionIdRegistry* get() const = 0;
+};
+
+struct DefaultExtensionIdRegistryProvider : public ExtensionIdRegistryProvider {
+  virtual ~DefaultExtensionIdRegistryProvider() {}
+  ExtensionIdRegistry* get() const override { return default_extension_id_registry(); }
+};
+
+struct NestedExtensionIdRegistryProvider : public ExtensionIdRegistryProvider {
+  virtual ~NestedExtensionIdRegistryProvider() {}
+  std::shared_ptr<ExtensionIdRegistry> registry_ = substrait::MakeExtensionIdRegistry();
+  ExtensionIdRegistry* get() const override { return &*registry_; }
+};
+
+using Id = ExtensionIdRegistry::Id;
+
+bool operator==(const Id& id1, const Id& id2) {
+  return id1.uri == id2.uri && id1.name == id2.name;
+}
+
+bool operator!=(const Id& id1, const Id& id2) { return !(id1 == id2); }
+
+struct TypeName {
+  std::shared_ptr<DataType> type;
+  util::string_view name;
+};
+
+static const std::vector<TypeName> kTypeNames = {
+    TypeName{uint8(), "u8"},
+    TypeName{uint16(), "u16"},
+    TypeName{uint32(), "u32"},
+    TypeName{uint64(), "u64"},
+    TypeName{float16(), "fp16"},
+    TypeName{null(), "null"},
+    TypeName{month_interval(), "interval_month"},
+    TypeName{day_time_interval(), "interval_day_milli"},
+    TypeName{month_day_nano_interval(), "interval_month_day_nano"},
+};
+
+static const std::vector<util::string_view> kFunctionNames = {
+    "add",
+};
+
+static const std::vector<util::string_view> kTempFunctionNames = {
+    "temp_func_1",
+    "temp_func_2",
+};
+
+static const std::vector<TypeName> kTempTypeNames = {
+    TypeName{timestamp(TimeUnit::SECOND, "temp_tz_1"), "temp_type_1"},
+    TypeName{timestamp(TimeUnit::SECOND, "temp_tz_2"), "temp_type_2"},
+};
+
+using ExtensionIdRegistryParams =
+    std::tuple<std::shared_ptr<ExtensionIdRegistryProvider>, std::string>;
+
+struct ExtensionIdRegistryTest
+    : public testing::TestWithParam<ExtensionIdRegistryParams> {};
+
+TEST_P(ExtensionIdRegistryTest, GetTypes) {
+  auto provider = std::get<0>(GetParam());
+  auto registry = provider->get();
+
+  for (TypeName e : kTypeNames) {
+    auto id = Id{kArrowExtTypesUri, e.name};
+    for (auto typerec_opt : {registry->GetType(id), registry->GetType(*e.type)}) {
+      ASSERT_TRUE(typerec_opt);
+      auto typerec = typerec_opt.value();
+      ASSERT_EQ(id, typerec.id);
+      ASSERT_EQ(*e.type, *typerec.type);
+    }
+  }
+}

Review Comment:
   Can you add failed lookup checks here...
   
   ```
   ASSERT_FALSE(registry->GetType(kNonExistentId));
   ASSERT_FALSE(registry->GetType(kNonExistentType));
   ```



##########
cpp/src/arrow/engine/substrait/extension_set.cc:
##########
@@ -204,152 +204,259 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) {
   return &it->second;
 }
 
-ExtensionIdRegistry* default_extension_id_registry() {
-  static struct Impl : ExtensionIdRegistry {
-    Impl() {
-      struct TypeName {
-        std::shared_ptr<DataType> type;
-        util::string_view name;
-      };
-
-      // The type (variation) mappings listed below need to be kept in sync
-      // with the YAML at substrait/format/extension_types.yaml manually;
-      // see ARROW-15535.
-      for (TypeName e : {
-               TypeName{uint8(), "u8"},
-               TypeName{uint16(), "u16"},
-               TypeName{uint32(), "u32"},
-               TypeName{uint64(), "u64"},
-               TypeName{float16(), "fp16"},
-           }) {
-        DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type)));
-      }
-
-      for (TypeName e : {
-               TypeName{null(), "null"},
-               TypeName{month_interval(), "interval_month"},
-               TypeName{day_time_interval(), "interval_day_milli"},
-               TypeName{month_day_nano_interval(), "interval_month_day_nano"},
-           }) {
-        DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type)));
-      }
-
-      // TODO: this is just a placeholder right now. We'll need a YAML file for
-      // all functions (and prototypes) that Arrow provides that are relevant
-      // for Substrait, and include mappings for all of them here. See
-      // ARROW-15535.
-      for (util::string_view name : {
-               "add",
-               "equal",
-               "is_not_distinct_from",
-           }) {
-        DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string()));
-      }
+namespace {
+
+struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
+  virtual ~ExtensionIdRegistryImpl() {}
+
+  std::vector<util::string_view> Uris() const override {
+    return {uris_.begin(), uris_.end()};
+  }
+
+  util::optional<TypeRecord> GetType(const DataType& type) const override {
+    if (auto index = GetIndex(type_to_index_, &type)) {
+      return TypeRecord{type_ids_[*index], types_[*index]};
+    }
+    return {};
+  }
+
+  util::optional<TypeRecord> GetType(Id id) const override {
+    if (auto index = GetIndex(id_to_index_, id)) {
+      return TypeRecord{type_ids_[*index], types_[*index]};
+    }
+    return {};
+  }
+
+  Status CanRegisterType(Id id, const std::shared_ptr<DataType>& type) const override {
+    if (id_to_index_.find(id) != id_to_index_.end()) {
+      return Status::Invalid("Type id was already registered");
+    }
+    if (type_to_index_.find(&*type) != type_to_index_.end()) {
+      return Status::Invalid("Type was already registered");
+    }
+    return Status::OK();
+  }
+
+  Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
+    DCHECK_EQ(type_ids_.size(), types_.size());
+
+    Id copied_id{*uris_.emplace(id.uri.to_string()).first,
+                 *names_.emplace(id.name.to_string()).first};
+
+    auto index = static_cast<int>(type_ids_.size());
+
+    auto it_success = id_to_index_.emplace(copied_id, index);
+
+    if (!it_success.second) {
+      return Status::Invalid("Type id was already registered");
+    }
+
+    if (!type_to_index_.emplace(type.get(), index).second) {
+      id_to_index_.erase(it_success.first);
+      return Status::Invalid("Type was already registered");
     }
 
-    std::vector<util::string_view> Uris() const override {
-      return {uris_.begin(), uris_.end()};
+    type_ids_.push_back(copied_id);
+    types_.push_back(std::move(type));
+    return Status::OK();
+  }
+
+  util::optional<FunctionRecord> GetFunction(
+      util::string_view arrow_function_name) const override {
+    if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) {
+      return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
     }
+    return {};
+  }
 
-    util::optional<TypeRecord> GetType(const DataType& type) const override {
-      if (auto index = GetIndex(type_to_index_, &type)) {
-        return TypeRecord{type_ids_[*index], types_[*index]};
-      }
-      return {};
+  util::optional<FunctionRecord> GetFunction(Id id) const override {
+    if (auto index = GetIndex(function_id_to_index_, id)) {
+      return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
     }
+    return {};
+  }
 
-    util::optional<TypeRecord> GetType(Id id) const override {
-      if (auto index = GetIndex(id_to_index_, id)) {
-        return TypeRecord{type_ids_[*index], types_[*index]};
-      }
-      return {};
+  Status CanRegisterFunction(Id id,
+                             const std::string& arrow_function_name) const override {
+    if (function_id_to_index_.find(id) != function_id_to_index_.end()) {
+      return Status::Invalid("Function id was already registered");
+    }
+    if (function_name_to_index_.find(arrow_function_name) !=
+        function_name_to_index_.end()) {
+      return Status::Invalid("Function name was already registered");
     }
+    return Status::OK();
+  }
 
-    Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
-      DCHECK_EQ(type_ids_.size(), types_.size());
+  Status RegisterFunction(Id id, std::string arrow_function_name) override {
+    DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size());
 
-      Id copied_id{*uris_.emplace(id.uri.to_string()).first,
-                   *names_.emplace(id.name.to_string()).first};
+    Id copied_id{*uris_.emplace(id.uri.to_string()).first,
+                 *names_.emplace(id.name.to_string()).first};
 
-      auto index = static_cast<int>(type_ids_.size());
+    const std::string& copied_function_name{
+        *function_names_.emplace(std::move(arrow_function_name)).first};
 
-      auto it_success = id_to_index_.emplace(copied_id, index);
+    auto index = static_cast<int>(function_ids_.size());
 
-      if (!it_success.second) {
-        return Status::Invalid("Type id was already registered");
-      }
+    auto it_success = function_id_to_index_.emplace(copied_id, index);
 
-      if (!type_to_index_.emplace(type.get(), index).second) {
-        id_to_index_.erase(it_success.first);
-        return Status::Invalid("Type was already registered");
-      }
+    if (!it_success.second) {
+      return Status::Invalid("Function id was already registered");
+    }
 
-      type_ids_.push_back(copied_id);
-      types_.push_back(std::move(type));
-      return Status::OK();
+    if (!function_name_to_index_.emplace(copied_function_name, index).second) {
+      function_id_to_index_.erase(it_success.first);
+      return Status::Invalid("Function name was already registered");
     }
 
-    util::optional<FunctionRecord> GetFunction(
-        util::string_view arrow_function_name) const override {
-      if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) {
-        return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
-      }
-      return {};
+    function_name_ptrs_.push_back(&copied_function_name);
+    function_ids_.push_back(copied_id);
+    return Status::OK();
+  }
+
+  // owning storage of uris, names, (arrow::)function_names, types
+  //    note that storing strings like this is safe since references into an
+  //    unordered_set are not invalidated on insertion
+  std::unordered_set<std::string> uris_, names_, function_names_;
+  DataTypeVector types_;
+
+  // non-owning lookup helpers
+  std::vector<Id> type_ids_, function_ids_;
+  std::unordered_map<Id, int, IdHashEq, IdHashEq> id_to_index_;
+  std::unordered_map<const DataType*, int, TypePtrHashEq, TypePtrHashEq> type_to_index_;
+
+  std::vector<const std::string*> function_name_ptrs_;
+  std::unordered_map<Id, int, IdHashEq, IdHashEq> function_id_to_index_;
+  std::unordered_map<util::string_view, int, ::arrow::internal::StringViewHash>
+      function_name_to_index_;
+};
+
+struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl {
+  explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent)
+      : parent_(parent) {}
+
+  virtual ~NestedExtensionIdRegistryImpl() {}
+
+  std::vector<util::string_view> Uris() const override {
+    std::vector<util::string_view> uris = parent_->Uris();
+    std::unordered_set<util::string_view> uri_set;
+    uri_set.insert(uris.begin(), uris.end());
+    uri_set.insert(uris_.begin(), uris_.end());
+    return std::vector<util::string_view>(uris);
+  }
+
+  util::optional<TypeRecord> GetType(const DataType& type) const override {
+    auto type_opt = ExtensionIdRegistryImpl::GetType(type);
+    if (type_opt) {
+      return type_opt;
     }
+    return parent_->GetType(type);
+  }
 
-    util::optional<FunctionRecord> GetFunction(Id id) const override {
-      if (auto index = GetIndex(function_id_to_index_, id)) {
-        return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
-      }
-      return {};
+  util::optional<TypeRecord> GetType(Id id) const override {
+    auto type_opt = ExtensionIdRegistryImpl::GetType(id);
+    if (type_opt) {
+      return type_opt;
     }
+    return parent_->GetType(id);
+  }
 
-    Status RegisterFunction(Id id, std::string arrow_function_name) override {
-      DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size());
+  Status CanRegisterType(Id id, const std::shared_ptr<DataType>& type) const override {
+    return parent_->CanRegisterType(id, type) &
+           ExtensionIdRegistryImpl::CanRegisterType(id, type);
+  }
 
-      Id copied_id{*uris_.emplace(id.uri.to_string()).first,
-                   *names_.emplace(id.name.to_string()).first};
+  Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
+    return parent_->CanRegisterType(id, type) &
+           ExtensionIdRegistryImpl::RegisterType(id, type);

Review Comment:
   ```suggestion
              ExtensionIdRegistryImpl::RegisterType(id, std::move(type));
   ```



##########
cpp/src/arrow/engine/substrait/extension_set.cc:
##########
@@ -204,152 +204,259 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) {
   return &it->second;
 }
 
-ExtensionIdRegistry* default_extension_id_registry() {
-  static struct Impl : ExtensionIdRegistry {
-    Impl() {
-      struct TypeName {
-        std::shared_ptr<DataType> type;
-        util::string_view name;
-      };
-
-      // The type (variation) mappings listed below need to be kept in sync
-      // with the YAML at substrait/format/extension_types.yaml manually;
-      // see ARROW-15535.
-      for (TypeName e : {
-               TypeName{uint8(), "u8"},
-               TypeName{uint16(), "u16"},
-               TypeName{uint32(), "u32"},
-               TypeName{uint64(), "u64"},
-               TypeName{float16(), "fp16"},
-           }) {
-        DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type)));
-      }
-
-      for (TypeName e : {
-               TypeName{null(), "null"},
-               TypeName{month_interval(), "interval_month"},
-               TypeName{day_time_interval(), "interval_day_milli"},
-               TypeName{month_day_nano_interval(), "interval_month_day_nano"},
-           }) {
-        DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type)));
-      }
-
-      // TODO: this is just a placeholder right now. We'll need a YAML file for
-      // all functions (and prototypes) that Arrow provides that are relevant
-      // for Substrait, and include mappings for all of them here. See
-      // ARROW-15535.
-      for (util::string_view name : {
-               "add",
-               "equal",
-               "is_not_distinct_from",
-           }) {
-        DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string()));
-      }
+namespace {
+
+struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
+  virtual ~ExtensionIdRegistryImpl() {}
+
+  std::vector<util::string_view> Uris() const override {
+    return {uris_.begin(), uris_.end()};
+  }
+
+  util::optional<TypeRecord> GetType(const DataType& type) const override {
+    if (auto index = GetIndex(type_to_index_, &type)) {
+      return TypeRecord{type_ids_[*index], types_[*index]};
+    }
+    return {};
+  }
+
+  util::optional<TypeRecord> GetType(Id id) const override {
+    if (auto index = GetIndex(id_to_index_, id)) {
+      return TypeRecord{type_ids_[*index], types_[*index]};
+    }
+    return {};
+  }
+
+  Status CanRegisterType(Id id, const std::shared_ptr<DataType>& type) const override {
+    if (id_to_index_.find(id) != id_to_index_.end()) {
+      return Status::Invalid("Type id was already registered");
+    }
+    if (type_to_index_.find(&*type) != type_to_index_.end()) {
+      return Status::Invalid("Type was already registered");
+    }
+    return Status::OK();
+  }
+
+  Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
+    DCHECK_EQ(type_ids_.size(), types_.size());
+
+    Id copied_id{*uris_.emplace(id.uri.to_string()).first,
+                 *names_.emplace(id.name.to_string()).first};
+
+    auto index = static_cast<int>(type_ids_.size());
+
+    auto it_success = id_to_index_.emplace(copied_id, index);
+
+    if (!it_success.second) {
+      return Status::Invalid("Type id was already registered");
+    }
+
+    if (!type_to_index_.emplace(type.get(), index).second) {
+      id_to_index_.erase(it_success.first);
+      return Status::Invalid("Type was already registered");
     }
 
-    std::vector<util::string_view> Uris() const override {
-      return {uris_.begin(), uris_.end()};
+    type_ids_.push_back(copied_id);
+    types_.push_back(std::move(type));
+    return Status::OK();
+  }
+
+  util::optional<FunctionRecord> GetFunction(
+      util::string_view arrow_function_name) const override {
+    if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) {
+      return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
     }
+    return {};
+  }
 
-    util::optional<TypeRecord> GetType(const DataType& type) const override {
-      if (auto index = GetIndex(type_to_index_, &type)) {
-        return TypeRecord{type_ids_[*index], types_[*index]};
-      }
-      return {};
+  util::optional<FunctionRecord> GetFunction(Id id) const override {
+    if (auto index = GetIndex(function_id_to_index_, id)) {
+      return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
     }
+    return {};
+  }
 
-    util::optional<TypeRecord> GetType(Id id) const override {
-      if (auto index = GetIndex(id_to_index_, id)) {
-        return TypeRecord{type_ids_[*index], types_[*index]};
-      }
-      return {};
+  Status CanRegisterFunction(Id id,
+                             const std::string& arrow_function_name) const override {
+    if (function_id_to_index_.find(id) != function_id_to_index_.end()) {
+      return Status::Invalid("Function id was already registered");
+    }
+    if (function_name_to_index_.find(arrow_function_name) !=
+        function_name_to_index_.end()) {
+      return Status::Invalid("Function name was already registered");
     }
+    return Status::OK();
+  }
 
-    Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
-      DCHECK_EQ(type_ids_.size(), types_.size());
+  Status RegisterFunction(Id id, std::string arrow_function_name) override {
+    DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size());
 
-      Id copied_id{*uris_.emplace(id.uri.to_string()).first,
-                   *names_.emplace(id.name.to_string()).first};
+    Id copied_id{*uris_.emplace(id.uri.to_string()).first,
+                 *names_.emplace(id.name.to_string()).first};
 
-      auto index = static_cast<int>(type_ids_.size());
+    const std::string& copied_function_name{
+        *function_names_.emplace(std::move(arrow_function_name)).first};
 
-      auto it_success = id_to_index_.emplace(copied_id, index);
+    auto index = static_cast<int>(function_ids_.size());
 
-      if (!it_success.second) {
-        return Status::Invalid("Type id was already registered");
-      }
+    auto it_success = function_id_to_index_.emplace(copied_id, index);
 
-      if (!type_to_index_.emplace(type.get(), index).second) {
-        id_to_index_.erase(it_success.first);
-        return Status::Invalid("Type was already registered");
-      }
+    if (!it_success.second) {
+      return Status::Invalid("Function id was already registered");
+    }
 
-      type_ids_.push_back(copied_id);
-      types_.push_back(std::move(type));
-      return Status::OK();
+    if (!function_name_to_index_.emplace(copied_function_name, index).second) {
+      function_id_to_index_.erase(it_success.first);
+      return Status::Invalid("Function name was already registered");
     }
 
-    util::optional<FunctionRecord> GetFunction(
-        util::string_view arrow_function_name) const override {
-      if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) {
-        return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
-      }
-      return {};
+    function_name_ptrs_.push_back(&copied_function_name);
+    function_ids_.push_back(copied_id);
+    return Status::OK();
+  }
+
+  // owning storage of uris, names, (arrow::)function_names, types
+  //    note that storing strings like this is safe since references into an
+  //    unordered_set are not invalidated on insertion
+  std::unordered_set<std::string> uris_, names_, function_names_;
+  DataTypeVector types_;
+
+  // non-owning lookup helpers
+  std::vector<Id> type_ids_, function_ids_;
+  std::unordered_map<Id, int, IdHashEq, IdHashEq> id_to_index_;
+  std::unordered_map<const DataType*, int, TypePtrHashEq, TypePtrHashEq> type_to_index_;
+
+  std::vector<const std::string*> function_name_ptrs_;
+  std::unordered_map<Id, int, IdHashEq, IdHashEq> function_id_to_index_;
+  std::unordered_map<util::string_view, int, ::arrow::internal::StringViewHash>
+      function_name_to_index_;
+};
+
+struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl {
+  explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent)
+      : parent_(parent) {}
+
+  virtual ~NestedExtensionIdRegistryImpl() {}
+
+  std::vector<util::string_view> Uris() const override {
+    std::vector<util::string_view> uris = parent_->Uris();
+    std::unordered_set<util::string_view> uri_set;
+    uri_set.insert(uris.begin(), uris.end());
+    uri_set.insert(uris_.begin(), uris_.end());
+    return std::vector<util::string_view>(uris);
+  }
+
+  util::optional<TypeRecord> GetType(const DataType& type) const override {
+    auto type_opt = ExtensionIdRegistryImpl::GetType(type);
+    if (type_opt) {
+      return type_opt;
     }
+    return parent_->GetType(type);
+  }
 
-    util::optional<FunctionRecord> GetFunction(Id id) const override {
-      if (auto index = GetIndex(function_id_to_index_, id)) {
-        return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
-      }
-      return {};
+  util::optional<TypeRecord> GetType(Id id) const override {
+    auto type_opt = ExtensionIdRegistryImpl::GetType(id);
+    if (type_opt) {
+      return type_opt;
     }
+    return parent_->GetType(id);
+  }
 
-    Status RegisterFunction(Id id, std::string arrow_function_name) override {
-      DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size());
+  Status CanRegisterType(Id id, const std::shared_ptr<DataType>& type) const override {
+    return parent_->CanRegisterType(id, type) &
+           ExtensionIdRegistryImpl::CanRegisterType(id, type);
+  }
 
-      Id copied_id{*uris_.emplace(id.uri.to_string()).first,
-                   *names_.emplace(id.name.to_string()).first};
+  Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
+    return parent_->CanRegisterType(id, type) &
+           ExtensionIdRegistryImpl::RegisterType(id, type);
+  }
 
-      const std::string& copied_function_name{
-          *function_names_.emplace(std::move(arrow_function_name)).first};
+  util::optional<FunctionRecord> GetFunction(
+      util::string_view arrow_function_name) const override {
+    auto func_opt = ExtensionIdRegistryImpl::GetFunction(arrow_function_name);
+    if (func_opt) {
+      return func_opt;
+    }
+    return parent_->GetFunction(arrow_function_name);
+  }
 
-      auto index = static_cast<int>(function_ids_.size());
+  util::optional<FunctionRecord> GetFunction(Id id) const override {
+    auto func_opt = ExtensionIdRegistryImpl::GetFunction(id);
+    if (func_opt) {
+      return func_opt;
+    }
+    return parent_->GetFunction(id);
+  }
 
-      auto it_success = function_id_to_index_.emplace(copied_id, index);
+  Status CanRegisterFunction(Id id,
+                             const std::string& arrow_function_name) const override {
+    return parent_->CanRegisterFunction(id, arrow_function_name) &
+           ExtensionIdRegistryImpl::CanRegisterFunction(id, arrow_function_name);
+  }
 
-      if (!it_success.second) {
-        return Status::Invalid("Function id was already registered");
-      }
+  Status RegisterFunction(Id id, std::string arrow_function_name) override {
+    return parent_->CanRegisterFunction(id, arrow_function_name) &
+           ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name);

Review Comment:
   ```suggestion
              ExtensionIdRegistryImpl::RegisterFunction(id, std::move(arrow_function_name));
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org