You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2023/01/26 13:24:38 UTC

[arrow] branch master updated: GH-32916: [C++] [Python] User-defined tabular functions (#14682)

This is an automated email from the ASF dual-hosted git repository.

westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new a1d9b511da GH-32916: [C++] [Python] User-defined tabular functions (#14682)
a1d9b511da is described below

commit a1d9b511dab6fed7e0d074e8abc6dbc26f50a460
Author: rtpsw <rt...@hotmail.com>
AuthorDate: Thu Jan 26 15:24:28 2023 +0200

    GH-32916: [C++] [Python] User-defined tabular functions (#14682)
    
    See https://issues.apache.org/jira/browse/ARROW-17676
    * Closes: #32916
    
    Lead-authored-by: Yaron Gvili <rt...@hotmail.com>
    Co-authored-by: rtpsw <rt...@hotmail.com>
    Signed-off-by: Weston Pace <we...@gmail.com>
---
 cpp/src/arrow/type.cc                        |  10 ++
 cpp/src/arrow/type.h                         |   3 +
 python/pyarrow/_compute.pyx                  | 186 +++++++++++++++++++++-
 python/pyarrow/_dataset.pyx                  |  13 +-
 python/pyarrow/compute.py                    |   2 +
 python/pyarrow/includes/libarrow.pxd         |  25 ++-
 python/pyarrow/includes/libarrow_dataset.pxd |   7 -
 python/pyarrow/src/arrow/python/udf.cc       | 228 ++++++++++++++++++++++++---
 python/pyarrow/src/arrow/python/udf.h        |  21 ++-
 python/pyarrow/table.pxi                     |  14 ++
 python/pyarrow/tests/test_udf.py             | 123 ++++++++++++++-
 11 files changed, 566 insertions(+), 66 deletions(-)

diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc
index cc31735512..3514d0538f 100644
--- a/cpp/src/arrow/type.cc
+++ b/cpp/src/arrow/type.cc
@@ -452,6 +452,16 @@ std::string TypeHolder::ToString(const std::vector<TypeHolder>& types) {
   return ss.str();
 }
 
+std::vector<TypeHolder> TypeHolder::FromTypes(
+    const std::vector<std::shared_ptr<DataType>>& types) {
+  std::vector<TypeHolder> type_holders;
+  type_holders.reserve(types.size());
+  for (const auto& type : types) {
+    type_holders.emplace_back(type);
+  }
+  return type_holders;
+}
+
 // ----------------------------------------------------------------------
 
 FloatingPointType::Precision HalfFloatType::precision() const {
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index 1e16806947..ee5bc6ef37 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -264,6 +264,9 @@ struct ARROW_EXPORT TypeHolder {
   }
 
   static std::string ToString(const std::vector<TypeHolder>&);
+
+  static std::vector<TypeHolder> FromTypes(
+      const std::vector<std::shared_ptr<DataType>>& types);
 };
 
 ARROW_EXPORT
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index c75c5bf189..283f532837 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -36,6 +36,18 @@ import inspect
 import numpy as np
 
 
+def _forbid_instantiation(klass, subclasses_instead=True):
+    msg = '{} is an abstract class thus cannot be initialized.'.format(
+        klass.__name__
+    )
+    if subclasses_instead:
+        subclasses = [cls.__name__ for cls in klass.__subclasses__]
+        msg += ' Use one of the subclasses instead: {}'.format(
+            ', '.join(subclasses)
+        )
+    raise TypeError(msg)
+
+
 cdef wrap_scalar_function(const shared_ptr[CFunction]& sp_func):
     """
     Wrap a C++ scalar Function in a ScalarFunction object.
@@ -2574,7 +2586,7 @@ cdef object box_scalar_udf_context(const CScalarUdfContext& c_context):
     return context
 
 
-cdef _scalar_udf_callback(user_function, const CScalarUdfContext& c_context, inputs):
+cdef _udf_callback(user_function, const CScalarUdfContext& c_context, inputs):
     """
     Helper callback function used to wrap the ScalarUdfContext from Python to C++
     execution.
@@ -2591,8 +2603,30 @@ def _get_scalar_udf_context(memory_pool, batch_length):
     return context
 
 
-def register_scalar_function(func, function_name, function_doc, in_types,
-                             out_type):
+ctypedef CStatus (*CRegisterUdf)(PyObject* function, function[CallbackUdf] wrapper,
+                                 const CUdfOptions& options, CFunctionRegistry* registry)
+
+cdef class RegisterUdf(_Weakrefable):
+    cdef CRegisterUdf register_func
+
+    cdef void init(self, const CRegisterUdf register_func):
+        self.register_func = register_func
+
+
+cdef get_register_scalar_function():
+    cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf)
+    reg.register_func = RegisterScalarFunction
+    return reg
+
+
+cdef get_register_tabular_function():
+    cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf)
+    reg.register_func = RegisterTabularFunction
+    return reg
+
+
+def register_scalar_function(func, function_name, function_doc, in_types, out_type,
+                             func_registry=None):
     """
     Register a user-defined scalar function.
 
@@ -2633,6 +2667,8 @@ def register_scalar_function(func, function_name, function_doc, in_types,
         arity.
     out_type : DataType
         Output type of the function.
+    func_registry : FunctionRegistry
+        Optional function registry to use instead of the default global one.
 
     Examples
     --------
@@ -2662,14 +2698,106 @@ def register_scalar_function(func, function_name, function_doc, in_types,
       21
     ]
     """
+    return _register_scalar_like_function(get_register_scalar_function(),
+                                          func, function_name, function_doc, in_types,
+                                          out_type, func_registry)
+
+
+def register_tabular_function(func, function_name, function_doc, in_types, out_type,
+                              func_registry=None):
+    """
+    Register a user-defined tabular function.
+
+    A tabular function is one accepting a context argument of type
+    ScalarUdfContext and returning a generator of struct arrays.
+    The in_types argument must be empty and the out_type argument
+    specifies a schema. Each struct array must have field types
+    correspoding to the schema.
+
+    Parameters
+    ----------
+    func : callable
+        A callable implementing the user-defined function.
+        The only argument is the context argument of type
+        ScalarUdfContext. It must return a callable that
+        returns on each invocation a StructArray matching
+        the out_type, where an empty array indicates end.
+    function_name : str
+        Name of the function. This name must be globally unique.
+    function_doc : dict
+        A dictionary object with keys "summary" (str),
+        and "description" (str).
+    in_types : Dict[str, DataType]
+        Must be an empty dictionary (reserved for future use).
+    out_type : Union[Schema, DataType]
+        Schema of the function's output, or a corresponding flat struct type.
+    func_registry : FunctionRegistry
+        Optional function registry to use instead of the default global one.
+    """
     cdef:
+        shared_ptr[CSchema] c_schema
+        shared_ptr[CDataType] c_type
+
+    if isinstance(out_type, Schema):
+        c_schema = pyarrow_unwrap_schema(out_type)
+        with nogil:
+            c_type = <shared_ptr[CDataType]>make_shared[CStructType](deref(c_schema).fields())
+        out_type = pyarrow_wrap_data_type(c_type)
+    return _register_scalar_like_function(get_register_tabular_function(),
+                                          func, function_name, function_doc, in_types,
+                                          out_type, func_registry)
+
+
+def _register_scalar_like_function(register_func, func, function_name, function_doc, in_types,
+                                   out_type, func_registry=None):
+    """
+    Register a user-defined scalar-like function.
+
+    A scalar-like function is a callable accepting a first
+    context argument of type ScalarUdfContext as well as
+    possibly additional Arrow arguments, and returning a
+    an Arrow result appropriate for the kind of function.
+    A scalar function and a tabular function are examples
+    for scalar-like functions.
+    This function is normally not called directly but via
+    register_scalar_function or register_tabular_function.
+
+    Parameters
+    ----------
+    register_func: object
+        An object holding a CRegisterUdf in a "register_func" attribute. Use
+        get_register_scalar_function() for a scalar function and
+        get_register_tabular_function() for a tabular function.
+    func : callable
+        A callable implementing the user-defined function.
+        See register_scalar_function and
+        register_tabular_function for details.
+
+    function_name : str
+        Name of the function. This name must be globally unique.
+    function_doc : dict
+        A dictionary object with keys "summary" (str),
+        and "description" (str).
+    in_types : Dict[str, DataType]
+        A dictionary mapping function argument names to
+        their respective DataType.
+        See register_scalar_function and
+        register_tabular_function for details.
+    out_type : DataType
+        Output type of the function.
+    func_registry : FunctionRegistry
+        Optional function registry to use instead of the default global one.
+    """
+    cdef:
+        CRegisterUdf c_register_func
         c_string c_func_name
         CArity c_arity
         CFunctionDoc c_func_doc
         vector[shared_ptr[CDataType]] c_in_types
         PyObject* c_function
         shared_ptr[CDataType] c_out_type
-        CScalarUdfOptions c_options
+        CUdfOptions c_options
+        CFunctionRegistry* c_func_registry
 
     if callable(func):
         c_function = <PyObject*>func
@@ -2711,5 +2839,51 @@ def register_scalar_function(func, function_name, function_doc, in_types,
     c_options.input_types = c_in_types
     c_options.output_type = c_out_type
 
-    check_status(RegisterScalarFunction(c_function,
-                                        <function[CallbackUdf]> &_scalar_udf_callback, c_options))
+    if func_registry is None:
+        c_func_registry = NULL
+    else:
+        c_func_registry = (<FunctionRegistry>func_registry).registry
+
+    c_register_func = (<RegisterUdf>register_func).register_func
+
+    check_status(c_register_func(c_function,
+                                 <function[CallbackUdf]> &_udf_callback,
+                                 c_options, c_func_registry))
+
+
+def call_tabular_function(function_name, args=None, func_registry=None):
+    """
+    Get a record batch iterator from a tabular function.
+
+    Parameters
+    ----------
+    function_name : str
+        Name of the function.
+    args : iterable
+        The arguments to pass to the function.  Accepted types depend
+        on the specific function.  Currently, only an empty args is supported.
+    func_registry : FunctionRegistry
+        Optional function registry to use instead of the default global one.
+    """
+    cdef:
+        c_string c_func_name
+        vector[CDatum] c_args
+        CFunctionRegistry* c_func_registry
+        shared_ptr[CRecordBatchReader] c_reader
+        RecordBatchReader reader
+
+    c_func_name = tobytes(function_name)
+    if func_registry is None:
+        c_func_registry = NULL
+    else:
+        c_func_registry = (<FunctionRegistry>func_registry).registry
+    if args is None:
+        args = []
+    _pack_compute_args(args, &c_args)
+
+    with nogil:
+        c_reader = GetResultValue(CallTabularFunction(
+            c_func_name, c_args, c_func_registry))
+    reader = RecordBatchReader.__new__(RecordBatchReader)
+    reader.reader = c_reader
+    return RecordBatchReader.from_batches(pyarrow_wrap_schema(deref(c_reader).schema()), reader)
diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx
index 5f1610c384..38ff60f380 100644
--- a/python/pyarrow/_dataset.pyx
+++ b/python/pyarrow/_dataset.pyx
@@ -32,24 +32,13 @@ from pyarrow.lib cimport *
 from pyarrow.lib import ArrowTypeError, frombytes, tobytes, _pc
 from pyarrow.includes.libarrow_dataset cimport *
 from pyarrow._compute cimport Expression, _bind
+from pyarrow._compute import _forbid_instantiation
 from pyarrow._fs cimport FileSystem, FileInfo, FileSelector
 from pyarrow._csv cimport (
     ConvertOptions, ParseOptions, ReadOptions, WriteOptions)
 from pyarrow.util import _is_iterable, _is_path_like, _stringify_path
 
 
-def _forbid_instantiation(klass, subclasses_instead=True):
-    msg = '{} is an abstract class thus cannot be initialized.'.format(
-        klass.__name__
-    )
-    if subclasses_instead:
-        subclasses = [cls.__name__ for cls in klass.__subclasses__]
-        msg += ' Use one of the subclasses instead: {}'.format(
-            ', '.join(subclasses)
-        )
-    raise TypeError(msg)
-
-
 _orc_fileformat = None
 _orc_imported = False
 
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index 1ee6c40f42..f455b81411 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -80,7 +80,9 @@ from pyarrow._compute import (  # noqa
     list_functions,
     _group_by,
     # Udf
+    call_tabular_function,
     register_scalar_function,
+    register_tabular_function,
     ScalarUdfContext,
     # Expressions
     Expression,
diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd
index 2d87971422..82358beea1 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -480,6 +480,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
         vector[shared_ptr[CField]] GetAllFieldsByName(const c_string& name)
         int GetFieldIndex(const c_string& name)
         vector[int] GetAllFieldIndices(const c_string& name)
+        const vector[shared_ptr[CField]] fields()
         int num_fields()
         c_string ToString()
 
@@ -800,6 +801,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
             const shared_ptr[CSchema]& schema, int64_t num_rows,
             const vector[shared_ptr[CArray]]& columns)
 
+        CResult[shared_ptr[CStructArray]] ToStructArray() const
+
         @staticmethod
         CResult[shared_ptr[CRecordBatch]] FromStructArray(
             const shared_ptr[CArray]& array)
@@ -2805,12 +2808,20 @@ cdef extern from "arrow/util/byte_size.h" namespace "arrow::util" nogil:
 
 ctypedef PyObject* CallbackUdf(object user_function, const CScalarUdfContext& context, object inputs)
 
-cdef extern from "arrow/python/udf.h" namespace "arrow::py":
+
+cdef extern from "arrow/api.h" namespace "arrow" nogil:
+
+    cdef cppclass CRecordBatchIterator "arrow::RecordBatchIterator"(
+            CIterator[shared_ptr[CRecordBatch]]):
+        pass
+
+
+cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil:
     cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext":
         CMemoryPool *pool
         int64_t batch_length
 
-    cdef cppclass CScalarUdfOptions" arrow::py::ScalarUdfOptions":
+    cdef cppclass CUdfOptions" arrow::py::UdfOptions":
         c_string func_name
         CArity arity
         CFunctionDoc func_doc
@@ -2818,4 +2829,12 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py":
         shared_ptr[CDataType] output_type
 
     CStatus RegisterScalarFunction(PyObject* function,
-                                   function[CallbackUdf] wrapper, const CScalarUdfOptions& options)
+                                   function[CallbackUdf] wrapper, const CUdfOptions& options,
+                                   CFunctionRegistry* registry)
+
+    CStatus RegisterTabularFunction(PyObject* function,
+                                    function[CallbackUdf] wrapper, const CUdfOptions& options,
+                                    CFunctionRegistry* registry)
+
+    CResult[shared_ptr[CRecordBatchReader]] CallTabularFunction(
+        const c_string& func_name, const vector[CDatum]& args, CFunctionRegistry* registry)
diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd
index b75eafcdee..1603797084 100644
--- a/python/pyarrow/includes/libarrow_dataset.pxd
+++ b/python/pyarrow/includes/libarrow_dataset.pxd
@@ -25,13 +25,6 @@ from pyarrow.includes.libarrow cimport *
 from pyarrow.includes.libarrow_fs cimport *
 
 
-cdef extern from "arrow/api.h" namespace "arrow" nogil:
-
-    cdef cppclass CRecordBatchIterator "arrow::RecordBatchIterator"(
-            CIterator[shared_ptr[CRecordBatch]]):
-        pass
-
-
 cdef extern from "arrow/dataset/plan.h" namespace "arrow::dataset::internal" nogil:
 
     cdef void Initialize()
diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc
index 81bf47c0ad..763bdf5a03 100644
--- a/python/pyarrow/src/arrow/python/udf.cc
+++ b/python/pyarrow/src/arrow/python/udf.cc
@@ -17,37 +17,121 @@
 
 #include "arrow/python/udf.h"
 #include "arrow/compute/function.h"
+#include "arrow/compute/kernel.h"
 #include "arrow/python/common.h"
+#include "arrow/util/checked_cast.h"
 
 namespace arrow {
-
-using compute::ExecResult;
-using compute::ExecSpan;
-
 namespace py {
 
 namespace {
 
-struct PythonUdf : public compute::KernelState {
-  ScalarUdfWrapperCallback cb;
+struct PythonUdfKernelState : public compute::KernelState {
+  explicit PythonUdfKernelState(std::shared_ptr<OwnedRefNoGIL> function)
+      : function(function) {
+    Py_INCREF(function->obj());
+  }
+
+  // function needs to be destroyed at process exit
+  // and Python may no longer be initialized.
+  ~PythonUdfKernelState() {
+    if (_Py_IsFinalizing()) {
+      function->detach();
+    }
+  }
+
   std::shared_ptr<OwnedRefNoGIL> function;
-  std::shared_ptr<DataType> output_type;
+};
 
-  PythonUdf(ScalarUdfWrapperCallback cb, std::shared_ptr<OwnedRefNoGIL> function,
-            const std::shared_ptr<DataType>& output_type)
-      : cb(cb), function(function), output_type(output_type) {}
+struct PythonUdfKernelInit {
+  explicit PythonUdfKernelInit(std::shared_ptr<OwnedRefNoGIL> function)
+      : function(function) {
+    Py_INCREF(function->obj());
+  }
 
   // function needs to be destroyed at process exit
   // and Python may no longer be initialized.
-  ~PythonUdf() {
+  ~PythonUdfKernelInit() {
     if (_Py_IsFinalizing()) {
       function->detach();
     }
   }
 
-  Status Exec(compute::KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
+  Result<std::unique_ptr<compute::KernelState>> operator()(
+      compute::KernelContext*, const compute::KernelInitArgs&) {
+    return std::make_unique<PythonUdfKernelState>(function);
+  }
+
+  std::shared_ptr<OwnedRefNoGIL> function;
+};
+
+struct PythonTableUdfKernelInit {
+  PythonTableUdfKernelInit(std::shared_ptr<OwnedRefNoGIL> function_maker,
+                           UdfWrapperCallback cb)
+      : function_maker(function_maker), cb(cb) {
+    Py_INCREF(function_maker->obj());
+  }
+
+  // function needs to be destroyed at process exit
+  // and Python may no longer be initialized.
+  ~PythonTableUdfKernelInit() {
+    if (_Py_IsFinalizing()) {
+      function_maker->detach();
+    }
+  }
+
+  Result<std::unique_ptr<compute::KernelState>> operator()(
+      compute::KernelContext* ctx, const compute::KernelInitArgs&) {
+    ScalarUdfContext scalar_udf_context{ctx->memory_pool(), /*batch_length=*/0};
+    std::unique_ptr<OwnedRefNoGIL> function;
+    RETURN_NOT_OK(SafeCallIntoPython([this, &scalar_udf_context, &function] {
+      OwnedRef empty_tuple(PyTuple_New(0));
+      function = std::make_unique<OwnedRefNoGIL>(
+          cb(function_maker->obj(), scalar_udf_context, empty_tuple.obj()));
+      RETURN_NOT_OK(CheckPyError());
+      return Status::OK();
+    }));
+    if (!PyCallable_Check(function->obj())) {
+      return Status::TypeError("Expected a callable Python object.");
+    }
+    return std::make_unique<PythonUdfKernelState>(
+        std::move(function));
+  }
+
+  std::shared_ptr<OwnedRefNoGIL> function_maker;
+  UdfWrapperCallback cb;
+};
+
+struct PythonUdf : public PythonUdfKernelState {
+  PythonUdf(std::shared_ptr<OwnedRefNoGIL> function, UdfWrapperCallback cb,
+            std::vector<TypeHolder> input_types, compute::OutputType output_type)
+      : PythonUdfKernelState(function),
+        cb(cb),
+        input_types(input_types),
+        output_type(output_type) {}
+
+  UdfWrapperCallback cb;
+  std::vector<TypeHolder> input_types;
+  compute::OutputType output_type;
+  TypeHolder resolved_type;
+
+  Result<TypeHolder> ResolveType(compute::KernelContext* ctx,
+                                 const std::vector<TypeHolder>& types) {
+    if (input_types == types) {
+      if (!resolved_type) {
+        ARROW_ASSIGN_OR_RAISE(resolved_type, output_type.Resolve(ctx, input_types));
+      }
+      return resolved_type;
+    }
+    return output_type.Resolve(ctx, types);
+  }
+
+  Status Exec(compute::KernelContext* ctx, const compute::ExecSpan& batch,
+              compute::ExecResult* out) {
+    auto state = arrow::internal::checked_cast<PythonUdfKernelState*>(ctx->state());
+    std::shared_ptr<OwnedRefNoGIL>& function = state->function;
     const int num_args = batch.num_values();
-    ScalarUdfContext udf_context{ctx->memory_pool(), batch.length};
+    ScalarUdfContext scalar_udf_context{ctx->memory_pool(), batch.length};
 
     OwnedRef arg_tuple(PyTuple_New(num_args));
     RETURN_NOT_OK(CheckPyError());
@@ -63,13 +147,17 @@ struct PythonUdf : public compute::KernelState {
       }
     }
 
-    OwnedRef result(cb(function->obj(), udf_context, arg_tuple.obj()));
+    OwnedRef result(cb(function->obj(), scalar_udf_context, arg_tuple.obj()));
     RETURN_NOT_OK(CheckPyError());
     // unwrapping the output for expected output type
     if (is_array(result.obj())) {
       ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> val, unwrap_array(result.obj()));
-      if (!output_type->Equals(*val->type())) {
-        return Status::TypeError("Expected output datatype ", output_type->ToString(),
+      ARROW_ASSIGN_OR_RAISE(TypeHolder type, ResolveType(ctx, batch.GetTypes()));
+      if (type.type == NULLPTR) {
+        return Status::TypeError("expected output datatype is null");
+      }
+      if (*type.type != *val->type()) {
+        return Status::TypeError("Expected output datatype ", type.type->ToString(),
                                  ", but function returned datatype ",
                                  val->type()->ToString());
       }
@@ -83,16 +171,15 @@ struct PythonUdf : public compute::KernelState {
   }
 };
 
-Status PythonUdfExec(compute::KernelContext* ctx, const ExecSpan& batch,
-                     ExecResult* out) {
+Status PythonUdfExec(compute::KernelContext* ctx, const compute::ExecSpan& batch,
+                     compute::ExecResult* out) {
   auto udf = static_cast<PythonUdf*>(ctx->kernel()->data.get());
   return SafeCallIntoPython([&]() -> Status { return udf->Exec(ctx, batch, out); });
 }
 
-}  // namespace
-
-Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback wrapper,
-                              const ScalarUdfOptions& options) {
+Status RegisterUdf(PyObject* user_function, compute::KernelInit kernel_init,
+                   UdfWrapperCallback wrapper, const UdfOptions& options,
+                   compute::FunctionRegistry* registry) {
   if (!PyCallable_Check(user_function)) {
     return Status::TypeError("Expected a callable Python object.");
   }
@@ -105,21 +192,110 @@ Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback
   }
   compute::OutputType output_type(options.output_type);
   auto udf_data = std::make_shared<PythonUdf>(
-      wrapper, std::make_shared<OwnedRefNoGIL>(user_function), options.output_type);
+      std::make_shared<OwnedRefNoGIL>(user_function), wrapper,
+      TypeHolder::FromTypes(options.input_types), options.output_type);
   compute::ScalarKernel kernel(
       compute::KernelSignature::Make(std::move(input_types), std::move(output_type),
                                      options.arity.is_varargs),
-      PythonUdfExec);
+      PythonUdfExec, kernel_init);
   kernel.data = std::move(udf_data);
 
   kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE;
   kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE;
   RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel)));
-  auto registry = compute::GetFunctionRegistry();
+  if (registry == NULLPTR) {
+    registry = compute::GetFunctionRegistry();
+  }
   RETURN_NOT_OK(registry->AddFunction(std::move(scalar_func)));
   return Status::OK();
 }
 
-}  // namespace py
+}  // namespace
 
+Status RegisterScalarFunction(PyObject* user_function, UdfWrapperCallback wrapper,
+                              const UdfOptions& options,
+                              compute::FunctionRegistry* registry) {
+  return RegisterUdf(
+      user_function,
+      PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(user_function)}, wrapper,
+      options, registry);
+}
+
+Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapper,
+                               const UdfOptions& options,
+                               compute::FunctionRegistry* registry) {
+  if (options.arity.num_args != 0 || options.arity.is_varargs) {
+    return Status::NotImplemented("tabular function of non-null arity");
+  }
+  if (options.output_type->id() != Type::type::STRUCT) {
+    return Status::Invalid("tabular function with non-struct output");
+  }
+  return RegisterUdf(
+      user_function,
+      PythonTableUdfKernelInit{std::make_shared<OwnedRefNoGIL>(user_function), wrapper},
+      wrapper, options, registry);
+}
+
+Result<std::shared_ptr<RecordBatchReader>> CallTabularFunction(
+    const std::string& func_name, const std::vector<Datum>& args,
+    compute::FunctionRegistry* registry) {
+  if (args.size() != 0) {
+    return Status::NotImplemented("non-empty arguments to tabular function");
+  }
+  if (registry == NULLPTR) {
+    registry = compute::GetFunctionRegistry();
+  }
+  ARROW_ASSIGN_OR_RAISE(auto func, registry->GetFunction(func_name));
+  if (func->kind() != compute::Function::SCALAR) {
+    return Status::Invalid("tabular function of non-scalar kind");
+  }
+  auto arity = func->arity();
+  if (arity.num_args != 0 || arity.is_varargs) {
+    return Status::NotImplemented("tabular function of non-null arity");
+  }
+  auto kernels =
+      arrow::internal::checked_pointer_cast<compute::ScalarFunction>(func)->kernels();
+  if (kernels.size() != 1) {
+    return Status::NotImplemented("tabular function with non-single kernel");
+  }
+  const compute::ScalarKernel* kernel = kernels[0];
+  auto out_type = kernel->signature->out_type();
+  if (out_type.kind() != compute::OutputType::FIXED) {
+    return Status::Invalid("tabular kernel of non-fixed kind");
+  }
+  auto datatype = out_type.type();
+  if (datatype->id() != Type::type::STRUCT) {
+    return Status::Invalid("tabular kernel with non-struct output");
+  }
+  auto struct_type = arrow::internal::checked_cast<StructType*>(datatype.get());
+  auto schema = ::arrow::schema(struct_type->fields());
+  std::vector<TypeHolder> in_types;
+  ARROW_ASSIGN_OR_RAISE(auto func_exec,
+                        GetFunctionExecutor(func_name, in_types, NULLPTR, registry));
+  auto next_func =
+      [schema,
+       func_exec = std::move(func_exec)]() -> Result<std::shared_ptr<RecordBatch>> {
+    std::vector<Datum> args;
+    // passed_length of -1 or 0 with args.size() of 0 leads to an empty ExecSpanIterator
+    // in exec.cc and to never invoking the source function, so 1 is passed instead
+    // TODO: GH-33612: Support batch size in user-defined tabular functions
+    ARROW_ASSIGN_OR_RAISE(auto datum, func_exec->Execute(args, /*passed_length=*/1));
+    if (!datum.is_array()) {
+      return Status::Invalid("UDF result of non-array kind");
+    }
+    std::shared_ptr<Array> array = datum.make_array();
+    if (array->length() == 0) {
+      return IterationTraits<std::shared_ptr<RecordBatch>>::End();
+    }
+    ARROW_ASSIGN_OR_RAISE(auto batch, RecordBatch::FromStructArray(std::move(array)));
+    if (!schema->Equals(batch->schema())) {
+      return Status::Invalid("UDF result with shape not conforming to schema");
+    }
+    return std::move(batch);
+  };
+  return RecordBatchReader::MakeFromIterator(MakeFunctionIterator(std::move(next_func)),
+                                             schema);
+}
+
+}  // namespace py
 }  // namespace arrow
diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h
index 9a3666459f..cde97d9cb9 100644
--- a/python/pyarrow/src/arrow/python/udf.h
+++ b/python/pyarrow/src/arrow/python/udf.h
@@ -21,6 +21,8 @@
 #include "arrow/compute/function.h"
 #include "arrow/compute/registry.h"
 #include "arrow/python/platform.h"
+#include "arrow/record_batch.h"
+#include "arrow/util/iterator.h"
 
 #include "arrow/python/common.h"
 #include "arrow/python/pyarrow.h"
@@ -33,7 +35,7 @@ namespace py {
 // TODO: TODO(ARROW-16041): UDF Options are not exposed to the Python
 // users. This feature will be included when extending to provide advanced
 // options for the users.
-struct ARROW_PYTHON_EXPORT ScalarUdfOptions {
+struct ARROW_PYTHON_EXPORT UdfOptions {
   std::string func_name;
   compute::Arity arity;
   compute::FunctionDoc func_doc;
@@ -47,13 +49,22 @@ struct ARROW_PYTHON_EXPORT ScalarUdfContext {
   int64_t batch_length;
 };
 
-using ScalarUdfWrapperCallback = std::function<PyObject*(
+using UdfWrapperCallback = std::function<PyObject*(
     PyObject* user_function, const ScalarUdfContext& context, PyObject* inputs)>;
 
 /// \brief register a Scalar user-defined-function from Python
-Status ARROW_PYTHON_EXPORT RegisterScalarFunction(PyObject* user_function,
-                                                  ScalarUdfWrapperCallback wrapper,
-                                                  const ScalarUdfOptions& options);
+Status ARROW_PYTHON_EXPORT RegisterScalarFunction(
+    PyObject* user_function, UdfWrapperCallback wrapper,
+    const UdfOptions& options, compute::FunctionRegistry* registry = NULLPTR);
+
+/// \brief register a Table user-defined-function from Python
+Status ARROW_PYTHON_EXPORT RegisterTabularFunction(
+    PyObject* user_function, UdfWrapperCallback wrapper,
+    const UdfOptions& options, compute::FunctionRegistry* registry = NULLPTR);
+
+Result<std::shared_ptr<RecordBatchReader>> ARROW_PYTHON_EXPORT CallTabularFunction(
+    const std::string& func_name, const std::vector<Datum>& args,
+    compute::FunctionRegistry* registry = NULLPTR);
 
 }  // namespace py
 
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index 35492d4f64..47f37a53ad 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -2550,6 +2550,20 @@ cdef class RecordBatch(_PandasConvertible):
                 CRecordBatch.FromStructArray(struct_array.sp_array))
         return pyarrow_wrap_batch(c_record_batch)
 
+    def to_struct_array(self):
+        """
+        Convert to a struct array.
+        """
+        cdef:
+            shared_ptr[CRecordBatch] c_record_batch
+            shared_ptr[CArray] c_array
+
+        c_record_batch = pyarrow_unwrap_batch(self)
+        with nogil:
+            c_array = GetResultValue(
+                <CResult[shared_ptr[CArray]]>deref(c_record_batch).ToStructArray())
+        return pyarrow_wrap_array(c_array)
+
     def _export_to_c(self, out_ptr, out_schema_ptr=0):
         """
         Export to a C ArrowArray struct, given its pointer.
diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py
index e711619582..6a67e0bae9 100644
--- a/python/pyarrow/tests/test_udf.py
+++ b/python/pyarrow/tests/test_udf.py
@@ -31,7 +31,7 @@ except ImportError:
     ds = None
 
 
-def mock_udf_context(batch_length=10):
+def mock_scalar_udf_context(batch_length=10):
     from pyarrow._compute import _get_scalar_udf_context
     return _get_scalar_udf_context(pa.default_memory_pool(), batch_length)
 
@@ -248,7 +248,7 @@ def check_scalar_function(func_fixture,
         if all_scalar:
             batch_length = 1
 
-    expected_output = function(mock_udf_context(batch_length), *inputs)
+    expected_output = function(mock_scalar_udf_context(batch_length), *inputs)
     func = pc.get_function(name)
     assert func.name == name
 
@@ -266,7 +266,7 @@ def check_scalar_function(func_fixture,
         assert result_table.column(0).chunks[0] == expected_output
 
 
-def test_scalar_udf_array_unary(unary_func_fixture):
+def test_udf_array_unary(unary_func_fixture):
     check_scalar_function(unary_func_fixture,
                           [
                               pa.array([10, 20], pa.int64())
@@ -274,7 +274,7 @@ def test_scalar_udf_array_unary(unary_func_fixture):
                           )
 
 
-def test_scalar_udf_array_binary(binary_func_fixture):
+def test_udf_array_binary(binary_func_fixture):
     check_scalar_function(binary_func_fixture,
                           [
                               pa.array([10, 20], pa.int64()),
@@ -283,7 +283,7 @@ def test_scalar_udf_array_binary(binary_func_fixture):
                           )
 
 
-def test_scalar_udf_array_ternary(ternary_func_fixture):
+def test_udf_array_ternary(ternary_func_fixture):
     check_scalar_function(ternary_func_fixture,
                           [
                               pa.array([10, 20], pa.int64()),
@@ -293,7 +293,7 @@ def test_scalar_udf_array_ternary(ternary_func_fixture):
                           )
 
 
-def test_scalar_udf_array_varargs(varargs_func_fixture):
+def test_udf_array_varargs(varargs_func_fixture):
     check_scalar_function(varargs_func_fixture,
                           [
                               pa.array([2, 3], pa.int64()),
@@ -464,7 +464,7 @@ def test_wrong_input_type_declaration():
                                     in_types, out_type)
 
 
-def test_udf_context(unary_func_fixture):
+def test_scalar_udf_context(unary_func_fixture):
     # Check the memory_pool argument is properly propagated
     proxy_pool = pa.proxy_memory_pool(pa.default_memory_pool())
     _, func_name = unary_func_fixture
@@ -504,3 +504,112 @@ def test_input_lifetime(unary_func_fixture):
     # Calling a UDF should not have kept `v` alive longer than required
     v = None
     assert proxy_pool.bytes_allocated() == 0
+
+
+def _record_batch_from_iters(schema, *iters):
+    arrays = [pa.array(list(v), type=schema[i].type)
+              for i, v in enumerate(iters)]
+    return pa.RecordBatch.from_arrays(arrays=arrays, schema=schema)
+
+
+def _record_batch_for_range(schema, n):
+    return _record_batch_from_iters(schema,
+                                    range(n, n + 10),
+                                    range(n + 1, n + 11))
+
+
+def make_udt_func(schema, batch_gen):
+    def udf_func(ctx):
+        class UDT:
+            def __init__(self):
+                self.caller = None
+
+            def __call__(self, ctx):
+                try:
+                    if self.caller is None:
+                        self.caller, ctx = batch_gen(ctx).send, None
+                    batch = self.caller(ctx)
+                except StopIteration:
+                    arrays = [pa.array([], type=field.type)
+                              for field in schema]
+                    batch = pa.RecordBatch.from_arrays(
+                        arrays=arrays, schema=schema)
+                return batch.to_struct_array()
+        return UDT()
+    return udf_func
+
+
+def datasource1_direct():
+    """A short dataset"""
+    schema = datasource1_schema()
+
+    class Generator:
+        def __init__(self):
+            self.n = 3
+
+        def __call__(self, ctx):
+            if self.n == 0:
+                batch = _record_batch_from_iters(schema, [], [])
+            else:
+                self.n -= 1
+                batch = _record_batch_for_range(schema, self.n)
+            return batch.to_struct_array()
+    return lambda ctx: Generator()
+
+
+def datasource1_generator():
+    schema = datasource1_schema()
+
+    def batch_gen(ctx):
+        for n in range(3, 0, -1):
+            # ctx =
+            yield _record_batch_for_range(schema, n - 1)
+    return make_udt_func(schema, batch_gen)
+
+
+def datasource1_exception():
+    schema = datasource1_schema()
+
+    def batch_gen(ctx):
+        for n in range(3, 0, -1):
+            # ctx =
+            yield _record_batch_for_range(schema, n - 1)
+        raise RuntimeError("datasource1_exception")
+    return make_udt_func(schema, batch_gen)
+
+
+def datasource1_schema():
+    return pa.schema([('', pa.int32()), ('', pa.int32())])
+
+
+def datasource1_args(func, func_name):
+    func_doc = {"summary": f"{func_name} UDT",
+                "description": "test {func_name} UDT"}
+    in_types = {}
+    out_type = pa.struct([("", pa.int32()), ("", pa.int32())])
+    return func, func_name, func_doc, in_types, out_type
+
+
+def _test_datasource1_udt(func_maker):
+    schema = datasource1_schema()
+    func = func_maker()
+    func_name = func_maker.__name__
+    func_args = datasource1_args(func, func_name)
+    pc.register_tabular_function(*func_args)
+    n = 3
+    for item in pc.call_tabular_function(func_name):
+        n -= 1
+        assert item == _record_batch_for_range(schema, n)
+
+
+def test_udt_datasource1_direct():
+    _test_datasource1_udt(datasource1_direct)
+
+
+def test_udt_datasource1_generator():
+    _test_datasource1_udt(datasource1_generator)
+
+
+def test_udt_datasource1_exception():
+    with pytest.raises(RuntimeError, match='datasource1_exception'):
+        _test_datasource1_udt(datasource1_exception)