You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/12 22:36:46 UTC
[arrow-datafusion] branch master updated: feat: user-defined aggregate function(UDAF) as window function (#4553)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 0f0d33526 feat: user-defined aggregate function(UDAF) as window function (#4553)
0f0d33526 is described below
commit 0f0d335266200a281ea730843d782e59285ce60f
Author: LFC <ba...@gmail.com>
AuthorDate: Tue Dec 13 06:36:41 2022 +0800
feat: user-defined aggregate function(UDAF) as window function (#4553)
* feat: user-defined aggregate function(UDAF) as window function
* fix: resolve PR comments
Co-authored-by: luofucong <lu...@greptime.com>
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
datafusion/core/src/physical_plan/windows/mod.rs | 84 +++++++++++++++-
datafusion/expr/src/window_function.rs | 101 ++++++++++---------
datafusion/proto/src/to_proto.rs | 7 ++
datafusion/sql/Cargo.toml | 3 +
datafusion/sql/src/planner.rs | 117 +++++++++++++++++++----
5 files changed, 246 insertions(+), 66 deletions(-)
diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs
index 0f837e581..5cd0a1a9c 100644
--- a/datafusion/core/src/physical_plan/windows/mod.rs
+++ b/datafusion/core/src/physical_plan/windows/mod.rs
@@ -25,7 +25,7 @@ use crate::physical_plan::{
PhysicalSortExpr, RowNumber,
},
type_coercion::coerce,
- PhysicalExpr,
+ udaf, PhysicalExpr,
};
use crate::scalar::ScalarValue;
use arrow::datatypes::Schema;
@@ -67,6 +67,12 @@ pub fn create_window_expr(
order_by,
window_frame,
)),
+ WindowFunction::AggregateUDF(fun) => Arc::new(AggregateWindowExpr::new(
+ udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?,
+ partition_by,
+ order_by,
+ window_frame,
+ )),
})
}
@@ -172,6 +178,7 @@ mod tests {
use arrow::datatypes::{DataType, Field, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::cast::as_primitive_array;
+ use datafusion_expr::{create_udaf, Accumulator, AggregateState, Volatility};
use futures::FutureExt;
fn create_test_schema(partitions: usize) -> Result<(Arc<CsvExec>, SchemaRef)> {
@@ -180,6 +187,81 @@ mod tests {
Ok((csv, schema))
}
+ #[tokio::test]
+ async fn window_function_with_udaf() -> Result<()> {
+ #[derive(Debug)]
+ struct MyCount(i64);
+
+ impl Accumulator for MyCount {
+ fn state(&self) -> Result<Vec<AggregateState>> {
+ Ok(vec![AggregateState::Scalar(ScalarValue::Int64(Some(
+ self.0,
+ )))])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let array = &values[0];
+ self.0 += (array.len() - array.data().null_count()) as i64;
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ let counts: &Int64Array = arrow::array::as_primitive_array(&states[0]);
+ if let Some(c) = &arrow::compute::sum(counts) {
+ self.0 += *c;
+ }
+ Ok(())
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ Ok(ScalarValue::Int64(Some(self.0)))
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+ }
+
+ let my_count = create_udaf(
+ "my_count",
+ DataType::Int64,
+ Arc::new(DataType::Int64),
+ Volatility::Immutable,
+ Arc::new(|_| Ok(Box::new(MyCount(0)))),
+ Arc::new(vec![DataType::Int64]),
+ );
+
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+ let (input, schema) = create_test_schema(1)?;
+
+ let window_exec = Arc::new(WindowAggExec::try_new(
+ vec![create_window_expr(
+ &WindowFunction::AggregateUDF(Arc::new(my_count)),
+ "my_count".to_owned(),
+ &[col("c3", &schema)?],
+ &[],
+ &[],
+ Arc::new(WindowFrame::new(false)),
+ schema.as_ref(),
+ )?],
+ input,
+ schema,
+ vec![],
+ None,
+ )?);
+
+ let result: Vec<RecordBatch> = collect(window_exec, task_ctx).await?;
+ assert_eq!(result.len(), 1);
+
+ let columns = result[0].columns();
+
+ let count: &Int64Array = as_primitive_array(&columns[0])?;
+ assert_eq!(count.value(0), 100);
+ assert_eq!(count.value(99), 100);
+ Ok(())
+ }
+
#[tokio::test]
async fn window_function() -> Result<()> {
let session_ctx = SessionContext::new();
diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs
index c37653ab0..038091ac7 100644
--- a/datafusion/expr/src/window_function.rs
+++ b/datafusion/expr/src/window_function.rs
@@ -23,9 +23,10 @@
use crate::aggregate_function::AggregateFunction;
use crate::type_coercion::functions::data_types;
-use crate::{aggregate_function, Signature, TypeSignature, Volatility};
+use crate::{aggregate_function, AggregateUDF, Signature, TypeSignature, Volatility};
use arrow::datatypes::DataType;
use datafusion_common::{DataFusionError, Result};
+use std::sync::Arc;
use std::{fmt, str::FromStr};
/// WindowFunction
@@ -35,24 +36,18 @@ pub enum WindowFunction {
AggregateFunction(AggregateFunction),
/// window function that leverages a built-in window function
BuiltInWindowFunction(BuiltInWindowFunction),
+ AggregateUDF(Arc<AggregateUDF>),
}
-impl FromStr for WindowFunction {
- type Err = DataFusionError;
- fn from_str(name: &str) -> Result<WindowFunction> {
- let name = name.to_lowercase();
- if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) {
- Ok(WindowFunction::AggregateFunction(aggregate))
- } else if let Ok(built_in_function) =
- BuiltInWindowFunction::from_str(name.as_str())
- {
- Ok(WindowFunction::BuiltInWindowFunction(built_in_function))
- } else {
- Err(DataFusionError::Plan(format!(
- "There is no window function named {}",
- name
- )))
- }
+/// Find DataFusion's built-in window function by name.
+pub fn find_df_window_func(name: &str) -> Option<WindowFunction> {
+ let name = name.to_lowercase();
+ if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) {
+ Some(WindowFunction::AggregateFunction(aggregate))
+ } else if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) {
+ Some(WindowFunction::BuiltInWindowFunction(built_in_function))
+ } else {
+ None
}
}
@@ -79,6 +74,7 @@ impl fmt::Display for WindowFunction {
match self {
WindowFunction::AggregateFunction(fun) => fun.fmt(f),
WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f),
+ WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f),
}
}
}
@@ -153,6 +149,9 @@ pub fn return_type(
WindowFunction::BuiltInWindowFunction(fun) => {
return_type_for_built_in(fun, input_expr_types)
}
+ WindowFunction::AggregateUDF(fun) => {
+ Ok((*(fun.return_type)(input_expr_types)?).clone())
+ }
}
}
@@ -188,6 +187,7 @@ pub fn signature(fun: &WindowFunction) -> Signature {
match fun {
WindowFunction::AggregateFunction(fun) => aggregate_function::signature(fun),
WindowFunction::BuiltInWindowFunction(fun) => signature_for_built_in(fun),
+ WindowFunction::AggregateUDF(fun) => fun.signature.clone(),
}
}
@@ -221,11 +221,10 @@ pub fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature {
#[cfg(test)]
mod tests {
use super::*;
- use std::str::FromStr;
#[test]
fn test_count_return_type() -> Result<()> {
- let fun = WindowFunction::from_str("count")?;
+ let fun = find_df_window_func("count").unwrap();
let observed = return_type(&fun, &[DataType::Utf8])?;
assert_eq!(DataType::Int64, observed);
@@ -237,7 +236,7 @@ mod tests {
#[test]
fn test_first_value_return_type() -> Result<()> {
- let fun = WindowFunction::from_str("first_value")?;
+ let fun = find_df_window_func("first_value").unwrap();
let observed = return_type(&fun, &[DataType::Utf8])?;
assert_eq!(DataType::Utf8, observed);
@@ -249,7 +248,7 @@ mod tests {
#[test]
fn test_last_value_return_type() -> Result<()> {
- let fun = WindowFunction::from_str("last_value")?;
+ let fun = find_df_window_func("last_value").unwrap();
let observed = return_type(&fun, &[DataType::Utf8])?;
assert_eq!(DataType::Utf8, observed);
@@ -261,7 +260,7 @@ mod tests {
#[test]
fn test_lead_return_type() -> Result<()> {
- let fun = WindowFunction::from_str("lead")?;
+ let fun = find_df_window_func("lead").unwrap();
let observed = return_type(&fun, &[DataType::Utf8])?;
assert_eq!(DataType::Utf8, observed);
@@ -273,7 +272,7 @@ mod tests {
#[test]
fn test_lag_return_type() -> Result<()> {
- let fun = WindowFunction::from_str("lag")?;
+ let fun = find_df_window_func("lag").unwrap();
let observed = return_type(&fun, &[DataType::Utf8])?;
assert_eq!(DataType::Utf8, observed);
@@ -285,7 +284,7 @@ mod tests {
#[test]
fn test_nth_value_return_type() -> Result<()> {
- let fun = WindowFunction::from_str("nth_value")?;
+ let fun = find_df_window_func("nth_value").unwrap();
let observed = return_type(&fun, &[DataType::Utf8, DataType::UInt64])?;
assert_eq!(DataType::Utf8, observed);
@@ -297,7 +296,7 @@ mod tests {
#[test]
fn test_percent_rank_return_type() -> Result<()> {
- let fun = WindowFunction::from_str("percent_rank")?;
+ let fun = find_df_window_func("percent_rank").unwrap();
let observed = return_type(&fun, &[])?;
assert_eq!(DataType::Float64, observed);
@@ -306,7 +305,7 @@ mod tests {
#[test]
fn test_cume_dist_return_type() -> Result<()> {
- let fun = WindowFunction::from_str("cume_dist")?;
+ let fun = find_df_window_func("cume_dist").unwrap();
let observed = return_type(&fun, &[])?;
assert_eq!(DataType::Float64, observed);
@@ -334,8 +333,8 @@ mod tests {
"sum",
];
for name in names {
- let fun = WindowFunction::from_str(name)?;
- let fun2 = WindowFunction::from_str(name.to_uppercase().as_str())?;
+ let fun = find_df_window_func(name).unwrap();
+ let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap();
assert_eq!(fun, fun2);
assert_eq!(fun.to_string(), name.to_uppercase());
}
@@ -343,39 +342,49 @@ mod tests {
}
#[test]
- fn test_window_function_from_str() -> Result<()> {
+ fn test_find_df_window_function() {
assert_eq!(
- WindowFunction::from_str("max")?,
- WindowFunction::AggregateFunction(AggregateFunction::Max)
+ find_df_window_func("max"),
+ Some(WindowFunction::AggregateFunction(AggregateFunction::Max))
);
assert_eq!(
- WindowFunction::from_str("min")?,
- WindowFunction::AggregateFunction(AggregateFunction::Min)
+ find_df_window_func("min"),
+ Some(WindowFunction::AggregateFunction(AggregateFunction::Min))
);
assert_eq!(
- WindowFunction::from_str("avg")?,
- WindowFunction::AggregateFunction(AggregateFunction::Avg)
+ find_df_window_func("avg"),
+ Some(WindowFunction::AggregateFunction(AggregateFunction::Avg))
);
assert_eq!(
- WindowFunction::from_str("cume_dist")?,
- WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::CumeDist)
+ find_df_window_func("cume_dist"),
+ Some(WindowFunction::BuiltInWindowFunction(
+ BuiltInWindowFunction::CumeDist
+ ))
);
assert_eq!(
- WindowFunction::from_str("first_value")?,
- WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue)
+ find_df_window_func("first_value"),
+ Some(WindowFunction::BuiltInWindowFunction(
+ BuiltInWindowFunction::FirstValue
+ ))
);
assert_eq!(
- WindowFunction::from_str("LAST_value")?,
- WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue)
+ find_df_window_func("LAST_value"),
+ Some(WindowFunction::BuiltInWindowFunction(
+ BuiltInWindowFunction::LastValue
+ ))
);
assert_eq!(
- WindowFunction::from_str("LAG")?,
- WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag)
+ find_df_window_func("LAG"),
+ Some(WindowFunction::BuiltInWindowFunction(
+ BuiltInWindowFunction::Lag
+ ))
);
assert_eq!(
- WindowFunction::from_str("LEAD")?,
- WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead)
+ find_df_window_func("LEAD"),
+ Some(WindowFunction::BuiltInWindowFunction(
+ BuiltInWindowFunction::Lead
+ ))
);
- Ok(())
+ assert_eq!(find_df_window_func("not_exist"), None)
}
}
diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs
index 4c280f7b0..fdbcd060e 100644
--- a/datafusion/proto/src/to_proto.rs
+++ b/datafusion/proto/src/to_proto.rs
@@ -61,6 +61,8 @@ pub enum Error {
InvalidTimeUnit(TimeUnit),
UnsupportedScalarFunction(BuiltinScalarFunction),
+
+ NotImplemented(String),
}
impl std::error::Error for Error {}
@@ -99,6 +101,9 @@ impl std::fmt::Display for Error {
Self::UnsupportedScalarFunction(function) => {
write!(f, "Unsupported scalar function {:?}", function)
}
+ Self::NotImplemented(s) => {
+ write!(f, "Not implemented: {}", s)
+ }
}
}
}
@@ -546,6 +551,8 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
protobuf::BuiltInWindowFunction::from(fun).into(),
)
}
+ // TODO: Tracked in https://github.com/apache/arrow-datafusion/issues/4584
+ WindowFunction::AggregateUDF(_) => return Err(Error::NotImplemented("UDAF as window function in proto".to_string()))
};
let arg_expr: Option<Box<Self>> = if !args.is_empty() {
let arg = &args[0];
diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml
index dd19e6aff..bc6add89a 100644
--- a/datafusion/sql/Cargo.toml
+++ b/datafusion/sql/Cargo.toml
@@ -42,3 +42,6 @@ datafusion-common = { path = "../common", version = "15.0.0" }
datafusion-expr = { path = "../expr", version = "15.0.0" }
log = "^0.4"
sqlparser = "0.28"
+
+[dev-dependencies]
+datafusion = { path = "../core" }
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 45eee8e42..c78032a02 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -68,7 +68,8 @@ use datafusion_expr::{
GetIndexedField, Operator, ScalarUDF, SubqueryAlias, WindowFrame, WindowFrameUnits,
};
use datafusion_expr::{
- window_function::WindowFunction, BuiltinScalarFunction, TableSource,
+ window_function::{self, WindowFunction},
+ BuiltinScalarFunction, TableSource,
};
use crate::parser::{CreateExternalTable, DescribeTable, Statement as DFStatement};
@@ -2356,8 +2357,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
} else {
WindowFrame::new(!order_by.is_empty())
};
- let fun = WindowFunction::from_str(&name)?;
- match fun {
+ let fun = self.find_window_func(&name)?;
+ let expr = match fun {
WindowFunction::AggregateFunction(
aggregate_fun,
) => {
@@ -2367,7 +2368,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
schema,
)?;
- return Ok(Expr::WindowFunction {
+ Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(
aggregate_fun,
),
@@ -2375,22 +2376,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
partition_by,
order_by,
window_frame,
- });
+ }
}
- WindowFunction::BuiltInWindowFunction(
- window_fun,
- ) => {
- return Ok(Expr::WindowFunction {
- fun: WindowFunction::BuiltInWindowFunction(
- window_fun,
- ),
+ _ => {
+ Expr::WindowFunction {
+ fun,
args: self.function_args_to_expr(function.args, schema)?,
partition_by,
order_by,
window_frame,
- });
+ }
}
- }
+ };
+ return Ok(expr);
}
// next, aggregate built-ins
@@ -2454,6 +2452,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}
+ fn find_window_func(&self, name: &str) -> Result<WindowFunction> {
+ window_function::find_df_window_func(name)
+ .or_else(|| {
+ self.schema_provider
+ .get_aggregate_meta(name)
+ .map(WindowFunction::AggregateUDF)
+ })
+ .ok_or_else(|| {
+ DataFusionError::Plan(format!(
+ "There is no window function named {}",
+ name
+ ))
+ })
+ }
+
fn parse_exists_subquery(
&self,
subquery: Query,
@@ -3288,11 +3301,14 @@ fn ensure_any_column_reference_is_unambiguous(
#[cfg(test)]
mod tests {
+ use datafusion::arrow::array::ArrayRef;
+ use datafusion::prelude::SessionContext;
use std::any::Any;
use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect};
use datafusion_common::assert_contains;
+ use datafusion_expr::{create_udaf, Accumulator, AggregateState, Volatility};
use super::*;
@@ -5304,6 +5320,64 @@ mod tests {
quick_test(sql, expected);
}
+ #[test]
+ fn udaf_as_window_func() -> Result<()> {
+ #[derive(Debug)]
+ struct MyAccumulator;
+
+ impl Accumulator for MyAccumulator {
+ fn state(&self) -> Result<Vec<AggregateState>> {
+ unimplemented!()
+ }
+
+ fn update_batch(&mut self, _: &[ArrayRef]) -> Result<()> {
+ unimplemented!()
+ }
+
+ fn merge_batch(&mut self, _: &[ArrayRef]) -> Result<()> {
+ unimplemented!()
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ unimplemented!()
+ }
+
+ fn size(&self) -> usize {
+ unimplemented!()
+ }
+ }
+
+ let my_acc = create_udaf(
+ "my_acc",
+ DataType::Int32,
+ Arc::new(DataType::Int32),
+ Volatility::Immutable,
+ Arc::new(|_| Ok(Box::new(MyAccumulator))),
+ Arc::new(vec![DataType::Int32]),
+ );
+
+ let mut context = SessionContext::new();
+ context.register_table(
+ TableReference::Bare { table: "my_table" },
+ Arc::new(datafusion::datasource::empty::EmptyTable::new(Arc::new(
+ Schema::new(vec![
+ Field::new("a", DataType::UInt32, false),
+ Field::new("b", DataType::Int32, false),
+ ]),
+ ))),
+ )?;
+ context.register_udaf(my_acc);
+
+ let sql = "SELECT a, MY_ACC(b) OVER(PARTITION BY a) FROM my_table";
+ let expected = r#"Projection: my_table.a, AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "<FUNC>" }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+ WindowAggr: windowExpr=[[AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "<FUNC>" }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
+ TableScan: my_table"#;
+
+ let plan = context.create_logical_plan(sql)?;
+ assert_eq!(format!("{:?}", plan), expected);
+ Ok(())
+ }
+
#[test]
fn select_typed_date_string() {
let sql = "SELECT date '2020-12-10' AS date";
@@ -5345,7 +5419,8 @@ mod tests {
sql: &str,
dialect: &dyn Dialect,
) -> Result<LogicalPlan> {
- let planner = SqlToRel::new(&MockContextProvider {});
+ let context = MockContextProvider::default();
+ let planner = SqlToRel::new(&context);
let result = DFParser::parse_sql_with_dialect(sql, dialect);
let mut ast = result?;
planner.statement_to_plan(ast.pop_front().unwrap())
@@ -5356,7 +5431,8 @@ mod tests {
dialect: &dyn Dialect,
options: ParserOptions,
) -> Result<LogicalPlan> {
- let planner = SqlToRel::new_with_options(&MockContextProvider {}, options);
+ let context = MockContextProvider::default();
+ let planner = SqlToRel::new_with_options(&context, options);
let result = DFParser::parse_sql_with_dialect(sql, dialect);
let mut ast = result?;
planner.statement_to_plan(ast.pop_front().unwrap())
@@ -5405,7 +5481,10 @@ mod tests {
plan
}
- struct MockContextProvider {}
+ #[derive(Default)]
+ struct MockContextProvider {
+ udafs: HashMap<String, Arc<AggregateUDF>>,
+ }
impl ContextProvider for MockContextProvider {
fn get_table_provider(
@@ -5491,8 +5570,8 @@ mod tests {
unimplemented!()
}
- fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
- unimplemented!()
+ fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
+ self.udafs.get(name).map(Arc::clone)
}
fn get_variable_type(&self, _: &[String]) -> Option<DataType> {