You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2023/05/13 13:56:46 UTC

[arrow-datafusion] branch main updated: refine decimal multiply, avoid cast to wider type (#6331)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 063f99fd61 refine decimal multiply, avoid cast to wider type (#6331)
063f99fd61 is described below

commit 063f99fd61aa8f2854dc138d16864ae485a7116e
Author: Ken, Wang <mi...@gmail.com>
AuthorDate: Sat May 13 21:56:41 2023 +0800

    refine decimal multiply, avoid cast to wider type (#6331)
    
    * refine decimal multiply, avoid cast to wider type
    
    * fix clippy
    
    * fix fmt
---
 datafusion/expr/src/type_coercion/binary.rs        | 21 ++++++++++-----------
 .../src/expressions/binary/kernels_arrow.rs        | 22 +++-------------------
 2 files changed, 13 insertions(+), 30 deletions(-)

diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs
index fa912b777a..962434d652 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -514,7 +514,7 @@ pub fn coercion_decimal_mathematics_type(
                 left_decimal_type,
                 right_decimal_type,
             ),
-            Operator::Multiply | Operator::Divide | Operator::Modulo => {
+            Operator::Divide | Operator::Modulo => {
                 get_wider_decimal_type(left_decimal_type, right_decimal_type)
             }
             _ => None,
@@ -946,7 +946,7 @@ mod tests {
             &left_decimal_type,
             &right_decimal_type,
         );
-        assert_eq!(DataType::Decimal128(20, 4), result.unwrap());
+        assert_eq!(None, result);
         let result =
             decimal_op_mathematics_type(&op, &left_decimal_type, &right_decimal_type);
         assert_eq!(DataType::Decimal128(31, 7), result.unwrap());
@@ -1232,7 +1232,7 @@ mod tests {
         mathematics_op: Operator,
         expected_lhs_type: Option<DataType>,
         expected_rhs_type: Option<DataType>,
-        expected_coerced_type: DataType,
+        expected_coerced_type: Option<DataType>,
         expected_output_type: DataType,
     ) {
         // The coerced types for lhs and rhs, if any of them is not decimal
@@ -1245,8 +1245,7 @@ mod tests {
 
         // The coerced type of decimal math expression, applied during expression evaluation
         let coerced_type =
-            coercion_decimal_mathematics_type(&mathematics_op, &lhs_type, &rhs_type)
-                .unwrap();
+            coercion_decimal_mathematics_type(&mathematics_op, &lhs_type, &rhs_type);
         assert_eq!(coerced_type, expected_coerced_type);
 
         // The output type of decimal math expression
@@ -1263,7 +1262,7 @@ mod tests {
             Operator::Plus,
             None,
             None,
-            DataType::Decimal128(11, 2),
+            Some(DataType::Decimal128(11, 2)),
             DataType::Decimal128(11, 2),
         );
 
@@ -1273,7 +1272,7 @@ mod tests {
             Operator::Plus,
             Some(DataType::Decimal128(10, 0)),
             None,
-            DataType::Decimal128(13, 2),
+            Some(DataType::Decimal128(13, 2)),
             DataType::Decimal128(13, 2),
         );
 
@@ -1283,7 +1282,7 @@ mod tests {
             Operator::Minus,
             Some(DataType::Decimal128(10, 0)),
             None,
-            DataType::Decimal128(13, 2),
+            Some(DataType::Decimal128(13, 2)),
             DataType::Decimal128(13, 2),
         );
 
@@ -1293,7 +1292,7 @@ mod tests {
             Operator::Multiply,
             Some(DataType::Decimal128(10, 0)),
             None,
-            DataType::Decimal128(12, 2),
+            None,
             DataType::Decimal128(21, 2),
         );
 
@@ -1303,7 +1302,7 @@ mod tests {
             Operator::Divide,
             Some(DataType::Decimal128(10, 0)),
             None,
-            DataType::Decimal128(12, 2),
+            Some(DataType::Decimal128(12, 2)),
             DataType::Decimal128(23, 11),
         );
 
@@ -1313,7 +1312,7 @@ mod tests {
             Operator::Modulo,
             Some(DataType::Decimal128(10, 0)),
             None,
-            DataType::Decimal128(12, 2),
+            Some(DataType::Decimal128(12, 2)),
             DataType::Decimal128(10, 2),
         );
 
diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
index 2852d617bf..90fca17157 100644
--- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
+++ b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
@@ -35,9 +35,7 @@ use arrow_schema::DataType;
 use datafusion_common::cast::{as_date32_array, as_date64_array, as_decimal128_array};
 use datafusion_common::scalar::{date32_add, date64_add};
 use datafusion_common::{DataFusionError, Result, ScalarValue};
-use datafusion_expr::type_coercion::binary::decimal_op_mathematics_type;
 use datafusion_expr::ColumnarValue;
-use datafusion_expr::Operator;
 use std::cmp::min;
 use std::sync::Arc;
 
@@ -469,24 +467,8 @@ pub(crate) fn multiply_decimal_dyn_scalar(
     result_type: &DataType,
 ) -> Result<ArrayRef> {
     let (precision, scale) = get_precision_scale(result_type)?;
-
-    let op_type = decimal_op_mathematics_type(
-        &Operator::Multiply,
-        left.data_type(),
-        left.data_type(),
-    )
-    .unwrap();
-    let (_, op_scale) = get_precision_scale(&op_type)?;
-
     let array = multiply_scalar_dyn::<Decimal128Type>(left, right)?;
-
-    if op_scale > scale {
-        let div = 10_i128.pow((op_scale - scale) as u32);
-        let array = divide_scalar_dyn::<Decimal128Type>(&array, div)?;
-        decimal_array_with_precision_scale(array, precision, scale)
-    } else {
-        decimal_array_with_precision_scale(array, precision, scale)
-    }
+    decimal_array_with_precision_scale(array, precision, scale)
 }
 
 pub(crate) fn divide_decimal_dyn_scalar(
@@ -703,6 +685,8 @@ pub(crate) fn modulus_decimal_dyn_scalar(
 #[cfg(test)]
 mod tests {
     use super::*;
+    use datafusion_expr::type_coercion::binary::decimal_op_mathematics_type;
+    use datafusion_expr::Operator;
 
     fn create_decimal_array(
         array: &[Option<i128>],