You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/04/19 14:08:40 UTC
[arrow-datafusion] branch master updated: support array with scalar arithmetic operation for decimal data type (#2233)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 7334481f0 support array with scalar arithmetic operation for decimal data type (#2233)
7334481f0 is described below
commit 7334481f0dba36bc8d67005e468e00f69a269724
Author: Kun Liu <li...@apache.org>
AuthorDate: Tue Apr 19 22:08:33 2022 +0800
support array with scalar arithmetic operation for decimal data type (#2233)
* support array with scalar arithmetic operation for decimal data type
* add sql level test for arithmetic op
* add sql level test for decimal for arithmetic operation
---
datafusion/core/tests/sql/decimal.rs | 300 +++++++++++++++++++++
datafusion/physical-expr/src/expressions/binary.rs | 82 ++++++
2 files changed, 382 insertions(+)
diff --git a/datafusion/core/tests/sql/decimal.rs b/datafusion/core/tests/sql/decimal.rs
index 1e7920935..2153495a6 100644
--- a/datafusion/core/tests/sql/decimal.rs
+++ b/datafusion/core/tests/sql/decimal.rs
@@ -363,6 +363,306 @@ async fn decimal_logic_op() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn decimal_arithmetic_op() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_decimal_csv_table_by_sql(&ctx).await;
+ // add
+ let sql = "select c1+1 from decimal_simple"; // add scalar
+ let actual = execute_to_batches(&ctx, sql).await;
+ // array decimal(10,6) + scalar decimal(20,0) => decimal(21,6)
+ assert_eq!(
+ &DataType::Decimal(27, 6),
+ actual[0].schema().field(0).data_type()
+ );
+ let expected = vec![
+ "+---------------------------------+",
+ "| decimal_simple.c1 Plus Int64(1) |",
+ "+---------------------------------+",
+ "| 1.000010 |",
+ "| 1.000020 |",
+ "| 1.000020 |",
+ "| 1.000030 |",
+ "| 1.000030 |",
+ "| 1.000030 |",
+ "| 1.000040 |",
+ "| 1.000040 |",
+ "| 1.000040 |",
+ "| 1.000040 |",
+ "| 1.000050 |",
+ "| 1.000050 |",
+ "| 1.000050 |",
+ "| 1.000050 |",
+ "| 1.000050 |",
+ "+---------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ // array decimal(10,6) + array decimal(12,7) => decimal(13,7)
+ let sql = "select c1+c5 from decimal_simple";
+ let actual = execute_to_batches(&ctx, sql).await;
+ assert_eq!(
+ &DataType::Decimal(13, 7),
+ actual[0].schema().field(0).data_type()
+ );
+ let expected = vec![
+ "+------------------------------------------+",
+ "| decimal_simple.c1 Plus decimal_simple.c5 |",
+ "+------------------------------------------+",
+ "| 0.0000240 |",
+ "| 0.0000450 |",
+ "| 0.0000390 |",
+ "| 0.0000620 |",
+ "| 0.0000650 |",
+ "| 0.0000410 |",
+ "| 0.0000840 |",
+ "| 0.0000800 |",
+ "| 0.0000800 |",
+ "| 0.0000840 |",
+ "| 0.0001020 |",
+ "| 0.0001280 |",
+ "| 0.0000830 |",
+ "| 0.0001180 |",
+ "| 0.0001500 |",
+ "+------------------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ // subtract
+ let sql = "select c1-1 from decimal_simple";
+ let actual = execute_to_batches(&ctx, sql).await;
+ assert_eq!(
+ &DataType::Decimal(27, 6),
+ actual[0].schema().field(0).data_type()
+ );
+ let expected = vec![
+ "+----------------------------------+",
+ "| decimal_simple.c1 Minus Int64(1) |",
+ "+----------------------------------+",
+ "| -0.999990 |",
+ "| -0.999980 |",
+ "| -0.999980 |",
+ "| -0.999970 |",
+ "| -0.999970 |",
+ "| -0.999970 |",
+ "| -0.999960 |",
+ "| -0.999960 |",
+ "| -0.999960 |",
+ "| -0.999960 |",
+ "| -0.999950 |",
+ "| -0.999950 |",
+ "| -0.999950 |",
+ "| -0.999950 |",
+ "| -0.999950 |",
+ "+----------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "select c1-c5 from decimal_simple";
+ let actual = execute_to_batches(&ctx, sql).await;
+ assert_eq!(
+ &DataType::Decimal(13, 7),
+ actual[0].schema().field(0).data_type()
+ );
+ let expected = vec![
+ "+-------------------------------------------+",
+ "| decimal_simple.c1 Minus decimal_simple.c5 |",
+ "+-------------------------------------------+",
+ "| -0.0000040 |",
+ "| -0.0000050 |",
+ "| 0.0000010 |",
+ "| -0.0000020 |",
+ "| -0.0000050 |",
+ "| 0.0000190 |",
+ "| -0.0000040 |",
+ "| 0.0000000 |",
+ "| 0.0000000 |",
+ "| -0.0000040 |",
+ "| -0.0000020 |",
+ "| -0.0000280 |",
+ "| 0.0000170 |",
+ "| -0.0000180 |",
+ "| -0.0000500 |",
+ "+-------------------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ // multiply
+ let sql = "select c1*20 from decimal_simple";
+ let actual = execute_to_batches(&ctx, sql).await;
+ assert_eq!(
+ &DataType::Decimal(31, 6),
+ actual[0].schema().field(0).data_type()
+ );
+ let expected = vec![
+ "+--------------------------------------+",
+ "| decimal_simple.c1 Multiply Int64(20) |",
+ "+--------------------------------------+",
+ "| 0.000200 |",
+ "| 0.000400 |",
+ "| 0.000400 |",
+ "| 0.000600 |",
+ "| 0.000600 |",
+ "| 0.000600 |",
+ "| 0.000800 |",
+ "| 0.000800 |",
+ "| 0.000800 |",
+ "| 0.000800 |",
+ "| 0.001000 |",
+ "| 0.001000 |",
+ "| 0.001000 |",
+ "| 0.001000 |",
+ "| 0.001000 |",
+ "+--------------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "select c1*c5 from decimal_simple";
+ let actual = execute_to_batches(&ctx, sql).await;
+ assert_eq!(
+ &DataType::Decimal(23, 13),
+ actual[0].schema().field(0).data_type()
+ );
+ let expected = vec![
+ "+----------------------------------------------+",
+ "| decimal_simple.c1 Multiply decimal_simple.c5 |",
+ "+----------------------------------------------+",
+ "| 0.0000000001400 |",
+ "| 0.0000000005000 |",
+ "| 0.0000000003800 |",
+ "| 0.0000000009600 |",
+ "| 0.0000000010500 |",
+ "| 0.0000000003300 |",
+ "| 0.0000000017600 |",
+ "| 0.0000000016000 |",
+ "| 0.0000000016000 |",
+ "| 0.0000000017600 |",
+ "| 0.0000000026000 |",
+ "| 0.0000000039000 |",
+ "| 0.0000000016500 |",
+ "| 0.0000000034000 |",
+ "| 0.0000000050000 |",
+ "+----------------------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ // divide
+ let sql = "select c1/cast(0.00001 as decimal(5,5)) from decimal_simple";
+ let actual = execute_to_batches(&ctx, sql).await;
+ assert_eq!(
+ &DataType::Decimal(21, 12),
+ actual[0].schema().field(0).data_type()
+ );
+ let expected = vec![
+ "+-------------------------------------------------------------+",
+ "| decimal_simple.c1 / CAST(Float64(0.00001) AS Decimal(5, 5)) |",
+ "+-------------------------------------------------------------+",
+ "| 1.000000000000 |",
+ "| 2.000000000000 |",
+ "| 2.000000000000 |",
+ "| 3.000000000000 |",
+ "| 3.000000000000 |",
+ "| 3.000000000000 |",
+ "| 4.000000000000 |",
+ "| 4.000000000000 |",
+ "| 4.000000000000 |",
+ "| 4.000000000000 |",
+ "| 5.000000000000 |",
+ "| 5.000000000000 |",
+ "| 5.000000000000 |",
+ "| 5.000000000000 |",
+ "| 5.000000000000 |",
+ "+-------------------------------------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "select c1/c5 from decimal_simple";
+ let actual = execute_to_batches(&ctx, sql).await;
+ assert_eq!(
+ &DataType::Decimal(30, 19),
+ actual[0].schema().field(0).data_type()
+ );
+ let expected = vec![
+ "+--------------------------------------------+",
+ "| decimal_simple.c1 Divide decimal_simple.c5 |",
+ "+--------------------------------------------+",
+ "| 0.7142857142857143296 |",
+ "| 0.8000000000000000000 |",
+ "| 1.0526315789473683456 |",
+ "| 0.9375000000000000000 |",
+ "| 0.8571428571428571136 |",
+ "| 2.7272727272727269376 |",
+ "| 0.9090909090909090816 |",
+ "| 1.0000000000000000000 |",
+ "| 1.0000000000000000000 |",
+ "| 0.9090909090909090816 |",
+ "| 0.9615384615384614912 |",
+ "| 0.6410256410256410624 |",
+ "| 1.5151515151515152384 |",
+ "| 0.7352941176470588416 |",
+ "| 0.5000000000000000000 |",
+ "+--------------------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ // modulo
+ let sql = "select c5%cast(0.00001 as decimal(5,5)) from decimal_simple";
+ let actual = execute_to_batches(&ctx, sql).await;
+ assert_eq!(
+ &DataType::Decimal(7, 7),
+ actual[0].schema().field(0).data_type()
+ );
+ let expected = vec![
+ "+-------------------------------------------------------------+",
+ "| decimal_simple.c5 % CAST(Float64(0.00001) AS Decimal(5, 5)) |",
+ "+-------------------------------------------------------------+",
+ "| 0.0000040 |",
+ "| 0.0000050 |",
+ "| 0.0000090 |",
+ "| 0.0000020 |",
+ "| 0.0000050 |",
+ "| 0.0000010 |",
+ "| 0.0000040 |",
+ "| 0.0000000 |",
+ "| 0.0000000 |",
+ "| 0.0000040 |",
+ "| 0.0000020 |",
+ "| 0.0000080 |",
+ "| 0.0000030 |",
+ "| 0.0000080 |",
+ "| 0.0000000 |",
+ "+-------------------------------------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "select c1%c5 from decimal_simple";
+ let actual = execute_to_batches(&ctx, sql).await;
+ assert_eq!(
+ &DataType::Decimal(11, 7),
+ actual[0].schema().field(0).data_type()
+ );
+ let expected = vec![
+ "+--------------------------------------------+",
+ "| decimal_simple.c1 Modulo decimal_simple.c5 |",
+ "+--------------------------------------------+",
+ "| 0.0000100 |",
+ "| 0.0000200 |",
+ "| 0.0000010 |",
+ "| 0.0000300 |",
+ "| 0.0000300 |",
+ "| 0.0000080 |",
+ "| 0.0000400 |",
+ "| 0.0000000 |",
+ "| 0.0000000 |",
+ "| 0.0000400 |",
+ "| 0.0000500 |",
+ "| 0.0000500 |",
+ "| 0.0000170 |",
+ "| 0.0000500 |",
+ "| 0.0000500 |",
+ "+--------------------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
+
#[tokio::test]
async fn decimal_sort() -> Result<()> {
let ctx = SessionContext::new();
diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs
index a155fa5ed..6dafb43f9 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -279,18 +279,49 @@ where
.collect()
}
+fn arith_decimal_scalar<F>(
+ left: &DecimalArray,
+ right: i128,
+ op: F,
+) -> Result<DecimalArray>
+where
+ F: Fn(i128, i128) -> Result<i128>,
+{
+ left.iter()
+ .map(|left| {
+ if let Some(left) = left {
+ Some(op(left, right)).transpose()
+ } else {
+ Ok(None)
+ }
+ })
+ .collect()
+}
+
fn add_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalArray> {
let array = arith_decimal(left, right, |left, right| Ok(left + right))?
.with_precision_and_scale(left.precision(), left.scale())?;
Ok(array)
}
+fn add_decimal_scalar(left: &DecimalArray, right: i128) -> Result<DecimalArray> {
+ let array = arith_decimal_scalar(left, right, |left, right| Ok(left + right))?
+ .with_precision_and_scale(left.precision(), left.scale())?;
+ Ok(array)
+}
+
fn subtract_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalArray> {
let array = arith_decimal(left, right, |left, right| Ok(left - right))?
.with_precision_and_scale(left.precision(), left.scale())?;
Ok(array)
}
+fn subtract_decimal_scalar(left: &DecimalArray, right: i128) -> Result<DecimalArray> {
+ let array = arith_decimal_scalar(left, right, |left, right| Ok(left - right))?
+ .with_precision_and_scale(left.precision(), left.scale())?;
+ Ok(array)
+}
+
fn multiply_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalArray> {
let divide = 10_i128.pow(left.scale() as u32);
let array = arith_decimal(left, right, |left, right| Ok(left * right / divide))?
@@ -298,6 +329,14 @@ fn multiply_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<Decimal
Ok(array)
}
+fn multiply_decimal_scalar(left: &DecimalArray, right: i128) -> Result<DecimalArray> {
+ let divide = 10_i128.pow(left.scale() as u32);
+ let array =
+ arith_decimal_scalar(left, right, |left, right| Ok(left * right / divide))?
+ .with_precision_and_scale(left.precision(), left.scale())?;
+ Ok(array)
+}
+
fn divide_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalArray> {
let mul = 10_f64.powi(left.scale() as i32);
let array = arith_decimal(left, right, |left, right| {
@@ -310,6 +349,18 @@ fn divide_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalAr
Ok(array)
}
+fn divide_decimal_scalar(left: &DecimalArray, right: i128) -> Result<DecimalArray> {
+ let mul = 10_f64.powi(left.scale() as i32);
+ let array = arith_decimal_scalar(left, right, |left, right| {
+ let l_value = left as f64;
+ let r_value = right as f64;
+ let result = ((l_value / r_value) * mul) as i128;
+ Ok(result)
+ })?
+ .with_precision_and_scale(left.precision(), left.scale())?;
+ Ok(array)
+}
+
fn modulus_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalArray> {
let array = arith_decimal(left, right, |left, right| {
if right == 0 {
@@ -322,6 +373,15 @@ fn modulus_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalA
Ok(array)
}
+fn modulus_decimal_scalar(left: &DecimalArray, right: i128) -> Result<DecimalArray> {
+ if right == 0 {
+ return Err(DataFusionError::ArrowError(DivideByZero));
+ }
+ let array = arith_decimal_scalar(left, right, |left, right| Ok(left % right))?
+ .with_precision_and_scale(left.precision(), left.scale())?;
+ Ok(array)
+}
+
/// The binary_bitwise_array_op macro only evaluates for integer types
/// like int64, int32.
/// It is used to do bitwise operation.
@@ -777,6 +837,7 @@ macro_rules! binary_primitive_array_op {
macro_rules! binary_primitive_array_op_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
+ DataType::Decimal(_,_) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, DecimalArray),
DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array),
DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array),
DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array),
@@ -2916,14 +2977,25 @@ mod tests {
let expect =
create_decimal_array(&[Some(246), None, Some(245), Some(247)], 25, 3)?;
assert_eq!(expect, result);
+ let result = add_decimal_scalar(&left_decimal_array, 10)?;
+ let expect =
+ create_decimal_array(&[Some(133), None, Some(132), Some(134)], 25, 3)?;
+ assert_eq!(expect, result);
// subtract
let result = subtract_decimal(&left_decimal_array, &right_decimal_array)?;
let expect = create_decimal_array(&[Some(0), None, Some(-1), Some(1)], 25, 3)?;
assert_eq!(expect, result);
+ let result = subtract_decimal_scalar(&left_decimal_array, 10)?;
+ let expect =
+ create_decimal_array(&[Some(113), None, Some(112), Some(114)], 25, 3)?;
+ assert_eq!(expect, result);
// multiply
let result = multiply_decimal(&left_decimal_array, &right_decimal_array)?;
let expect = create_decimal_array(&[Some(15), None, Some(15), Some(15)], 25, 3)?;
assert_eq!(expect, result);
+ let result = multiply_decimal_scalar(&left_decimal_array, 10)?;
+ let expect = create_decimal_array(&[Some(1), None, Some(1), Some(1)], 25, 3)?;
+ assert_eq!(expect, result);
// divide
let left_decimal_array = create_decimal_array(
&[Some(1234567), None, Some(1234567), Some(1234567)],
@@ -2939,10 +3011,20 @@ mod tests {
3,
)?;
assert_eq!(expect, result);
+ let result = divide_decimal_scalar(&left_decimal_array, 10)?;
+ let expect = create_decimal_array(
+ &[Some(123456700), None, Some(123456700), Some(123456700)],
+ 25,
+ 3,
+ )?;
+ assert_eq!(expect, result);
// modulus
let result = modulus_decimal(&left_decimal_array, &right_decimal_array)?;
let expect = create_decimal_array(&[Some(7), None, Some(37), Some(16)], 25, 3)?;
assert_eq!(expect, result);
+ let result = modulus_decimal_scalar(&left_decimal_array, 10)?;
+ let expect = create_decimal_array(&[Some(7), None, Some(7), Some(7)], 25, 3)?;
+ assert_eq!(expect, result);
Ok(())
}