You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2022/11/23 18:33:52 UTC

[arrow-rs] branch master updated: Support decimal negative scale (#3152)

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 78ab0ef3f Support decimal negative scale (#3152)
78ab0ef3f is described below

commit 78ab0ef3f6f422fd4b79a29504f0274220aaf74b
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Wed Nov 23 10:33:47 2022 -0800

    Support decimal negative scale (#3152)
    
    * Support decimal negative scale
    
    * Fix casting from numeric to negative scale decimal
    
    * Fix clippy
---
 arrow-array/src/array/primitive_array.rs |  15 +++--
 arrow-array/src/types.rs                 |  33 +++++-----
 arrow-cast/src/cast.rs                   | 102 ++++++++++++++++++++++++++-----
 arrow-csv/src/reader.rs                  |   6 +-
 arrow-data/src/decimal.rs                |   6 +-
 arrow-schema/src/datatype.rs             |   4 +-
 arrow-select/src/take.rs                 |   6 +-
 arrow/benches/cast_kernels.rs            |   4 +-
 arrow/src/datatypes/ffi.rs               |   4 +-
 arrow/tests/array_transform.rs           |   2 +-
 10 files changed, 131 insertions(+), 51 deletions(-)

diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs
index f34c899e2..bd68b9698 100644
--- a/arrow-array/src/array/primitive_array.rs
+++ b/arrow-array/src/array/primitive_array.rs
@@ -1003,7 +1003,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
     pub fn with_precision_and_scale(
         self,
         precision: u8,
-        scale: u8,
+        scale: i8,
     ) -> Result<Self, ArrowError>
     where
         Self: Sized,
@@ -1024,7 +1024,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
     fn validate_precision_scale(
         &self,
         precision: u8,
-        scale: u8,
+        scale: i8,
     ) -> Result<(), ArrowError> {
         if precision == 0 {
             return Err(ArrowError::InvalidArgumentError(format!(
@@ -1046,7 +1046,14 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
                 T::MAX_SCALE
             )));
         }
-        if scale > precision {
+        if scale < -T::MAX_SCALE {
+            return Err(ArrowError::InvalidArgumentError(format!(
+                "scale {} is smaller than min {}",
+                scale,
+                -Decimal128Type::MAX_SCALE
+            )));
+        }
+        if scale > 0 && scale as u8 > precision {
             return Err(ArrowError::InvalidArgumentError(format!(
                 "scale {} is greater than precision {}",
                 scale, precision
@@ -1102,7 +1109,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
     }
 
     /// Returns the decimal scale of this array
-    pub fn scale(&self) -> u8 {
+    pub fn scale(&self) -> i8 {
         match T::BYTE_LENGTH {
             16 => {
                 if let DataType::Decimal128(_, s) = self.data().data_type() {
diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs
index dd4d1ba42..40d262e8e 100644
--- a/arrow-array/src/types.rs
+++ b/arrow-array/src/types.rs
@@ -491,15 +491,15 @@ pub trait DecimalType:
 {
     const BYTE_LENGTH: usize;
     const MAX_PRECISION: u8;
-    const MAX_SCALE: u8;
-    const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType;
+    const MAX_SCALE: i8;
+    const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType;
     const DEFAULT_TYPE: DataType;
 
     /// "Decimal128" or "Decimal256", for use in error messages
     const PREFIX: &'static str;
 
     /// Formats the decimal value with the provided precision and scale
-    fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String;
+    fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String;
 
     /// Validates that `value` contains no more than `precision` decimal digits
     fn validate_decimal_precision(
@@ -515,14 +515,14 @@ pub struct Decimal128Type {}
 impl DecimalType for Decimal128Type {
     const BYTE_LENGTH: usize = 16;
     const MAX_PRECISION: u8 = DECIMAL128_MAX_PRECISION;
-    const MAX_SCALE: u8 = DECIMAL128_MAX_SCALE;
-    const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal128;
+    const MAX_SCALE: i8 = DECIMAL128_MAX_SCALE;
+    const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal128;
     const DEFAULT_TYPE: DataType =
         DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
     const PREFIX: &'static str = "Decimal128";
 
-    fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
-        format_decimal_str(&value.to_string(), precision as usize, scale as usize)
+    fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String {
+        format_decimal_str(&value.to_string(), precision as usize, scale)
     }
 
     fn validate_decimal_precision(num: i128, precision: u8) -> Result<(), ArrowError> {
@@ -543,14 +543,14 @@ pub struct Decimal256Type {}
 impl DecimalType for Decimal256Type {
     const BYTE_LENGTH: usize = 32;
     const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION;
-    const MAX_SCALE: u8 = DECIMAL256_MAX_SCALE;
-    const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal256;
+    const MAX_SCALE: i8 = DECIMAL256_MAX_SCALE;
+    const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal256;
     const DEFAULT_TYPE: DataType =
         DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
     const PREFIX: &'static str = "Decimal256";
 
-    fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
-        format_decimal_str(&value.to_string(), precision as usize, scale as usize)
+    fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String {
+        format_decimal_str(&value.to_string(), precision as usize, scale)
     }
 
     fn validate_decimal_precision(num: i256, precision: u8) -> Result<(), ArrowError> {
@@ -564,7 +564,7 @@ impl ArrowPrimitiveType for Decimal256Type {
     const DATA_TYPE: DataType = <Self as DecimalType>::DEFAULT_TYPE;
 }
 
-fn format_decimal_str(value_str: &str, precision: usize, scale: usize) -> String {
+fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
     let (sign, rest) = match value_str.strip_prefix('-') {
         Some(stripped) => ("-", stripped),
         None => ("", value_str),
@@ -574,13 +574,16 @@ fn format_decimal_str(value_str: &str, precision: usize, scale: usize) -> String
 
     if scale == 0 {
         value_str.to_string()
-    } else if rest.len() > scale {
+    } else if scale < 0 {
+        let padding = value_str.len() + scale.unsigned_abs() as usize;
+        format!("{:0<width$}", value_str, width = padding)
+    } else if rest.len() > scale as usize {
         // Decimal separator is in the middle of the string
-        let (whole, decimal) = value_str.split_at(value_str.len() - scale);
+        let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize);
         format!("{}.{}", whole, decimal)
     } else {
         // String has to be padded
-        format!("{}0.{:0>width$}", sign, rest, width = scale)
+        format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
     }
 }
 
diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index 3bf97cf7a..61be2171b 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -319,7 +319,7 @@ fn cast_integer_to_decimal<
 >(
     array: &PrimitiveArray<T>,
     precision: u8,
-    scale: u8,
+    scale: i8,
     base: M,
     cast_options: &CastOptions,
 ) -> Result<ArrayRef, ArrowError>
@@ -327,7 +327,7 @@ where
     <T as ArrowPrimitiveType>::Native: AsPrimitive<M>,
     M: ArrowNativeTypeOp,
 {
-    let mul: M = base.pow_checked(scale as u32).map_err(|_| {
+    let mul_or_div: M = base.pow_checked(scale.unsigned_abs() as u32).map_err(|_| {
         ArrowError::CastError(format!(
             "Cannot cast to {:?}({}, {}). The scale causes overflow.",
             D::PREFIX,
@@ -336,14 +336,26 @@ where
         ))
     })?;
 
-    if cast_options.safe {
+    if scale < 0 {
+        if cast_options.safe {
+            array
+                .unary_opt::<_, D>(|v| v.as_().div_checked(mul_or_div).ok())
+                .with_precision_and_scale(precision, scale)
+                .map(|a| Arc::new(a) as ArrayRef)
+        } else {
+            array
+                .try_unary::<_, D, _>(|v| v.as_().div_checked(mul_or_div))
+                .and_then(|a| a.with_precision_and_scale(precision, scale))
+                .map(|a| Arc::new(a) as ArrayRef)
+        }
+    } else if cast_options.safe {
         array
-            .unary_opt::<_, D>(|v| v.as_().mul_checked(mul).ok())
+            .unary_opt::<_, D>(|v| v.as_().mul_checked(mul_or_div).ok())
             .with_precision_and_scale(precision, scale)
             .map(|a| Arc::new(a) as ArrayRef)
     } else {
         array
-            .try_unary::<_, D, _>(|v| v.as_().mul_checked(mul))
+            .try_unary::<_, D, _>(|v| v.as_().mul_checked(mul_or_div))
             .and_then(|a| a.with_precision_and_scale(precision, scale))
             .map(|a| Arc::new(a) as ArrayRef)
     }
@@ -352,7 +364,7 @@ where
 fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
     array: &PrimitiveArray<T>,
     precision: u8,
-    scale: u8,
+    scale: i8,
     cast_options: &CastOptions,
 ) -> Result<ArrayRef, ArrowError>
 where
@@ -391,7 +403,7 @@ where
 fn cast_floating_point_to_decimal256<T: ArrowPrimitiveType>(
     array: &PrimitiveArray<T>,
     precision: u8,
-    scale: u8,
+    scale: i8,
     cast_options: &CastOptions,
 ) -> Result<ArrayRef, ArrowError>
 where
@@ -437,7 +449,7 @@ fn cast_reinterpret_arrays<
 fn cast_decimal_to_integer<D, T>(
     array: &ArrayRef,
     base: D::Native,
-    scale: u8,
+    scale: i8,
     cast_options: &CastOptions,
 ) -> Result<ArrayRef, ArrowError>
 where
@@ -1921,9 +1933,9 @@ fn cast_decimal_to_decimal_with_option<
     const BYTE_WIDTH2: usize,
 >(
     array: &ArrayRef,
-    input_scale: &u8,
+    input_scale: &i8,
     output_precision: &u8,
-    output_scale: &u8,
+    output_scale: &i8,
     cast_options: &CastOptions,
 ) -> Result<ArrayRef, ArrowError> {
     if cast_options.safe {
@@ -1947,9 +1959,9 @@ fn cast_decimal_to_decimal_with_option<
 /// the array values when cast failures happen.
 fn cast_decimal_to_decimal_safe<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
     array: &ArrayRef,
-    input_scale: &u8,
+    input_scale: &i8,
     output_precision: &u8,
-    output_scale: &u8,
+    output_scale: &i8,
 ) -> Result<ArrayRef, ArrowError> {
     if input_scale > output_scale {
         // For example, input_scale is 4 and output_scale is 3;
@@ -2062,9 +2074,9 @@ fn cast_decimal_to_decimal_safe<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usi
 /// cast failure happens.
 fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
     array: &ArrayRef,
-    input_scale: &u8,
+    input_scale: &i8,
     output_precision: &u8,
-    output_scale: &u8,
+    output_scale: &i8,
 ) -> Result<ArrayRef, ArrowError> {
     if input_scale > output_scale {
         // For example, input_scale is 4 and output_scale is 3;
@@ -3540,7 +3552,7 @@ mod tests {
     fn create_decimal_array(
         array: Vec<Option<i128>>,
         precision: u8,
-        scale: u8,
+        scale: i8,
     ) -> Result<Decimal128Array, ArrowError> {
         array
             .into_iter()
@@ -3551,7 +3563,7 @@ mod tests {
     fn create_decimal256_array(
         array: Vec<Option<i256>>,
         precision: u8,
-        scale: u8,
+        scale: i8,
     ) -> Result<Decimal256Array, ArrowError> {
         array
             .into_iter()
@@ -7206,4 +7218,62 @@ mod tests {
             err
         );
     }
+
+    #[test]
+    fn test_cast_decimal128_to_decimal128_negative_scale() {
+        let input_type = DataType::Decimal128(20, 0);
+        let output_type = DataType::Decimal128(20, -1);
+        assert!(can_cast_types(&input_type, &output_type));
+        let array = vec![Some(1123456), Some(2123456), Some(3123456), None];
+        let input_decimal_array = create_decimal_array(array, 20, 0).unwrap();
+        let array = Arc::new(input_decimal_array) as ArrayRef;
+        generate_cast_test_case!(
+            &array,
+            Decimal128Array,
+            &output_type,
+            vec![
+                Some(112345_i128),
+                Some(212345_i128),
+                Some(312345_i128),
+                None
+            ]
+        );
+
+        let casted_array = cast(&array, &output_type).unwrap();
+        let decimal_arr = as_primitive_array::<Decimal128Type>(&casted_array);
+
+        assert_eq!("1123450", decimal_arr.value_as_string(0));
+        assert_eq!("2123450", decimal_arr.value_as_string(1));
+        assert_eq!("3123450", decimal_arr.value_as_string(2));
+    }
+
+    #[test]
+    fn test_cast_numeric_to_decimal128_negative() {
+        let decimal_type = DataType::Decimal128(38, -1);
+        let array = Arc::new(Int32Array::from(vec![
+            Some(1123456),
+            Some(2123456),
+            Some(3123456),
+        ])) as ArrayRef;
+
+        let casted_array = cast(&array, &decimal_type).unwrap();
+        let decimal_arr = as_primitive_array::<Decimal128Type>(&casted_array);
+
+        assert_eq!("1123450", decimal_arr.value_as_string(0));
+        assert_eq!("2123450", decimal_arr.value_as_string(1));
+        assert_eq!("3123450", decimal_arr.value_as_string(2));
+
+        let array = Arc::new(Float32Array::from(vec![
+            Some(1123.456),
+            Some(2123.456),
+            Some(3123.456),
+        ])) as ArrayRef;
+
+        let casted_array = cast(&array, &decimal_type).unwrap();
+        let decimal_arr = as_primitive_array::<Decimal128Type>(&casted_array);
+
+        assert_eq!("1120", decimal_arr.value_as_string(0));
+        assert_eq!("2120", decimal_arr.value_as_string(1));
+        assert_eq!("3120", decimal_arr.value_as_string(2));
+    }
 }
diff --git a/arrow-csv/src/reader.rs b/arrow-csv/src/reader.rs
index 4200e9329..6432fb1b8 100644
--- a/arrow-csv/src/reader.rs
+++ b/arrow-csv/src/reader.rs
@@ -721,7 +721,7 @@ fn build_decimal_array(
     rows: &[StringRecord],
     col_idx: usize,
     precision: u8,
-    scale: u8,
+    scale: i8,
 ) -> Result<ArrayRef, ArrowError> {
     let mut decimal_builder = Decimal128Builder::with_capacity(rows.len());
     for row in rows {
@@ -762,13 +762,13 @@ fn build_decimal_array(
 fn parse_decimal_with_parameter(
     s: &str,
     precision: u8,
-    scale: u8,
+    scale: i8,
 ) -> Result<i128, ArrowError> {
     if PARSE_DECIMAL_RE.is_match(s) {
         let mut offset = s.len();
         let len = s.len();
         let mut base = 1;
-        let scale_usize = usize::from(scale);
+        let scale_usize = usize::from(scale as u8);
 
         // handle the value after the '.' and meet the scale
         let delimiter_position = s.find('.');
diff --git a/arrow-data/src/decimal.rs b/arrow-data/src/decimal.rs
index a6a087749..7011c4085 100644
--- a/arrow-data/src/decimal.rs
+++ b/arrow-data/src/decimal.rs
@@ -728,17 +728,17 @@ pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
 pub const DECIMAL128_MAX_PRECISION: u8 = 38;
 
 /// The maximum scale for [arrow_schema::DataType::Decimal128] values
-pub const DECIMAL128_MAX_SCALE: u8 = 38;
+pub const DECIMAL128_MAX_SCALE: i8 = 38;
 
 /// The maximum precision for [arrow_schema::DataType::Decimal256] values
 pub const DECIMAL256_MAX_PRECISION: u8 = 76;
 
 /// The maximum scale for [arrow_schema::DataType::Decimal256] values
-pub const DECIMAL256_MAX_SCALE: u8 = 76;
+pub const DECIMAL256_MAX_SCALE: i8 = 76;
 
 /// The default scale for [arrow_schema::DataType::Decimal128] and
 /// [arrow_schema::DataType::Decimal256] values
-pub const DECIMAL_DEFAULT_SCALE: u8 = 10;
+pub const DECIMAL_DEFAULT_SCALE: i8 = 10;
 
 /// Validates that the specified `i128` value can be properly
 /// interpreted as a Decimal number with precision `precision`
diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs
index cf85902e4..f74e2a24b 100644
--- a/arrow-schema/src/datatype.rs
+++ b/arrow-schema/src/datatype.rs
@@ -190,14 +190,14 @@ pub enum DataType {
     /// * scale is the number of digits past the decimal
     ///
     /// For example the number 123.45 has precision 5 and scale 2.
-    Decimal128(u8, u8),
+    Decimal128(u8, i8),
     /// Exact 256-bit width decimal value with precision and scale
     ///
     /// * precision is the total number of digits
     /// * scale is the number of digits past the decimal
     ///
     /// For example the number 123.45 has precision 5 and scale 2.
-    Decimal256(u8, u8),
+    Decimal256(u8, i8),
     /// A Map is a logical nested type that is represented as
     ///
     /// `List<entries: Struct<key: K, value: V>>`
diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs
index d498ae487..857b6e323 100644
--- a/arrow-select/src/take.rs
+++ b/arrow-select/src/take.rs
@@ -914,7 +914,7 @@ mod tests {
         options: Option<TakeOptions>,
         expected_data: Vec<Option<i128>>,
         precision: &u8,
-        scale: &u8,
+        scale: &i8,
     ) -> Result<(), ArrowError> {
         let output = data
             .into_iter()
@@ -1032,7 +1032,7 @@ mod tests {
     fn test_take_decimal128_non_null_indices() {
         let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
         let precision: u8 = 10;
-        let scale: u8 = 5;
+        let scale: i8 = 5;
         test_take_decimal_arrays(
             vec![None, Some(3), Some(5), Some(2), Some(3), None],
             &index,
@@ -1048,7 +1048,7 @@ mod tests {
     fn test_take_decimal128() {
         let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
         let precision: u8 = 10;
-        let scale: u8 = 5;
+        let scale: i8 = 5;
         test_take_decimal_arrays(
             vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
             &index,
diff --git a/arrow/benches/cast_kernels.rs b/arrow/benches/cast_kernels.rs
index e93c78608..7ef4d1d7e 100644
--- a/arrow/benches/cast_kernels.rs
+++ b/arrow/benches/cast_kernels.rs
@@ -84,7 +84,7 @@ fn build_utf8_date_time_array(size: usize, with_nulls: bool) -> ArrayRef {
     Arc::new(builder.finish())
 }
 
-fn build_decimal128_array(size: usize, precision: u8, scale: u8) -> ArrayRef {
+fn build_decimal128_array(size: usize, precision: u8, scale: i8) -> ArrayRef {
     let mut rng = seedable_rng();
     let mut builder = Decimal128Builder::with_capacity(size);
 
@@ -99,7 +99,7 @@ fn build_decimal128_array(size: usize, precision: u8, scale: u8) -> ArrayRef {
     )
 }
 
-fn build_decimal256_array(size: usize, precision: u8, scale: u8) -> ArrayRef {
+fn build_decimal256_array(size: usize, precision: u8, scale: i8) -> ArrayRef {
     let mut rng = seedable_rng();
     let mut builder = Decimal256Builder::with_capacity(size);
     let mut bytes = [0; 32];
diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs
index ef303dfdd..41addf24f 100644
--- a/arrow/src/datatypes/ffi.rs
+++ b/arrow/src/datatypes/ffi.rs
@@ -103,7 +103,7 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
                                         "The decimal type requires an integer precision".to_string(),
                                     )
                                 })?;
-                                let parsed_scale = scale.parse::<u8>().map_err(|_| {
+                                let parsed_scale = scale.parse::<i8>().map_err(|_| {
                                     ArrowError::CDataInterface(
                                         "The decimal type requires an integer scale".to_string(),
                                     )
@@ -119,7 +119,7 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
                                         "The decimal type requires an integer precision".to_string(),
                                     )
                                 })?;
-                                let parsed_scale = scale.parse::<u8>().map_err(|_| {
+                                let parsed_scale = scale.parse::<i8>().map_err(|_| {
                                     ArrowError::CDataInterface(
                                         "The decimal type requires an integer scale".to_string(),
                                     )
diff --git a/arrow/tests/array_transform.rs b/arrow/tests/array_transform.rs
index 42f9ab277..3c08a592d 100644
--- a/arrow/tests/array_transform.rs
+++ b/arrow/tests/array_transform.rs
@@ -31,7 +31,7 @@ use std::sync::Arc;
 fn create_decimal_array(
     array: Vec<Option<i128>>,
     precision: u8,
-    scale: u8,
+    scale: i8,
 ) -> Decimal128Array {
     array
         .into_iter()