You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/10 06:37:46 UTC

[arrow-datafusion] branch master updated: Support using var/var_pop/stddev/stddev_pop in window expressions with custom frames (#4848)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 292eb954f Support using var/var_pop/stddev/stddev_pop in window expressions with custom frames (#4848)
292eb954f is described below

commit 292eb954fc0bad3a1febc597233ba26cb60bda3e
Author: Jon Mease <jo...@gmail.com>
AuthorDate: Tue Jan 10 01:37:41 2023 -0500

    Support using var/var_pop/stddev/stddev_pop in window expressions with custom frames (#4848)
    
    * Wire up retract_batch for Stddev/StddevPop/Variance/VariancePop to
    
    * Add test for Stddev/StddevPop/Variance/VariancePop with window frame
---
 datafusion/core/tests/sql/window.rs                | 28 ++++++++++++++++++++++
 datafusion/physical-expr/src/aggregate/stddev.rs   | 12 ++++++++++
 datafusion/physical-expr/src/aggregate/variance.rs | 10 ++++++++
 3 files changed, 50 insertions(+)

diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs
index 0c3ecfa59..1167d57a4 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -524,6 +524,34 @@ async fn window_frame_rows_preceding() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn window_frame_rows_preceding_stddev_variance() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_aggregate_csv(&ctx).await?;
+    let sql = "SELECT \
+               VAR(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
+               VAR_POP(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
+               STDDEV(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
+               STDDEV_POP(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)\
+               FROM aggregate_test_100 \
+               ORDER BY c9 \
+               LIMIT 5";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+---------------------------------+------------------------------------+-------------------------------+----------------------------------+",
+        "| VARIANCE(aggregate_test_100.c4) | VARIANCEPOP(aggregate_test_100.c4) | STDDEV(aggregate_test_100.c4) | STDDEVPOP(aggregate_test_100.c4) |",
+        "+---------------------------------+------------------------------------+-------------------------------+----------------------------------+",
+        "| 46721.33333333174               | 31147.555555554496                 | 216.15118166073427            | 176.4867007894773                |",
+        "| 2639429.333333332               | 1759619.5555555548                 | 1624.6320609089714            | 1326.5065229977404               |",
+        "| 746202.3333333324               | 497468.2222222216                  | 863.8300372951455             | 705.3142719541563                |",
+        "| 768422.9999999981               | 512281.9999999988                  | 876.5973990378925             | 715.7387791645767                |",
+        "| 66526.3333333288                | 44350.88888888587                  | 257.9269922542594             | 210.5965073045749                |",
+        "+---------------------------------+------------------------------------+-------------------------------+----------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
 #[tokio::test]
 async fn window_frame_rows_preceding_with_partition_unique_order_by() -> Result<()> {
     let ctx = SessionContext::new();
diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs
index 4c9e46644..dab84b14a 100644
--- a/datafusion/physical-expr/src/aggregate/stddev.rs
+++ b/datafusion/physical-expr/src/aggregate/stddev.rs
@@ -73,6 +73,10 @@ impl AggregateExpr for Stddev {
         Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
     }
 
+    fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
+    }
+
     fn state_fields(&self) -> Result<Vec<Field>> {
         Ok(vec![
             Field::new(
@@ -128,6 +132,10 @@ impl AggregateExpr for StddevPop {
         Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
     }
 
+    fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
+    }
+
     fn state_fields(&self) -> Result<Vec<Field>> {
         Ok(vec![
             Field::new(
@@ -184,6 +192,10 @@ impl Accumulator for StddevAccumulator {
         self.variance.update_batch(values)
     }
 
+    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        self.variance.retract_batch(values)
+    }
+
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
         self.variance.merge_batch(states)
     }
diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs
index 289513744..657103e43 100644
--- a/datafusion/physical-expr/src/aggregate/variance.rs
+++ b/datafusion/physical-expr/src/aggregate/variance.rs
@@ -79,6 +79,10 @@ impl AggregateExpr for Variance {
         Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
     }
 
+    fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
+    }
+
     fn state_fields(&self) -> Result<Vec<Field>> {
         Ok(vec![
             Field::new(
@@ -136,6 +140,12 @@ impl AggregateExpr for VariancePop {
         )?))
     }
 
+    fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(VarianceAccumulator::try_new(
+            StatsType::Population,
+        )?))
+    }
+
     fn state_fields(&self) -> Result<Vec<Field>> {
         Ok(vec![
             Field::new(