You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/06/15 21:30:47 UTC

[arrow-datafusion] branch main updated: Allow `AggregateUDF` to define retractable batch , implement sliding window functions (#6671)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 84e49771b7 Allow `AggregateUDF` to define retractable batch , implement sliding window functions (#6671)
84e49771b7 is described below

commit 84e49771b7403b3d313d8493b61d2d58dcdd7514
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Thu Jun 15 17:30:41 2023 -0400

    Allow `AggregateUDF` to define retractable batch , implement sliding window functions (#6671)
    
    * feat: support sliding window accumulators
    
    Rationale:
    
    The default implementation of the `Accumulator` trait returns an error
    for the `retract_batch` API.
    
    * Allow AggregateUDF to define retractable batch
    
    * Return error rather than wrong results when aggregate without retract_batch is used as a sliding accumulator
    
    ---------
    
    Co-authored-by: Stuart Carnie <st...@gmail.com>
---
 datafusion/core/src/physical_plan/udaf.rs        |  57 ++++++++-
 datafusion/core/src/physical_plan/windows/mod.rs |  65 ++++++----
 datafusion/core/tests/user_defined_aggregates.rs | 144 +++++++++++++++++------
 datafusion/expr/src/accumulator.rs               |  19 ++-
 datafusion/expr/src/udaf.rs                      |  13 +-
 5 files changed, 234 insertions(+), 64 deletions(-)

diff --git a/datafusion/core/src/physical_plan/udaf.rs b/datafusion/core/src/physical_plan/udaf.rs
index d9f52eba77..bca9eb8782 100644
--- a/datafusion/core/src/physical_plan/udaf.rs
+++ b/datafusion/core/src/physical_plan/udaf.rs
@@ -28,7 +28,7 @@ use arrow::{
 
 use super::{expressions::format_state_name, Accumulator, AggregateExpr};
 use crate::physical_plan::PhysicalExpr;
-use datafusion_common::Result;
+use datafusion_common::{DataFusionError, Result};
 pub use datafusion_expr::AggregateUDF;
 
 use datafusion_physical_expr::aggregate::utils::down_cast_any_ref;
@@ -106,6 +106,61 @@ impl AggregateExpr for AggregateFunctionExpr {
         (self.fun.accumulator)(&self.data_type)
     }
 
+    fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        let accumulator = (self.fun.accumulator)(&self.data_type)?;
+
+        // Accumulators that have window frame startings different
+        // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to
+        // implement retract_batch method in order to run correctly
+        // currently in DataFusion.
+        //
+        // If this `retract_batches` is not present, there is no way
+        // to calculate result correctly. For example, the query
+        //
+        // ```sql
+        // SELECT
+        //  SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a
+        // FROM
+        //  t
+        // ```
+        //
+        // 1. First sum value will be the sum of rows between `[0, 1)`,
+        //
+        // 2. Second sum value will be the sum of rows between `[0, 2)`
+        //
+        // 3. Third sum value will be the sum of rows between `[1, 3)`, etc.
+        //
+        // Since the accumulator keeps the running sum:
+        //
+        // 1. First sum we add to the state sum value between `[0, 1)`
+        //
+        // 2. Second sum we add to the state sum value between `[1, 2)`
+        // (`[0, 1)` is already in the state sum, hence running sum will
+        // cover `[0, 2)` range)
+        //
+        // 3. Third sum we add to the state sum value between `[2, 3)`
+        // (`[0, 2)` is already in the state sum).  Also we need to
+        // retract values between `[0, 1)` by this way we can obtain sum
+        // between [1, 3) which is indeed the apropriate range.
+        //
+        // When we use `UNBOUNDED PRECEDING` in the query starting
+        // index will always be 0 for the desired range, and hence the
+        // `retract_batch` method will not be called. In this case
+        // having retract_batch is not a requirement.
+        //
+        // This approach is a a bit different than window function
+        // approach. In window function (when they use a window frame)
+        // they get all the desired range during evaluation.
+        if !accumulator.supports_retract_batch() {
+            return Err(DataFusionError::NotImplemented(format!(
+                "Aggregate can not be used as a sliding accumulator because \
+                     `retract_batch` is not implemented: {}",
+                self.name
+            )));
+        }
+        Ok(accumulator)
+    }
+
     fn name(&self) -> &str {
         &self.name
     }
diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs
index a43ada82ee..0cd6a746dd 100644
--- a/datafusion/core/src/physical_plan/windows/mod.rs
+++ b/datafusion/core/src/physical_plan/windows/mod.rs
@@ -33,8 +33,9 @@ use datafusion_expr::{
     window_function::{BuiltInWindowFunction, WindowFunction},
     WindowFrame,
 };
-use datafusion_physical_expr::window::{
-    BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr,
+use datafusion_physical_expr::{
+    window::{BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr},
+    AggregateExpr,
 };
 use std::borrow::Borrow;
 use std::convert::TryInto;
@@ -68,21 +69,12 @@ pub fn create_window_expr(
         WindowFunction::AggregateFunction(fun) => {
             let aggregate =
                 aggregates::create_aggregate_expr(fun, false, args, input_schema, name)?;
-            if !window_frame.start_bound.is_unbounded() {
-                Arc::new(SlidingAggregateWindowExpr::new(
-                    aggregate,
-                    partition_by,
-                    order_by,
-                    window_frame,
-                ))
-            } else {
-                Arc::new(PlainAggregateWindowExpr::new(
-                    aggregate,
-                    partition_by,
-                    order_by,
-                    window_frame,
-                ))
-            }
+            window_expr_from_aggregate_expr(
+                partition_by,
+                order_by,
+                window_frame,
+                aggregate,
+            )
         }
         WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr::new(
             create_built_in_window_expr(fun, args, input_schema, name)?,
@@ -90,13 +82,44 @@ pub fn create_window_expr(
             order_by,
             window_frame,
         )),
-        WindowFunction::AggregateUDF(fun) => Arc::new(PlainAggregateWindowExpr::new(
-            udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?,
+        WindowFunction::AggregateUDF(fun) => {
+            let aggregate =
+                udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?;
+            window_expr_from_aggregate_expr(
+                partition_by,
+                order_by,
+                window_frame,
+                aggregate,
+            )
+        }
+    })
+}
+
+/// Creates an appropriate [`WindowExpr`] based on the window frame and
+fn window_expr_from_aggregate_expr(
+    partition_by: &[Arc<dyn PhysicalExpr>],
+    order_by: &[PhysicalSortExpr],
+    window_frame: Arc<WindowFrame>,
+    aggregate: Arc<dyn AggregateExpr>,
+) -> Arc<dyn WindowExpr> {
+    // Is there a potentially unlimited sized window frame?
+    let unbounded_window = window_frame.start_bound.is_unbounded();
+
+    if !unbounded_window {
+        Arc::new(SlidingAggregateWindowExpr::new(
+            aggregate,
             partition_by,
             order_by,
             window_frame,
-        )),
-    })
+        ))
+    } else {
+        Arc::new(PlainAggregateWindowExpr::new(
+            aggregate,
+            partition_by,
+            order_by,
+            window_frame,
+        ))
+    }
 }
 
 fn get_scalar_value_from_args(
diff --git a/datafusion/core/tests/user_defined_aggregates.rs b/datafusion/core/tests/user_defined_aggregates.rs
index 7c95b9a2d4..4202b9bea9 100644
--- a/datafusion/core/tests/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined_aggregates.rs
@@ -40,13 +40,32 @@ use datafusion::{
     prelude::SessionContext,
     scalar::ScalarValue,
 };
-use datafusion_common::cast::as_primitive_array;
+use datafusion_common::{assert_contains, cast::as_primitive_array, DataFusionError};
+
+/// Test to show the contents of the setup
+#[tokio::test]
+async fn test_setup() {
+    let TestContext { ctx, test_state: _ } = TestContext::new();
+    let sql = "SELECT * from t order by time";
+    let expected = vec![
+        "+-------+----------------------------+",
+        "| value | time                       |",
+        "+-------+----------------------------+",
+        "| 2.0   | 1970-01-01T00:00:00.000002 |",
+        "| 3.0   | 1970-01-01T00:00:00.000003 |",
+        "| 1.0   | 1970-01-01T00:00:00.000004 |",
+        "| 5.0   | 1970-01-01T00:00:00.000005 |",
+        "| 5.0   | 1970-01-01T00:00:00.000005 |",
+        "+-------+----------------------------+",
+    ];
+    assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
+}
 
 /// Basic user defined aggregate
 #[tokio::test]
 async fn test_udaf() {
-    let TestContext { ctx, counters } = TestContext::new();
-    assert!(!counters.update_batch());
+    let TestContext { ctx, test_state } = TestContext::new();
+    assert!(!test_state.update_batch());
     let sql = "SELECT time_sum(time) from t";
     let expected = vec![
         "+----------------------------+",
@@ -55,16 +74,16 @@ async fn test_udaf() {
         "| 1970-01-01T00:00:00.000019 |",
         "+----------------------------+",
     ];
-    assert_batches_eq!(expected, &execute(&ctx, sql).await);
+    assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
     // normal aggregates call update_batch
-    assert!(counters.update_batch());
-    assert!(!counters.retract_batch());
+    assert!(test_state.update_batch());
+    assert!(!test_state.retract_batch());
 }
 
 /// User defined aggregate used as a window function
 #[tokio::test]
 async fn test_udaf_as_window() {
-    let TestContext { ctx, counters } = TestContext::new();
+    let TestContext { ctx, test_state } = TestContext::new();
     let sql = "SELECT time_sum(time) OVER() as time_sum from t";
     let expected = vec![
         "+----------------------------+",
@@ -77,16 +96,16 @@ async fn test_udaf_as_window() {
         "| 1970-01-01T00:00:00.000019 |",
         "+----------------------------+",
     ];
-    assert_batches_eq!(expected, &execute(&ctx, sql).await);
+    assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
     // aggregate over the entire window function call update_batch
-    assert!(counters.update_batch());
-    assert!(!counters.retract_batch());
+    assert!(test_state.update_batch());
+    assert!(!test_state.retract_batch());
 }
 
 /// User defined aggregate used as a window function with a window frame
 #[tokio::test]
 async fn test_udaf_as_window_with_frame() {
-    let TestContext { ctx, counters } = TestContext::new();
+    let TestContext { ctx, test_state } = TestContext::new();
     let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t";
     let expected = vec![
         "+----------------------------+",
@@ -94,22 +113,34 @@ async fn test_udaf_as_window_with_frame() {
         "+----------------------------+",
         "| 1970-01-01T00:00:00.000005 |",
         "| 1970-01-01T00:00:00.000009 |",
+        "| 1970-01-01T00:00:00.000012 |",
         "| 1970-01-01T00:00:00.000014 |",
-        "| 1970-01-01T00:00:00.000019 |",
-        "| 1970-01-01T00:00:00.000019 |",
+        "| 1970-01-01T00:00:00.000010 |",
         "+----------------------------+",
     ];
-    assert_batches_eq!(expected, &execute(&ctx, sql).await);
+    assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
     // user defined aggregates with window frame should be calling retract batch
-    // but doesn't yet: https://github.com/apache/arrow-datafusion/issues/6611
-    assert!(counters.update_batch());
-    assert!(!counters.retract_batch());
+    assert!(test_state.update_batch());
+    assert!(test_state.retract_batch());
+}
+
+/// Ensure that User defined aggregate used as a window function with a window
+/// frame, but that does not implement retract_batch, returns an error
+#[tokio::test]
+async fn test_udaf_as_window_with_frame_without_retract_batch() {
+    let test_state = Arc::new(TestState::new().with_error_on_retract_batch());
+
+    let TestContext { ctx, test_state: _ } = TestContext::new_with_test_state(test_state);
+    let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t";
+    // Note if this query ever does start working
+    let err = execute(&ctx, sql).await.unwrap_err();
+    assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: AggregateUDF { name: \"time_sum\"");
 }
 
 /// Basic query for with a udaf returning a structure
 #[tokio::test]
 async fn test_udaf_returning_struct() {
-    let TestContext { ctx, counters: _ } = TestContext::new();
+    let TestContext { ctx, test_state: _ } = TestContext::new();
     let sql = "SELECT first(value, time) from t";
     let expected = vec![
         "+------------------------------------------------+",
@@ -118,13 +149,13 @@ async fn test_udaf_returning_struct() {
         "| {value: 2.0, time: 1970-01-01T00:00:00.000002} |",
         "+------------------------------------------------+",
     ];
-    assert_batches_eq!(expected, &execute(&ctx, sql).await);
+    assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
 }
 
 /// Demonstrate extracting the fields from a structure using a subquery
 #[tokio::test]
 async fn test_udaf_returning_struct_subquery() {
-    let TestContext { ctx, counters: _ } = TestContext::new();
+    let TestContext { ctx, test_state: _ } = TestContext::new();
     let sql = "select sq.first['value'], sq.first['time'] from (SELECT first(value, time) as first from t) as sq";
     let expected = vec![
         "+-----------------+----------------------------+",
@@ -133,11 +164,11 @@ async fn test_udaf_returning_struct_subquery() {
         "| 2.0             | 1970-01-01T00:00:00.000002 |",
         "+-----------------+----------------------------+",
     ];
-    assert_batches_eq!(expected, &execute(&ctx, sql).await);
+    assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
 }
 
-async fn execute(ctx: &SessionContext, sql: &str) -> Vec<RecordBatch> {
-    ctx.sql(sql).await.unwrap().collect().await.unwrap()
+async fn execute(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> {
+    ctx.sql(sql).await?.collect().await
 }
 
 /// Returns an context with a table "t" and the "first" and "time_sum"
@@ -155,13 +186,16 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Vec<RecordBatch> {
 /// ```
 struct TestContext {
     ctx: SessionContext,
-    counters: Arc<TestCounters>,
+    test_state: Arc<TestState>,
 }
 
 impl TestContext {
     fn new() -> Self {
-        let counters = Arc::new(TestCounters::new());
+        let test_state = Arc::new(TestState::new());
+        Self::new_with_test_state(test_state)
+    }
 
+    fn new_with_test_state(test_state: Arc<TestState>) -> Self {
         let value = Float64Array::from(vec![3.0, 2.0, 1.0, 5.0, 5.0]);
         let time = TimestampNanosecondArray::from(vec![3000, 2000, 4000, 5000, 5000]);
 
@@ -178,21 +212,24 @@ impl TestContext {
         // Tell DataFusion about the "first" function
         FirstSelector::register(&mut ctx);
         // Tell DataFusion about the "time_sum" function
-        TimeSum::register(&mut ctx, Arc::clone(&counters));
+        TimeSum::register(&mut ctx, Arc::clone(&test_state));
 
-        Self { ctx, counters }
+        Self { ctx, test_state }
     }
 }
 
 #[derive(Debug, Default)]
-struct TestCounters {
+struct TestState {
     /// was update_batch called?
     update_batch: AtomicBool,
     /// was retract_batch called?
     retract_batch: AtomicBool,
+    /// should the udaf throw an error if retract batch is called? Can
+    /// only be configured at construction time.
+    error_on_retract_batch: bool,
 }
 
-impl TestCounters {
+impl TestState {
     fn new() -> Self {
         Default::default()
     }
@@ -202,10 +239,31 @@ impl TestCounters {
         self.update_batch.load(Ordering::SeqCst)
     }
 
+    /// Set the `update_batch` flag
+    fn set_update_batch(&self) {
+        self.update_batch.store(true, Ordering::SeqCst)
+    }
+
     /// Has `retract_batch` been called?
     fn retract_batch(&self) -> bool {
         self.retract_batch.load(Ordering::SeqCst)
     }
+
+    /// set the `retract_batch` flag
+    fn set_retract_batch(&self) {
+        self.retract_batch.store(true, Ordering::SeqCst)
+    }
+
+    /// Is this state configured to return an error on retract batch?
+    fn error_on_retract_batch(&self) -> bool {
+        self.error_on_retract_batch
+    }
+
+    /// Configure the test to return error on retract batch
+    fn with_error_on_retract_batch(mut self) -> Self {
+        self.error_on_retract_batch = true;
+        self
+    }
 }
 
 /// Models a user defined aggregate function that computes the a sum
@@ -213,15 +271,15 @@ impl TestCounters {
 #[derive(Debug)]
 struct TimeSum {
     sum: i64,
-    counters: Arc<TestCounters>,
+    test_state: Arc<TestState>,
 }
 
 impl TimeSum {
-    fn new(counters: Arc<TestCounters>) -> Self {
-        Self { sum: 0, counters }
+    fn new(test_state: Arc<TestState>) -> Self {
+        Self { sum: 0, test_state }
     }
 
-    fn register(ctx: &mut SessionContext, counters: Arc<TestCounters>) {
+    fn register(ctx: &mut SessionContext, test_state: Arc<TestState>) {
         let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None);
 
         // Returns the same type as its input
@@ -237,8 +295,9 @@ impl TimeSum {
 
         let signature = Signature::exact(vec![timestamp_type], volatility);
 
+        let captured_state = Arc::clone(&test_state);
         let accumulator: AccumulatorFunctionImplementation =
-            Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&counters)))));
+            Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state)))));
 
         let name = "time_sum";
 
@@ -256,12 +315,13 @@ impl Accumulator for TimeSum {
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        self.counters.update_batch.store(true, Ordering::SeqCst);
+        self.test_state.set_update_batch();
         assert_eq!(values.len(), 1);
         let arr = &values[0];
         let arr = arr.as_primitive::<TimestampNanosecondType>();
 
         for v in arr.values().iter() {
+            println!("Adding {v}");
             self.sum += v;
         }
         Ok(())
@@ -273,6 +333,7 @@ impl Accumulator for TimeSum {
     }
 
     fn evaluate(&self) -> Result<ScalarValue> {
+        println!("Evaluating to {}", self.sum);
         Ok(ScalarValue::TimestampNanosecond(Some(self.sum), None))
     }
 
@@ -282,16 +343,27 @@ impl Accumulator for TimeSum {
     }
 
     fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        self.counters.retract_batch.store(true, Ordering::SeqCst);
+        if self.test_state.error_on_retract_batch() {
+            return Err(DataFusionError::Execution(
+                "Error in Retract Batch".to_string(),
+            ));
+        }
+
+        self.test_state.set_retract_batch();
         assert_eq!(values.len(), 1);
         let arr = &values[0];
         let arr = arr.as_primitive::<TimestampNanosecondType>();
 
         for v in arr.values().iter() {
+            println!("Retracting {v}");
             self.sum -= v;
         }
         Ok(())
     }
+
+    fn supports_retract_batch(&self) -> bool {
+        !self.test_state.error_on_retract_batch()
+    }
 }
 
 /// Models a specialized timeseries aggregate function
diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr/src/accumulator.rs
index 7e941d0cff..c448ed4235 100644
--- a/datafusion/expr/src/accumulator.rs
+++ b/datafusion/expr/src/accumulator.rs
@@ -21,12 +21,15 @@ use arrow::array::ArrayRef;
 use datafusion_common::{DataFusionError, Result, ScalarValue};
 use std::fmt::Debug;
 
-/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and
-/// generically accumulates values.
+/// Accumulates an aggregate's state.
+///
+/// `Accumulator`s are stateful objects that lives throughout the
+/// evaluation of multiple rows and aggregate multiple values together
+/// into a final output aggregate.
 ///
 /// An accumulator knows how to:
 /// * update its state from inputs via `update_batch`
-/// * retract an update to its state from given inputs via `retract_batch`
+/// * (optionally) retract an update to its state from given inputs via `retract_batch`
 /// * convert its internal state to a vector of aggregate values
 /// * update its state from multiple accumulators' states via `merge_batch`
 /// * compute the final value from its internal state via `evaluate`
@@ -68,6 +71,16 @@ pub trait Accumulator: Send + Sync + Debug {
         ))
     }
 
+    /// Does the accumulator support incrementally updating its value
+    /// by *removing* values.
+    ///
+    /// If this function returns true, [`Self::retract_batch`] will be
+    /// called for sliding window functions such as queries with an
+    /// `OVER (ROWS BETWEEN 1 PRECEDING AND 2 FOLLOWING)`
+    fn supports_retract_batch(&self) -> bool {
+        false
+    }
+
     /// Updates the accumulator's state from an `Array` containing one
     /// or more intermediate values.
     ///
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index 6c3690e283..1b455a0985 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -24,13 +24,20 @@ use crate::{
 use std::fmt::{self, Debug, Formatter};
 use std::sync::Arc;
 
-/// Logical representation of a user-defined aggregate function (UDAF)
-/// A UDAF is different from a UDF in that it is stateful across batches.
+/// Logical representation of a user-defined aggregate function (UDAF).
+///
+/// A UDAF is different from a user-defined scalar function (UDF) in
+/// that it is stateful across batches. UDAFs can be used as normal
+/// aggregate functions as well as window functions (the `OVER` clause)
+///
+/// For more information, please see [the examples]
+///
+/// [the examples]: https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples#single-process
 #[derive(Clone)]
 pub struct AggregateUDF {
     /// name
     pub name: String,
-    /// signature
+    /// Signature (input arguments)
     pub signature: Signature,
     /// Return type
     pub return_type: ReturnTypeFunction,