You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/04/24 03:15:50 UTC

[arrow-rs] branch master updated: optimize cast for same decimal type and same scale (#4088)

This is an automated email from the ASF dual-hosted git repository.

liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 609986418 optimize cast for same decimal type and same scale (#4088)
609986418 is described below

commit 6099864180b9a42472bd0a0a17d16a1b612902f4
Author: Kun Liu <li...@apache.org>
AuthorDate: Mon Apr 24 11:15:44 2023 +0800

    optimize cast for same decimal type and same scale (#4088)
---
 arrow-cast/src/cast.rs        | 288 ++++++++++++++++++++++++++++--------------
 arrow/benches/cast_kernels.rs |   7 +
 2 files changed, 201 insertions(+), 94 deletions(-)

diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index bc37174b9..61a296e99 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -851,7 +851,7 @@ pub fn cast_with_options(
             cast_primitive_to_list::<i64>(array, to, to_type, cast_options)
         }
         (Decimal128(_, s1), Decimal128(p2, s2)) => {
-            cast_decimal_to_decimal::<Decimal128Type, Decimal128Type>(
+            cast_decimal_to_decimal_same_type::<Decimal128Type>(
                 array.as_primitive(),
                 *s1,
                 *p2,
@@ -860,7 +860,7 @@ pub fn cast_with_options(
             )
         }
         (Decimal256(_, s1), Decimal256(p2, s2)) => {
-            cast_decimal_to_decimal::<Decimal256Type, Decimal256Type>(
+            cast_decimal_to_decimal_same_type::<Decimal256Type>(
                 array.as_primitive(),
                 *s1,
                 *p2,
@@ -1292,16 +1292,16 @@ pub fn cast_with_options(
                 cast_string_to_time64nanosecond::<i32>(array, cast_options)
             }
             Timestamp(TimeUnit::Second, to_tz) => {
-                cast_string_to_timestamp::<i32, TimestampSecondType>(array, to_tz,cast_options)
+                cast_string_to_timestamp::<i32, TimestampSecondType>(array, to_tz, cast_options)
             }
             Timestamp(TimeUnit::Millisecond, to_tz) => {
                 cast_string_to_timestamp::<i32, TimestampMillisecondType>(array, to_tz, cast_options)
             }
             Timestamp(TimeUnit::Microsecond, to_tz) => {
-                cast_string_to_timestamp::<i32, TimestampMicrosecondType>(array, to_tz,cast_options)
+                cast_string_to_timestamp::<i32, TimestampMicrosecondType>(array, to_tz, cast_options)
             }
             Timestamp(TimeUnit::Nanosecond, to_tz) => {
-                cast_string_to_timestamp::<i32, TimestampNanosecondType>(array, to_tz,cast_options)
+                cast_string_to_timestamp::<i32, TimestampNanosecondType>(array, to_tz, cast_options)
             }
             Interval(IntervalUnit::YearMonth) => {
                 cast_string_to_year_month_interval::<i32>(array, cast_options)
@@ -1385,7 +1385,7 @@ pub fn cast_with_options(
                 cast_byte_container::<BinaryType, LargeBinaryType>(array)
             }
             FixedSizeBinary(size) => {
-                cast_binary_to_fixed_size_binary::<i32>(array,*size, cast_options)
+                cast_binary_to_fixed_size_binary::<i32>(array, *size, cast_options)
             }
             _ => Err(ArrowError::CastError(format!(
                 "Casting from {from_type:?} to {to_type:?} not supported",
@@ -1876,12 +1876,12 @@ pub fn cast_with_options(
                         })
                 }
                 false => {
-                            array.as_primitive::<TimestampSecondType>().try_unary::<_, Date64Type, _>(
-                                |x| {
-                                    x.mul_checked(MILLISECONDS)
-                                },
-                            )?
-                        }
+                    array.as_primitive::<TimestampSecondType>().try_unary::<_, Date64Type, _>(
+                        |x| {
+                            x.mul_checked(MILLISECONDS)
+                        },
+                    )?
+                }
             },
         )),
         (Timestamp(TimeUnit::Millisecond, _), Date64) => {
@@ -1922,10 +1922,10 @@ pub fn cast_with_options(
             Ok(Arc::new(
                 array.as_primitive::<TimestampMillisecondType>()
                     .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| {
-                    Ok(time_to_time64us(as_time_res_with_timezone::<
-                        TimestampMillisecondType,
-                    >(x, tz)?))
-                })?,
+                        Ok(time_to_time64us(as_time_res_with_timezone::<
+                            TimestampMillisecondType,
+                        >(x, tz)?))
+                    })?,
             ))
         }
         (Timestamp(TimeUnit::Millisecond, tz), Time64(TimeUnit::Nanosecond)) => {
@@ -1933,10 +1933,10 @@ pub fn cast_with_options(
             Ok(Arc::new(
                 array.as_primitive::<TimestampMillisecondType>()
                     .try_unary::<_, Time64NanosecondType, ArrowError>(|x| {
-                    Ok(time_to_time64ns(as_time_res_with_timezone::<
-                        TimestampMillisecondType,
-                    >(x, tz)?))
-                })?,
+                        Ok(time_to_time64ns(as_time_res_with_timezone::<
+                            TimestampMillisecondType,
+                        >(x, tz)?))
+                    })?,
             ))
         }
         (Timestamp(TimeUnit::Microsecond, tz), Time64(TimeUnit::Microsecond)) => {
@@ -1944,10 +1944,10 @@ pub fn cast_with_options(
             Ok(Arc::new(
                 array.as_primitive::<TimestampMicrosecondType>()
                     .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| {
-                    Ok(time_to_time64us(as_time_res_with_timezone::<
-                        TimestampMicrosecondType,
-                    >(x, tz)?))
-                })?,
+                        Ok(time_to_time64us(as_time_res_with_timezone::<
+                            TimestampMicrosecondType,
+                        >(x, tz)?))
+                    })?,
             ))
         }
         (Timestamp(TimeUnit::Microsecond, tz), Time64(TimeUnit::Nanosecond)) => {
@@ -1955,10 +1955,10 @@ pub fn cast_with_options(
             Ok(Arc::new(
                 array.as_primitive::<TimestampMicrosecondType>()
                     .try_unary::<_, Time64NanosecondType, ArrowError>(|x| {
-                    Ok(time_to_time64ns(as_time_res_with_timezone::<
-                        TimestampMicrosecondType,
-                    >(x, tz)?))
-                })?,
+                        Ok(time_to_time64ns(as_time_res_with_timezone::<
+                            TimestampMicrosecondType,
+                        >(x, tz)?))
+                    })?,
             ))
         }
         (Timestamp(TimeUnit::Nanosecond, tz), Time64(TimeUnit::Microsecond)) => {
@@ -1966,10 +1966,10 @@ pub fn cast_with_options(
             Ok(Arc::new(
                 array.as_primitive::<TimestampNanosecondType>()
                     .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| {
-                    Ok(time_to_time64us(as_time_res_with_timezone::<
-                        TimestampNanosecondType,
-                    >(x, tz)?))
-                })?,
+                        Ok(time_to_time64us(as_time_res_with_timezone::<
+                            TimestampNanosecondType,
+                        >(x, tz)?))
+                    })?,
             ))
         }
         (Timestamp(TimeUnit::Nanosecond, tz), Time64(TimeUnit::Nanosecond)) => {
@@ -1977,10 +1977,10 @@ pub fn cast_with_options(
             Ok(Arc::new(
                 array.as_primitive::<TimestampNanosecondType>()
                     .try_unary::<_, Time64NanosecondType, ArrowError>(|x| {
-                    Ok(time_to_time64ns(as_time_res_with_timezone::<
-                        TimestampNanosecondType,
-                    >(x, tz)?))
-                })?,
+                        Ok(time_to_time64ns(as_time_res_with_timezone::<
+                            TimestampNanosecondType,
+                        >(x, tz)?))
+                    })?,
             ))
         }
         (Timestamp(TimeUnit::Second, tz), Time32(TimeUnit::Second)) => {
@@ -2021,10 +2021,10 @@ pub fn cast_with_options(
             Ok(Arc::new(
                 array.as_primitive::<TimestampMillisecondType>()
                     .try_unary::<_, Time32MillisecondType, ArrowError>(|x| {
-                    Ok(time_to_time32ms(as_time_res_with_timezone::<
-                        TimestampMillisecondType,
-                    >(x, tz)?))
-                })?,
+                        Ok(time_to_time32ms(as_time_res_with_timezone::<
+                            TimestampMillisecondType,
+                        >(x, tz)?))
+                    })?,
             ))
         }
         (Timestamp(TimeUnit::Microsecond, tz), Time32(TimeUnit::Second)) => {
@@ -2043,10 +2043,10 @@ pub fn cast_with_options(
             Ok(Arc::new(
                 array.as_primitive::<TimestampMicrosecondType>()
                     .try_unary::<_, Time32MillisecondType, ArrowError>(|x| {
-                    Ok(time_to_time32ms(as_time_res_with_timezone::<
-                        TimestampMicrosecondType,
-                    >(x, tz)?))
-                })?,
+                        Ok(time_to_time32ms(as_time_res_with_timezone::<
+                            TimestampMicrosecondType,
+                        >(x, tz)?))
+                    })?,
             ))
         }
         (Timestamp(TimeUnit::Nanosecond, tz), Time32(TimeUnit::Second)) => {
@@ -2065,10 +2065,10 @@ pub fn cast_with_options(
             Ok(Arc::new(
                 array.as_primitive::<TimestampNanosecondType>()
                     .try_unary::<_, Time32MillisecondType, ArrowError>(|x| {
-                    Ok(time_to_time32ms(as_time_res_with_timezone::<
-                        TimestampNanosecondType,
-                    >(x, tz)?))
-                })?,
+                        Ok(time_to_time32ms(as_time_res_with_timezone::<
+                            TimestampNanosecondType,
+                        >(x, tz)?))
+                    })?,
             ))
         }
 
