You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2022/11/15 08:13:35 UTC

[arrow-rs] branch master updated: Check overflow while casting between decimal types (#3076)

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 7d41e1c19 Check overflow while casting between decimal types (#3076)
7d41e1c19 is described below

commit 7d41e1c194b24238010e1a26c4864f535a4899eb
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Tue Nov 15 00:13:29 2022 -0800

    Check overflow while casting between decimal types (#3076)
---
 arrow-cast/src/cast.rs | 439 ++++++++++++++++++++++++++++++++++++++++---------
 1 file changed, 359 insertions(+), 80 deletions(-)

diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index d6dbf3061..79c23bfac 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -612,16 +612,16 @@ pub fn cast_with_options(
     }
     match (from_type, to_type) {
         (Decimal128(_, s1), Decimal128(p2, s2)) => {
-            cast_decimal_to_decimal::<16, 16>(array, s1, p2, s2)
+            cast_decimal_to_decimal_with_option::<16, 16>(array, s1, p2, s2, cast_options)
         }
         (Decimal256(_, s1), Decimal256(p2, s2)) => {
-            cast_decimal_to_decimal::<32, 32>(array, s1, p2, s2)
+            cast_decimal_to_decimal_with_option::<32, 32>(array, s1, p2, s2, cast_options)
         }
         (Decimal128(_, s1), Decimal256(p2, s2)) => {
-            cast_decimal_to_decimal::<16, 32>(array, s1, p2, s2)
+            cast_decimal_to_decimal_with_option::<16, 32>(array, s1, p2, s2, cast_options)
         }
         (Decimal256(_, s1), Decimal128(p2, s2)) => {
-            cast_decimal_to_decimal::<32, 16>(array, s1, p2, s2)
+            cast_decimal_to_decimal_with_option::<32, 16>(array, s1, p2, s2, cast_options)
         }
         (Decimal128(_, scale), _) => {
             // cast decimal to other type
@@ -1916,7 +1916,36 @@ const fn time_unit_multiple(unit: &TimeUnit) -> i64 {
 }
 
 /// Cast one type of decimal array to another type of decimal array
-fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
+fn cast_decimal_to_decimal_with_option<
+    const BYTE_WIDTH1: usize,
+    const BYTE_WIDTH2: usize,
+>(
+    array: &ArrayRef,
+    input_scale: &u8,
+    output_precision: &u8,
+    output_scale: &u8,
+    cast_options: &CastOptions,
+) -> Result<ArrayRef, ArrowError> {
+    if cast_options.safe {
+        cast_decimal_to_decimal_safe::<BYTE_WIDTH1, BYTE_WIDTH2>(
+            array,
+            input_scale,
+            output_precision,
+            output_scale,
+        )
+    } else {
+        cast_decimal_to_decimal::<BYTE_WIDTH1, BYTE_WIDTH2>(
+            array,
+            input_scale,
+            output_precision,
+            output_scale,
+        )
+    }
+}
+
+/// Cast one type of decimal array to another type of decimal array. Returning NULLs for
+/// the array values when cast failures happen.
+fn cast_decimal_to_decimal_safe<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
     array: &ArrayRef,
     input_scale: &u8,
     output_precision: &u8,
@@ -1928,54 +1957,50 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
         let div = 10_i128.pow((input_scale - output_scale) as u32);
         if BYTE_WIDTH1 == 16 {
             let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
-            let iter = array.iter().map(|v| v.map(|v| v.wrapping_div(div)));
             if BYTE_WIDTH2 == 16 {
-                let output_array = iter
-                    .collect::<Decimal128Array>()
-                    .with_precision_and_scale(*output_precision, *output_scale)?;
-
-                Ok(Arc::new(output_array))
+                let iter = array
+                    .iter()
+                    .map(|v| v.and_then(|v| v.div_checked(div).ok()));
+                let casted_array = unsafe {
+                    PrimitiveArray::<Decimal128Type>::from_trusted_len_iter(iter)
+                };
+                casted_array
+                    .with_precision_and_scale(*output_precision, *output_scale)
+                    .map(|a| Arc::new(a) as ArrayRef)
             } else {
-                let output_array = iter
-                    .map(|v| v.map(i256::from_i128))
-                    .collect::<Decimal256Array>()
-                    .with_precision_and_scale(*output_precision, *output_scale)?;
-
-                Ok(Arc::new(output_array))
+                let iter = array.iter().map(|v| {
+                    v.and_then(|v| v.div_checked(div).ok().map(i256::from_i128))
+                });
+                let casted_array = unsafe {
+                    PrimitiveArray::<Decimal256Type>::from_trusted_len_iter(iter)
+                };
+                casted_array
+                    .with_precision_and_scale(*output_precision, *output_scale)
+                    .map(|a| Arc::new(a) as ArrayRef)
             }
         } else {
             let array = array.as_any().downcast_ref::<Decimal256Array>().unwrap();
             let div = i256::from_i128(div);
-            let iter = array.iter().map(|v| v.map(|v| v.wrapping_div(div)));
             if BYTE_WIDTH2 == 16 {
-                let values = iter
-                    .map(|v| {
-                        if v.is_none() {
-                            Ok(None)
-                        } else {
-                            v.as_ref().and_then(|v| v.to_i128())
-                                .ok_or_else(|| {
-                                    ArrowError::InvalidArgumentError(
-                                        format!("{:?} cannot be casted to 128-bit integer for Decimal128", v),
-                                    )
-                                })
-                                .map(Some)
-                        }
-                    })
-                    .collect::<Result<Vec<_>, _>>()?;
-
-                let output_array = values
-                    .into_iter()
-                    .collect::<Decimal128Array>()
-                    .with_precision_and_scale(*output_precision, *output_scale)?;
-
-                Ok(Arc::new(output_array))
+                let iter = array.iter().map(|v| {
+                    v.and_then(|v| v.div_checked(div).ok().and_then(|v| v.to_i128()))
+                });
+                let casted_array = unsafe {
+                    PrimitiveArray::<Decimal128Type>::from_trusted_len_iter(iter)
+                };
+                casted_array
+                    .with_precision_and_scale(*output_precision, *output_scale)
+                    .map(|a| Arc::new(a) as ArrayRef)
             } else {
-                let output_array = iter
-                    .collect::<Decimal256Array>()
-                    .with_precision_and_scale(*output_precision, *output_scale)?;
-
-                Ok(Arc::new(output_array))
+                let iter = array
+                    .iter()
+                    .map(|v| v.and_then(|v| v.div_checked(div).ok()));
+                let casted_array = unsafe {
+                    PrimitiveArray::<Decimal256Type>::from_trusted_len_iter(iter)
+                };
+                casted_array
+                    .with_precision_and_scale(*output_precision, *output_scale)
+                    .map(|a| Arc::new(a) as ArrayRef)
             }
         }
     } else {
@@ -1984,54 +2009,278 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
         let mul = 10_i128.pow((output_scale - input_scale) as u32);
         if BYTE_WIDTH1 == 16 {
             let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
-            let iter = array.iter().map(|v| v.map(|v| v.wrapping_mul(mul)));
             if BYTE_WIDTH2 == 16 {
-                let output_array = iter
-                    .collect::<Decimal128Array>()
-                    .with_precision_and_scale(*output_precision, *output_scale)?;
+                let iter = array
+                    .iter()
+                    .map(|v| v.and_then(|v| v.mul_checked(mul).ok()));
+                let casted_array = unsafe {
+                    PrimitiveArray::<Decimal128Type>::from_trusted_len_iter(iter)
+                };
+                casted_array
+                    .with_precision_and_scale(*output_precision, *output_scale)
+                    .map(|a| Arc::new(a) as ArrayRef)
+            } else {
+                let iter = array.iter().map(|v| {
+                    v.and_then(|v| v.mul_checked(mul).ok().map(i256::from_i128))
+                });
+                let casted_array = unsafe {
+                    PrimitiveArray::<Decimal256Type>::from_trusted_len_iter(iter)
+                };
+                casted_array
+                    .with_precision_and_scale(*output_precision, *output_scale)
+                    .map(|a| Arc::new(a) as ArrayRef)
+            }
+        } else {
+            let array = array.as_any().downcast_ref::<Decimal256Array>().unwrap();
+            let mul = i256::from_i128(mul);
+            if BYTE_WIDTH2 == 16 {
+                let iter = array.iter().map(|v| {
+                    v.and_then(|v| v.mul_checked(mul).ok().and_then(|v| v.to_i128()))
+                });
+                let casted_array = unsafe {
+                    PrimitiveArray::<Decimal128Type>::from_trusted_len_iter(iter)
+                };
+                casted_array
+                    .with_precision_and_scale(*output_precision, *output_scale)
+                    .map(|a| Arc::new(a) as ArrayRef)
+            } else {
+                let iter = array
+                    .iter()
+                    .map(|v| v.and_then(|v| v.mul_checked(mul).ok()));
+                let casted_array = unsafe {
+                    PrimitiveArray::<Decimal256Type>::from_trusted_len_iter(iter)
+                };
+                casted_array
+                    .with_precision_and_scale(*output_precision, *output_scale)
+                    .map(|a| Arc::new(a) as ArrayRef)
+            }
+        }
+    }
+}
+
+/// Cast one type of decimal array to another type of decimal array. Returning `Err` if
+/// cast failure happens.
+fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
+    array: &ArrayRef,
+    input_scale: &u8,
+    output_precision: &u8,
+    output_scale: &u8,
+) -> Result<ArrayRef, ArrowError> {
+    if input_scale > output_scale {
+        // For example, input_scale is 4 and output_scale is 3;
+        // Original value is 11234_i128, and will be cast to 1123_i128.
+        let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
+        if BYTE_WIDTH1 == 16 {
+            if BYTE_WIDTH2 == 16 {
+                let div = 10_i128
+                    .pow_checked((input_scale - output_scale) as u32)
+                    .map_err(|_| {
+                        ArrowError::CastError(format!(
+                            "Cannot cast. The scale {} causes overflow.",
+                            *output_scale,
+                        ))
+                    })?;
 
-                Ok(Arc::new(output_array))
+                array
+                    .try_unary::<_, Decimal128Type, _>(|v| {
+                        v.checked_div(div).ok_or_else(|| {
+                            ArrowError::CastError(format!(
+                                "Cannot cast to {:?}({}, {}). Overflowing on {:?}",
+                                Decimal128Type::PREFIX,
+                                *output_precision,
+                                *output_scale,
+                                v
+                            ))
+                        })
+                    })
+                    .and_then(|a| {
+                        a.with_precision_and_scale(*output_precision, *output_scale)
+                    })
+                    .map(|a| Arc::new(a) as ArrayRef)
             } else {
-                let output_array = iter
-                    .map(|v| v.map(i256::from_i128))
-                    .collect::<Decimal256Array>()
-                    .with_precision_and_scale(*output_precision, *output_scale)?;
+                let div = i256::from_i128(10_i128)
+                    .pow_checked((input_scale - output_scale) as u32)
+                    .map_err(|_| {
+                        ArrowError::CastError(format!(
+                            "Cannot cast. The scale {} causes overflow.",
+                            *output_scale,
+                        ))
+                    })?;
 
-                Ok(Arc::new(output_array))
+                array
+                    .try_unary::<_, Decimal256Type, _>(|v| {
+                        i256::from_i128(v).checked_div(div).ok_or_else(|| {
+                            ArrowError::CastError(format!(
+                                "Cannot cast to {:?}({}, {}). Overflowing on {:?}",
+                                Decimal256Type::PREFIX,
+                                *output_precision,
+                                *output_scale,
+                                v
+                            ))
+                        })
+                    })
+                    .and_then(|a| {
+                        a.with_precision_and_scale(*output_precision, *output_scale)
+                    })
+                    .map(|a| Arc::new(a) as ArrayRef)
             }
         } else {
             let array = array.as_any().downcast_ref::<Decimal256Array>().unwrap();
-            let mul = i256::from_i128(mul);
-            let iter = array.iter().map(|v| v.map(|v| v.wrapping_mul(mul)));
+            let div = i256::from_i128(10_i128)
+                .pow_checked((input_scale - output_scale) as u32)
+                .map_err(|_| {
+                    ArrowError::CastError(format!(
+                        "Cannot cast. The scale {} causes overflow.",
+                        *output_scale,
+                    ))
+                })?;
             if BYTE_WIDTH2 == 16 {
-                let values = iter
-                    .map(|v| {
-                        if v.is_none() {
-                            Ok(None)
-                        } else {
-                            v.as_ref().and_then(|v| v.to_i128())
-                                .ok_or_else(|| {
-                                    ArrowError::InvalidArgumentError(
-                                        format!("{:?} cannot be casted to 128-bit integer for Decimal128", v),
-                                    )
-                                })
-                                .map(Some)
-                        }
+                array
+                    .try_unary::<_, Decimal128Type, _>(|v| {
+                        v.checked_div(div).ok_or_else(|| {
+                            ArrowError::CastError(format!(
+                                "Cannot cast to {:?}({}, {}). Overflowing on {:?}",
+                                Decimal128Type::PREFIX,
+                                *output_precision,
+                                *output_scale,
+                                v
+                            ))
+                        }).and_then(|v| v.to_i128().ok_or_else(|| {
+                            ArrowError::InvalidArgumentError(
+                                format!("{:?} cannot be casted to 128-bit integer for Decimal128", v),
+                            )
+                        }))
+                    })
+                    .and_then(|a| {
+                        a.with_precision_and_scale(*output_precision, *output_scale)
                     })
-                    .collect::<Result<Vec<_>, _>>()?;
+                    .map(|a| Arc::new(a) as ArrayRef)
+            } else {
+                array
+                    .try_unary::<_, Decimal256Type, _>(|v| {
+                        v.checked_div(div).ok_or_else(|| {
+                            ArrowError::CastError(format!(
+                                "Cannot cast to {:?}({}, {}). Overflowing on {:?}",
+                                Decimal256Type::PREFIX,
+                                *output_precision,
+                                *output_scale,
+                                v
+                            ))
+                        })
+                    })
+                    .and_then(|a| {
+                        a.with_precision_and_scale(*output_precision, *output_scale)
+                    })
+                    .map(|a| Arc::new(a) as ArrayRef)
+            }
+        }
+    } else {
+        // For example, input_scale is 3 and output_scale is 4;
+        // Original value is 1123_i128, and will be cast to 11230_i128.
+        if BYTE_WIDTH1 == 16 {
+            let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
 
-                let output_array = values
-                    .into_iter()
-                    .collect::<Decimal128Array>()
-                    .with_precision_and_scale(*output_precision, *output_scale)?;
+            if BYTE_WIDTH2 == 16 {
+                let mul = 10_i128
+                    .pow_checked((output_scale - input_scale) as u32)
+                    .map_err(|_| {
+                        ArrowError::CastError(format!(
+                            "Cannot cast. The scale {} causes overflow.",
+                            *output_scale,
+                        ))
+                    })?;
 
-                Ok(Arc::new(output_array))
+                array
+                    .try_unary::<_, Decimal128Type, _>(|v| {
+                        v.checked_mul(mul).ok_or_else(|| {
+                            ArrowError::CastError(format!(
+                                "Cannot cast to {:?}({}, {}). Overflowing on {:?}",
+                                Decimal128Type::PREFIX,
+                                *output_precision,
+                                *output_scale,
+                                v
+                            ))
+                        })
+                    })
+                    .and_then(|a| {
+                        a.with_precision_and_scale(*output_precision, *output_scale)
+                    })
+                    .map(|a| Arc::new(a) as ArrayRef)
             } else {
-                let output_array = iter
-                    .collect::<Decimal256Array>()
-                    .with_precision_and_scale(*output_precision, *output_scale)?;
+                let mul = i256::from_i128(10_i128)
+                    .pow_checked((output_scale - input_scale) as u32)
+                    .map_err(|_| {
+                        ArrowError::CastError(format!(
+                            "Cannot cast. The scale {} causes overflow.",
+                            *output_scale,
+                        ))
+                    })?;
 
-                Ok(Arc::new(output_array))
+                array
+                    .try_unary::<_, Decimal256Type, _>(|v| {
+                        i256::from_i128(v).checked_mul(mul).ok_or_else(|| {
+                            ArrowError::CastError(format!(
+                                "Cannot cast to {:?}({}, {}). Overflowing on {:?}",
+                                Decimal256Type::PREFIX,
+                                *output_precision,
+                                *output_scale,
+                                v
+                            ))
+                        })
+                    })
+                    .and_then(|a| {
+                        a.with_precision_and_scale(*output_precision, *output_scale)
+                    })
+                    .map(|a| Arc::new(a) as ArrayRef)
+            }
+        } else {
+            let array = array.as_any().downcast_ref::<Decimal256Array>().unwrap();
+            let mul = i256::from_i128(10_i128)
+                .pow_checked((output_scale - input_scale) as u32)
+                .map_err(|_| {
+                    ArrowError::CastError(format!(
+                        "Cannot cast. The scale {} causes overflow.",
+                        *output_scale,
+                    ))
+                })?;
+            if BYTE_WIDTH2 == 16 {
+                array
+                    .try_unary::<_, Decimal128Type, _>(|v| {
+                        v.checked_mul(mul).ok_or_else(|| {
+                            ArrowError::CastError(format!(
+                                "Cannot cast to {:?}({}, {}). Overflowing on {:?}",
+                                Decimal128Type::PREFIX,
+                                *output_precision,
+                                *output_scale,
+                                v
+                            ))
+                        }).and_then(|v| v.to_i128().ok_or_else(|| {
+                            ArrowError::InvalidArgumentError(
+                                format!("{:?} cannot be casted to 128-bit integer for Decimal128", v),
+                            )
+                        }))
+                    })
+                    .and_then(|a| {
+                        a.with_precision_and_scale(*output_precision, *output_scale)
+                    })
+                    .map(|a| Arc::new(a) as ArrayRef)
+            } else {
+                array
+                    .try_unary::<_, Decimal256Type, _>(|v| {
+                        v.checked_mul(mul).ok_or_else(|| {
+                            ArrowError::CastError(format!(
+                                "Cannot cast to {:?}({}, {}). Overflowing on {:?}",
+                                Decimal256Type::PREFIX,
+                                *output_precision,
+                                *output_scale,
+                                v
+                            ))
+                        })
+                    })
+                    .and_then(|a| {
+                        a.with_precision_and_scale(*output_precision, *output_scale)
+                    })
+                    .map(|a| Arc::new(a) as ArrayRef)
             }
         }
     }
@@ -3343,6 +3592,36 @@ mod tests {
                    err.unwrap_err().to_string());
     }
 
+    #[test]
+    fn test_cast_decimal128_to_decimal128_overflow() {
+        let input_type = DataType::Decimal128(38, 3);
+        let output_type = DataType::Decimal128(38, 38);
+        assert!(can_cast_types(&input_type, &output_type));
+
+        let array = vec![Some(i128::MAX)];
+        let input_decimal_array = create_decimal_array(array, 38, 3).unwrap();
+        let array = Arc::new(input_decimal_array) as ArrayRef;
+        let result =
+            cast_with_options(&array, &output_type, &CastOptions { safe: false });
+        assert_eq!("Cast error: Cannot cast to \"Decimal128\"(38, 38). Overflowing on 170141183460469231731687303715884105727",
+                   result.unwrap_err().to_string());
+    }
+
+    #[test]
+    fn test_cast_decimal128_to_decimal256_overflow() {
+        let input_type = DataType::Decimal128(38, 3);
+        let output_type = DataType::Decimal256(76, 76);
+        assert!(can_cast_types(&input_type, &output_type));
+
+        let array = vec![Some(i128::MAX)];
+        let input_decimal_array = create_decimal_array(array, 38, 3).unwrap();
+        let array = Arc::new(input_decimal_array) as ArrayRef;
+        let result =
+            cast_with_options(&array, &output_type, &CastOptions { safe: false });
+        assert_eq!("Cast error: Cannot cast to \"Decimal256\"(76, 76). Overflowing on 170141183460469231731687303715884105727",
+                   result.unwrap_err().to_string());
+    }
+
     #[test]
     fn test_cast_decimal128_to_decimal256() {
         let input_type = DataType::Decimal128(20, 3);