You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ic...@apache.org on 2023/06/29 14:05:04 UTC

[arrow] branch main updated: GH-36252: [Python] Add non decomposable hash aggregate UDF (#36253)

This is an automated email from the ASF dual-hosted git repository.

icexelloss pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new baf17a20a3 GH-36252: [Python] Add non decomposable hash aggregate UDF  (#36253)
baf17a20a3 is described below

commit baf17a20a351ba3782b621bbed4d2db4a41ae0e1
Author: Li Jin <ic...@gmail.com>
AuthorDate: Thu Jun 29 10:04:56 2023 -0400

    GH-36252: [Python] Add non decomposable hash aggregate UDF  (#36253)
    
    ### Rationale for this change
    
    In https://github.com/apache/arrow/issues/35515,
    
    I have implemented a Scalar version of the non decomposable UDF (Scalar as in SCALAR_AGGREGATE). I would like to support the Hash version of it (Hash as in HASH_AGGREGATE)
    
    With this PR, user can register an aggregate UDF once with `pc.register_aggregate_function` and it can be used as both scalar aggregate function and hash aggregate function.
    
    Example:
    
    ```
    def median(x):
        return pa.scalar(np.nanmedian(x))
    
    pc.register_aggregate_function(func=median, func_name='median_udf', ...)
    
    table = ...
    table.groupby("id").aggregate([("v", 'median_udf')])
    ```
    
    ### What changes are included in this PR?
    
    The main changes are:
    * In ResigterAggregateFunction (udf.cc), we now register the function both as a scalar aggregate function and a hash aggregate function (with signature adjustment for hash aggregate kernel because we need to append the grouping key)
    * Implemented PythonUdfHashAggregateImpl, similar to the PythonUdfScalarAggregateImpl. In Consume, it will accumulate both the input batches and the group id array. In Merge, it will merge the input batches and group id array (with the group_id_mapping). In Finalize, it will apply groupings to the accumulated batches to create one record batch per group, then apply the UDF over each group.
    * Some code clean up - `UdfWrapperCallback` objects are named `cb` (previously, `agg_cb` or `wrapper`) now and the user defined python function is now just called `function` (previously `agg_function`)
    
    For table.groupby().aggregate(...), the space complexity is O(n) where n is the size of the table (and therefore, is not very useful). However, this is more useful in the segmented aggregation case, where the space complexity of O(s), where s the size of the segments.
    
    ### Are these changes tested?
    Added new test in test_udf.py (with table.group_by().aggregate() and test_substrait.py (with segmented aggregation)
    
    ### Are there any user-facing changes?
    Yes with this change, user can call use registered aggregate UDF with `table.group_by().aggregate() ` or Acero's segmented aggregation.
    
    ### Checklist
    
    - [x] Self Review
    - [x] API Documentation
    
    * Closes: #36252
    
    Lead-authored-by: Li Jin <ic...@gmail.com>
    Co-authored-by: Weston Pace <we...@gmail.com>
    Signed-off-by: Li Jin <ic...@gmail.com>
---
 python/pyarrow/_compute.pyx            |  12 ++
 python/pyarrow/conftest.py             |  10 +-
 python/pyarrow/src/arrow/python/udf.cc | 348 +++++++++++++++++++++++++++++----
 python/pyarrow/tests/test_substrait.py | 174 ++++++++++++++++-
 python/pyarrow/tests/test_udf.py       | 117 ++++++++++-
 5 files changed, 602 insertions(+), 59 deletions(-)

diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index a33a09548d..7e32ba48fb 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -2799,6 +2799,9 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out
     This is often used with ordered or segmented aggregation where groups
     can be emit before accumulating all of the input data.
 
+    Note that currently the size of any input column can not exceed 2 GB
+    for a single segment (all groups combined).
+
     Parameters
     ----------
     func : callable
@@ -2855,6 +2858,15 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out
     >>> answer = pc.call_function(func_name, [pa.array([20, 40])])
     >>> answer
     <pyarrow.DoubleScalar: 30.0>
+    >>> table = pa.table([pa.array([1, 1, 2, 2]), pa.array([10, 20, 30, 40])], names=['k', 'v'])
+    >>> result = table.group_by('k').aggregate([('v', 'py_compute_median')])
+    >>> result
+    pyarrow.Table
+    k: int64
+    v_py_compute_median: double
+    ----
+    k: [[1,2]]
+    v_py_compute_median: [[15,35]]
     """
     return _register_user_defined_function(get_register_aggregate_function(),
                                            func, function_name, function_doc, in_types,
diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py
index f32cbf01ef..6f6807e907 100644
--- a/python/pyarrow/conftest.py
+++ b/python/pyarrow/conftest.py
@@ -20,6 +20,8 @@ import pyarrow as pa
 from pyarrow import Codec
 from pyarrow import fs
 
+import numpy as np
+
 groups = [
     'acero',
     'brotli',
@@ -283,15 +285,14 @@ def unary_func_fixture():
 @pytest.fixture(scope="session")
 def unary_agg_func_fixture():
     """
-    Register a unary aggregate function
+    Register a unary aggregate function (mean)
     """
     from pyarrow import compute as pc
-    import numpy as np
 
     def func(ctx, x):
         return pa.scalar(np.nanmean(x))
 
-    func_name = "y=avg(x)"
+    func_name = "mean_udf"
     func_doc = {"summary": "y=avg(x)",
                 "description": "find mean of x"}
 
@@ -312,7 +313,6 @@ def varargs_agg_func_fixture():
     Register a unary aggregate function
     """
     from pyarrow import compute as pc
-    import numpy as np
 
     def func(ctx, *args):
         sum = 0.0
@@ -320,7 +320,7 @@ def varargs_agg_func_fixture():
             sum += np.nanmean(arg)
         return pa.scalar(sum)
 
-    func_name = "y=sum_mean(x...)"
+    func_name = "sum_mean"
     func_doc = {"summary": "Varargs aggregate",
                 "description": "Varargs aggregate"}
 
diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc
index 06c116af82..435c89f596 100644
--- a/python/pyarrow/src/arrow/python/udf.cc
+++ b/python/pyarrow/src/arrow/python/udf.cc
@@ -16,15 +16,25 @@
 // under the License.
 
 #include "arrow/python/udf.h"
+#include "arrow/array/builder_base.h"
+#include "arrow/buffer_builder.h"
 #include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/api_vector.h"
 #include "arrow/compute/function.h"
 #include "arrow/compute/kernel.h"
+#include "arrow/compute/row/grouper.h"
 #include "arrow/python/common.h"
 #include "arrow/table.h"
 #include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
 
 namespace arrow {
+using compute::ExecSpan;
+using compute::Grouper;
+using compute::KernelContext;
+using compute::KernelState;
 using internal::checked_cast;
+
 namespace py {
 namespace {
 
@@ -73,6 +83,13 @@ struct ScalarUdfAggregator : public compute::KernelState {
   virtual Status Finalize(compute::KernelContext* ctx, Datum* out) = 0;
 };
 
+struct HashUdfAggregator : public compute::KernelState {
+  virtual Status Resize(KernelContext* ctx, int64_t size) = 0;
+  virtual Status Consume(KernelContext* ctx, const ExecSpan& batch) = 0;
+  virtual Status Merge(KernelContext* ct, KernelState&& other, const ArrayData&) = 0;
+  virtual Status Finalize(KernelContext* ctx, Datum* out) = 0;
+};
+
 arrow::Status AggregateUdfConsume(compute::KernelContext* ctx,
                                   const compute::ExecSpan& batch) {
   return checked_cast<ScalarUdfAggregator*>(ctx->state())->Consume(ctx, batch);
@@ -87,6 +104,24 @@ arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* ou
   return checked_cast<ScalarUdfAggregator*>(ctx->state())->Finalize(ctx, out);
 }
 
+arrow::Status HashAggregateUdfResize(KernelContext* ctx, int64_t size) {
+  return checked_cast<HashUdfAggregator*>(ctx->state())->Resize(ctx, size);
+}
+
+arrow::Status HashAggregateUdfConsume(KernelContext* ctx, const ExecSpan& batch) {
+  return checked_cast<HashUdfAggregator*>(ctx->state())->Consume(ctx, batch);
+}
+
+arrow::Status HashAggregateUdfMerge(KernelContext* ctx, KernelState&& src,
+                                    const ArrayData& group_id_mapping) {
+  return checked_cast<HashUdfAggregator*>(ctx->state())
+      ->Merge(ctx, std::move(src), group_id_mapping);
+}
+
+arrow::Status HashAggregateUdfFinalize(KernelContext* ctx, Datum* out) {
+  return checked_cast<HashUdfAggregator*>(ctx->state())->Finalize(ctx, out);
+}
+
 struct PythonTableUdfKernelInit {
   PythonTableUdfKernelInit(std::shared_ptr<OwnedRefNoGIL> function_maker,
                            UdfWrapperCallback cb)
@@ -124,14 +159,12 @@ struct PythonTableUdfKernelInit {
 };
 
 struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
-  PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb,
-                                std::shared_ptr<OwnedRefNoGIL> agg_function,
+  PythonUdfScalarAggregatorImpl(std::shared_ptr<OwnedRefNoGIL> function,
+                                UdfWrapperCallback cb,
                                 std::vector<std::shared_ptr<DataType>> input_types,
                                 std::shared_ptr<DataType> output_type)
-      : agg_cb(std::move(agg_cb)),
-        agg_function(agg_function),
-        output_type(std::move(output_type)) {
-    Py_INCREF(agg_function->obj());
+      : function(function), cb(std::move(cb)), output_type(std::move(output_type)) {
+    Py_INCREF(function->obj());
     std::vector<std::shared_ptr<Field>> fields;
     for (size_t i = 0; i < input_types.size(); i++) {
       fields.push_back(field("", input_types[i]));
@@ -141,7 +174,7 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
 
   ~PythonUdfScalarAggregatorImpl() override {
     if (_Py_IsFinalizing()) {
-      agg_function->detach();
+      function->detach();
     }
   }
 
@@ -164,7 +197,6 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
   Status Finalize(compute::KernelContext* ctx, Datum* out) override {
     auto state =
         arrow::internal::checked_cast<PythonUdfScalarAggregatorImpl*>(ctx->state());
-    std::shared_ptr<OwnedRefNoGIL>& function = state->agg_function;
     const int num_args = input_schema->num_fields();
 
     // Note: The way that batches are concatenated together
@@ -195,8 +227,8 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
         PyObject* data = wrap_array(c_data);
         PyTuple_SetItem(arg_tuple.obj(), arg_id, data);
       }
-      result = std::make_unique<OwnedRef>(
-          agg_cb(function->obj(), udf_context, arg_tuple.obj()));
+      result =
+          std::make_unique<OwnedRef>(cb(function->obj(), udf_context, arg_tuple.obj()));
       RETURN_NOT_OK(CheckPyError());
       // unwrapping the output for expected output type
       if (is_scalar(result->obj())) {
@@ -215,9 +247,164 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
     return Status::OK();
   }
 
-  UdfWrapperCallback agg_cb;
+  std::shared_ptr<OwnedRefNoGIL> function;
+  UdfWrapperCallback cb;
   std::vector<std::shared_ptr<RecordBatch>> values;
-  std::shared_ptr<OwnedRefNoGIL> agg_function;
+  std::shared_ptr<Schema> input_schema;
+  std::shared_ptr<DataType> output_type;
+};
+
+struct PythonUdfHashAggregatorImpl : public HashUdfAggregator {
+  PythonUdfHashAggregatorImpl(std::shared_ptr<OwnedRefNoGIL> function,
+                              UdfWrapperCallback cb,
+                              std::vector<std::shared_ptr<DataType>> input_types,
+                              std::shared_ptr<DataType> output_type)
+      : function(function), cb(std::move(cb)), output_type(std::move(output_type)) {
+    Py_INCREF(function->obj());
+    std::vector<std::shared_ptr<Field>> fields;
+    fields.reserve(input_types.size());
+    for (size_t i = 0; i < input_types.size(); i++) {
+      fields.push_back(field("", input_types[i]));
+    }
+    input_schema = schema(std::move(fields));
+  };
+
+  ~PythonUdfHashAggregatorImpl() override {
+    if (_Py_IsFinalizing()) {
+      function->detach();
+    }
+  }
+
+  // same as ApplyGrouping in parition.cc
+  // replicated the code here to avoid complicating the dependencies
+  static Result<RecordBatchVector> ApplyGroupings(
+      const ListArray& groupings, const std::shared_ptr<RecordBatch>& batch) {
+    ARROW_ASSIGN_OR_RAISE(Datum sorted,
+                          compute::Take(batch, groupings.data()->child_data[0]));
+
+    const auto& sorted_batch = *sorted.record_batch();
+
+    RecordBatchVector out(static_cast<size_t>(groupings.length()));
+    for (size_t i = 0; i < out.size(); ++i) {
+      out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i));
+    }
+
+    return out;
+  }
+
+  Status Resize(KernelContext* ctx, int64_t new_num_groups) {
+    // We only need to change num_groups in resize
+    // similar to other hash aggregate kernels
+    num_groups = new_num_groups;
+    return Status::OK();
+  }
+
+  Status Consume(KernelContext* ctx, const ExecSpan& batch) {
+    ARROW_ASSIGN_OR_RAISE(
+        std::shared_ptr<RecordBatch> rb,
+        batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool()));
+
+    // This is similar to GroupedListImpl
+    // last array is the group id
+    const ArraySpan& groups_array_data = batch[batch.num_values() - 1].array;
+    DCHECK_EQ(groups_array_data.offset, 0);
+    int64_t batch_num_values = groups_array_data.length;
+    const auto* batch_groups = groups_array_data.GetValues<uint32_t>(1);
+    RETURN_NOT_OK(groups.Append(batch_groups, batch_num_values));
+    values.push_back(std::move(rb));
+    num_values += batch_num_values;
+    return Status::OK();
+  }
+  Status Merge(KernelContext* ctx, KernelState&& other_state,
+               const ArrayData& group_id_mapping) {
+    // This is similar to GroupedListImpl
+    auto& other = checked_cast<PythonUdfHashAggregatorImpl&>(other_state);
+    auto& other_values = other.values;
+    const uint32_t* other_raw_groups = other.groups.data();
+    values.insert(values.end(), std::make_move_iterator(other_values.begin()),
+                  std::make_move_iterator(other_values.end()));
+
+    auto g = group_id_mapping.GetValues<uint32_t>(1);
+    for (uint32_t other_g = 0; static_cast<int64_t>(other_g) < other.num_values;
+         ++other_g) {
+      // Different state can have different group_id mappings, so we
+      // need to translate the ids
+      RETURN_NOT_OK(groups.Append(g[other_raw_groups[other_g]]));
+    }
+
+    num_values += other.num_values;
+    return Status::OK();
+  }
+
+  Status Finalize(KernelContext* ctx, Datum* out) {
+    // Exclude the last column which is the group id
+    const int num_args = input_schema->num_fields() - 1;
+
+    ARROW_ASSIGN_OR_RAISE(auto groups_buffer, groups.Finish());
+    ARROW_ASSIGN_OR_RAISE(auto groupings,
+                          Grouper::MakeGroupings(UInt32Array(num_values, groups_buffer),
+                                                 static_cast<uint32_t>(num_groups)));
+
+    ARROW_ASSIGN_OR_RAISE(auto table,
+                          arrow::Table::FromRecordBatches(input_schema, values));
+    ARROW_ASSIGN_OR_RAISE(auto rb, table->CombineChunksToBatch(ctx->memory_pool()));
+    UdfContext udf_context{ctx->memory_pool(), table->num_rows()};
+
+    if (rb->num_rows() == 0) {
+      *out = Datum();
+      return Status::OK();
+    }
+
+    ARROW_ASSIGN_OR_RAISE(RecordBatchVector rbs, ApplyGroupings(*groupings, rb));
+
+    return SafeCallIntoPython([&] {
+      ARROW_ASSIGN_OR_RAISE(std::unique_ptr<ArrayBuilder> builder,
+                            MakeBuilder(output_type, ctx->memory_pool()));
+      for (auto& group_rb : rbs) {
+        std::unique_ptr<OwnedRef> result;
+        OwnedRef arg_tuple(PyTuple_New(num_args));
+        RETURN_NOT_OK(CheckPyError());
+
+        for (int arg_id = 0; arg_id < num_args; arg_id++) {
+          // Since we combined chunks there is only one chunk
+          std::shared_ptr<Array> c_data = group_rb->column(arg_id);
+          PyObject* data = wrap_array(c_data);
+          PyTuple_SetItem(arg_tuple.obj(), arg_id, data);
+        }
+
+        result =
+            std::make_unique<OwnedRef>(cb(function->obj(), udf_context, arg_tuple.obj()));
+        RETURN_NOT_OK(CheckPyError());
+
+        // unwrapping the output for expected output type
+        if (is_scalar(result->obj())) {
+          ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> val,
+                                unwrap_scalar(result->obj()));
+          if (*output_type != *val->type) {
+            return Status::TypeError("Expected output datatype ", output_type->ToString(),
+                                     ", but function returned datatype ",
+                                     val->type->ToString());
+          }
+          ARROW_RETURN_NOT_OK(builder->AppendScalar(std::move(*val)));
+        } else {
+          return Status::TypeError("Unexpected output type: ",
+                                   Py_TYPE(result->obj())->tp_name, " (expected Scalar)");
+        }
+      }
+      ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish());
+      out->value = std::move(result->data());
+      return Status::OK();
+    });
+  }
+
+  std::shared_ptr<OwnedRefNoGIL> function;
+  UdfWrapperCallback cb;
+  // Accumulated input batches
+  std::vector<std::shared_ptr<RecordBatch>> values;
+  // Group ids - extracted from the last column from the batch
+  TypedBufferBuilder<uint32_t> groups;
+  int64_t num_groups = 0;
+  int64_t num_values = 0;
   std::shared_ptr<Schema> input_schema;
   std::shared_ptr<DataType> output_type;
 };
@@ -332,15 +519,15 @@ Status RegisterUdf(PyObject* user_function, compute::KernelInit kernel_init,
 
 }  // namespace
 
-Status RegisterScalarFunction(PyObject* user_function, UdfWrapperCallback wrapper,
+Status RegisterScalarFunction(PyObject* function, UdfWrapperCallback cb,
                               const UdfOptions& options,
                               compute::FunctionRegistry* registry) {
-  return RegisterUdf(user_function,
-                     PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(user_function)},
-                     wrapper, options, registry);
+  return RegisterUdf(function,
+                     PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function)}, cb,
+                     options, registry);
 }
 
-Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapper,
+Status RegisterTabularFunction(PyObject* function, UdfWrapperCallback cb,
                                const UdfOptions& options,
                                compute::FunctionRegistry* registry) {
   if (options.arity.num_args != 0 || options.arity.is_varargs) {
@@ -350,24 +537,14 @@ Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapp
     return Status::Invalid("tabular function with non-struct output");
   }
   return RegisterUdf(
-      user_function,
-      PythonTableUdfKernelInit{std::make_shared<OwnedRefNoGIL>(user_function), wrapper},
-      wrapper, options, registry);
-}
-
-Status AddAggKernel(std::shared_ptr<compute::KernelSignature> sig,
-                    compute::KernelInit init, compute::ScalarAggregateFunction* func) {
-  compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init),
-                                        AggregateUdfConsume, AggregateUdfMerge,
-                                        AggregateUdfFinalize, /*ordered=*/false);
-  RETURN_NOT_OK(func->AddKernel(std::move(kernel)));
-  return Status::OK();
+      function, PythonTableUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function), cb},
+      cb, options, registry);
 }
 
-Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_wrapper,
-                                 const UdfOptions& options,
-                                 compute::FunctionRegistry* registry) {
-  if (!PyCallable_Check(agg_function)) {
+Status RegisterScalarAggregateFunction(PyObject* function, UdfWrapperCallback cb,
+                                       const UdfOptions& options,
+                                       compute::FunctionRegistry* registry) {
+  if (!PyCallable_Check(function)) {
     return Status::TypeError("Expected a callable Python object.");
   }
 
@@ -378,7 +555,7 @@ Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_
   // Py_INCREF here so that once a function is registered
   // its refcount gets increased by 1 and doesn't get gced
   // if all existing refs are gone
-  Py_INCREF(agg_function);
+  Py_INCREF(function);
 
   static auto default_scalar_aggregate_options =
       compute::ScalarAggregateOptions::Defaults();
@@ -392,24 +569,109 @@ Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_
   }
   compute::OutputType output_type(options.output_type);
 
-  compute::KernelInit init = [agg_wrapper, agg_function, options](
-                                 compute::KernelContext* ctx,
-                                 const compute::KernelInitArgs& args)
+  compute::KernelInit init = [cb, function, options](compute::KernelContext* ctx,
+                                                     const compute::KernelInitArgs& args)
       -> Result<std::unique_ptr<compute::KernelState>> {
     return std::make_unique<PythonUdfScalarAggregatorImpl>(
-        agg_wrapper, std::make_shared<OwnedRefNoGIL>(agg_function), options.input_types,
+        std::make_shared<OwnedRefNoGIL>(function), cb, options.input_types,
         options.output_type);
   };
 
-  RETURN_NOT_OK(AddAggKernel(
-      compute::KernelSignature::Make(std::move(input_types), std::move(output_type),
-                                     options.arity.is_varargs),
-      init, aggregate_func.get()));
-
+  auto sig = compute::KernelSignature::Make(
+      std::move(input_types), std::move(output_type), options.arity.is_varargs);
+  compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init),
+                                        AggregateUdfConsume, AggregateUdfMerge,
+                                        AggregateUdfFinalize, /*ordered=*/false);
+  RETURN_NOT_OK(aggregate_func->AddKernel(std::move(kernel)));
   RETURN_NOT_OK(registry->AddFunction(std::move(aggregate_func)));
   return Status::OK();
 }
 
