You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/12/10 19:38:14 UTC

[arrow-datafusion] branch master updated: support decimal for min/max agg (#1407)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new e89da30  support decimal for min/max agg (#1407)
e89da30 is described below

commit e89da30828960f98eb8f28f37c1d4af8f9319653
Author: Kun Liu <li...@apache.org>
AuthorDate: Sat Dec 11 03:38:11 2021 +0800

    support decimal for min/max agg (#1407)
    
    * support decimal for min/max agg
    
    * add table/sql test for decimal min/max agg
    
    * change decimal test case
---
 datafusion/src/execution/context.rs                |  40 ++++
 .../src/physical_plan/expressions/min_max.rs       | 254 ++++++++++++++++++++-
 datafusion/src/test/mod.rs                         |  23 +-
 3 files changed, 306 insertions(+), 11 deletions(-)

diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs
index 59d6f44..d7c536e 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -1843,6 +1843,46 @@ mod tests {
     }
 
     #[tokio::test]
+    async fn aggregate_decimal_min() -> Result<()> {
+        let mut ctx = ExecutionContext::new();
+        ctx.register_table("d_table", test::table_with_decimal())
+            .unwrap();
+
+        let result = plan_and_collect(&mut ctx, "select min(c1) from d_table")
+            .await
+            .unwrap();
+        let expected = vec![
+            "+-----------------+",
+            "| MIN(d_table.c1) |",
+            "+-----------------+",
+            "| -100.009        |",
+            "+-----------------+",
+        ];
+        assert_batches_sorted_eq!(expected, &result);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn aggregate_decimal_max() -> Result<()> {
+        let mut ctx = ExecutionContext::new();
+        ctx.register_table("d_table", test::table_with_decimal())
+            .unwrap();
+
+        let result = plan_and_collect(&mut ctx, "select max(c1) from d_table")
+            .await
+            .unwrap();
+        let expected = vec![
+            "+-----------------+",
+            "| MAX(d_table.c1) |",
+            "+-----------------+",
+            "| 110.009         |",
+            "+-----------------+",
+        ];
+        assert_batches_sorted_eq!(expected, &result);
+        Ok(())
+    }
+
+    #[tokio::test]
     async fn aggregate() -> Result<()> {
         let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?;
         assert_eq!(results.len(), 1);
diff --git a/datafusion/src/physical_plan/expressions/min_max.rs b/datafusion/src/physical_plan/expressions/min_max.rs
index 9e5b1e0..2f61881 100644
--- a/datafusion/src/physical_plan/expressions/min_max.rs
+++ b/datafusion/src/physical_plan/expressions/min_max.rs
@@ -37,6 +37,8 @@ use arrow::{
 };
 
 use super::format_state_name;
+use crate::arrow::array::Array;
+use arrow::array::DecimalArray;
 
 // Min/max aggregation can take Dictionary encode input but always produces unpacked
 // (aka non Dictionary) output. We need to adjust the output data type to reflect this.
@@ -129,11 +131,49 @@ macro_rules! typed_min_max_batch {
     }};
 }
 
+// TODO implement this in arrow-rs with simd
+// https://github.com/apache/arrow-rs/issues/1010
+// Statically-typed version of min/max(array) -> ScalarValue for decimal types.
+macro_rules! typed_min_max_batch_decimal128 {
+    ($VALUES:expr, $PRECISION:ident, $SCALE:ident, $OP:ident) => {{
+        let null_count = $VALUES.null_count();
+        if null_count == $VALUES.len() {
+            ScalarValue::Decimal128(None, *$PRECISION, *$SCALE)
+        } else {
+            let array = $VALUES.as_any().downcast_ref::<DecimalArray>().unwrap();
+            if null_count == 0 {
+                // there is no null value
+                let mut result = array.value(0);
+                for i in 1..array.len() {
+                    result = result.$OP(array.value(i));
+                }
+                ScalarValue::Decimal128(Some(result), *$PRECISION, *$SCALE)
+            } else {
+                let mut result = 0_i128;
+                let mut has_value = false;
+                for i in 0..array.len() {
+                    if !has_value && array.is_valid(i) {
+                        has_value = true;
+                        result = array.value(i);
+                    }
+                    if array.is_valid(i) {
+                        result = result.$OP(array.value(i));
+                    }
+                }
+                ScalarValue::Decimal128(Some(result), *$PRECISION, *$SCALE)
+            }
+        }
+    }};
+}
+
 // Statically-typed version of min/max(array) -> ScalarValue  for non-string types.
 // this is a macro to support both operations (min and max).
 macro_rules! min_max_batch {
     ($VALUES:expr, $OP:ident) => {{
         match $VALUES.data_type() {
+            DataType::Decimal(precision, scale) => {
+                typed_min_max_batch_decimal128!($VALUES, precision, scale, $OP)
+            }
             // all types that have a natural order
             DataType::Float64 => {
                 typed_min_max_batch!($VALUES, Float64Array, Float64, $OP)
@@ -208,6 +248,20 @@ fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
         _ => min_max_batch!(values, max),
     })
 }
+macro_rules! typed_min_max_decimal {
+    ($VALUE:expr, $DELTA:expr, $PRECISION:expr, $SCALE:expr, $SCALAR:ident, $OP:ident) => {{
+        ScalarValue::$SCALAR(
+            match ($VALUE, $DELTA) {
+                (None, None) => None,
+                (Some(a), None) => Some(a.clone()),
+                (None, Some(b)) => Some(b.clone()),
+                (Some(a), Some(b)) => Some((*a).$OP(*b)),
+            },
+            $PRECISION.clone(),
+            $SCALE.clone(),
+        )
+    }};
+}
 
 // min/max of two non-string scalar values.
 macro_rules! typed_min_max {
@@ -237,6 +291,16 @@ macro_rules! typed_min_max_string {
 macro_rules! min_max {
     ($VALUE:expr, $DELTA:expr, $OP:ident) => {{
         Ok(match ($VALUE, $DELTA) {
+            (ScalarValue::Decimal128(lhsv,lhsp,lhss), ScalarValue::Decimal128(rhsv,rhsp,rhss)) => {
+                if lhsp.eq(rhsp) && lhss.eq(rhss) {
+                    typed_min_max_decimal!(lhsv, rhsv, lhsp, lhss, Decimal128, $OP)
+                } else {
+                    return Err(DataFusionError::Internal(format!(
+                    "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
+                    (ScalarValue::Decimal128(*lhsv,*lhsp,*lhss),ScalarValue::Decimal128(*rhsv,*rhsp,*rhss))
+                )));
+                }
+            }
             (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
                 typed_min_max!(lhs, rhs, Float64, $OP)
             }
@@ -411,6 +475,10 @@ impl AggregateExpr for Min {
         ))
     }
 
+    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(MinAccumulator::try_new(&self.data_type)?))
+    }
+
     fn state_fields(&self) -> Result<Vec<Field>> {
         Ok(vec![Field::new(
             &format_state_name(&self.name, "min"),
@@ -423,10 +491,6 @@ impl AggregateExpr for Min {
         vec![self.expr.clone()]
     }
 
-    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        Ok(Box::new(MinAccumulator::try_new(&self.data_type)?))
-    }
-
     fn name(&self) -> &str {
         &self.name
     }
@@ -452,6 +516,12 @@ impl Accumulator for MinAccumulator {
         Ok(vec![self.min.clone()])
     }
 
+    fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
+        let value = &values[0];
+        self.min = min(&self.min, value)?;
+        Ok(())
+    }
+
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
         let values = &values[0];
         let delta = &min_batch(values)?;
@@ -459,12 +529,6 @@ impl Accumulator for MinAccumulator {
         Ok(())
     }
 
-    fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
-        let value = &values[0];
-        self.min = min(&self.min, value)?;
-        Ok(())
-    }
-
     fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
         self.update(states)
     }
@@ -483,11 +547,181 @@ mod tests {
     use super::*;
     use crate::physical_plan::expressions::col;
     use crate::physical_plan::expressions::tests::aggregate;
+    use crate::scalar::ScalarValue::Decimal128;
     use crate::{error::Result, generic_test_op};
+    use arrow::array::DecimalBuilder;
     use arrow::datatypes::*;
     use arrow::record_batch::RecordBatch;
 
     #[test]
+    fn min_decimal() -> Result<()> {
+        // min
+        let left = ScalarValue::Decimal128(Some(123), 10, 2);
+        let right = ScalarValue::Decimal128(Some(124), 10, 2);
+        let result = min(&left, &right)?;
+        assert_eq!(result, left);
+
+        // min batch
+        let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+        for i in 1..6 {
+            decimal_builder.append_value(i as i128)?;
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+
+        let result = min_batch(&array)?;
+        assert_eq!(result, ScalarValue::Decimal128(Some(1), 10, 0));
+        // min batch without values
+        let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        let result = min_batch(&array)?;
+        assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
+
+        let mut decimal_builder = DecimalBuilder::new(0, 10, 0);
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        let result = min_batch(&array)?;
+        assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
+
+        // min batch with agg
+        let mut decimal_builder = DecimalBuilder::new(6, 10, 0);
+        decimal_builder.append_null().unwrap();
+        for i in 1..6 {
+            decimal_builder.append_value(i as i128)?;
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        generic_test_op!(
+            array,
+            DataType::Decimal(10, 0),
+            Min,
+            ScalarValue::Decimal128(Some(1), 10, 0),
+            DataType::Decimal(10, 0)
+        )
+    }
+
+    #[test]
+    fn min_decimal_all_nulls() -> Result<()> {
+        // min batch all nulls
+        let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+        for _i in 1..6 {
+            decimal_builder.append_null()?;
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        generic_test_op!(
+            array,
+            DataType::Decimal(10, 0),
+            Min,
+            ScalarValue::Decimal128(None, 10, 0),
+            DataType::Decimal(10, 0)
+        )
+    }
+
+    #[test]
+    fn min_decimal_with_nulls() -> Result<()> {
+        // min batch with nulls
+        let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+        for i in 1..6 {
+            if i == 2 {
+                decimal_builder.append_null()?;
+            } else {
+                decimal_builder.append_value(i as i128)?;
+            }
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        generic_test_op!(
+            array,
+            DataType::Decimal(10, 0),
+            Min,
+            ScalarValue::Decimal128(Some(1), 10, 0),
+            DataType::Decimal(10, 0)
+        )
+    }
+
+    #[test]
+    fn max_decimal() -> Result<()> {
+        // max
+        let left = ScalarValue::Decimal128(Some(123), 10, 2);
+        let right = ScalarValue::Decimal128(Some(124), 10, 2);
+        let result = max(&left, &right)?;
+        assert_eq!(result, right);
+
+        let right = ScalarValue::Decimal128(Some(124), 10, 3);
+        let result = max(&left, &right);
+        let expect = DataFusionError::Internal(format!(
+            "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
+            (Decimal128(Some(123), 10, 2), Decimal128(Some(124), 10, 3))
+        ));
+        assert_eq!(expect.to_string(), result.unwrap_err().to_string());
+
+        // max batch
+        let mut decimal_builder = DecimalBuilder::new(5, 10, 5);
+        for i in 1..6 {
+            decimal_builder.append_value(i as i128)?;
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        let result = max_batch(&array)?;
+        assert_eq!(result, ScalarValue::Decimal128(Some(5), 10, 5));
+        // max batch without values
+        let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        let result = max_batch(&array)?;
+        assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
+
+        let mut decimal_builder = DecimalBuilder::new(0, 10, 0);
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        let result = max_batch(&array)?;
+        assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
+        // max batch with agg
+        let mut decimal_builder = DecimalBuilder::new(6, 10, 0);
+        decimal_builder.append_null().unwrap();
+        for i in 1..6 {
+            decimal_builder.append_value(i as i128)?;
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        generic_test_op!(
+            array,
+            DataType::Decimal(10, 0),
+            Max,
+            ScalarValue::Decimal128(Some(5), 10, 0),
+            DataType::Decimal(10, 0)
+        )
+    }
+
+    #[test]
+    fn max_decimal_with_nulls() -> Result<()> {
+        let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+        for i in 1..6 {
+            if i == 2 {
+                decimal_builder.append_null()?;
+            } else {
+                decimal_builder.append_value(i as i128)?;
+            }
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        generic_test_op!(
+            array,
+            DataType::Decimal(10, 0),
+            Max,
+            ScalarValue::Decimal128(Some(5), 10, 0),
+            DataType::Decimal(10, 0)
+        )
+    }
+
+    #[test]
+    fn max_decimal_all_nulls() -> Result<()> {
+        let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+        for _i in 1..6 {
+            decimal_builder.append_null()?;
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        generic_test_op!(
+            array,
+            DataType::Decimal(10, 0),
+            Min,
+            ScalarValue::Decimal128(None, 10, 0),
+            DataType::Decimal(10, 0)
+        )
+    }
+
+    #[test]
     fn max_i32() -> Result<()> {
         let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
         generic_test_op!(
diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs
index 16c1383..39c9de1 100644
--- a/datafusion/src/test/mod.rs
+++ b/datafusion/src/test/mod.rs
@@ -25,7 +25,7 @@ use array::{
     Array, ArrayRef, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray,
     TimestampNanosecondArray, TimestampSecondArray,
 };
-use arrow::array::{self, Int32Array};
+use arrow::array::{self, DecimalBuilder, Int32Array};
 use arrow::datatypes::{DataType, Field, Schema};
 use arrow::record_batch::RecordBatch;
 use futures::{Future, FutureExt};
@@ -192,6 +192,27 @@ pub fn table_with_timestamps() -> Arc<dyn TableProvider> {
     Arc::new(MemTable::try_new(schema, partitions).unwrap())
 }
 
+/// Return a new table which provide this decimal column
+pub fn table_with_decimal() -> Arc<dyn TableProvider> {
+    let batch_decimal = make_decimal();
+    let schema = batch_decimal.schema();
+    let partitions = vec![vec![batch_decimal]];
+    Arc::new(MemTable::try_new(schema, partitions).unwrap())
+}
+
+fn make_decimal() -> RecordBatch {
+    let mut decimal_builder = DecimalBuilder::new(20, 10, 3);
+    for i in 110000..110010 {
+        decimal_builder.append_value(i as i128).unwrap();
+    }
+    for i in 100000..100010 {
+        decimal_builder.append_value(-i as i128).unwrap();
+    }
+    let array = decimal_builder.finish();
+    let schema = Schema::new(vec![Field::new("c1", array.data_type().clone(), true)]);
+    RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap()
+}
+
 /// Return  record batch with all of the supported timestamp types
 /// values
 ///