You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ic...@apache.org on 2023/06/08 18:12:56 UTC

[arrow] branch main updated: GH-35515: [C++][Python] Add non decomposable aggregation UDF (#35514)

This is an automated email from the ASF dual-hosted git repository.

icexelloss pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8b5919d886 GH-35515: [C++][Python] Add non decomposable aggregation UDF (#35514)
8b5919d886 is described below

commit 8b5919d886125c3dae9dd5484f7e9e45ae8580d3
Author: Li Jin <ic...@gmail.com>
AuthorDate: Thu Jun 8 14:12:49 2023 -0400

    GH-35515: [C++][Python] Add non decomposable aggregation UDF (#35514)
    
    
    
    ### Rationale for this change
    Non decomposable aggregation is aggregation that cannot be split into consume/merge/finalize. This is often when the logic rewritten with external python libraries (numpy, pandas, statmodels, etc) and those either cannot be decomposed or not worthy the effect (these are often one-off function instead of reusable one). This PR implements the support for non decomposable aggregation UDFs.
    
    The major issue with non decomposable UDF is that the UDF needs to see all data at once, unlike scalar UDF where UDF only needs to see a batch at a time. This makes non decomposable not so useful as it is same as collect all the data to a pd.DataFrame and apply the UDF on it. However, one very application of non decomposable UDF is with segmented aggregation. To refresh, segmented aggregation works on ordered data and passed one logic chunk at a time (e.g., all data with the same date [...]
    
    ### What changes are included in this PR?
    This PR is currently WIP and not ready for review.
    
    So far I have implemented the minimal amount of code to make a basic test working but needs clean up, error handling etc.
    
    * [x] First round of self review
    * [x] Second round of self review
    * [x] Implement and test unary
    * [x] Implement and test varargs
    * [x] Implement and test Acero support with segmented aggregation
    
    ### Are these changes tested?
    Added new test calling with compute and acero.
    
    The compute tests calls the aggregation on the full array. The acero test callings the aggregation with segmented aggregation.
    
    ### Are there any user-facing changes?
    
    * Closes: #35515
    
    Lead-authored-by: Li Jin <ic...@gmail.com>
    Co-authored-by: Weston Pace <we...@gmail.com>
    Signed-off-by: Li Jin <ic...@gmail.com>
---
 cpp/src/arrow/engine/substrait/extension_set.cc |  27 ++--
 python/pyarrow/_compute.pxd                     |   6 +-
 python/pyarrow/_compute.pyx                     | 161 ++++++++++++++------
 python/pyarrow/compute.py                       |   3 +-
 python/pyarrow/conftest.py                      |  56 +++++++
 python/pyarrow/includes/libarrow.pxd            |   8 +-
 python/pyarrow/src/arrow/python/udf.cc          | 188 +++++++++++++++++++++++-
 python/pyarrow/src/arrow/python/udf.h           |  11 +-
 python/pyarrow/tests/test_substrait.py          | 156 +++++++++++++++++++-
 python/pyarrow/tests/test_udf.py                | 113 +++++++++++++-
 10 files changed, 652 insertions(+), 77 deletions(-)

diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc
index 5501889d7a..d89248383b 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.cc
+++ b/cpp/src/arrow/engine/substrait/extension_set.cc
@@ -954,7 +954,9 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate(
         return Status::Invalid("Expected aggregate call ", call.id().uri, "#",
                                call.id().name, " to have at least one argument");
       }
-      case 1: {
+      default: {
+        // Handles all arity > 0
+
         std::shared_ptr<compute::FunctionOptions> options = nullptr;
         if (arrow_function_name == "stddev" || arrow_function_name == "variance") {
           // See the following URL for the spec of stddev and variance:
@@ -981,21 +983,22 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate(
         }
         fixed_arrow_func += arrow_function_name;
 
-        ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0));
-        const FieldRef* arg_ref = arg.field_ref();
-        if (!arg_ref) {
-          return Status::Invalid("Expected an aggregate call ", call.id().uri, "#",
-                                 call.id().name, " to have a direct reference");
+        std::vector<FieldRef> target;
+        for (int i = 0; i < call.size(); i++) {
+          ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(i));
+          const FieldRef* arg_ref = arg.field_ref();
+          if (!arg_ref) {
+            return Status::Invalid("Expected an aggregate call ", call.id().uri, "#",
+                                   call.id().name, " to have a direct reference");
+          }
+          // Copy arg_ref here because field_ref() return const FieldRef*
+          target.emplace_back(*arg_ref);
         }
-
         return compute::Aggregate{std::move(fixed_arrow_func),
-                                  options ? std::move(options) : nullptr, *arg_ref, ""};
+                                  options ? std::move(options) : nullptr,
+                                  std::move(target), ""};
       }
-      default:
-        break;
     }
-    return Status::NotImplemented(
-        "Only nullary and unary aggregate functions are currently supported");
   };
 }
 
diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd
index 2dc0de2d0b..29b37da3ac 100644
--- a/python/pyarrow/_compute.pxd
+++ b/python/pyarrow/_compute.pxd
@@ -21,11 +21,11 @@ from pyarrow.lib cimport *
 from pyarrow.includes.common cimport *
 from pyarrow.includes.libarrow cimport *
 
