You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/07/11 12:57:52 UTC

[GitHub] [arrow] pitrou commented on a diff in pull request #13397: ARROW-16444: [R] Implement user-defined scalar functions in R bindings

pitrou commented on code in PR #13397:
URL: https://github.com/apache/arrow/pull/13397#discussion_r917871545


##########
r/R/compute.R:
##########
@@ -307,3 +307,157 @@ cast_options <- function(safe = TRUE, ...) {
   )
   modifyList(opts, list(...))
 }
+
+#' Register user-defined functions
+#'
+#' These functions support calling R code from query engine execution
+#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]).
+#' Use [arrow_scalar_function()] to define an R function that accepts and
+#' returns R objects; use [arrow_advanced_scalar_function()] to define a
+#' lower-level function that operates directly on Arrow objects.
+#'
+#' @param name The function name to be used in the dplyr bindings
+#' @param scalar_function An object created with [arrow_scalar_function()]
+#'   or [arrow_advanced_scalar_function()].
+#' @param in_type A [DataType] of the input type or a [schema()]
+#'   for functions with more than one argument. This signature will be used
+#'   to determine if this function is appropriate for a given set of arguments.
+#'   If this function is appropriate for more than one signature, pass a
+#'   `list()` of the above.
+#' @param out_type A [DataType] of the output type or a function accepting
+#'   a single argument (`types`), which is a `list()` of [DataType]s. If a
+#'   function it must return a [DataType].
+#' @param fun An R function or rlang-style lambda expression. This function
+#'   will be called with R objects as arguments and must return an object
+#'   that can be converted to an [Array] using [as_arrow_array()]. Function
+#'   authors must take care to return an array castable to the output data
+#'   type specified by `out_type`.
+#' @param advanced_fun An R function or rlang-style lambda expression. This
+#'   function will be called with exactly two arguments: `kernel_context`,
+#'   which is a `list()` of objects giving information about the
+#'   execution context and `args`, which is a list of [Array] or [Scalar]
+#'   objects corresponding to the input arguments.
+#'
+#' @return
+#'   - `register_user_defined_function()`: `NULL`, invisibly
+#'   - `arrow_scalar_function()`: returns an object of class
+#'     "arrow_advanced_scalar_function" that can be passed to
+#'     `register_user_defined_function()`.
+#' @export
+#'
+#' @examplesIf .Machine$sizeof.pointer >= 8

Review Comment:
   Why this restriction?



##########
r/R/dplyr-funcs.R:
##########
@@ -50,14 +50,17 @@ NULL
 #'   - `fun`: string function name
 #'   - `data`: `Expression` (these are all currently a single field)
 #'   - `options`: list of function options, as passed to call_function
+#' @param update_cache Update .cache$functions at the time of registration.

Review Comment:
   Newbie question, but what is this cache exactly? "`.cache$functions`" isn't exactly informative...



