You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2023/05/02 07:10:38 UTC

[arrow-datafusion] branch main updated: Handle ScalarValue::Dictionary in add_to_row and update_avg_to_row (#6175)

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 263bfcb28c Handle ScalarValue::Dictionary in add_to_row and update_avg_to_row (#6175)
263bfcb28c is described below

commit 263bfcb28cf111cb7aabd00713f34b32f08a9d74
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Tue May 2 00:10:32 2023 -0700

    Handle ScalarValue::Dictionary in add_to_row and update_avg_to_row (#6175)
    
    * Handle ScalarValue::Dictionary in add_to_row and update_avg_to_row
    
    * Add test
    
    * Add test
    
    * Improve tests
---
 datafusion/physical-expr/src/aggregate/sum.rs | 82 ++++++++++++++++++++++++++-
 1 file changed, 81 insertions(+), 1 deletion(-)

diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs
index e08726e465..1c70dc67be 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -275,6 +275,10 @@ pub(crate) fn add_to_row(
         ScalarValue::Decimal128(rhs, _, _) => {
             sum_row!(index, accessor, rhs, i128)
         }
+        ScalarValue::Dictionary(_, value) => {
+            let value = value.as_ref();
+            return add_to_row(index, accessor, value);
+        }
         _ => {
             let msg =
                 format!("Row sum updater is not expected to receive a scalar {s:?}");
@@ -308,6 +312,10 @@ pub(crate) fn update_avg_to_row(
         ScalarValue::Decimal128(rhs, _, _) => {
             avg_row!(index, accessor, rhs, i128)
         }
+        ScalarValue::Dictionary(_, value) => {
+            let value = value.as_ref();
+            return update_avg_to_row(index, accessor, value);
+        }
         _ => {
             let msg =
                 format!("Row avg updater is not expected to receive a scalar {s:?}");
@@ -419,11 +427,12 @@ impl RowAccumulator for SumRowAccumulator {
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::expressions::col;
     use crate::expressions::tests::aggregate;
+    use crate::expressions::{col, Avg};
     use crate::generic_test_op;
     use arrow::datatypes::*;
     use arrow::record_batch::RecordBatch;
+    use arrow_array::DictionaryArray;
     use datafusion_common::Result;
 
     #[test]
@@ -546,4 +555,75 @@ mod tests {
             Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
         generic_test_op!(a, DataType::Float64, Sum, ScalarValue::from(15_f64))
     }
+
+    fn row_aggregate(
+        array: &ArrayRef,
+        agg: Arc<dyn AggregateExpr>,
+        row_accessor: &mut RowAccessor,
+        row_indexs: Vec<usize>,
+    ) -> Result<ScalarValue> {
+        let mut accum = agg.create_row_accumulator(0)?;
+
+        for row_index in row_indexs {
+            let scalar_value = ScalarValue::try_from_array(array, row_index)?;
+            accum.update_scalar(&scalar_value, row_accessor)?;
+        }
+        accum.evaluate(row_accessor)
+    }
+
+    #[test]
+    fn sum_dictionary_f64() -> Result<()> {
+        let keys = Int32Array::from(vec![2, 3, 1, 0, 1]);
+        let values = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64]));
+
+        let a: ArrayRef = Arc::new(DictionaryArray::try_new(keys, values).unwrap());
+
+        let row_schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
+        let mut row_accessor = RowAccessor::new(&row_schema);
+        let mut buffer: Vec<u8> = vec![0; 16];
+        row_accessor.point_to(0, &mut buffer);
+
+        let expected = ScalarValue::from(9_f64);
+
+        let agg = Arc::new(Sum::new(
+            col("a", &row_schema)?,
+            "bla".to_string(),
+            expected.get_datatype(),
+        ));
+
+        let actual = row_aggregate(&a, agg, &mut row_accessor, vec![0, 1, 2])?;
+        assert_eq!(expected, actual);
+
+        Ok(())
+    }
+
+    #[test]
+    fn avg_dictionary_f64() -> Result<()> {
+        let keys = Int32Array::from(vec![2, 1, 1, 3, 0]);
+        let values = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64]));
+
+        let a: ArrayRef = Arc::new(DictionaryArray::try_new(keys, values).unwrap());
+
+        let row_schema = Schema::new(vec![
+            Field::new("count", DataType::UInt64, true),
+            Field::new("a", DataType::Float64, true),
+        ]);
+        let mut row_accessor = RowAccessor::new(&row_schema);
+        let mut buffer: Vec<u8> = vec![0; 24];
+        row_accessor.point_to(0, &mut buffer);
+
+        let expected = ScalarValue::from(2.3333333333333335_f64);
+
+        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
+        let agg = Arc::new(Avg::new(
+            col("a", &schema)?,
+            "bla".to_string(),
+            expected.get_datatype(),
+        ));
+
+        let actual = row_aggregate(&a, agg, &mut row_accessor, vec![0, 1, 2])?;
+        assert_eq!(expected, actual);
+
+        Ok(())
+    }
 }