+/// \brief Create a new UdfOptions with adjustment for hash kernel
+/// \param options User provided udf options
+UdfOptions AdjustForHashAggregate(const UdfOptions& options) {
+  UdfOptions hash_options;
+  // Append hash_ before the function name to seperate from the scalar
+  // version
+  hash_options.func_name = "hash_" + options.func_name;
+  // Extend input types with group id. Group id is appended by the group
+  // aggregation node. Here we change both arity and input types
+  if (options.arity.is_varargs) {
+    hash_options.arity = options.arity;
+  } else {
+    hash_options.arity = compute::Arity(options.arity.num_args + 1, false);
+  }
+  // Changing the function doc shouldn't be necessarily because group id
+  // is not user visible, however, this is currently needed to pass the
+  // function validation. The name group_id_array is consistent with
+  // hash kernels in hash_aggregate.cc
+  hash_options.func_doc = options.func_doc;
+  hash_options.func_doc.arg_names.emplace_back("group_id_array");
+  std::vector<std::shared_ptr<DataType>> input_dtypes = options.input_types;
+  input_dtypes.emplace_back(uint32());
+  hash_options.input_types = std::move(input_dtypes);
+  hash_options.output_type = options.output_type;
+  return hash_options;
+}
+
+Status RegisterHashAggregateFunction(PyObject* function, UdfWrapperCallback cb,
+                                     const UdfOptions& options,
+                                     compute::FunctionRegistry* registry) {
+  if (!PyCallable_Check(function)) {
+    return Status::TypeError("Expected a callable Python object.");
+  }
+
+  if (registry == NULLPTR) {
+    registry = compute::GetFunctionRegistry();
+  }
+
+  // Py_INCREF here so that once a function is registered
+  // its refcount gets increased by 1 and doesn't get gced
+  // if all existing refs are gone
+  Py_INCREF(function);
+  UdfOptions hash_options = AdjustForHashAggregate(options);
+
+  std::vector<compute::InputType> input_types;
+  for (const auto& in_dtype : hash_options.input_types) {
+    input_types.emplace_back(in_dtype);
+  }
+  compute::OutputType output_type(hash_options.output_type);
+
+  static auto default_hash_aggregate_options =
+      compute::ScalarAggregateOptions::Defaults();
+  auto hash_aggregate_func = std::make_shared<compute::HashAggregateFunction>(
+      hash_options.func_name, hash_options.arity, hash_options.func_doc,
+      &default_hash_aggregate_options);
+
+  compute::KernelInit init = [function, cb, hash_options](
+                                 compute::KernelContext* ctx,
+                                 const compute::KernelInitArgs& args)
+      -> Result<std::unique_ptr<compute::KernelState>> {
+    return std::make_unique<PythonUdfHashAggregatorImpl>(
+        std::make_shared<OwnedRefNoGIL>(function), cb, hash_options.input_types,
+        hash_options.output_type);
+  };
+
+  auto sig = compute::KernelSignature::Make(
+      std::move(input_types), std::move(output_type), hash_options.arity.is_varargs);
+
+  compute::HashAggregateKernel kernel(
+      std::move(sig), std::move(init), HashAggregateUdfResize, HashAggregateUdfConsume,
+      HashAggregateUdfMerge, HashAggregateUdfFinalize, /*ordered=*/false);
+  RETURN_NOT_OK(hash_aggregate_func->AddKernel(std::move(kernel)));
+  RETURN_NOT_OK(registry->AddFunction(std::move(hash_aggregate_func)));
+  return Status::OK();
+}
+
+Status RegisterAggregateFunction(PyObject* function, UdfWrapperCallback cb,
+                                 const UdfOptions& options,
+                                 compute::FunctionRegistry* registry) {
+  RETURN_NOT_OK(RegisterScalarAggregateFunction(function, cb, options, registry));
+  RETURN_NOT_OK(RegisterHashAggregateFunction(function, cb, options, registry));
+
+  return Status::OK();
+}
+
 Result<std::shared_ptr<RecordBatchReader>> CallTabularFunction(
     const std::string& func_name, const std::vector<Datum>& args,
     compute::FunctionRegistry* registry) {
diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py
index 34faaa157a..93ecae7bfa 100644
--- a/python/pyarrow/tests/test_substrait.py
+++ b/python/pyarrow/tests/test_substrait.py
@@ -607,7 +607,7 @@ def test_output_field_names(use_threads):
     assert res_tb == expected
 
 
-def test_aggregate_udf_basic(varargs_agg_func_fixture):
+def test_scalar_aggregate_udf_basic(varargs_agg_func_fixture):
 
     test_table = pa.Table.from_pydict(
         {"k": [1, 1, 2, 2], "v1": [1, 2, 3, 4],
@@ -630,7 +630,7 @@ def test_aggregate_udf_basic(varargs_agg_func_fixture):
       "extensionFunction": {
         "extensionUriReference": 1,
         "functionAnchor": 1,
-        "name": "y=sum_mean(x...)"
+        "name": "sum_mean"
       }
     }
   ],
@@ -753,3 +753,173 @@ def test_aggregate_udf_basic(varargs_agg_func_fixture):
     })
 
     assert res_tb == expected_tb
