You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2023/12/28 20:10:52 UTC

(arrow-datafusion-python) branch main updated: feat: udaf: enable multiple column input (#546)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 76d7fcf  feat: udaf: enable multiple column input (#546)
76d7fcf is described below

commit 76d7fcffdd9d8664a003b76754033ebad4a15847
Author: Dan Lovell <dl...@gmail.com>
AuthorDate: Thu Dec 28 15:10:46 2023 -0500

    feat: udaf: enable multiple column input (#546)
---
 datafusion/__init__.py | 2 ++
 src/udaf.rs            | 4 ++--
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/datafusion/__init__.py b/datafusion/__init__.py
index c854f3f..df53b39 100644
--- a/datafusion/__init__.py
+++ b/datafusion/__init__.py
@@ -213,6 +213,8 @@ def udaf(accum, input_type, return_type, state_type, volatility, name=None):
         )
     if name is None:
         name = accum.__qualname__.lower()
+    if isinstance(input_type, pa.lib.DataType):
+        input_type = [input_type]
     return AggregateUDF(
         name=name,
         accumulator=accum,
diff --git a/src/udaf.rs b/src/udaf.rs
index 5c43b67..0e7a8de 100644
--- a/src/udaf.rs
+++ b/src/udaf.rs
@@ -148,14 +148,14 @@ impl PyAggregateUDF {
     fn new(
         name: &str,
         accumulator: PyObject,
-        input_type: PyArrowType<DataType>,
+        input_type: PyArrowType<Vec<DataType>>,
         return_type: PyArrowType<DataType>,
         state_type: PyArrowType<Vec<DataType>>,
         volatility: &str,
     ) -> PyResult<Self> {
         let function = create_udaf(
             name,
-            vec![input_type.0],
+            input_type.0,
             Arc::new(return_type.0),
             parse_volatility(volatility)?,
             to_rust_accumulator(accumulator),