You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ks...@apache.org on 2019/04/10 10:14:41 UTC

[arrow] branch master updated: ARROW-5038: [Rust] [DataFusion] Implement AVG aggregate function

This is an automated email from the ASF dual-hosted git repository.

kszucs pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 646624b  ARROW-5038: [Rust] [DataFusion] Implement AVG aggregate function
646624b is described below

commit 646624bc7524840077203f0e1fc0bf14e67244fa
Author: Zhiyuan Zheng <zh...@hotmail.com>
AuthorDate: Wed Apr 10 12:14:24 2019 +0200

    ARROW-5038: [Rust] [DataFusion] Implement AVG aggregate function
    
    Implement the AVG aggregate function for DataFusion.
    
    Split `accumulate_scalar` function into `accumulate_scalar` & `accumulate_batch` for single value update and batch update.
    
    I am a newbie to Arrow. Please take careful review for this @andygrove .
    
    Author: Zhiyuan Zheng <zh...@hotmail.com>
    
    Closes #4120 from zhzy0077/master and squashes the following commits:
    
    26e85bc46 <Zhiyuan Zheng> Fix format issue.
    da1f5c897 <Zhiyuan Zheng> Add multi-batch AVG test.
    54e13b267 <Zhiyuan Zheng> Apply the fmt changes.
    beb27f789 <Zhiyuan Zheng> Support for AVG function.
---
 rust/datafusion/src/execution/aggregate.rs  | 234 +++++++++++++++++++---------
 rust/datafusion/src/execution/expression.rs |   1 +
 rust/datafusion/tests/sql.rs                |  95 ++++++++++-
 3 files changed, 257 insertions(+), 73 deletions(-)

