You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2023/04/30 07:49:06 UTC
[arrow-rs] branch master updated: Support fixed point multiplication for DictionaryArray of Decimals (#4136)
This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 08dc16c96 Support fixed point multiplication for DictionaryArray of Decimals (#4136)
08dc16c96 is described below
commit 08dc16c9645def758e59651bed55f5aa95e2f42e
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Sun Apr 30 00:49:00 2023 -0700
Support fixed point multiplication for DictionaryArray of Decimals (#4136)
* Add multiply_fixed_point_dyn
* Fix clippy
* For review
---
arrow-arith/src/arithmetic.rs | 252 +++++++++++++++++++++++++++++++++++++-----
1 file changed, 222 insertions(+), 30 deletions(-)
diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs
index 7f5a08190..40ae3255b 100644
--- a/arrow-arith/src/arithmetic.rs
+++ b/arrow-arith/src/arithmetic.rs
@@ -1434,6 +1434,114 @@ pub fn multiply_dyn_checked(
}
}
+#[cfg(feature = "dyn_arith_dict")]
+fn get_precision_scale(dt: &DataType) -> Result<(u8, i8), ArrowError> {
+ match dt {
+ DataType::Decimal128(precision, scale) => Ok((*precision, *scale)),
+ _ => Err(ArrowError::ComputeError(
+ "Cannot get precision and scale from non-decimal type".to_string(),
+ )),
+ }
+}
+
+/// Returns the precision and scale of the result of a multiplication of two decimal types,
+/// and the divisor for fixed point multiplication.
+fn get_fixed_point_info(
+ left: (u8, i8),
+ right: (u8, i8),
+ required_scale: i8,
+) -> Result<(u8, i8, i256), ArrowError> {
+ let product_scale = left.1 + right.1;
+ let precision = min(left.0 + right.0 + 1, DECIMAL128_MAX_PRECISION);
+
+ if required_scale > product_scale {
+ return Err(ArrowError::ComputeError(format!(
+ "Required scale {} is greater than product scale {}",
+ required_scale, product_scale
+ )));
+ }
+
+ let divisor =
+ i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);
+
+ Ok((precision, product_scale, divisor))
+}
+
+#[cfg(feature = "dyn_arith_dict")]
+/// Perform `left * right` operation on two decimal arrays. If either left or right value is
+/// null then the result is also null.
+///
+/// This performs decimal multiplication which allows precision loss if an exact representation
+/// is not possible for the result, according to the required scale. In the case, the result
+/// will be rounded to the required scale.
+///
+/// If the required scale is greater than the product scale, an error is returned.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap around.
+///
+/// It is implemented for compatibility with precision loss `multiply` function provided by
+/// other data processing engines. For multiplication with precision loss detection, use
+/// `multiply_dyn` or `multiply_dyn_checked` instead.
+pub fn multiply_fixed_point_dyn(
+ left: &dyn Array,
+ right: &dyn Array,
+ required_scale: i8,
+) -> Result<ArrayRef, ArrowError> {
+ match (left.data_type(), right.data_type()) {
+ (
+ DataType::Dictionary(_, lhs_value_type),
+ DataType::Dictionary(_, rhs_value_type),
+ ) if matches!(lhs_value_type.as_ref(), &DataType::Decimal128(_, _))
+ && matches!(rhs_value_type.as_ref(), &DataType::Decimal128(_, _)) =>
+ {
+ downcast_dictionary_array!(
+ left => match left.values().data_type() {
+ DataType::Decimal128(_, _) => {
+ let lhs_precision_scale = get_precision_scale(lhs_value_type.as_ref())?;
+ let rhs_precision_scale = get_precision_scale(rhs_value_type.as_ref())?;
+
+ let (precision, product_scale, divisor) = get_fixed_point_info(lhs_precision_scale, rhs_precision_scale, required_scale)?;
+
+ let right = as_dictionary_array::<_>(right);
+
+ if required_scale == product_scale {
+ let mul = multiply_dyn(left, right)?;
+ let array = mul.as_any().downcast_ref::<Decimal128Array>().unwrap();
+ let array = array.clone().with_precision_and_scale(precision, required_scale)?;
+ return Ok(Arc::new(array))
+ }
+
+ let array = math_op_dict::<_, Decimal128Type, _>(left, right, |a, b| {
+ let a = i256::from_i128(a);
+ let b = i256::from_i128(b);
+
+ let mut mul = a.wrapping_mul(b);
+ mul = divide_and_round::<Decimal256Type>(mul, divisor);
+ mul.as_i128()
+ }).and_then(|a| a.with_precision_and_scale(precision, required_scale))?;
+
+ Ok(Arc::new(array))
+ }
+ t => unreachable!("Unsupported dictionary value type {}", t),
+ },
+ t => unreachable!("Unsupported data type {}", t),
+ )
+ }
+ (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => {
+ let left = left.as_any().downcast_ref::<Decimal128Array>().unwrap();
+ let right = right.as_any().downcast_ref::<Decimal128Array>().unwrap();
+
+ multiply_fixed_point(left, right, required_scale)
+ .map(|a| Arc::new(a) as ArrayRef)
+ }
+ (_, _) => Err(ArrowError::CastError(format!(
+ "Unsupported data type {}, {}",
+ left.data_type(),
+ right.data_type()
+ ))),
+ }
+}
+
/// Perform `left * right` operation on two decimal arrays. If either left or right value is
/// null then the result is also null.
///
@@ -1451,27 +1559,17 @@ pub fn multiply_fixed_point_checked(
right: &PrimitiveArray<Decimal128Type>,
required_scale: i8,
) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
- let product_scale = left.scale() + right.scale();
- let precision = min(
- left.precision() + right.precision() + 1,
- DECIMAL128_MAX_PRECISION,
- );
+ let (precision, product_scale, divisor) = get_fixed_point_info(
+ (left.precision(), left.scale()),
+ (right.precision(), right.scale()),
+ required_scale,
+ )?;
if required_scale == product_scale {
return multiply_checked(left, right)?
.with_precision_and_scale(precision, required_scale);
}
- if required_scale > product_scale {
- return Err(ArrowError::ComputeError(format!(
- "Required scale {} is greater than product scale {}",
- required_scale, product_scale
- )));
- }
-
- let divisor =
- i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);
-
try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
let a = i256::from_i128(a);
let b = i256::from_i128(b);
@@ -1505,27 +1603,17 @@ pub fn multiply_fixed_point(
right: &PrimitiveArray<Decimal128Type>,
required_scale: i8,
) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
- let product_scale = left.scale() + right.scale();
- let precision = min(
- left.precision() + right.precision() + 1,
- DECIMAL128_MAX_PRECISION,
- );
+ let (precision, product_scale, divisor) = get_fixed_point_info(
+ (left.precision(), left.scale()),
+ (right.precision(), right.scale()),
+ required_scale,
+ )?;
if required_scale == product_scale {
return multiply(left, right)?
.with_precision_and_scale(precision, required_scale);
}
- if required_scale > product_scale {
- return Err(ArrowError::ComputeError(format!(
- "Required scale {} is greater than product scale {}",
- required_scale, product_scale
- )));
- }
-
- let divisor =
- i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);
-
binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
let a = i256::from_i128(a);
let b = i256::from_i128(b);
@@ -3910,6 +3998,110 @@ mod tests {
);
}
+ #[test]
+ #[cfg(feature = "dyn_arith_dict")]
+ fn test_decimal_multiply_fixed_point_dyn() {
+ // [123456789]
+ let a = Decimal128Array::from(vec![123456789000000000000000000])
+ .with_precision_and_scale(38, 18)
+ .unwrap();
+
+ // [10]
+ let b = Decimal128Array::from(vec![10000000000000000000])
+ .with_precision_and_scale(38, 18)
+ .unwrap();
+
+ // Avoid overflow by reducing the scale.
+ let result = multiply_fixed_point_dyn(&a, &b, 28).unwrap();
+ // [1234567890]
+ let expected = Arc::new(
+ Decimal128Array::from(vec![12345678900000000000000000000000000000])
+ .with_precision_and_scale(38, 28)
+ .unwrap(),
+ ) as ArrayRef;
+
+ assert_eq!(&expected, &result);
+ assert_eq!(
+ result.as_primitive::<Decimal128Type>().value_as_string(0),
+ "1234567890.0000000000000000000000000000"
+ );
+
+ // [123456789, 10]
+ let a = Decimal128Array::from(vec![
+ 123456789000000000000000000,
+ 10000000000000000000,
+ ])
+ .with_precision_and_scale(38, 18)
+ .unwrap();
+
+ // [10, 123456789, 12]
+ let b = Decimal128Array::from(vec![
+ 10000000000000000000,
+ 123456789000000000000000000,
+ 12000000000000000000,
+ ])
+ .with_precision_and_scale(38, 18)
+ .unwrap();
+
+ let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1), None]);
+ let array1 = DictionaryArray::new(keys, Arc::new(a));
+ let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(2), None]);
+ let array2 = DictionaryArray::new(keys, Arc::new(b));
+
+ let result = multiply_fixed_point_dyn(&array1, &array2, 28).unwrap();
+ let expected = Arc::new(
+ Decimal128Array::from(vec![
+ Some(12345678900000000000000000000000000000),
+ Some(12345678900000000000000000000000000000),
+ Some(1200000000000000000000000000000),
+ None,
+ ])
+ .with_precision_and_scale(38, 28)
+ .unwrap(),
+ ) as ArrayRef;
+
+ assert_eq!(&expected, &result);
+ assert_eq!(
+ result.as_primitive::<Decimal128Type>().value_as_string(0),
+ "1234567890.0000000000000000000000000000"
+ );
+ assert_eq!(
+ result.as_primitive::<Decimal128Type>().value_as_string(1),
+ "1234567890.0000000000000000000000000000"
+ );
+ assert_eq!(
+ result.as_primitive::<Decimal128Type>().value_as_string(2),
+ "120.0000000000000000000000000000"
+ );
+
+ // Required scale is same as the product of the input scales. Behavior is same as multiply_dyn.
+ let a = Decimal128Array::from(vec![123, 100])
+ .with_precision_and_scale(3, 2)
+ .unwrap();
+
+ let b = Decimal128Array::from(vec![100, 123, 120])
+ .with_precision_and_scale(3, 2)
+ .unwrap();
+
+ let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1), None]);
+ let array1 = DictionaryArray::new(keys, Arc::new(a));
+ let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(2), None]);
+ let array2 = DictionaryArray::new(keys, Arc::new(b));
+
+ let result = multiply_fixed_point_dyn(&array1, &array2, 4).unwrap();
+ let expected = multiply_dyn(&array1, &array2).unwrap();
+ let expected = Arc::new(
+ expected
+ .as_any()
+ .downcast_ref::<Decimal128Array>()
+ .unwrap()
+ .clone()
+ .with_precision_and_scale(7, 4)
+ .unwrap(),
+ ) as ArrayRef;
+ assert_eq!(&expected, &result);
+ }
+
#[test]
fn test_timestamp_second_add_interval() {
// timestamp second + interval year month