@@ -2222,20 +2222,17 @@ impl DecimalCast for i256 {
     }
 }
 
-fn cast_decimal_to_decimal<I, O>(
-    array: &PrimitiveArray<I>,
-    input_scale: i8,
+fn cast_decimal_to_decimal_error<I, O>(
     output_precision: u8,
     output_scale: i8,
-    cast_options: &CastOptions,
-) -> Result<ArrayRef, ArrowError>
+) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError
 where
     I: DecimalType,
     O: DecimalType,
     I::Native: DecimalCast + ArrowNativeTypeOp,
     O::Native: DecimalCast + ArrowNativeTypeOp,
 {
-    let error = |x| {
+    move |x: I::Native| {
         ArrowError::CastError(format!(
             "Cannot cast to {}({}, {}). Overflowing on {:?}",
             O::PREFIX,
@@ -2243,45 +2240,148 @@ where
             output_scale,
             x
         ))
-    };
+    }
+}
 
-    let array: PrimitiveArray<O> = if input_scale > output_scale {
-        let div = I::Native::from_decimal(10_i128)
-            .unwrap()
-            .pow_checked((input_scale - output_scale) as u32)?;
+fn convert_to_smaller_scale_decimal<I, O>(
+    array: &PrimitiveArray<I>,
+    input_scale: i8,
+    output_precision: u8,
+    output_scale: i8,
+    cast_options: &CastOptions,
+) -> Result<PrimitiveArray<O>, ArrowError>
+where
+    I: DecimalType,
+    O: DecimalType,
+    I::Native: DecimalCast + ArrowNativeTypeOp,
+    O::Native: DecimalCast + ArrowNativeTypeOp,
+{
+    let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
+    let div = I::Native::from_decimal(10_i128)
+        .unwrap()
+        .pow_checked((input_scale - output_scale) as u32)?;
 
-        let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
-        let half_neg = half.neg_wrapping();
+    let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
+    let half_neg = half.neg_wrapping();
 
-        let f = |x: I::Native| {
-            // div is >= 10 and so this cannot overflow
-            let d = x.div_wrapping(div);
-            let r = x.mod_wrapping(div);
+    let f = |x: I::Native| {
+        // div is >= 10 and so this cannot overflow
+        let d = x.div_wrapping(div);
+        let r = x.mod_wrapping(div);
 
-            // Round result
-            let adjusted = match x >= I::Native::ZERO {
-                true if r >= half => d.add_wrapping(I::Native::ONE),
-                false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
-                _ => d,
-            };
-            O::Native::from_decimal(adjusted)
+        // Round result
+        let adjusted = match x >= I::Native::ZERO {
+            true if r >= half => d.add_wrapping(I::Native::ONE),
+            false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
+            _ => d,
         };
+        O::Native::from_decimal(adjusted)
+    };
 
-        match cast_options.safe {
-            true => array.unary_opt(f),
-            false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
-        }
-    } else {
-        let mul = O::Native::from_decimal(10_i128)
-            .unwrap()
-            .pow_checked((output_scale - input_scale) as u32)?;
+    Ok(match cast_options.safe {
+        true => array.unary_opt(f),
+        false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
+    })
+}
 
-        let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
+fn convert_to_bigger_or_equal_scale_decimal<I, O>(
+    array: &PrimitiveArray<I>,
+    input_scale: i8,
+    output_precision: u8,
+    output_scale: i8,
+    cast_options: &CastOptions,
+) -> Result<PrimitiveArray<O>, ArrowError>
+where
+    I: DecimalType,
+    O: DecimalType,
+    I::Native: DecimalCast + ArrowNativeTypeOp,
+    O::Native: DecimalCast + ArrowNativeTypeOp,
+{
+    let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
+    let mul = O::Native::from_decimal(10_i128)
+        .unwrap()
+        .pow_checked((output_scale - input_scale) as u32)?;
 
-        match cast_options.safe {
-            true => array.unary_opt(f),
-            false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
+    let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
+
+    Ok(match cast_options.safe {
+        true => array.unary_opt(f),
+        false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
+    })
+}
+
+// Only support one type of decimal cast operations
+fn cast_decimal_to_decimal_same_type<T>(
+    array: &PrimitiveArray<T>,
+    input_scale: i8,
+    output_precision: u8,
+    output_scale: i8,
+    cast_options: &CastOptions,
+) -> Result<ArrayRef, ArrowError>
+where
+    T: DecimalType,
+    T::Native: DecimalCast + ArrowNativeTypeOp,
+{
+    let array: PrimitiveArray<T> = match input_scale.cmp(&output_scale) {
+        Ordering::Equal => {
+            // the scale doesn't change, the native value don't need to be changed
+            array.clone()
         }
+        Ordering::Greater => convert_to_smaller_scale_decimal::<T, T>(
+            array,
+            input_scale,
+            output_precision,
+            output_scale,
+            cast_options,
+        )?,
+        Ordering::Less => {
+            // input_scale < output_scale
+            convert_to_bigger_or_equal_scale_decimal::<T, T>(
+                array,
+                input_scale,
+                output_precision,
+                output_scale,
+                cast_options,
+            )?
+        }
+    };
+
+    Ok(Arc::new(array.with_precision_and_scale(
+        output_precision,
+        output_scale,
+    )?))
+}
+
+// Support two different types of decimal cast operations
+fn cast_decimal_to_decimal<I, O>(
+    array: &PrimitiveArray<I>,
+    input_scale: i8,
+    output_precision: u8,
+    output_scale: i8,
+    cast_options: &CastOptions,
+) -> Result<ArrayRef, ArrowError>
+where
+    I: DecimalType,
+    O: DecimalType,
+    I::Native: DecimalCast + ArrowNativeTypeOp,
+    O::Native: DecimalCast + ArrowNativeTypeOp,
+{
+    let array: PrimitiveArray<O> = if input_scale > output_scale {
+        convert_to_smaller_scale_decimal::<I, O>(
+            array,
+            input_scale,
+            output_precision,
+            output_scale,
+            cast_options,
+        )?
+    } else {
+        convert_to_bigger_or_equal_scale_decimal::<I, O>(
+            array,
+            input_scale,
+            output_precision,
+            output_scale,
+            cast_options,
+        )?
     };
 
     Ok(Arc::new(array.with_precision_and_scale(
@@ -7821,7 +7921,7 @@ mod tests {
             Decimal128Type::format_decimal(
                 parse_string_to_decimal_native::<Decimal128Type>("12345", 2).unwrap(),
                 38,
-                2
+                2,
             ),
             "12345.00"
         );
@@ -7829,7 +7929,7 @@ mod tests {
             Decimal128Type::format_decimal(
                 parse_string_to_decimal_native::<Decimal128Type>("0.12345", 2).unwrap(),
                 38,
-                2
+                2,
             ),
             "0.12"
         );
@@ -7837,7 +7937,7 @@ mod tests {
             Decimal128Type::format_decimal(
                 parse_string_to_decimal_native::<Decimal128Type>(".12345", 2).unwrap(),
                 38,
-                2
+                2,
             ),
             "0.12"
         );
@@ -7845,7 +7945,7 @@ mod tests {
             Decimal128Type::format_decimal(
                 parse_string_to_decimal_native::<Decimal128Type>(".1265", 2).unwrap(),
                 38,
-                2
+                2,
             ),
             "0.13"
         );
@@ -7853,7 +7953,7 @@ mod tests {
             Decimal128Type::format_decimal(
                 parse_string_to_decimal_native::<Decimal128Type>(".1265", 2).unwrap(),
                 38,
-                2
+                2,
             ),
             "0.13"
         );
@@ -7862,7 +7962,7 @@ mod tests {
             Decimal256Type::format_decimal(
                 parse_string_to_decimal_native::<Decimal256Type>("123.45", 3).unwrap(),
                 38,
-                3
+                3,
             ),
             "123.450"
         );
@@ -7870,7 +7970,7 @@ mod tests {
             Decimal256Type::format_decimal(
                 parse_string_to_decimal_native::<Decimal256Type>("12345", 3).unwrap(),
                 38,
-                3
+                3,
             ),
             "12345.000"
         );
@@ -7878,7 +7978,7 @@ mod tests {
             Decimal256Type::format_decimal(
                 parse_string_to_decimal_native::<Decimal256Type>("0.12345", 3).unwrap(),
                 38,
-                3
+                3,
             ),
             "0.123"
         );
@@ -7886,7 +7986,7 @@ mod tests {
             Decimal256Type::format_decimal(
                 parse_string_to_decimal_native::<Decimal256Type>(".12345", 3).unwrap(),
                 38,
-                3
+                3,
             ),
             "0.123"
         );
@@ -7894,7 +7994,7 @@ mod tests {
             Decimal256Type::format_decimal(
                 parse_string_to_decimal_native::<Decimal256Type>(".1265", 3).unwrap(),
                 38,
-                3
+                3,
             ),
             "0.127"
         );
diff --git a/arrow/benches/cast_kernels.rs b/arrow/benches/cast_kernels.rs
index 7ef4d1d7e..933ddd4a0 100644
--- a/arrow/benches/cast_kernels.rs
+++ b/arrow/benches/cast_kernels.rs
@@ -230,6 +230,13 @@ fn add_benchmark(c: &mut Criterion) {
     c.bench_function("cast decimal256 to decimal256 512", |b| {
         b.iter(|| cast_array(&decimal256_array, DataType::Decimal256(50, 5)))
     });
+
+    c.bench_function("cast decimal128 to decimal128 512 with same scale", |b| {
+        b.iter(|| cast_array(&decimal128_array, DataType::Decimal128(30, 3)))
+    });
+    c.bench_function("cast decimal256 to decimal256 512 with same scale", |b| {
+        b.iter(|| cast_array(&decimal256_array, DataType::Decimal256(60, 3)))
+    });
 }
 
 criterion_group!(benches, add_benchmark);