You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by dh...@apache.org on 2023/07/03 07:57:38 UTC
[arrow-datafusion] 03/17: complete accumulator
This is an automated email from the ASF dual-hosted git repository.
dheres pushed a commit to branch hash_agg_spike
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit 337353810df503d02245de02357fb1d6ba04f675
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Fri Jun 30 11:28:48 2023 -0400
complete accumulator
---
datafusion/physical-expr/src/aggregate/average.rs | 76 ++++++++++++++++++-----
1 file changed, 61 insertions(+), 15 deletions(-)
diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs
index f81c704d8b..b23b555805 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -18,7 +18,7 @@
//! Defines physical expressions that can evaluated at runtime during query execution
use arrow::array::AsArray;
-use log::info;
+use log::debug;
use std::any::Any;
use std::convert::TryFrom;
@@ -45,6 +45,8 @@ use datafusion_common::{DataFusionError, Result};
use datafusion_expr::Accumulator;
use datafusion_row::accessor::RowAccessor;
+use super::utils::Decimal128Averager;
+
/// AVG aggregate expression
#[derive(Debug, Clone)]
pub struct Avg {
@@ -161,16 +163,29 @@ impl AggregateExpr for Avg {
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
// instantiate specialized accumulator
- match self.sum_data_type {
- DataType::Decimal128(_, _) => {
- Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type>::new(
+ match (&self.sum_data_type, &self.rt_data_type) {
+ (
+ DataType::Decimal128(_sum_precision, sum_scale),
+ DataType::Decimal128(target_precision, target_scale),
+ ) => {
+ let decimal_averager = Decimal128Averager::try_new(
+ *sum_scale,
+ *target_precision,
+ *target_scale,
+ )?;
+
+ let avg_fn =
+ move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
+
+ Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
&self.sum_data_type,
&self.rt_data_type,
+ avg_fn,
)))
}
_ => Err(DataFusionError::NotImplemented(format!(
- "AvgGroupsAccumulator for {}",
- self.sum_data_type
+ "AvgGroupsAccumulator for ({} --> {})",
+ self.sum_data_type, self.rt_data_type,
))),
}
}
@@ -403,9 +418,13 @@ impl RowAccumulator for AvgRowAccumulator {
}
/// An accumulator to compute the average of PrimitiveArray<T>.
-/// Stores values as native types
+/// Stores values as native types, and does overflow checking
#[derive(Debug)]
-struct AvgGroupsAccumulator<T: ArrowNumericType + Send> {
+struct AvgGroupsAccumulator<T, F>
+where
+ T: ArrowNumericType + Send,
+ F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
/// The type of the internal sum
sum_data_type: DataType,
@@ -415,13 +434,20 @@ struct AvgGroupsAccumulator<T: ArrowNumericType + Send> {
/// Count per group (use u64 to make UInt64Array)
counts: Vec<u64>,
- // Sums per group, stored as the native type
+ /// Sums per group, stored as the native type
sums: Vec<T::Native>,
+
+ /// Function that computes the average (value / count)
+ avg_fn: F,
}
-impl<T: ArrowNumericType + Send> AvgGroupsAccumulator<T> {
- pub fn new(sum_data_type: &DataType, return_data_type: &DataType) -> Self {
- info!(
+impl<T, F> AvgGroupsAccumulator<T, F>
+where
+ T: ArrowNumericType + Send,
+ F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
+ pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
+ debug!(
"AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}",
std::any::type_name::<T>()
);
@@ -430,6 +456,7 @@ impl<T: ArrowNumericType + Send> AvgGroupsAccumulator<T> {
sum_data_type: sum_data_type.clone(),
counts: vec![],
sums: vec![],
+ avg_fn,
}
}
@@ -500,7 +527,11 @@ impl<T: ArrowNumericType + Send> AvgGroupsAccumulator<T> {
}
}
-impl<T: ArrowNumericType + Send> GroupsAccumulator for AvgGroupsAccumulator<T> {
+impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
+where
+ T: ArrowNumericType + Send,
+ F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
fn update_batch(
&mut self,
values: &[ArrayRef],
@@ -549,7 +580,22 @@ impl<T: ArrowNumericType + Send> GroupsAccumulator for AvgGroupsAccumulator<T> {
}
fn evaluate(&mut self) -> Result<ArrayRef> {
- todo!()
+ let counts = std::mem::take(&mut self.counts);
+ let sums = std::mem::take(&mut self.sums);
+
+ let averages: Vec<T::Native> = sums
+ .into_iter()
+ .zip(counts.into_iter())
+ .map(|(sum, count)| (self.avg_fn)(sum, count))
+ .collect::<Result<Vec<_>>>()?;
+
+ // TODO figure out how to do this without the iter / copy
+ let array = PrimitiveArray::<T>::from_iter_values(averages);
+
+ // fix up decimal precision and scale for decimals
+ let array = set_decimal_precision(&self.return_data_type, Arc::new(array))?;
+
+ Ok(array)
}
// return arrays for sums and counts
@@ -563,7 +609,7 @@ impl<T: ArrowNumericType + Send> GroupsAccumulator for AvgGroupsAccumulator<T> {
// TODO figure out how to do this without the iter / copy
let sums: PrimitiveArray<T> = PrimitiveArray::from_iter_values(sums);
- // fix up decimal precision and scale
+ // fix up decimal precision and scale for decimals
let sums = set_decimal_precision(&self.sum_data_type, Arc::new(sums))?;
Ok(vec![