-cdef class ScalarUdfContext(_Weakrefable):
+cdef class UdfContext(_Weakrefable):
     cdef:
-        CScalarUdfContext c_context
+        CUdfContext c_context
 
-    cdef void init(self, const CScalarUdfContext& c_context)
+    cdef void init(self, const CUdfContext& c_context)
 
 
 cdef class FunctionOptions(_Weakrefable):
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index a5db5be551..eaf9d1dfb6 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -2559,7 +2559,7 @@ cdef CExpression _bind(Expression filter, Schema schema) except *:
         deref(pyarrow_unwrap_schema(schema).get())))
 
 
-cdef class ScalarUdfContext:
+cdef class UdfContext:
     """
     Per-invocation function context/state.
 
@@ -2571,7 +2571,7 @@ cdef class ScalarUdfContext:
         raise TypeError("Do not call {}'s constructor directly"
                         .format(self.__class__.__name__))
 
-    cdef void init(self, const CScalarUdfContext &c_context):
+    cdef void init(self, const CUdfContext &c_context):
         self.c_context = c_context
 
     @property
@@ -2620,26 +2620,26 @@ cdef inline CFunctionDoc _make_function_doc(dict func_doc) except *:
     return f_doc
 
 
-cdef object box_scalar_udf_context(const CScalarUdfContext& c_context):
-    cdef ScalarUdfContext context = ScalarUdfContext.__new__(ScalarUdfContext)
+cdef object box_udf_context(const CUdfContext& c_context):
+    cdef UdfContext context = UdfContext.__new__(UdfContext)
     context.init(c_context)
     return context
 
 
-cdef _udf_callback(user_function, const CScalarUdfContext& c_context, inputs):
+cdef _udf_callback(user_function, const CUdfContext& c_context, inputs):
     """
-    Helper callback function used to wrap the ScalarUdfContext from Python to C++
+    Helper callback function used to wrap the UdfContext from Python to C++
     execution.
     """
-    context = box_scalar_udf_context(c_context)
+    context = box_udf_context(c_context)
     return user_function(context, *inputs)
 
 
-def _get_scalar_udf_context(memory_pool, batch_length):
-    cdef CScalarUdfContext c_context
+def _get_udf_context(memory_pool, batch_length):
+    cdef CUdfContext c_context
     c_context.pool = maybe_unbox_memory_pool(memory_pool)
     c_context.batch_length = batch_length
-    context = box_scalar_udf_context(c_context)
+    context = box_udf_context(c_context)
     return context
 
 
@@ -2665,11 +2665,19 @@ cdef get_register_tabular_function():
     return reg
 
 
+cdef get_register_aggregate_function():
+    cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf)
+    reg.register_func = RegisterAggregateFunction
+    return reg
+
+
 def register_scalar_function(func, function_name, function_doc, in_types, out_type,
                              func_registry=None):
     """
     Register a user-defined scalar function.
 
+    This API is EXPERIMENTAL.
+
     A scalar function is a function that executes elementwise
     operations on arrays or scalars, i.e. a scalar function must
     be computed row-by-row with no state where each output row
@@ -2684,17 +2692,18 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty
     func : callable
         A callable implementing the user-defined function.
         The first argument is the context argument of type
-        ScalarUdfContext.
+        UdfContext.
         Then, it must take arguments equal to the number of
         in_types defined. It must return an Array or Scalar
         matching the out_type. It must return a Scalar if
         all arguments are scalar, else it must return an Array.
 
         To define a varargs function, pass a callable that takes
-        varargs. The last in_type will be the type of all varargs
+        *args. The last in_type will be the type of all varargs
         arguments.
     function_name : str
-        Name of the function. This name must be globally unique.
+        Name of the function. There should only be one function
+        registered with this name in the function registry.
     function_doc : dict
         A dictionary object with keys "summary" (str),
         and "description" (str).
