You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/10 06:37:46 UTC
[arrow-datafusion] branch master updated: Support using var/var_pop/stddev/stddev_pop in window expressions with custom frames (#4848)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 292eb954f Support using var/var_pop/stddev/stddev_pop in window expressions with custom frames (#4848)
292eb954f is described below
commit 292eb954fc0bad3a1febc597233ba26cb60bda3e
Author: Jon Mease <jo...@gmail.com>
AuthorDate: Tue Jan 10 01:37:41 2023 -0500
Support using var/var_pop/stddev/stddev_pop in window expressions with custom frames (#4848)
* Wire up retract_batch for Stddev/StddevPop/Variance/VariancePop to
* Add test for Stddev/StddevPop/Variance/VariancePop with window frame
---
datafusion/core/tests/sql/window.rs | 28 ++++++++++++++++++++++
datafusion/physical-expr/src/aggregate/stddev.rs | 12 ++++++++++
datafusion/physical-expr/src/aggregate/variance.rs | 10 ++++++++
3 files changed, 50 insertions(+)
diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs
index 0c3ecfa59..1167d57a4 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -524,6 +524,34 @@ async fn window_frame_rows_preceding() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn window_frame_rows_preceding_stddev_variance() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT \
+ VAR(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
+ VAR_POP(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
+ STDDEV(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
+ STDDEV_POP(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)\
+ FROM aggregate_test_100 \
+ ORDER BY c9 \
+ LIMIT 5";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+---------------------------------+------------------------------------+-------------------------------+----------------------------------+",
+ "| VARIANCE(aggregate_test_100.c4) | VARIANCEPOP(aggregate_test_100.c4) | STDDEV(aggregate_test_100.c4) | STDDEVPOP(aggregate_test_100.c4) |",
+ "+---------------------------------+------------------------------------+-------------------------------+----------------------------------+",
+ "| 46721.33333333174 | 31147.555555554496 | 216.15118166073427 | 176.4867007894773 |",
+ "| 2639429.333333332 | 1759619.5555555548 | 1624.6320609089714 | 1326.5065229977404 |",
+ "| 746202.3333333324 | 497468.2222222216 | 863.8300372951455 | 705.3142719541563 |",
+ "| 768422.9999999981 | 512281.9999999988 | 876.5973990378925 | 715.7387791645767 |",
+ "| 66526.3333333288 | 44350.88888888587 | 257.9269922542594 | 210.5965073045749 |",
+ "+---------------------------------+------------------------------------+-------------------------------+----------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
#[tokio::test]
async fn window_frame_rows_preceding_with_partition_unique_order_by() -> Result<()> {
let ctx = SessionContext::new();
diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs
index 4c9e46644..dab84b14a 100644
--- a/datafusion/physical-expr/src/aggregate/stddev.rs
+++ b/datafusion/physical-expr/src/aggregate/stddev.rs
@@ -73,6 +73,10 @@ impl AggregateExpr for Stddev {
Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
}
+ fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
+ }
+
fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![
Field::new(
@@ -128,6 +132,10 @@ impl AggregateExpr for StddevPop {
Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
}
+ fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
+ }
+
fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![
Field::new(
@@ -184,6 +192,10 @@ impl Accumulator for StddevAccumulator {
self.variance.update_batch(values)
}
+ fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ self.variance.retract_batch(values)
+ }
+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.variance.merge_batch(states)
}
diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs
index 289513744..657103e43 100644
--- a/datafusion/physical-expr/src/aggregate/variance.rs
+++ b/datafusion/physical-expr/src/aggregate/variance.rs
@@ -79,6 +79,10 @@ impl AggregateExpr for Variance {
Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
}
+ fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
+ }
+
fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![
Field::new(
@@ -136,6 +140,12 @@ impl AggregateExpr for VariancePop {
)?))
}
+ fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(VarianceAccumulator::try_new(
+ StatsType::Population,
+ )?))
+ }
+
fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![
Field::new(