You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2022/11/04 08:07:14 UTC
[arrow-rs] branch master updated: Check overflow when casting integer to decimal (#3009)
This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 766f69f71 Check overflow when casting integer to decimal (#3009)
766f69f71 is described below
commit 766f69f715faa619077cc5458aef955b627af715
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Fri Nov 4 01:07:08 2022 -0700
Check overflow when casting integer to decimal (#3009)
* Check overflow when casting integer to decimal
* Trigger Build
* Combine cast_integer_to_decimal functions of decimal128 and decimal256
* Fix clippy
* Trigger Build
* Use PREFIX way.
---
arrow-array/src/types.rs | 5 ++
arrow/src/compute/kernels/cast.rs | 130 +++++++++++++++++++++++++++-----------
arrow/src/datatypes/native.rs | 23 +++++++
3 files changed, 122 insertions(+), 36 deletions(-)
diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs
index e6197eed1..7c7a5c811 100644
--- a/arrow-array/src/types.rs
+++ b/arrow-array/src/types.rs
@@ -495,6 +495,9 @@ pub trait DecimalType:
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType;
const DEFAULT_TYPE: DataType;
+ /// "Decimal128" or "Decimal256", for use in error messages
+ const PREFIX: &'static str;
+
/// Formats the decimal value with the provided precision and scale
fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String;
@@ -516,6 +519,7 @@ impl DecimalType for Decimal128Type {
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal128;
const DEFAULT_TYPE: DataType =
DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
+ const PREFIX: &'static str = "Decimal128";
fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
@@ -543,6 +547,7 @@ impl DecimalType for Decimal256Type {
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal256;
const DEFAULT_TYPE: DataType =
DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
+ const PREFIX: &'static str = "Decimal256";
fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs
index 4ad8dd99e..b1e744d26 100644
--- a/arrow/src/compute/kernels/cast.rs
+++ b/arrow/src/compute/kernels/cast.rs
@@ -309,41 +309,43 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS)
}
-fn cast_integer_to_decimal128<T: ArrowNumericType>(
- array: &PrimitiveArray<T>,
- precision: u8,
- scale: u8,
-) -> Result<ArrayRef>
-where
- <T as ArrowPrimitiveType>::Native: AsPrimitive<i128>,
-{
- let mul: i128 = 10_i128.pow(scale as u32);
-
- unary::<T, _, Decimal128Type>(array, |v| v.as_() * mul)
- .with_precision_and_scale(precision, scale)
- .map(|a| Arc::new(a) as ArrayRef)
-}
-
-fn cast_integer_to_decimal256<T: ArrowNumericType>(
+fn cast_integer_to_decimal<
+ T: ArrowNumericType,
+ D: DecimalType + ArrowPrimitiveType<Native = M>,
+ M,
+>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
+ base: M,
+ cast_options: &CastOptions,
) -> Result<ArrayRef>
where
- <T as ArrowPrimitiveType>::Native: AsPrimitive<i256>,
+ <T as ArrowPrimitiveType>::Native: AsPrimitive<M>,
+ M: ArrowNativeTypeOp,
{
- let mul: i256 = i256::from_i128(10_i128)
- .checked_pow(scale as u32)
- .ok_or_else(|| {
- ArrowError::CastError(format!(
- "Cannot cast to Decimal256({}, {}). The scale causes overflow.",
- precision, scale
- ))
- })?;
+ let mul: M = base.pow_checked(scale as u32).map_err(|_| {
+ ArrowError::CastError(format!(
+ "Cannot cast to {:?}({}, {}). The scale causes overflow.",
+ D::PREFIX,
+ precision,
+ scale,
+ ))
+ })?;
- unary::<T, _, Decimal256Type>(array, |v| v.as_().wrapping_mul(mul))
- .with_precision_and_scale(precision, scale)
- .map(|a| Arc::new(a) as ArrayRef)
+ if cast_options.safe {
+ let iter = array
+ .iter()
+ .map(|v| v.and_then(|v| v.as_().mul_checked(mul).ok()));
+ let casted_array = unsafe { PrimitiveArray::<D>::from_trusted_len_iter(iter) };
+ casted_array
+ .with_precision_and_scale(precision, scale)
+ .map(|a| Arc::new(a) as ArrayRef)
+ } else {
+ try_unary::<T, _, D>(array, |v| v.as_().mul_checked(mul))
+ .and_then(|a| a.with_precision_and_scale(precision, scale))
+ .map(|a| Arc::new(a) as ArrayRef)
+ }
}
fn cast_floating_point_to_decimal128<T: ArrowNumericType>(
@@ -562,25 +564,33 @@ pub fn cast_with_options(
// cast data to decimal
match from_type {
// TODO now just support signed numeric to decimal, support decimal to numeric later
- Int8 => cast_integer_to_decimal128(
+ Int8 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int8Type>(array),
*precision,
*scale,
+ 10_i128,
+ cast_options,
),
- Int16 => cast_integer_to_decimal128(
+ Int16 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int16Type>(array),
*precision,
*scale,
+ 10_i128,
+ cast_options,
),
- Int32 => cast_integer_to_decimal128(
+ Int32 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int32Type>(array),
*precision,
*scale,
+ 10_i128,
+ cast_options,
),
- Int64 => cast_integer_to_decimal128(
+ Int64 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int64Type>(array),
*precision,
*scale,
+ 10_i128,
+ cast_options,
),
Float32 => cast_floating_point_to_decimal128(
as_primitive_array::<Float32Type>(array),
@@ -603,25 +613,33 @@ pub fn cast_with_options(
// cast data to decimal
match from_type {
// TODO now just support signed numeric to decimal, support decimal to numeric later
- Int8 => cast_integer_to_decimal256(
+ Int8 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int8Type>(array),
*precision,
*scale,
+ i256::from_i128(10_i128),
+ cast_options,
),
- Int16 => cast_integer_to_decimal256(
+ Int16 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int16Type>(array),
*precision,
*scale,
+ i256::from_i128(10_i128),
+ cast_options,
),
- Int32 => cast_integer_to_decimal256(
+ Int32 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int32Type>(array),
*precision,
*scale,
+ i256::from_i128(10_i128),
+ cast_options,
),
- Int64 => cast_integer_to_decimal256(
+ Int64 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int64Type>(array),
*precision,
*scale,
+ i256::from_i128(10_i128),
+ cast_options,
),
Float32 => cast_floating_point_to_decimal256(
as_primitive_array::<Float32Type>(array),
@@ -6049,4 +6067,44 @@ mod tests {
]
);
}
+
+ #[test]
+ fn test_cast_numeric_to_decimal128_overflow() {
+ let array = Int64Array::from(vec![i64::MAX]);
+ let array = Arc::new(array) as ArrayRef;
+ let casted_array = cast_with_options(
+ &array,
+ &DataType::Decimal128(38, 30),
+ &CastOptions { safe: true },
+ );
+ assert!(casted_array.is_ok());
+ assert!(casted_array.unwrap().is_null(0));
+
+ let casted_array = cast_with_options(
+ &array,
+ &DataType::Decimal128(38, 30),
+ &CastOptions { safe: false },
+ );
+ assert!(casted_array.is_err());
+ }
+
+ #[test]
+ fn test_cast_numeric_to_decimal256_overflow() {
+ let array = Int64Array::from(vec![i64::MAX]);
+ let array = Arc::new(array) as ArrayRef;
+ let casted_array = cast_with_options(
+ &array,
+ &DataType::Decimal256(76, 76),
+ &CastOptions { safe: true },
+ );
+ assert!(casted_array.is_ok());
+ assert!(casted_array.unwrap().is_null(0));
+
+ let casted_array = cast_with_options(
+ &array,
+ &DataType::Decimal256(76, 76),
+ &CastOptions { safe: false },
+ );
+ assert!(casted_array.is_err());
+ }
}
diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs
index bbdec14b4..28ef877a2 100644
--- a/arrow/src/datatypes/native.rs
+++ b/arrow/src/datatypes/native.rs
@@ -19,6 +19,7 @@ use crate::error::{ArrowError, Result};
pub use arrow_array::ArrowPrimitiveType;
pub use arrow_buffer::{i256, ArrowNativeType, ToByteSlice};
use half::f16;
+use num::complex::ComplexFloat;
/// Trait for [`ArrowNativeType`] that adds checked and unchecked arithmetic operations,
/// and totally ordered comparison operations
@@ -68,6 +69,10 @@ pub trait ArrowNativeTypeOp: ArrowNativeType {
fn neg_wrapping(self) -> Self;
+ fn pow_checked(self, exp: u32) -> Result<Self>;
+
+ fn pow_wrapping(self, exp: u32) -> Self;
+
fn is_zero(self) -> bool;
fn is_eq(self, rhs: Self) -> bool;
@@ -171,6 +176,16 @@ macro_rules! native_type_op {
})
}
+ fn pow_checked(self, exp: u32) -> Result<Self> {
+ self.checked_pow(exp).ok_or_else(|| {
+ ArrowError::ComputeError(format!("Overflow happened on: {:?}", self))
+ })
+ }
+
+ fn pow_wrapping(self, exp: u32) -> Self {
+ self.wrapping_pow(exp)
+ }
+
fn neg_wrapping(self) -> Self {
self.wrapping_neg()
}
@@ -279,6 +294,14 @@ macro_rules! native_type_float_op {
-self
}
+ fn pow_checked(self, exp: u32) -> Result<Self> {
+ Ok(self.powi(exp as i32))
+ }
+
+ fn pow_wrapping(self, exp: u32) -> Self {
+ self.powi(exp as i32)
+ }
+
fn is_zero(self) -> bool {
self == $zero
}