@@ -2738,9 +2747,86 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty
       21
     ]
     """
-    return _register_scalar_like_function(get_register_scalar_function(),
-                                          func, function_name, function_doc, in_types,
-                                          out_type, func_registry)
+    return _register_user_defined_function(get_register_scalar_function(),
+                                           func, function_name, function_doc, in_types,
+                                           out_type, func_registry)
+
+
+def register_aggregate_function(func, function_name, function_doc, in_types, out_type,
+                                func_registry=None):
+    """
+    Register a user-defined non-decomposable aggregate function.
+
+    This API is EXPERIMENTAL.
+
+    A non-decomposable aggregation function is a function that executes
+    aggregate operations on the whole data that it is aggregating.
+    In other words, non-decomposable aggregate function cannot be
+    split into consume/merge/finalize steps.
+
+    This is often used with ordered or segmented aggregation where groups
+    can be emit before accumulating all of the input data.
+
+    Parameters
+    ----------
+    func : callable
+        A callable implementing the user-defined function.
+        The first argument is the context argument of type
+        UdfContext.
+        Then, it must take arguments equal to the number of
+        in_types defined. It must return a Scalar matching the
+        out_type.
+        To define a varargs function, pass a callable that takes
+        *args. The in_type needs to match in type of inputs when
+        the function gets called.
+    function_name : str
+        Name of the function. This name must be unique, i.e.,
+        there should only be one function registered with
+        this name in the function registry.
+    function_doc : dict
+        A dictionary object with keys "summary" (str),
+        and "description" (str).
+    in_types : Dict[str, DataType]
+        A dictionary mapping function argument names to
+        their respective DataType.
+        The argument names will be used to generate
+        documentation for the function. The number of
+        arguments specified here determines the function
+        arity.
+    out_type : DataType
+        Output type of the function.
+    func_registry : FunctionRegistry
+        Optional function registry to use instead of the default global one.
+
+    Examples
+    --------
+    >>> import numpy as np
+    >>> import pyarrow as pa
+    >>> import pyarrow.compute as pc
+    >>>
+    >>> func_doc = {}
+    >>> func_doc["summary"] = "simple median udf"
+    >>> func_doc["description"] = "compute median"
+    >>>
+    >>> def compute_median(ctx, array):
+    ...     return pa.scalar(np.median(array))
+    >>>
+    >>> func_name = "py_compute_median"
+    >>> in_types = {"array": pa.int64()}
+    >>> out_type = pa.float64()
+    >>> pc.register_aggregate_function(compute_median, func_name, func_doc,
+    ...                   in_types, out_type)
+    >>>
+    >>> func = pc.get_function(func_name)
+    >>> func.name
+    'py_compute_median'
+    >>> answer = pc.call_function(func_name, [pa.array([20, 40])])
+    >>> answer
+    <pyarrow.DoubleScalar: 30.0>
+    """
+    return _register_user_defined_function(get_register_aggregate_function(),
+                                           func, function_name, function_doc, in_types,
+                                           out_type, func_registry)
 
 
 def register_tabular_function(func, function_name, function_doc, in_types, out_type,
@@ -2748,8 +2834,10 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t
     """
     Register a user-defined tabular function.
 
+    This API is EXPERIMENTAL.
+
     A tabular function is one accepting a context argument of type
-    ScalarUdfContext and returning a generator of struct arrays.
+    UdfContext and returning a generator of struct arrays.
     The in_types argument must be empty and the out_type argument
     specifies a schema. Each struct array must have field types
     correspoding to the schema.
@@ -2759,11 +2847,12 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t
     func : callable
         A callable implementing the user-defined function.
         The only argument is the context argument of type
-        ScalarUdfContext. It must return a callable that
+        UdfContext. It must return a callable that
         returns on each invocation a StructArray matching
         the out_type, where an empty array indicates end.
     function_name : str
-        Name of the function. This name must be globally unique.
+        Name of the function. There should only be one function
+        registered with this name in the function registry.
     function_doc : dict
         A dictionary object with keys "summary" (str),
         and "description" (str).
@@ -2783,46 +2872,34 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t
         with nogil:
             c_type = <shared_ptr[CDataType]>make_shared[CStructType](deref(c_schema).fields())
         out_type = pyarrow_wrap_data_type(c_type)