##########
r/src/compute.cpp:
##########
@@ -574,3 +576,171 @@ SEXP compute__CallFunction(std::string func_name, cpp11::list args, cpp11::list
 std::vector<std::string> compute__GetFunctionNames() {
   return arrow::compute::GetFunctionRegistry()->GetFunctionNames();
 }
+
+class RScalarUDFKernelState : public arrow::compute::KernelState {
+ public:
+  RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver)
+      : exec_func_(exec_func), resolver_(resolver) {}
+
+  cpp11::function exec_func_;
+  cpp11::function resolver_;
+};
+
+arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
+    arrow::compute::KernelContext* context,
+    const std::vector<arrow::TypeHolder>& input_types) {
+  return SafeCallIntoR<arrow::TypeHolder>(
+      [&]() -> arrow::TypeHolder {
+        auto kernel =
+            reinterpret_cast<const arrow::compute::ScalarKernel*>(context->kernel());
+        auto state = std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
+
+        cpp11::writable::list input_types_sexp(input_types.size());
+        for (size_t i = 0; i < input_types.size(); i++) {
+          input_types_sexp[i] =
+              cpp11::to_r6<arrow::DataType>(input_types[i].GetSharedPtr());
+        }
+
+        cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp);
+        if (!Rf_inherits(output_type_sexp, "DataType")) {
+          cpp11::stop("arrow_scalar_function resolver must return a DataType");
+        }
+
+        return arrow::TypeHolder(
+            cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(output_type_sexp));
+      },
+      "resolve scalar user-defined function output data type");
+}
+
+arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context,
+                             const arrow::compute::ExecSpan& span,
+                             arrow::compute::ExecResult* result) {
+  if (result->is_array_span()) {
+    return arrow::Status::NotImplemented("ArraySpan result from R scalar UDF");
+  }
+
+  return SafeCallIntoRVoid(
+      [&]() {
+        auto kernel =
+            reinterpret_cast<const arrow::compute::ScalarKernel*>(context->kernel());
+        auto state = std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
+
+        cpp11::writable::list args_sexp(span.num_values());
+
+        for (int i = 0; i < span.num_values(); i++) {
+          const arrow::compute::ExecValue& exec_val = span[i];
+          if (exec_val.is_array()) {
+            std::shared_ptr<arrow::Array> array = exec_val.array.ToArray();
+            args_sexp[i] = cpp11::to_r6<arrow::Array>(array);

Review Comment:
   Certainly a matter of taste, but FTR I generally find it more readable to avoid "obvious" variable definitions. Of course I'm not maintaining this code so you may freely disregard :-)
   ```suggestion
               args_sexp[i] = cpp11::to_r6<arrow::Array>(exec_val.array.ToArray());
   ```



##########
r/src/compute.cpp:
##########
@@ -574,3 +576,171 @@ SEXP compute__CallFunction(std::string func_name, cpp11::list args, cpp11::list
 std::vector<std::string> compute__GetFunctionNames() {
   return arrow::compute::GetFunctionRegistry()->GetFunctionNames();
 }
+
+class RScalarUDFKernelState : public arrow::compute::KernelState {
+ public:
+  RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver)
+      : exec_func_(exec_func), resolver_(resolver) {}
+
+  cpp11::function exec_func_;
+  cpp11::function resolver_;
+};
+
+arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
+    arrow::compute::KernelContext* context,
+    const std::vector<arrow::TypeHolder>& input_types) {
+  return SafeCallIntoR<arrow::TypeHolder>(
+      [&]() -> arrow::TypeHolder {
+        auto kernel =
+            reinterpret_cast<const arrow::compute::ScalarKernel*>(context->kernel());
+        auto state = std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
+
+        cpp11::writable::list input_types_sexp(input_types.size());
+        for (size_t i = 0; i < input_types.size(); i++) {
+          input_types_sexp[i] =
+              cpp11::to_r6<arrow::DataType>(input_types[i].GetSharedPtr());
+        }
+
+        cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp);
+        if (!Rf_inherits(output_type_sexp, "DataType")) {
+          cpp11::stop("arrow_scalar_function resolver must return a DataType");
+        }
+
+        return arrow::TypeHolder(
+            cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(output_type_sexp));
+      },
+      "resolve scalar user-defined function output data type");
+}
+
+arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context,
+                             const arrow::compute::ExecSpan& span,
+                             arrow::compute::ExecResult* result) {
+  if (result->is_array_span()) {
+    return arrow::Status::NotImplemented("ArraySpan result from R scalar UDF");
+  }
+
+  return SafeCallIntoRVoid(
+      [&]() {
+        auto kernel =
+            reinterpret_cast<const arrow::compute::ScalarKernel*>(context->kernel());
+        auto state = std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
+
+        cpp11::writable::list args_sexp(span.num_values());
+
+        for (int i = 0; i < span.num_values(); i++) {
+          const arrow::compute::ExecValue& exec_val = span[i];
+          if (exec_val.is_array()) {
+            std::shared_ptr<arrow::Array> array = exec_val.array.ToArray();
+            args_sexp[i] = cpp11::to_r6<arrow::Array>(array);
+          } else if (exec_val.is_scalar()) {
+            std::shared_ptr<arrow::Scalar> scalar = exec_val.scalar->GetSharedPtr();
+            args_sexp[i] = cpp11::to_r6<arrow::Scalar>(scalar);
+          }
+        }
+
+        cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length);
+
+        std::shared_ptr<arrow::DataType> output_type = result->type()->GetSharedPtr();
+        cpp11::sexp output_type_sexp = cpp11::to_r6<arrow::DataType>(output_type);
+        cpp11::writable::list udf_context = {batch_length_sexp, output_type_sexp};
+        udf_context.names() = {"batch_length", "output_type"};
+
+        cpp11::sexp func_result_sexp = state->exec_func_(udf_context, args_sexp);
+
+        if (Rf_inherits(func_result_sexp, "Array")) {
+          auto array = cpp11::as_cpp<std::shared_ptr<arrow::Array>>(func_result_sexp);
+
+          // handle an Array result of the wrong type
+          if (!result->type()->Equals(array->type())) {
+            arrow::Datum out = ValueOrStop(arrow::compute::Cast(array, result->type()));
+            std::shared_ptr<arrow::Array> out_array = out.make_array();
+            array.swap(out_array);
+          }
+
+          result->value = std::move(array->data());
+        } else if (Rf_inherits(func_result_sexp, "Scalar")) {
+          auto scalar = cpp11::as_cpp<std::shared_ptr<arrow::Scalar>>(func_result_sexp);
+
+          // handle a Scalar result of the wrong type
+          if (!result->type()->Equals(scalar->type)) {
+            arrow::Datum out = ValueOrStop(arrow::compute::Cast(scalar, result->type()));
+            std::shared_ptr<arrow::Scalar> out_scalar = out.scalar();
+            scalar.swap(out_scalar);
+          }
+
+          auto array = ValueOrStop(
+              arrow::MakeArrayFromScalar(*scalar, span.length, context->memory_pool()));
+          result->value = std::move(array->data());
+        } else {
+          cpp11::stop("arrow_scalar_function must return an Array or Scalar");
+        }
+      },
+      "execute scalar user-defined function");
+}
+
+// [[arrow::export]]
+void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) {
+  cpp11::list in_type_r(func_sexp.attr("in_type"));
+  cpp11::list out_type_r(func_sexp.attr("out_type"));
+  R_xlen_t n_kernels = in_type_r.size();
+
+  if (n_kernels == 0) {
+    cpp11::stop("Can't register user-defined function with zero kernels");
+  }
+
+  // compute the Arity from the list of input kernels
+  std::vector<int64_t> n_args(n_kernels);
+  for (R_xlen_t i = 0; i < n_kernels; i++) {
+    auto in_types = cpp11::as_cpp<std::shared_ptr<arrow::Schema>>(in_type_r[i]);
+    n_args[i] = in_types->num_fields();
+  }
+
+  const int64_t min_args = *std::min_element(n_args.begin(), n_args.end());
+  const int64_t max_args = *std::max_element(n_args.begin(), n_args.end());

Review Comment:
   I'm afraid I don't understand this. Is this assuming that different implementations of a function may take differing numbers of arguments, and that the function's arity is the union of all possible kernel arities? If so, this seems rather weird and clumsy IMHO.



##########
r/src/compute.cpp:
##########
@@ -574,3 +576,171 @@ SEXP compute__CallFunction(std::string func_name, cpp11::list args, cpp11::list
 std::vector<std::string> compute__GetFunctionNames() {
   return arrow::compute::GetFunctionRegistry()->GetFunctionNames();
 }
+
+class RScalarUDFKernelState : public arrow::compute::KernelState {
+ public:
+  RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver)
+      : exec_func_(exec_func), resolver_(resolver) {}
+
+  cpp11::function exec_func_;
+  cpp11::function resolver_;
+};
+
+arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
+    arrow::compute::KernelContext* context,
+    const std::vector<arrow::TypeHolder>& input_types) {
+  return SafeCallIntoR<arrow::TypeHolder>(
+      [&]() -> arrow::TypeHolder {
+        auto kernel =
+            reinterpret_cast<const arrow::compute::ScalarKernel*>(context->kernel());
+        auto state = std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
+
+        cpp11::writable::list input_types_sexp(input_types.size());
+        for (size_t i = 0; i < input_types.size(); i++) {
+          input_types_sexp[i] =
+              cpp11::to_r6<arrow::DataType>(input_types[i].GetSharedPtr());
+        }
+
+        cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp);
+        if (!Rf_inherits(output_type_sexp, "DataType")) {
+          cpp11::stop("arrow_scalar_function resolver must return a DataType");
+        }
+
+        return arrow::TypeHolder(
+            cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(output_type_sexp));
+      },
+      "resolve scalar user-defined function output data type");
+}
+
+arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context,
+                             const arrow::compute::ExecSpan& span,
+                             arrow::compute::ExecResult* result) {
+  if (result->is_array_span()) {
+    return arrow::Status::NotImplemented("ArraySpan result from R scalar UDF");
+  }
+
+  return SafeCallIntoRVoid(
+      [&]() {
+        auto kernel =
+            reinterpret_cast<const arrow::compute::ScalarKernel*>(context->kernel());
+        auto state = std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
+
+        cpp11::writable::list args_sexp(span.num_values());
+
+        for (int i = 0; i < span.num_values(); i++) {
+          const arrow::compute::ExecValue& exec_val = span[i];
+          if (exec_val.is_array()) {
+            std::shared_ptr<arrow::Array> array = exec_val.array.ToArray();
+            args_sexp[i] = cpp11::to_r6<arrow::Array>(array);
+          } else if (exec_val.is_scalar()) {
+            std::shared_ptr<arrow::Scalar> scalar = exec_val.scalar->GetSharedPtr();
+            args_sexp[i] = cpp11::to_r6<arrow::Scalar>(scalar);
+          }
+        }
+
+        cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length);
+
+        std::shared_ptr<arrow::DataType> output_type = result->type()->GetSharedPtr();
+        cpp11::sexp output_type_sexp = cpp11::to_r6<arrow::DataType>(output_type);
+        cpp11::writable::list udf_context = {batch_length_sexp, output_type_sexp};
+        udf_context.names() = {"batch_length", "output_type"};
+
+        cpp11::sexp func_result_sexp = state->exec_func_(udf_context, args_sexp);
+
+        if (Rf_inherits(func_result_sexp, "Array")) {
+          auto array = cpp11::as_cpp<std::shared_ptr<arrow::Array>>(func_result_sexp);
+
+          // handle an Array result of the wrong type
+          if (!result->type()->Equals(array->type())) {
+            arrow::Datum out = ValueOrStop(arrow::compute::Cast(array, result->type()));
+            std::shared_ptr<arrow::Array> out_array = out.make_array();
+            array.swap(out_array);
+          }
+
+          result->value = std::move(array->data());
+        } else if (Rf_inherits(func_result_sexp, "Scalar")) {
+          auto scalar = cpp11::as_cpp<std::shared_ptr<arrow::Scalar>>(func_result_sexp);
+
+          // handle a Scalar result of the wrong type
+          if (!result->type()->Equals(scalar->type)) {
+            arrow::Datum out = ValueOrStop(arrow::compute::Cast(scalar, result->type()));
+            std::shared_ptr<arrow::Scalar> out_scalar = out.scalar();
+            scalar.swap(out_scalar);
+          }
+
+          auto array = ValueOrStop(
+              arrow::MakeArrayFromScalar(*scalar, span.length, context->memory_pool()));
+          result->value = std::move(array->data());
+        } else {
+          cpp11::stop("arrow_scalar_function must return an Array or Scalar");
+        }
+      },
+      "execute scalar user-defined function");
+}
+
+// [[arrow::export]]
+void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) {
+  cpp11::list in_type_r(func_sexp.attr("in_type"));
+  cpp11::list out_type_r(func_sexp.attr("out_type"));
+  R_xlen_t n_kernels = in_type_r.size();
+
+  if (n_kernels == 0) {
+    cpp11::stop("Can't register user-defined function with zero kernels");
+  }
+
+  // compute the Arity from the list of input kernels
+  std::vector<int64_t> n_args(n_kernels);
+  for (R_xlen_t i = 0; i < n_kernels; i++) {
+    auto in_types = cpp11::as_cpp<std::shared_ptr<arrow::Schema>>(in_type_r[i]);
+    n_args[i] = in_types->num_fields();
+  }
+
+  const int64_t min_args = *std::min_element(n_args.begin(), n_args.end());
+  const int64_t max_args = *std::max_element(n_args.begin(), n_args.end());
+
+  // We can't currently handle variable numbers of arguments in a user-defined
+  // function and we don't have a mechanism for the user to specify a variable
+  // number of arguments at the end of a signature.
+  if (min_args != max_args) {
+    cpp11::stop(
+        "User-defined function with a variable number of arguments is not supported");
+  }
+
+  arrow::compute::Arity arity(min_args, false);
+
+  // The function documentation isn't currently accessible from R but is required
+  // for the C++ function constructor.

Review Comment:
   Is it possible to pass that documentation (at least the title and description) as plain strings?



##########
r/R/compute.R:
##########
@@ -307,3 +307,157 @@ cast_options <- function(safe = TRUE, ...) {
   )
   modifyList(opts, list(...))
 }
