You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/10 06:24:37 UTC

[arrow-datafusion] branch master updated: Implement retract_batch for AvgAccumulator (#4846)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 13fb42efe Implement retract_batch for AvgAccumulator (#4846)
13fb42efe is described below

commit 13fb42efec4b5ab7f9aa251f1705fdcf89057d23
Author: Jon Mease <jo...@gmail.com>
AuthorDate: Tue Jan 10 01:24:31 2023 -0500

    Implement retract_batch for AvgAccumulator (#4846)
    
    * Implement retract_batch for AvgAccumulator,
    
    Add avg to custom window frame tests
    
    * fmt
---
 datafusion/core/tests/sql/window.rs               | 38 ++++++++++++-----------
 datafusion/physical-expr/src/aggregate/average.rs | 13 ++++++++
 2 files changed, 33 insertions(+), 18 deletions(-)

diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs
index 5ca49cff2..0c3ecfa59 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -503,21 +503,22 @@ async fn window_frame_rows_preceding() -> Result<()> {
     register_aggregate_csv(&ctx).await?;
     let sql = "SELECT \
                SUM(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
+               AVG(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
                COUNT(*) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)\
                FROM aggregate_test_100 \
                ORDER BY c9 \
                LIMIT 5";
     let actual = execute_to_batches(&ctx, sql).await;
     let expected = vec![
-        "+----------------------------+-----------------+",
-        "| SUM(aggregate_test_100.c4) | COUNT(UInt8(1)) |",
-        "+----------------------------+-----------------+",
-        "| -48302                     | 3               |",
-        "| 11243                      | 3               |",
-        "| -51311                     | 3               |",
-        "| -2391                      | 3               |",
-        "| 46756                      | 3               |",
-        "+----------------------------+-----------------+",
+        "+----------------------------+----------------------------+-----------------+",
+        "| SUM(aggregate_test_100.c4) | AVG(aggregate_test_100.c4) | COUNT(UInt8(1)) |",
+        "+----------------------------+----------------------------+-----------------+",
+        "| -48302                     | -16100.666666666666        | 3               |",
+        "| 11243                      | 3747.6666666666665         | 3               |",
+        "| -51311                     | -17103.666666666668        | 3               |",
+        "| -2391                      | -797                       | 3               |",
+        "| 46756                      | 15585.333333333334         | 3               |",
+        "+----------------------------+----------------------------+-----------------+",
     ];
     assert_batches_eq!(expected, &actual);
     Ok(())
@@ -529,21 +530,22 @@ async fn window_frame_rows_preceding_with_partition_unique_order_by() -> Result<
     register_aggregate_csv(&ctx).await?;
     let sql = "SELECT \
                SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
+               AVG(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
                COUNT(*) OVER(PARTITION BY c2 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)\
                FROM aggregate_test_100 \
                ORDER BY c9 \
                LIMIT 5";
     let actual = execute_to_batches(&ctx, sql).await;
     let expected = vec![
-        "+----------------------------+-----------------+",
-        "| SUM(aggregate_test_100.c4) | COUNT(UInt8(1)) |",
-        "+----------------------------+-----------------+",
-        "| -38611                     | 2               |",
-        "| 17547                      | 2               |",
-        "| -1301                      | 2               |",
-        "| 26638                      | 3               |",
-        "| 26861                      | 3               |",
-        "+----------------------------+-----------------+",
+        "+----------------------------+----------------------------+-----------------+",
+        "| SUM(aggregate_test_100.c4) | AVG(aggregate_test_100.c4) | COUNT(UInt8(1)) |",
+        "+----------------------------+----------------------------+-----------------+",
+        "| -38611                     | -19305.5                   | 2               |",
+        "| 17547                      | 8773.5                     | 2               |",
+        "| -1301                      | -650.5                     | 2               |",
+        "| 26638                      | 13319                      | 3               |",
+        "| 26861                      | 8953.666666666666          | 3               |",
+        "+----------------------------+----------------------------+-----------------+",
     ];
     assert_batches_eq!(expected, &actual);
     Ok(())
diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs
index 12f84ca1f..216bd56af 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -25,6 +25,7 @@ use crate::aggregate::row_accumulator::{
     is_row_accumulator_support_dtype, RowAccumulator,
 };
 use crate::aggregate::sum;
+use crate::aggregate::sum::sum_batch;
 use crate::expressions::format_state_name;
 use crate::{AggregateExpr, PhysicalExpr};
 use arrow::compute;
@@ -119,6 +120,10 @@ impl AggregateExpr for Avg {
             self.data_type.clone(),
         )))
     }
+
+    fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(AvgAccumulator::try_new(&self.data_type)?))
+    }
 }
 
 /// An accumulator to compute the average
@@ -154,6 +159,14 @@ impl Accumulator for AvgAccumulator {
         Ok(())
     }
 
+    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let values = &values[0];
+        self.count -= (values.len() - values.data().null_count()) as u64;
+        let delta = sum_batch(values, &self.sum.get_datatype())?;
+        self.sum = self.sum.sub(&delta)?;
+        Ok(())
+    }
+
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
         let counts = downcast_value!(states[0], UInt64Array);
         // counts are summed