You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2023/01/26 13:24:38 UTC
[arrow] branch master updated: GH-32916: [C++] [Python] User-defined tabular functions (#14682)
This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new a1d9b511da GH-32916: [C++] [Python] User-defined tabular functions (#14682)
a1d9b511da is described below
commit a1d9b511dab6fed7e0d074e8abc6dbc26f50a460
Author: rtpsw <rt...@hotmail.com>
AuthorDate: Thu Jan 26 15:24:28 2023 +0200
GH-32916: [C++] [Python] User-defined tabular functions (#14682)
See https://issues.apache.org/jira/browse/ARROW-17676
* Closes: #32916
Lead-authored-by: Yaron Gvili <rt...@hotmail.com>
Co-authored-by: rtpsw <rt...@hotmail.com>
Signed-off-by: Weston Pace <we...@gmail.com>
---
cpp/src/arrow/type.cc | 10 ++
cpp/src/arrow/type.h | 3 +
python/pyarrow/_compute.pyx | 186 +++++++++++++++++++++-
python/pyarrow/_dataset.pyx | 13 +-
python/pyarrow/compute.py | 2 +
python/pyarrow/includes/libarrow.pxd | 25 ++-
python/pyarrow/includes/libarrow_dataset.pxd | 7 -
python/pyarrow/src/arrow/python/udf.cc | 228 ++++++++++++++++++++++++---
python/pyarrow/src/arrow/python/udf.h | 21 ++-
python/pyarrow/table.pxi | 14 ++
python/pyarrow/tests/test_udf.py | 123 ++++++++++++++-
11 files changed, 566 insertions(+), 66 deletions(-)
diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc
index cc31735512..3514d0538f 100644
--- a/cpp/src/arrow/type.cc
+++ b/cpp/src/arrow/type.cc
@@ -452,6 +452,16 @@ std::string TypeHolder::ToString(const std::vector<TypeHolder>& types) {
return ss.str();
}
+std::vector<TypeHolder> TypeHolder::FromTypes(
+ const std::vector<std::shared_ptr<DataType>>& types) {
+ std::vector<TypeHolder> type_holders;
+ type_holders.reserve(types.size());
+ for (const auto& type : types) {
+ type_holders.emplace_back(type);
+ }
+ return type_holders;
+}
+
// ----------------------------------------------------------------------
FloatingPointType::Precision HalfFloatType::precision() const {
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index 1e16806947..ee5bc6ef37 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -264,6 +264,9 @@ struct ARROW_EXPORT TypeHolder {
}
static std::string ToString(const std::vector<TypeHolder>&);
+
+ static std::vector<TypeHolder> FromTypes(
+ const std::vector<std::shared_ptr<DataType>>& types);
};
ARROW_EXPORT
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index c75c5bf189..283f532837 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -36,6 +36,18 @@ import inspect
import numpy as np
+def _forbid_instantiation(klass, subclasses_instead=True):
+ msg = '{} is an abstract class thus cannot be initialized.'.format(
+ klass.__name__
+ )
+ if subclasses_instead:
+ subclasses = [cls.__name__ for cls in klass.__subclasses__]
+ msg += ' Use one of the subclasses instead: {}'.format(
+ ', '.join(subclasses)
+ )
+ raise TypeError(msg)
+
+
cdef wrap_scalar_function(const shared_ptr[CFunction]& sp_func):
"""
Wrap a C++ scalar Function in a ScalarFunction object.
@@ -2574,7 +2586,7 @@ cdef object box_scalar_udf_context(const CScalarUdfContext& c_context):
return context
-cdef _scalar_udf_callback(user_function, const CScalarUdfContext& c_context, inputs):
+cdef _udf_callback(user_function, const CScalarUdfContext& c_context, inputs):
"""
Helper callback function used to wrap the ScalarUdfContext from Python to C++
execution.
@@ -2591,8 +2603,30 @@ def _get_scalar_udf_context(memory_pool, batch_length):
return context
-def register_scalar_function(func, function_name, function_doc, in_types,
- out_type):
+ctypedef CStatus (*CRegisterUdf)(PyObject* function, function[CallbackUdf] wrapper,
+ const CUdfOptions& options, CFunctionRegistry* registry)
+
+cdef class RegisterUdf(_Weakrefable):
+ cdef CRegisterUdf register_func
+
+ cdef void init(self, const CRegisterUdf register_func):
+ self.register_func = register_func
+
+
+cdef get_register_scalar_function():
+ cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf)
+ reg.register_func = RegisterScalarFunction
+ return reg
+
+
+cdef get_register_tabular_function():
+ cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf)
+ reg.register_func = RegisterTabularFunction
+ return reg
+
+
+def register_scalar_function(func, function_name, function_doc, in_types, out_type,
+ func_registry=None):
"""
Register a user-defined scalar function.
@@ -2633,6 +2667,8 @@ def register_scalar_function(func, function_name, function_doc, in_types,
arity.
out_type : DataType
Output type of the function.
+ func_registry : FunctionRegistry
+ Optional function registry to use instead of the default global one.
Examples
--------
@@ -2662,14 +2698,106 @@ def register_scalar_function(func, function_name, function_doc, in_types,
21
]
"""
+ return _register_scalar_like_function(get_register_scalar_function(),
+ func, function_name, function_doc, in_types,
+ out_type, func_registry)
+
+
+def register_tabular_function(func, function_name, function_doc, in_types, out_type,
+ func_registry=None):
+ """
+ Register a user-defined tabular function.
+
+ A tabular function is one accepting a context argument of type
+ ScalarUdfContext and returning a generator of struct arrays.
+ The in_types argument must be empty and the out_type argument
+ specifies a schema. Each struct array must have field types
+ correspoding to the schema.
+
+ Parameters
+ ----------
+ func : callable
+ A callable implementing the user-defined function.
+ The only argument is the context argument of type
+ ScalarUdfContext. It must return a callable that
+ returns on each invocation a StructArray matching
+ the out_type, where an empty array indicates end.
+ function_name : str
+ Name of the function. This name must be globally unique.
+ function_doc : dict
+ A dictionary object with keys "summary" (str),
+ and "description" (str).
+ in_types : Dict[str, DataType]
+ Must be an empty dictionary (reserved for future use).
+ out_type : Union[Schema, DataType]
+ Schema of the function's output, or a corresponding flat struct type.
+ func_registry : FunctionRegistry
+ Optional function registry to use instead of the default global one.
+ """
cdef:
+ shared_ptr[CSchema] c_schema
+ shared_ptr[CDataType] c_type
+
+ if isinstance(out_type, Schema):
+ c_schema = pyarrow_unwrap_schema(out_type)
+ with nogil:
+ c_type = <shared_ptr[CDataType]>make_shared[CStructType](deref(c_schema).fields())
+ out_type = pyarrow_wrap_data_type(c_type)
+ return _register_scalar_like_function(get_register_tabular_function(),
+ func, function_name, function_doc, in_types,
+ out_type, func_registry)
+
+
+def _register_scalar_like_function(register_func, func, function_name, function_doc, in_types,
+ out_type, func_registry=None):
+ """
+ Register a user-defined scalar-like function.
+
+ A scalar-like function is a callable accepting a first
+ context argument of type ScalarUdfContext as well as
+ possibly additional Arrow arguments, and returning a
+ an Arrow result appropriate for the kind of function.
+ A scalar function and a tabular function are examples
+ for scalar-like functions.
+ This function is normally not called directly but via
+ register_scalar_function or register_tabular_function.
+
+ Parameters
+ ----------
+ register_func: object
+ An object holding a CRegisterUdf in a "register_func" attribute. Use
+ get_register_scalar_function() for a scalar function and
+ get_register_tabular_function() for a tabular function.
+ func : callable
+ A callable implementing the user-defined function.
+ See register_scalar_function and
+ register_tabular_function for details.
+
+ function_name : str
+ Name of the function. This name must be globally unique.
+ function_doc : dict
+ A dictionary object with keys "summary" (str),
+ and "description" (str).
+ in_types : Dict[str, DataType]
+ A dictionary mapping function argument names to
+ their respective DataType.
+ See register_scalar_function and
+ register_tabular_function for details.
+ out_type : DataType
+ Output type of the function.
+ func_registry : FunctionRegistry
+ Optional function registry to use instead of the default global one.
+ """
+ cdef:
+ CRegisterUdf c_register_func
c_string c_func_name
CArity c_arity
CFunctionDoc c_func_doc
vector[shared_ptr[CDataType]] c_in_types
PyObject* c_function
shared_ptr[CDataType] c_out_type
- CScalarUdfOptions c_options
+ CUdfOptions c_options
+ CFunctionRegistry* c_func_registry
if callable(func):
c_function = <PyObject*>func
@@ -2711,5 +2839,51 @@ def register_scalar_function(func, function_name, function_doc, in_types,
c_options.input_types = c_in_types
c_options.output_type = c_out_type
- check_status(RegisterScalarFunction(c_function,
- <function[CallbackUdf]> &_scalar_udf_callback, c_options))
+ if func_registry is None:
+ c_func_registry = NULL
+ else:
+ c_func_registry = (<FunctionRegistry>func_registry).registry
+
+ c_register_func = (<RegisterUdf>register_func).register_func
+
+ check_status(c_register_func(c_function,
+ <function[CallbackUdf]> &_udf_callback,
+ c_options, c_func_registry))
+
+
+def call_tabular_function(function_name, args=None, func_registry=None):
+ """
+ Get a record batch iterator from a tabular function.
+
+ Parameters
+ ----------
+ function_name : str
+ Name of the function.
+ args : iterable
+ The arguments to pass to the function. Accepted types depend
+ on the specific function. Currently, only an empty args is supported.
+ func_registry : FunctionRegistry
+ Optional function registry to use instead of the default global one.
+ """
+ cdef:
+ c_string c_func_name
+ vector[CDatum] c_args
+ CFunctionRegistry* c_func_registry
+ shared_ptr[CRecordBatchReader] c_reader
+ RecordBatchReader reader
+
+ c_func_name = tobytes(function_name)
+ if func_registry is None:
+ c_func_registry = NULL
+ else:
+ c_func_registry = (<FunctionRegistry>func_registry).registry
+ if args is None:
+ args = []
+ _pack_compute_args(args, &c_args)
+
+ with nogil:
+ c_reader = GetResultValue(CallTabularFunction(
+ c_func_name, c_args, c_func_registry))
+ reader = RecordBatchReader.__new__(RecordBatchReader)
+ reader.reader = c_reader
+ return RecordBatchReader.from_batches(pyarrow_wrap_schema(deref(c_reader).schema()), reader)
diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx
index 5f1610c384..38ff60f380 100644
--- a/python/pyarrow/_dataset.pyx
+++ b/python/pyarrow/_dataset.pyx
@@ -32,24 +32,13 @@ from pyarrow.lib cimport *
from pyarrow.lib import ArrowTypeError, frombytes, tobytes, _pc
from pyarrow.includes.libarrow_dataset cimport *
from pyarrow._compute cimport Expression, _bind
+from pyarrow._compute import _forbid_instantiation
from pyarrow._fs cimport FileSystem, FileInfo, FileSelector
from pyarrow._csv cimport (
ConvertOptions, ParseOptions, ReadOptions, WriteOptions)
from pyarrow.util import _is_iterable, _is_path_like, _stringify_path
-def _forbid_instantiation(klass, subclasses_instead=True):
- msg = '{} is an abstract class thus cannot be initialized.'.format(
- klass.__name__
- )
- if subclasses_instead:
- subclasses = [cls.__name__ for cls in klass.__subclasses__]
- msg += ' Use one of the subclasses instead: {}'.format(
- ', '.join(subclasses)
- )
- raise TypeError(msg)
-
-
_orc_fileformat = None
_orc_imported = False
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index 1ee6c40f42..f455b81411 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -80,7 +80,9 @@ from pyarrow._compute import ( # noqa
list_functions,
_group_by,
# Udf
+ call_tabular_function,
register_scalar_function,
+ register_tabular_function,
ScalarUdfContext,
# Expressions
Expression,
diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd
index 2d87971422..82358beea1 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -480,6 +480,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
vector[shared_ptr[CField]] GetAllFieldsByName(const c_string& name)
int GetFieldIndex(const c_string& name)
vector[int] GetAllFieldIndices(const c_string& name)
+ const vector[shared_ptr[CField]] fields()
int num_fields()
c_string ToString()
@@ -800,6 +801,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
const shared_ptr[CSchema]& schema, int64_t num_rows,
const vector[shared_ptr[CArray]]& columns)
+ CResult[shared_ptr[CStructArray]] ToStructArray() const
+
@staticmethod
CResult[shared_ptr[CRecordBatch]] FromStructArray(
const shared_ptr[CArray]& array)
@@ -2805,12 +2808,20 @@ cdef extern from "arrow/util/byte_size.h" namespace "arrow::util" nogil:
ctypedef PyObject* CallbackUdf(object user_function, const CScalarUdfContext& context, object inputs)
-cdef extern from "arrow/python/udf.h" namespace "arrow::py":
+
+cdef extern from "arrow/api.h" namespace "arrow" nogil:
+
+ cdef cppclass CRecordBatchIterator "arrow::RecordBatchIterator"(
+ CIterator[shared_ptr[CRecordBatch]]):
+ pass
+
+
+cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil:
cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext":
CMemoryPool *pool
int64_t batch_length
- cdef cppclass CScalarUdfOptions" arrow::py::ScalarUdfOptions":
+ cdef cppclass CUdfOptions" arrow::py::UdfOptions":
c_string func_name
CArity arity
CFunctionDoc func_doc
@@ -2818,4 +2829,12 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py":
shared_ptr[CDataType] output_type
CStatus RegisterScalarFunction(PyObject* function,
- function[CallbackUdf] wrapper, const CScalarUdfOptions& options)
+ function[CallbackUdf] wrapper, const CUdfOptions& options,
+ CFunctionRegistry* registry)
+
+ CStatus RegisterTabularFunction(PyObject* function,
+ function[CallbackUdf] wrapper, const CUdfOptions& options,
+ CFunctionRegistry* registry)
+
+ CResult[shared_ptr[CRecordBatchReader]] CallTabularFunction(
+ const c_string& func_name, const vector[CDatum]& args, CFunctionRegistry* registry)
diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd
index b75eafcdee..1603797084 100644
--- a/python/pyarrow/includes/libarrow_dataset.pxd
+++ b/python/pyarrow/includes/libarrow_dataset.pxd
@@ -25,13 +25,6 @@ from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_fs cimport *
-cdef extern from "arrow/api.h" namespace "arrow" nogil:
-
- cdef cppclass CRecordBatchIterator "arrow::RecordBatchIterator"(
- CIterator[shared_ptr[CRecordBatch]]):
- pass
-
-
cdef extern from "arrow/dataset/plan.h" namespace "arrow::dataset::internal" nogil:
cdef void Initialize()
diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc
index 81bf47c0ad..763bdf5a03 100644
--- a/python/pyarrow/src/arrow/python/udf.cc
+++ b/python/pyarrow/src/arrow/python/udf.cc
@@ -17,37 +17,121 @@
#include "arrow/python/udf.h"
#include "arrow/compute/function.h"
+#include "arrow/compute/kernel.h"
#include "arrow/python/common.h"
+#include "arrow/util/checked_cast.h"
namespace arrow {
-
-using compute::ExecResult;
-using compute::ExecSpan;
-
namespace py {
namespace {
-struct PythonUdf : public compute::KernelState {
- ScalarUdfWrapperCallback cb;
+struct PythonUdfKernelState : public compute::KernelState {
+ explicit PythonUdfKernelState(std::shared_ptr<OwnedRefNoGIL> function)
+ : function(function) {
+ Py_INCREF(function->obj());
+ }
+
+ // function needs to be destroyed at process exit
+ // and Python may no longer be initialized.
+ ~PythonUdfKernelState() {
+ if (_Py_IsFinalizing()) {
+ function->detach();
+ }
+ }
+
std::shared_ptr<OwnedRefNoGIL> function;
- std::shared_ptr<DataType> output_type;
+};
- PythonUdf(ScalarUdfWrapperCallback cb, std::shared_ptr<OwnedRefNoGIL> function,
- const std::shared_ptr<DataType>& output_type)
- : cb(cb), function(function), output_type(output_type) {}
+struct PythonUdfKernelInit {
+ explicit PythonUdfKernelInit(std::shared_ptr<OwnedRefNoGIL> function)
+ : function(function) {
+ Py_INCREF(function->obj());
+ }
// function needs to be destroyed at process exit
// and Python may no longer be initialized.
- ~PythonUdf() {
+ ~PythonUdfKernelInit() {
if (_Py_IsFinalizing()) {
function->detach();
}
}
- Status Exec(compute::KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
+ Result<std::unique_ptr<compute::KernelState>> operator()(
+ compute::KernelContext*, const compute::KernelInitArgs&) {
+ return std::make_unique<PythonUdfKernelState>(function);
+ }
+
+ std::shared_ptr<OwnedRefNoGIL> function;
+};
+
+struct PythonTableUdfKernelInit {
+ PythonTableUdfKernelInit(std::shared_ptr<OwnedRefNoGIL> function_maker,
+ UdfWrapperCallback cb)
+ : function_maker(function_maker), cb(cb) {
+ Py_INCREF(function_maker->obj());
+ }
+
+ // function needs to be destroyed at process exit
+ // and Python may no longer be initialized.
+ ~PythonTableUdfKernelInit() {
+ if (_Py_IsFinalizing()) {
+ function_maker->detach();
+ }
+ }
+
+ Result<std::unique_ptr<compute::KernelState>> operator()(
+ compute::KernelContext* ctx, const compute::KernelInitArgs&) {
+ ScalarUdfContext scalar_udf_context{ctx->memory_pool(), /*batch_length=*/0};
+ std::unique_ptr<OwnedRefNoGIL> function;
+ RETURN_NOT_OK(SafeCallIntoPython([this, &scalar_udf_context, &function] {
+ OwnedRef empty_tuple(PyTuple_New(0));
+ function = std::make_unique<OwnedRefNoGIL>(
+ cb(function_maker->obj(), scalar_udf_context, empty_tuple.obj()));
+ RETURN_NOT_OK(CheckPyError());
+ return Status::OK();
+ }));
+ if (!PyCallable_Check(function->obj())) {
+ return Status::TypeError("Expected a callable Python object.");
+ }
+ return std::make_unique<PythonUdfKernelState>(
+ std::move(function));
+ }
+
+ std::shared_ptr<OwnedRefNoGIL> function_maker;
+ UdfWrapperCallback cb;
+};
+
+struct PythonUdf : public PythonUdfKernelState {
+ PythonUdf(std::shared_ptr<OwnedRefNoGIL> function, UdfWrapperCallback cb,
+ std::vector<TypeHolder> input_types, compute::OutputType output_type)
+ : PythonUdfKernelState(function),
+ cb(cb),
+ input_types(input_types),
+ output_type(output_type) {}
+
+ UdfWrapperCallback cb;
+ std::vector<TypeHolder> input_types;
+ compute::OutputType output_type;
+ TypeHolder resolved_type;
+
+ Result<TypeHolder> ResolveType(compute::KernelContext* ctx,
+ const std::vector<TypeHolder>& types) {
+ if (input_types == types) {
+ if (!resolved_type) {
+ ARROW_ASSIGN_OR_RAISE(resolved_type, output_type.Resolve(ctx, input_types));
+ }
+ return resolved_type;
+ }
+ return output_type.Resolve(ctx, types);
+ }
+
+ Status Exec(compute::KernelContext* ctx, const compute::ExecSpan& batch,
+ compute::ExecResult* out) {
+ auto state = arrow::internal::checked_cast<PythonUdfKernelState*>(ctx->state());
+ std::shared_ptr<OwnedRefNoGIL>& function = state->function;
const int num_args = batch.num_values();
- ScalarUdfContext udf_context{ctx->memory_pool(), batch.length};
+ ScalarUdfContext scalar_udf_context{ctx->memory_pool(), batch.length};
OwnedRef arg_tuple(PyTuple_New(num_args));
RETURN_NOT_OK(CheckPyError());
@@ -63,13 +147,17 @@ struct PythonUdf : public compute::KernelState {
}
}
- OwnedRef result(cb(function->obj(), udf_context, arg_tuple.obj()));
+ OwnedRef result(cb(function->obj(), scalar_udf_context, arg_tuple.obj()));
RETURN_NOT_OK(CheckPyError());
// unwrapping the output for expected output type
if (is_array(result.obj())) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> val, unwrap_array(result.obj()));
- if (!output_type->Equals(*val->type())) {
- return Status::TypeError("Expected output datatype ", output_type->ToString(),
+ ARROW_ASSIGN_OR_RAISE(TypeHolder type, ResolveType(ctx, batch.GetTypes()));
+ if (type.type == NULLPTR) {
+ return Status::TypeError("expected output datatype is null");
+ }
+ if (*type.type != *val->type()) {
+ return Status::TypeError("Expected output datatype ", type.type->ToString(),
", but function returned datatype ",
val->type()->ToString());
}
@@ -83,16 +171,15 @@ struct PythonUdf : public compute::KernelState {
}
};
-Status PythonUdfExec(compute::KernelContext* ctx, const ExecSpan& batch,
- ExecResult* out) {
+Status PythonUdfExec(compute::KernelContext* ctx, const compute::ExecSpan& batch,
+ compute::ExecResult* out) {
auto udf = static_cast<PythonUdf*>(ctx->kernel()->data.get());
return SafeCallIntoPython([&]() -> Status { return udf->Exec(ctx, batch, out); });
}
-} // namespace
-
-Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback wrapper,
- const ScalarUdfOptions& options) {
+Status RegisterUdf(PyObject* user_function, compute::KernelInit kernel_init,
+ UdfWrapperCallback wrapper, const UdfOptions& options,
+ compute::FunctionRegistry* registry) {
if (!PyCallable_Check(user_function)) {
return Status::TypeError("Expected a callable Python object.");
}
@@ -105,21 +192,110 @@ Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback
}
compute::OutputType output_type(options.output_type);
auto udf_data = std::make_shared<PythonUdf>(
- wrapper, std::make_shared<OwnedRefNoGIL>(user_function), options.output_type);
+ std::make_shared<OwnedRefNoGIL>(user_function), wrapper,
+ TypeHolder::FromTypes(options.input_types), options.output_type);
compute::ScalarKernel kernel(
compute::KernelSignature::Make(std::move(input_types), std::move(output_type),
options.arity.is_varargs),
- PythonUdfExec);
+ PythonUdfExec, kernel_init);
kernel.data = std::move(udf_data);
kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE;
kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE;
RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel)));
- auto registry = compute::GetFunctionRegistry();
+ if (registry == NULLPTR) {
+ registry = compute::GetFunctionRegistry();
+ }
RETURN_NOT_OK(registry->AddFunction(std::move(scalar_func)));
return Status::OK();
}
-} // namespace py
+} // namespace
+Status RegisterScalarFunction(PyObject* user_function, UdfWrapperCallback wrapper,
+ const UdfOptions& options,
+ compute::FunctionRegistry* registry) {
+ return RegisterUdf(
+ user_function,
+ PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(user_function)}, wrapper,
+ options, registry);
+}
+
+Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapper,
+ const UdfOptions& options,
+ compute::FunctionRegistry* registry) {
+ if (options.arity.num_args != 0 || options.arity.is_varargs) {
+ return Status::NotImplemented("tabular function of non-null arity");
+ }
+ if (options.output_type->id() != Type::type::STRUCT) {
+ return Status::Invalid("tabular function with non-struct output");
+ }
+ return RegisterUdf(
+ user_function,
+ PythonTableUdfKernelInit{std::make_shared<OwnedRefNoGIL>(user_function), wrapper},
+ wrapper, options, registry);
+}
+
+Result<std::shared_ptr<RecordBatchReader>> CallTabularFunction(
+ const std::string& func_name, const std::vector<Datum>& args,
+ compute::FunctionRegistry* registry) {
+ if (args.size() != 0) {
+ return Status::NotImplemented("non-empty arguments to tabular function");
+ }
+ if (registry == NULLPTR) {
+ registry = compute::GetFunctionRegistry();
+ }
+ ARROW_ASSIGN_OR_RAISE(auto func, registry->GetFunction(func_name));
+ if (func->kind() != compute::Function::SCALAR) {
+ return Status::Invalid("tabular function of non-scalar kind");
+ }
+ auto arity = func->arity();
+ if (arity.num_args != 0 || arity.is_varargs) {
+ return Status::NotImplemented("tabular function of non-null arity");
+ }
+ auto kernels =
+ arrow::internal::checked_pointer_cast<compute::ScalarFunction>(func)->kernels();
+ if (kernels.size() != 1) {
+ return Status::NotImplemented("tabular function with non-single kernel");
+ }
+ const compute::ScalarKernel* kernel = kernels[0];
+ auto out_type = kernel->signature->out_type();
+ if (out_type.kind() != compute::OutputType::FIXED) {
+ return Status::Invalid("tabular kernel of non-fixed kind");
+ }
+ auto datatype = out_type.type();
+ if (datatype->id() != Type::type::STRUCT) {
+ return Status::Invalid("tabular kernel with non-struct output");
+ }
+ auto struct_type = arrow::internal::checked_cast<StructType*>(datatype.get());
+ auto schema = ::arrow::schema(struct_type->fields());
+ std::vector<TypeHolder> in_types;
+ ARROW_ASSIGN_OR_RAISE(auto func_exec,
+ GetFunctionExecutor(func_name, in_types, NULLPTR, registry));
+ auto next_func =
+ [schema,
+ func_exec = std::move(func_exec)]() -> Result<std::shared_ptr<RecordBatch>> {
+ std::vector<Datum> args;
+ // passed_length of -1 or 0 with args.size() of 0 leads to an empty ExecSpanIterator
+ // in exec.cc and to never invoking the source function, so 1 is passed instead
+ // TODO: GH-33612: Support batch size in user-defined tabular functions
+ ARROW_ASSIGN_OR_RAISE(auto datum, func_exec->Execute(args, /*passed_length=*/1));
+ if (!datum.is_array()) {
+ return Status::Invalid("UDF result of non-array kind");
+ }
+ std::shared_ptr<Array> array = datum.make_array();
+ if (array->length() == 0) {
+ return IterationTraits<std::shared_ptr<RecordBatch>>::End();
+ }
+ ARROW_ASSIGN_OR_RAISE(auto batch, RecordBatch::FromStructArray(std::move(array)));
+ if (!schema->Equals(batch->schema())) {
+ return Status::Invalid("UDF result with shape not conforming to schema");
+ }
+ return std::move(batch);
+ };
+ return RecordBatchReader::MakeFromIterator(MakeFunctionIterator(std::move(next_func)),
+ schema);
+}
+
+} // namespace py
} // namespace arrow
diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h
index 9a3666459f..cde97d9cb9 100644
--- a/python/pyarrow/src/arrow/python/udf.h
+++ b/python/pyarrow/src/arrow/python/udf.h
@@ -21,6 +21,8 @@
#include "arrow/compute/function.h"
#include "arrow/compute/registry.h"
#include "arrow/python/platform.h"
+#include "arrow/record_batch.h"
+#include "arrow/util/iterator.h"
#include "arrow/python/common.h"
#include "arrow/python/pyarrow.h"
@@ -33,7 +35,7 @@ namespace py {
// TODO: TODO(ARROW-16041): UDF Options are not exposed to the Python
// users. This feature will be included when extending to provide advanced
// options for the users.
-struct ARROW_PYTHON_EXPORT ScalarUdfOptions {
+struct ARROW_PYTHON_EXPORT UdfOptions {
std::string func_name;
compute::Arity arity;
compute::FunctionDoc func_doc;
@@ -47,13 +49,22 @@ struct ARROW_PYTHON_EXPORT ScalarUdfContext {
int64_t batch_length;
};
-using ScalarUdfWrapperCallback = std::function<PyObject*(
+using UdfWrapperCallback = std::function<PyObject*(
PyObject* user_function, const ScalarUdfContext& context, PyObject* inputs)>;
/// \brief register a Scalar user-defined-function from Python
-Status ARROW_PYTHON_EXPORT RegisterScalarFunction(PyObject* user_function,
- ScalarUdfWrapperCallback wrapper,
- const ScalarUdfOptions& options);
+Status ARROW_PYTHON_EXPORT RegisterScalarFunction(
+ PyObject* user_function, UdfWrapperCallback wrapper,
+ const UdfOptions& options, compute::FunctionRegistry* registry = NULLPTR);
+
+/// \brief register a Table user-defined-function from Python
+Status ARROW_PYTHON_EXPORT RegisterTabularFunction(
+ PyObject* user_function, UdfWrapperCallback wrapper,
+ const UdfOptions& options, compute::FunctionRegistry* registry = NULLPTR);
+
+Result<std::shared_ptr<RecordBatchReader>> ARROW_PYTHON_EXPORT CallTabularFunction(
+ const std::string& func_name, const std::vector<Datum>& args,
+ compute::FunctionRegistry* registry = NULLPTR);
} // namespace py
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index 35492d4f64..47f37a53ad 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -2550,6 +2550,20 @@ cdef class RecordBatch(_PandasConvertible):
CRecordBatch.FromStructArray(struct_array.sp_array))
return pyarrow_wrap_batch(c_record_batch)
+ def to_struct_array(self):
+ """
+ Convert to a struct array.
+ """
+ cdef:
+ shared_ptr[CRecordBatch] c_record_batch
+ shared_ptr[CArray] c_array
+
+ c_record_batch = pyarrow_unwrap_batch(self)
+ with nogil:
+ c_array = GetResultValue(
+ <CResult[shared_ptr[CArray]]>deref(c_record_batch).ToStructArray())
+ return pyarrow_wrap_array(c_array)
+
def _export_to_c(self, out_ptr, out_schema_ptr=0):
"""
Export to a C ArrowArray struct, given its pointer.
diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py
index e711619582..6a67e0bae9 100644
--- a/python/pyarrow/tests/test_udf.py
+++ b/python/pyarrow/tests/test_udf.py
@@ -31,7 +31,7 @@ except ImportError:
ds = None
-def mock_udf_context(batch_length=10):
+def mock_scalar_udf_context(batch_length=10):
from pyarrow._compute import _get_scalar_udf_context
return _get_scalar_udf_context(pa.default_memory_pool(), batch_length)
@@ -248,7 +248,7 @@ def check_scalar_function(func_fixture,
if all_scalar:
batch_length = 1
- expected_output = function(mock_udf_context(batch_length), *inputs)
+ expected_output = function(mock_scalar_udf_context(batch_length), *inputs)
func = pc.get_function(name)
assert func.name == name
@@ -266,7 +266,7 @@ def check_scalar_function(func_fixture,
assert result_table.column(0).chunks[0] == expected_output
-def test_scalar_udf_array_unary(unary_func_fixture):
+def test_udf_array_unary(unary_func_fixture):
check_scalar_function(unary_func_fixture,
[
pa.array([10, 20], pa.int64())
@@ -274,7 +274,7 @@ def test_scalar_udf_array_unary(unary_func_fixture):
)
-def test_scalar_udf_array_binary(binary_func_fixture):
+def test_udf_array_binary(binary_func_fixture):
check_scalar_function(binary_func_fixture,
[
pa.array([10, 20], pa.int64()),
@@ -283,7 +283,7 @@ def test_scalar_udf_array_binary(binary_func_fixture):
)
-def test_scalar_udf_array_ternary(ternary_func_fixture):
+def test_udf_array_ternary(ternary_func_fixture):
check_scalar_function(ternary_func_fixture,
[
pa.array([10, 20], pa.int64()),
@@ -293,7 +293,7 @@ def test_scalar_udf_array_ternary(ternary_func_fixture):
)
-def test_scalar_udf_array_varargs(varargs_func_fixture):
+def test_udf_array_varargs(varargs_func_fixture):
check_scalar_function(varargs_func_fixture,
[
pa.array([2, 3], pa.int64()),
@@ -464,7 +464,7 @@ def test_wrong_input_type_declaration():
in_types, out_type)
-def test_udf_context(unary_func_fixture):
+def test_scalar_udf_context(unary_func_fixture):
# Check the memory_pool argument is properly propagated
proxy_pool = pa.proxy_memory_pool(pa.default_memory_pool())
_, func_name = unary_func_fixture
@@ -504,3 +504,112 @@ def test_input_lifetime(unary_func_fixture):
# Calling a UDF should not have kept `v` alive longer than required
v = None
assert proxy_pool.bytes_allocated() == 0
+
+
+def _record_batch_from_iters(schema, *iters):
+ arrays = [pa.array(list(v), type=schema[i].type)
+ for i, v in enumerate(iters)]
+ return pa.RecordBatch.from_arrays(arrays=arrays, schema=schema)
+
+
+def _record_batch_for_range(schema, n):
+ return _record_batch_from_iters(schema,
+ range(n, n + 10),
+ range(n + 1, n + 11))
+
+
+def make_udt_func(schema, batch_gen):
+ def udf_func(ctx):
+ class UDT:
+ def __init__(self):
+ self.caller = None
+
+ def __call__(self, ctx):
+ try:
+ if self.caller is None:
+ self.caller, ctx = batch_gen(ctx).send, None
+ batch = self.caller(ctx)
+ except StopIteration:
+ arrays = [pa.array([], type=field.type)
+ for field in schema]
+ batch = pa.RecordBatch.from_arrays(
+ arrays=arrays, schema=schema)
+ return batch.to_struct_array()
+ return UDT()
+ return udf_func
+
+
+def datasource1_direct():
+ """A short dataset"""
+ schema = datasource1_schema()
+
+ class Generator:
+ def __init__(self):
+ self.n = 3
+
+ def __call__(self, ctx):
+ if self.n == 0:
+ batch = _record_batch_from_iters(schema, [], [])
+ else:
+ self.n -= 1
+ batch = _record_batch_for_range(schema, self.n)
+ return batch.to_struct_array()
+ return lambda ctx: Generator()
+
+
+def datasource1_generator():
+ schema = datasource1_schema()
+
+ def batch_gen(ctx):
+ for n in range(3, 0, -1):
+ # ctx =
+ yield _record_batch_for_range(schema, n - 1)
+ return make_udt_func(schema, batch_gen)
+
+
+def datasource1_exception():
+ schema = datasource1_schema()
+
+ def batch_gen(ctx):
+ for n in range(3, 0, -1):
+ # ctx =
+ yield _record_batch_for_range(schema, n - 1)
+ raise RuntimeError("datasource1_exception")
+ return make_udt_func(schema, batch_gen)
+
+
+def datasource1_schema():
+ return pa.schema([('', pa.int32()), ('', pa.int32())])
+
+
+def datasource1_args(func, func_name):
+ func_doc = {"summary": f"{func_name} UDT",
+ "description": "test {func_name} UDT"}
+ in_types = {}
+ out_type = pa.struct([("", pa.int32()), ("", pa.int32())])
+ return func, func_name, func_doc, in_types, out_type
+
+
+def _test_datasource1_udt(func_maker):
+ schema = datasource1_schema()
+ func = func_maker()
+ func_name = func_maker.__name__
+ func_args = datasource1_args(func, func_name)
+ pc.register_tabular_function(*func_args)
+ n = 3
+ for item in pc.call_tabular_function(func_name):
+ n -= 1
+ assert item == _record_batch_for_range(schema, n)
+
+
+def test_udt_datasource1_direct():
+ _test_datasource1_udt(datasource1_direct)
+
+
+def test_udt_datasource1_generator():
+ _test_datasource1_udt(datasource1_generator)
+
+
+def test_udt_datasource1_exception():
+ with pytest.raises(RuntimeError, match='datasource1_exception'):
+ _test_datasource1_udt(datasource1_exception)