You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2023/05/02 07:10:38 UTC
[arrow-datafusion] branch main updated: Handle ScalarValue::Dictionary in add_to_row and update_avg_to_row (#6175)
This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 263bfcb28c Handle ScalarValue::Dictionary in add_to_row and update_avg_to_row (#6175)
263bfcb28c is described below
commit 263bfcb28cf111cb7aabd00713f34b32f08a9d74
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Tue May 2 00:10:32 2023 -0700
Handle ScalarValue::Dictionary in add_to_row and update_avg_to_row (#6175)
* Handle ScalarValue::Dictionary in add_to_row and update_avg_to_row
* Add test
* Add test
* Improve tests
---
datafusion/physical-expr/src/aggregate/sum.rs | 82 ++++++++++++++++++++++++++-
1 file changed, 81 insertions(+), 1 deletion(-)
diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs
index e08726e465..1c70dc67be 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -275,6 +275,10 @@ pub(crate) fn add_to_row(
ScalarValue::Decimal128(rhs, _, _) => {
sum_row!(index, accessor, rhs, i128)
}
+ ScalarValue::Dictionary(_, value) => {
+ let value = value.as_ref();
+ return add_to_row(index, accessor, value);
+ }
_ => {
let msg =
format!("Row sum updater is not expected to receive a scalar {s:?}");
@@ -308,6 +312,10 @@ pub(crate) fn update_avg_to_row(
ScalarValue::Decimal128(rhs, _, _) => {
avg_row!(index, accessor, rhs, i128)
}
+ ScalarValue::Dictionary(_, value) => {
+ let value = value.as_ref();
+ return update_avg_to_row(index, accessor, value);
+ }
_ => {
let msg =
format!("Row avg updater is not expected to receive a scalar {s:?}");
@@ -419,11 +427,12 @@ impl RowAccumulator for SumRowAccumulator {
#[cfg(test)]
mod tests {
use super::*;
- use crate::expressions::col;
use crate::expressions::tests::aggregate;
+ use crate::expressions::{col, Avg};
use crate::generic_test_op;
use arrow::datatypes::*;
use arrow::record_batch::RecordBatch;
+ use arrow_array::DictionaryArray;
use datafusion_common::Result;
#[test]
@@ -546,4 +555,75 @@ mod tests {
Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
generic_test_op!(a, DataType::Float64, Sum, ScalarValue::from(15_f64))
}
+
+ fn row_aggregate(
+ array: &ArrayRef,
+ agg: Arc<dyn AggregateExpr>,
+ row_accessor: &mut RowAccessor,
+ row_indexs: Vec<usize>,
+ ) -> Result<ScalarValue> {
+ let mut accum = agg.create_row_accumulator(0)?;
+
+ for row_index in row_indexs {
+ let scalar_value = ScalarValue::try_from_array(array, row_index)?;
+ accum.update_scalar(&scalar_value, row_accessor)?;
+ }
+ accum.evaluate(row_accessor)
+ }
+
+ #[test]
+ fn sum_dictionary_f64() -> Result<()> {
+ let keys = Int32Array::from(vec![2, 3, 1, 0, 1]);
+ let values = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64]));
+
+ let a: ArrayRef = Arc::new(DictionaryArray::try_new(keys, values).unwrap());
+
+ let row_schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
+ let mut row_accessor = RowAccessor::new(&row_schema);
+ let mut buffer: Vec<u8> = vec![0; 16];
+ row_accessor.point_to(0, &mut buffer);
+
+ let expected = ScalarValue::from(9_f64);
+
+ let agg = Arc::new(Sum::new(
+ col("a", &row_schema)?,
+ "bla".to_string(),
+ expected.get_datatype(),
+ ));
+
+ let actual = row_aggregate(&a, agg, &mut row_accessor, vec![0, 1, 2])?;
+ assert_eq!(expected, actual);
+
+ Ok(())
+ }
+
+ #[test]
+ fn avg_dictionary_f64() -> Result<()> {
+ let keys = Int32Array::from(vec![2, 1, 1, 3, 0]);
+ let values = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64]));
+
+ let a: ArrayRef = Arc::new(DictionaryArray::try_new(keys, values).unwrap());
+
+ let row_schema = Schema::new(vec![
+ Field::new("count", DataType::UInt64, true),
+ Field::new("a", DataType::Float64, true),
+ ]);
+ let mut row_accessor = RowAccessor::new(&row_schema);
+ let mut buffer: Vec<u8> = vec![0; 24];
+ row_accessor.point_to(0, &mut buffer);
+
+ let expected = ScalarValue::from(2.3333333333333335_f64);
+
+ let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
+ let agg = Arc::new(Avg::new(
+ col("a", &schema)?,
+ "bla".to_string(),
+ expected.get_datatype(),
+ ));
+
+ let actual = row_aggregate(&a, agg, &mut row_accessor, vec![0, 1, 2])?;
+ assert_eq!(expected, actual);
+
+ Ok(())
+ }
}