You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ks...@apache.org on 2019/04/10 10:14:41 UTC
[arrow] branch master updated: ARROW-5038: [Rust] [DataFusion]
Implement AVG aggregate function
This is an automated email from the ASF dual-hosted git repository.
kszucs pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 646624b ARROW-5038: [Rust] [DataFusion] Implement AVG aggregate function
646624b is described below
commit 646624bc7524840077203f0e1fc0bf14e67244fa
Author: Zhiyuan Zheng <zh...@hotmail.com>
AuthorDate: Wed Apr 10 12:14:24 2019 +0200
ARROW-5038: [Rust] [DataFusion] Implement AVG aggregate function
Implement the AVG aggregate function for DataFusion.
Split `accumulate_scalar` function into `accumulate_scalar` & `accumulate_batch` for single value update and batch update.
I am a newbie to Arrow. Please take careful review for this @andygrove .
Author: Zhiyuan Zheng <zh...@hotmail.com>
Closes #4120 from zhzy0077/master and squashes the following commits:
26e85bc46 <Zhiyuan Zheng> Fix format issue.
da1f5c897 <Zhiyuan Zheng> Add multi-batch AVG test.
54e13b267 <Zhiyuan Zheng> Apply the fmt changes.
beb27f789 <Zhiyuan Zheng> Support for AVG function.
---
rust/datafusion/src/execution/aggregate.rs | 234 +++++++++++++++++++---------
rust/datafusion/src/execution/expression.rs | 1 +
rust/datafusion/tests/sql.rs | 95 ++++++++++-
3 files changed, 257 insertions(+), 73 deletions(-)
diff --git a/rust/datafusion/src/execution/aggregate.rs b/rust/datafusion/src/execution/aggregate.rs
index 7068c6a..4399bef 100644
--- a/rust/datafusion/src/execution/aggregate.rs
+++ b/rust/datafusion/src/execution/aggregate.rs
@@ -86,21 +86,14 @@ trait AggregateFunction {
fn name(&self) -> &str;
/// Update the current aggregate value based on a new value. A value of `None` represents a
- /// null value. If rollup is false, then this aggregate function instance is being used
- /// to aggregate individual values within a RecordBatch. If rollup is true then the aggregate
- /// function instance is being used to combine the aggregates for multiple batches.
- /// For some aggregate operations, such as `min`, `max`, and `sum`, the logic is the same
- /// regardless of whether rollup is true or false. For example, `min` can be implemented as
- /// `min(min(batch1), min(batch2), ..)`. However for `count` the logic is
- /// `sum(count(batch1), count(batch2), ..)`.
- fn accumulate_scalar(
- &mut self,
- value: &Option<ScalarValue>,
- rollup: bool,
- ) -> Result<()>;
+ /// null value.
+ fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()>;
+
+ /// Update the current aggregate value based on an array.
+ fn accumulate_batch(&mut self, array: ArrayRef) -> Result<()>;
/// Return the result of the aggregate function after all values have been processed
- /// by calls to `acccumulate_scalar`.
+ /// by calls to `accumulate_scalar`.
fn result(&self) -> Option<ScalarValue>;
/// Get the data type of the result of the aggregate function. For some operations,
@@ -131,11 +124,7 @@ impl AggregateFunction for MinFunction {
"min"
}
- fn accumulate_scalar(
- &mut self,
- value: &Option<ScalarValue>,
- _rollup: bool,
- ) -> Result<()> {
+ fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
if self.value.is_none() {
self.value = value.clone();
} else if value.is_some() {
@@ -180,6 +169,12 @@ impl AggregateFunction for MinFunction {
Ok(())
}
+ fn accumulate_batch(&mut self, array: ArrayRef) -> Result<()> {
+ let accumulated_value = array_min(array)?;
+
+ self.accumulate_scalar(&accumulated_value)
+ }
+
fn result(&self) -> Option<ScalarValue> {
self.value.clone()
}
@@ -210,11 +205,7 @@ impl AggregateFunction for MaxFunction {
"max"
}
- fn accumulate_scalar(
- &mut self,
- value: &Option<ScalarValue>,
- _rollup: bool,
- ) -> Result<()> {
+ fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
if self.value.is_none() {
self.value = value.clone();
} else if value.is_some() {
@@ -259,6 +250,12 @@ impl AggregateFunction for MaxFunction {
Ok(())
}
+ fn accumulate_batch(&mut self, array: ArrayRef) -> Result<()> {
+ let accumulated_value = array_max(array)?;
+
+ self.accumulate_scalar(&accumulated_value)
+ }
+
fn result(&self) -> Option<ScalarValue> {
self.value.clone()
}
@@ -289,11 +286,7 @@ impl AggregateFunction for SumFunction {
"sum"
}
- fn accumulate_scalar(
- &mut self,
- value: &Option<ScalarValue>,
- _rollup: bool,
- ) -> Result<()> {
+ fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
if self.value.is_none() {
self.value = value.clone();
} else if value.is_some() {
@@ -338,6 +331,12 @@ impl AggregateFunction for SumFunction {
Ok(())
}
+ fn accumulate_batch(&mut self, array: ArrayRef) -> Result<()> {
+ let accumulated_value = array_sum(array)?;
+
+ self.accumulate_scalar(&accumulated_value)
+ }
+
fn result(&self) -> Option<ScalarValue> {
self.value.clone()
}
@@ -347,6 +346,79 @@ impl AggregateFunction for SumFunction {
}
}
+/// Implementation of AVG aggregate function
+#[derive(Debug)]
+struct AvgFunction {
+ data_type: DataType,
+ sum_value: SumFunction,
+ count_value: CountFunction,
+}
+
+impl AvgFunction {
+ fn new(data_type: &DataType) -> Self {
+ Self {
+ data_type: DataType::Float64,
+ sum_value: SumFunction::new(data_type),
+ count_value: CountFunction::new(),
+ }
+ }
+}
+
+impl AggregateFunction for AvgFunction {
+ fn name(&self) -> &str {
+ "avg"
+ }
+
+ fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
+ self.sum_value.accumulate_scalar(value)?;
+ self.count_value.accumulate_scalar(value)?;
+
+ Ok(())
+ }
+
+ fn accumulate_batch(&mut self, array: ArrayRef) -> Result<()> {
+ self.sum_value.accumulate_batch(array.clone())?;
+ self.count_value.accumulate_batch(array)?;
+
+ Ok(())
+ }
+
+ fn result(&self) -> Option<ScalarValue> {
+ let sum = match self.sum_value.result() {
+ Some(ScalarValue::UInt8(a)) => a as f64,
+ Some(ScalarValue::UInt16(a)) => a as f64,
+ Some(ScalarValue::UInt32(a)) => a as f64,
+ Some(ScalarValue::UInt64(a)) => a as f64,
+ Some(ScalarValue::Int8(a)) => a as f64,
+ Some(ScalarValue::Int16(a)) => a as f64,
+ Some(ScalarValue::Int32(a)) => a as f64,
+ Some(ScalarValue::Int64(a)) => a as f64,
+ Some(ScalarValue::Float32(a)) => a as f64,
+ Some(ScalarValue::Float64(a)) => a as f64,
+ Some(ScalarValue::Null) => {
+ return Some(ScalarValue::Null);
+ }
+ _ => {
+ return None;
+ }
+ };
+ let count = match self.count_value.result() {
+ Some(ScalarValue::UInt64(a)) => a as f64,
+ Some(ScalarValue::Null) => {
+ return Some(ScalarValue::Null);
+ }
+ _ => {
+ return None;
+ }
+ };
+ Some(ScalarValue::Float64(sum / count))
+ }
+
+ fn data_type(&self) -> &DataType {
+ &self.data_type
+ }
+}
+
/// Implementation of COUNT aggregate function
#[derive(Debug)]
struct CountFunction {
@@ -364,28 +436,27 @@ impl AggregateFunction for CountFunction {
"count"
}
- fn accumulate_scalar(
- &mut self,
- value: &Option<ScalarValue>,
- rollup: bool,
- ) -> Result<()> {
- if rollup {
- // in rollup mode, the counts are added together
- if let Some(ScalarValue::UInt64(n)) = value {
- self.value = match self.value {
- Some(cur_value) => Some(cur_value + *n),
- _ => Some(*n),
- }
- }
- } else {
- // count the value if it is not null
- if value.is_some() {
- self.value = match self.value {
- Some(cur_value) => Some(cur_value + 1),
- _ => Some(1),
- }
+ fn accumulate_scalar(&mut self, value: &Option<ScalarValue>) -> Result<()> {
+ if value.is_some() {
+ self.value = match self.value {
+ Some(cur_value) => Some(cur_value + 1),
+ None => Some(1),
}
}
+
+ Ok(())
+ }
+
+ fn accumulate_batch(&mut self, array: ArrayRef) -> Result<()> {
+ let accumulated_value = array_count(array)?;
+
+ if let Some(ScalarValue::UInt64(n)) = accumulated_value {
+ self.value = match self.value {
+ Some(cur_value) => Some(cur_value + n),
+ None => Some(n),
+ }
+ };
+
Ok(())
}
@@ -406,14 +477,14 @@ struct AccumulatorSet {
}
impl AccumulatorSet {
- fn accumulate_scalar(
- &mut self,
- i: usize,
- value: Option<ScalarValue>,
- rollup: bool,
- ) -> Result<()> {
+ fn accumulate_scalar(&mut self, i: usize, value: Option<ScalarValue>) -> Result<()> {
+ let mut accumulator = self.aggr_values[i].borrow_mut();
+ accumulator.accumulate_scalar(&value)
+ }
+
+ fn accumulate_batch(&mut self, i: usize, array: ArrayRef) -> Result<()> {
let mut accumulator = self.aggr_values[i].borrow_mut();
- accumulator.accumulate_scalar(&value, rollup)
+ accumulator.accumulate_batch(array)
}
fn values(&self) -> Vec<Option<ScalarValue>> {
@@ -449,6 +520,10 @@ fn create_accumulators(
Ok(Rc::new(RefCell::new(SumFunction::new(e.data_type())))
as Rc<RefCell<AggregateFunction>>)
}
+ AggregateType::Avg => {
+ Ok(Rc::new(RefCell::new(AvgFunction::new(e.data_type())))
+ as Rc<RefCell<AggregateFunction>>)
+ }
AggregateType::Count => Ok(Rc::new(RefCell::new(CountFunction::new()))
as Rc<RefCell<AggregateFunction>>),
_ => Err(ExecutionError::ExecutionError(
@@ -730,7 +805,7 @@ fn update_accumulators(
}
};
- accumulator_set.accumulate_scalar(j, value, false)?;
+ accumulator_set.accumulate_scalar(j, value)?;
}
Ok(())
}
@@ -834,18 +909,11 @@ impl AggregateRelation {
// evaluate the argument to the aggregate function
let array = self.aggr_expr[i].evaluate_arg(&batch)?;
match self.aggr_expr[i].aggr_type() {
- AggregateType::Min => {
- accumulator_set.accumulate_scalar(i, array_min(array)?, true)?
- }
- AggregateType::Max => {
- accumulator_set.accumulate_scalar(i, array_max(array)?, true)?
- }
- AggregateType::Sum => {
- accumulator_set.accumulate_scalar(i, array_sum(array)?, true)?
- }
- AggregateType::Count => {
- accumulator_set.accumulate_scalar(i, array_count(array)?, true)?
- }
+ AggregateType::Min => accumulator_set.accumulate_batch(i, array)?,
+ AggregateType::Max => accumulator_set.accumulate_batch(i, array)?,
+ AggregateType::Sum => accumulator_set.accumulate_batch(i, array)?,
+ AggregateType::Count => accumulator_set.accumulate_batch(i, array)?,
+ AggregateType::Avg => accumulator_set.accumulate_batch(i, array)?,
_ => {
return Err(ExecutionError::NotImplemented(
"Unsupported aggregate function".to_string(),
@@ -1243,7 +1311,7 @@ mod tests {
}
#[test]
- fn test_min_max_sum_count_f64_group_by_uint32() {
+ fn test_min_max_sum_count_avg_f64_group_by_uint32() {
let schema = aggr_test_schema();
let testdata = env::var("ARROW_TEST_DATA").expect("ARROW_TEST_DATA not defined");
let relation =
@@ -1298,22 +1366,34 @@ mod tests {
)
.unwrap();
+ let avg_expr = expression::compile_aggregate_expr(
+ &context,
+ &Expr::AggregateFunction {
+ name: String::from("avg"),
+ args: vec![Expr::Column(11)],
+ return_type: DataType::Float64,
+ },
+ &schema,
+ )
+ .unwrap();
+
let aggr_schema = Arc::new(Schema::new(vec![
Field::new("c2", DataType::UInt32, false),
Field::new("min", DataType::Float64, false),
Field::new("max", DataType::Float64, false),
Field::new("sum", DataType::Float64, false),
Field::new("count", DataType::UInt64, false),
+ Field::new("avg", DataType::Float64, false),
]));
let mut projection = AggregateRelation::new(
aggr_schema,
relation,
vec![group_by_expr],
- vec![min_expr, max_expr, sum_expr, count_expr],
+ vec![min_expr, max_expr, sum_expr, count_expr, avg_expr],
);
let batch = projection.next().unwrap().unwrap();
- assert_eq!(5, batch.num_columns());
+ assert_eq!(6, batch.num_columns());
assert_eq!(5, batch.num_rows());
let a = batch
@@ -1341,36 +1421,46 @@ mod tests {
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
+ let avg = batch
+ .column(5)
+ .as_any()
+ .downcast_ref::<Float64Array>()
+ .unwrap();
assert_eq!(4, a.value(0));
assert_eq!(0.02182578039211991, min.value(0));
assert_eq!(0.9237877978193884, max.value(0));
assert_eq!(9.253864188402662, sum.value(0));
assert_eq!(23, count.value(0));
+ assert_eq!(0.40234192123489837, avg.value(0));
assert_eq!(2, a.value(1));
assert_eq!(0.16301110515739792, min.value(1));
assert_eq!(0.991517828651004, max.value(1));
assert_eq!(14.400412325480858, sum.value(1));
assert_eq!(22, count.value(1));
+ assert_eq!(0.6545641966127662, avg.value(1));
assert_eq!(5, a.value(2));
assert_eq!(0.01479305307777301, min.value(2));
assert_eq!(0.9723580396501548, max.value(2));
assert_eq!(6.037181692266781, sum.value(2));
assert_eq!(14, count.value(2));
+ assert_eq!(0.4312272637333415, avg.value(2));
assert_eq!(3, a.value(3));
assert_eq!(0.047343434291126085, min.value(3));
assert_eq!(0.9293883502480845, max.value(3));
assert_eq!(9.966125219358322, sum.value(3));
assert_eq!(19, count.value(3));
+ assert_eq!(0.5245329062820169, avg.value(3));
assert_eq!(1, a.value(4));
assert_eq!(0.05636955101974106, min.value(4));
assert_eq!(0.9965400387585364, max.value(4));
assert_eq!(11.239667565763519, sum.value(4));
- assert_eq!(14, count.value(2));
+ assert_eq!(22, count.value(4));
+ assert_eq!(0.5108939802619781, avg.value(4));
}
fn aggr_test_schema() -> Arc<Schema> {
diff --git a/rust/datafusion/src/execution/expression.rs b/rust/datafusion/src/execution/expression.rs
index 74081d5..6282e77 100644
--- a/rust/datafusion/src/execution/expression.rs
+++ b/rust/datafusion/src/execution/expression.rs
@@ -121,6 +121,7 @@ pub(super) fn compile_aggregate_expr(
"max" => Ok(AggregateType::Max),
"count" => Ok(AggregateType::Count),
"sum" => Ok(AggregateType::Sum),
+ "avg" => Ok(AggregateType::Avg),
_ => Err(ExecutionError::General(format!(
"Unsupported aggregate function '{}'",
name
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index e21e9c7..81c52c9 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -78,6 +78,98 @@ fn csv_query_group_by_int_min_max() {
}
#[test]
+fn csv_query_avg() {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx);
+ //TODO add ORDER BY once supported, to make this test determistic
+ let sql = "SELECT avg(c12) FROM aggregate_test_100";
+ let actual = execute(&mut ctx, sql);
+ let expected = "0.5089725099127211\n".to_string();
+ assert_eq!(expected, actual);
+}
+
+#[test]
+fn csv_query_group_by_avg() {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx);
+ //TODO add ORDER BY once supported, to make this test determistic
+ let sql = "SELECT c1, avg(c12) FROM aggregate_test_100 GROUP BY c1";
+ let actual = execute(&mut ctx, sql);
+ let expected = "\"d\"\t0.48855379387549824\n\"c\"\t0.6600456536439784\n\"b\"\t0.41040709263815384\n\"a\"\t0.48754517466109415\n\"e\"\t0.48600669271341534\n".to_string();
+ assert_eq!(expected, actual);
+}
+
+#[test]
+fn csv_query_avg_multi_batch() {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx);
+ //TODO add ORDER BY once supported, to make this test determistic
+ let sql = "SELECT avg(c12) FROM aggregate_test_100";
+ let plan = ctx.create_logical_plan(&sql).unwrap();
+ let results = ctx.execute(&plan, 4).unwrap();
+ let mut relation = results.borrow_mut();
+ let batch = relation.next().unwrap().unwrap();
+ let column = batch.column(0);
+ let array = column.as_any().downcast_ref::<Float64Array>().unwrap();
+ let actual = array.value(0);
+ let expected = 0.5089725;
+ // Due to float number's accuracy, different batch size will lead to different answers.
+ assert!((expected - actual).abs() < 0.01);
+}
+
+#[test]
+fn csv_query_group_by_avg_multi_batch() {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx);
+ //TODO add ORDER BY once supported, to make this test determistic
+ let sql = "SELECT c1, avg(c12) FROM aggregate_test_100 GROUP BY c1";
+ let plan = ctx.create_logical_plan(&sql).unwrap();
+ let results = ctx.execute(&plan, 4).unwrap();
+ let mut relation = results.borrow_mut();
+ let mut actual_vec = Vec::new();
+ while let Some(batch) = relation.next().unwrap() {
+ let column = batch.column(1);
+ let array = column.as_any().downcast_ref::<Float64Array>().unwrap();
+
+ for row_index in 0..batch.num_rows() {
+ actual_vec.push(array.value(row_index));
+ }
+ }
+
+ let expect_vec = vec![0.48855379, 0.66004565, 0.41040709, 0.48754517];
+
+ actual_vec
+ .iter()
+ .zip(expect_vec.iter())
+ .for_each(|(actual, expect)| {
+ // Due to float number's accuracy, different batch size will lead to different answers.
+ assert!((expect - actual).abs() < 0.01);
+ });
+}
+
+#[test]
+fn csv_query_count() {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx);
+ //TODO add ORDER BY once supported, to make this test determistic
+ let sql = "SELECT count(c12) FROM aggregate_test_100";
+ let actual = execute(&mut ctx, sql);
+ let expected = "100\n".to_string();
+ assert_eq!(expected, actual);
+}
+
+#[test]
+fn csv_query_group_by_int_count() {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx);
+ //TODO add ORDER BY once supported, to make this test determistic
+ let sql = "SELECT count(c12) FROM aggregate_test_100 GROUP BY c1";
+ let actual = execute(&mut ctx, sql);
+ let expected = "\"d\"\t18\n\"c\"\t21\n\"b\"\t19\n\"a\"\t21\n\"e\"\t21\n".to_string();
+ assert_eq!(expected, actual);
+}
+
+#[test]
fn csv_query_group_by_string_min_max() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
@@ -149,7 +241,8 @@ fn csv_query_limit_zero() {
assert_eq!(expected, actual);
}
-//TODO Uncomment the following test when ORDER BY is implemented to be able to test ORDER BY + LIMIT
+//TODO Uncomment the following test when ORDER BY is implemented to be able to test ORDER
+// BY + LIMIT
/*
#[test]
fn csv_query_limit_with_order_by() {