You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2023/12/28 20:10:52 UTC
(arrow-datafusion-python) branch main updated: feat: udaf: enable multiple column input (#546)
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 76d7fcf feat: udaf: enable multiple column input (#546)
76d7fcf is described below
commit 76d7fcffdd9d8664a003b76754033ebad4a15847
Author: Dan Lovell <dl...@gmail.com>
AuthorDate: Thu Dec 28 15:10:46 2023 -0500
feat: udaf: enable multiple column input (#546)
---
datafusion/__init__.py | 2 ++
src/udaf.rs | 4 ++--
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/datafusion/__init__.py b/datafusion/__init__.py
index c854f3f..df53b39 100644
--- a/datafusion/__init__.py
+++ b/datafusion/__init__.py
@@ -213,6 +213,8 @@ def udaf(accum, input_type, return_type, state_type, volatility, name=None):
)
if name is None:
name = accum.__qualname__.lower()
+ if isinstance(input_type, pa.lib.DataType):
+ input_type = [input_type]
return AggregateUDF(
name=name,
accumulator=accum,
diff --git a/src/udaf.rs b/src/udaf.rs
index 5c43b67..0e7a8de 100644
--- a/src/udaf.rs
+++ b/src/udaf.rs
@@ -148,14 +148,14 @@ impl PyAggregateUDF {
fn new(
name: &str,
accumulator: PyObject,
- input_type: PyArrowType<DataType>,
+ input_type: PyArrowType<Vec<DataType>>,
return_type: PyArrowType<DataType>,
state_type: PyArrowType<Vec<DataType>>,
volatility: &str,
) -> PyResult<Self> {
let function = create_udaf(
name,
- vec![input_type.0],
+ input_type.0,
Arc::new(return_type.0),
parse_volatility(volatility)?,
to_rust_accumulator(accumulator),