You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2022/11/08 08:58:39 UTC

[arrow-rs] branch master updated: Cast decimal256 to signed integer (#3040)

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new a950b52ec Cast decimal256 to signed integer (#3040)
a950b52ec is described below

commit a950b52ec83e5ac14e147f9605f871ba6bd06ee0
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Tue Nov 8 00:58:34 2022 -0800

    Cast decimal256 to signed integer (#3040)
    
    * Cast decimal256 to signed integer
    
    * Use ToPrimitive
    
    * Add CastOptions
---
 arrow-buffer/src/bigint.rs |  87 ++++++++++++++++--
 arrow-cast/src/cast.rs     | 216 +++++++++++++++++++++++++++++++++++++--------
 2 files changed, 261 insertions(+), 42 deletions(-)

diff --git a/arrow-buffer/src/bigint.rs b/arrow-buffer/src/bigint.rs
index 8dd57d2c4..be02c2857 100644
--- a/arrow-buffer/src/bigint.rs
+++ b/arrow-buffer/src/bigint.rs
@@ -16,7 +16,7 @@
 // under the License.
 
 use num::cast::AsPrimitive;
-use num::{BigInt, FromPrimitive};
+use num::{BigInt, FromPrimitive, ToPrimitive};
 use std::cmp::Ordering;
 
 /// A signed 256-bit integer
@@ -388,13 +388,15 @@ impl i256 {
 
 /// Temporary workaround due to lack of stable const array slicing
 /// See <https://github.com/rust-lang/rust/issues/90091>
-const fn split_array(vals: [u8; 32]) -> ([u8; 16], [u8; 16]) {
-    let mut a = [0; 16];
-    let mut b = [0; 16];
+const fn split_array<const N: usize, const M: usize>(
+    vals: [u8; N],
+) -> ([u8; M], [u8; M]) {
+    let mut a = [0; M];
+    let mut b = [0; M];
     let mut i = 0;
-    while i != 16 {
+    while i != M {
         a[i] = vals[i];
-        b[i] = vals[i + 16];
+        b[i] = vals[i + M];
         i += 1;
     }
     (a, b)
@@ -478,6 +480,44 @@ define_as_primitive!(i16);
 define_as_primitive!(i32);
 define_as_primitive!(i64);
 
+impl ToPrimitive for i256 {
+    fn to_i64(&self) -> Option<i64> {
+        let as_i128 = self.low as i128;
+
+        let high_negative = self.high < 0;
+        let low_negative = as_i128 < 0;
+        let high_valid = self.high == -1 || self.high == 0;
+
+        if high_negative == low_negative && high_valid {
+            let (low_bytes, high_bytes) = split_array(u128::to_le_bytes(self.low));
+            let high = i64::from_le_bytes(high_bytes);
+            let low = i64::from_le_bytes(low_bytes);
+
+            let high_negative = high < 0;
+            let low_negative = low < 0;
+            let high_valid = self.high == -1 || self.high == 0;
+
+            (high_negative == low_negative && high_valid).then_some(low)
+        } else {
+            None
+        }
+    }
+
+    fn to_u64(&self) -> Option<u64> {
+        let as_i128 = self.low as i128;
+
+        let high_negative = self.high < 0;
+        let low_negative = as_i128 < 0;
+        let high_valid = self.high == -1 || self.high == 0;
+
+        if high_negative == low_negative && high_valid {
+            self.low.to_u64()
+        } else {
+            None
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -676,4 +716,39 @@ mod tests {
             test_ops(i256::from_le_bytes(l), i256::from_le_bytes(r))
         }
     }
+
+    #[test]
+    fn test_i256_to_primitive() {
+        let a = i256::MAX;
+        assert!(a.to_i64().is_none());
+        assert!(a.to_u64().is_none());
+
+        let a = i256::from_i128(i128::MAX);
+        assert!(a.to_i64().is_none());
+        assert!(a.to_u64().is_none());
+
+        let a = i256::from_i128(i64::MAX as i128);
+        assert_eq!(a.to_i64().unwrap(), i64::MAX);
+        assert_eq!(a.to_u64().unwrap(), i64::MAX as u64);
+
+        let a = i256::from_i128(i64::MAX as i128 + 1);
+        assert!(a.to_i64().is_none());
+        assert_eq!(a.to_u64().unwrap(), i64::MAX as u64 + 1);
+
+        let a = i256::MIN;
+        assert!(a.to_i64().is_none());
+        assert!(a.to_u64().is_none());
+
+        let a = i256::from_i128(i128::MIN);
+        assert!(a.to_i64().is_none());
+        assert!(a.to_u64().is_none());
+
+        let a = i256::from_i128(i64::MIN as i128);
+        assert_eq!(a.to_i64().unwrap(), i64::MIN);
+        assert!(a.to_u64().is_none());
+
+        let a = i256::from_i128(i64::MIN as i128 - 1);
+        assert!(a.to_i64().is_none());
+        assert!(a.to_u64().is_none());
+    }
 }
diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index e394426bd..1cc814730 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -81,7 +81,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
         (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal128(_, _)) |
         (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal256(_, _)) |
         // decimal to signed numeric
-        (Decimal128(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
+        (Decimal128(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) |
+        (Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64 )
         | (
             Null,
             Boolean
@@ -433,34 +434,65 @@ fn cast_reinterpret_arrays<
     ))
 }
 
-// cast the decimal array to integer array
-macro_rules! cast_decimal_to_integer {
-    ($ARRAY:expr, $SCALE : ident, $VALUE_BUILDER: ident, $NATIVE_TYPE : ident, $DATA_TYPE : expr) => {{
-        let array = $ARRAY.as_any().downcast_ref::<Decimal128Array>().unwrap();
-        let mut value_builder = $VALUE_BUILDER::with_capacity(array.len());
-        let div: i128 = 10_i128.pow(*$SCALE as u32);
-        let min_bound = ($NATIVE_TYPE::MIN) as i128;
-        let max_bound = ($NATIVE_TYPE::MAX) as i128;
+fn cast_decimal_to_integer<D, T>(
+    array: &ArrayRef,
+    base: D::Native,
+    scale: u8,
+    cast_options: &CastOptions,
+) -> Result<ArrayRef, ArrowError>
+where
+    T: ArrowPrimitiveType,
+    <T as ArrowPrimitiveType>::Native: NumCast,
+    D: DecimalType + ArrowPrimitiveType,
+    <D as ArrowPrimitiveType>::Native: ArrowNativeTypeOp + ToPrimitive,
+{
+    let array = array.as_any().downcast_ref::<PrimitiveArray<D>>().unwrap();
+
+    let div: D::Native = base.pow_checked(scale as u32).map_err(|_| {
+        ArrowError::CastError(format!(
+            "Cannot cast to {:?}. The scale {} causes overflow.",
+            D::PREFIX,
+            scale,
+        ))
+    })?;
+
+    let mut value_builder = PrimitiveBuilder::<T>::with_capacity(array.len());
+
+    if cast_options.safe {
         for i in 0..array.len() {
             if array.is_null(i) {
                 value_builder.append_null();
             } else {
-                let v = array.value(i) / div;
-                // check the overflow
-                // For example: Decimal(128,10,0) as i8
-                // 128 is out of range i8
-                if v <= max_bound && v >= min_bound {
-                    value_builder.append_value(v as $NATIVE_TYPE);
-                } else {
-                    return Err(ArrowError::CastError(format!(
-                        "value of {} is out of range {}",
-                        v, $DATA_TYPE
-                    )));
-                }
+                let v = array
+                    .value(i)
+                    .div_checked(div)
+                    .ok()
+                    .and_then(<T::Native as NumCast>::from::<D::Native>);
+
+                value_builder.append_option(v);
             }
         }
-        Ok(Arc::new(value_builder.finish()))
-    }};
+    } else {
+        for i in 0..array.len() {
+            if array.is_null(i) {
+                value_builder.append_null();
+            } else {
+                let v = array.value(i).div_checked(div)?;
+
+                let value =
+                    <T::Native as NumCast>::from::<D::Native>(v).ok_or_else(|| {
+                        ArrowError::CastError(format!(
+                            "value of {:?} is out of range {}",
+                            v,
+                            T::DATA_TYPE
+                        ))
+                    })?;
+
+                value_builder.append_value(value);
+            }
+        }
+    }
+    Ok(Arc::new(value_builder.finish()))
 }
 
 // cast the decimal array to floating-point array
@@ -576,18 +608,30 @@ pub fn cast_with_options(
         (Decimal128(_, scale), _) => {
             // cast decimal to other type
             match to_type {
-                Int8 => {
-                    cast_decimal_to_integer!(array, scale, Int8Builder, i8, Int8)
-                }
-                Int16 => {
-                    cast_decimal_to_integer!(array, scale, Int16Builder, i16, Int16)
-                }
-                Int32 => {
-                    cast_decimal_to_integer!(array, scale, Int32Builder, i32, Int32)
-                }
-                Int64 => {
-                    cast_decimal_to_integer!(array, scale, Int64Builder, i64, Int64)
-                }
+                Int8 => cast_decimal_to_integer::<Decimal128Type, Int8Type>(
+                    array,
+                    10_i128,
+                    *scale,
+                    cast_options,
+                ),
+                Int16 => cast_decimal_to_integer::<Decimal128Type, Int16Type>(
+                    array,
+                    10_i128,
+                    *scale,
+                    cast_options,
+                ),
+                Int32 => cast_decimal_to_integer::<Decimal128Type, Int32Type>(
+                    array,
+                    10_i128,
+                    *scale,
+                    cast_options,
+                ),
+                Int64 => cast_decimal_to_integer::<Decimal128Type, Int64Type>(
+                    array,
+                    10_i128,
+                    *scale,
+                    cast_options,
+                ),
                 Float32 => {
                     cast_decimal_to_float!(array, scale, Float32Builder, f32)
                 }
@@ -601,6 +645,40 @@ pub fn cast_with_options(
                 ))),
             }
         }
+        (Decimal256(_, scale), _) => {
+            // cast decimal to other type
+            match to_type {
+                Int8 => cast_decimal_to_integer::<Decimal256Type, Int8Type>(
+                    array,
+                    i256::from_i128(10_i128),
+                    *scale,
+                    cast_options,
+                ),
+                Int16 => cast_decimal_to_integer::<Decimal256Type, Int16Type>(
+                    array,
+                    i256::from_i128(10_i128),
+                    *scale,
+                    cast_options,
+                ),
+                Int32 => cast_decimal_to_integer::<Decimal256Type, Int32Type>(
+                    array,
+                    i256::from_i128(10_i128),
+                    *scale,
+                    cast_options,
+                ),
+                Int64 => cast_decimal_to_integer::<Decimal256Type, Int64Type>(
+                    array,
+                    i256::from_i128(10_i128),
+                    *scale,
+                    cast_options,
+                ),
+                Null => Ok(new_null_array(to_type, array.len())),
+                _ => Err(ArrowError::CastError(format!(
+                    "Casting from {:?} to {:?} not supported",
+                    from_type, to_type
+                ))),
+            }
+        }
         (_, Decimal128(precision, scale)) => {
             // cast data to decimal
             match from_type {
@@ -3154,12 +3232,18 @@ mod tests {
         let value_array: Vec<Option<i128>> = vec![Some(24400)];
         let decimal_array = create_decimal_array(value_array, 38, 2).unwrap();
         let array = Arc::new(decimal_array) as ArrayRef;
-        let casted_array = cast(&array, &DataType::Int8);
+        let casted_array =
+            cast_with_options(&array, &DataType::Int8, &CastOptions { safe: false });
         assert_eq!(
             "Cast error: value of 244 is out of range Int8".to_string(),
             casted_array.unwrap_err().to_string()
         );
 
+        let casted_array =
+            cast_with_options(&array, &DataType::Int8, &CastOptions { safe: true });
+        assert!(casted_array.is_ok());
+        assert!(casted_array.unwrap().is_null(0));
+
         // loss the precision: convert decimal to f32、f64
         // f32
         // 112345678_f32 and 112345679_f32 are same, so the 112345679_f32 will lose precision.
@@ -3218,6 +3302,66 @@ mod tests {
         );
     }
 
+    #[test]
+    fn test_cast_decimal256_to_numeric() {
+        let decimal_type = DataType::Decimal256(38, 2);
+        // negative test
+        assert!(!can_cast_types(&decimal_type, &DataType::UInt8));
+        let value_array: Vec<Option<i256>> = vec![
+            Some(i256::from_i128(125)),
+            Some(i256::from_i128(225)),
+            Some(i256::from_i128(325)),
+            None,
+            Some(i256::from_i128(525)),
+        ];
+        let decimal_array = create_decimal256_array(value_array, 38, 2).unwrap();
+        let array = Arc::new(decimal_array) as ArrayRef;
+        // i8
+        generate_cast_test_case!(
+            &array,
+            Int8Array,
+            &DataType::Int8,
+            vec![Some(1_i8), Some(2_i8), Some(3_i8), None, Some(5_i8)]
+        );
+        // i16
+        generate_cast_test_case!(
+            &array,
+            Int16Array,
+            &DataType::Int16,
+            vec![Some(1_i16), Some(2_i16), Some(3_i16), None, Some(5_i16)]
+        );
+        // i32
+        generate_cast_test_case!(
+            &array,
+            Int32Array,
+            &DataType::Int32,
+            vec![Some(1_i32), Some(2_i32), Some(3_i32), None, Some(5_i32)]
+        );
+        // i64
+        generate_cast_test_case!(
+            &array,
+            Int64Array,
+            &DataType::Int64,
+            vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)]
+        );
+
+        // overflow test: out of range of max i8
+        let value_array: Vec<Option<i256>> = vec![Some(i256::from_i128(24400))];
+        let decimal_array = create_decimal256_array(value_array, 38, 2).unwrap();
+        let array = Arc::new(decimal_array) as ArrayRef;
+        let casted_array =
+            cast_with_options(&array, &DataType::Int8, &CastOptions { safe: false });
+        assert_eq!(
+            "Cast error: value of 244 is out of range Int8".to_string(),
+            casted_array.unwrap_err().to_string()
+        );
+
+        let casted_array =
+            cast_with_options(&array, &DataType::Int8, &CastOptions { safe: true });
+        assert!(casted_array.is_ok());
+        assert!(casted_array.unwrap().is_null(0));
+    }
+
     #[test]
     #[cfg(not(feature = "force_validate"))]
     fn test_cast_numeric_to_decimal128() {