You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/06/15 21:30:47 UTC
[arrow-datafusion] branch main updated: Allow `AggregateUDF` to define retractable batch , implement sliding window functions (#6671)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 84e49771b7 Allow `AggregateUDF` to define retractable batch , implement sliding window functions (#6671)
84e49771b7 is described below
commit 84e49771b7403b3d313d8493b61d2d58dcdd7514
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Thu Jun 15 17:30:41 2023 -0400
Allow `AggregateUDF` to define retractable batch , implement sliding window functions (#6671)
* feat: support sliding window accumulators
Rationale:
The default implementation of the `Accumulator` trait returns an error
for the `retract_batch` API.
* Allow AggregateUDF to define retractable batch
* Return error rather than wrong results when aggregate without retract_batch is used as a sliding accumulator
---------
Co-authored-by: Stuart Carnie <st...@gmail.com>
---
datafusion/core/src/physical_plan/udaf.rs | 57 ++++++++-
datafusion/core/src/physical_plan/windows/mod.rs | 65 ++++++----
datafusion/core/tests/user_defined_aggregates.rs | 144 +++++++++++++++++------
datafusion/expr/src/accumulator.rs | 19 ++-
datafusion/expr/src/udaf.rs | 13 +-
5 files changed, 234 insertions(+), 64 deletions(-)
diff --git a/datafusion/core/src/physical_plan/udaf.rs b/datafusion/core/src/physical_plan/udaf.rs
index d9f52eba77..bca9eb8782 100644
--- a/datafusion/core/src/physical_plan/udaf.rs
+++ b/datafusion/core/src/physical_plan/udaf.rs
@@ -28,7 +28,7 @@ use arrow::{
use super::{expressions::format_state_name, Accumulator, AggregateExpr};
use crate::physical_plan::PhysicalExpr;
-use datafusion_common::Result;
+use datafusion_common::{DataFusionError, Result};
pub use datafusion_expr::AggregateUDF;
use datafusion_physical_expr::aggregate::utils::down_cast_any_ref;
@@ -106,6 +106,61 @@ impl AggregateExpr for AggregateFunctionExpr {
(self.fun.accumulator)(&self.data_type)
}
+ fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ let accumulator = (self.fun.accumulator)(&self.data_type)?;
+
+ // Accumulators that have window frame startings different
+ // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to
+ // implement retract_batch method in order to run correctly
+ // currently in DataFusion.
+ //
+ // If this `retract_batches` is not present, there is no way
+ // to calculate result correctly. For example, the query
+ //
+ // ```sql
+ // SELECT
+ // SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a
+ // FROM
+ // t
+ // ```
+ //
+ // 1. First sum value will be the sum of rows between `[0, 1)`,
+ //
+ // 2. Second sum value will be the sum of rows between `[0, 2)`
+ //
+ // 3. Third sum value will be the sum of rows between `[1, 3)`, etc.
+ //
+ // Since the accumulator keeps the running sum:
+ //
+ // 1. First sum we add to the state sum value between `[0, 1)`
+ //
+ // 2. Second sum we add to the state sum value between `[1, 2)`
+ // (`[0, 1)` is already in the state sum, hence running sum will
+ // cover `[0, 2)` range)
+ //
+ // 3. Third sum we add to the state sum value between `[2, 3)`
+ // (`[0, 2)` is already in the state sum). Also we need to
+ // retract values between `[0, 1)` by this way we can obtain sum
+ // between [1, 3) which is indeed the apropriate range.
+ //
+ // When we use `UNBOUNDED PRECEDING` in the query starting
+ // index will always be 0 for the desired range, and hence the
+ // `retract_batch` method will not be called. In this case
+ // having retract_batch is not a requirement.
+ //
+ // This approach is a a bit different than window function
+ // approach. In window function (when they use a window frame)
+ // they get all the desired range during evaluation.
+ if !accumulator.supports_retract_batch() {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Aggregate can not be used as a sliding accumulator because \
+ `retract_batch` is not implemented: {}",
+ self.name
+ )));
+ }
+ Ok(accumulator)
+ }
+
fn name(&self) -> &str {
&self.name
}
diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs
index a43ada82ee..0cd6a746dd 100644
--- a/datafusion/core/src/physical_plan/windows/mod.rs
+++ b/datafusion/core/src/physical_plan/windows/mod.rs
@@ -33,8 +33,9 @@ use datafusion_expr::{
window_function::{BuiltInWindowFunction, WindowFunction},
WindowFrame,
};
-use datafusion_physical_expr::window::{
- BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr,
+use datafusion_physical_expr::{
+ window::{BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr},
+ AggregateExpr,
};
use std::borrow::Borrow;
use std::convert::TryInto;
@@ -68,21 +69,12 @@ pub fn create_window_expr(
WindowFunction::AggregateFunction(fun) => {
let aggregate =
aggregates::create_aggregate_expr(fun, false, args, input_schema, name)?;
- if !window_frame.start_bound.is_unbounded() {
- Arc::new(SlidingAggregateWindowExpr::new(
- aggregate,
- partition_by,
- order_by,
- window_frame,
- ))
- } else {
- Arc::new(PlainAggregateWindowExpr::new(
- aggregate,
- partition_by,
- order_by,
- window_frame,
- ))
- }
+ window_expr_from_aggregate_expr(
+ partition_by,
+ order_by,
+ window_frame,
+ aggregate,
+ )
}
WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr::new(
create_built_in_window_expr(fun, args, input_schema, name)?,
@@ -90,13 +82,44 @@ pub fn create_window_expr(
order_by,
window_frame,
)),
- WindowFunction::AggregateUDF(fun) => Arc::new(PlainAggregateWindowExpr::new(
- udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?,
+ WindowFunction::AggregateUDF(fun) => {
+ let aggregate =
+ udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?;
+ window_expr_from_aggregate_expr(
+ partition_by,
+ order_by,
+ window_frame,
+ aggregate,
+ )
+ }
+ })
+}
+
+/// Creates an appropriate [`WindowExpr`] based on the window frame and
+fn window_expr_from_aggregate_expr(
+ partition_by: &[Arc<dyn PhysicalExpr>],
+ order_by: &[PhysicalSortExpr],
+ window_frame: Arc<WindowFrame>,
+ aggregate: Arc<dyn AggregateExpr>,
+) -> Arc<dyn WindowExpr> {
+ // Is there a potentially unlimited sized window frame?
+ let unbounded_window = window_frame.start_bound.is_unbounded();
+
+ if !unbounded_window {
+ Arc::new(SlidingAggregateWindowExpr::new(
+ aggregate,
partition_by,
order_by,
window_frame,
- )),
- })
+ ))
+ } else {
+ Arc::new(PlainAggregateWindowExpr::new(
+ aggregate,
+ partition_by,
+ order_by,
+ window_frame,
+ ))
+ }
}
fn get_scalar_value_from_args(
diff --git a/datafusion/core/tests/user_defined_aggregates.rs b/datafusion/core/tests/user_defined_aggregates.rs
index 7c95b9a2d4..4202b9bea9 100644
--- a/datafusion/core/tests/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined_aggregates.rs
@@ -40,13 +40,32 @@ use datafusion::{
prelude::SessionContext,
scalar::ScalarValue,
};
-use datafusion_common::cast::as_primitive_array;
+use datafusion_common::{assert_contains, cast::as_primitive_array, DataFusionError};
+
+/// Test to show the contents of the setup
+#[tokio::test]
+async fn test_setup() {
+ let TestContext { ctx, test_state: _ } = TestContext::new();
+ let sql = "SELECT * from t order by time";
+ let expected = vec![
+ "+-------+----------------------------+",
+ "| value | time |",
+ "+-------+----------------------------+",
+ "| 2.0 | 1970-01-01T00:00:00.000002 |",
+ "| 3.0 | 1970-01-01T00:00:00.000003 |",
+ "| 1.0 | 1970-01-01T00:00:00.000004 |",
+ "| 5.0 | 1970-01-01T00:00:00.000005 |",
+ "| 5.0 | 1970-01-01T00:00:00.000005 |",
+ "+-------+----------------------------+",
+ ];
+ assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
+}
/// Basic user defined aggregate
#[tokio::test]
async fn test_udaf() {
- let TestContext { ctx, counters } = TestContext::new();
- assert!(!counters.update_batch());
+ let TestContext { ctx, test_state } = TestContext::new();
+ assert!(!test_state.update_batch());
let sql = "SELECT time_sum(time) from t";
let expected = vec![
"+----------------------------+",
@@ -55,16 +74,16 @@ async fn test_udaf() {
"| 1970-01-01T00:00:00.000019 |",
"+----------------------------+",
];
- assert_batches_eq!(expected, &execute(&ctx, sql).await);
+ assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
// normal aggregates call update_batch
- assert!(counters.update_batch());
- assert!(!counters.retract_batch());
+ assert!(test_state.update_batch());
+ assert!(!test_state.retract_batch());
}
/// User defined aggregate used as a window function
#[tokio::test]
async fn test_udaf_as_window() {
- let TestContext { ctx, counters } = TestContext::new();
+ let TestContext { ctx, test_state } = TestContext::new();
let sql = "SELECT time_sum(time) OVER() as time_sum from t";
let expected = vec![
"+----------------------------+",
@@ -77,16 +96,16 @@ async fn test_udaf_as_window() {
"| 1970-01-01T00:00:00.000019 |",
"+----------------------------+",
];
- assert_batches_eq!(expected, &execute(&ctx, sql).await);
+ assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
// aggregate over the entire window function call update_batch
- assert!(counters.update_batch());
- assert!(!counters.retract_batch());
+ assert!(test_state.update_batch());
+ assert!(!test_state.retract_batch());
}
/// User defined aggregate used as a window function with a window frame
#[tokio::test]
async fn test_udaf_as_window_with_frame() {
- let TestContext { ctx, counters } = TestContext::new();
+ let TestContext { ctx, test_state } = TestContext::new();
let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t";
let expected = vec![
"+----------------------------+",
@@ -94,22 +113,34 @@ async fn test_udaf_as_window_with_frame() {
"+----------------------------+",
"| 1970-01-01T00:00:00.000005 |",
"| 1970-01-01T00:00:00.000009 |",
+ "| 1970-01-01T00:00:00.000012 |",
"| 1970-01-01T00:00:00.000014 |",
- "| 1970-01-01T00:00:00.000019 |",
- "| 1970-01-01T00:00:00.000019 |",
+ "| 1970-01-01T00:00:00.000010 |",
"+----------------------------+",
];
- assert_batches_eq!(expected, &execute(&ctx, sql).await);
+ assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
// user defined aggregates with window frame should be calling retract batch
- // but doesn't yet: https://github.com/apache/arrow-datafusion/issues/6611
- assert!(counters.update_batch());
- assert!(!counters.retract_batch());
+ assert!(test_state.update_batch());
+ assert!(test_state.retract_batch());
+}
+
+/// Ensure that User defined aggregate used as a window function with a window
+/// frame, but that does not implement retract_batch, returns an error
+#[tokio::test]
+async fn test_udaf_as_window_with_frame_without_retract_batch() {
+ let test_state = Arc::new(TestState::new().with_error_on_retract_batch());
+
+ let TestContext { ctx, test_state: _ } = TestContext::new_with_test_state(test_state);
+ let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t";
+ // Note if this query ever does start working
+ let err = execute(&ctx, sql).await.unwrap_err();
+ assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: AggregateUDF { name: \"time_sum\"");
}
/// Basic query for with a udaf returning a structure
#[tokio::test]
async fn test_udaf_returning_struct() {
- let TestContext { ctx, counters: _ } = TestContext::new();
+ let TestContext { ctx, test_state: _ } = TestContext::new();
let sql = "SELECT first(value, time) from t";
let expected = vec![
"+------------------------------------------------+",
@@ -118,13 +149,13 @@ async fn test_udaf_returning_struct() {
"| {value: 2.0, time: 1970-01-01T00:00:00.000002} |",
"+------------------------------------------------+",
];
- assert_batches_eq!(expected, &execute(&ctx, sql).await);
+ assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
}
/// Demonstrate extracting the fields from a structure using a subquery
#[tokio::test]
async fn test_udaf_returning_struct_subquery() {
- let TestContext { ctx, counters: _ } = TestContext::new();
+ let TestContext { ctx, test_state: _ } = TestContext::new();
let sql = "select sq.first['value'], sq.first['time'] from (SELECT first(value, time) as first from t) as sq";
let expected = vec![
"+-----------------+----------------------------+",
@@ -133,11 +164,11 @@ async fn test_udaf_returning_struct_subquery() {
"| 2.0 | 1970-01-01T00:00:00.000002 |",
"+-----------------+----------------------------+",
];
- assert_batches_eq!(expected, &execute(&ctx, sql).await);
+ assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
}
-async fn execute(ctx: &SessionContext, sql: &str) -> Vec<RecordBatch> {
- ctx.sql(sql).await.unwrap().collect().await.unwrap()
+async fn execute(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> {
+ ctx.sql(sql).await?.collect().await
}
/// Returns an context with a table "t" and the "first" and "time_sum"
@@ -155,13 +186,16 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Vec<RecordBatch> {
/// ```
struct TestContext {
ctx: SessionContext,
- counters: Arc<TestCounters>,
+ test_state: Arc<TestState>,
}
impl TestContext {
fn new() -> Self {
- let counters = Arc::new(TestCounters::new());
+ let test_state = Arc::new(TestState::new());
+ Self::new_with_test_state(test_state)
+ }
+ fn new_with_test_state(test_state: Arc<TestState>) -> Self {
let value = Float64Array::from(vec![3.0, 2.0, 1.0, 5.0, 5.0]);
let time = TimestampNanosecondArray::from(vec![3000, 2000, 4000, 5000, 5000]);
@@ -178,21 +212,24 @@ impl TestContext {
// Tell DataFusion about the "first" function
FirstSelector::register(&mut ctx);
// Tell DataFusion about the "time_sum" function
- TimeSum::register(&mut ctx, Arc::clone(&counters));
+ TimeSum::register(&mut ctx, Arc::clone(&test_state));
- Self { ctx, counters }
+ Self { ctx, test_state }
}
}
#[derive(Debug, Default)]
-struct TestCounters {
+struct TestState {
/// was update_batch called?
update_batch: AtomicBool,
/// was retract_batch called?
retract_batch: AtomicBool,
+ /// should the udaf throw an error if retract batch is called? Can
+ /// only be configured at construction time.
+ error_on_retract_batch: bool,
}
-impl TestCounters {
+impl TestState {
fn new() -> Self {
Default::default()
}
@@ -202,10 +239,31 @@ impl TestCounters {
self.update_batch.load(Ordering::SeqCst)
}
+ /// Set the `update_batch` flag
+ fn set_update_batch(&self) {
+ self.update_batch.store(true, Ordering::SeqCst)
+ }
+
/// Has `retract_batch` been called?
fn retract_batch(&self) -> bool {
self.retract_batch.load(Ordering::SeqCst)
}
+
+ /// set the `retract_batch` flag
+ fn set_retract_batch(&self) {
+ self.retract_batch.store(true, Ordering::SeqCst)
+ }
+
+ /// Is this state configured to return an error on retract batch?
+ fn error_on_retract_batch(&self) -> bool {
+ self.error_on_retract_batch
+ }
+
+ /// Configure the test to return error on retract batch
+ fn with_error_on_retract_batch(mut self) -> Self {
+ self.error_on_retract_batch = true;
+ self
+ }
}
/// Models a user defined aggregate function that computes the a sum
@@ -213,15 +271,15 @@ impl TestCounters {
#[derive(Debug)]
struct TimeSum {
sum: i64,
- counters: Arc<TestCounters>,
+ test_state: Arc<TestState>,
}
impl TimeSum {
- fn new(counters: Arc<TestCounters>) -> Self {
- Self { sum: 0, counters }
+ fn new(test_state: Arc<TestState>) -> Self {
+ Self { sum: 0, test_state }
}
- fn register(ctx: &mut SessionContext, counters: Arc<TestCounters>) {
+ fn register(ctx: &mut SessionContext, test_state: Arc<TestState>) {
let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None);
// Returns the same type as its input
@@ -237,8 +295,9 @@ impl TimeSum {
let signature = Signature::exact(vec![timestamp_type], volatility);
+ let captured_state = Arc::clone(&test_state);
let accumulator: AccumulatorFunctionImplementation =
- Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&counters)))));
+ Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state)))));
let name = "time_sum";
@@ -256,12 +315,13 @@ impl Accumulator for TimeSum {
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- self.counters.update_batch.store(true, Ordering::SeqCst);
+ self.test_state.set_update_batch();
assert_eq!(values.len(), 1);
let arr = &values[0];
let arr = arr.as_primitive::<TimestampNanosecondType>();
for v in arr.values().iter() {
+ println!("Adding {v}");
self.sum += v;
}
Ok(())
@@ -273,6 +333,7 @@ impl Accumulator for TimeSum {
}
fn evaluate(&self) -> Result<ScalarValue> {
+ println!("Evaluating to {}", self.sum);
Ok(ScalarValue::TimestampNanosecond(Some(self.sum), None))
}
@@ -282,16 +343,27 @@ impl Accumulator for TimeSum {
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- self.counters.retract_batch.store(true, Ordering::SeqCst);
+ if self.test_state.error_on_retract_batch() {
+ return Err(DataFusionError::Execution(
+ "Error in Retract Batch".to_string(),
+ ));
+ }
+
+ self.test_state.set_retract_batch();
assert_eq!(values.len(), 1);
let arr = &values[0];
let arr = arr.as_primitive::<TimestampNanosecondType>();
for v in arr.values().iter() {
+ println!("Retracting {v}");
self.sum -= v;
}
Ok(())
}
+
+ fn supports_retract_batch(&self) -> bool {
+ !self.test_state.error_on_retract_batch()
+ }
}
/// Models a specialized timeseries aggregate function
diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr/src/accumulator.rs
index 7e941d0cff..c448ed4235 100644
--- a/datafusion/expr/src/accumulator.rs
+++ b/datafusion/expr/src/accumulator.rs
@@ -21,12 +21,15 @@ use arrow::array::ArrayRef;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use std::fmt::Debug;
-/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and
-/// generically accumulates values.
+/// Accumulates an aggregate's state.
+///
+/// `Accumulator`s are stateful objects that lives throughout the
+/// evaluation of multiple rows and aggregate multiple values together
+/// into a final output aggregate.
///
/// An accumulator knows how to:
/// * update its state from inputs via `update_batch`
-/// * retract an update to its state from given inputs via `retract_batch`
+/// * (optionally) retract an update to its state from given inputs via `retract_batch`
/// * convert its internal state to a vector of aggregate values
/// * update its state from multiple accumulators' states via `merge_batch`
/// * compute the final value from its internal state via `evaluate`
@@ -68,6 +71,16 @@ pub trait Accumulator: Send + Sync + Debug {
))
}
+ /// Does the accumulator support incrementally updating its value
+ /// by *removing* values.
+ ///
+ /// If this function returns true, [`Self::retract_batch`] will be
+ /// called for sliding window functions such as queries with an
+ /// `OVER (ROWS BETWEEN 1 PRECEDING AND 2 FOLLOWING)`
+ fn supports_retract_batch(&self) -> bool {
+ false
+ }
+
/// Updates the accumulator's state from an `Array` containing one
/// or more intermediate values.
///
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index 6c3690e283..1b455a0985 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -24,13 +24,20 @@ use crate::{
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
-/// Logical representation of a user-defined aggregate function (UDAF)
-/// A UDAF is different from a UDF in that it is stateful across batches.
+/// Logical representation of a user-defined aggregate function (UDAF).
+///
+/// A UDAF is different from a user-defined scalar function (UDF) in
+/// that it is stateful across batches. UDAFs can be used as normal
+/// aggregate functions as well as window functions (the `OVER` clause)
+///
+/// For more information, please see [the examples]
+///
+/// [the examples]: https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples#single-process
#[derive(Clone)]
pub struct AggregateUDF {
/// name
pub name: String,
- /// signature
+ /// Signature (input arguments)
pub signature: Signature,
/// Return type
pub return_type: ReturnTypeFunction,