You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/10 06:24:37 UTC
[arrow-datafusion] branch master updated: Implement retract_batch for AvgAccumulator (#4846)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 13fb42efe Implement retract_batch for AvgAccumulator (#4846)
13fb42efe is described below
commit 13fb42efec4b5ab7f9aa251f1705fdcf89057d23
Author: Jon Mease <jo...@gmail.com>
AuthorDate: Tue Jan 10 01:24:31 2023 -0500
Implement retract_batch for AvgAccumulator (#4846)
* Implement retract_batch for AvgAccumulator,
Add avg to custom window frame tests
* fmt
---
datafusion/core/tests/sql/window.rs | 38 ++++++++++++-----------
datafusion/physical-expr/src/aggregate/average.rs | 13 ++++++++
2 files changed, 33 insertions(+), 18 deletions(-)
diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs
index 5ca49cff2..0c3ecfa59 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -503,21 +503,22 @@ async fn window_frame_rows_preceding() -> Result<()> {
register_aggregate_csv(&ctx).await?;
let sql = "SELECT \
SUM(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
+ AVG(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
COUNT(*) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)\
FROM aggregate_test_100 \
ORDER BY c9 \
LIMIT 5";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+----------------------------+-----------------+",
- "| SUM(aggregate_test_100.c4) | COUNT(UInt8(1)) |",
- "+----------------------------+-----------------+",
- "| -48302 | 3 |",
- "| 11243 | 3 |",
- "| -51311 | 3 |",
- "| -2391 | 3 |",
- "| 46756 | 3 |",
- "+----------------------------+-----------------+",
+ "+----------------------------+----------------------------+-----------------+",
+ "| SUM(aggregate_test_100.c4) | AVG(aggregate_test_100.c4) | COUNT(UInt8(1)) |",
+ "+----------------------------+----------------------------+-----------------+",
+ "| -48302 | -16100.666666666666 | 3 |",
+ "| 11243 | 3747.6666666666665 | 3 |",
+ "| -51311 | -17103.666666666668 | 3 |",
+ "| -2391 | -797 | 3 |",
+ "| 46756 | 15585.333333333334 | 3 |",
+ "+----------------------------+----------------------------+-----------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
@@ -529,21 +530,22 @@ async fn window_frame_rows_preceding_with_partition_unique_order_by() -> Result<
register_aggregate_csv(&ctx).await?;
let sql = "SELECT \
SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
+ AVG(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
COUNT(*) OVER(PARTITION BY c2 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)\
FROM aggregate_test_100 \
ORDER BY c9 \
LIMIT 5";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+----------------------------+-----------------+",
- "| SUM(aggregate_test_100.c4) | COUNT(UInt8(1)) |",
- "+----------------------------+-----------------+",
- "| -38611 | 2 |",
- "| 17547 | 2 |",
- "| -1301 | 2 |",
- "| 26638 | 3 |",
- "| 26861 | 3 |",
- "+----------------------------+-----------------+",
+ "+----------------------------+----------------------------+-----------------+",
+ "| SUM(aggregate_test_100.c4) | AVG(aggregate_test_100.c4) | COUNT(UInt8(1)) |",
+ "+----------------------------+----------------------------+-----------------+",
+ "| -38611 | -19305.5 | 2 |",
+ "| 17547 | 8773.5 | 2 |",
+ "| -1301 | -650.5 | 2 |",
+ "| 26638 | 13319 | 3 |",
+ "| 26861 | 8953.666666666666 | 3 |",
+ "+----------------------------+----------------------------+-----------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs
index 12f84ca1f..216bd56af 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -25,6 +25,7 @@ use crate::aggregate::row_accumulator::{
is_row_accumulator_support_dtype, RowAccumulator,
};
use crate::aggregate::sum;
+use crate::aggregate::sum::sum_batch;
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use arrow::compute;
@@ -119,6 +120,10 @@ impl AggregateExpr for Avg {
self.data_type.clone(),
)))
}
+
+ fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(AvgAccumulator::try_new(&self.data_type)?))
+ }
}
/// An accumulator to compute the average
@@ -154,6 +159,14 @@ impl Accumulator for AvgAccumulator {
Ok(())
}
+ fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values = &values[0];
+ self.count -= (values.len() - values.data().null_count()) as u64;
+ let delta = sum_batch(values, &self.sum.get_datatype())?;
+ self.sum = self.sum.sub(&delta)?;
+ Ok(())
+ }
+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let counts = downcast_value!(states[0], UInt64Array);
// counts are summed