You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/12/03 15:39:30 UTC

[arrow-rs] branch master updated: Get the round result for decimal to a decimal with smaller scale (#3224)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new cb4170b50 Get the round result for decimal to a decimal with smaller scale  (#3224)
cb4170b50 is described below

commit cb4170b50a54c466897afc83583f01dca23544c0
Author: Kun Liu <li...@apache.org>
AuthorDate: Sat Dec 3 23:39:25 2022 +0800

    Get the round result for decimal to a decimal with smaller scale  (#3224)
    
    * support cast decimal for round when the option is false
    
    * fix conflict after merge
    
    * fix error case
    
    * change to wrapping api
---
 arrow-cast/src/cast.rs | 143 ++++++++++++++++++++++++++++++++++++++-----------
 1 file changed, 111 insertions(+), 32 deletions(-)

diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index be767f137..8d28a6cc7 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -2164,6 +2164,7 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
         if BYTE_WIDTH1 == 16 {
             let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
             if BYTE_WIDTH2 == 16 {
+                // the div must be greater or equal than 10
                 let div = 10_i128
                     .pow_checked((input_scale - output_scale) as u32)
                     .map_err(|_| {
@@ -2172,10 +2173,23 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
                             *output_scale,
                         ))
                     })?;
+                let half = div / 2;
+                let neg_half = -half;
 
                 array
                     .try_unary::<_, Decimal128Type, _>(|v| {
-                        v.checked_div(div).ok_or_else(|| {
+                        // cast to smaller scale, need to round the result
+                        // the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
+                        let d = v.wrapping_div(div);
+                        let r = v.wrapping_rem(div);
+                        if v >= 0 && r >= half {
+                            d.checked_add(1)
+                        } else if v < 0 && r <= neg_half {
+                            d.checked_sub(1)
+                        } else {
+                            Some(d)
+                        }
+                        .ok_or_else(|| {
                             ArrowError::CastError(format!(
                                 "Cannot cast to {:?}({}, {}). Overflowing on {:?}",
                                 Decimal128Type::PREFIX,
@@ -2199,9 +2213,23 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
                         ))
                     })?;
 
+                let half = div / i256::from_i128(2_i128);
+                let neg_half = -half;
+
                 array
                     .try_unary::<_, Decimal256Type, _>(|v| {
-                        i256::from_i128(v).checked_div(div).ok_or_else(|| {
+                        // the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
+                        let v = i256::from_i128(v);
+                        let d = v.wrapping_div(div);
+                        let r = v.wrapping_rem(div);
+                        if v >= i256::ZERO && r >= half {
+                            d.checked_add(i256::ONE)
+                        } else if v < i256::ZERO && r <= neg_half {
+                            d.checked_sub(i256::ONE)
+                        } else {
+                            Some(d)
+                        }
+                        .ok_or_else(|| {
                             ArrowError::CastError(format!(
                                 "Cannot cast to {:?}({}, {}). Overflowing on {:?}",
                                 Decimal256Type::PREFIX,
@@ -2226,10 +2254,21 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
                         *output_scale,
                     ))
                 })?;
+            let half = div / i256::from_i128(2_i128);
+            let neg_half = -half;
             if BYTE_WIDTH2 == 16 {
                 array
                     .try_unary::<_, Decimal128Type, _>(|v| {
-                        v.checked_div(div).ok_or_else(|| {
+                        // the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
+                        let d = v.wrapping_div(div);
+                        let r = v.wrapping_rem(div);
+                        if v >= i256::ZERO && r >= half {
+                            d.checked_add(i256::ONE)
+                        } else if v < i256::ZERO && r <= neg_half {
+                            d.checked_sub(i256::ONE)
+                        } else {
+                            Some(d)
+                        }.ok_or_else(|| {
                             ArrowError::CastError(format!(
                                 "Cannot cast to {:?}({}, {}). Overflowing on {:?}",
                                 Decimal128Type::PREFIX,
@@ -2250,7 +2289,17 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
             } else {
                 array
                     .try_unary::<_, Decimal256Type, _>(|v| {
-                        v.checked_div(div).ok_or_else(|| {
+                        // the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
+                        let d = v.wrapping_div(div);
+                        let r = v.wrapping_rem(div);
+                        if v >= i256::ZERO && r >= half {
+                            d.checked_add(i256::ONE)
+                        } else if v < i256::ZERO && r <= neg_half {
+                            d.checked_sub(i256::ONE)
+                        } else {
+                            Some(d)
+                        }
+                        .ok_or_else(|| {
                             ArrowError::CastError(format!(
                                 "Cannot cast to {:?}({}, {}). Overflowing on {:?}",
                                 Decimal256Type::PREFIX,
@@ -3621,6 +3670,26 @@ mod tests {
                     }
                 }
             }
+
+            let cast_option = CastOptions { safe: false };
+            let casted_array_with_option =
+                cast_with_options($INPUT_ARRAY, $OUTPUT_TYPE, &cast_option).unwrap();
+            let result_array = casted_array_with_option
+                .as_any()
+                .downcast_ref::<$OUTPUT_TYPE_ARRAY>()
+                .unwrap();
+            assert_eq!($OUTPUT_TYPE, result_array.data_type());
+            assert_eq!(result_array.len(), $OUTPUT_VALUES.len());
+            for (i, x) in $OUTPUT_VALUES.iter().enumerate() {
+                match x {
+                    Some(x) => {
+                        assert_eq!(result_array.value(i), *x);
+                    }
+                    None => {
+                        assert!(result_array.is_null(i));
+                    }
+                }
+            }
         };
     }
 
@@ -3647,6 +3716,44 @@ mod tests {
     }
 
     #[test]
+    #[cfg(not(feature = "force_validate"))]
+    #[should_panic(
+        expected = "5789604461865809771178549250434395392663499233282028201972879200395656481997 cannot be casted to 128-bit integer for Decimal128"
+    )]
+    fn test_cast_decimal_to_decimal_round_with_error() {
+        // decimal256 to decimal128 overflow
+        let array = vec![
+            Some(i256::from_i128(1123454)),
+            Some(i256::from_i128(2123456)),
+            Some(i256::from_i128(-3123453)),
+            Some(i256::from_i128(-3123456)),
+            None,
+            Some(i256::MAX),
+            Some(i256::MIN),
+        ];
+        let input_decimal_array = create_decimal256_array(array, 76, 4).unwrap();
+        let array = Arc::new(input_decimal_array) as ArrayRef;
+        let input_type = DataType::Decimal256(76, 4);
+        let output_type = DataType::Decimal128(20, 3);
+        assert!(can_cast_types(&input_type, &output_type));
+        generate_cast_test_case!(
+            &array,
+            Decimal128Array,
+            &output_type,
+            vec![
+                Some(112345_i128),
+                Some(212346_i128),
+                Some(-312345_i128),
+                Some(-312346_i128),
+                None,
+                None,
+                None,
+            ]
+        );
+    }
+
+    #[test]
+    #[cfg(not(feature = "force_validate"))]
     fn test_cast_decimal_to_decimal_round() {
         let array = vec![
             Some(1123454),
@@ -3734,34 +3841,6 @@ mod tests {
                 None
             ]
         );
-
-        // decimal256 to decimal128 overflow
-        let array = vec![
-            Some(i256::from_i128(1123454)),
-            Some(i256::from_i128(2123456)),
-            Some(i256::from_i128(-3123453)),
-            Some(i256::from_i128(-3123456)),
-            None,
-            Some(i256::MAX),
-            Some(i256::MIN),
-        ];
-        let input_decimal_array = create_decimal256_array(array, 76, 4).unwrap();
-        let array = Arc::new(input_decimal_array) as ArrayRef;
-        assert!(can_cast_types(&input_type, &output_type));
-        generate_cast_test_case!(
-            &array,
-            Decimal128Array,
-            &output_type,
-            vec![
-                Some(112345_i128),
-                Some(212346_i128),
-                Some(-312345_i128),
-                Some(-312346_i128),
-                None,
-                None,
-                None
-            ]
-        );
     }
 
     #[test]