You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/09 12:04:33 UTC

[arrow-datafusion] branch master updated: Fix panic in median "AggregateState is not a scalar aggregate" (#4488)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 31bbe6c3e Fix panic in median "AggregateState is not a scalar aggregate" (#4488)
31bbe6c3e is described below

commit 31bbe6c3eec245127b005d8afc8766e2d06d811b
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Fri Dec 9 07:04:28 2022 -0500

    Fix panic in median "AggregateState is not a scalar aggregate" (#4488)
    
    * Fix panic in median "AggregateState is not a scalar aggregate"
    
    * Apply suggestions from code review
    
    Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
    
    Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
---
 datafusion/common/src/scalar.rs                    |   6 +-
 .../tests/sqllogictests/test_files/aggregate.slt   |  70 +++++-
 datafusion/physical-expr/src/aggregate/median.rs   | 238 +++++++++------------
 3 files changed, 173 insertions(+), 141 deletions(-)

diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 449c75c25..b6a49fbf6 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -721,7 +721,7 @@ impl std::hash::Hash for ScalarValue {
 /// dictionary array
 #[inline]
 fn get_dict_value<K: ArrowDictionaryKeyType>(
-    array: &ArrayRef,
+    array: &dyn Array,
     index: usize,
 ) -> (&ArrayRef, Option<usize>) {
     let dict_array = as_dictionary_array::<K>(array).unwrap();
@@ -1963,7 +1963,7 @@ impl ScalarValue {
     }
 
     fn get_decimal_value_from_array(
-        array: &ArrayRef,
+        array: &dyn Array,
         index: usize,
         precision: u8,
         scale: i8,
@@ -1978,7 +1978,7 @@ impl ScalarValue {
     }
 
     /// Converts a value in `array` at `index` into a ScalarValue
-    pub fn try_from_array(array: &ArrayRef, index: usize) -> Result<Self> {
+    pub fn try_from_array(array: &dyn Array, index: usize) -> Result<Self> {
         // handle NULL value
         if !array.is_valid(index) {
             return array.data_type().try_into();
diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
index c0b747702..c1cefd70e 100644
--- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
@@ -79,7 +79,7 @@ SELECT stddev_pop(c2) FROM aggregate_test_100
 1.3665650368716449
 
 # csv_query_stddev_2
-query R 
+query R
 SELECT stddev_pop(c6) FROM aggregate_test_100
 ----
 5.114326382039172e18
@@ -216,6 +216,70 @@ SELECT approx_median(a) FROM median_f64_nan
 ----
 NaN
 
+# median_multi
+# test case for https://github.com/apache/arrow-datafusion/issues/3105
+# has an intermediate grouping
+statement ok
+create table cpu (host string, usage float) as select * from (values
+('host0', 90.1),
+('host1', 90.2),
+('host1', 90.4)
+);
+
+query CI rowsort
+select host, median(usage) from cpu group by host;
+----
+host1 90.3
+host0 90.1
+
+query CI
+select median(usage) from cpu;
+----
+90.2
+
+
+statement ok
+drop table cpu;
+
+# median_multi_odd
+
+# data is not sorted and has an odd number of values per group
+statement ok
+create table cpu (host string, usage float) as select * from (values
+  ('host0', 90.2),
+  ('host1', 90.1),
+  ('host1', 90.5),
+  ('host0', 90.5),
+  ('host1', 90.0),
+  ('host1', 90.3),
+  ('host0', 87.9),
+  ('host1', 89.3)
+);
+
+query CI rowsort
+select host, median(usage) from cpu group by host;
+----
+host0 90.2
+host1 90.1
+
+
+statement ok
+drop table cpu;
+
+# median_multi_even
+# data is not sorted and has an odd number of values per group
+statement ok
+create table cpu (host string, usage float) as select * from (values ('host0', 90.2), ('host1', 90.1), ('host1', 90.5), ('host0', 90.5), ('host1', 90.0), ('host1', 90.3), ('host1', 90.2), ('host1', 90.3));
+
+query CI rowsort
+select host, median(usage) from cpu group by host;
+----
+host1 90.25
+host0 90.35
+
+statement ok
+drop table cpu
+
 # csv_query_external_table_count
 query I
 SELECT COUNT(c12) FROM aggregate_test_100
@@ -818,7 +882,7 @@ select c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count
 # SELECT array_agg(c13 ORDER BY c1) FROM aggregate_test_100;
 
 # csv_query_array_cube_agg_with_overflow
-query TIIRIII 
+query TIIRIII
 select c1, c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count(c3) count_c3 from aggregate_test_100 group by CUBE (c1,c2) order by c1, c2
 ----
 a 1 -88  -17.6               83  -85  5
@@ -870,7 +934,7 @@ e   847  40.333333333333336  120 -95  21
 # query IIII
 # SELECT count(nanos), count(micros), count(millis), count(secs) FROM t
 # ----
-# 3 3 3 3 
+# 3 3 3 3
 
 # aggregate_timestamps_min
 # query TTTT
diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs
index a04bd5369..abde3702f 100644
--- a/datafusion/physical-expr/src/aggregate/median.rs
+++ b/datafusion/physical-expr/src/aggregate/median.rs
@@ -19,13 +19,9 @@
 
 use crate::expressions::format_state_name;
 use crate::{AggregateExpr, PhysicalExpr};
-use arrow::array::{Array, ArrayRef, PrimitiveBuilder};
-use arrow::compute::sort;
-use arrow::datatypes::{
-    ArrowPrimitiveType, DataType, Field, Float32Type, Float64Type, Int16Type, Int32Type,
-    Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
-};
-use datafusion_common::cast::as_primitive_array;
+use arrow::array::{Array, ArrayRef, UInt32Array};
+use arrow::compute::sort_to_indices;
+use arrow::datatypes::{DataType, Field};
 use datafusion_common::{DataFusionError, Result, ScalarValue};
 use datafusion_expr::{Accumulator, AggregateState};
 use std::any::Any;
@@ -74,9 +70,13 @@ impl AggregateExpr for Median {
     }
 
     fn state_fields(&self) -> Result<Vec<Field>> {
+        //Intermediate state is a list of the elements we have collected so far
+        let field = Field::new("item", self.data_type.clone(), true);
+        let data_type = DataType::List(Box::new(field));
+
         Ok(vec![Field::new(
             &format_state_name(&self.name, "median"),
-            self.data_type.clone(),
+            data_type,
             true,
         )])
     }
@@ -91,158 +91,126 @@ impl AggregateExpr for Median {
 }
 
 #[derive(Debug)]
+/// The median accumulator accumulates the raw input values
+/// as `ScalarValue`s
+///
+/// The intermediate state is represented as a List of those scalars
 struct MedianAccumulator {
     data_type: DataType,
-    all_values: Vec<ArrayRef>,
-}
-
-macro_rules! median {
-    ($SELF:ident, $TY:ty, $SCALAR_TY:ident, $TWO:expr) => {{
-        let combined = combine_arrays::<$TY>($SELF.all_values.as_slice())?;
-        if combined.is_empty() {
-            return Ok(ScalarValue::Null);
-        }
-        let sorted = sort(&combined, None)?;
-        let array = as_primitive_array::<$TY>(&sorted)?;
-        let len = sorted.len();
-        let mid = len / 2;
-        if len % 2 == 0 {
-            Ok(ScalarValue::$SCALAR_TY(Some(
-                (array.value(mid - 1) + array.value(mid)) / $TWO,
-            )))
-        } else {
-            Ok(ScalarValue::$SCALAR_TY(Some(array.value(mid))))
-        }
-    }};
+    all_values: Vec<ScalarValue>,
 }
 
 impl Accumulator for MedianAccumulator {
     fn state(&self) -> Result<Vec<AggregateState>> {
-        let mut vec: Vec<AggregateState> = self
-            .all_values
-            .iter()
-            .map(|v| AggregateState::Array(v.clone()))
-            .collect();
-        if vec.is_empty() {
-            match self.data_type {
-                DataType::UInt8 => vec.push(empty_array::<UInt8Type>()),
-                DataType::UInt16 => vec.push(empty_array::<UInt16Type>()),
-                DataType::UInt32 => vec.push(empty_array::<UInt32Type>()),
-                DataType::UInt64 => vec.push(empty_array::<UInt64Type>()),
-                DataType::Int8 => vec.push(empty_array::<Int8Type>()),
-                DataType::Int16 => vec.push(empty_array::<Int16Type>()),
-                DataType::Int32 => vec.push(empty_array::<Int32Type>()),
-                DataType::Int64 => vec.push(empty_array::<Int64Type>()),
-                DataType::Float32 => vec.push(empty_array::<Float32Type>()),
-                DataType::Float64 => vec.push(empty_array::<Float64Type>()),
-                _ => {
-                    return Err(DataFusionError::Execution(
-                        "unsupported data type for median".to_string(),
-                    ))
-                }
-            }
-        }
-        Ok(vec)
+        let state =
+            ScalarValue::new_list(Some(self.all_values.clone()), self.data_type.clone());
+        Ok(vec![AggregateState::Scalar(state)])
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        let x = values[0].clone();
-        self.all_values.extend_from_slice(&[x]);
+        assert_eq!(values.len(), 1);
+        let array = &values[0];
+
+        assert_eq!(array.data_type(), &self.data_type);
+        self.all_values.reserve(self.all_values.len() + array.len());
+        for index in 0..array.len() {
+            self.all_values
+                .push(ScalarValue::try_from_array(array, index)?);
+        }
+
         Ok(())
     }
 
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
-        for array in states {
-            self.all_values.extend_from_slice(&[array.clone()]);
+        assert_eq!(states.len(), 1);
+
+        let array = &states[0];
+        assert!(matches!(array.data_type(), DataType::List(_)));
+        for index in 0..array.len() {
+            match ScalarValue::try_from_array(array, index)? {
+                ScalarValue::List(Some(mut values), _) => {
+                    self.all_values.append(&mut values);
+                }
+                ScalarValue::List(None, _) => {} // skip empty state
+                v => {
+                    return Err(DataFusionError::Internal(format!(
+                        "unexpected state in median. Expected DataType::List, got {:?}",
+                        v
+                    )))
+                }
+            }
         }
         Ok(())
     }
 
     fn evaluate(&self) -> Result<ScalarValue> {
-        match self.all_values[0].data_type() {
-            DataType::Int8 => median!(self, arrow::datatypes::Int8Type, Int8, 2),
-            DataType::Int16 => median!(self, arrow::datatypes::Int16Type, Int16, 2),
-            DataType::Int32 => median!(self, arrow::datatypes::Int32Type, Int32, 2),
-            DataType::Int64 => median!(self, arrow::datatypes::Int64Type, Int64, 2),
-            DataType::UInt8 => median!(self, arrow::datatypes::UInt8Type, UInt8, 2),
-            DataType::UInt16 => median!(self, arrow::datatypes::UInt16Type, UInt16, 2),
-            DataType::UInt32 => median!(self, arrow::datatypes::UInt32Type, UInt32, 2),
-            DataType::UInt64 => median!(self, arrow::datatypes::UInt64Type, UInt64, 2),
-            DataType::Float32 => {
-                median!(self, arrow::datatypes::Float32Type, Float32, 2_f32)
-            }
-            DataType::Float64 => {
-                median!(self, arrow::datatypes::Float64Type, Float64, 2_f64)
+        // Create an array of all the non null values and find the
+        // sorted indexes
+        let array = ScalarValue::iter_to_array(
+            self.all_values
+                .iter()
+                // ignore null values
+                .filter(|v| !v.is_null())
+                .cloned(),
+        )?;
+
+        // find the mid point
+        let len = array.len();
+        let mid = len / 2;
+
+        // only sort up to the top size/2 elements
+        let limit = Some(mid + 1);
+        let options = None;
+        let indices = sort_to_indices(&array, options, limit)?;
+
+        // pick the relevant indices in the original arrays
+        let result = if len >= 2 && len % 2 == 0 {
+            // even number of values, average the two mid points
+            let s1 = scalar_at_index(&array, &indices, mid - 1)?;
+            let s2 = scalar_at_index(&array, &indices, mid)?;
+            match s1.add(s2)? {
+                ScalarValue::Int8(Some(v)) => ScalarValue::Int8(Some(v / 2)),
+                ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(v / 2)),
+                ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(v / 2)),
+                ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(v / 2)),
+                ScalarValue::UInt8(Some(v)) => ScalarValue::UInt8(Some(v / 2)),
+                ScalarValue::UInt16(Some(v)) => ScalarValue::UInt16(Some(v / 2)),
+                ScalarValue::UInt32(Some(v)) => ScalarValue::UInt32(Some(v / 2)),
+                ScalarValue::UInt64(Some(v)) => ScalarValue::UInt64(Some(v / 2)),
+                ScalarValue::Float32(Some(v)) => ScalarValue::Float32(Some(v / 2.0)),
+                ScalarValue::Float64(Some(v)) => ScalarValue::Float64(Some(v / 2.0)),
+                v => {
+                    return Err(DataFusionError::Internal(format!(
+                        "Unsupported type in MedianAccumulator: {:?}",
+                        v
+                    )))
+                }
             }
-            _ => Err(DataFusionError::Execution(
-                "unsupported data type for median".to_string(),
-            )),
-        }
+        } else {
+            // odd number of values, pick that one
+            scalar_at_index(&array, &indices, mid)?
+        };
+
+        Ok(result)
     }
 
     fn size(&self) -> usize {
-        std::mem::align_of_val(self)
-            + (std::mem::size_of::<ArrayRef>() * self.all_values.capacity())
-            + self
-                .all_values
-                .iter()
-                .map(|array_ref| {
-                    std::mem::size_of_val(array_ref.as_ref())
-                        + array_ref.get_array_memory_size()
-                })
-                .sum::<usize>()
+        std::mem::size_of_val(self) + ScalarValue::size_of_vec(&self.all_values)
+            - std::mem::size_of_val(&self.all_values)
             + self.data_type.size()
             - std::mem::size_of_val(&self.data_type)
     }
 }
 
-/// Create an empty array
-fn empty_array<T: ArrowPrimitiveType>() -> AggregateState {
-    AggregateState::Array(Arc::new(PrimitiveBuilder::<T>::with_capacity(0).finish()))
-}
-
-/// Combine all non-null values from provided arrays into a single array
-fn combine_arrays<T: ArrowPrimitiveType>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
-    let len = arrays.iter().map(|a| a.len() - a.null_count()).sum();
-    let mut builder: PrimitiveBuilder<T> = PrimitiveBuilder::with_capacity(len);
-    for array in arrays {
-        let array = as_primitive_array::<T>(array)?;
-        for i in 0..array.len() {
-            if !array.is_null(i) {
-                builder.append_value(array.value(i));
-            }
-        }
-    }
-    Ok(Arc::new(builder.finish()))
-}
-
-#[cfg(test)]
-mod test {
-    use crate::aggregate::median::combine_arrays;
-    use arrow::array::{Int32Array, UInt32Array};
-    use arrow::datatypes::{Int32Type, UInt32Type};
-    use datafusion_common::Result;
-    use std::sync::Arc;
-
-    #[test]
-    fn combine_i32_array() -> Result<()> {
-        let a = Arc::new(Int32Array::from(vec![1, 2, 3]));
-        let b = combine_arrays::<Int32Type>(&[a.clone(), a])?;
-        assert_eq!(
-            "PrimitiveArray<Int32>\n[\n  1,\n  2,\n  3,\n  1,\n  2,\n  3,\n]",
-            format!("{:?}", b)
-        );
-        Ok(())
-    }
-
-    #[test]
-    fn combine_u32_array() -> Result<()> {
-        let a = Arc::new(UInt32Array::from(vec![1, 2, 3]));
-        let b = combine_arrays::<UInt32Type>(&[a.clone(), a])?;
-        assert_eq!(
-            "PrimitiveArray<UInt32>\n[\n  1,\n  2,\n  3,\n  1,\n  2,\n  3,\n]",
-            format!("{:?}", b)
-        );
-        Ok(())
-    }
+/// Given a returns `array[indicies[indicie_index]]` as a `ScalarValue`
+fn scalar_at_index(
+    array: &dyn Array,
+    indices: &UInt32Array,
+    indicies_index: usize,
+) -> Result<ScalarValue> {
+    let array_index = indices
+        .value(indicies_index)
+        .try_into()
+        .expect("Convert uint32 to usize");
+    ScalarValue::try_from_array(array, array_index)
 }