You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/08 09:51:18 UTC
[arrow-datafusion] branch master updated: Update variance/stddev to work with single values (#4847)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 3d75bb843 Update variance/stddev to work with single values (#4847)
3d75bb843 is described below
commit 3d75bb84375c83f5fd4bd4d0182c4b0be9d71d4e
Author: Jon Mease <jo...@gmail.com>
AuthorDate: Sun Jan 8 04:51:12 2023 -0500
Update variance/stddev to work with single values (#4847)
* Update variance/stddev to work with single values
Following Postgres:
- var/stddev of single element is NULL
- var_pop/stddev_pop of single element is 0
* Fix tests
* matches! to if let
* fix test_stddev_1_input test
---
.../tests/sqllogictests/test_files/aggregate.slt | 12 +++++++++
datafusion/physical-expr/src/aggregate/stddev.rs | 9 +++----
datafusion/physical-expr/src/aggregate/variance.rs | 30 +++++++++++-----------
3 files changed, 31 insertions(+), 20 deletions(-)
diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
index 7a1b012b8..611fd75ef 100644
--- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
@@ -1091,3 +1091,15 @@ query U
SELECT ARRAY_AGG([1]);
----
[[1]]
+
+# variance_single_value
+query RRRR
+select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0)) as sq;
+----
+NULL 0 NULL 0
+
+# variance_two_values
+query RRRR
+select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0), (3.0)) as sq;
+----
+2 1 1.4142135623730951 1
diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs
index 05dc56cff..4c9e46644 100644
--- a/datafusion/physical-expr/src/aggregate/stddev.rs
+++ b/datafusion/physical-expr/src/aggregate/stddev.rs
@@ -307,8 +307,8 @@ mod tests {
"bla".to_string(),
DataType::Float64,
));
- let actual = aggregate(&batch, agg);
- assert!(actual.is_err());
+ let actual = aggregate(&batch, agg).unwrap();
+ assert_eq!(actual, ScalarValue::Float64(None));
Ok(())
}
@@ -341,9 +341,8 @@ mod tests {
"bla".to_string(),
DataType::Float64,
));
- let actual = aggregate(&batch, agg);
- assert!(actual.is_err());
-
+ let actual = aggregate(&batch, agg).unwrap();
+ assert_eq!(actual, ScalarValue::Float64(None));
Ok(())
}
diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs
index d1ccea7e1..289513744 100644
--- a/datafusion/physical-expr/src/aggregate/variance.rs
+++ b/datafusion/physical-expr/src/aggregate/variance.rs
@@ -286,17 +286,17 @@ impl Accumulator for VarianceAccumulator {
}
};
- if count <= 1 {
- return Err(DataFusionError::Internal(
- "At least two values are needed to calculate variance".to_string(),
- ));
- }
-
- if self.count == 0 {
- Ok(ScalarValue::Float64(None))
- } else {
- Ok(ScalarValue::Float64(Some(self.m2 / count as f64)))
- }
+ Ok(ScalarValue::Float64(match self.count {
+ 0 => None,
+ 1 => {
+ if let StatsType::Population = self.stats_type {
+ Some(0.0)
+ } else {
+ None
+ }
+ }
+ _ => Some(self.m2 / count as f64),
+ }))
}
fn size(&self) -> usize {
@@ -382,8 +382,8 @@ mod tests {
"bla".to_string(),
DataType::Float64,
));
- let actual = aggregate(&batch, agg);
- assert!(actual.is_err());
+ let actual = aggregate(&batch, agg).unwrap();
+ assert_eq!(actual, ScalarValue::Float64(None));
Ok(())
}
@@ -416,8 +416,8 @@ mod tests {
"bla".to_string(),
DataType::Float64,
));
- let actual = aggregate(&batch, agg);
- assert!(actual.is_err());
+ let actual = aggregate(&batch, agg).unwrap();
+ assert_eq!(actual, ScalarValue::Float64(None));
Ok(())
}