-    return _register_scalar_like_function(get_register_tabular_function(),
-                                          func, function_name, function_doc, in_types,
-                                          out_type, func_registry)
+    return _register_user_defined_function(get_register_tabular_function(),
+                                           func, function_name, function_doc, in_types,
+                                           out_type, func_registry)
 
 
-def _register_scalar_like_function(register_func, func, function_name, function_doc, in_types,
-                                   out_type, func_registry=None):
+def _register_user_defined_function(register_func, func, function_name, function_doc, in_types,
+                                    out_type, func_registry=None):
     """
-    Register a user-defined scalar-like function.
+    Register a user-defined function.
 
-    A scalar-like function is a callable accepting a first
-    context argument of type ScalarUdfContext as well as
-    possibly additional Arrow arguments, and returning a
-    an Arrow result appropriate for the kind of function.
-    A scalar function and a tabular function are examples
-    for scalar-like functions.
-    This function is normally not called directly but via
-    register_scalar_function or register_tabular_function.
+    This method itself doesn't care about the type of the UDF
+    (i.e., scalar vs tabular vs aggregate)
 
     Parameters
     ----------
     register_func: object
-        An object holding a CRegisterUdf in a "register_func" attribute. Use
-        get_register_scalar_function() for a scalar function and
-        get_register_tabular_function() for a tabular function.
+        An object holding a CRegisterUdf in a "register_func" attribute.
     func : callable
         A callable implementing the user-defined function.
-        See register_scalar_function and
-        register_tabular_function for details.
-
     function_name : str
-        Name of the function. This name must be globally unique.
+        Name of the function. There should only be one function
+        registered with this name in the function registry.
     function_doc : dict
         A dictionary object with keys "summary" (str),
         and "description" (str).
     in_types : Dict[str, DataType]
         A dictionary mapping function argument names to
         their respective DataType.
-        See register_scalar_function and
-        register_tabular_function for details.
     out_type : DataType
         Output type of the function.
     func_registry : FunctionRegistry
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index e299d44c04..e92f093547 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -84,7 +84,8 @@ from pyarrow._compute import (  # noqa
     call_tabular_function,
     register_scalar_function,
     register_tabular_function,
-    ScalarUdfContext,
+    register_aggregate_function,
+    UdfContext,
     # Expressions
     Expression,
 )
diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py
index ef09393cfb..f32cbf01ef 100644
--- a/python/pyarrow/conftest.py
+++ b/python/pyarrow/conftest.py
@@ -278,3 +278,59 @@ def unary_func_fixture():
                                 {"array": pa.int64()},
                                 pa.int64())
     return unary_function, func_name
+
+
+@pytest.fixture(scope="session")
+def unary_agg_func_fixture():
+    """
+    Register a unary aggregate function
+    """
+    from pyarrow import compute as pc
+    import numpy as np
+
+    def func(ctx, x):
+        return pa.scalar(np.nanmean(x))
+
+    func_name = "y=avg(x)"
+    func_doc = {"summary": "y=avg(x)",
+                "description": "find mean of x"}
+
+    pc.register_aggregate_function(func,
+                                   func_name,
+                                   func_doc,
+                                   {
+                                       "x": pa.float64(),
+                                   },
+                                   pa.float64()
+                                   )
+    return func, func_name
+
+
+@pytest.fixture(scope="session")
+def varargs_agg_func_fixture():
+    """
+    Register a unary aggregate function
+    """
+    from pyarrow import compute as pc
+    import numpy as np
+
+    def func(ctx, *args):
+        sum = 0.0
+        for arg in args:
+            sum += np.nanmean(arg)
+        return pa.scalar(sum)
+
+    func_name = "y=sum_mean(x...)"
+    func_doc = {"summary": "Varargs aggregate",
+                "description": "Varargs aggregate"}
+
+    pc.register_aggregate_function(func,
+                                   func_name,
+                                   func_doc,
+                                   {
+                                       "x": pa.int64(),
+                                       "y": pa.float64()
+                                   },
+                                   pa.float64()
+                                   )
+    return func, func_name
diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd
index 3190877ea0..86f21f4b52 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -2775,7 +2775,7 @@ cdef extern from "arrow/util/byte_size.h" namespace "arrow::util" nogil:
     int64_t TotalBufferSize(const CRecordBatch& record_batch)
     int64_t TotalBufferSize(const CTable& table)
 
-ctypedef PyObject* CallbackUdf(object user_function, const CScalarUdfContext& context, object inputs)
+ctypedef PyObject* CallbackUdf(object user_function, const CUdfContext& context, object inputs)
 
 
 cdef extern from "arrow/api.h" namespace "arrow" nogil:
@@ -2786,7 +2786,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
 
 
 cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil:
-    cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext":
+    cdef cppclass CUdfContext" arrow::py::UdfContext":
         CMemoryPool *pool
         int64_t batch_length
 
@@ -2805,5 +2805,9 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil:
                                     function[CallbackUdf] wrapper, const CUdfOptions& options,
                                     CFunctionRegistry* registry)
 
+    CStatus RegisterAggregateFunction(PyObject* function,
+                                      function[CallbackUdf] wrapper, const CUdfOptions& options,
+                                      CFunctionRegistry* registry)
+
     CResult[shared_ptr[CRecordBatchReader]] CallTabularFunction(
         const c_string& func_name, const vector[CDatum]& args, CFunctionRegistry* registry)
diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc
index 7d63adb835..06c116af82 100644
--- a/python/pyarrow/src/arrow/python/udf.cc
+++ b/python/pyarrow/src/arrow/python/udf.cc
@@ -16,14 +16,16 @@
 // under the License.
 
 #include "arrow/python/udf.h"
+#include "arrow/compute/api_aggregate.h"
 #include "arrow/compute/function.h"
 #include "arrow/compute/kernel.h"
 #include "arrow/python/common.h"
+#include "arrow/table.h"
 #include "arrow/util/checked_cast.h"
 
 namespace arrow {
+using internal::checked_cast;
 namespace py {
-
 namespace {
 
 struct PythonUdfKernelState : public compute::KernelState {
@@ -65,6 +67,26 @@ struct PythonUdfKernelInit {
   std::shared_ptr<OwnedRefNoGIL> function;
 };
 
+struct ScalarUdfAggregator : public compute::KernelState {
+  virtual Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) = 0;
+  virtual Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) = 0;
+  virtual Status Finalize(compute::KernelContext* ctx, Datum* out) = 0;
+};
+
+arrow::Status AggregateUdfConsume(compute::KernelContext* ctx,
+                                  const compute::ExecSpan& batch) {
+  return checked_cast<ScalarUdfAggregator*>(ctx->state())->Consume(ctx, batch);
+}
+
+arrow::Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src,
+                                compute::KernelState* dst) {
+  return checked_cast<ScalarUdfAggregator*>(dst)->MergeFrom(ctx, std::move(src));
+}
+
+arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) {
+  return checked_cast<ScalarUdfAggregator*>(ctx->state())->Finalize(ctx, out);
+}
+
 struct PythonTableUdfKernelInit {
   PythonTableUdfKernelInit(std::shared_ptr<OwnedRefNoGIL> function_maker,
                            UdfWrapperCallback cb)
@@ -82,12 +104,12 @@ struct PythonTableUdfKernelInit {
 
   Result<std::unique_ptr<compute::KernelState>> operator()(
       compute::KernelContext* ctx, const compute::KernelInitArgs&) {
-    ScalarUdfContext scalar_udf_context{ctx->memory_pool(), /*batch_length=*/0};
+    UdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0};
     std::unique_ptr<OwnedRefNoGIL> function;
-    RETURN_NOT_OK(SafeCallIntoPython([this, &scalar_udf_context, &function] {
+    RETURN_NOT_OK(SafeCallIntoPython([this, &udf_context, &function] {
       OwnedRef empty_tuple(PyTuple_New(0));
       function = std::make_unique<OwnedRefNoGIL>(
-          cb(function_maker->obj(), scalar_udf_context, empty_tuple.obj()));
+          cb(function_maker->obj(), udf_context, empty_tuple.obj()));
       RETURN_NOT_OK(CheckPyError());
       return Status::OK();
     }));
@@ -101,6 +123,105 @@ struct PythonTableUdfKernelInit {
   UdfWrapperCallback cb;
 };
 
+struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
+  PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb,
+                                std::shared_ptr<OwnedRefNoGIL> agg_function,
+                                std::vector<std::shared_ptr<DataType>> input_types,
+                                std::shared_ptr<DataType> output_type)
+      : agg_cb(std::move(agg_cb)),
+        agg_function(agg_function),
+        output_type(std::move(output_type)) {
+    Py_INCREF(agg_function->obj());
+    std::vector<std::shared_ptr<Field>> fields;
+    for (size_t i = 0; i < input_types.size(); i++) {
+      fields.push_back(field("", input_types[i]));
+    }
+    input_schema = schema(std::move(fields));
+  };
+
+  ~PythonUdfScalarAggregatorImpl() override {
+    if (_Py_IsFinalizing()) {
+      agg_function->detach();
+    }
+  }
+
+  Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) override {
+    ARROW_ASSIGN_OR_RAISE(
+        auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool()));
+    values.push_back(std::move(rb));
+    return Status::OK();
+  }
+
+  Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) override {
+    auto& other_values = checked_cast<PythonUdfScalarAggregatorImpl&>(src).values;
+    values.insert(values.end(), std::make_move_iterator(other_values.begin()),
+                  std::make_move_iterator(other_values.end()));
+
+    other_values.erase(other_values.begin(), other_values.end());
+    return Status::OK();
+  }
+
+  Status Finalize(compute::KernelContext* ctx, Datum* out) override {
+    auto state =
+        arrow::internal::checked_cast<PythonUdfScalarAggregatorImpl*>(ctx->state());
+    std::shared_ptr<OwnedRefNoGIL>& function = state->agg_function;
+    const int num_args = input_schema->num_fields();
+
+    // Note: The way that batches are concatenated together
+    // would result in using double amount of the memory.
+    // This is OK for now because non decomposable aggregate
+    // UDF is supposed to be used with segmented aggregation
+    // where the size of the segment is more or less constant
+    // so doubling that is not a big deal. This can be also
+    // improved in the future to use more efficient way to
+    // concatenate.
+    ARROW_ASSIGN_OR_RAISE(auto table,
+                          arrow::Table::FromRecordBatches(input_schema, values));
+    ARROW_ASSIGN_OR_RAISE(table, table->CombineChunks(ctx->memory_pool()));
+    UdfContext udf_context{ctx->memory_pool(), table->num_rows()};
+
+    if (table->num_rows() == 0) {
+      return Status::Invalid("Finalized is called with empty inputs");
+    }
+
+    RETURN_NOT_OK(SafeCallIntoPython([&] {
+      std::unique_ptr<OwnedRef> result;
+      OwnedRef arg_tuple(PyTuple_New(num_args));
+      RETURN_NOT_OK(CheckPyError());
+
+      for (int arg_id = 0; arg_id < num_args; arg_id++) {
+        // Since we combined chunks there is only one chunk
+        std::shared_ptr<Array> c_data = table->column(arg_id)->chunk(0);
+        PyObject* data = wrap_array(c_data);
+        PyTuple_SetItem(arg_tuple.obj(), arg_id, data);
+      }
+      result = std::make_unique<OwnedRef>(
+          agg_cb(function->obj(), udf_context, arg_tuple.obj()));
+      RETURN_NOT_OK(CheckPyError());
+      // unwrapping the output for expected output type
+      if (is_scalar(result->obj())) {
+        ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> val, unwrap_scalar(result->obj()));
+        if (*output_type != *val->type) {
+          return Status::TypeError("Expected output datatype ", output_type->ToString(),
+                                   ", but function returned datatype ",
+                                   val->type->ToString());
+        }
+        out->value = std::move(val);
+        return Status::OK();
+      }
+      return Status::TypeError("Unexpected output type: ",
+                               Py_TYPE(result->obj())->tp_name, " (expected Scalar)");
+    }));
+    return Status::OK();
+  }
+
+  UdfWrapperCallback agg_cb;
+  std::vector<std::shared_ptr<RecordBatch>> values;
+  std::shared_ptr<OwnedRefNoGIL> agg_function;
+  std::shared_ptr<Schema> input_schema;
+  std::shared_ptr<DataType> output_type;
+};
+
 struct PythonUdf : public PythonUdfKernelState {
   PythonUdf(std::shared_ptr<OwnedRefNoGIL> function, UdfWrapperCallback cb,
             std::vector<TypeHolder> input_types, compute::OutputType output_type)
@@ -130,7 +251,7 @@ struct PythonUdf : public PythonUdfKernelState {
     auto state = arrow::internal::checked_cast<PythonUdfKernelState*>(ctx->state());
     std::shared_ptr<OwnedRefNoGIL>& function = state->function;
     const int num_args = batch.num_values();
-    ScalarUdfContext scalar_udf_context{ctx->memory_pool(), batch.length};
+    UdfContext udf_context{ctx->memory_pool(), batch.length};
 
     OwnedRef arg_tuple(PyTuple_New(num_args));
     RETURN_NOT_OK(CheckPyError());
@@ -146,7 +267,7 @@ struct PythonUdf : public PythonUdfKernelState {
       }
     }
 
-    OwnedRef result(cb(function->obj(), scalar_udf_context, arg_tuple.obj()));
+    OwnedRef result(cb(function->obj(), udf_context, arg_tuple.obj()));
     RETURN_NOT_OK(CheckPyError());
     // unwrapping the output for expected output type
     if (is_array(result.obj())) {
@@ -234,6 +355,61 @@ Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapp
       wrapper, options, registry);
 }
 
+Status AddAggKernel(std::shared_ptr<compute::KernelSignature> sig,
+                    compute::KernelInit init, compute::ScalarAggregateFunction* func) {
+  compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init),
+                                        AggregateUdfConsume, AggregateUdfMerge,
+                                        AggregateUdfFinalize, /*ordered=*/false);
+  RETURN_NOT_OK(func->AddKernel(std::move(kernel)));
+  return Status::OK();
+}
+
+Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_wrapper,
+                                 const UdfOptions& options,
+                                 compute::FunctionRegistry* registry) {
+  if (!PyCallable_Check(agg_function)) {
+    return Status::TypeError("Expected a callable Python object.");
+  }
+
+  if (registry == NULLPTR) {
+    registry = compute::GetFunctionRegistry();
+  }
+
+  // Py_INCREF here so that once a function is registered
+  // its refcount gets increased by 1 and doesn't get gced
+  // if all existing refs are gone
+  Py_INCREF(agg_function);
+
+  static auto default_scalar_aggregate_options =
+      compute::ScalarAggregateOptions::Defaults();
+  auto aggregate_func = std::make_shared<compute::ScalarAggregateFunction>(
+      options.func_name, options.arity, options.func_doc,
+      &default_scalar_aggregate_options);
+
+  std::vector<compute::InputType> input_types;
+  for (const auto& in_dtype : options.input_types) {
+    input_types.emplace_back(in_dtype);
+  }
+  compute::OutputType output_type(options.output_type);
+
+  compute::KernelInit init = [agg_wrapper, agg_function, options](
+                                 compute::KernelContext* ctx,
+                                 const compute::KernelInitArgs& args)
+      -> Result<std::unique_ptr<compute::KernelState>> {
+    return std::make_unique<PythonUdfScalarAggregatorImpl>(
+        agg_wrapper, std::make_shared<OwnedRefNoGIL>(agg_function), options.input_types,
+        options.output_type);
+  };
+
+  RETURN_NOT_OK(AddAggKernel(
+      compute::KernelSignature::Make(std::move(input_types), std::move(output_type),
+                                     options.arity.is_varargs),
+      init, aggregate_func.get()));
+
+  RETURN_NOT_OK(registry->AddFunction(std::move(aggregate_func)));
+  return Status::OK();
+}
+
 Result<std::shared_ptr<RecordBatchReader>> CallTabularFunction(
     const std::string& func_name, const std::vector<Datum>& args,
     compute::FunctionRegistry* registry) {
diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h
index b3dcc9ccf4..682cbb2ffe 100644
--- a/python/pyarrow/src/arrow/python/udf.h
+++ b/python/pyarrow/src/arrow/python/udf.h
@@ -43,14 +43,14 @@ struct ARROW_PYTHON_EXPORT UdfOptions {
   std::shared_ptr<DataType> output_type;
 };
 
-/// \brief A context passed as the first argument of scalar UDF functions.
-struct ARROW_PYTHON_EXPORT ScalarUdfContext {
+/// \brief A context passed as the first argument of UDF functions.
+struct ARROW_PYTHON_EXPORT UdfContext {
   MemoryPool* pool;
   int64_t batch_length;
 };
 
 using UdfWrapperCallback = std::function<PyObject*(
-    PyObject* user_function, const ScalarUdfContext& context, PyObject* inputs)>;
+    PyObject* user_function, const UdfContext& context, PyObject* inputs)>;
 
 /// \brief register a Scalar user-defined-function from Python
 Status ARROW_PYTHON_EXPORT RegisterScalarFunction(
@@ -62,6 +62,11 @@ Status ARROW_PYTHON_EXPORT RegisterTabularFunction(
     PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& options,
     compute::FunctionRegistry* registry = NULLPTR);
 
+/// \brief register a Aggregate user-defined-function from Python
+Status ARROW_PYTHON_EXPORT RegisterAggregateFunction(
+    PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& options,
+    compute::FunctionRegistry* registry = NULLPTR);
+
 Result<std::shared_ptr<RecordBatchReader>> ARROW_PYTHON_EXPORT
 CallTabularFunction(const std::string& func_name, const std::vector<Datum>& args,
                     compute::FunctionRegistry* registry = NULLPTR);
diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py
index d0da517ea7..34faaa157a 100644
--- a/python/pyarrow/tests/test_substrait.py
+++ b/python/pyarrow/tests/test_substrait.py
@@ -34,9 +34,9 @@ except ImportError:
 pytestmark = [pytest.mark.dataset, pytest.mark.substrait]
 
 
-def mock_scalar_udf_context(batch_length=10):
-    from pyarrow._compute import _get_scalar_udf_context
-    return _get_scalar_udf_context(pa.default_memory_pool(), batch_length)
+def mock_udf_context(batch_length=10):
+    from pyarrow._compute import _get_udf_context
+    return _get_udf_context(pa.default_memory_pool(), batch_length)
 
 
 def _write_dummy_data_to_disk(tmpdir, file_name, table):
@@ -442,7 +442,7 @@ def test_udf_via_substrait(unary_func_fixture, use_threads):
 
     function, name = unary_func_fixture
     expected_tb = test_table.add_column(1, 'y', function(
-        mock_scalar_udf_context(10), test_table['x']))
+        mock_udf_context(10), test_table['x']))
     assert res_tb == expected_tb
 
 
@@ -605,3 +605,151 @@ def test_output_field_names(use_threads):
     expected = pa.Table.from_pydict({"out": [1, 2, 3]})
 
     assert res_tb == expected
+
+
+def test_aggregate_udf_basic(varargs_agg_func_fixture):
+
+    test_table = pa.Table.from_pydict(
+        {"k": [1, 1, 2, 2], "v1": [1, 2, 3, 4],
+         "v2": [1.0, 1.0, 1.0, 1.0]}
+    )
+
+    def table_provider(names, _):
+        return test_table
+
+    substrait_query = b"""
+{
+  "extensionUris": [
+    {
+      "extensionUriAnchor": 1,
+      "uri": "urn:arrow:substrait_simple_extension_function"
+    },
+  ],
+  "extensions": [
+    {
+      "extensionFunction": {
+        "extensionUriReference": 1,
+        "functionAnchor": 1,
+        "name": "y=sum_mean(x...)"
+      }
+    }
+  ],
+  "relations": [
+    {
+      "root": {
+        "input": {
+          "extensionSingle": {
+            "common": {
+              "emit": {
+                "outputMapping": [
+                  0,
+                  1
+                ]
+              }
+            },
+            "input": {
+              "read": {
+                "baseSchema": {
+                  "names": [
+                    "k",
+                    "v1",
+                    "v2",
+                  ],
+                  "struct": {
+                    "types": [
+                      {
+                        "i64": {
+                          "nullability": "NULLABILITY_REQUIRED"
+                        }
+                      },
+                      {
+                        "i64": {
+                          "nullability": "NULLABILITY_NULLABLE"
+                        }
+                      },
+                      {
+                        "fp64": {
+                          "nullability": "NULLABILITY_NULLABLE"
+                        }
+                      }
+                    ],
+                    "nullability": "NULLABILITY_REQUIRED"
+                  }
+                },
+                "namedTable": {
+                  "names": ["t1"]
+                }
+              }
+            },
+            "detail": {
+              "@type": "/arrow.substrait_ext.SegmentedAggregateRel",
+              "segmentKeys": [
+                {
+                  "directReference": {
+                    "structField": {}
+                  },
+                  "rootReference": {}
+                }
+              ],
+              "measures": [
+                {
+                  "measure": {
+                    "functionReference": 1,
+                    "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+                    "outputType": {
+                      "fp64": {
+                        "nullability": "NULLABILITY_NULLABLE"
+                      }
+                    },
+                    "arguments": [
+                      {
+                        "value": {
+                          "selection": {
+                            "directReference": {
+                              "structField": {
+                                "field": 1
+                              }
+                            },
+                            "rootReference": {}
+                          }
+                        }
+                      },
+                      {
+                        "value": {
+                          "selection": {
+                            "directReference": {
+                              "structField": {
+                                "field": 2
+                              }
+                            },
+                            "rootReference": {}
+                          }
+                        }
+                      }
+                    ]
+                  }
+                }
+              ]
+            }
+          }
+        },
+        "names": [
+          "k",
+          "v_avg"
+        ]
+      }
+    }
+  ],
+}
+"""
+    buf = pa._substrait._parse_json_plan(substrait_query)
+    reader = pa.substrait.run_query(
+        buf, table_provider=table_provider, use_threads=False)
+    res_tb = reader.read_all()
+
+    expected_tb = pa.Table.from_pydict({
+        'k': [1, 2],
+        'v_avg': [2.5, 4.5]
+    })
+
+    assert res_tb == expected_tb
diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py
index 0f336555f7..c0cfd3d26e 100644
--- a/python/pyarrow/tests/test_udf.py
+++ b/python/pyarrow/tests/test_udf.py
@@ -24,21 +24,82 @@ from pyarrow import compute as pc
 # UDFs are all tested with a dataset scan
 pytestmark = pytest.mark.dataset
 
+# For convience, most of the test here doesn't care about udf func docs
+empty_udf_doc = {"summary": "", "description": ""}
+
 try:
     import pyarrow.dataset as ds
 except ImportError:
     ds = None
 
 
-def mock_scalar_udf_context(batch_length=10):
-    from pyarrow._compute import _get_scalar_udf_context
-    return _get_scalar_udf_context(pa.default_memory_pool(), batch_length)
+def mock_udf_context(batch_length=10):
+    from pyarrow._compute import _get_udf_context
+    return _get_udf_context(pa.default_memory_pool(), batch_length)
 
 
 class MyError(RuntimeError):
     pass
 
 
+@pytest.fixture(scope="session")
+def exception_agg_func_fixture():
+    def func(ctx, x):
+        raise RuntimeError("Oops")
+        return pa.scalar(len(x))
+
+    func_name = "y=exception_len(x)"
+    func_doc = empty_udf_doc
+
+    pc.register_aggregate_function(func,
+                                   func_name,
+                                   func_doc,
+                                   {
+                                       "x": pa.int64(),
+                                   },
+                                   pa.int64()
+                                   )
+    return func, func_name
+
+
+@pytest.fixture(scope="session")
+def wrong_output_dtype_agg_func_fixture(scope="session"):
+    def func(ctx, x):
+        return pa.scalar(len(x), pa.int32())
+
+    func_name = "y=wrong_output_dtype(x)"
+    func_doc = empty_udf_doc
+
+    pc.register_aggregate_function(func,
+                                   func_name,
+                                   func_doc,
+                                   {
+                                       "x": pa.int64(),
+                                   },
+                                   pa.int64()
+                                   )
+    return func, func_name
+
+
+@pytest.fixture(scope="session")
+def wrong_output_type_agg_func_fixture(scope="session"):
+    def func(ctx, x):
+        return len(x)
+
+    func_name = "y=wrong_output_type(x)"
+    func_doc = empty_udf_doc
+
+    pc.register_aggregate_function(func,
+                                   func_name,
+                                   func_doc,
+                                   {
+                                       "x": pa.int64(),
+                                   },
+                                   pa.int64()
+                                   )
+    return func, func_name
+
+
 @pytest.fixture(scope="session")
 def binary_func_fixture():
     """
