You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/04/19 14:08:40 UTC

[arrow-datafusion] branch master updated: support array with scalar arithmetic operation for decimal data type (#2233)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 7334481f0 support array with scalar arithmetic operation for decimal data type (#2233)
7334481f0 is described below

commit 7334481f0dba36bc8d67005e468e00f69a269724
Author: Kun Liu <li...@apache.org>
AuthorDate: Tue Apr 19 22:08:33 2022 +0800

    support array with scalar arithmetic operation for decimal data type (#2233)
    
    * support array with scalar arithmetic operation for decimal data type
    
    * add sql level test for arithmetic op
    
    * add sql level test for decimal for arithmetic operation
---
 datafusion/core/tests/sql/decimal.rs               | 300 +++++++++++++++++++++
 datafusion/physical-expr/src/expressions/binary.rs |  82 ++++++
 2 files changed, 382 insertions(+)

diff --git a/datafusion/core/tests/sql/decimal.rs b/datafusion/core/tests/sql/decimal.rs
index 1e7920935..2153495a6 100644
--- a/datafusion/core/tests/sql/decimal.rs
+++ b/datafusion/core/tests/sql/decimal.rs
@@ -363,6 +363,306 @@ async fn decimal_logic_op() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn decimal_arithmetic_op() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_decimal_csv_table_by_sql(&ctx).await;
+    // add
+    let sql = "select c1+1 from decimal_simple"; // add scalar
+    let actual = execute_to_batches(&ctx, sql).await;
+    // array decimal(10,6) +  scalar decimal(20,0) => decimal(21,6)
+    assert_eq!(
+        &DataType::Decimal(27, 6),
+        actual[0].schema().field(0).data_type()
+    );
+    let expected = vec![
+        "+---------------------------------+",
+        "| decimal_simple.c1 Plus Int64(1) |",
+        "+---------------------------------+",
+        "| 1.000010                        |",
+        "| 1.000020                        |",
+        "| 1.000020                        |",
+        "| 1.000030                        |",
+        "| 1.000030                        |",
+        "| 1.000030                        |",
+        "| 1.000040                        |",
+        "| 1.000040                        |",
+        "| 1.000040                        |",
+        "| 1.000040                        |",
+        "| 1.000050                        |",
+        "| 1.000050                        |",
+        "| 1.000050                        |",
+        "| 1.000050                        |",
+        "| 1.000050                        |",
+        "+---------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    // array decimal(10,6) + array decimal(12,7) => decimal(13,7)
+    let sql = "select c1+c5 from decimal_simple";
+    let actual = execute_to_batches(&ctx, sql).await;
+    assert_eq!(
+        &DataType::Decimal(13, 7),
+        actual[0].schema().field(0).data_type()
+    );
+    let expected = vec![
+        "+------------------------------------------+",
+        "| decimal_simple.c1 Plus decimal_simple.c5 |",
+        "+------------------------------------------+",
+        "| 0.0000240                                |",
+        "| 0.0000450                                |",
+        "| 0.0000390                                |",
+        "| 0.0000620                                |",
+        "| 0.0000650                                |",
+        "| 0.0000410                                |",
+        "| 0.0000840                                |",
+        "| 0.0000800                                |",
+        "| 0.0000800                                |",
+        "| 0.0000840                                |",
+        "| 0.0001020                                |",
+        "| 0.0001280                                |",
+        "| 0.0000830                                |",
+        "| 0.0001180                                |",
+        "| 0.0001500                                |",
+        "+------------------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    // subtract
+    let sql = "select c1-1 from decimal_simple";
+    let actual = execute_to_batches(&ctx, sql).await;
+    assert_eq!(
+        &DataType::Decimal(27, 6),
+        actual[0].schema().field(0).data_type()
+    );
+    let expected = vec![
+        "+----------------------------------+",
+        "| decimal_simple.c1 Minus Int64(1) |",
+        "+----------------------------------+",
+        "| -0.999990                        |",
+        "| -0.999980                        |",
+        "| -0.999980                        |",
+        "| -0.999970                        |",
+        "| -0.999970                        |",
+        "| -0.999970                        |",
+        "| -0.999960                        |",
+        "| -0.999960                        |",
+        "| -0.999960                        |",
+        "| -0.999960                        |",
+        "| -0.999950                        |",
+        "| -0.999950                        |",
+        "| -0.999950                        |",
+        "| -0.999950                        |",
+        "| -0.999950                        |",
+        "+----------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "select c1-c5 from decimal_simple";
+    let actual = execute_to_batches(&ctx, sql).await;
+    assert_eq!(
+        &DataType::Decimal(13, 7),
+        actual[0].schema().field(0).data_type()
+    );
+    let expected = vec![
+        "+-------------------------------------------+",
+        "| decimal_simple.c1 Minus decimal_simple.c5 |",
+        "+-------------------------------------------+",
+        "| -0.0000040                                |",
+        "| -0.0000050                                |",
+        "| 0.0000010                                 |",
+        "| -0.0000020                                |",
+        "| -0.0000050                                |",
+        "| 0.0000190                                 |",
+        "| -0.0000040                                |",
+        "| 0.0000000                                 |",
+        "| 0.0000000                                 |",
+        "| -0.0000040                                |",
+        "| -0.0000020                                |",
+        "| -0.0000280                                |",
+        "| 0.0000170                                 |",
+        "| -0.0000180                                |",
+        "| -0.0000500                                |",
+        "+-------------------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    // multiply
+    let sql = "select c1*20 from decimal_simple";
+    let actual = execute_to_batches(&ctx, sql).await;
+    assert_eq!(
+        &DataType::Decimal(31, 6),
+        actual[0].schema().field(0).data_type()
+    );
+    let expected = vec![
+        "+--------------------------------------+",
+        "| decimal_simple.c1 Multiply Int64(20) |",
+        "+--------------------------------------+",
+        "| 0.000200                             |",
+        "| 0.000400                             |",
+        "| 0.000400                             |",
+        "| 0.000600                             |",
+        "| 0.000600                             |",
+        "| 0.000600                             |",
+        "| 0.000800                             |",
+        "| 0.000800                             |",
+        "| 0.000800                             |",
+        "| 0.000800                             |",
+        "| 0.001000                             |",
+        "| 0.001000                             |",
+        "| 0.001000                             |",
+        "| 0.001000                             |",
+        "| 0.001000                             |",
+        "+--------------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "select c1*c5 from decimal_simple";
+    let actual = execute_to_batches(&ctx, sql).await;
+    assert_eq!(
+        &DataType::Decimal(23, 13),
+        actual[0].schema().field(0).data_type()
+    );
+    let expected = vec![
+        "+----------------------------------------------+",
+        "| decimal_simple.c1 Multiply decimal_simple.c5 |",
+        "+----------------------------------------------+",
+        "| 0.0000000001400                              |",
+        "| 0.0000000005000                              |",
+        "| 0.0000000003800                              |",
+        "| 0.0000000009600                              |",
+        "| 0.0000000010500                              |",
+        "| 0.0000000003300                              |",
+        "| 0.0000000017600                              |",
+        "| 0.0000000016000                              |",
+        "| 0.0000000016000                              |",
+        "| 0.0000000017600                              |",
+        "| 0.0000000026000                              |",
+        "| 0.0000000039000                              |",
+        "| 0.0000000016500                              |",
+        "| 0.0000000034000                              |",
+        "| 0.0000000050000                              |",
+        "+----------------------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    // divide
+    let sql = "select c1/cast(0.00001 as decimal(5,5)) from decimal_simple";
+    let actual = execute_to_batches(&ctx, sql).await;
+    assert_eq!(
+        &DataType::Decimal(21, 12),
+        actual[0].schema().field(0).data_type()
+    );
+    let expected = vec![
+        "+-------------------------------------------------------------+",
+        "| decimal_simple.c1 / CAST(Float64(0.00001) AS Decimal(5, 5)) |",
+        "+-------------------------------------------------------------+",
+        "| 1.000000000000                                              |",
+        "| 2.000000000000                                              |",
+        "| 2.000000000000                                              |",
+        "| 3.000000000000                                              |",
+        "| 3.000000000000                                              |",
+        "| 3.000000000000                                              |",
+        "| 4.000000000000                                              |",
+        "| 4.000000000000                                              |",
+        "| 4.000000000000                                              |",
+        "| 4.000000000000                                              |",
+        "| 5.000000000000                                              |",
+        "| 5.000000000000                                              |",
+        "| 5.000000000000                                              |",
+        "| 5.000000000000                                              |",
+        "| 5.000000000000                                              |",
+        "+-------------------------------------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "select c1/c5 from decimal_simple";
+    let actual = execute_to_batches(&ctx, sql).await;
+    assert_eq!(
+        &DataType::Decimal(30, 19),
+        actual[0].schema().field(0).data_type()
+    );
+    let expected = vec![
+        "+--------------------------------------------+",
+        "| decimal_simple.c1 Divide decimal_simple.c5 |",
+        "+--------------------------------------------+",
+        "| 0.7142857142857143296                      |",
+        "| 0.8000000000000000000                      |",
+        "| 1.0526315789473683456                      |",
+        "| 0.9375000000000000000                      |",
+        "| 0.8571428571428571136                      |",
+        "| 2.7272727272727269376                      |",
+        "| 0.9090909090909090816                      |",
+        "| 1.0000000000000000000                      |",
+        "| 1.0000000000000000000                      |",
+        "| 0.9090909090909090816                      |",
+        "| 0.9615384615384614912                      |",
+        "| 0.6410256410256410624                      |",
+        "| 1.5151515151515152384                      |",
+        "| 0.7352941176470588416                      |",
+        "| 0.5000000000000000000                      |",
+        "+--------------------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    // modulo
+    let sql = "select c5%cast(0.00001 as decimal(5,5)) from decimal_simple";
+    let actual = execute_to_batches(&ctx, sql).await;
+    assert_eq!(
+        &DataType::Decimal(7, 7),
+        actual[0].schema().field(0).data_type()
+    );
+    let expected = vec![
+        "+-------------------------------------------------------------+",
+        "| decimal_simple.c5 % CAST(Float64(0.00001) AS Decimal(5, 5)) |",
+        "+-------------------------------------------------------------+",
+        "| 0.0000040                                                   |",
+        "| 0.0000050                                                   |",
+        "| 0.0000090                                                   |",
+        "| 0.0000020                                                   |",
+        "| 0.0000050                                                   |",
+        "| 0.0000010                                                   |",
+        "| 0.0000040                                                   |",
+        "| 0.0000000                                                   |",
+        "| 0.0000000                                                   |",
+        "| 0.0000040                                                   |",
+        "| 0.0000020                                                   |",
+        "| 0.0000080                                                   |",
+        "| 0.0000030                                                   |",
+        "| 0.0000080                                                   |",
+        "| 0.0000000                                                   |",
+        "+-------------------------------------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "select c1%c5 from decimal_simple";
+    let actual = execute_to_batches(&ctx, sql).await;
+    assert_eq!(
+        &DataType::Decimal(11, 7),
+        actual[0].schema().field(0).data_type()
+    );
+    let expected = vec![
+        "+--------------------------------------------+",
+        "| decimal_simple.c1 Modulo decimal_simple.c5 |",
+        "+--------------------------------------------+",
+        "| 0.0000100                                  |",
+        "| 0.0000200                                  |",
+        "| 0.0000010                                  |",
+        "| 0.0000300                                  |",
+        "| 0.0000300                                  |",
+        "| 0.0000080                                  |",
+        "| 0.0000400                                  |",
+        "| 0.0000000                                  |",
+        "| 0.0000000                                  |",
+        "| 0.0000400                                  |",
+        "| 0.0000500                                  |",
+        "| 0.0000500                                  |",
+        "| 0.0000170                                  |",
+        "| 0.0000500                                  |",
+        "| 0.0000500                                  |",
+        "+--------------------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    Ok(())
+}
+
 #[tokio::test]
 async fn decimal_sort() -> Result<()> {
     let ctx = SessionContext::new();
diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs
index a155fa5ed..6dafb43f9 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -279,18 +279,49 @@ where
         .collect()
 }
 
+fn arith_decimal_scalar<F>(
+    left: &DecimalArray,
+    right: i128,
+    op: F,
+) -> Result<DecimalArray>
+where
+    F: Fn(i128, i128) -> Result<i128>,
+{
+    left.iter()
+        .map(|left| {
+            if let Some(left) = left {
+                Some(op(left, right)).transpose()
+            } else {
+                Ok(None)
+            }
+        })
+        .collect()
+}
+
 fn add_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalArray> {
     let array = arith_decimal(left, right, |left, right| Ok(left + right))?
         .with_precision_and_scale(left.precision(), left.scale())?;
     Ok(array)
 }
 
+fn add_decimal_scalar(left: &DecimalArray, right: i128) -> Result<DecimalArray> {
+    let array = arith_decimal_scalar(left, right, |left, right| Ok(left + right))?
+        .with_precision_and_scale(left.precision(), left.scale())?;
+    Ok(array)
+}
+
 fn subtract_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalArray> {
     let array = arith_decimal(left, right, |left, right| Ok(left - right))?
         .with_precision_and_scale(left.precision(), left.scale())?;
     Ok(array)
 }
 
+fn subtract_decimal_scalar(left: &DecimalArray, right: i128) -> Result<DecimalArray> {
+    let array = arith_decimal_scalar(left, right, |left, right| Ok(left - right))?
+        .with_precision_and_scale(left.precision(), left.scale())?;
+    Ok(array)
+}
+
 fn multiply_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalArray> {
     let divide = 10_i128.pow(left.scale() as u32);
     let array = arith_decimal(left, right, |left, right| Ok(left * right / divide))?
@@ -298,6 +329,14 @@ fn multiply_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<Decimal
     Ok(array)
 }
 
+fn multiply_decimal_scalar(left: &DecimalArray, right: i128) -> Result<DecimalArray> {
+    let divide = 10_i128.pow(left.scale() as u32);
+    let array =
+        arith_decimal_scalar(left, right, |left, right| Ok(left * right / divide))?
+            .with_precision_and_scale(left.precision(), left.scale())?;
+    Ok(array)
+}
+
 fn divide_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalArray> {
     let mul = 10_f64.powi(left.scale() as i32);
     let array = arith_decimal(left, right, |left, right| {
@@ -310,6 +349,18 @@ fn divide_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalAr
     Ok(array)
 }
 
+fn divide_decimal_scalar(left: &DecimalArray, right: i128) -> Result<DecimalArray> {
+    let mul = 10_f64.powi(left.scale() as i32);
+    let array = arith_decimal_scalar(left, right, |left, right| {
+        let l_value = left as f64;
+        let r_value = right as f64;
+        let result = ((l_value / r_value) * mul) as i128;
+        Ok(result)
+    })?
+    .with_precision_and_scale(left.precision(), left.scale())?;
+    Ok(array)
+}
+
 fn modulus_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalArray> {
     let array = arith_decimal(left, right, |left, right| {
         if right == 0 {
@@ -322,6 +373,15 @@ fn modulus_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalA
     Ok(array)
 }
 
+fn modulus_decimal_scalar(left: &DecimalArray, right: i128) -> Result<DecimalArray> {
+    if right == 0 {
+        return Err(DataFusionError::ArrowError(DivideByZero));
+    }
+    let array = arith_decimal_scalar(left, right, |left, right| Ok(left % right))?
+        .with_precision_and_scale(left.precision(), left.scale())?;
+    Ok(array)
+}
+
 /// The binary_bitwise_array_op macro only evaluates for integer types
 /// like int64, int32.
 /// It is used to do bitwise operation.
@@ -777,6 +837,7 @@ macro_rules! binary_primitive_array_op {
 macro_rules! binary_primitive_array_op_scalar {
     ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
         let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
+            DataType::Decimal(_,_) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, DecimalArray),
             DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array),
             DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array),
             DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array),
@@ -2916,14 +2977,25 @@ mod tests {
         let expect =
             create_decimal_array(&[Some(246), None, Some(245), Some(247)], 25, 3)?;
         assert_eq!(expect, result);
+        let result = add_decimal_scalar(&left_decimal_array, 10)?;
+        let expect =
+            create_decimal_array(&[Some(133), None, Some(132), Some(134)], 25, 3)?;
+        assert_eq!(expect, result);
         // subtract
         let result = subtract_decimal(&left_decimal_array, &right_decimal_array)?;
         let expect = create_decimal_array(&[Some(0), None, Some(-1), Some(1)], 25, 3)?;
         assert_eq!(expect, result);
+        let result = subtract_decimal_scalar(&left_decimal_array, 10)?;
+        let expect =
+            create_decimal_array(&[Some(113), None, Some(112), Some(114)], 25, 3)?;
+        assert_eq!(expect, result);
         // multiply
         let result = multiply_decimal(&left_decimal_array, &right_decimal_array)?;
         let expect = create_decimal_array(&[Some(15), None, Some(15), Some(15)], 25, 3)?;
         assert_eq!(expect, result);
+        let result = multiply_decimal_scalar(&left_decimal_array, 10)?;
+        let expect = create_decimal_array(&[Some(1), None, Some(1), Some(1)], 25, 3)?;
+        assert_eq!(expect, result);
         // divide
         let left_decimal_array = create_decimal_array(
             &[Some(1234567), None, Some(1234567), Some(1234567)],
@@ -2939,10 +3011,20 @@ mod tests {
             3,
         )?;
         assert_eq!(expect, result);
+        let result = divide_decimal_scalar(&left_decimal_array, 10)?;
+        let expect = create_decimal_array(
+            &[Some(123456700), None, Some(123456700), Some(123456700)],
+            25,
+            3,
+        )?;
+        assert_eq!(expect, result);
         // modulus
         let result = modulus_decimal(&left_decimal_array, &right_decimal_array)?;
         let expect = create_decimal_array(&[Some(7), None, Some(37), Some(16)], 25, 3)?;
         assert_eq!(expect, result);
+        let result = modulus_decimal_scalar(&left_decimal_array, 10)?;
+        let expect = create_decimal_array(&[Some(7), None, Some(7), Some(7)], 25, 3)?;
+        assert_eq!(expect, result);
 
         Ok(())
     }