You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/04/24 03:15:50 UTC
[arrow-rs] branch master updated: optimize cast for same decimal type and same scale (#4088)
This is an automated email from the ASF dual-hosted git repository.
liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 609986418 optimize cast for same decimal type and same scale (#4088)
609986418 is described below
commit 6099864180b9a42472bd0a0a17d16a1b612902f4
Author: Kun Liu <li...@apache.org>
AuthorDate: Mon Apr 24 11:15:44 2023 +0800
optimize cast for same decimal type and same scale (#4088)
---
arrow-cast/src/cast.rs | 288 ++++++++++++++++++++++++++++--------------
arrow/benches/cast_kernels.rs | 7 +
2 files changed, 201 insertions(+), 94 deletions(-)
diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index bc37174b9..61a296e99 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -851,7 +851,7 @@ pub fn cast_with_options(
cast_primitive_to_list::<i64>(array, to, to_type, cast_options)
}
(Decimal128(_, s1), Decimal128(p2, s2)) => {
- cast_decimal_to_decimal::<Decimal128Type, Decimal128Type>(
+ cast_decimal_to_decimal_same_type::<Decimal128Type>(
array.as_primitive(),
*s1,
*p2,
@@ -860,7 +860,7 @@ pub fn cast_with_options(
)
}
(Decimal256(_, s1), Decimal256(p2, s2)) => {
- cast_decimal_to_decimal::<Decimal256Type, Decimal256Type>(
+ cast_decimal_to_decimal_same_type::<Decimal256Type>(
array.as_primitive(),
*s1,
*p2,
@@ -1292,16 +1292,16 @@ pub fn cast_with_options(
cast_string_to_time64nanosecond::<i32>(array, cast_options)
}
Timestamp(TimeUnit::Second, to_tz) => {
- cast_string_to_timestamp::<i32, TimestampSecondType>(array, to_tz,cast_options)
+ cast_string_to_timestamp::<i32, TimestampSecondType>(array, to_tz, cast_options)
}
Timestamp(TimeUnit::Millisecond, to_tz) => {
cast_string_to_timestamp::<i32, TimestampMillisecondType>(array, to_tz, cast_options)
}
Timestamp(TimeUnit::Microsecond, to_tz) => {
- cast_string_to_timestamp::<i32, TimestampMicrosecondType>(array, to_tz,cast_options)
+ cast_string_to_timestamp::<i32, TimestampMicrosecondType>(array, to_tz, cast_options)
}
Timestamp(TimeUnit::Nanosecond, to_tz) => {
- cast_string_to_timestamp::<i32, TimestampNanosecondType>(array, to_tz,cast_options)
+ cast_string_to_timestamp::<i32, TimestampNanosecondType>(array, to_tz, cast_options)
}
Interval(IntervalUnit::YearMonth) => {
cast_string_to_year_month_interval::<i32>(array, cast_options)
@@ -1385,7 +1385,7 @@ pub fn cast_with_options(
cast_byte_container::<BinaryType, LargeBinaryType>(array)
}
FixedSizeBinary(size) => {
- cast_binary_to_fixed_size_binary::<i32>(array,*size, cast_options)
+ cast_binary_to_fixed_size_binary::<i32>(array, *size, cast_options)
}
_ => Err(ArrowError::CastError(format!(
"Casting from {from_type:?} to {to_type:?} not supported",
@@ -1876,12 +1876,12 @@ pub fn cast_with_options(
})
}
false => {
- array.as_primitive::<TimestampSecondType>().try_unary::<_, Date64Type, _>(
- |x| {
- x.mul_checked(MILLISECONDS)
- },
- )?
- }
+ array.as_primitive::<TimestampSecondType>().try_unary::<_, Date64Type, _>(
+ |x| {
+ x.mul_checked(MILLISECONDS)
+ },
+ )?
+ }
},
)),
(Timestamp(TimeUnit::Millisecond, _), Date64) => {
@@ -1922,10 +1922,10 @@ pub fn cast_with_options(
Ok(Arc::new(
array.as_primitive::<TimestampMillisecondType>()
.try_unary::<_, Time64MicrosecondType, ArrowError>(|x| {
- Ok(time_to_time64us(as_time_res_with_timezone::<
- TimestampMillisecondType,
- >(x, tz)?))
- })?,
+ Ok(time_to_time64us(as_time_res_with_timezone::<
+ TimestampMillisecondType,
+ >(x, tz)?))
+ })?,
))
}
(Timestamp(TimeUnit::Millisecond, tz), Time64(TimeUnit::Nanosecond)) => {
@@ -1933,10 +1933,10 @@ pub fn cast_with_options(
Ok(Arc::new(
array.as_primitive::<TimestampMillisecondType>()
.try_unary::<_, Time64NanosecondType, ArrowError>(|x| {
- Ok(time_to_time64ns(as_time_res_with_timezone::<
- TimestampMillisecondType,
- >(x, tz)?))
- })?,
+ Ok(time_to_time64ns(as_time_res_with_timezone::<
+ TimestampMillisecondType,
+ >(x, tz)?))
+ })?,
))
}
(Timestamp(TimeUnit::Microsecond, tz), Time64(TimeUnit::Microsecond)) => {
@@ -1944,10 +1944,10 @@ pub fn cast_with_options(
Ok(Arc::new(
array.as_primitive::<TimestampMicrosecondType>()
.try_unary::<_, Time64MicrosecondType, ArrowError>(|x| {
- Ok(time_to_time64us(as_time_res_with_timezone::<
- TimestampMicrosecondType,
- >(x, tz)?))
- })?,
+ Ok(time_to_time64us(as_time_res_with_timezone::<
+ TimestampMicrosecondType,
+ >(x, tz)?))
+ })?,
))
}
(Timestamp(TimeUnit::Microsecond, tz), Time64(TimeUnit::Nanosecond)) => {
@@ -1955,10 +1955,10 @@ pub fn cast_with_options(
Ok(Arc::new(
array.as_primitive::<TimestampMicrosecondType>()
.try_unary::<_, Time64NanosecondType, ArrowError>(|x| {
- Ok(time_to_time64ns(as_time_res_with_timezone::<
- TimestampMicrosecondType,
- >(x, tz)?))
- })?,
+ Ok(time_to_time64ns(as_time_res_with_timezone::<
+ TimestampMicrosecondType,
+ >(x, tz)?))
+ })?,
))
}
(Timestamp(TimeUnit::Nanosecond, tz), Time64(TimeUnit::Microsecond)) => {
@@ -1966,10 +1966,10 @@ pub fn cast_with_options(
Ok(Arc::new(
array.as_primitive::<TimestampNanosecondType>()
.try_unary::<_, Time64MicrosecondType, ArrowError>(|x| {
- Ok(time_to_time64us(as_time_res_with_timezone::<
- TimestampNanosecondType,
- >(x, tz)?))
- })?,
+ Ok(time_to_time64us(as_time_res_with_timezone::<
+ TimestampNanosecondType,
+ >(x, tz)?))
+ })?,
))
}
(Timestamp(TimeUnit::Nanosecond, tz), Time64(TimeUnit::Nanosecond)) => {
@@ -1977,10 +1977,10 @@ pub fn cast_with_options(
Ok(Arc::new(
array.as_primitive::<TimestampNanosecondType>()
.try_unary::<_, Time64NanosecondType, ArrowError>(|x| {
- Ok(time_to_time64ns(as_time_res_with_timezone::<
- TimestampNanosecondType,
- >(x, tz)?))
- })?,
+ Ok(time_to_time64ns(as_time_res_with_timezone::<
+ TimestampNanosecondType,
+ >(x, tz)?))
+ })?,
))
}
(Timestamp(TimeUnit::Second, tz), Time32(TimeUnit::Second)) => {
@@ -2021,10 +2021,10 @@ pub fn cast_with_options(
Ok(Arc::new(
array.as_primitive::<TimestampMillisecondType>()
.try_unary::<_, Time32MillisecondType, ArrowError>(|x| {
- Ok(time_to_time32ms(as_time_res_with_timezone::<
- TimestampMillisecondType,
- >(x, tz)?))
- })?,
+ Ok(time_to_time32ms(as_time_res_with_timezone::<
+ TimestampMillisecondType,
+ >(x, tz)?))
+ })?,
))
}
(Timestamp(TimeUnit::Microsecond, tz), Time32(TimeUnit::Second)) => {
@@ -2043,10 +2043,10 @@ pub fn cast_with_options(
Ok(Arc::new(
array.as_primitive::<TimestampMicrosecondType>()
.try_unary::<_, Time32MillisecondType, ArrowError>(|x| {
- Ok(time_to_time32ms(as_time_res_with_timezone::<
- TimestampMicrosecondType,
- >(x, tz)?))
- })?,
+ Ok(time_to_time32ms(as_time_res_with_timezone::<
+ TimestampMicrosecondType,
+ >(x, tz)?))
+ })?,
))
}
(Timestamp(TimeUnit::Nanosecond, tz), Time32(TimeUnit::Second)) => {
@@ -2065,10 +2065,10 @@ pub fn cast_with_options(
Ok(Arc::new(
array.as_primitive::<TimestampNanosecondType>()
.try_unary::<_, Time32MillisecondType, ArrowError>(|x| {
- Ok(time_to_time32ms(as_time_res_with_timezone::<
- TimestampNanosecondType,
- >(x, tz)?))
- })?,
+ Ok(time_to_time32ms(as_time_res_with_timezone::<
+ TimestampNanosecondType,
+ >(x, tz)?))
+ })?,
))
}
@@ -2222,20 +2222,17 @@ impl DecimalCast for i256 {
}
}
-fn cast_decimal_to_decimal<I, O>(
- array: &PrimitiveArray<I>,
- input_scale: i8,
+fn cast_decimal_to_decimal_error<I, O>(
output_precision: u8,
output_scale: i8,
- cast_options: &CastOptions,
-) -> Result<ArrayRef, ArrowError>
+) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError
where
I: DecimalType,
O: DecimalType,
I::Native: DecimalCast + ArrowNativeTypeOp,
O::Native: DecimalCast + ArrowNativeTypeOp,
{
- let error = |x| {
+ move |x: I::Native| {
ArrowError::CastError(format!(
"Cannot cast to {}({}, {}). Overflowing on {:?}",
O::PREFIX,
@@ -2243,45 +2240,148 @@ where
output_scale,
x
))
- };
+ }
+}
- let array: PrimitiveArray<O> = if input_scale > output_scale {
- let div = I::Native::from_decimal(10_i128)
- .unwrap()
- .pow_checked((input_scale - output_scale) as u32)?;
+fn convert_to_smaller_scale_decimal<I, O>(
+ array: &PrimitiveArray<I>,
+ input_scale: i8,
+ output_precision: u8,
+ output_scale: i8,
+ cast_options: &CastOptions,
+) -> Result<PrimitiveArray<O>, ArrowError>
+where
+ I: DecimalType,
+ O: DecimalType,
+ I::Native: DecimalCast + ArrowNativeTypeOp,
+ O::Native: DecimalCast + ArrowNativeTypeOp,
+{
+ let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
+ let div = I::Native::from_decimal(10_i128)
+ .unwrap()
+ .pow_checked((input_scale - output_scale) as u32)?;
- let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
- let half_neg = half.neg_wrapping();
+ let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
+ let half_neg = half.neg_wrapping();
- let f = |x: I::Native| {
- // div is >= 10 and so this cannot overflow
- let d = x.div_wrapping(div);
- let r = x.mod_wrapping(div);
+ let f = |x: I::Native| {
+ // div is >= 10 and so this cannot overflow
+ let d = x.div_wrapping(div);
+ let r = x.mod_wrapping(div);
- // Round result
- let adjusted = match x >= I::Native::ZERO {
- true if r >= half => d.add_wrapping(I::Native::ONE),
- false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
- _ => d,
- };
- O::Native::from_decimal(adjusted)
+ // Round result
+ let adjusted = match x >= I::Native::ZERO {
+ true if r >= half => d.add_wrapping(I::Native::ONE),
+ false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
+ _ => d,
};
+ O::Native::from_decimal(adjusted)
+ };
- match cast_options.safe {
- true => array.unary_opt(f),
- false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
- }
- } else {
- let mul = O::Native::from_decimal(10_i128)
- .unwrap()
- .pow_checked((output_scale - input_scale) as u32)?;
+ Ok(match cast_options.safe {
+ true => array.unary_opt(f),
+ false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
+ })
+}
- let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
+fn convert_to_bigger_or_equal_scale_decimal<I, O>(
+ array: &PrimitiveArray<I>,
+ input_scale: i8,
+ output_precision: u8,
+ output_scale: i8,
+ cast_options: &CastOptions,
+) -> Result<PrimitiveArray<O>, ArrowError>
+where
+ I: DecimalType,
+ O: DecimalType,
+ I::Native: DecimalCast + ArrowNativeTypeOp,
+ O::Native: DecimalCast + ArrowNativeTypeOp,
+{
+ let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
+ let mul = O::Native::from_decimal(10_i128)
+ .unwrap()
+ .pow_checked((output_scale - input_scale) as u32)?;
- match cast_options.safe {
- true => array.unary_opt(f),
- false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
+ let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
+
+ Ok(match cast_options.safe {
+ true => array.unary_opt(f),
+ false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
+ })
+}
+
+// Only support one type of decimal cast operations
+fn cast_decimal_to_decimal_same_type<T>(
+ array: &PrimitiveArray<T>,
+ input_scale: i8,
+ output_precision: u8,
+ output_scale: i8,
+ cast_options: &CastOptions,
+) -> Result<ArrayRef, ArrowError>
+where
+ T: DecimalType,
+ T::Native: DecimalCast + ArrowNativeTypeOp,
+{
+ let array: PrimitiveArray<T> = match input_scale.cmp(&output_scale) {
+ Ordering::Equal => {
+ // the scale doesn't change, the native value don't need to be changed
+ array.clone()
}
+ Ordering::Greater => convert_to_smaller_scale_decimal::<T, T>(
+ array,
+ input_scale,
+ output_precision,
+ output_scale,
+ cast_options,
+ )?,
+ Ordering::Less => {
+ // input_scale < output_scale
+ convert_to_bigger_or_equal_scale_decimal::<T, T>(
+ array,
+ input_scale,
+ output_precision,
+ output_scale,
+ cast_options,
+ )?
+ }
+ };
+
+ Ok(Arc::new(array.with_precision_and_scale(
+ output_precision,
+ output_scale,
+ )?))
+}
+
+// Support two different types of decimal cast operations
+fn cast_decimal_to_decimal<I, O>(
+ array: &PrimitiveArray<I>,
+ input_scale: i8,
+ output_precision: u8,
+ output_scale: i8,
+ cast_options: &CastOptions,
+) -> Result<ArrayRef, ArrowError>
+where
+ I: DecimalType,
+ O: DecimalType,
+ I::Native: DecimalCast + ArrowNativeTypeOp,
+ O::Native: DecimalCast + ArrowNativeTypeOp,
+{
+ let array: PrimitiveArray<O> = if input_scale > output_scale {
+ convert_to_smaller_scale_decimal::<I, O>(
+ array,
+ input_scale,
+ output_precision,
+ output_scale,
+ cast_options,
+ )?
+ } else {
+ convert_to_bigger_or_equal_scale_decimal::<I, O>(
+ array,
+ input_scale,
+ output_precision,
+ output_scale,
+ cast_options,
+ )?
};
Ok(Arc::new(array.with_precision_and_scale(
@@ -7821,7 +7921,7 @@ mod tests {
Decimal128Type::format_decimal(
parse_string_to_decimal_native::<Decimal128Type>("12345", 2).unwrap(),
38,
- 2
+ 2,
),
"12345.00"
);
@@ -7829,7 +7929,7 @@ mod tests {
Decimal128Type::format_decimal(
parse_string_to_decimal_native::<Decimal128Type>("0.12345", 2).unwrap(),
38,
- 2
+ 2,
),
"0.12"
);
@@ -7837,7 +7937,7 @@ mod tests {
Decimal128Type::format_decimal(
parse_string_to_decimal_native::<Decimal128Type>(".12345", 2).unwrap(),
38,
- 2
+ 2,
),
"0.12"
);
@@ -7845,7 +7945,7 @@ mod tests {
Decimal128Type::format_decimal(
parse_string_to_decimal_native::<Decimal128Type>(".1265", 2).unwrap(),
38,
- 2
+ 2,
),
"0.13"
);
@@ -7853,7 +7953,7 @@ mod tests {
Decimal128Type::format_decimal(
parse_string_to_decimal_native::<Decimal128Type>(".1265", 2).unwrap(),
38,
- 2
+ 2,
),
"0.13"
);
@@ -7862,7 +7962,7 @@ mod tests {
Decimal256Type::format_decimal(
parse_string_to_decimal_native::<Decimal256Type>("123.45", 3).unwrap(),
38,
- 3
+ 3,
),
"123.450"
);
@@ -7870,7 +7970,7 @@ mod tests {
Decimal256Type::format_decimal(
parse_string_to_decimal_native::<Decimal256Type>("12345", 3).unwrap(),
38,
- 3
+ 3,
),
"12345.000"
);
@@ -7878,7 +7978,7 @@ mod tests {
Decimal256Type::format_decimal(
parse_string_to_decimal_native::<Decimal256Type>("0.12345", 3).unwrap(),
38,
- 3
+ 3,
),
"0.123"
);
@@ -7886,7 +7986,7 @@ mod tests {
Decimal256Type::format_decimal(
parse_string_to_decimal_native::<Decimal256Type>(".12345", 3).unwrap(),
38,
- 3
+ 3,
),
"0.123"
);
@@ -7894,7 +7994,7 @@ mod tests {
Decimal256Type::format_decimal(
parse_string_to_decimal_native::<Decimal256Type>(".1265", 3).unwrap(),
38,
- 3
+ 3,
),
"0.127"
);
diff --git a/arrow/benches/cast_kernels.rs b/arrow/benches/cast_kernels.rs
index 7ef4d1d7e..933ddd4a0 100644
--- a/arrow/benches/cast_kernels.rs
+++ b/arrow/benches/cast_kernels.rs
@@ -230,6 +230,13 @@ fn add_benchmark(c: &mut Criterion) {
c.bench_function("cast decimal256 to decimal256 512", |b| {
b.iter(|| cast_array(&decimal256_array, DataType::Decimal256(50, 5)))
});
+
+ c.bench_function("cast decimal128 to decimal128 512 with same scale", |b| {
+ b.iter(|| cast_array(&decimal128_array, DataType::Decimal128(30, 3)))
+ });
+ c.bench_function("cast decimal256 to decimal256 512 with same scale", |b| {
+ b.iter(|| cast_array(&decimal256_array, DataType::Decimal256(60, 3)))
+ });
}
criterion_group!(benches, add_benchmark);