You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ko...@apache.org on 2022/10/20 21:43:53 UTC

[arrow] 06/13: ARROW-17460: [R] Don't warn if the new UDF I'm registering is the same as the existing one (#14436)

This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch maint-10.0.0
in repository https://gitbox.apache.org/repos/asf/arrow.git

commit ff80b30ea6424f21d15ca9511107a869a2f956b0
Author: Dewey Dunnington <de...@voltrondata.com>
AuthorDate: Tue Oct 18 13:58:31 2022 -0300

    ARROW-17460: [R] Don't warn if the new UDF I'm registering is the same as the existing one (#14436)
    
    This PR makes it so that you can do the following without a warning:
    
    ``` r
    library(arrow, warn.conflicts = FALSE)
    
    register_scalar_function(
      "times_32",
      function(context, x) x * 32L,
      in_type = list(int32(), int64(), float64()),
      out_type = function(in_types) in_types[[1]],
      auto_convert = TRUE
    )
    
    register_scalar_function(
      "times_32",
      function(context, x) x * 32L,
      in_type = list(int32(), int64(), float64()),
      out_type = function(in_types) in_types[[1]],
      auto_convert = TRUE
    )
    ```
    
    In fixing that, I also ran across an important discovery, which is that `cpp11::function` does *not* protect the underlying `SEXP` from garbage collection!!!! It the two functions we used this for were being protected by proxy because the execution environment of `register_scalar_function()` was being saved when the binding was registered.
    
    Authored-by: Dewey Dunnington <de...@voltrondata.com>
    Signed-off-by: Dewey Dunnington <de...@fishandwhistle.net>
---
 r/R/compute.R                       | 10 +++++++++-
 r/R/dplyr-funcs.R                   |  2 +-
 r/src/compute.cpp                   | 10 ++++++----
 r/src/recordbatchreader.cpp         |  4 ++--
 r/tests/testthat/test-dplyr-funcs.R |  3 +++
 5 files changed, 21 insertions(+), 8 deletions(-)

diff --git a/r/R/compute.R b/r/R/compute.R
index a144e7d678..1386728ac9 100644
--- a/r/R/compute.R
+++ b/r/R/compute.R
@@ -379,9 +379,17 @@ register_scalar_function <- function(name, fun, in_type, out_type,
   RegisterScalarUDF(name, scalar_function)
 
   # register with dplyr binding (enables its use in mutate(), filter(), etc.)
+  binding_fun <- function(...) build_expr(name, ...)
+
+  # inject the value of `name` into the expression to avoid saving this
+  # execution environment in the binding, which eliminates a warning when the
+  # same binding is registered twice
+  body(binding_fun) <- expr_substitute(body(binding_fun), sym("name"), name)
+  environment(binding_fun) <- asNamespace("arrow")
+
   register_binding(
     name,
-    function(...) build_expr(name, ...),
+    binding_fun,
     update_cache = TRUE
   )
 
diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R
index e5f7657061..ee64a09918 100644
--- a/r/R/dplyr-funcs.R
+++ b/r/R/dplyr-funcs.R
@@ -75,7 +75,7 @@ register_binding <- function(fun_name,
   previous_fun <- registry[[unqualified_name]]
 
   # if the unqualified name exists in the registry, warn
-  if (!is.null(previous_fun)) {
+  if (!is.null(previous_fun) && !identical(fun, previous_fun)) {
     warn(
       paste0(
         "A \"",
diff --git a/r/src/compute.cpp b/r/src/compute.cpp
index 1ed949e729..0bfc517285 100644
--- a/r/src/compute.cpp
+++ b/r/src/compute.cpp
@@ -611,8 +611,8 @@ class RScalarUDFKernelState : public arrow::compute::KernelState {
   RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver)
       : exec_func_(exec_func), resolver_(resolver) {}
 
-  cpp11::function exec_func_;
-  cpp11::function resolver_;
+  cpp11::sexp exec_func_;
+  cpp11::sexp resolver_;
 };
 
 arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
@@ -630,7 +630,8 @@ arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
               cpp11::to_r6<arrow::DataType>(input_types[i].GetSharedPtr());
         }
 
-        cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp);
+        cpp11::sexp output_type_sexp =
+            cpp11::function(state->resolver_)(input_types_sexp);
         if (!Rf_inherits(output_type_sexp, "DataType")) {
           cpp11::stop(
               "Function specified as arrow_scalar_function() out_type argument must "
@@ -674,7 +675,8 @@ arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context,
         cpp11::writable::list udf_context = {batch_length_sexp, output_type_sexp};
         udf_context.names() = {"batch_length", "output_type"};
 
-        cpp11::sexp func_result_sexp = state->exec_func_(udf_context, args_sexp);
+        cpp11::sexp func_result_sexp =
+            cpp11::function(state->exec_func_)(udf_context, args_sexp);
 
         if (Rf_inherits(func_result_sexp, "Array")) {
           auto array = cpp11::as_cpp<std::shared_ptr<arrow::Array>>(func_result_sexp);
diff --git a/r/src/recordbatchreader.cpp b/r/src/recordbatchreader.cpp
index d0c52acc41..8e9df12174 100644
--- a/r/src/recordbatchreader.cpp
+++ b/r/src/recordbatchreader.cpp
@@ -70,7 +70,7 @@ class RFunctionRecordBatchReader : public arrow::RecordBatchReader {
 
   arrow::Status ReadNext(std::shared_ptr<arrow::RecordBatch>* batch_out) {
     auto batch = SafeCallIntoR<std::shared_ptr<arrow::RecordBatch>>([&]() {
-      cpp11::sexp result_sexp = fun_();
+      cpp11::sexp result_sexp = cpp11::function(fun_)();
       if (result_sexp == R_NilValue) {
         return std::shared_ptr<arrow::RecordBatch>(nullptr);
       } else if (!Rf_inherits(result_sexp, "RecordBatch")) {
@@ -94,7 +94,7 @@ class RFunctionRecordBatchReader : public arrow::RecordBatchReader {
   }
 
  private:
-  cpp11::function fun_;
+  cpp11::sexp fun_;
   std::shared_ptr<arrow::Schema> schema_;
 };
 
diff --git a/r/tests/testthat/test-dplyr-funcs.R b/r/tests/testthat/test-dplyr-funcs.R
index 86f984dd32..48b74c9af4 100644
--- a/r/tests/testthat/test-dplyr-funcs.R
+++ b/r/tests/testthat/test-dplyr-funcs.R
@@ -35,6 +35,9 @@ test_that("register_binding()/unregister_binding() works", {
     register_binding("some.pkg2::some_fun", fun2, fake_registry),
     "A \"some_fun\" binding already exists in the registry and will be overwritten."
   )
+
+  # No warning when an identical function is re-registered
+  expect_silent(register_binding("some.pkg2::some_fun", fun2, fake_registry))
 })
 
 test_that("register_binding_agg() works", {