You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/09/25 09:46:01 UTC

[arrow-rs] branch master updated: fix: add missing precision overflow checking for `cast_string_to_decimal` (#4830)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 7e7ac153c6 fix: add missing precision overflow checking for `cast_string_to_decimal` (#4830)
7e7ac153c6 is described below

commit 7e7ac153c69a0b227ae11e0caf0f00b04b85cd23
Author: Jonah Gao <jo...@gmail.com>
AuthorDate: Mon Sep 25 17:45:55 2023 +0800

    fix: add missing precision overflow checking for `cast_string_to_decimal` (#4830)
    
    * fix: add missing precision overflow checking for `cast_string_to_decimal`
    
    * Add test_cast_string_to_decimal256_precision_overflow
---
 arrow-cast/src/cast.rs | 75 +++++++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 68 insertions(+), 7 deletions(-)

diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index 7b8e6144bb..e7727565c9 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -2801,6 +2801,11 @@ where
     if cast_options.safe {
         let iter = from.iter().map(|v| {
             v.and_then(|v| parse_string_to_decimal_native::<T>(v, scale as usize).ok())
+                .and_then(|v| {
+                    T::validate_decimal_precision(v, precision)
+                        .is_ok()
+                        .then_some(v)
+                })
         });
         // Benefit:
         //     20% performance improvement
@@ -2815,13 +2820,17 @@ where
             .iter()
             .map(|v| {
                 v.map(|v| {
-                    parse_string_to_decimal_native::<T>(v, scale as usize).map_err(|_| {
-                        ArrowError::CastError(format!(
-                            "Cannot cast string '{}' to value of {:?} type",
-                            v,
-                            T::DATA_TYPE,
-                        ))
-                    })
+                    parse_string_to_decimal_native::<T>(v, scale as usize)
+                        .map_err(|_| {
+                            ArrowError::CastError(format!(
+                                "Cannot cast string '{}' to value of {:?} type",
+                                v,
+                                T::DATA_TYPE,
+                            ))
+                        })
+                        .and_then(|v| {
+                            T::validate_decimal_precision(v, precision).map(|_| v)
+                        })
                 })
                 .transpose()
             })
@@ -8152,6 +8161,32 @@ mod tests {
         );
     }
 
+    #[test]
+    fn test_cast_string_to_decimal128_precision_overflow() {
+        let array = StringArray::from(vec!["1000".to_string()]);
+        let array = Arc::new(array) as ArrayRef;
+        let casted_array = cast_with_options(
+            &array,
+            &DataType::Decimal128(10, 8),
+            &CastOptions {
+                safe: true,
+                format_options: FormatOptions::default(),
+            },
+        );
+        assert!(casted_array.is_ok());
+        assert!(casted_array.unwrap().is_null(0));
+
+        let err = cast_with_options(
+            &array,
+            &DataType::Decimal128(10, 8),
+            &CastOptions {
+                safe: false,
+                format_options: FormatOptions::default(),
+            },
+        );
+        assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal128 of precision 10. Max is 9999999999", err.unwrap_err().to_string());
+    }
+
     #[test]
     fn test_cast_utf8_to_decimal128_overflow() {
         let overflow_str_array = StringArray::from(vec![
@@ -8209,6 +8244,32 @@ mod tests {
         assert!(decimal_arr.is_null(6));
     }
 
+    #[test]
+    fn test_cast_string_to_decimal256_precision_overflow() {
+        let array = StringArray::from(vec!["1000".to_string()]);
+        let array = Arc::new(array) as ArrayRef;
+        let casted_array = cast_with_options(
+            &array,
+            &DataType::Decimal256(10, 8),
+            &CastOptions {
+                safe: true,
+                format_options: FormatOptions::default(),
+            },
+        );
+        assert!(casted_array.is_ok());
+        assert!(casted_array.unwrap().is_null(0));
+
+        let err = cast_with_options(
+            &array,
+            &DataType::Decimal256(10, 8),
+            &CastOptions {
+                safe: false,
+                format_options: FormatOptions::default(),
+            },
+        );
+        assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal256 of precision 10. Max is 9999999999", err.unwrap_err().to_string());
+    }
+
     #[test]
     fn test_cast_utf8_to_decimal256_overflow() {
         let overflow_str_array = StringArray::from(vec![