+
+
+def test_hash_aggregate_udf_basic(varargs_agg_func_fixture):
+
+    test_table = pa.Table.from_pydict(
+        {"t": [1, 1, 1, 1, 2, 2, 2, 2],
+         "k": [1, 0, 0, 1, 0, 1, 0, 1],
+         "v1": [1, 2, 3, 4, 5, 6, 7, 8],
+         "v2": [1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0]}
+    )
+
+    def table_provider(names, _):
+        return test_table
+
+    substrait_query = b"""
+{
+  "extensionUris": [
+    {
+      "extensionUriAnchor": 1,
+      "uri": "urn:arrow:substrait_simple_extension_function"
+    },
+  ],
+  "extensions": [
+    {
+      "extensionFunction": {
+        "extensionUriReference": 1,
+        "functionAnchor": 1,
+        "name": "sum_mean"
+      }
+    }
+  ],
+  "relations": [
+    {
+      "root": {
+        "input": {
+          "extensionSingle": {
+            "common": {
+              "emit": {
+                "outputMapping": [
+                  0,
+                  1,
+                  2
+                ]
+              }
+            },
+            "input": {
+              "read": {
+                "baseSchema": {
+                  "names": [
+                    "t",
+                    "k",
+                    "v1",
+                    "v2",
+                  ],
+                  "struct": {
+                    "types": [
+                      {
+                        "i64": {
+                          "nullability": "NULLABILITY_REQUIRED"
+                        }
+                      },
+                      {
+                        "i64": {
+                          "nullability": "NULLABILITY_REQUIRED"
+                        }
+                      },
+                      {
+                        "i64": {
+                          "nullability": "NULLABILITY_NULLABLE"
+                        }
+                      },
+                      {
+                        "fp64": {
+                          "nullability": "NULLABILITY_NULLABLE"
+                        }
+                      }
+                    ],
+                    "nullability": "NULLABILITY_REQUIRED"
+                  }
+                },
+                "namedTable": {
+                  "names": ["t1"]
+                }
+              }
+            },
+            "detail": {
+              "@type": "/arrow.substrait_ext.SegmentedAggregateRel",
+              "groupingKeys": [
+                {
+                  "directReference": {
+                    "structField": {
+                      "field": 1
+                    }
+                  },
+                  "rootReference": {}
+                }
+              ],
+              "segmentKeys": [
+                {
+                  "directReference": {
+                    "structField": {}
+                  },
+                  "rootReference": {}
+                }
+              ],
+              "measures": [
+                {
+                  "measure": {
+                    "functionReference": 1,
+                    "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+                    "outputType": {
+                      "fp64": {
+                        "nullability": "NULLABILITY_NULLABLE"
+                      }
+                    },
+                    "arguments": [
+                      {
+                        "value": {
+                          "selection": {
+                            "directReference": {
+                              "structField": {
+                                "field": 2
+                              }
+                            },
+                            "rootReference": {}
+                          }
+                        }
+                      },
+                      {
+                        "value": {
+                          "selection": {
+                            "directReference": {
+                              "structField": {
+                                "field": 3
+                              }
+                            },
+                            "rootReference": {}
+                          }
+                        }
+                      }
+                    ]
+                  }
+                }
+              ]
+            }
+          }
+        },
+        "names": [
+          "t",
+          "k",
+          "v_avg"
+        ]
+      }
+    }
+  ],
+}
+"""
+    buf = pa._substrait._parse_json_plan(substrait_query)
+    reader = pa.substrait.run_query(
+        buf, table_provider=table_provider, use_threads=False)
+    res_tb = reader.read_all()
+
+    expected_tb = pa.Table.from_pydict({
+        't': [1, 1, 2, 2],
+        'k': [1, 0, 0, 1],
+        'v_avg': [3.5, 3.5, 9.0, 11.0]
+    })
+
+    # Ordering of k is deterministic because this is running with serial execution
+    assert res_tb == expected_tb
diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py
index c0cfd3d26e..5631e19455 100644
--- a/python/pyarrow/tests/test_udf.py
+++ b/python/pyarrow/tests/test_udf.py
@@ -18,6 +18,8 @@
 
 import pytest
 