diff --git a/rust/datafusion/src/execution/aggregate.rs b/rust/datafusion/src/execution/aggregate.rs
index 7068c6a..4399bef 100644
--- a/rust/datafusion/src/execution/aggregate.rs
+++ b/rust/datafusion/src/execution/aggregate.rs
@@ -86,21 +86,14 @@ trait AggregateFunction {
     fn name(&self) -> &str;
 
     /// Update the current aggregate value based on a new value. A value of `None` represents a
-    /// null value. If rollup is false, then this aggregate function instance is being used
-    /// to aggregate individual values within a RecordBatch. If rollup is true then the aggregate
-    /// function instance is being used to combine the aggregates for multiple batches.
-    /// For some aggregate operations, such as `min`, `max`, and `sum`, the logic is the same
-    /// regardless of whether rollup is true or false. For example, `min` can be implemented as
-    /// `min(min(batch1), min(batch2), ..)`. However for `count` the logic is
-    /// `sum(count(batch1), count(batch2), ..)`.
-    fn accumulate_scalar(
-        &mut self,
-        value: &Option<ScalarValue>,
-        rollup: bool,
-    ) -> Result<()>;
+    /// null value.
+    fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()>;
+
+    /// Update the current aggregate value based on an array.
+    fn accumulate_batch(&mut self, array: ArrayRef) -> Result<()>;
 
     /// Return the result of the aggregate function after all values have been processed
-    /// by calls to `acccumulate_scalar`.
+    /// by calls to `accumulate_scalar`.
     fn result(&self) -> Option<ScalarValue>;
 
     /// Get the data type of the result of the aggregate function. For some operations,
@@ -131,11 +124,7 @@ impl AggregateFunction for MinFunction {
         "min"
     }
 
-    fn accumulate_scalar(
-        &mut self,
-        value: &Option<ScalarValue>,
-        _rollup: bool,
-    ) -> Result<()> {
+    fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
         if self.value.is_none() {
             self.value = value.clone();
         } else if value.is_some() {
@@ -180,6 +169,12 @@ impl AggregateFunction for MinFunction {
         Ok(())
     }
 
+    fn accumulate_batch(&mut self, array: ArrayRef) -> Result<()> {
+        let accumulated_value = array_min(array)?;
+
+        self.accumulate_scalar(&accumulated_value)
+    }
+
     fn result(&self) -> Option<ScalarValue> {
         self.value.clone()
     }
@@ -210,11 +205,7 @@ impl AggregateFunction for MaxFunction {
         "max"
     }
 
-    fn accumulate_scalar(
-        &mut self,
-        value: &Option<ScalarValue>,
-        _rollup: bool,
-    ) -> Result<()> {
+    fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
         if self.value.is_none() {
             self.value = value.clone();
         } else if value.is_some() {
@@ -259,6 +250,12 @@ impl AggregateFunction for MaxFunction {
         Ok(())
     }
 
+    fn accumulate_batch(&mut self, array: ArrayRef) -> Result<()> {
+        let accumulated_value = array_max(array)?;
+
+        self.accumulate_scalar(&accumulated_value)
+    }
+
     fn result(&self) -> Option<ScalarValue> {
         self.value.clone()
     }
@@ -289,11 +286,7 @@ impl AggregateFunction for SumFunction {
         "sum"
     }
 
-    fn accumulate_scalar(
-        &mut self,
-        value: &Option<ScalarValue>,
-        _rollup: bool,
-    ) -> Result<()> {
+    fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
         if self.value.is_none() {
             self.value = value.clone();
         } else if value.is_some() {
@@ -338,6 +331,12 @@ impl AggregateFunction for SumFunction {
         Ok(())
     }
 
+    fn accumulate_batch(&mut self, array: ArrayRef) -> Result<()> {
+        let accumulated_value = array_sum(array)?;
+
+        self.accumulate_scalar(&accumulated_value)
+    }
+
     fn result(&self) -> Option<ScalarValue> {
         self.value.clone()
     }
@@ -347,6 +346,79 @@ impl AggregateFunction for SumFunction {
     }
 }
 
+/// Implementation of AVG aggregate function
+#[derive(Debug)]
+struct AvgFunction {
+    data_type: DataType,
+    sum_value: SumFunction,
+    count_value: CountFunction,
+}
+
+impl AvgFunction {
+    fn new(data_type: &DataType) -> Self {
+        Self {
+            data_type: DataType::Float64,
+            sum_value: SumFunction::new(data_type),
+            count_value: CountFunction::new(),
+        }
+    }
+}
+
+impl AggregateFunction for AvgFunction {
+    fn name(&self) -> &str {
+        "avg"
+    }
+
+    fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
+        self.sum_value.accumulate_scalar(value)?;
+        self.count_value.accumulate_scalar(value)?;
+
+        Ok(())
+    }
+
+    fn accumulate_batch(&mut self, array: ArrayRef) -> Result<()> {
+        self.sum_value.accumulate_batch(array.clone())?;
+        self.count_value.accumulate_batch(array)?;
+
+        Ok(())
+    }
+
+    fn result(&self) -> Option<ScalarValue> {
+        let sum = match self.sum_value.result() {
+            Some(ScalarValue::UInt8(a)) => a as f64,
+            Some(ScalarValue::UInt16(a)) => a as f64,
+            Some(ScalarValue::UInt32(a)) => a as f64,
+            Some(ScalarValue::UInt64(a)) => a as f64,
+            Some(ScalarValue::Int8(a)) => a as f64,
+            Some(ScalarValue::Int16(a)) => a as f64,
+            Some(ScalarValue::Int32(a)) => a as f64,
+            Some(ScalarValue::Int64(a)) => a as f64,
+            Some(ScalarValue::Float32(a)) => a as f64,
+            Some(ScalarValue::Float64(a)) => a as f64,
+            Some(ScalarValue::Null) => {
+                return Some(ScalarValue::Null);
+            }
+            _ => {
+                return None;
+            }
+        };
+        let count = match self.count_value.result() {
+            Some(ScalarValue::UInt64(a)) => a as f64,
+            Some(ScalarValue::Null) => {
+                return Some(ScalarValue::Null);
+            }
+            _ => {
+                return None;
+            }
+        };
+        Some(ScalarValue::Float64(sum / count))
+    }
+
+    fn data_type(&self) -> &DataType {
+        &self.data_type
+    }
+}
+
 /// Implementation of COUNT aggregate function
 #[derive(Debug)]
 struct CountFunction {
@@ -364,28 +436,27 @@ impl AggregateFunction for CountFunction {
         "count"
     }
 
-    fn accumulate_scalar(
-        &mut self,
-        value: &Option<ScalarValue>,
-        rollup: bool,
-    ) -> Result<()> {
-        if rollup {
-            // in rollup mode, the counts are added together
-            if let Some(ScalarValue::UInt64(n)) = value {
-                self.value = match self.value {
-                    Some(cur_value) => Some(cur_value + *n),
-                    _ => Some(*n),
-                }
-            }
-        } else {
-            // count the value if it is not null
-            if value.is_some() {
-                self.value = match self.value {
-                    Some(cur_value) => Some(cur_value + 1),
-                    _ => Some(1),
-                }
+    fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
+        if value.is_some() {
+            self.value = match self.value {
+                Some(cur_value) => Some(cur_value + 1),
+                None => Some(1),
             }
         }
+
+        Ok(())
+    }
+
+    fn accumulate_batch(&mut self, array: ArrayRef) -> Result<()> {
+        let accumulated_value = array_count(array)?;
+
+        if let Some(ScalarValue::UInt64(n)) = accumulated_value {
+            self.value = match self.value {
+                Some(cur_value) => Some(cur_value + n),
+                None => Some(n),
+            }
+        };
+
         Ok(())
     }
 
@@ -406,14 +477,14 @@ struct AccumulatorSet {
 }
 
 impl AccumulatorSet {
-    fn accumulate_scalar(
-        &mut self,
-        i: usize,
-        value: Option<ScalarValue>,
-        rollup: bool,
-    ) -> Result<()> {
+    fn accumulate_scalar(&mut self, i: usize, value: Option<ScalarValue>) -> Result<()> {
+        let mut accumulator = self.aggr_values[i].borrow_mut();
+        accumulator.accumulate_scalar(&value)
+    }
+
+    fn accumulate_batch(&mut self, i: usize, array: ArrayRef) -> Result<()> {
         let mut accumulator = self.aggr_values[i].borrow_mut();
-        accumulator.accumulate_scalar(&value, rollup)
+        accumulator.accumulate_batch(array)
     }
 
     fn values(&self) -> Vec<Option<ScalarValue>> {
@@ -449,6 +520,10 @@ fn create_accumulators(
                 Ok(Rc::new(RefCell::new(SumFunction::new(e.data_type())))
                     as Rc<RefCell<AggregateFunction>>)
             }
+            AggregateType::Avg => {
+                Ok(Rc::new(RefCell::new(AvgFunction::new(e.data_type())))
+                    as Rc<RefCell<AggregateFunction>>)
+            }
             AggregateType::Count => Ok(Rc::new(RefCell::new(CountFunction::new()))
                 as Rc<RefCell<AggregateFunction>>),
             _ => Err(ExecutionError::ExecutionError(
@@ -730,7 +805,7 @@ fn update_accumulators(
             }
         };
 
-        accumulator_set.accumulate_scalar(j, value, false)?;
+        accumulator_set.accumulate_scalar(j, value)?;
     }
     Ok(())
 }
@@ -834,18 +909,11 @@ impl AggregateRelation {
                 // evaluate the argument to the aggregate function
                 let array = self.aggr_expr[i].evaluate_arg(&batch)?;
                 match self.aggr_expr[i].aggr_type() {
-                    AggregateType::Min => {
-                        accumulator_set.accumulate_scalar(i, array_min(array)?, true)?
-                    }
-                    AggregateType::Max => {
-                        accumulator_set.accumulate_scalar(i, array_max(array)?, true)?
-                    }
-                    AggregateType::Sum => {
-                        accumulator_set.accumulate_scalar(i, array_sum(array)?, true)?
-                    }
-                    AggregateType::Count => {
-                        accumulator_set.accumulate_scalar(i, array_count(array)?, true)?
-                    }
+                    AggregateType::Min => accumulator_set.accumulate_batch(i, array)?,
+                    AggregateType::Max => accumulator_set.accumulate_batch(i, array)?,
+                    AggregateType::Sum => accumulator_set.accumulate_batch(i, array)?,
+                    AggregateType::Count => accumulator_set.accumulate_batch(i, array)?,
+                    AggregateType::Avg => accumulator_set.accumulate_batch(i, array)?,
                     _ => {
                         return Err(ExecutionError::NotImplemented(
                             "Unsupported aggregate function".to_string(),
@@ -1243,7 +1311,7 @@ mod tests {
     }
 
     #[test]
-    fn test_min_max_sum_count_f64_group_by_uint32() {
+    fn test_min_max_sum_count_avg_f64_group_by_uint32() {
         let schema = aggr_test_schema();
         let testdata = env::var("ARROW_TEST_DATA").expect("ARROW_TEST_DATA not defined");
         let relation =
@@ -1298,22 +1366,34 @@ mod tests {
         )
         .unwrap();
 
+        let avg_expr = expression::compile_aggregate_expr(
+            &context,
+            &Expr::AggregateFunction {
+                name: String::from("avg"),
+                args: vec![Expr::Column(11)],
+                return_type: DataType::Float64,
+            },
+            &schema,
+        )
+        .unwrap();
+
         let aggr_schema = Arc::new(Schema::new(vec![
             Field::new("c2", DataType::UInt32, false),
             Field::new("min", DataType::Float64, false),
             Field::new("max", DataType::Float64, false),
             Field::new("sum", DataType::Float64, false),
             Field::new("count", DataType::UInt64, false),
+            Field::new("avg", DataType::Float64, false),
         ]));
 
         let mut projection = AggregateRelation::new(
             aggr_schema,
             relation,
             vec![group_by_expr],
-            vec![min_expr, max_expr, sum_expr, count_expr],
+            vec![min_expr, max_expr, sum_expr, count_expr, avg_expr],
         );
         let batch = projection.next().unwrap().unwrap();
-        assert_eq!(5, batch.num_columns());
+        assert_eq!(6, batch.num_columns());
         assert_eq!(5, batch.num_rows());
 
         let a = batch
@@ -1341,36 +1421,46 @@ mod tests {
             .as_any()
             .downcast_ref::<UInt64Array>()
             .unwrap();
+        let avg = batch
+            .column(5)
+            .as_any()
+            .downcast_ref::<Float64Array>()
+            .unwrap();
 
         assert_eq!(4, a.value(0));
         assert_eq!(0.02182578039211991, min.value(0));
         assert_eq!(0.9237877978193884, max.value(0));
         assert_eq!(9.253864188402662, sum.value(0));
         assert_eq!(23, count.value(0));
+        assert_eq!(0.40234192123489837, avg.value(0));
 
         assert_eq!(2, a.value(1));
         assert_eq!(0.16301110515739792, min.value(1));
         assert_eq!(0.991517828651004, max.value(1));
         assert_eq!(14.400412325480858, sum.value(1));
         assert_eq!(22, count.value(1));
+        assert_eq!(0.6545641966127662, avg.value(1));
 
         assert_eq!(5, a.value(2));
         assert_eq!(0.01479305307777301, min.value(2));
         assert_eq!(0.9723580396501548, max.value(2));
         assert_eq!(6.037181692266781, sum.value(2));
         assert_eq!(14, count.value(2));
+        assert_eq!(0.4312272637333415, avg.value(2));
 
         assert_eq!(3, a.value(3));
         assert_eq!(0.047343434291126085, min.value(3));
         assert_eq!(0.9293883502480845, max.value(3));
         assert_eq!(9.966125219358322, sum.value(3));
         assert_eq!(19, count.value(3));
+        assert_eq!(0.5245329062820169, avg.value(3));
 
         assert_eq!(1, a.value(4));
         assert_eq!(0.05636955101974106, min.value(4));
         assert_eq!(0.9965400387585364, max.value(4));
         assert_eq!(11.239667565763519, sum.value(4));
-        assert_eq!(14, count.value(2));
+        assert_eq!(22, count.value(4));
+        assert_eq!(0.5108939802619781, avg.value(4));
     }
 
     fn aggr_test_schema() -> Arc<Schema> {
diff --git a/rust/datafusion/src/execution/expression.rs b/rust/datafusion/src/execution/expression.rs
index 74081d5..6282e77 100644
--- a/rust/datafusion/src/execution/expression.rs
+++ b/rust/datafusion/src/execution/expression.rs
@@ -121,6 +121,7 @@ pub(super) fn compile_aggregate_expr(
                 "max" => Ok(AggregateType::Max),
                 "count" => Ok(AggregateType::Count),
                 "sum" => Ok(AggregateType::Sum),
+                "avg" => Ok(AggregateType::Avg),
                 _ => Err(ExecutionError::General(format!(
                     "Unsupported aggregate function '{}'",
                     name
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index e21e9c7..81c52c9 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -78,6 +78,98 @@ fn csv_query_group_by_int_min_max() {
 }
 
 #[test]
+fn csv_query_avg() {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx);
+    //TODO add ORDER BY once supported, to make this test determistic
+    let sql = "SELECT avg(c12) FROM aggregate_test_100";
+    let actual = execute(&mut ctx, sql);
+    let expected = "0.5089725099127211\n".to_string();
+    assert_eq!(expected, actual);
+}
+
+#[test]
+fn csv_query_group_by_avg() {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx);
+    //TODO add ORDER BY once supported, to make this test determistic
+    let sql = "SELECT c1, avg(c12) FROM aggregate_test_100 GROUP BY c1";
+    let actual = execute(&mut ctx, sql);
+    let expected = "\"d\"\t0.48855379387549824\n\"c\"\t0.6600456536439784\n\"b\"\t0.41040709263815384\n\"a\"\t0.48754517466109415\n\"e\"\t0.48600669271341534\n".to_string();
+    assert_eq!(expected, actual);
+}
+
+#[test]
+fn csv_query_avg_multi_batch() {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx);
+    //TODO add ORDER BY once supported, to make this test determistic
+    let sql = "SELECT avg(c12) FROM aggregate_test_100";
+    let plan = ctx.create_logical_plan(&sql).unwrap();
+    let results = ctx.execute(&plan, 4).unwrap();
+    let mut relation = results.borrow_mut();
+    let batch = relation.next().unwrap().unwrap();
+    let column = batch.column(0);
+    let array = column.as_any().downcast_ref::<Float64Array>().unwrap();
+    let actual = array.value(0);
+    let expected = 0.5089725;
+    // Due to float number's accuracy, different batch size will lead to different answers.
+    assert!((expected - actual).abs() < 0.01);
+}
+
+#[test]
+fn csv_query_group_by_avg_multi_batch() {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx);
+    //TODO add ORDER BY once supported, to make this test determistic
+    let sql = "SELECT c1, avg(c12) FROM aggregate_test_100 GROUP BY c1";
+    let plan = ctx.create_logical_plan(&sql).unwrap();
+    let results = ctx.execute(&plan, 4).unwrap();
+    let mut relation = results.borrow_mut();
+    let mut actual_vec = Vec::new();
+    while let Some(batch) = relation.next().unwrap() {
+        let column = batch.column(1);
+        let array = column.as_any().downcast_ref::<Float64Array>().unwrap();
+
+        for row_index in 0..batch.num_rows() {
+            actual_vec.push(array.value(row_index));
+        }
+    }
+
+    let expect_vec = vec![0.48855379, 0.66004565, 0.41040709, 0.48754517];
+
+    actual_vec
+        .iter()
+        .zip(expect_vec.iter())
+        .for_each(|(actual, expect)| {
+            // Due to float number's accuracy, different batch size will lead to different answers.
+            assert!((expect - actual).abs() < 0.01);
+        });
+}
+
+#[test]
+fn csv_query_count() {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx);
+    //TODO add ORDER BY once supported, to make this test determistic
+    let sql = "SELECT count(c12) FROM aggregate_test_100";
+    let actual = execute(&mut ctx, sql);
+    let expected = "100\n".to_string();
+    assert_eq!(expected, actual);
+}
+
+#[test]
+fn csv_query_group_by_int_count() {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx);
+    //TODO add ORDER BY once supported, to make this test determistic
+    let sql = "SELECT count(c12) FROM aggregate_test_100 GROUP BY c1";
+    let actual = execute(&mut ctx, sql);
+    let expected = "\"d\"\t18\n\"c\"\t21\n\"b\"\t19\n\"a\"\t21\n\"e\"\t21\n".to_string();
+    assert_eq!(expected, actual);
+}
+
+#[test]
 fn csv_query_group_by_string_min_max() {
     let mut ctx = ExecutionContext::new();
     register_aggregate_csv(&mut ctx);
@@ -149,7 +241,8 @@ fn csv_query_limit_zero() {
     assert_eq!(expected, actual);
 }
 
-//TODO Uncomment the following test when ORDER BY is implemented to be able to test ORDER BY + LIMIT
+//TODO Uncomment the following test when ORDER BY is implemented to be able to test ORDER
+// BY + LIMIT
 /*
 #[test]
 fn csv_query_limit_with_order_by() {