You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/11/03 22:19:48 UTC

[arrow-rs] branch master updated: Round instead of Truncate while casting float to decimal (#3000)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 61cf6f75c Round instead of Truncate while casting float to decimal (#3000)
61cf6f75c is described below

commit 61cf6f75c4e03ca950f750cb2fdba4adee534372
Author: Wei-Ting Kuo <wa...@gmail.com>
AuthorDate: Fri Nov 4 06:19:42 2022 +0800

    Round instead of Truncate while casting float to decimal (#3000)
    
    * add .round() before casting to integer
    
    * add more test cases
    
    * update test cases
    
    * add doc
    
    * Format
    
    Co-authored-by: Raphael Taylor-Davies <r....@googlemail.com>
---
 arrow/src/compute/kernels/cast.rs | 103 +++++++++++++++++++++++++++++---------
 1 file changed, 79 insertions(+), 24 deletions(-)

diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs
index 4c724b640..4ad8dd99e 100644
--- a/arrow/src/compute/kernels/cast.rs
+++ b/arrow/src/compute/kernels/cast.rs
@@ -297,6 +297,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
 /// * Time32 and Time64: precision lost when going to higher interval
 /// * Timestamp and Date{32|64}: precision lost when going to higher interval
 /// * Temporal to/from backing primitive: zero-copy with data type change
+/// * Casting from `float32/float64` to `Decimal(precision, scale)` rounds to the `scale` decimals
+///   (i.e. casting 6.4999 to Decimal(10, 1) becomes 6.5). This is the breaking change from `26.0.0`.
+///   It used to truncate it instead of round (i.e. outputs 6.4 instead)
 ///
 /// Unsupported Casts
 /// * To or from `StructArray`
@@ -353,7 +356,7 @@ where
 {
     let mul = 10_f64.powi(scale as i32);
 
-    unary::<T, _, Decimal128Type>(array, |v| (v.as_() * mul) as i128)
+    unary::<T, _, Decimal128Type>(array, |v| (v.as_() * mul).round() as i128)
         .with_precision_and_scale(precision, scale)
         .map(|a| Arc::new(a) as ArrayRef)
 }
@@ -368,9 +371,11 @@ where
 {
     let mul = 10_f64.powi(scale as i32);
 
-    unary::<T, _, Decimal256Type>(array, |v| i256::from_i128((v.as_() * mul) as i128))
-        .with_precision_and_scale(precision, scale)
-        .map(|a| Arc::new(a) as ArrayRef)
+    unary::<T, _, Decimal256Type>(array, |v| {
+        i256::from_i128((v.as_() * mul).round() as i128)
+    })
+    .with_precision_and_scale(precision, scale)
+    .map(|a| Arc::new(a) as ArrayRef)
 }
 
 /// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`]
@@ -3192,8 +3197,8 @@ mod tests {
             Some(2.2),
             Some(4.4),
             None,
-            Some(1.123_456_7),
-            Some(1.123_456_7),
+            Some(1.123_456_4), // round down
+            Some(1.123_456_7), // round up
         ]);
         let array = Arc::new(array) as ArrayRef;
         generate_cast_test_case!(
@@ -3205,8 +3210,8 @@ mod tests {
                 Some(2200000_i128),
                 Some(4400000_i128),
                 None,
-                Some(1123456_i128),
-                Some(1123456_i128),
+                Some(1123456_i128), // round down
+                Some(1123457_i128), // round up
             ]
         );
 
@@ -3216,9 +3221,10 @@ mod tests {
             Some(2.2),
             Some(4.4),
             None,
-            Some(1.123_456_789_123_4),
-            Some(1.123_456_789_012_345_6),
-            Some(1.123_456_789_012_345_6),
+            Some(1.123_456_489_123_4),     // round up
+            Some(1.123_456_789_123_4),     // round up
+            Some(1.123_456_489_012_345_6), // round down
+            Some(1.123_456_789_012_345_6), // round up
         ]);
         let array = Arc::new(array) as ArrayRef;
         generate_cast_test_case!(
@@ -3230,9 +3236,10 @@ mod tests {
                 Some(2200000_i128),
                 Some(4400000_i128),
                 None,
-                Some(1123456_i128),
-                Some(1123456_i128),
-                Some(1123456_i128),
+                Some(1123456_i128), // round down
+                Some(1123457_i128), // round up
+                Some(1123456_i128), // round down
+                Some(1123457_i128), // round up
             ]
         );
     }
