You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/09 12:04:33 UTC
[arrow-datafusion] branch master updated: Fix panic in median "AggregateState is not a scalar aggregate" (#4488)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 31bbe6c3e Fix panic in median "AggregateState is not a scalar aggregate" (#4488)
31bbe6c3e is described below
commit 31bbe6c3eec245127b005d8afc8766e2d06d811b
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Fri Dec 9 07:04:28 2022 -0500
Fix panic in median "AggregateState is not a scalar aggregate" (#4488)
* Fix panic in median "AggregateState is not a scalar aggregate"
* Apply suggestions from code review
Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
---
datafusion/common/src/scalar.rs | 6 +-
.../tests/sqllogictests/test_files/aggregate.slt | 70 +++++-
datafusion/physical-expr/src/aggregate/median.rs | 238 +++++++++------------
3 files changed, 173 insertions(+), 141 deletions(-)
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 449c75c25..b6a49fbf6 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -721,7 +721,7 @@ impl std::hash::Hash for ScalarValue {
/// dictionary array
#[inline]
fn get_dict_value<K: ArrowDictionaryKeyType>(
- array: &ArrayRef,
+ array: &dyn Array,
index: usize,
) -> (&ArrayRef, Option<usize>) {
let dict_array = as_dictionary_array::<K>(array).unwrap();
@@ -1963,7 +1963,7 @@ impl ScalarValue {
}
fn get_decimal_value_from_array(
- array: &ArrayRef,
+ array: &dyn Array,
index: usize,
precision: u8,
scale: i8,
@@ -1978,7 +1978,7 @@ impl ScalarValue {
}
/// Converts a value in `array` at `index` into a ScalarValue
- pub fn try_from_array(array: &ArrayRef, index: usize) -> Result<Self> {
+ pub fn try_from_array(array: &dyn Array, index: usize) -> Result<Self> {
// handle NULL value
if !array.is_valid(index) {
return array.data_type().try_into();
diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
index c0b747702..c1cefd70e 100644
--- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
@@ -79,7 +79,7 @@ SELECT stddev_pop(c2) FROM aggregate_test_100
1.3665650368716449
# csv_query_stddev_2
-query R
+query R
SELECT stddev_pop(c6) FROM aggregate_test_100
----
5.114326382039172e18
@@ -216,6 +216,70 @@ SELECT approx_median(a) FROM median_f64_nan
----
NaN
+# median_multi
+# test case for https://github.com/apache/arrow-datafusion/issues/3105
+# has an intermediate grouping
+statement ok
+create table cpu (host string, usage float) as select * from (values
+('host0', 90.1),
+('host1', 90.2),
+('host1', 90.4)
+);
+
+query CI rowsort
+select host, median(usage) from cpu group by host;
+----
+host1 90.3
+host0 90.1
+
+query CI
+select median(usage) from cpu;
+----
+90.2
+
+
+statement ok
+drop table cpu;
+
+# median_multi_odd
+
+# data is not sorted and has an odd number of values per group
+statement ok
+create table cpu (host string, usage float) as select * from (values
+ ('host0', 90.2),
+ ('host1', 90.1),
+ ('host1', 90.5),
+ ('host0', 90.5),
+ ('host1', 90.0),
+ ('host1', 90.3),
+ ('host0', 87.9),
+ ('host1', 89.3)
+);
+
+query CI rowsort
+select host, median(usage) from cpu group by host;
+----
+host0 90.2
+host1 90.1
+
+
+statement ok
+drop table cpu;
+
+# median_multi_even
+# data is not sorted and has an odd number of values per group
+statement ok
+create table cpu (host string, usage float) as select * from (values ('host0', 90.2), ('host1', 90.1), ('host1', 90.5), ('host0', 90.5), ('host1', 90.0), ('host1', 90.3), ('host1', 90.2), ('host1', 90.3));
+
+query CI rowsort
+select host, median(usage) from cpu group by host;
+----
+host1 90.25
+host0 90.35
+
+statement ok
+drop table cpu
+
# csv_query_external_table_count
query I
SELECT COUNT(c12) FROM aggregate_test_100
@@ -818,7 +882,7 @@ select c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count
# SELECT array_agg(c13 ORDER BY c1) FROM aggregate_test_100;
# csv_query_array_cube_agg_with_overflow
-query TIIRIII
+query TIIRIII
select c1, c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count(c3) count_c3 from aggregate_test_100 group by CUBE (c1,c2) order by c1, c2
----
a 1 -88 -17.6 83 -85 5
@@ -870,7 +934,7 @@ e 847 40.333333333333336 120 -95 21
# query IIII
# SELECT count(nanos), count(micros), count(millis), count(secs) FROM t
# ----
-# 3 3 3 3
+# 3 3 3 3
# aggregate_timestamps_min
# query TTTT
diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs
index a04bd5369..abde3702f 100644
--- a/datafusion/physical-expr/src/aggregate/median.rs
+++ b/datafusion/physical-expr/src/aggregate/median.rs
@@ -19,13 +19,9 @@
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
-use arrow::array::{Array, ArrayRef, PrimitiveBuilder};
-use arrow::compute::sort;
-use arrow::datatypes::{
- ArrowPrimitiveType, DataType, Field, Float32Type, Float64Type, Int16Type, Int32Type,
- Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
-};
-use datafusion_common::cast::as_primitive_array;
+use arrow::array::{Array, ArrayRef, UInt32Array};
+use arrow::compute::sort_to_indices;
+use arrow::datatypes::{DataType, Field};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::{Accumulator, AggregateState};
use std::any::Any;
@@ -74,9 +70,13 @@ impl AggregateExpr for Median {
}
fn state_fields(&self) -> Result<Vec<Field>> {
+ //Intermediate state is a list of the elements we have collected so far
+ let field = Field::new("item", self.data_type.clone(), true);
+ let data_type = DataType::List(Box::new(field));
+
Ok(vec![Field::new(
&format_state_name(&self.name, "median"),
- self.data_type.clone(),
+ data_type,
true,
)])
}
@@ -91,158 +91,126 @@ impl AggregateExpr for Median {
}
#[derive(Debug)]
+/// The median accumulator accumulates the raw input values
+/// as `ScalarValue`s
+///
+/// The intermediate state is represented as a List of those scalars
struct MedianAccumulator {
data_type: DataType,
- all_values: Vec<ArrayRef>,
-}
-
-macro_rules! median {
- ($SELF:ident, $TY:ty, $SCALAR_TY:ident, $TWO:expr) => {{
- let combined = combine_arrays::<$TY>($SELF.all_values.as_slice())?;
- if combined.is_empty() {
- return Ok(ScalarValue::Null);
- }
- let sorted = sort(&combined, None)?;
- let array = as_primitive_array::<$TY>(&sorted)?;
- let len = sorted.len();
- let mid = len / 2;
- if len % 2 == 0 {
- Ok(ScalarValue::$SCALAR_TY(Some(
- (array.value(mid - 1) + array.value(mid)) / $TWO,
- )))
- } else {
- Ok(ScalarValue::$SCALAR_TY(Some(array.value(mid))))
- }
- }};
+ all_values: Vec<ScalarValue>,
}
impl Accumulator for MedianAccumulator {
fn state(&self) -> Result<Vec<AggregateState>> {
- let mut vec: Vec<AggregateState> = self
- .all_values
- .iter()
- .map(|v| AggregateState::Array(v.clone()))
- .collect();
- if vec.is_empty() {
- match self.data_type {
- DataType::UInt8 => vec.push(empty_array::<UInt8Type>()),
- DataType::UInt16 => vec.push(empty_array::<UInt16Type>()),
- DataType::UInt32 => vec.push(empty_array::<UInt32Type>()),
- DataType::UInt64 => vec.push(empty_array::<UInt64Type>()),
- DataType::Int8 => vec.push(empty_array::<Int8Type>()),
- DataType::Int16 => vec.push(empty_array::<Int16Type>()),
- DataType::Int32 => vec.push(empty_array::<Int32Type>()),
- DataType::Int64 => vec.push(empty_array::<Int64Type>()),
- DataType::Float32 => vec.push(empty_array::<Float32Type>()),
- DataType::Float64 => vec.push(empty_array::<Float64Type>()),
- _ => {
- return Err(DataFusionError::Execution(
- "unsupported data type for median".to_string(),
- ))
- }
- }
- }
- Ok(vec)
+ let state =
+ ScalarValue::new_list(Some(self.all_values.clone()), self.data_type.clone());
+ Ok(vec![AggregateState::Scalar(state)])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- let x = values[0].clone();
- self.all_values.extend_from_slice(&[x]);
+ assert_eq!(values.len(), 1);
+ let array = &values[0];
+
+ assert_eq!(array.data_type(), &self.data_type);
+ self.all_values.reserve(self.all_values.len() + array.len());
+ for index in 0..array.len() {
+ self.all_values
+ .push(ScalarValue::try_from_array(array, index)?);
+ }
+
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
- for array in states {
- self.all_values.extend_from_slice(&[array.clone()]);
+ assert_eq!(states.len(), 1);
+
+ let array = &states[0];
+ assert!(matches!(array.data_type(), DataType::List(_)));
+ for index in 0..array.len() {
+ match ScalarValue::try_from_array(array, index)? {
+ ScalarValue::List(Some(mut values), _) => {
+ self.all_values.append(&mut values);
+ }
+ ScalarValue::List(None, _) => {} // skip empty state
+ v => {
+ return Err(DataFusionError::Internal(format!(
+ "unexpected state in median. Expected DataType::List, got {:?}",
+ v
+ )))
+ }
+ }
}
Ok(())
}
fn evaluate(&self) -> Result<ScalarValue> {
- match self.all_values[0].data_type() {
- DataType::Int8 => median!(self, arrow::datatypes::Int8Type, Int8, 2),
- DataType::Int16 => median!(self, arrow::datatypes::Int16Type, Int16, 2),
- DataType::Int32 => median!(self, arrow::datatypes::Int32Type, Int32, 2),
- DataType::Int64 => median!(self, arrow::datatypes::Int64Type, Int64, 2),
- DataType::UInt8 => median!(self, arrow::datatypes::UInt8Type, UInt8, 2),
- DataType::UInt16 => median!(self, arrow::datatypes::UInt16Type, UInt16, 2),
- DataType::UInt32 => median!(self, arrow::datatypes::UInt32Type, UInt32, 2),
- DataType::UInt64 => median!(self, arrow::datatypes::UInt64Type, UInt64, 2),
- DataType::Float32 => {
- median!(self, arrow::datatypes::Float32Type, Float32, 2_f32)
- }
- DataType::Float64 => {
- median!(self, arrow::datatypes::Float64Type, Float64, 2_f64)
+ // Create an array of all the non null values and find the
+ // sorted indexes
+ let array = ScalarValue::iter_to_array(
+ self.all_values
+ .iter()
+ // ignore null values
+ .filter(|v| !v.is_null())
+ .cloned(),
+ )?;
+
+ // find the mid point
+ let len = array.len();
+ let mid = len / 2;
+
+ // only sort up to the top size/2 elements
+ let limit = Some(mid + 1);
+ let options = None;
+ let indices = sort_to_indices(&array, options, limit)?;
+
+ // pick the relevant indices in the original arrays
+ let result = if len >= 2 && len % 2 == 0 {
+ // even number of values, average the two mid points
+ let s1 = scalar_at_index(&array, &indices, mid - 1)?;
+ let s2 = scalar_at_index(&array, &indices, mid)?;
+ match s1.add(s2)? {
+ ScalarValue::Int8(Some(v)) => ScalarValue::Int8(Some(v / 2)),
+ ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(v / 2)),
+ ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(v / 2)),
+ ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(v / 2)),
+ ScalarValue::UInt8(Some(v)) => ScalarValue::UInt8(Some(v / 2)),
+ ScalarValue::UInt16(Some(v)) => ScalarValue::UInt16(Some(v / 2)),
+ ScalarValue::UInt32(Some(v)) => ScalarValue::UInt32(Some(v / 2)),
+ ScalarValue::UInt64(Some(v)) => ScalarValue::UInt64(Some(v / 2)),
+ ScalarValue::Float32(Some(v)) => ScalarValue::Float32(Some(v / 2.0)),
+ ScalarValue::Float64(Some(v)) => ScalarValue::Float64(Some(v / 2.0)),
+ v => {
+ return Err(DataFusionError::Internal(format!(
+ "Unsupported type in MedianAccumulator: {:?}",
+ v
+ )))
+ }
}
- _ => Err(DataFusionError::Execution(
- "unsupported data type for median".to_string(),
- )),
- }
+ } else {
+ // odd number of values, pick that one
+ scalar_at_index(&array, &indices, mid)?
+ };
+
+ Ok(result)
}
fn size(&self) -> usize {
- std::mem::align_of_val(self)
- + (std::mem::size_of::<ArrayRef>() * self.all_values.capacity())
- + self
- .all_values
- .iter()
- .map(|array_ref| {
- std::mem::size_of_val(array_ref.as_ref())
- + array_ref.get_array_memory_size()
- })
- .sum::<usize>()
+ std::mem::size_of_val(self) + ScalarValue::size_of_vec(&self.all_values)
+ - std::mem::size_of_val(&self.all_values)
+ self.data_type.size()
- std::mem::size_of_val(&self.data_type)
}
}
-/// Create an empty array
-fn empty_array<T: ArrowPrimitiveType>() -> AggregateState {
- AggregateState::Array(Arc::new(PrimitiveBuilder::<T>::with_capacity(0).finish()))
-}
-
-/// Combine all non-null values from provided arrays into a single array
-fn combine_arrays<T: ArrowPrimitiveType>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
- let len = arrays.iter().map(|a| a.len() - a.null_count()).sum();
- let mut builder: PrimitiveBuilder<T> = PrimitiveBuilder::with_capacity(len);
- for array in arrays {
- let array = as_primitive_array::<T>(array)?;
- for i in 0..array.len() {
- if !array.is_null(i) {
- builder.append_value(array.value(i));
- }
- }
- }
- Ok(Arc::new(builder.finish()))
-}
-
-#[cfg(test)]
-mod test {
- use crate::aggregate::median::combine_arrays;
- use arrow::array::{Int32Array, UInt32Array};
- use arrow::datatypes::{Int32Type, UInt32Type};
- use datafusion_common::Result;
- use std::sync::Arc;
-
- #[test]
- fn combine_i32_array() -> Result<()> {
- let a = Arc::new(Int32Array::from(vec![1, 2, 3]));
- let b = combine_arrays::<Int32Type>(&[a.clone(), a])?;
- assert_eq!(
- "PrimitiveArray<Int32>\n[\n 1,\n 2,\n 3,\n 1,\n 2,\n 3,\n]",
- format!("{:?}", b)
- );
- Ok(())
- }
-
- #[test]
- fn combine_u32_array() -> Result<()> {
- let a = Arc::new(UInt32Array::from(vec![1, 2, 3]));
- let b = combine_arrays::<UInt32Type>(&[a.clone(), a])?;
- assert_eq!(
- "PrimitiveArray<UInt32>\n[\n 1,\n 2,\n 3,\n 1,\n 2,\n 3,\n]",
- format!("{:?}", b)
- );
- Ok(())
- }
+/// Given a returns `array[indicies[indicie_index]]` as a `ScalarValue`
+fn scalar_at_index(
+ array: &dyn Array,
+ indices: &UInt32Array,
+ indicies_index: usize,
+) -> Result<ScalarValue> {
+ let array_index = indices
+ .value(indicies_index)
+ .try_into()
+ .expect("Convert uint32 to usize");
+ ScalarValue::try_from_array(array, array_index)
}