You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/01/10 13:59:08 UTC

[arrow-rs] branch master updated: Refactoring build_compare for decimal and using downcast_primitive (#3484)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new e8cc351af Refactoring build_compare for decimal and using downcast_primitive (#3484)
e8cc351af is described below

commit e8cc351af662515f7ff9e25b6eb1e609f89b6bc8
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Tue Jan 10 05:59:01 2023 -0800

    Refactoring build_compare for decimal and using downcast_primitive (#3484)
    
    * Refactor build_compare for decimal and add dict support
    
    * Simplify code using downcast_primitive
---
 arrow-ord/src/ord.rs | 162 +++++++++++++++++++++++++--------------------------
 1 file changed, 80 insertions(+), 82 deletions(-)

diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs
index 6122f9cb3..b7737c6de 100644
--- a/arrow-ord/src/ord.rs
+++ b/arrow-ord/src/ord.rs
@@ -153,6 +153,12 @@ where
     })
 }
 
+macro_rules! cmp_dict_primitive_helper {
+    ($t:ty, $key_type_lhs:expr, $left:expr, $right:expr) => {
+        cmp_dict_primitive::<$t>($key_type_lhs, $left, $right)?
+    };
+}
+
 /// returns a comparison function that compares two values at two different positions
 /// between the two arrays.
 /// The arrays' types must be equal.
