You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2022/11/04 08:07:14 UTC

[arrow-rs] branch master updated: Check overflow when casting integer to decimal (#3009)

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 766f69f71 Check overflow when casting integer to decimal (#3009)
766f69f71 is described below

commit 766f69f715faa619077cc5458aef955b627af715
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Fri Nov 4 01:07:08 2022 -0700

    Check overflow when casting integer to decimal (#3009)
    
    * Check overflow when casting integer to decimal
    
    * Trigger Build
    
    * Combine cast_integer_to_decimal functions of decimal128 and decimal256
    
    * Fix clippy
    
    * Trigger Build
    
    * Use PREFIX way.
---
 arrow-array/src/types.rs          |   5 ++
 arrow/src/compute/kernels/cast.rs | 130 +++++++++++++++++++++++++++-----------
 arrow/src/datatypes/native.rs     |  23 +++++++
 3 files changed, 122 insertions(+), 36 deletions(-)

diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs
index e6197eed1..7c7a5c811 100644
--- a/arrow-array/src/types.rs
+++ b/arrow-array/src/types.rs
@@ -495,6 +495,9 @@ pub trait DecimalType:
     const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType;
     const DEFAULT_TYPE: DataType;
 
+    /// "Decimal128" or "Decimal256", for use in error messages
+    const PREFIX: &'static str;
+
     /// Formats the decimal value with the provided precision and scale
     fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String;
 
@@ -516,6 +519,7 @@ impl DecimalType for Decimal128Type {
     const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal128;
     const DEFAULT_TYPE: DataType =
         DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
+    const PREFIX: &'static str = "Decimal128";
 
     fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
         format_decimal_str(&value.to_string(), precision as usize, scale as usize)
@@ -543,6 +547,7 @@ impl DecimalType for Decimal256Type {
     const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal256;
     const DEFAULT_TYPE: DataType =
         DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
+    const PREFIX: &'static str = "Decimal256";
 
     fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
         format_decimal_str(&value.to_string(), precision as usize, scale as usize)
diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs
index 4ad8dd99e..b1e744d26 100644
--- a/arrow/src/compute/kernels/cast.rs
+++ b/arrow/src/compute/kernels/cast.rs
@@ -309,41 +309,43 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
     cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS)
 }
 
-fn cast_integer_to_decimal128<T: ArrowNumericType>(
-    array: &PrimitiveArray<T>,
-    precision: u8,
-    scale: u8,
-) -> Result<ArrayRef>
-where
-    <T as ArrowPrimitiveType>::Native: AsPrimitive<i128>,
-{
-    let mul: i128 = 10_i128.pow(scale as u32);
-
-    unary::<T, _, Decimal128Type>(array, |v| v.as_() * mul)
-        .with_precision_and_scale(precision, scale)
-        .map(|a| Arc::new(a) as ArrayRef)
-}
-
-fn cast_integer_to_decimal256<T: ArrowNumericType>(
+fn cast_integer_to_decimal<
+    T: ArrowNumericType,
+    D: DecimalType + ArrowPrimitiveType<Native = M>,
+    M,
+>(
     array: &PrimitiveArray<T>,
     precision: u8,
     scale: u8,
+    base: M,
+    cast_options: &CastOptions,
 ) -> Result<ArrayRef>
 where
-    <T as ArrowPrimitiveType>::Native: AsPrimitive<i256>,
+    <T as ArrowPrimitiveType>::Native: AsPrimitive<M>,
+    M: ArrowNativeTypeOp,
 {
-    let mul: i256 = i256::from_i128(10_i128)
-        .checked_pow(scale as u32)
-        .ok_or_else(|| {
-            ArrowError::CastError(format!(
-                "Cannot cast to Decimal256({}, {}). The scale causes overflow.",
-                precision, scale
-            ))
-        })?;
+    let mul: M = base.pow_checked(scale as u32).map_err(|_| {
+        ArrowError::CastError(format!(
+            "Cannot cast to {:?}({}, {}). The scale causes overflow.",
+            D::PREFIX,
+            precision,
+            scale,
+        ))
+    })?;
 
-    unary::<T, _, Decimal256Type>(array, |v| v.as_().wrapping_mul(mul))
-        .with_precision_and_scale(precision, scale)
-        .map(|a| Arc::new(a) as ArrayRef)
+    if cast_options.safe {
+        let iter = array
+            .iter()
+            .map(|v| v.and_then(|v| v.as_().mul_checked(mul).ok()));
+        let casted_array = unsafe { PrimitiveArray::<D>::from_trusted_len_iter(iter) };
+        casted_array
+            .with_precision_and_scale(precision, scale)
+            .map(|a| Arc::new(a) as ArrayRef)
+    } else {
+        try_unary::<T, _, D>(array, |v| v.as_().mul_checked(mul))
+            .and_then(|a| a.with_precision_and_scale(precision, scale))
+            .map(|a| Arc::new(a) as ArrayRef)
+    }
 }
 
 fn cast_floating_point_to_decimal128<T: ArrowNumericType>(
@@ -562,25 +564,33 @@ pub fn cast_with_options(
             // cast data to decimal
             match from_type {
                 // TODO now just support signed numeric to decimal, support decimal to numeric later
-                Int8 => cast_integer_to_decimal128(
+                Int8 => cast_integer_to_decimal::<_, Decimal128Type, _>(
                     as_primitive_array::<Int8Type>(array),
                     *precision,
                     *scale,
+                    10_i128,
+                    cast_options,
                 ),
-                Int16 => cast_integer_to_decimal128(
+                Int16 => cast_integer_to_decimal::<_, Decimal128Type, _>(
                     as_primitive_array::<Int16Type>(array),
                     *precision,
                     *scale,
+                    10_i128,
+                    cast_options,
                 ),
-                Int32 => cast_integer_to_decimal128(
+                Int32 => cast_integer_to_decimal::<_, Decimal128Type, _>(
                     as_primitive_array::<Int32Type>(array),
                     *precision,
                     *scale,
+                    10_i128,
+                    cast_options,
                 ),
-                Int64 => cast_integer_to_decimal128(
+                Int64 => cast_integer_to_decimal::<_, Decimal128Type, _>(
                     as_primitive_array::<Int64Type>(array),
                     *precision,
                     *scale,
+                    10_i128,
+                    cast_options,
                 ),
                 Float32 => cast_floating_point_to_decimal128(
                     as_primitive_array::<Float32Type>(array),
@@ -603,25 +613,33 @@ pub fn cast_with_options(
             // cast data to decimal
             match from_type {
                 // TODO now just support signed numeric to decimal, support decimal to numeric later
-                Int8 => cast_integer_to_decimal256(
+                Int8 => cast_integer_to_decimal::<_, Decimal256Type, _>(
                     as_primitive_array::<Int8Type>(array),
                     *precision,
                     *scale,
+                    i256::from_i128(10_i128),
+                    cast_options,
                 ),
-                Int16 => cast_integer_to_decimal256(
+                Int16 => cast_integer_to_decimal::<_, Decimal256Type, _>(
                     as_primitive_array::<Int16Type>(array),
                     *precision,
                     *scale,
+                    i256::from_i128(10_i128),
+                    cast_options,
                 ),
-                Int32 => cast_integer_to_decimal256(
+                Int32 => cast_integer_to_decimal::<_, Decimal256Type, _>(
                     as_primitive_array::<Int32Type>(array),
                     *precision,
                     *scale,
+                    i256::from_i128(10_i128),
+                    cast_options,
                 ),
-                Int64 => cast_integer_to_decimal256(
+                Int64 => cast_integer_to_decimal::<_, Decimal256Type, _>(
                     as_primitive_array::<Int64Type>(array),
                     *precision,
                     *scale,
+                    i256::from_i128(10_i128),
+                    cast_options,
                 ),
                 Float32 => cast_floating_point_to_decimal256(
                     as_primitive_array::<Float32Type>(array),
@@ -6049,4 +6067,44 @@ mod tests {
             ]
         );
     }
+
+    #[test]
+    fn test_cast_numeric_to_decimal128_overflow() {
+        let array = Int64Array::from(vec![i64::MAX]);
+        let array = Arc::new(array) as ArrayRef;
+        let casted_array = cast_with_options(
+            &array,
+            &DataType::Decimal128(38, 30),
+            &CastOptions { safe: true },
+        );
+        assert!(casted_array.is_ok());
+        assert!(casted_array.unwrap().is_null(0));
+
+        let casted_array = cast_with_options(
+            &array,
+            &DataType::Decimal128(38, 30),
+            &CastOptions { safe: false },
+        );
+        assert!(casted_array.is_err());
+    }
+
+    #[test]
+    fn test_cast_numeric_to_decimal256_overflow() {
+        let array = Int64Array::from(vec![i64::MAX]);
+        let array = Arc::new(array) as ArrayRef;
+        let casted_array = cast_with_options(
+            &array,
+            &DataType::Decimal256(76, 76),
+            &CastOptions { safe: true },
+        );
+        assert!(casted_array.is_ok());
+        assert!(casted_array.unwrap().is_null(0));
+
+        let casted_array = cast_with_options(
+            &array,
+            &DataType::Decimal256(76, 76),
+            &CastOptions { safe: false },
+        );
+        assert!(casted_array.is_err());
+    }
 }
diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs
index bbdec14b4..28ef877a2 100644
--- a/arrow/src/datatypes/native.rs
+++ b/arrow/src/datatypes/native.rs
@@ -19,6 +19,7 @@ use crate::error::{ArrowError, Result};
 pub use arrow_array::ArrowPrimitiveType;
 pub use arrow_buffer::{i256, ArrowNativeType, ToByteSlice};
 use half::f16;
+use num::complex::ComplexFloat;
 
 /// Trait for [`ArrowNativeType`] that adds checked and unchecked arithmetic operations,
 /// and totally ordered comparison operations
@@ -68,6 +69,10 @@ pub trait ArrowNativeTypeOp: ArrowNativeType {
 
     fn neg_wrapping(self) -> Self;
 
+    fn pow_checked(self, exp: u32) -> Result<Self>;
+
+    fn pow_wrapping(self, exp: u32) -> Self;
+
     fn is_zero(self) -> bool;
 
     fn is_eq(self, rhs: Self) -> bool;
@@ -171,6 +176,16 @@ macro_rules! native_type_op {
                 })
             }
 
+            fn pow_checked(self, exp: u32) -> Result<Self> {
+                self.checked_pow(exp).ok_or_else(|| {
+                    ArrowError::ComputeError(format!("Overflow happened on: {:?}", self))
+                })
+            }
+
+            fn pow_wrapping(self, exp: u32) -> Self {
+                self.wrapping_pow(exp)
+            }
+
             fn neg_wrapping(self) -> Self {
                 self.wrapping_neg()
             }
@@ -279,6 +294,14 @@ macro_rules! native_type_float_op {
                 -self
             }
 
+            fn pow_checked(self, exp: u32) -> Result<Self> {
+                Ok(self.powi(exp as i32))
+            }
+
+            fn pow_wrapping(self, exp: u32) -> Self {
+                self.powi(exp as i32)
+            }
+
             fn is_zero(self) -> bool {
                 self == $zero
             }