You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2019/04/01 22:35:12 UTC
[arrow] branch master updated: ARROW-4596: [Rust] [DataFusion]
Implement COUNT
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new f9e21ae ARROW-4596: [Rust] [DataFusion] Implement COUNT
f9e21ae is described below
commit f9e21ae16ff77e890240ded713d1adfefe8649ba
Author: Andy Grove <an...@gmail.com>
AuthorDate: Mon Apr 1 16:34:59 2019 -0600
ARROW-4596: [Rust] [DataFusion] Implement COUNT
This builds on https://github.com/apache/arrow/pull/4035 to implement COUNT. I based this on work done by @LukeMathWalker in https://github.com/apache/arrow/pull/4020
Author: Andy Grove <an...@gmail.com>
Closes #4036 from andygrove/count and squashes the following commits:
612c963a <Andy Grove> Address PR feedback
84716528 <Andy Grove> Address PR feedback
cc55c273 <Andy Grove> revert change
f2f92769 <Andy Grove> Fix typo in comments
d63b98a1 <Andy Grove> Fix bug in SQL query planner where COUNT(1) was being translated incorrectly
611305b0 <Andy Grove> Add better documentation for AggregateFunction
f580a146 <Andy Grove> Implement COUNT
29a7778a <Andy Grove> Implement COUNT
---
rust/datafusion/src/execution/aggregate.rs | 243 +++++++++++++++++++++++-----
rust/datafusion/src/execution/expression.rs | 4 +-
rust/datafusion/src/sql/planner.rs | 18 ++-
rust/datafusion/tests/sql.rs | 10 ++
4 files changed, 230 insertions(+), 45 deletions(-)
diff --git a/rust/datafusion/src/execution/aggregate.rs b/rust/datafusion/src/execution/aggregate.rs
index 84fe925..7068c6a 100644
--- a/rust/datafusion/src/execution/aggregate.rs
+++ b/rust/datafusion/src/execution/aggregate.rs
@@ -80,16 +80,37 @@ enum GroupByScalar {
Utf8(String),
}
-/// Common trait for all aggregation functions
+/// Aggregate function that can accept individual values and compute an aggregate
trait AggregateFunction {
/// Get the function name (used for debugging)
fn name(&self) -> &str;
- fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()>;
- fn result(&self) -> &Option<ScalarValue>;
+
+ /// Update the current aggregate value based on a new value. A value of `None` represents a
+ /// null value. If rollup is false, then this aggregate function instance is being used
+ /// to aggregate individual values within a RecordBatch. If rollup is true then the aggregate
+ /// function instance is being used to combine the aggregates for multiple batches.
+ /// For some aggregate operations, such as `min`, `max`, and `sum`, the logic is the same
+ /// regardless of whether rollup is true or false. For example, `min` can be implemented as
+ /// `min(min(batch1), min(batch2), ..)`. However for `count` the logic is
+ /// `sum(count(batch1), count(batch2), ..)`.
+ fn accumulate_scalar(
+ &mut self,
+ value: &Option<ScalarValue>,
+ rollup: bool,
+ ) -> Result<()>;
+
+ /// Return the result of the aggregate function after all values have been processed
+ /// by calls to `acccumulate_scalar`.
+ fn result(&self) -> Option<ScalarValue>;
+
+ /// Get the data type of the result of the aggregate function. For some operations,
+ /// such as `min`, `max`, and `sum`, the data type will be the same as the data type
+ /// of the argument. For other aggregates, such as `count`, the data type is independent
+ /// of the data type of the input.
fn data_type(&self) -> &DataType;
}
-/// Implemntation of MIN aggregate function
+/// Implementation of MIN aggregate function
#[derive(Debug)]
struct MinFunction {
data_type: DataType,
@@ -110,7 +131,11 @@ impl AggregateFunction for MinFunction {
"min"
}
- fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
+ fn accumulate_scalar(
+ &mut self,
+ value: &Option<ScalarValue>,
+ _rollup: bool,
+ ) -> Result<()> {
if self.value.is_none() {
self.value = value.clone();
} else if value.is_some() {
@@ -155,8 +180,8 @@ impl AggregateFunction for MinFunction {
Ok(())
}
- fn result(&self) -> &Option<ScalarValue> {
- &self.value
+ fn result(&self) -> Option<ScalarValue> {
+ self.value.clone()
}
fn data_type(&self) -> &DataType {
@@ -164,7 +189,7 @@ impl AggregateFunction for MinFunction {
}
}
-/// Implemntation of MAX aggregate function
+/// Implementation of MAX aggregate function
#[derive(Debug)]
struct MaxFunction {
data_type: DataType,
@@ -185,7 +210,11 @@ impl AggregateFunction for MaxFunction {
"max"
}
- fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
+ fn accumulate_scalar(
+ &mut self,
+ value: &Option<ScalarValue>,
+ _rollup: bool,
+ ) -> Result<()> {
if self.value.is_none() {
self.value = value.clone();
} else if value.is_some() {
@@ -230,8 +259,8 @@ impl AggregateFunction for MaxFunction {
Ok(())
}
- fn result(&self) -> &Option<ScalarValue> {
- &self.value
+ fn result(&self) -> Option<ScalarValue> {
+ self.value.clone()
}
fn data_type(&self) -> &DataType {
@@ -239,7 +268,7 @@ impl AggregateFunction for MaxFunction {
}
}
-/// Implemntation of SUM aggregate function
+/// Implementation of SUM aggregate function
#[derive(Debug)]
struct SumFunction {
data_type: DataType,
@@ -260,7 +289,11 @@ impl AggregateFunction for SumFunction {
"sum"
}
- fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
+ fn accumulate_scalar(
+ &mut self,
+ value: &Option<ScalarValue>,
+ _rollup: bool,
+ ) -> Result<()> {
if self.value.is_none() {
self.value = value.clone();
} else if value.is_some() {
@@ -305,8 +338,8 @@ impl AggregateFunction for SumFunction {
Ok(())
}
- fn result(&self) -> &Option<ScalarValue> {
- &self.value
+ fn result(&self) -> Option<ScalarValue> {
+ self.value.clone()
}
fn data_type(&self) -> &DataType {
@@ -314,14 +347,73 @@ impl AggregateFunction for SumFunction {
}
}
+/// Implementation of COUNT aggregate function
+#[derive(Debug)]
+struct CountFunction {
+ value: Option<u64>,
+}
+
+impl CountFunction {
+ fn new() -> Self {
+ Self { value: None }
+ }
+}
+
+impl AggregateFunction for CountFunction {
+ fn name(&self) -> &str {
+ "count"
+ }
+
+ fn accumulate_scalar(
+ &mut self,
+ value: &Option<ScalarValue>,
+ rollup: bool,
+ ) -> Result<()> {
+ if rollup {
+ // in rollup mode, the counts are added together
+ if let Some(ScalarValue::UInt64(n)) = value {
+ self.value = match self.value {
+ Some(cur_value) => Some(cur_value + *n),
+ _ => Some(*n),
+ }
+ }
+ } else {
+ // count the value if it is not null
+ if value.is_some() {
+ self.value = match self.value {
+ Some(cur_value) => Some(cur_value + 1),
+ _ => Some(1),
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn result(&self) -> Option<ScalarValue> {
+ match self.value {
+ Some(n) => Some(ScalarValue::UInt64(n)),
+ None => None,
+ }
+ }
+
+ fn data_type(&self) -> &DataType {
+ &DataType::UInt64
+ }
+}
+
struct AccumulatorSet {
aggr_values: Vec<Rc<RefCell<AggregateFunction>>>,
}
impl AccumulatorSet {
- fn accumulate_scalar(&mut self, i: usize, value: Option<ScalarValue>) -> Result<()> {
+ fn accumulate_scalar(
+ &mut self,
+ i: usize,
+ value: Option<ScalarValue>,
+ rollup: bool,
+ ) -> Result<()> {
let mut accumulator = self.aggr_values[i].borrow_mut();
- accumulator.accumulate_scalar(&value)
+ accumulator.accumulate_scalar(&value, rollup)
}
fn values(&self) -> Vec<Option<ScalarValue>> {
@@ -357,6 +449,8 @@ fn create_accumulators(
Ok(Rc::new(RefCell::new(SumFunction::new(e.data_type())))
as Rc<RefCell<AggregateFunction>>)
}
+ AggregateType::Count => Ok(Rc::new(RefCell::new(CountFunction::new()))
+ as Rc<RefCell<AggregateFunction>>),
_ => Err(ExecutionError::ExecutionError(
"unsupported aggregate function".to_string(),
)),
@@ -366,8 +460,8 @@ fn create_accumulators(
Ok(AccumulatorSet { aggr_values })
}
-fn array_min(array: ArrayRef, dt: &DataType) -> Result<Option<ScalarValue>> {
- match dt {
+fn array_min(array: ArrayRef) -> Result<Option<ScalarValue>> {
+ match array.data_type() {
DataType::UInt8 => {
match compute::min(array.as_any().downcast_ref::<UInt8Array>().unwrap()) {
Some(n) => Ok(Some(ScalarValue::UInt8(n))),
@@ -434,8 +528,8 @@ fn array_min(array: ArrayRef, dt: &DataType) -> Result<Option<ScalarValue>> {
}
}
-fn array_max(array: ArrayRef, dt: &DataType) -> Result<Option<ScalarValue>> {
- match dt {
+fn array_max(array: ArrayRef) -> Result<Option<ScalarValue>> {
+ match array.data_type() {
DataType::UInt8 => {
match compute::max(array.as_any().downcast_ref::<UInt8Array>().unwrap()) {
Some(n) => Ok(Some(ScalarValue::UInt8(n))),
@@ -502,8 +596,8 @@ fn array_max(array: ArrayRef, dt: &DataType) -> Result<Option<ScalarValue>> {
}
}
-fn array_sum(array: ArrayRef, dt: &DataType) -> Result<Option<ScalarValue>> {
- match dt {
+fn array_sum(array: ArrayRef) -> Result<Option<ScalarValue>> {
+ match array.data_type() {
DataType::UInt8 => {
match compute::sum(array.as_any().downcast_ref::<UInt8Array>().unwrap()) {
Some(n) => Ok(Some(ScalarValue::UInt8(n))),
@@ -570,6 +664,12 @@ fn array_sum(array: ArrayRef, dt: &DataType) -> Result<Option<ScalarValue>> {
}
}
+fn array_count(array: ArrayRef) -> Result<Option<ScalarValue>> {
+ Ok(Some(ScalarValue::UInt64(
+ (array.len() - array.null_count()) as u64,
+ )))
+}
+
fn update_accumulators(
batch: &RecordBatch,
row: usize,
@@ -579,9 +679,9 @@ fn update_accumulators(
// update the accumulators
for j in 0..accumulator_set.aggr_values.len() {
// evaluate the argument to the aggregate function
- let array = aggr_expr[j].invoke(batch)?;
+ let array = aggr_expr[j].evaluate_arg(batch)?;
- let value: Option<ScalarValue> = match aggr_expr[j].data_type() {
+ let value: Option<ScalarValue> = match array.data_type() {
DataType::UInt8 => {
let z = array.as_any().downcast_ref::<UInt8Array>().unwrap();
Some(ScalarValue::UInt8(z.value(row)))
@@ -630,7 +730,7 @@ fn update_accumulators(
}
};
- accumulator_set.accumulate_scalar(j, value)?;
+ accumulator_set.accumulate_scalar(j, value, false)?;
}
Ok(())
}
@@ -660,7 +760,7 @@ macro_rules! array_from_scalar {
let mut err = false;
match $ACCUM.result() {
Some(ScalarValue::$TY(n)) => {
- b.append_value(*n)?;
+ b.append_value(n)?;
}
None => {
b.append_null()?;
@@ -732,19 +832,19 @@ impl AggregateRelation {
while let Some(batch) = self.input.borrow_mut().next()? {
for i in 0..aggr_expr_count {
// evaluate the argument to the aggregate function
- let array = self.aggr_expr[i].invoke(&batch)?;
-
- let t = self.aggr_expr[i].data_type();
-
+ let array = self.aggr_expr[i].evaluate_arg(&batch)?;
match self.aggr_expr[i].aggr_type() {
AggregateType::Min => {
- accumulator_set.accumulate_scalar(i, array_min(array, &t)?)?
+ accumulator_set.accumulate_scalar(i, array_min(array)?, true)?
}
AggregateType::Max => {
- accumulator_set.accumulate_scalar(i, array_max(array, &t)?)?
+ accumulator_set.accumulate_scalar(i, array_max(array)?, true)?
}
AggregateType::Sum => {
- accumulator_set.accumulate_scalar(i, array_sum(array, &t)?)?
+ accumulator_set.accumulate_scalar(i, array_sum(array)?, true)?
+ }
+ AggregateType::Count => {
+ accumulator_set.accumulate_scalar(i, array_count(array)?, true)?
}
_ => {
return Err(ExecutionError::NotImplemented(
@@ -1071,6 +1171,41 @@ mod tests {
}
#[test]
+ fn count() {
+ let schema = aggr_test_schema();
+ let relation = load_csv("../../testing/data/csv/aggregate_test_100.csv", &schema);
+ let context = ExecutionContext::new();
+
+ let aggr_expr = vec![expression::compile_aggregate_expr(
+ &context,
+ &Expr::AggregateFunction {
+ name: String::from("count"),
+ args: vec![Expr::Column(11)],
+ return_type: DataType::UInt64,
+ },
+ &schema,
+ )
+ .unwrap()];
+
+ let aggr_schema = Arc::new(Schema::new(vec![Field::new(
+ "count",
+ DataType::UInt64,
+ false,
+ )]));
+
+ let mut projection =
+ AggregateRelation::new(aggr_schema, relation, vec![], aggr_expr);
+ let batch = projection.next().unwrap().unwrap();
+ assert_eq!(1, batch.num_columns());
+ let count = batch
+ .column(0)
+ .as_any()
+ .downcast_ref::<UInt64Array>()
+ .unwrap();
+ assert_eq!(100, count.value(0));
+ }
+
+ #[test]
fn max_f64_group_by_string() {
let schema = aggr_test_schema();
let testdata = env::var("ARROW_TEST_DATA").expect("ARROW_TEST_DATA not defined");
@@ -1108,7 +1243,7 @@ mod tests {
}
#[test]
- fn test_min_max_sum_f64_group_by_uint32() {
+ fn test_min_max_sum_count_f64_group_by_uint32() {
let schema = aggr_test_schema();
let testdata = env::var("ARROW_TEST_DATA").expect("ARROW_TEST_DATA not defined");
let relation =
@@ -1152,21 +1287,33 @@ mod tests {
)
.unwrap();
+ let count_expr = expression::compile_aggregate_expr(
+ &context,
+ &Expr::AggregateFunction {
+ name: String::from("count"),
+ args: vec![Expr::Column(11)],
+ return_type: DataType::UInt64,
+ },
+ &schema,
+ )
+ .unwrap();
+
let aggr_schema = Arc::new(Schema::new(vec![
Field::new("c2", DataType::UInt32, false),
Field::new("min", DataType::Float64, false),
Field::new("max", DataType::Float64, false),
Field::new("sum", DataType::Float64, false),
+ Field::new("count", DataType::UInt64, false),
]));
let mut projection = AggregateRelation::new(
aggr_schema,
relation,
vec![group_by_expr],
- vec![min_expr, max_expr, sum_expr],
+ vec![min_expr, max_expr, sum_expr, count_expr],
);
let batch = projection.next().unwrap().unwrap();
- assert_eq!(4, batch.num_columns());
+ assert_eq!(5, batch.num_columns());
assert_eq!(5, batch.num_rows());
let a = batch
@@ -1189,21 +1336,41 @@ mod tests {
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
+ let count = batch
+ .column(4)
+ .as_any()
+ .downcast_ref::<UInt64Array>()
+ .unwrap();
assert_eq!(4, a.value(0));
assert_eq!(0.02182578039211991, min.value(0));
assert_eq!(0.9237877978193884, max.value(0));
assert_eq!(9.253864188402662, sum.value(0));
+ assert_eq!(23, count.value(0));
assert_eq!(2, a.value(1));
assert_eq!(0.16301110515739792, min.value(1));
assert_eq!(0.991517828651004, max.value(1));
assert_eq!(14.400412325480858, sum.value(1));
+ assert_eq!(22, count.value(1));
assert_eq!(5, a.value(2));
assert_eq!(0.01479305307777301, min.value(2));
assert_eq!(0.9723580396501548, max.value(2));
assert_eq!(6.037181692266781, sum.value(2));
+ assert_eq!(14, count.value(2));
+
+ assert_eq!(3, a.value(3));
+ assert_eq!(0.047343434291126085, min.value(3));
+ assert_eq!(0.9293883502480845, max.value(3));
+ assert_eq!(9.966125219358322, sum.value(3));
+ assert_eq!(19, count.value(3));
+
+ assert_eq!(1, a.value(4));
+ assert_eq!(0.05636955101974106, min.value(4));
+ assert_eq!(0.9965400387585364, max.value(4));
+ assert_eq!(11.239667565763519, sum.value(4));
+ assert_eq!(14, count.value(2));
}
fn aggr_test_schema() -> Arc<Schema> {
@@ -1225,7 +1392,7 @@ mod tests {
}
fn load_csv(filename: &str, schema: &Arc<Schema>) -> Rc<RefCell<Relation>> {
- let ds = CsvBatchIterator::new(filename, schema.clone(), true, &None, 1024);
+ let ds = CsvBatchIterator::new(filename, schema.clone(), true, &None, 10);
Rc::new(RefCell::new(DataSourceRelation::new(Arc::new(Mutex::new(
ds,
)))))
diff --git a/rust/datafusion/src/execution/expression.rs b/rust/datafusion/src/execution/expression.rs
index 45adaac..74081d5 100644
--- a/rust/datafusion/src/execution/expression.rs
+++ b/rust/datafusion/src/execution/expression.rs
@@ -91,8 +91,8 @@ impl CompiledAggregateExpression {
&self.t
}
- /// invoke the compiled expression
- pub fn invoke(&self, batch: &RecordBatch) -> Result<ArrayRef> {
+ /// invoke the compiled expression for the input to the aggregate function
+ pub fn evaluate_arg(&self, batch: &RecordBatch) -> Result<ArrayRef> {
self.args[0].invoke(batch)
}
}
diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs
index 50cb77d..71ad507 100644
--- a/rust/datafusion/src/sql/planner.rs
+++ b/rust/datafusion/src/sql/planner.rs
@@ -315,12 +315,12 @@ impl SqlToRel {
let rex_args = args
.iter()
.map(|a| match a {
- // this feels hacky but translate COUNT(1)/COUNT(*) to
- // COUNT(first_column)
- ASTNode::SQLValue(sqlparser::sqlast::Value::Long(1)) => {
- Ok(Expr::Column(0))
+ ASTNode::SQLValue(sqlparser::sqlast::Value::Long(_)) => {
+ Ok(Expr::Literal(ScalarValue::UInt8(1)))
}
- ASTNode::SQLWildcard => Ok(Expr::Column(0)),
+ ASTNode::SQLWildcard => {
+ Ok(Expr::Literal(ScalarValue::UInt8(1)))
+ },
_ => self.sql_to_rex(a, schema),
})
.collect::<Result<Vec<Expr>>>()?;
@@ -482,6 +482,14 @@ mod tests {
#[test]
fn select_count_one() {
let sql = "SELECT COUNT(1) FROM person";
+ let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
+ \n TableScan: person projection=None";
+ quick_test(sql, expected);
+ }
+
+ #[test]
+ fn select_count_column() {
+ let sql = "SELECT COUNT(id) FROM person";
let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(#0)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index 7b04609..e21e9c7 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -47,6 +47,16 @@ fn parquet_query() {
}
#[test]
+fn csv_count_star() {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx);
+ let sql = "SELECT COUNT(*), COUNT(1), COUNT(c1) FROM aggregate_test_100";
+ let actual = execute(&mut ctx, sql);
+ let expected = "100\t100\t100\n".to_string();
+ assert_eq!(expected, actual);
+}
+
+#[test]
fn csv_query_with_predicate() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);