You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/08 09:51:18 UTC

[arrow-datafusion] branch master updated: Update variance/stddev to work with single values (#4847)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 3d75bb843 Update variance/stddev to work with single values (#4847)
3d75bb843 is described below

commit 3d75bb84375c83f5fd4bd4d0182c4b0be9d71d4e
Author: Jon Mease <jo...@gmail.com>
AuthorDate: Sun Jan 8 04:51:12 2023 -0500

    Update variance/stddev to work with single values (#4847)
    
    * Update variance/stddev to work with single values
    
    Following Postgres:
     - var/stddev of single element is NULL
     - var_pop/stddev_pop of single element is 0
    
    * Fix tests
    
    * matches! to if let
    
    * fix test_stddev_1_input test
---
 .../tests/sqllogictests/test_files/aggregate.slt   | 12 +++++++++
 datafusion/physical-expr/src/aggregate/stddev.rs   |  9 +++----
 datafusion/physical-expr/src/aggregate/variance.rs | 30 +++++++++++-----------
 3 files changed, 31 insertions(+), 20 deletions(-)

diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
index 7a1b012b8..611fd75ef 100644
--- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
@@ -1091,3 +1091,15 @@ query U
 SELECT ARRAY_AGG([1]);
 ----
 [[1]]
+
+# variance_single_value
+query RRRR
+select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0)) as sq;
+----
+NULL 0 NULL 0
+
+# variance_two_values
+query RRRR
+select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0), (3.0)) as sq;
+----
+2 1 1.4142135623730951 1
diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs
index 05dc56cff..4c9e46644 100644
--- a/datafusion/physical-expr/src/aggregate/stddev.rs
+++ b/datafusion/physical-expr/src/aggregate/stddev.rs
@@ -307,8 +307,8 @@ mod tests {
             "bla".to_string(),
             DataType::Float64,
         ));
-        let actual = aggregate(&batch, agg);
-        assert!(actual.is_err());
+        let actual = aggregate(&batch, agg).unwrap();
+        assert_eq!(actual, ScalarValue::Float64(None));
 
         Ok(())
     }
@@ -341,9 +341,8 @@ mod tests {
             "bla".to_string(),
             DataType::Float64,
         ));
-        let actual = aggregate(&batch, agg);
-        assert!(actual.is_err());
-
+        let actual = aggregate(&batch, agg).unwrap();
+        assert_eq!(actual, ScalarValue::Float64(None));
         Ok(())
     }
 
diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs
index d1ccea7e1..289513744 100644
--- a/datafusion/physical-expr/src/aggregate/variance.rs
+++ b/datafusion/physical-expr/src/aggregate/variance.rs
@@ -286,17 +286,17 @@ impl Accumulator for VarianceAccumulator {
             }
         };
 
-        if count <= 1 {
-            return Err(DataFusionError::Internal(
-                "At least two values are needed to calculate variance".to_string(),
-            ));
-        }
-
-        if self.count == 0 {
-            Ok(ScalarValue::Float64(None))
-        } else {
-            Ok(ScalarValue::Float64(Some(self.m2 / count as f64)))
-        }
+        Ok(ScalarValue::Float64(match self.count {
+            0 => None,
+            1 => {
+                if let StatsType::Population = self.stats_type {
+                    Some(0.0)
+                } else {
+                    None
+                }
+            }
+            _ => Some(self.m2 / count as f64),
+        }))
     }
 
     fn size(&self) -> usize {
@@ -382,8 +382,8 @@ mod tests {
             "bla".to_string(),
             DataType::Float64,
         ));
-        let actual = aggregate(&batch, agg);
-        assert!(actual.is_err());
+        let actual = aggregate(&batch, agg).unwrap();
+        assert_eq!(actual, ScalarValue::Float64(None));
 
         Ok(())
     }
@@ -416,8 +416,8 @@ mod tests {
             "bla".to_string(),
             DataType::Float64,
         ));
-        let actual = aggregate(&batch, agg);
-        assert!(actual.is_err());
+        let actual = aggregate(&batch, agg).unwrap();
+        assert_eq!(actual, ScalarValue::Float64(None));
 
         Ok(())
     }