You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2023/04/30 07:49:06 UTC

[arrow-rs] branch master updated: Support fixed point multiplication for DictionaryArray of Decimals (#4136)

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 08dc16c96 Support fixed point multiplication for DictionaryArray of Decimals (#4136)
08dc16c96 is described below

commit 08dc16c9645def758e59651bed55f5aa95e2f42e
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Sun Apr 30 00:49:00 2023 -0700

    Support fixed point multiplication for DictionaryArray of Decimals (#4136)
    
    * Add multiply_fixed_point_dyn
    
    * Fix clippy
    
    * For review
---
 arrow-arith/src/arithmetic.rs | 252 +++++++++++++++++++++++++++++++++++++-----
 1 file changed, 222 insertions(+), 30 deletions(-)

diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs
index 7f5a08190..40ae3255b 100644
--- a/arrow-arith/src/arithmetic.rs
+++ b/arrow-arith/src/arithmetic.rs
@@ -1434,6 +1434,114 @@ pub fn multiply_dyn_checked(
     }
 }
 
+#[cfg(feature = "dyn_arith_dict")]
+fn get_precision_scale(dt: &DataType) -> Result<(u8, i8), ArrowError> {
+    match dt {
+        DataType::Decimal128(precision, scale) => Ok((*precision, *scale)),
+        _ => Err(ArrowError::ComputeError(
+            "Cannot get precision and scale from non-decimal type".to_string(),
+        )),
+    }
+}
+
+/// Returns the precision and scale of the result of a multiplication of two decimal types,
+/// and the divisor for fixed point multiplication.
+fn get_fixed_point_info(
+    left: (u8, i8),
+    right: (u8, i8),
+    required_scale: i8,
+) -> Result<(u8, i8, i256), ArrowError> {
+    let product_scale = left.1 + right.1;
+    let precision = min(left.0 + right.0 + 1, DECIMAL128_MAX_PRECISION);
+
+    if required_scale > product_scale {
+        return Err(ArrowError::ComputeError(format!(
+            "Required scale {} is greater than product scale {}",
+            required_scale, product_scale
+        )));
+    }
+
+    let divisor =
+        i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);
+
+    Ok((precision, product_scale, divisor))
+}
+
+#[cfg(feature = "dyn_arith_dict")]
+/// Perform `left * right` operation on two decimal arrays. If either left or right value is
+/// null then the result is also null.
+///
+/// This performs decimal multiplication which allows precision loss if an exact representation
+/// is not possible for the result, according to the required scale. In the case, the result
+/// will be rounded to the required scale.
+///
+/// If the required scale is greater than the product scale, an error is returned.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap around.
+///
+/// It is implemented for compatibility with precision loss `multiply` function provided by
+/// other data processing engines. For multiplication with precision loss detection, use
+/// `multiply_dyn` or `multiply_dyn_checked` instead.
+pub fn multiply_fixed_point_dyn(
+    left: &dyn Array,
+    right: &dyn Array,
+    required_scale: i8,
+) -> Result<ArrayRef, ArrowError> {
+    match (left.data_type(), right.data_type()) {
+        (
+            DataType::Dictionary(_, lhs_value_type),
+            DataType::Dictionary(_, rhs_value_type),
+        ) if matches!(lhs_value_type.as_ref(), &DataType::Decimal128(_, _))
+            && matches!(rhs_value_type.as_ref(), &DataType::Decimal128(_, _)) =>
+        {
+            downcast_dictionary_array!(
+                left => match left.values().data_type() {
+                    DataType::Decimal128(_, _) => {
+                        let lhs_precision_scale = get_precision_scale(lhs_value_type.as_ref())?;
+                        let rhs_precision_scale = get_precision_scale(rhs_value_type.as_ref())?;
+
+                        let (precision, product_scale, divisor) = get_fixed_point_info(lhs_precision_scale, rhs_precision_scale, required_scale)?;
+
+                        let right = as_dictionary_array::<_>(right);
+
+                        if required_scale == product_scale {
+                            let mul = multiply_dyn(left, right)?;
+                            let array = mul.as_any().downcast_ref::<Decimal128Array>().unwrap();
+                            let array = array.clone().with_precision_and_scale(precision, required_scale)?;
+                            return Ok(Arc::new(array))
+                        }
+
+                        let array = math_op_dict::<_, Decimal128Type, _>(left, right, |a, b| {
+                            let a = i256::from_i128(a);
+                            let b = i256::from_i128(b);
+
+                            let mut mul = a.wrapping_mul(b);
+                            mul = divide_and_round::<Decimal256Type>(mul, divisor);
+                            mul.as_i128()
+                        }).and_then(|a| a.with_precision_and_scale(precision, required_scale))?;
+
+                        Ok(Arc::new(array))
+                    }
+                    t => unreachable!("Unsupported dictionary value type {}", t),
+                },
+                t => unreachable!("Unsupported data type {}", t),
+            )
+        }
+        (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => {
+            let left = left.as_any().downcast_ref::<Decimal128Array>().unwrap();
+            let right = right.as_any().downcast_ref::<Decimal128Array>().unwrap();
+
+            multiply_fixed_point(left, right, required_scale)
+                .map(|a| Arc::new(a) as ArrayRef)
+        }
+        (_, _) => Err(ArrowError::CastError(format!(
+            "Unsupported data type {}, {}",
+            left.data_type(),
+            right.data_type()
+        ))),
+    }
+}
+
 /// Perform `left * right` operation on two decimal arrays. If either left or right value is
 /// null then the result is also null.
 ///
@@ -1451,27 +1559,17 @@ pub fn multiply_fixed_point_checked(
     right: &PrimitiveArray<Decimal128Type>,
     required_scale: i8,
 ) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
-    let product_scale = left.scale() + right.scale();
-    let precision = min(
-        left.precision() + right.precision() + 1,
-        DECIMAL128_MAX_PRECISION,
-    );
+    let (precision, product_scale, divisor) = get_fixed_point_info(
+        (left.precision(), left.scale()),
+        (right.precision(), right.scale()),
+        required_scale,
+    )?;
 
     if required_scale == product_scale {
         return multiply_checked(left, right)?
             .with_precision_and_scale(precision, required_scale);
     }
 
-    if required_scale > product_scale {
-        return Err(ArrowError::ComputeError(format!(
-            "Required scale {} is greater than product scale {}",
-            required_scale, product_scale
-        )));
-    }
-
-    let divisor =
-        i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);
-
     try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
         let a = i256::from_i128(a);
         let b = i256::from_i128(b);
@@ -1505,27 +1603,17 @@ pub fn multiply_fixed_point(
     right: &PrimitiveArray<Decimal128Type>,
     required_scale: i8,
 ) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
-    let product_scale = left.scale() + right.scale();
-    let precision = min(
-        left.precision() + right.precision() + 1,
-        DECIMAL128_MAX_PRECISION,
-    );
+    let (precision, product_scale, divisor) = get_fixed_point_info(
+        (left.precision(), left.scale()),
+        (right.precision(), right.scale()),
+        required_scale,
+    )?;
 
     if required_scale == product_scale {
         return multiply(left, right)?
             .with_precision_and_scale(precision, required_scale);
     }
 
-    if required_scale > product_scale {
-        return Err(ArrowError::ComputeError(format!(
-            "Required scale {} is greater than product scale {}",
-            required_scale, product_scale
-        )));
-    }
-
-    let divisor =
-        i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);
-
     binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
         let a = i256::from_i128(a);
         let b = i256::from_i128(b);