+import numpy as np
+
 import pyarrow as pa
 from pyarrow import compute as pc
 
@@ -42,6 +44,28 @@ class MyError(RuntimeError):
     pass
 
 
+@pytest.fixture(scope="session")
+def sum_agg_func_fixture():
+    """
+    Register a unary aggregate function (mean)
+    """
+    def func(ctx, x, *args):
+        return pa.scalar(np.nansum(x))
+
+    func_name = "sum_udf"
+    func_doc = empty_udf_doc
+
+    pc.register_aggregate_function(func,
+                                   func_name,
+                                   func_doc,
+                                   {
+                                       "x": pa.float64(),
+                                   },
+                                   pa.float64()
+                                   )
+    return func, func_name
+
+
 @pytest.fixture(scope="session")
 def exception_agg_func_fixture():
     def func(ctx, x):
@@ -656,45 +680,120 @@ def test_udt_datasource1_exception():
         _test_datasource1_udt(datasource1_exception)
 
 
-def test_agg_basic(unary_agg_func_fixture):
+def test_scalar_agg_basic(unary_agg_func_fixture):
     arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64())
-    result = pc.call_function("y=avg(x)", [arr])
+    result = pc.call_function("mean_udf", [arr])
     expected = pa.scalar(30.0)
     assert result == expected
 
 