+
+#' Register user-defined functions
+#'
+#' These functions support calling R code from query engine execution
+#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]).
+#' Use [arrow_scalar_function()] to define an R function that accepts and
+#' returns R objects; use [arrow_advanced_scalar_function()] to define a
+#' lower-level function that operates directly on Arrow objects.
+#'
+#' @param name The function name to be used in the dplyr bindings
+#' @param scalar_function An object created with [arrow_scalar_function()]
+#'   or [arrow_advanced_scalar_function()].
+#' @param in_type A [DataType] of the input type or a [schema()]
+#'   for functions with more than one argument. This signature will be used
+#'   to determine if this function is appropriate for a given set of arguments.
+#'   If this function is appropriate for more than one signature, pass a
+#'   `list()` of the above.
+#' @param out_type A [DataType] of the output type or a function accepting
+#'   a single argument (`types`), which is a `list()` of [DataType]s. If a
+#'   function it must return a [DataType].
+#' @param fun An R function or rlang-style lambda expression. This function
+#'   will be called with R objects as arguments and must return an object
+#'   that can be converted to an [Array] using [as_arrow_array()]. Function
+#'   authors must take care to return an array castable to the output data
+#'   type specified by `out_type`.
+#' @param advanced_fun An R function or rlang-style lambda expression. This
+#'   function will be called with exactly two arguments: `kernel_context`,
+#'   which is a `list()` of objects giving information about the
+#'   execution context and `args`, which is a list of [Array] or [Scalar]
+#'   objects corresponding to the input arguments.
+#'
+#' @return
+#'   - `register_user_defined_function()`: `NULL`, invisibly
+#'   - `arrow_scalar_function()`: returns an object of class
+#'     "arrow_advanced_scalar_function" that can be passed to
+#'     `register_user_defined_function()`.
+#' @export
+#'
+#' @examplesIf .Machine$sizeof.pointer >= 8
+#' fun_wrapper <- arrow_scalar_function(
+#'   function(x, y, z) x + y + z,
+#'   schema(x = float64(), y = float64(), z = float64()),
+#'   float64()
+#' )
+#' register_user_defined_function(fun_wrapper, "example_add3")
+#'
+#' call_function(
+#'   "example_add3",
+#'   Scalar$create(1),
+#'   Scalar$create(2),
+#'   Array$create(3)
+#' )
+#'
+#' # use arrow_advanced_scalar_function() for a lower-level interface
+#' advanced_fun_wrapper <- arrow_advanced_scalar_function(
+#'   function(context, args) {