@@ -228,11 +289,11 @@ def check_scalar_function(func_fixture,
         if all_scalar:
             batch_length = 1
 
-    expected_output = function(mock_scalar_udf_context(batch_length), *inputs)
     func = pc.get_function(name)
     assert func.name == name
 
     result = pc.call_function(name, inputs, length=batch_length)
+    expected_output = function(mock_udf_context(batch_length), *inputs)
     assert result == expected_output
     # At the moment there is an issue when handling nullary functions.
     # See: ARROW-15286 and ARROW-16290.
@@ -593,3 +654,47 @@ def test_udt_datasource1_generator():
 def test_udt_datasource1_exception():
     with pytest.raises(RuntimeError, match='datasource1_exception'):
         _test_datasource1_udt(datasource1_exception)
+
+
+def test_agg_basic(unary_agg_func_fixture):
+    arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64())
+    result = pc.call_function("y=avg(x)", [arr])
+    expected = pa.scalar(30.0)
+    assert result == expected
+
+
+def test_agg_empty(unary_agg_func_fixture):
+    empty = pa.array([], pa.float64())
+
+    with pytest.raises(pa.ArrowInvalid, match='empty inputs'):
+        pc.call_function("y=avg(x)", [empty])
+
+
+def test_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture):
+    arr = pa.array([10, 20, 30, 40, 50], pa.int64())
+    with pytest.raises(pa.ArrowTypeError, match="output datatype"):
+        pc.call_function("y=wrong_output_dtype(x)", [arr])
+
+
+def test_agg_wrong_output_type(wrong_output_type_agg_func_fixture):
+    arr = pa.array([10, 20, 30, 40, 50], pa.int64())
+    with pytest.raises(pa.ArrowTypeError, match="output type"):
+        pc.call_function("y=wrong_output_type(x)", [arr])
+
+
+def test_agg_varargs(varargs_agg_func_fixture):
+    arr1 = pa.array([10, 20, 30, 40, 50], pa.int64())
+    arr2 = pa.array([1.0, 2.0, 3.0, 4.0, 5.0], pa.float64())
+
+    result = pc.call_function(
+        "y=sum_mean(x...)", [arr1, arr2]
+    )
+    expected = pa.scalar(33.0)
+    assert result == expected
+
+
+def test_agg_exception(exception_agg_func_fixture):
+    arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64())
+
+    with pytest.raises(RuntimeError, match='Oops'):
+        pc.call_function("y=exception_len(x)", [arr])