You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2019/04/01 22:35:12 UTC

[arrow] branch master updated: ARROW-4596: [Rust] [DataFusion] Implement COUNT

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new f9e21ae  ARROW-4596: [Rust] [DataFusion] Implement COUNT
f9e21ae is described below

commit f9e21ae16ff77e890240ded713d1adfefe8649ba
Author: Andy Grove <an...@gmail.com>
AuthorDate: Mon Apr 1 16:34:59 2019 -0600

    ARROW-4596: [Rust] [DataFusion] Implement COUNT
    
    This builds on https://github.com/apache/arrow/pull/4035 to implement COUNT. I based this on work done by @LukeMathWalker in https://github.com/apache/arrow/pull/4020
    
    Author: Andy Grove <an...@gmail.com>
    
    Closes #4036 from andygrove/count and squashes the following commits:
    
    612c963a <Andy Grove> Address PR feedback
    84716528 <Andy Grove> Address PR feedback
    cc55c273 <Andy Grove> revert change
    f2f92769 <Andy Grove> Fix typo in comments
    d63b98a1 <Andy Grove> Fix bug in SQL query planner where COUNT(1) was being translated incorrectly
    611305b0 <Andy Grove> Add better documentation for AggregateFunction
    f580a146 <Andy Grove> Implement COUNT
    29a7778a <Andy Grove> Implement COUNT
---
 rust/datafusion/src/execution/aggregate.rs  | 243 +++++++++++++++++++++++-----
 rust/datafusion/src/execution/expression.rs |   4 +-
 rust/datafusion/src/sql/planner.rs          |  18 ++-
 rust/datafusion/tests/sql.rs                |  10 ++
 4 files changed, 230 insertions(+), 45 deletions(-)

diff --git a/rust/datafusion/src/execution/aggregate.rs b/rust/datafusion/src/execution/aggregate.rs
index 84fe925..7068c6a 100644
--- a/rust/datafusion/src/execution/aggregate.rs
+++ b/rust/datafusion/src/execution/aggregate.rs
@@ -80,16 +80,37 @@ enum GroupByScalar {
     Utf8(String),
 }
 
-/// Common trait for all aggregation functions
+/// Aggregate function that can accept individual values and compute an aggregate
 trait AggregateFunction {
     /// Get the function name (used for debugging)
     fn name(&self) -> &str;
-    fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()>;
-    fn result(&self) -> &Option<ScalarValue>;
+
+    /// Update the current aggregate value based on a new value. A value of `None` represents a
+    /// null value. If rollup is false, then this aggregate function instance is being used
+    /// to aggregate individual values within a RecordBatch. If rollup is true then the aggregate
+    /// function instance is being used to combine the aggregates for multiple batches.
+    /// For some aggregate operations, such as `min`, `max`, and `sum`, the logic is the same
+    /// regardless of whether rollup is true or false. For example, `min` can be implemented as
+    /// `min(min(batch1), min(batch2), ..)`. However for `count` the logic is
+    /// `sum(count(batch1), count(batch2), ..)`.
+    fn accumulate_scalar(
+        &mut self,
+        value: &Option<ScalarValue>,
+        rollup: bool,
+    ) -> Result<()>;
+
+    /// Return the result of the aggregate function after all values have been processed
+    /// by calls to `acccumulate_scalar`.
+    fn result(&self) -> Option<ScalarValue>;
+
+    /// Get the data type of the result of the aggregate function. For some operations,
+    /// such as `min`, `max`, and `sum`, the data type will be the same as the data type
+    /// of the argument. For other aggregates, such as `count`, the data type is independent
+    /// of the data type of the input.
     fn data_type(&self) -> &DataType;
 }
 