Review Comment:
   Is it useful for `args` to be packed here? R doesn't allow creating varargs functions perhaps?



##########
r/src/safe-call-into-r.h:
##########
@@ -93,7 +100,8 @@ MainRThread& GetMainRThread();
 // a SEXP (use cpp11::as_cpp<T> to convert it to a C++ type inside
 // `fun`).
 template <typename T>
-arrow::Future<T> SafeCallIntoRAsync(std::function<arrow::Result<T>(void)> fun) {
+arrow::Future<T> SafeCallIntoRAsync(std::function<arrow::Result<T>(void)> fun,
+                                    std::string reason = "unspecified") {

Review Comment:
   Or perhaps the `SetError(e.token)` ensures that the error is kept around? But then `SafeCallIntoR` should not be called directly?



##########
r/src/safe-call-into-r.h:
##########
@@ -20,18 +20,25 @@
 
 #include "./arrow_types.h"
 
+#include <arrow/io/interfaces.h>
 #include <arrow/util/future.h>
 #include <arrow/util/thread_pool.h>
 
 #include <functional>
 #include <thread>
 
 // Unwind protection was added in R 3.5 and some calls here use it
-// and crash R in older versions (ARROW-16201). We use this define
-// to make sure we don't crash on R 3.4 and lower.
+// and crash R in older versions (ARROW-16201). Crashes also occur
+// on 32-bit R builds on R 3.6 and lower.
+static inline bool CanSafeCallIntoR() {
 #if defined(HAS_UNWIND_PROTECT)
-#define HAS_SAFE_CALL_INTO_R
+  cpp11::function on_old_windows_fun = cpp11::package("arrow")["on_old_windows"];
+  bool on_old_windows = on_old_windows_fun();

Review Comment:
   Just out of curiosity, is it costly to call into the R interpreter like this everytime `CanSafeCallIntoR`? Should perhaps the result be cached as a local static variable?



##########
r/src/compute.cpp:
##########
@@ -574,3 +576,171 @@ SEXP compute__CallFunction(std::string func_name, cpp11::list args, cpp11::list
 std::vector<std::string> compute__GetFunctionNames() {
   return arrow::compute::GetFunctionRegistry()->GetFunctionNames();
 }
+
+class RScalarUDFKernelState : public arrow::compute::KernelState {
+ public:
+  RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver)
+      : exec_func_(exec_func), resolver_(resolver) {}
+
+  cpp11::function exec_func_;
+  cpp11::function resolver_;
+};
+
+arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
+    arrow::compute::KernelContext* context,
+    const std::vector<arrow::TypeHolder>& input_types) {
+  return SafeCallIntoR<arrow::TypeHolder>(
+      [&]() -> arrow::TypeHolder {
+        auto kernel =
+            reinterpret_cast<const arrow::compute::ScalarKernel*>(context->kernel());
+        auto state = std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
+
+        cpp11::writable::list input_types_sexp(input_types.size());
+        for (size_t i = 0; i < input_types.size(); i++) {
+          input_types_sexp[i] =
+              cpp11::to_r6<arrow::DataType>(input_types[i].GetSharedPtr());
+        }
+
+        cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp);
+        if (!Rf_inherits(output_type_sexp, "DataType")) {
+          cpp11::stop("arrow_scalar_function resolver must return a DataType");
+        }
+
+        return arrow::TypeHolder(
+            cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(output_type_sexp));
+      },
+      "resolve scalar user-defined function output data type");
+}
+
+arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context,
+                             const arrow::compute::ExecSpan& span,
+                             arrow::compute::ExecResult* result) {
+  if (result->is_array_span()) {
+    return arrow::Status::NotImplemented("ArraySpan result from R scalar UDF");
+  }
+
+  return SafeCallIntoRVoid(
+      [&]() {
+        auto kernel =
+            reinterpret_cast<const arrow::compute::ScalarKernel*>(context->kernel());
+        auto state = std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
+
+        cpp11::writable::list args_sexp(span.num_values());
+
+        for (int i = 0; i < span.num_values(); i++) {
+          const arrow::compute::ExecValue& exec_val = span[i];
+          if (exec_val.is_array()) {
+            std::shared_ptr<arrow::Array> array = exec_val.array.ToArray();
+            args_sexp[i] = cpp11::to_r6<arrow::Array>(array);
+          } else if (exec_val.is_scalar()) {
+            std::shared_ptr<arrow::Scalar> scalar = exec_val.scalar->GetSharedPtr();
+            args_sexp[i] = cpp11::to_r6<arrow::Scalar>(scalar);
+          }
+        }
+
+        cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length);
+
+        std::shared_ptr<arrow::DataType> output_type = result->type()->GetSharedPtr();
+        cpp11::sexp output_type_sexp = cpp11::to_r6<arrow::DataType>(output_type);
+        cpp11::writable::list udf_context = {batch_length_sexp, output_type_sexp};
+        udf_context.names() = {"batch_length", "output_type"};
+
+        cpp11::sexp func_result_sexp = state->exec_func_(udf_context, args_sexp);
+
+        if (Rf_inherits(func_result_sexp, "Array")) {
+          auto array = cpp11::as_cpp<std::shared_ptr<arrow::Array>>(func_result_sexp);
+
+          // handle an Array result of the wrong type
+          if (!result->type()->Equals(array->type())) {
+            arrow::Datum out = ValueOrStop(arrow::compute::Cast(array, result->type()));
+            std::shared_ptr<arrow::Array> out_array = out.make_array();
+            array.swap(out_array);
+          }
+
+          result->value = std::move(array->data());
+        } else if (Rf_inherits(func_result_sexp, "Scalar")) {
+          auto scalar = cpp11::as_cpp<std::shared_ptr<arrow::Scalar>>(func_result_sexp);
+
+          // handle a Scalar result of the wrong type
+          if (!result->type()->Equals(scalar->type)) {
+            arrow::Datum out = ValueOrStop(arrow::compute::Cast(scalar, result->type()));
+            std::shared_ptr<arrow::Scalar> out_scalar = out.scalar();
+            scalar.swap(out_scalar);
+          }
+
+          auto array = ValueOrStop(
+              arrow::MakeArrayFromScalar(*scalar, span.length, context->memory_pool()));
+          result->value = std::move(array->data());
+        } else {
+          cpp11::stop("arrow_scalar_function must return an Array or Scalar");
+        }
+      },
+      "execute scalar user-defined function");
+}
+
+// [[arrow::export]]
+void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) {
+  cpp11::list in_type_r(func_sexp.attr("in_type"));
+  cpp11::list out_type_r(func_sexp.attr("out_type"));
+  R_xlen_t n_kernels = in_type_r.size();
+
+  if (n_kernels == 0) {
+    cpp11::stop("Can't register user-defined function with zero kernels");

Review Comment:
   I don't understand, does this function accept multiple kernels at once? I don't see any exemple thereof, am I missing something?



##########
r/src/safe-call-into-r.h:
##########
@@ -93,7 +100,8 @@ MainRThread& GetMainRThread();
 // a SEXP (use cpp11::as_cpp<T> to convert it to a C++ type inside
 // `fun`).
 template <typename T>
-arrow::Future<T> SafeCallIntoRAsync(std::function<arrow::Result<T>(void)> fun) {
+arrow::Future<T> SafeCallIntoRAsync(std::function<arrow::Result<T>(void)> fun,
+                                    std::string reason = "unspecified") {

Review Comment:
   Not related to this PR, but is it posible to properly propagate the R error below instead of some placeholder `UnknownError` (other languages such as Python and Java try to do that)?
   
   One possibility is to use `StatusDetail` to keep as much R information as desired (which may also allow to roundtrip later on, i.e. re-throw the R error to the R interpreter from an outer level).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org