You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/12/10 19:38:14 UTC
[arrow-datafusion] branch master updated: support decimal for min/max agg (#1407)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new e89da30 support decimal for min/max agg (#1407)
e89da30 is described below
commit e89da30828960f98eb8f28f37c1d4af8f9319653
Author: Kun Liu <li...@apache.org>
AuthorDate: Sat Dec 11 03:38:11 2021 +0800
support decimal for min/max agg (#1407)
* support decimal for min/max agg
* add table/sql test for decimal min/max agg
* change decimal test case
---
datafusion/src/execution/context.rs | 40 ++++
.../src/physical_plan/expressions/min_max.rs | 254 ++++++++++++++++++++-
datafusion/src/test/mod.rs | 23 +-
3 files changed, 306 insertions(+), 11 deletions(-)
diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs
index 59d6f44..d7c536e 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -1843,6 +1843,46 @@ mod tests {
}
#[tokio::test]
+ async fn aggregate_decimal_min() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ ctx.register_table("d_table", test::table_with_decimal())
+ .unwrap();
+
+ let result = plan_and_collect(&mut ctx, "select min(c1) from d_table")
+ .await
+ .unwrap();
+ let expected = vec![
+ "+-----------------+",
+ "| MIN(d_table.c1) |",
+ "+-----------------+",
+ "| -100.009 |",
+ "+-----------------+",
+ ];
+ assert_batches_sorted_eq!(expected, &result);
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn aggregate_decimal_max() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ ctx.register_table("d_table", test::table_with_decimal())
+ .unwrap();
+
+ let result = plan_and_collect(&mut ctx, "select max(c1) from d_table")
+ .await
+ .unwrap();
+ let expected = vec![
+ "+-----------------+",
+ "| MAX(d_table.c1) |",
+ "+-----------------+",
+ "| 110.009 |",
+ "+-----------------+",
+ ];
+ assert_batches_sorted_eq!(expected, &result);
+ Ok(())
+ }
+
+ #[tokio::test]
async fn aggregate() -> Result<()> {
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?;
assert_eq!(results.len(), 1);
diff --git a/datafusion/src/physical_plan/expressions/min_max.rs b/datafusion/src/physical_plan/expressions/min_max.rs
index 9e5b1e0..2f61881 100644
--- a/datafusion/src/physical_plan/expressions/min_max.rs
+++ b/datafusion/src/physical_plan/expressions/min_max.rs
@@ -37,6 +37,8 @@ use arrow::{
};
use super::format_state_name;
+use crate::arrow::array::Array;
+use arrow::array::DecimalArray;
// Min/max aggregation can take Dictionary encode input but always produces unpacked
// (aka non Dictionary) output. We need to adjust the output data type to reflect this.
@@ -129,11 +131,49 @@ macro_rules! typed_min_max_batch {
}};
}
+// TODO implement this in arrow-rs with simd
+// https://github.com/apache/arrow-rs/issues/1010
+// Statically-typed version of min/max(array) -> ScalarValue for decimal types.
+macro_rules! typed_min_max_batch_decimal128 {
+ ($VALUES:expr, $PRECISION:ident, $SCALE:ident, $OP:ident) => {{
+ let null_count = $VALUES.null_count();
+ if null_count == $VALUES.len() {
+ ScalarValue::Decimal128(None, *$PRECISION, *$SCALE)
+ } else {
+ let array = $VALUES.as_any().downcast_ref::<DecimalArray>().unwrap();
+ if null_count == 0 {
+ // there is no null value
+ let mut result = array.value(0);
+ for i in 1..array.len() {
+ result = result.$OP(array.value(i));
+ }
+ ScalarValue::Decimal128(Some(result), *$PRECISION, *$SCALE)
+ } else {
+ let mut result = 0_i128;
+ let mut has_value = false;
+ for i in 0..array.len() {
+ if !has_value && array.is_valid(i) {
+ has_value = true;
+ result = array.value(i);
+ }
+ if array.is_valid(i) {
+ result = result.$OP(array.value(i));
+ }
+ }
+ ScalarValue::Decimal128(Some(result), *$PRECISION, *$SCALE)
+ }
+ }
+ }};
+}
+
// Statically-typed version of min/max(array) -> ScalarValue for non-string types.
// this is a macro to support both operations (min and max).
macro_rules! min_max_batch {
($VALUES:expr, $OP:ident) => {{
match $VALUES.data_type() {
+ DataType::Decimal(precision, scale) => {
+ typed_min_max_batch_decimal128!($VALUES, precision, scale, $OP)
+ }
// all types that have a natural order
DataType::Float64 => {
typed_min_max_batch!($VALUES, Float64Array, Float64, $OP)
@@ -208,6 +248,20 @@ fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
_ => min_max_batch!(values, max),
})
}
+macro_rules! typed_min_max_decimal {
+ ($VALUE:expr, $DELTA:expr, $PRECISION:expr, $SCALE:expr, $SCALAR:ident, $OP:ident) => {{
+ ScalarValue::$SCALAR(
+ match ($VALUE, $DELTA) {
+ (None, None) => None,
+ (Some(a), None) => Some(a.clone()),
+ (None, Some(b)) => Some(b.clone()),
+ (Some(a), Some(b)) => Some((*a).$OP(*b)),
+ },
+ $PRECISION.clone(),
+ $SCALE.clone(),
+ )
+ }};
+}
// min/max of two non-string scalar values.
macro_rules! typed_min_max {
@@ -237,6 +291,16 @@ macro_rules! typed_min_max_string {
macro_rules! min_max {
($VALUE:expr, $DELTA:expr, $OP:ident) => {{
Ok(match ($VALUE, $DELTA) {
+ (ScalarValue::Decimal128(lhsv,lhsp,lhss), ScalarValue::Decimal128(rhsv,rhsp,rhss)) => {
+ if lhsp.eq(rhsp) && lhss.eq(rhss) {
+ typed_min_max_decimal!(lhsv, rhsv, lhsp, lhss, Decimal128, $OP)
+ } else {
+ return Err(DataFusionError::Internal(format!(
+ "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
+ (ScalarValue::Decimal128(*lhsv,*lhsp,*lhss),ScalarValue::Decimal128(*rhsv,*rhsp,*rhss))
+ )));
+ }
+ }
(ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
typed_min_max!(lhs, rhs, Float64, $OP)
}
@@ -411,6 +475,10 @@ impl AggregateExpr for Min {
))
}
+ fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(MinAccumulator::try_new(&self.data_type)?))
+ }
+
fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![Field::new(
&format_state_name(&self.name, "min"),
@@ -423,10 +491,6 @@ impl AggregateExpr for Min {
vec![self.expr.clone()]
}
- fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- Ok(Box::new(MinAccumulator::try_new(&self.data_type)?))
- }
-
fn name(&self) -> &str {
&self.name
}
@@ -452,6 +516,12 @@ impl Accumulator for MinAccumulator {
Ok(vec![self.min.clone()])
}
+ fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
+ let value = &values[0];
+ self.min = min(&self.min, value)?;
+ Ok(())
+ }
+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = &values[0];
let delta = &min_batch(values)?;
@@ -459,12 +529,6 @@ impl Accumulator for MinAccumulator {
Ok(())
}
- fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
- let value = &values[0];
- self.min = min(&self.min, value)?;
- Ok(())
- }
-
fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
self.update(states)
}
@@ -483,11 +547,181 @@ mod tests {
use super::*;
use crate::physical_plan::expressions::col;
use crate::physical_plan::expressions::tests::aggregate;
+ use crate::scalar::ScalarValue::Decimal128;
use crate::{error::Result, generic_test_op};
+ use arrow::array::DecimalBuilder;
use arrow::datatypes::*;
use arrow::record_batch::RecordBatch;
#[test]
+ fn min_decimal() -> Result<()> {
+ // min
+ let left = ScalarValue::Decimal128(Some(123), 10, 2);
+ let right = ScalarValue::Decimal128(Some(124), 10, 2);
+ let result = min(&left, &right)?;
+ assert_eq!(result, left);
+
+ // min batch
+ let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+ for i in 1..6 {
+ decimal_builder.append_value(i as i128)?;
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+
+ let result = min_batch(&array)?;
+ assert_eq!(result, ScalarValue::Decimal128(Some(1), 10, 0));
+ // min batch without values
+ let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ let result = min_batch(&array)?;
+ assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
+
+ let mut decimal_builder = DecimalBuilder::new(0, 10, 0);
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ let result = min_batch(&array)?;
+ assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
+
+ // min batch with agg
+ let mut decimal_builder = DecimalBuilder::new(6, 10, 0);
+ decimal_builder.append_null().unwrap();
+ for i in 1..6 {
+ decimal_builder.append_value(i as i128)?;
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ generic_test_op!(
+ array,
+ DataType::Decimal(10, 0),
+ Min,
+ ScalarValue::Decimal128(Some(1), 10, 0),
+ DataType::Decimal(10, 0)
+ )
+ }
+
+ #[test]
+ fn min_decimal_all_nulls() -> Result<()> {
+ // min batch all nulls
+ let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+ for _i in 1..6 {
+ decimal_builder.append_null()?;
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ generic_test_op!(
+ array,
+ DataType::Decimal(10, 0),
+ Min,
+ ScalarValue::Decimal128(None, 10, 0),
+ DataType::Decimal(10, 0)
+ )
+ }
+
+ #[test]
+ fn min_decimal_with_nulls() -> Result<()> {
+ // min batch with nulls
+ let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+ for i in 1..6 {
+ if i == 2 {
+ decimal_builder.append_null()?;
+ } else {
+ decimal_builder.append_value(i as i128)?;
+ }
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ generic_test_op!(
+ array,
+ DataType::Decimal(10, 0),
+ Min,
+ ScalarValue::Decimal128(Some(1), 10, 0),
+ DataType::Decimal(10, 0)
+ )
+ }
+
+ #[test]
+ fn max_decimal() -> Result<()> {
+ // max
+ let left = ScalarValue::Decimal128(Some(123), 10, 2);
+ let right = ScalarValue::Decimal128(Some(124), 10, 2);
+ let result = max(&left, &right)?;
+ assert_eq!(result, right);
+
+ let right = ScalarValue::Decimal128(Some(124), 10, 3);
+ let result = max(&left, &right);
+ let expect = DataFusionError::Internal(format!(
+ "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
+ (Decimal128(Some(123), 10, 2), Decimal128(Some(124), 10, 3))
+ ));
+ assert_eq!(expect.to_string(), result.unwrap_err().to_string());
+
+ // max batch
+ let mut decimal_builder = DecimalBuilder::new(5, 10, 5);
+ for i in 1..6 {
+ decimal_builder.append_value(i as i128)?;
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ let result = max_batch(&array)?;
+ assert_eq!(result, ScalarValue::Decimal128(Some(5), 10, 5));
+ // max batch without values
+ let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ let result = max_batch(&array)?;
+ assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
+
+ let mut decimal_builder = DecimalBuilder::new(0, 10, 0);
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ let result = max_batch(&array)?;
+ assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
+ // max batch with agg
+ let mut decimal_builder = DecimalBuilder::new(6, 10, 0);
+ decimal_builder.append_null().unwrap();
+ for i in 1..6 {
+ decimal_builder.append_value(i as i128)?;
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ generic_test_op!(
+ array,
+ DataType::Decimal(10, 0),
+ Max,
+ ScalarValue::Decimal128(Some(5), 10, 0),
+ DataType::Decimal(10, 0)
+ )
+ }
+
+ #[test]
+ fn max_decimal_with_nulls() -> Result<()> {
+ let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+ for i in 1..6 {
+ if i == 2 {
+ decimal_builder.append_null()?;
+ } else {
+ decimal_builder.append_value(i as i128)?;
+ }
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ generic_test_op!(
+ array,
+ DataType::Decimal(10, 0),
+ Max,
+ ScalarValue::Decimal128(Some(5), 10, 0),
+ DataType::Decimal(10, 0)
+ )
+ }
+
+ #[test]
+ fn max_decimal_all_nulls() -> Result<()> {
+ let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+ for _i in 1..6 {
+ decimal_builder.append_null()?;
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ generic_test_op!(
+ array,
+ DataType::Decimal(10, 0),
+ Min,
+ ScalarValue::Decimal128(None, 10, 0),
+ DataType::Decimal(10, 0)
+ )
+ }
+
+ #[test]
fn max_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
generic_test_op!(
diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs
index 16c1383..39c9de1 100644
--- a/datafusion/src/test/mod.rs
+++ b/datafusion/src/test/mod.rs
@@ -25,7 +25,7 @@ use array::{
Array, ArrayRef, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray,
};
-use arrow::array::{self, Int32Array};
+use arrow::array::{self, DecimalBuilder, Int32Array};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use futures::{Future, FutureExt};
@@ -192,6 +192,27 @@ pub fn table_with_timestamps() -> Arc<dyn TableProvider> {
Arc::new(MemTable::try_new(schema, partitions).unwrap())
}
+/// Return a new table which provide this decimal column
+pub fn table_with_decimal() -> Arc<dyn TableProvider> {
+ let batch_decimal = make_decimal();
+ let schema = batch_decimal.schema();
+ let partitions = vec![vec![batch_decimal]];
+ Arc::new(MemTable::try_new(schema, partitions).unwrap())
+}
+
+fn make_decimal() -> RecordBatch {
+ let mut decimal_builder = DecimalBuilder::new(20, 10, 3);
+ for i in 110000..110010 {
+ decimal_builder.append_value(i as i128).unwrap();
+ }
+ for i in 100000..100010 {
+ decimal_builder.append_value(-i as i128).unwrap();
+ }
+ let array = decimal_builder.finish();
+ let schema = Schema::new(vec![Field::new("c1", array.data_type().clone(), true)]);
+ RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap()
+}
+
/// Return record batch with all of the supported timestamp types
/// values
///