You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by dh...@apache.org on 2023/07/03 07:57:38 UTC

[arrow-datafusion] 03/17: complete accumulator

This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch hash_agg_spike
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit 337353810df503d02245de02357fb1d6ba04f675
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Fri Jun 30 11:28:48 2023 -0400

    complete accumulator
---
 datafusion/physical-expr/src/aggregate/average.rs | 76 ++++++++++++++++++-----
 1 file changed, 61 insertions(+), 15 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs
index f81c704d8b..b23b555805 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -18,7 +18,7 @@
 //! Defines physical expressions that can evaluated at runtime during query execution
 
 use arrow::array::AsArray;
-use log::info;
+use log::debug;
 
 use std::any::Any;
 use std::convert::TryFrom;
@@ -45,6 +45,8 @@ use datafusion_common::{DataFusionError, Result};
 use datafusion_expr::Accumulator;
 use datafusion_row::accessor::RowAccessor;
 
+use super::utils::Decimal128Averager;
+
 /// AVG aggregate expression
 #[derive(Debug, Clone)]
 pub struct Avg {
@@ -161,16 +163,29 @@ impl AggregateExpr for Avg {
 
     fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
         // instantiate specialized accumulator
-        match self.sum_data_type {
-            DataType::Decimal128(_, _) => {
-                Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type>::new(
+        match (&self.sum_data_type, &self.rt_data_type) {
+            (
+                DataType::Decimal128(_sum_precision, sum_scale),
+                DataType::Decimal128(target_precision, target_scale),
+            ) => {
+                let decimal_averager = Decimal128Averager::try_new(
+                    *sum_scale,
+                    *target_precision,
+                    *target_scale,
+                )?;
+
+                let avg_fn =
+                    move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
+
+                Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
                     &self.sum_data_type,
                     &self.rt_data_type,
+                    avg_fn,
                 )))
             }
             _ => Err(DataFusionError::NotImplemented(format!(
-                "AvgGroupsAccumulator for {}",
-                self.sum_data_type
+                "AvgGroupsAccumulator for ({} --> {})",
+                self.sum_data_type, self.rt_data_type,
             ))),
         }
     }
@@ -403,9 +418,13 @@ impl RowAccumulator for AvgRowAccumulator {
 }
 
 /// An accumulator to compute the average of PrimitiveArray<T>.
-/// Stores values as native types
+/// Stores values as native types, and does overflow checking
 #[derive(Debug)]
-struct AvgGroupsAccumulator<T: ArrowNumericType + Send> {
+struct AvgGroupsAccumulator<T, F>
+where
+    T: ArrowNumericType + Send,
+    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
     /// The type of the internal sum
     sum_data_type: DataType,
 
@@ -415,13 +434,20 @@ struct AvgGroupsAccumulator<T: ArrowNumericType + Send> {
     /// Count per group (use u64 to make UInt64Array)
     counts: Vec<u64>,
 
-    // Sums per group, stored as the native type
+    /// Sums per group, stored as the native type
     sums: Vec<T::Native>,
+
+    /// Function that computes the average (value / count)
+    avg_fn: F,
 }
 
-impl<T: ArrowNumericType + Send> AvgGroupsAccumulator<T> {
-    pub fn new(sum_data_type: &DataType, return_data_type: &DataType) -> Self {
-        info!(
+impl<T, F> AvgGroupsAccumulator<T, F>
+where
+    T: ArrowNumericType + Send,
+    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
+    pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
+        debug!(
             "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}",
             std::any::type_name::<T>()
         );
@@ -430,6 +456,7 @@ impl<T: ArrowNumericType + Send> AvgGroupsAccumulator<T> {
             sum_data_type: sum_data_type.clone(),
             counts: vec![],
             sums: vec![],
+            avg_fn,
         }
     }
 
@@ -500,7 +527,11 @@ impl<T: ArrowNumericType + Send> AvgGroupsAccumulator<T> {
     }
 }
 
-impl<T: ArrowNumericType + Send> GroupsAccumulator for AvgGroupsAccumulator<T> {
+impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
+where
+    T: ArrowNumericType + Send,
+    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
     fn update_batch(
         &mut self,
         values: &[ArrayRef],
@@ -549,7 +580,22 @@ impl<T: ArrowNumericType + Send> GroupsAccumulator for AvgGroupsAccumulator<T> {
     }
 
     fn evaluate(&mut self) -> Result<ArrayRef> {
-        todo!()
+        let counts = std::mem::take(&mut self.counts);
+        let sums = std::mem::take(&mut self.sums);
+
+        let averages: Vec<T::Native> = sums
+            .into_iter()
+            .zip(counts.into_iter())
+            .map(|(sum, count)| (self.avg_fn)(sum, count))
+            .collect::<Result<Vec<_>>>()?;
+
+        // TODO figure out how to do this without the iter / copy
+        let array = PrimitiveArray::<T>::from_iter_values(averages);
+
+        // fix up decimal precision and scale for decimals
+        let array = set_decimal_precision(&self.return_data_type, Arc::new(array))?;
+
+        Ok(array)
     }
 
     // return arrays for sums and counts
@@ -563,7 +609,7 @@ impl<T: ArrowNumericType + Send> GroupsAccumulator for AvgGroupsAccumulator<T> {
         // TODO figure out how to do this without the iter / copy
         let sums: PrimitiveArray<T> = PrimitiveArray::from_iter_values(sums);
 
-        // fix up decimal precision and scale
+        // fix up decimal precision and scale for decimals
         let sums = set_decimal_precision(&self.sum_data_type, Arc::new(sums))?;
 
         Ok(vec![