-def test_agg_empty(unary_agg_func_fixture):
+def test_scalar_agg_empty(unary_agg_func_fixture):
     empty = pa.array([], pa.float64())
 
     with pytest.raises(pa.ArrowInvalid, match='empty inputs'):
-        pc.call_function("y=avg(x)", [empty])
+        pc.call_function("mean_udf", [empty])
 
 
-def test_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture):
+def test_scalar_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture):
     arr = pa.array([10, 20, 30, 40, 50], pa.int64())
     with pytest.raises(pa.ArrowTypeError, match="output datatype"):
         pc.call_function("y=wrong_output_dtype(x)", [arr])
 
 
-def test_agg_wrong_output_type(wrong_output_type_agg_func_fixture):
+def test_scalar_agg_wrong_output_type(wrong_output_type_agg_func_fixture):
     arr = pa.array([10, 20, 30, 40, 50], pa.int64())
     with pytest.raises(pa.ArrowTypeError, match="output type"):
         pc.call_function("y=wrong_output_type(x)", [arr])
 
 
-def test_agg_varargs(varargs_agg_func_fixture):
+def test_scalar_agg_varargs(varargs_agg_func_fixture):
     arr1 = pa.array([10, 20, 30, 40, 50], pa.int64())
     arr2 = pa.array([1.0, 2.0, 3.0, 4.0, 5.0], pa.float64())
 
     result = pc.call_function(
-        "y=sum_mean(x...)", [arr1, arr2]
+        "sum_mean", [arr1, arr2]
     )
     expected = pa.scalar(33.0)
     assert result == expected
 
 
