You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/06/03 01:30:28 UTC

[GitHub] [arrow] westonpace commented on a diff in pull request #13252: ARROW-16677: [C++] Support nesting of function registries

westonpace commented on code in PR #13252:
URL: https://github.com/apache/arrow/pull/13252#discussion_r888518121


##########
cpp/src/arrow/compute/registry.cc:
##########
@@ -34,7 +34,20 @@ namespace compute {
 
 class FunctionRegistry::FunctionRegistryImpl {
  public:
-  Status AddFunction(std::shared_ptr<Function> function, bool allow_overwrite) {
+  virtual ~FunctionRegistryImpl() {}
+
+ private:

Review Comment:
   Instead of switching back and forth between public and private can we group the public and private functions together?



##########
cpp/src/arrow/compute/registry.cc:
##########
@@ -48,23 +61,56 @@ class FunctionRegistry::FunctionRegistryImpl {
     if (it != name_to_function_.end() && !allow_overwrite) {
       return Status::KeyError("Already have a function registered with name: ", name);
     }
-    name_to_function_[name] = std::move(function);
+    add(name, std::move(function));

Review Comment:
   ```suggestion
       if (do_add) {
         name_to_function_[name] = std::move(function);
       }
   ```
   If we change `add` to `bool do_add` I think this is a bit easier to follow.  If performance is a concern (I don't think it would be as this isn't really in a critical section) you could use a template...
   
   ```
   template <bool kDoAdd>
   Status DoAddFunction(std::shared_ptr<Function> function, bool allow_overwrite) {
     ...
     if (kDoAdd) {
       name_to_function_[name] = std::move(function);
     }
   }
   virtual Status CanAddFunction(std::shared_ptr<Function> function,
                                 bool allow_overwrite) {
     return DoAddFunction<false>(function, allow_overwrite);
   }
   ...
   
   ```



##########
cpp/src/arrow/compute/registry_test.cc:
##########
@@ -27,37 +27,44 @@
 #include "arrow/status.h"
 #include "arrow/testing/gtest_util.h"
 #include "arrow/util/macros.h"
+#include "arrow/util/make_unique.h"
 
 namespace arrow {
 namespace compute {
 
-class TestRegistry : public ::testing::Test {
- public:
-  void SetUp() { registry_ = FunctionRegistry::Make(); }
+using MakeFunctionRegistry = std::function<std::unique_ptr<FunctionRegistry>()>;
+using GetNumFunctions = std::function<int()>;
+using GetFunctionNames = std::function<std::vector<std::string>()>;
+using TestRegistryParams =
+    std::tuple<MakeFunctionRegistry, GetNumFunctions, GetFunctionNames, std::string>;
 
- protected:
-  std::unique_ptr<FunctionRegistry> registry_;
-};
+struct TestRegistry : public ::testing::TestWithParam<TestRegistryParams> {};
 
-TEST_F(TestRegistry, CreateBuiltInRegistry) {
+TEST(TestRegistry, CreateBuiltInRegistry) {

Review Comment:
   TEST_P? (I might be wrong here.  I can't remember if it's ok to mix parameterized and non-parameterized tests)



##########
cpp/src/arrow/compute/registry_test.cc:
##########
@@ -27,37 +27,44 @@
 #include "arrow/status.h"
 #include "arrow/testing/gtest_util.h"
 #include "arrow/util/macros.h"
+#include "arrow/util/make_unique.h"
 
 namespace arrow {
 namespace compute {
 
-class TestRegistry : public ::testing::Test {
- public:
-  void SetUp() { registry_ = FunctionRegistry::Make(); }
+using MakeFunctionRegistry = std::function<std::unique_ptr<FunctionRegistry>()>;
+using GetNumFunctions = std::function<int()>;
+using GetFunctionNames = std::function<std::vector<std::string>()>;
+using TestRegistryParams =
+    std::tuple<MakeFunctionRegistry, GetNumFunctions, GetFunctionNames, std::string>;
 
- protected:
-  std::unique_ptr<FunctionRegistry> registry_;
-};
+struct TestRegistry : public ::testing::TestWithParam<TestRegistryParams> {};
 
-TEST_F(TestRegistry, CreateBuiltInRegistry) {
+TEST(TestRegistry, CreateBuiltInRegistry) {
   // This does DCHECK_OK internally for now so this will fail in debug builds
   // if there is a problem initializing the global function registry
   FunctionRegistry* registry = GetFunctionRegistry();
   ARROW_UNUSED(registry);
 }
 
-TEST_F(TestRegistry, Basics) {
-  ASSERT_EQ(0, registry_->num_functions());
+TEST_P(TestRegistry, Basics) {
+  auto registry_factory = std::get<0>(GetParam());
+  auto registry_ = registry_factory();
+  auto get_num_funcs = std::get<1>(GetParam());
+  int n_funcs = get_num_funcs();
+  auto get_func_names = std::get<2>(GetParam());
+  std::vector<std::string> func_names = get_func_names();
+  ASSERT_EQ(n_funcs + 0, registry_->num_functions());

Review Comment:
   ```suggestion
     ASSERT_EQ(n_funcs, registry_->num_functions());
   ```



##########
cpp/src/arrow/compute/registry.h:
##########
@@ -84,6 +106,11 @@ class ARROW_EXPORT FunctionRegistry {
   // Use PIMPL pattern to not have std::unordered_map here
   class FunctionRegistryImpl;
   std::unique_ptr<FunctionRegistryImpl> impl_;
+
+  explicit FunctionRegistry(FunctionRegistryImpl* impl);
+
+  class NestedFunctionRegistryImpl;
+  friend class NestedFunctionRegistryImpl;

Review Comment:
   I'm not sure this friend declaration is needed.



##########
cpp/src/arrow/compute/registry.cc:
##########
@@ -48,23 +61,56 @@ class FunctionRegistry::FunctionRegistryImpl {
     if (it != name_to_function_.end() && !allow_overwrite) {
       return Status::KeyError("Already have a function registered with name: ", name);
     }
-    name_to_function_[name] = std::move(function);
+    add(name, std::move(function));
     return Status::OK();
   }
 
-  Status AddAlias(const std::string& target_name, const std::string& source_name) {
+ public:
+  virtual Status CanAddFunction(std::shared_ptr<Function> function,
+                                bool allow_overwrite) {
+    return DoAddFunction(function, allow_overwrite, kFuncAddNoOp);
+  }
+
+  virtual Status AddFunction(std::shared_ptr<Function> function, bool allow_overwrite) {
+    return DoAddFunction(function, allow_overwrite, kFuncAddDo);
+  }
+
+ private:
+  Status DoAddAlias(const std::string& target_name, const std::string& source_name,
+                    FuncAdd add) {
     std::lock_guard<std::mutex> mutation_guard(lock_);
 
-    auto it = name_to_function_.find(source_name);
-    if (it == name_to_function_.end()) {
+    auto func_res = GetFunction(source_name);  // must not acquire the mutex
+    if (!func_res.ok()) {
       return Status::KeyError("No function registered with name: ", source_name);
     }
-    name_to_function_[target_name] = it->second;
+    add(target_name, func_res.ValueOrDie());
     return Status::OK();
   }
 
-  Status AddFunctionOptionsType(const FunctionOptionsType* options_type,
-                                bool allow_overwrite = false) {
+ public:
+  virtual Status CanAddAlias(const std::string& target_name,
+                             const std::string& source_name) {
+    return DoAddAlias(target_name, source_name, kFuncAddNoOp);
+  }
+
+  virtual Status AddAlias(const std::string& target_name,
+                          const std::string& source_name) {
+    return DoAddAlias(target_name, source_name, kFuncAddDo);
+  }
+
+ private:
+  using FuncOptTypeAdd = std::function<void(const FunctionOptionsType* options_type)>;
+
+  const FuncOptTypeAdd kFuncOptTypeAddNoOp = [](const FunctionOptionsType* options_type) {
+  };
+  const FuncOptTypeAdd kFuncOptTypeAddDo =
+      [this](const FunctionOptionsType* options_type) {
+        name_to_options_type_[options_type->type_name()] = options_type;
+      };

Review Comment:
   Same thing here.  Replace with a boolean.



##########
cpp/src/arrow/compute/registry.cc:
##########
@@ -48,23 +61,56 @@ class FunctionRegistry::FunctionRegistryImpl {
     if (it != name_to_function_.end() && !allow_overwrite) {
       return Status::KeyError("Already have a function registered with name: ", name);
     }
-    name_to_function_[name] = std::move(function);
+    add(name, std::move(function));
     return Status::OK();
   }
 
-  Status AddAlias(const std::string& target_name, const std::string& source_name) {
+ public:
+  virtual Status CanAddFunction(std::shared_ptr<Function> function,
+                                bool allow_overwrite) {
+    return DoAddFunction(function, allow_overwrite, kFuncAddNoOp);
+  }
+
+  virtual Status AddFunction(std::shared_ptr<Function> function, bool allow_overwrite) {
+    return DoAddFunction(function, allow_overwrite, kFuncAddDo);
+  }
+
+ private:
+  Status DoAddAlias(const std::string& target_name, const std::string& source_name,
+                    FuncAdd add) {
     std::lock_guard<std::mutex> mutation_guard(lock_);
 
-    auto it = name_to_function_.find(source_name);
-    if (it == name_to_function_.end()) {
+    auto func_res = GetFunction(source_name);  // must not acquire the mutex
+    if (!func_res.ok()) {
       return Status::KeyError("No function registered with name: ", source_name);
     }

Review Comment:
   ```suggestion
       ARROW_ASSIGN_OR_RAISE(auto func_res, GetFunction(source_name));
   ```



##########
cpp/src/arrow/compute/registry.cc:
##########
@@ -103,32 +160,140 @@ class FunctionRegistry::FunctionRegistryImpl {
     return it->second;
   }
 
-  int num_functions() const { return static_cast<int>(name_to_function_.size()); }
+  virtual int num_functions() const { return static_cast<int>(name_to_function_.size()); }
 
  private:
   std::mutex lock_;
   std::unordered_map<std::string, std::shared_ptr<Function>> name_to_function_;
   std::unordered_map<std::string, const FunctionOptionsType*> name_to_options_type_;
 };
 
+class FunctionRegistry::NestedFunctionRegistryImpl
+    : public FunctionRegistry::FunctionRegistryImpl {

Review Comment:
   It's a little bit odd to have a pimpl pattern combined with inheritance.  If we want this inheritance chain we might be better off converting `FunctionRegistry` to a pure virtual class like we have with `ExtensionIdRegistry`.
   
   On the other hand, would it make sense to simple change the base `FunctionRegistryImpl` to always have a parent pointer that is sometimes `nullptr` (i.e. no parent == `nullptr`)?
   
   Then the lookup can just be...
   
   ```
     Result<std::shared_ptr<Function>> GetFunction(const std::string& name) const override {
       auto it = name_to_function_.find(name);
       if (it == name_to_function_.end()) {
         if (parent_ == nullptr) {
           return Status::KeyError("No function registered with name: ", name);
         }
         return parent_->GetFunction(name);
       }
       return it->second;
     }
   ```



##########
cpp/src/arrow/compute/registry_test.cc:
##########
@@ -85,5 +95,137 @@ TEST_F(TestRegistry, Basics) {
   ASSERT_EQ(func, f2);
 }
 
+INSTANTIATE_TEST_SUITE_P(
+    TestRegistry, TestRegistry,
+    testing::Values(
+        std::make_tuple(
+            static_cast<MakeFunctionRegistry>([]() { return FunctionRegistry::Make(); }),
+            []() { return 0; }, []() { return std::vector<std::string>{}; }, "default"),
+        std::make_tuple(
+            static_cast<MakeFunctionRegistry>([]() {
+              return FunctionRegistry::Make(GetFunctionRegistry());
+            }),
+            []() { return GetFunctionRegistry()->num_functions(); },
+            []() { return GetFunctionRegistry()->GetFunctionNames(); }, "nested")));
+
+TEST(TestRegistry, RegisterTempFunctions) {
+  auto default_registry = GetFunctionRegistry();
+  constexpr int rounds = 3;
+  for (int i = 0; i < rounds; i++) {
+    auto registry = FunctionRegistry::Make(default_registry);
+    for (std::string func_name : {"f1", "f2"}) {
+      std::shared_ptr<Function> func = std::make_shared<ScalarFunction>(
+          func_name, Arity::Unary(), /*doc=*/FunctionDoc::Empty());
+      ASSERT_OK(registry->CanAddFunction(func));
+      ASSERT_OK(registry->AddFunction(func));
+      ASSERT_RAISES(KeyError, registry->CanAddFunction(func));
+      ASSERT_RAISES(KeyError, registry->AddFunction(func));
+      ASSERT_OK(default_registry->CanAddFunction(func));
+    }
+  }
+}
+
+TEST(TestRegistry, RegisterTempAliases) {
+  auto default_registry = GetFunctionRegistry();
+  std::vector<std::string> func_names = default_registry->GetFunctionNames();
+  constexpr int rounds = 3;
+  for (int i = 0; i < rounds; i++) {
+    auto registry = FunctionRegistry::Make(default_registry);
+    for (std::string func_name : func_names) {
+      std::string alias_name = "alias_of_" + func_name;
+      std::shared_ptr<Function> func = std::make_shared<ScalarFunction>(
+          func_name, Arity::Unary(), /*doc=*/FunctionDoc::Empty());
+      ASSERT_RAISES(KeyError, registry->GetFunction(alias_name));
+      ASSERT_OK(registry->CanAddAlias(alias_name, func_name));
+      ASSERT_OK(registry->AddAlias(alias_name, func_name));
+      ASSERT_OK(registry->GetFunction(alias_name));
+      ASSERT_OK(default_registry->GetFunction(func_name));
+      ASSERT_RAISES(KeyError, default_registry->GetFunction(alias_name));
+    }
+  }
+}
+
+template <int N>

Review Comment:
   Instead of `N` can we call this `kTypeNameSuffix`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org