@@ -193,6 +199,12 @@ pub fn build_compare(
         (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
         (Float32, Float32) => compare_float::<Float32Type>(left, right),
         (Float64, Float64) => compare_float::<Float64Type>(left, right),
+        (Decimal128(_, _), Decimal128(_, _)) => {
+            compare_primitives::<Decimal128Type>(left, right)
+        }
+        (Decimal256(_, _), Decimal256(_, _)) => {
+            compare_primitives::<Decimal256Type>(left, right)
+        }
         (Date32, Date32) => compare_primitives::<Date32Type>(left, right),
         (Date64, Date64) => compare_primitives::<Date64Type>(left, right),
         (Time32(Second), Time32(Second)) => {
@@ -253,83 +265,8 @@ pub fn build_compare(
             }
 
             let key_type_lhs = key_type_lhs.as_ref();
-
-            match value_type_lhs.as_ref() {
-                Int8 => cmp_dict_primitive::<Int8Type>(key_type_lhs, left, right)?,
-                Int16 => cmp_dict_primitive::<Int16Type>(key_type_lhs, left, right)?,
-                Int32 => cmp_dict_primitive::<Int32Type>(key_type_lhs, left, right)?,
-                Int64 => cmp_dict_primitive::<Int64Type>(key_type_lhs, left, right)?,
-                UInt8 => cmp_dict_primitive::<UInt8Type>(key_type_lhs, left, right)?,
-                UInt16 => cmp_dict_primitive::<UInt16Type>(key_type_lhs, left, right)?,
-                UInt32 => cmp_dict_primitive::<UInt32Type>(key_type_lhs, left, right)?,
-                UInt64 => cmp_dict_primitive::<UInt64Type>(key_type_lhs, left, right)?,
-                Float32 => cmp_dict_primitive::<Float32Type>(key_type_lhs, left, right)?,
-                Float64 => cmp_dict_primitive::<Float64Type>(key_type_lhs, left, right)?,
-                Date32 => cmp_dict_primitive::<Date32Type>(key_type_lhs, left, right)?,
-                Date64 => cmp_dict_primitive::<Date64Type>(key_type_lhs, left, right)?,
-                Time32(Second) => {
-                    cmp_dict_primitive::<Time32SecondType>(key_type_lhs, left, right)?
-                }
-                Time32(Millisecond) => cmp_dict_primitive::<Time32MillisecondType>(
-                    key_type_lhs,
-                    left,
-                    right,
-                )?,
-                Time64(Microsecond) => cmp_dict_primitive::<Time64MicrosecondType>(
-                    key_type_lhs,
-                    left,
-                    right,
-                )?,
-                Time64(Nanosecond) => {
-                    cmp_dict_primitive::<Time64NanosecondType>(key_type_lhs, left, right)?
-                }
-                Timestamp(Second, _) => {
-                    cmp_dict_primitive::<TimestampSecondType>(key_type_lhs, left, right)?
-                }
-                Timestamp(Millisecond, _) => cmp_dict_primitive::<
-                    TimestampMillisecondType,
-                >(key_type_lhs, left, right)?,
-                Timestamp(Microsecond, _) => cmp_dict_primitive::<
-                    TimestampMicrosecondType,
-                >(key_type_lhs, left, right)?,
-                Timestamp(Nanosecond, _) => {
-                    cmp_dict_primitive::<TimestampNanosecondType>(
-                        key_type_lhs,
-                        left,
-                        right,
-                    )?
-                }
-                Interval(YearMonth) => cmp_dict_primitive::<IntervalYearMonthType>(
-                    key_type_lhs,
-                    left,
-                    right,
-                )?,
-                Interval(DayTime) => {
-                    cmp_dict_primitive::<IntervalDayTimeType>(key_type_lhs, left, right)?
-                }
-                Interval(MonthDayNano) => cmp_dict_primitive::<IntervalMonthDayNanoType>(
-                    key_type_lhs,
-                    left,
-                    right,
-                )?,
-                Duration(Second) => {
-                    cmp_dict_primitive::<DurationSecondType>(key_type_lhs, left, right)?
-                }
-                Duration(Millisecond) => cmp_dict_primitive::<DurationMillisecondType>(
-                    key_type_lhs,
-                    left,
-                    right,
-                )?,
-                Duration(Microsecond) => cmp_dict_primitive::<DurationMicrosecondType>(
-                    key_type_lhs,
-                    left,
-                    right,
-                )?,
-                Duration(Nanosecond) => cmp_dict_primitive::<DurationNanosecondType>(
-                    key_type_lhs,
-                    left,
-                    right,
-                )?,
+            downcast_primitive! {
+                value_type_lhs.as_ref() => (cmp_dict_primitive_helper, key_type_lhs, left, right),
                 Utf8 => match key_type_lhs {
                     UInt8 => compare_dict_string::<UInt8Type>(left, right),
                     UInt16 => compare_dict_string::<UInt16Type>(left, right),
@@ -354,11 +291,6 @@ pub fn build_compare(
                 }
             }
         }
-        (Decimal128(_, _), Decimal128(_, _)) => {
-            let left: Decimal128Array = Decimal128Array::from(left.data().clone());
-            let right: Decimal128Array = Decimal128Array::from(right.data().clone());
-            Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
-        }
         (FixedSizeBinary(_), FixedSizeBinary(_)) => {
             let left: FixedSizeBinaryArray =
                 FixedSizeBinaryArray::from(left.data().clone());
@@ -380,6 +312,7 @@ pub fn build_compare(
 pub mod tests {
     use super::*;
     use arrow_array::{FixedSizeBinaryArray, Float64Array, Int32Array};
+    use arrow_buffer::i256;
     use std::cmp::Ordering;
 
     #[test]
@@ -464,6 +397,23 @@ pub mod tests {
         assert_eq!(Ordering::Greater, (cmp)(0, 2));
     }
 
+    #[test]
+    fn test_decimali256() {
+        let array = vec![
+            Some(i256::from_i128(5_i128)),
+            Some(i256::from_i128(2_i128)),
+            Some(i256::from_i128(3_i128)),
+        ]
+        .into_iter()
+        .collect::<Decimal256Array>()
+        .with_precision_and_scale(53, 6)
+        .unwrap();
+
+        let cmp = build_compare(&array, &array).unwrap();
+        assert_eq!(Ordering::Less, (cmp)(1, 0));
+        assert_eq!(Ordering::Greater, (cmp)(0, 2));
+    }
+
     #[test]
     fn test_dict() {
         let data = vec!["a", "b", "c", "a", "a", "c", "c"];
@@ -584,4 +534,52 @@ pub mod tests {
         assert_eq!(Ordering::Greater, (cmp)(3, 1));
         assert_eq!(Ordering::Greater, (cmp)(3, 2));
     }
+
+    #[test]
+    fn test_decimal_dict() {
+        let values = Decimal128Array::from(vec![1, 0, 2, 5]);
+        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
+        let array1 = DictionaryArray::<Int8Type>::try_new(&keys, &values).unwrap();
+
+        let values = Decimal128Array::from(vec![2, 3, 4, 5]);
+        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
+        let array2 = DictionaryArray::<Int8Type>::try_new(&keys, &values).unwrap();
+
+        let cmp = build_compare(&array1, &array2).unwrap();
+
+        assert_eq!(Ordering::Less, (cmp)(0, 0));
+        assert_eq!(Ordering::Less, (cmp)(0, 3));
+        assert_eq!(Ordering::Equal, (cmp)(3, 3));
+        assert_eq!(Ordering::Greater, (cmp)(3, 1));
+        assert_eq!(Ordering::Greater, (cmp)(3, 2));
+    }
+
+    #[test]
+    fn test_decimal256_dict() {
+        let values = Decimal256Array::from(vec![
+            i256::from_i128(1),
+            i256::from_i128(0),
+            i256::from_i128(2),
+            i256::from_i128(5),
+        ]);
+        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
+        let array1 = DictionaryArray::<Int8Type>::try_new(&keys, &values).unwrap();
+
+        let values = Decimal256Array::from(vec![
+            i256::from_i128(2),
+            i256::from_i128(3),
+            i256::from_i128(4),
+            i256::from_i128(5),
+        ]);
+        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
+        let array2 = DictionaryArray::<Int8Type>::try_new(&keys, &values).unwrap();
+
+        let cmp = build_compare(&array1, &array2).unwrap();
+
+        assert_eq!(Ordering::Less, (cmp)(0, 0));
+        assert_eq!(Ordering::Less, (cmp)(0, 3));
+        assert_eq!(Ordering::Equal, (cmp)(3, 3));
+        assert_eq!(Ordering::Greater, (cmp)(3, 1));
+        assert_eq!(Ordering::Greater, (cmp)(3, 2));
+    }
 }