@@ -3307,8 +3314,8 @@ mod tests {
             Some(2.2),
             Some(4.4),
             None,
-            Some(1.123_456_7),
-            Some(1.123_456_7),
+            Some(1.123_456_4), // round down
+            Some(1.123_456_7), // round up
         ]);
         let array = Arc::new(array) as ArrayRef;
         generate_cast_test_case!(
@@ -3320,8 +3327,8 @@ mod tests {
                 Some(i256::from_i128(2200000_i128)),
                 Some(i256::from_i128(4400000_i128)),
                 None,
-                Some(i256::from_i128(1123456_i128)),
-                Some(i256::from_i128(1123456_i128)),
+                Some(i256::from_i128(1123456_i128)), // round down
+                Some(i256::from_i128(1123457_i128)), // round up
             ]
         );
 
@@ -3331,9 +3338,10 @@ mod tests {
             Some(2.2),
             Some(4.4),
             None,
-            Some(1.123_456_789_123_4),
-            Some(1.123_456_789_012_345_6),
-            Some(1.123_456_789_012_345_6),
+            Some(1.123_456_489_123_4),     // round down
+            Some(1.123_456_789_123_4),     // round up
+            Some(1.123_456_489_012_345_6), // round down
+            Some(1.123_456_789_012_345_6), // round up
         ]);
         let array = Arc::new(array) as ArrayRef;
         generate_cast_test_case!(
@@ -3345,9 +3353,10 @@ mod tests {
                 Some(i256::from_i128(2200000_i128)),
                 Some(i256::from_i128(4400000_i128)),
                 None,
-                Some(i256::from_i128(1123456_i128)),
-                Some(i256::from_i128(1123456_i128)),
-                Some(i256::from_i128(1123456_i128)),
+                Some(i256::from_i128(1123456_i128)), // round down
+                Some(i256::from_i128(1123457_i128)), // round up
+                Some(i256::from_i128(1123456_i128)), // round down
+                Some(i256::from_i128(1123457_i128)), // round up
             ]
         );
     }
@@ -5994,4 +6003,50 @@ mod tests {
             .collect::<Vec<_>>();
         assert_eq!(&out, &vec!["[0, 1, 2]", "[3, 4, 5]", "[6, 7]"]);
     }
+
+    #[test]
+    #[cfg(not(feature = "force_validate"))]
+    fn test_cast_f64_to_decimal128() {
+        // to reproduce https://github.com/apache/arrow-rs/issues/2997
+
+        let decimal_type = DataType::Decimal128(18, 2);
+        let array = Float64Array::from(vec![
+            Some(0.0699999999),
+            Some(0.0659999999),
+            Some(0.0650000000),
+            Some(0.0649999999),
+        ]);
+        let array = Arc::new(array) as ArrayRef;
+        generate_cast_test_case!(
+            &array,
+            Decimal128Array,
+            &decimal_type,
+            vec![
+                Some(7_i128), // round up
+                Some(7_i128), // round up
+                Some(7_i128), // round up
+                Some(6_i128), // round down
+            ]
+        );
+
+        let decimal_type = DataType::Decimal128(18, 3);
+        let array = Float64Array::from(vec![
+            Some(0.0699999999),
+            Some(0.0659999999),
+            Some(0.0650000000),
+            Some(0.0649999999),
+        ]);
+        let array = Arc::new(array) as ArrayRef;
+        generate_cast_test_case!(
+            &array,
+            Decimal128Array,
+            &decimal_type,
+            vec![
+                Some(70_i128), // round up
+                Some(66_i128), // round up
+                Some(65_i128), // round down
+                Some(65_i128), // round up
+            ]
+        );
+    }
 }