You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ko...@apache.org on 2022/10/20 21:43:53 UTC
[arrow] 06/13: ARROW-17460: [R] Don't warn if the new UDF I'm registering is the same as the existing one (#14436)
This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch maint-10.0.0
in repository https://gitbox.apache.org/repos/asf/arrow.git
commit ff80b30ea6424f21d15ca9511107a869a2f956b0
Author: Dewey Dunnington <de...@voltrondata.com>
AuthorDate: Tue Oct 18 13:58:31 2022 -0300
ARROW-17460: [R] Don't warn if the new UDF I'm registering is the same as the existing one (#14436)
This PR makes it so that you can do the following without a warning:
``` r
library(arrow, warn.conflicts = FALSE)
register_scalar_function(
"times_32",
function(context, x) x * 32L,
in_type = list(int32(), int64(), float64()),
out_type = function(in_types) in_types[[1]],
auto_convert = TRUE
)
register_scalar_function(
"times_32",
function(context, x) x * 32L,
in_type = list(int32(), int64(), float64()),
out_type = function(in_types) in_types[[1]],
auto_convert = TRUE
)
```
In fixing that, I also ran across an important discovery, which is that `cpp11::function` does *not* protect the underlying `SEXP` from garbage collection!!!! It the two functions we used this for were being protected by proxy because the execution environment of `register_scalar_function()` was being saved when the binding was registered.
Authored-by: Dewey Dunnington <de...@voltrondata.com>
Signed-off-by: Dewey Dunnington <de...@fishandwhistle.net>
---
r/R/compute.R | 10 +++++++++-
r/R/dplyr-funcs.R | 2 +-
r/src/compute.cpp | 10 ++++++----
r/src/recordbatchreader.cpp | 4 ++--
r/tests/testthat/test-dplyr-funcs.R | 3 +++
5 files changed, 21 insertions(+), 8 deletions(-)
diff --git a/r/R/compute.R b/r/R/compute.R
index a144e7d678..1386728ac9 100644
--- a/r/R/compute.R
+++ b/r/R/compute.R
@@ -379,9 +379,17 @@ register_scalar_function <- function(name, fun, in_type, out_type,
RegisterScalarUDF(name, scalar_function)
# register with dplyr binding (enables its use in mutate(), filter(), etc.)
+ binding_fun <- function(...) build_expr(name, ...)
+
+ # inject the value of `name` into the expression to avoid saving this
+ # execution environment in the binding, which eliminates a warning when the
+ # same binding is registered twice
+ body(binding_fun) <- expr_substitute(body(binding_fun), sym("name"), name)
+ environment(binding_fun) <- asNamespace("arrow")
+
register_binding(
name,
- function(...) build_expr(name, ...),
+ binding_fun,
update_cache = TRUE
)
diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R
index e5f7657061..ee64a09918 100644
--- a/r/R/dplyr-funcs.R
+++ b/r/R/dplyr-funcs.R
@@ -75,7 +75,7 @@ register_binding <- function(fun_name,
previous_fun <- registry[[unqualified_name]]
# if the unqualified name exists in the registry, warn
- if (!is.null(previous_fun)) {
+ if (!is.null(previous_fun) && !identical(fun, previous_fun)) {
warn(
paste0(
"A \"",
diff --git a/r/src/compute.cpp b/r/src/compute.cpp
index 1ed949e729..0bfc517285 100644
--- a/r/src/compute.cpp
+++ b/r/src/compute.cpp
@@ -611,8 +611,8 @@ class RScalarUDFKernelState : public arrow::compute::KernelState {
RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver)
: exec_func_(exec_func), resolver_(resolver) {}
- cpp11::function exec_func_;
- cpp11::function resolver_;
+ cpp11::sexp exec_func_;
+ cpp11::sexp resolver_;
};
arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
@@ -630,7 +630,8 @@ arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
cpp11::to_r6<arrow::DataType>(input_types[i].GetSharedPtr());
}
- cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp);
+ cpp11::sexp output_type_sexp =
+ cpp11::function(state->resolver_)(input_types_sexp);
if (!Rf_inherits(output_type_sexp, "DataType")) {
cpp11::stop(
"Function specified as arrow_scalar_function() out_type argument must "
@@ -674,7 +675,8 @@ arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context,
cpp11::writable::list udf_context = {batch_length_sexp, output_type_sexp};
udf_context.names() = {"batch_length", "output_type"};
- cpp11::sexp func_result_sexp = state->exec_func_(udf_context, args_sexp);
+ cpp11::sexp func_result_sexp =
+ cpp11::function(state->exec_func_)(udf_context, args_sexp);
if (Rf_inherits(func_result_sexp, "Array")) {
auto array = cpp11::as_cpp<std::shared_ptr<arrow::Array>>(func_result_sexp);
diff --git a/r/src/recordbatchreader.cpp b/r/src/recordbatchreader.cpp
index d0c52acc41..8e9df12174 100644
--- a/r/src/recordbatchreader.cpp
+++ b/r/src/recordbatchreader.cpp
@@ -70,7 +70,7 @@ class RFunctionRecordBatchReader : public arrow::RecordBatchReader {
arrow::Status ReadNext(std::shared_ptr<arrow::RecordBatch>* batch_out) {
auto batch = SafeCallIntoR<std::shared_ptr<arrow::RecordBatch>>([&]() {
- cpp11::sexp result_sexp = fun_();
+ cpp11::sexp result_sexp = cpp11::function(fun_)();
if (result_sexp == R_NilValue) {
return std::shared_ptr<arrow::RecordBatch>(nullptr);
} else if (!Rf_inherits(result_sexp, "RecordBatch")) {
@@ -94,7 +94,7 @@ class RFunctionRecordBatchReader : public arrow::RecordBatchReader {
}
private:
- cpp11::function fun_;
+ cpp11::sexp fun_;
std::shared_ptr<arrow::Schema> schema_;
};
diff --git a/r/tests/testthat/test-dplyr-funcs.R b/r/tests/testthat/test-dplyr-funcs.R
index 86f984dd32..48b74c9af4 100644
--- a/r/tests/testthat/test-dplyr-funcs.R
+++ b/r/tests/testthat/test-dplyr-funcs.R
@@ -35,6 +35,9 @@ test_that("register_binding()/unregister_binding() works", {
register_binding("some.pkg2::some_fun", fun2, fake_registry),
"A \"some_fun\" binding already exists in the registry and will be overwritten."
)
+
+ # No warning when an identical function is re-registered
+ expect_silent(register_binding("some.pkg2::some_fun", fun2, fake_registry))
})
test_that("register_binding_agg() works", {