@@ -3910,6 +3998,110 @@ mod tests {
         );
     }
 
+    #[test]
+    #[cfg(feature = "dyn_arith_dict")]
+    fn test_decimal_multiply_fixed_point_dyn() {
+        // [123456789]
+        let a = Decimal128Array::from(vec![123456789000000000000000000])
+            .with_precision_and_scale(38, 18)
+            .unwrap();
+
+        // [10]
+        let b = Decimal128Array::from(vec![10000000000000000000])
+            .with_precision_and_scale(38, 18)
+            .unwrap();
+
+        // Avoid overflow by reducing the scale.
+        let result = multiply_fixed_point_dyn(&a, &b, 28).unwrap();
+        // [1234567890]
+        let expected = Arc::new(
+            Decimal128Array::from(vec![12345678900000000000000000000000000000])
+                .with_precision_and_scale(38, 28)
+                .unwrap(),
+        ) as ArrayRef;
+
+        assert_eq!(&expected, &result);
+        assert_eq!(
+            result.as_primitive::<Decimal128Type>().value_as_string(0),
+            "1234567890.0000000000000000000000000000"
+        );
+
+        // [123456789, 10]
+        let a = Decimal128Array::from(vec![
+            123456789000000000000000000,
+            10000000000000000000,
+        ])
+        .with_precision_and_scale(38, 18)
+        .unwrap();
+
+        // [10, 123456789, 12]
+        let b = Decimal128Array::from(vec![
+            10000000000000000000,
+            123456789000000000000000000,
+            12000000000000000000,
+        ])
+        .with_precision_and_scale(38, 18)
+        .unwrap();
+
+        let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1), None]);
+        let array1 = DictionaryArray::new(keys, Arc::new(a));
+        let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(2), None]);
+        let array2 = DictionaryArray::new(keys, Arc::new(b));
+
+        let result = multiply_fixed_point_dyn(&array1, &array2, 28).unwrap();
+        let expected = Arc::new(
+            Decimal128Array::from(vec![
+                Some(12345678900000000000000000000000000000),
+                Some(12345678900000000000000000000000000000),
+                Some(1200000000000000000000000000000),
+                None,
+            ])
+            .with_precision_and_scale(38, 28)
+            .unwrap(),
+        ) as ArrayRef;
+
+        assert_eq!(&expected, &result);
+        assert_eq!(
+            result.as_primitive::<Decimal128Type>().value_as_string(0),
+            "1234567890.0000000000000000000000000000"
+        );
+        assert_eq!(
+            result.as_primitive::<Decimal128Type>().value_as_string(1),
+            "1234567890.0000000000000000000000000000"
+        );
+        assert_eq!(
+            result.as_primitive::<Decimal128Type>().value_as_string(2),
+            "120.0000000000000000000000000000"
+        );
+
+        // Required scale is same as the product of the input scales. Behavior is same as multiply_dyn.
+        let a = Decimal128Array::from(vec![123, 100])
+            .with_precision_and_scale(3, 2)
+            .unwrap();
+
+        let b = Decimal128Array::from(vec![100, 123, 120])
+            .with_precision_and_scale(3, 2)
+            .unwrap();
+
+        let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1), None]);
+        let array1 = DictionaryArray::new(keys, Arc::new(a));
+        let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(2), None]);
+        let array2 = DictionaryArray::new(keys, Arc::new(b));
+
+        let result = multiply_fixed_point_dyn(&array1, &array2, 4).unwrap();
+        let expected = multiply_dyn(&array1, &array2).unwrap();
+        let expected = Arc::new(
+            expected
+                .as_any()
+                .downcast_ref::<Decimal128Array>()
+                .unwrap()
+                .clone()
+                .with_precision_and_scale(7, 4)
+                .unwrap(),
+        ) as ArrayRef;
+        assert_eq!(&expected, &result);
+    }
+
     #[test]
     fn test_timestamp_second_add_interval() {
         // timestamp second + interval year month