-/// Implemntation of MIN aggregate function
+/// Implementation of MIN aggregate function
 #[derive(Debug)]
 struct MinFunction {
     data_type: DataType,
@@ -110,7 +131,11 @@ impl AggregateFunction for MinFunction {
         "min"
     }
 
-    fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
+    fn accumulate_scalar(
+        &mut self,
+        value: &Option<ScalarValue>,
+        _rollup: bool,
+    ) -> Result<()> {
         if self.value.is_none() {
             self.value = value.clone();
         } else if value.is_some() {
@@ -155,8 +180,8 @@ impl AggregateFunction for MinFunction {
         Ok(())
     }
 
-    fn result(&self) -> &Option<ScalarValue> {
-        &self.value
+    fn result(&self) -> Option<ScalarValue> {
+        self.value.clone()
     }
 
     fn data_type(&self) -> &DataType {
@@ -164,7 +189,7 @@ impl AggregateFunction for MinFunction {
     }
 }
 
-/// Implemntation of MAX aggregate function
+/// Implementation of MAX aggregate function
 #[derive(Debug)]
 struct MaxFunction {
     data_type: DataType,
@@ -185,7 +210,11 @@ impl AggregateFunction for MaxFunction {
         "max"
     }
 
-    fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
+    fn accumulate_scalar(
+        &mut self,
+        value: &Option<ScalarValue>,
+        _rollup: bool,
+    ) -> Result<()> {
         if self.value.is_none() {
             self.value = value.clone();
         } else if value.is_some() {
@@ -230,8 +259,8 @@ impl AggregateFunction for MaxFunction {
         Ok(())
     }
 
-    fn result(&self) -> &Option<ScalarValue> {
-        &self.value
+    fn result(&self) -> Option<ScalarValue> {
+        self.value.clone()
     }
 
     fn data_type(&self) -> &DataType {
@@ -239,7 +268,7 @@ impl AggregateFunction for MaxFunction {
     }
 }
 
-/// Implemntation of SUM aggregate function
+/// Implementation of SUM aggregate function
 #[derive(Debug)]
 struct SumFunction {
     data_type: DataType,
@@ -260,7 +289,11 @@ impl AggregateFunction for SumFunction {
         "sum"
     }
 
-    fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
+    fn accumulate_scalar(
+        &mut self,
+        value: &Option<ScalarValue>,
+        _rollup: bool,
+    ) -> Result<()> {
         if self.value.is_none() {
             self.value = value.clone();
         } else if value.is_some() {
@@ -305,8 +338,8 @@ impl AggregateFunction for SumFunction {
         Ok(())
     }
 
-    fn result(&self) -> &Option<ScalarValue> {
-        &self.value
+    fn result(&self) -> Option<ScalarValue> {
+        self.value.clone()
     }
 
     fn data_type(&self) -> &DataType {
@@ -314,14 +347,73 @@ impl AggregateFunction for SumFunction {
     }
 }
 
+/// Implementation of COUNT aggregate function
+#[derive(Debug)]
+struct CountFunction {
+    value: Option<u64>,
+}
+
+impl CountFunction {
+    fn new() -> Self {
+        Self { value: None }
+    }
+}
+
+impl AggregateFunction for CountFunction {
+    fn name(&self) -> &str {
+        "count"
+    }
+
+    fn accumulate_scalar(
+        &mut self,
+        value: &Option<ScalarValue>,
+        rollup: bool,
+    ) -> Result<()> {
+        if rollup {
+            // in rollup mode, the counts are added together
+            if let Some(ScalarValue::UInt64(n)) = value {
+                self.value = match self.value {
+                    Some(cur_value) => Some(cur_value + *n),
+                    _ => Some(*n),
+                }
+            }
+        } else {
+            // count the value if it is not null
+            if value.is_some() {
+                self.value = match self.value {
+                    Some(cur_value) => Some(cur_value + 1),
+                    _ => Some(1),
+                }
+            }
+        }
+        Ok(())
+    }
+
+    fn result(&self) -> Option<ScalarValue> {
+        match self.value {
+            Some(n) => Some(ScalarValue::UInt64(n)),
+            None => None,
+        }
+    }
+
+    fn data_type(&self) -> &DataType {
+        &DataType::UInt64
+    }
+}
+
 struct AccumulatorSet {
     aggr_values: Vec<Rc<RefCell<AggregateFunction>>>,
 }
 
 impl AccumulatorSet {
-    fn accumulate_scalar(&mut self, i: usize, value: Option<ScalarValue>) -> Result<()> {
+    fn accumulate_scalar(
+        &mut self,
+        i: usize,
+        value: Option<ScalarValue>,
+        rollup: bool,
+    ) -> Result<()> {
         let mut accumulator = self.aggr_values[i].borrow_mut();
-        accumulator.accumulate_scalar(&value)
+        accumulator.accumulate_scalar(&value, rollup)
     }
 
     fn values(&self) -> Vec<Option<ScalarValue>> {
@@ -357,6 +449,8 @@ fn create_accumulators(
                 Ok(Rc::new(RefCell::new(SumFunction::new(e.data_type())))
                     as Rc<RefCell<AggregateFunction>>)
             }
+            AggregateType::Count => Ok(Rc::new(RefCell::new(CountFunction::new()))
+                as Rc<RefCell<AggregateFunction>>),
             _ => Err(ExecutionError::ExecutionError(
                 "unsupported aggregate function".to_string(),
             )),
@@ -366,8 +460,8 @@ fn create_accumulators(
     Ok(AccumulatorSet { aggr_values })
 }
 
-fn array_min(array: ArrayRef, dt: &DataType) -> Result<Option<ScalarValue>> {
-    match dt {
+fn array_min(array: ArrayRef) -> Result<Option<ScalarValue>> {
+    match array.data_type() {
         DataType::UInt8 => {
             match compute::min(array.as_any().downcast_ref::<UInt8Array>().unwrap()) {
                 Some(n) => Ok(Some(ScalarValue::UInt8(n))),
@@ -434,8 +528,8 @@ fn array_min(array: ArrayRef, dt: &DataType) -> Result<Option<ScalarValue>> {
     }
 }
 
-fn array_max(array: ArrayRef, dt: &DataType) -> Result<Option<ScalarValue>> {
-    match dt {
+fn array_max(array: ArrayRef) -> Result<Option<ScalarValue>> {
+    match array.data_type() {
         DataType::UInt8 => {
             match compute::max(array.as_any().downcast_ref::<UInt8Array>().unwrap()) {
                 Some(n) => Ok(Some(ScalarValue::UInt8(n))),
@@ -502,8 +596,8 @@ fn array_max(array: ArrayRef, dt: &DataType) -> Result<Option<ScalarValue>> {
     }
 }
 
-fn array_sum(array: ArrayRef, dt: &DataType) -> Result<Option<ScalarValue>> {
-    match dt {
+fn array_sum(array: ArrayRef) -> Result<Option<ScalarValue>> {
+    match array.data_type() {
         DataType::UInt8 => {
             match compute::sum(array.as_any().downcast_ref::<UInt8Array>().unwrap()) {
                 Some(n) => Ok(Some(ScalarValue::UInt8(n))),
@@ -570,6 +664,12 @@ fn array_sum(array: ArrayRef, dt: &DataType) -> Result<Option<ScalarValue>> {
     }
 }
 
+fn array_count(array: ArrayRef) -> Result<Option<ScalarValue>> {
+    Ok(Some(ScalarValue::UInt64(
+        (array.len() - array.null_count()) as u64,
+    )))
+}
+
 fn update_accumulators(
     batch: &RecordBatch,
     row: usize,
@@ -579,9 +679,9 @@ fn update_accumulators(
     // update the accumulators
     for j in 0..accumulator_set.aggr_values.len() {
         // evaluate the argument to the aggregate function
-        let array = aggr_expr[j].invoke(batch)?;
+        let array = aggr_expr[j].evaluate_arg(batch)?;
 
-        let value: Option<ScalarValue> = match aggr_expr[j].data_type() {
+        let value: Option<ScalarValue> = match array.data_type() {
             DataType::UInt8 => {
                 let z = array.as_any().downcast_ref::<UInt8Array>().unwrap();
                 Some(ScalarValue::UInt8(z.value(row)))
@@ -630,7 +730,7 @@ fn update_accumulators(
             }
         };
 
-        accumulator_set.accumulate_scalar(j, value)?;
+        accumulator_set.accumulate_scalar(j, value, false)?;
     }
     Ok(())
 }
@@ -660,7 +760,7 @@ macro_rules! array_from_scalar {
         let mut err = false;
         match $ACCUM.result() {
             Some(ScalarValue::$TY(n)) => {
-                b.append_value(*n)?;
+                b.append_value(n)?;
             }
             None => {
                 b.append_null()?;
@@ -732,19 +832,19 @@ impl AggregateRelation {
         while let Some(batch) = self.input.borrow_mut().next()? {
             for i in 0..aggr_expr_count {
                 // evaluate the argument to the aggregate function
-                let array = self.aggr_expr[i].invoke(&batch)?;
-
-                let t = self.aggr_expr[i].data_type();
-
+                let array = self.aggr_expr[i].evaluate_arg(&batch)?;
                 match self.aggr_expr[i].aggr_type() {
                     AggregateType::Min => {
-                        accumulator_set.accumulate_scalar(i, array_min(array, &t)?)?
+                        accumulator_set.accumulate_scalar(i, array_min(array)?, true)?
                     }
                     AggregateType::Max => {
-                        accumulator_set.accumulate_scalar(i, array_max(array, &t)?)?
+                        accumulator_set.accumulate_scalar(i, array_max(array)?, true)?
                     }
                     AggregateType::Sum => {
-                        accumulator_set.accumulate_scalar(i, array_sum(array, &t)?)?
+                        accumulator_set.accumulate_scalar(i, array_sum(array)?, true)?
+                    }
+                    AggregateType::Count => {
+                        accumulator_set.accumulate_scalar(i, array_count(array)?, true)?
                     }
                     _ => {
                         return Err(ExecutionError::NotImplemented(
@@ -1071,6 +1171,41 @@ mod tests {
     }
 
     #[test]
+    fn count() {
+        let schema = aggr_test_schema();
+        let relation = load_csv("../../testing/data/csv/aggregate_test_100.csv", &schema);
+        let context = ExecutionContext::new();
+
+        let aggr_expr = vec![expression::compile_aggregate_expr(
+            &context,
+            &Expr::AggregateFunction {
+                name: String::from("count"),
+                args: vec![Expr::Column(11)],
+                return_type: DataType::UInt64,
+            },
+            &schema,
+        )
+        .unwrap()];
+
+        let aggr_schema = Arc::new(Schema::new(vec![Field::new(
+            "count",
+            DataType::UInt64,
+            false,
+        )]));
+
+        let mut projection =
+            AggregateRelation::new(aggr_schema, relation, vec![], aggr_expr);
+        let batch = projection.next().unwrap().unwrap();
+        assert_eq!(1, batch.num_columns());
+        let count = batch
+            .column(0)
+            .as_any()
+            .downcast_ref::<UInt64Array>()
+            .unwrap();
+        assert_eq!(100, count.value(0));
+    }
+
+    #[test]
     fn max_f64_group_by_string() {
         let schema = aggr_test_schema();
         let testdata = env::var("ARROW_TEST_DATA").expect("ARROW_TEST_DATA not defined");
@@ -1108,7 +1243,7 @@ mod tests {
     }
 
     #[test]
-    fn test_min_max_sum_f64_group_by_uint32() {
+    fn test_min_max_sum_count_f64_group_by_uint32() {
         let schema = aggr_test_schema();
         let testdata = env::var("ARROW_TEST_DATA").expect("ARROW_TEST_DATA not defined");
         let relation =
@@ -1152,21 +1287,33 @@ mod tests {
         )
         .unwrap();
 
+        let count_expr = expression::compile_aggregate_expr(
+            &context,
+            &Expr::AggregateFunction {
+                name: String::from("count"),
+                args: vec![Expr::Column(11)],
+                return_type: DataType::UInt64,
+            },
+            &schema,
+        )
+        .unwrap();
+
         let aggr_schema = Arc::new(Schema::new(vec![
             Field::new("c2", DataType::UInt32, false),
             Field::new("min", DataType::Float64, false),
             Field::new("max", DataType::Float64, false),
             Field::new("sum", DataType::Float64, false),
+            Field::new("count", DataType::UInt64, false),
         ]));
 
         let mut projection = AggregateRelation::new(
             aggr_schema,
             relation,
             vec![group_by_expr],
-            vec![min_expr, max_expr, sum_expr],
+            vec![min_expr, max_expr, sum_expr, count_expr],
         );
         let batch = projection.next().unwrap().unwrap();
-        assert_eq!(4, batch.num_columns());
+        assert_eq!(5, batch.num_columns());
         assert_eq!(5, batch.num_rows());
 
         let a = batch
@@ -1189,21 +1336,41 @@ mod tests {
             .as_any()
             .downcast_ref::<Float64Array>()
             .unwrap();
+        let count = batch
+            .column(4)
+            .as_any()
+            .downcast_ref::<UInt64Array>()
+            .unwrap();
 
         assert_eq!(4, a.value(0));
         assert_eq!(0.02182578039211991, min.value(0));
         assert_eq!(0.9237877978193884, max.value(0));
         assert_eq!(9.253864188402662, sum.value(0));
+        assert_eq!(23, count.value(0));
 
         assert_eq!(2, a.value(1));
         assert_eq!(0.16301110515739792, min.value(1));
         assert_eq!(0.991517828651004, max.value(1));
         assert_eq!(14.400412325480858, sum.value(1));
+        assert_eq!(22, count.value(1));
 
         assert_eq!(5, a.value(2));
         assert_eq!(0.01479305307777301, min.value(2));
         assert_eq!(0.9723580396501548, max.value(2));
         assert_eq!(6.037181692266781, sum.value(2));
+        assert_eq!(14, count.value(2));
+
+        assert_eq!(3, a.value(3));
+        assert_eq!(0.047343434291126085, min.value(3));
+        assert_eq!(0.9293883502480845, max.value(3));
+        assert_eq!(9.966125219358322, sum.value(3));
+        assert_eq!(19, count.value(3));
+
+        assert_eq!(1, a.value(4));
+        assert_eq!(0.05636955101974106, min.value(4));
+        assert_eq!(0.9965400387585364, max.value(4));
+        assert_eq!(11.239667565763519, sum.value(4));
+        assert_eq!(14, count.value(2));
     }
 
     fn aggr_test_schema() -> Arc<Schema> {
@@ -1225,7 +1392,7 @@ mod tests {
     }
 
     fn load_csv(filename: &str, schema: &Arc<Schema>) -> Rc<RefCell<Relation>> {
-        let ds = CsvBatchIterator::new(filename, schema.clone(), true, &None, 1024);
+        let ds = CsvBatchIterator::new(filename, schema.clone(), true, &None, 10);
         Rc::new(RefCell::new(DataSourceRelation::new(Arc::new(Mutex::new(
             ds,
         )))))
diff --git a/rust/datafusion/src/execution/expression.rs b/rust/datafusion/src/execution/expression.rs
index 45adaac..74081d5 100644
--- a/rust/datafusion/src/execution/expression.rs
+++ b/rust/datafusion/src/execution/expression.rs
@@ -91,8 +91,8 @@ impl CompiledAggregateExpression {
         &self.t
     }
 
-    /// invoke the compiled expression
-    pub fn invoke(&self, batch: &RecordBatch) -> Result<ArrayRef> {
+    /// invoke the compiled expression for the input to the aggregate function
+    pub fn evaluate_arg(&self, batch: &RecordBatch) -> Result<ArrayRef> {
         self.args[0].invoke(batch)
     }
 }
diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs
index 50cb77d..71ad507 100644
--- a/rust/datafusion/src/sql/planner.rs
+++ b/rust/datafusion/src/sql/planner.rs
@@ -315,12 +315,12 @@ impl SqlToRel {
                         let rex_args = args
                             .iter()
                             .map(|a| match a {
-                                // this feels hacky but translate COUNT(1)/COUNT(*) to
-                                // COUNT(first_column)
-                                ASTNode::SQLValue(sqlparser::sqlast::Value::Long(1)) => {
-                                    Ok(Expr::Column(0))
+                                ASTNode::SQLValue(sqlparser::sqlast::Value::Long(_)) => {
+                                    Ok(Expr::Literal(ScalarValue::UInt8(1)))
                                 }
-                                ASTNode::SQLWildcard => Ok(Expr::Column(0)),
+                                ASTNode::SQLWildcard => {
+                                    Ok(Expr::Literal(ScalarValue::UInt8(1)))
+                                },
                                 _ => self.sql_to_rex(a, schema),
                             })
                             .collect::<Result<Vec<Expr>>>()?;
@@ -482,6 +482,14 @@ mod tests {
     #[test]
     fn select_count_one() {
         let sql = "SELECT COUNT(1) FROM person";
+        let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
+                        \n  TableScan: person projection=None";
+        quick_test(sql, expected);
+    }
+
+    #[test]
+    fn select_count_column() {
+        let sql = "SELECT COUNT(id) FROM person";
         let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(#0)]]\
                         \n  TableScan: person projection=None";
         quick_test(sql, expected);
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index 7b04609..e21e9c7 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -47,6 +47,16 @@ fn parquet_query() {
 }
 
 #[test]
+fn csv_count_star() {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx);
+    let sql = "SELECT COUNT(*), COUNT(1), COUNT(c1) FROM aggregate_test_100";
+    let actual = execute(&mut ctx, sql);
+    let expected = "100\t100\t100\n".to_string();
+    assert_eq!(expected, actual);
+}
+
+#[test]
 fn csv_query_with_predicate() {
     let mut ctx = ExecutionContext::new();
     register_aggregate_csv(&mut ctx);