-def test_agg_exception(exception_agg_func_fixture):
+def test_scalar_agg_exception(exception_agg_func_fixture):
     arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64())
 
     with pytest.raises(RuntimeError, match='Oops'):
         pc.call_function("y=exception_len(x)", [arr])
+
+
+def test_hash_agg_basic(unary_agg_func_fixture):
+    arr1 = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64())
+    arr2 = pa.array([4, 2, 1, 2, 1], pa.int32())
+
+    arr3 = pa.array([60.0, 70.0, 80.0, 90.0, 100.0], pa.float64())
+    arr4 = pa.array([5, 1, 1, 4, 1], pa.int32())
+
+    table1 = pa.table([arr2, arr1], names=["id", "value"])
+    table2 = pa.table([arr4, arr3], names=["id", "value"])
+    table = pa.concat_tables([table1, table2])
+
+    result = table.group_by("id").aggregate([("value", "mean_udf")])
+    expected = table.group_by("id").aggregate(
+        [("value", "mean")]).rename_columns(['id', 'value_mean_udf'])
+
+    assert result.sort_by('id') == expected.sort_by('id')
+
+
+def test_hash_agg_empty(unary_agg_func_fixture):
+    arr1 = pa.array([], pa.float64())
+    arr2 = pa.array([], pa.int32())
+    table = pa.table([arr2, arr1], names=["id", "value"])
+
+    result = table.group_by("id").aggregate([("value", "mean_udf")])
+    expected = pa.table([pa.array([], pa.int32()), pa.array(
+        [], pa.float64())], names=['id', 'value_mean_udf'])
+
+    assert result == expected
+
+
+def test_hash_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture):
+    arr1 = pa.array([10, 20, 30, 40, 50], pa.int64())
+    arr2 = pa.array([4, 2, 1, 2, 1], pa.int32())
+
+    table = pa.table([arr2, arr1], names=["id", "value"])
+    with pytest.raises(pa.ArrowTypeError, match="output datatype"):
+        table.group_by("id").aggregate([("value", "y=wrong_output_dtype(x)")])
+
+
+def test_hash_agg_wrong_output_type(wrong_output_type_agg_func_fixture):
+    arr1 = pa.array([10, 20, 30, 40, 50], pa.int64())
+    arr2 = pa.array([4, 2, 1, 2, 1], pa.int32())
+    table = pa.table([arr2, arr1], names=["id", "value"])
+
+    with pytest.raises(pa.ArrowTypeError, match="output type"):
+        table.group_by("id").aggregate([("value", "y=wrong_output_type(x)")])
+
+
+def test_hash_agg_exception(exception_agg_func_fixture):
+    arr1 = pa.array([10, 20, 30, 40, 50], pa.int64())
+    arr2 = pa.array([4, 2, 1, 2, 1], pa.int32())
+    table = pa.table([arr2, arr1], names=["id", "value"])
+
+    with pytest.raises(RuntimeError, match='Oops'):
+        table.group_by("id").aggregate([("value", "y=exception_len(x)")])
+
+
+def test_hash_agg_random(sum_agg_func_fixture):
+    """Test hash aggregate udf with randomly sampled data"""
+
+    value_num = 1000000
+    group_num = 1000
+
+    arr1 = pa.array(np.repeat(1, value_num), pa.float64())
+    arr2 = pa.array(np.random.choice(group_num, value_num), pa.int32())
+
+    table = pa.table([arr2, arr1], names=['id', 'value'])
+
+    result = table.group_by("id").aggregate([("value", "sum_udf")])
+    expected = table.group_by("id").aggregate(
+        [("value", "sum")]).rename_columns(['id', 'value_sum_udf'])
+
+    assert result.sort_by('id') == expected.sort_by('id')