You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2022/10/30 01:09:18 UTC

[arrow-rs] branch master updated: Add decimal comparison kernel support (#2978)

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new bcce9dd4f Add decimal comparison kernel support (#2978)
bcce9dd4f is described below

commit bcce9dd4fcf211cb8e0355f3e32bd67931b6c9fa
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Sat Oct 29 18:09:13 2022 -0700

    Add decimal comparison kernel support (#2978)
---
 arrow-array/src/cast.rs                 |  15 +++++
 arrow/src/compute/kernels/comparison.rs | 104 ++++++++++++++++++++++++++++++--
 2 files changed, 113 insertions(+), 6 deletions(-)

diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs
index e4e290501..4436dc77c 100644
--- a/arrow-array/src/cast.rs
+++ b/arrow-array/src/cast.rs
@@ -461,6 +461,7 @@ array_downcast_fn!(as_decimal_array, Decimal128Array);
 
 #[cfg(test)]
 mod tests {
+    use arrow_buffer::i256;
     use std::sync::Arc;
 
     use super::*;
@@ -496,4 +497,18 @@ mod tests {
         let array: ArrayRef = Arc::new(array);
         assert!(!as_string_array(&array).is_empty())
     }
+
+    #[test]
+    fn test_decimal128array() {
+        let a = Decimal128Array::from_iter_values([1, 2, 4, 5]);
+        assert!(!as_primitive_array::<Decimal128Type>(&a).is_empty());
+    }
+
+    #[test]
+    fn test_decimal256array() {
+        let a = Decimal256Array::from_iter_values(
+            [1, 2, 4, 5].into_iter().map(i256::from_i128),
+        );
+        assert!(!as_primitive_array::<Decimal256Type>(&a).is_empty());
+    }
 }
diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs
index 94e7f9660..4d8248a8d 100644
--- a/arrow/src/compute/kernels/comparison.rs
+++ b/arrow/src/compute/kernels/comparison.rs
@@ -27,12 +27,13 @@ use crate::array::*;
 use crate::buffer::{buffer_unary_not, Buffer, MutableBuffer};
 use crate::compute::util::combine_option_bitmap;
 use crate::datatypes::{
-    ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type, Float32Type,
-    Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType,
-    IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Time32MillisecondType,
-    Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit,
-    TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
-    TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+    ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type,
+    Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, Int32Type,
+    Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
+    IntervalYearMonthType, Time32MillisecondType, Time32SecondType,
+    Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
+    TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type,
+    UInt32Type, UInt64Type, UInt8Type,
 };
 #[allow(unused_imports)]
 use crate::downcast_dictionary_array;
@@ -2257,6 +2258,12 @@ macro_rules! typed_compares {
             (DataType::Float64, DataType::Float64) => {
                 cmp_primitive_array::<Float64Type, _>($LEFT, $RIGHT, $OP_FLOAT)
             }
+            (DataType::Decimal128(_, s1), DataType::Decimal128(_, s2)) if s1 == s2 => {
+                cmp_primitive_array::<Decimal128Type, _>($LEFT, $RIGHT, $OP)
+            }
+            (DataType::Decimal256(_, s1), DataType::Decimal256(_, s2)) if s1 == s2 => {
+                cmp_primitive_array::<Decimal256Type, _>($LEFT, $RIGHT, $OP)
+            }
             (DataType::Utf8, DataType::Utf8) => {
                 compare_op(as_string_array($LEFT), as_string_array($RIGHT), $OP)
             }
@@ -3348,6 +3355,7 @@ fn new_all_set_buffer(len: usize) -> Buffer {
 #[rustfmt::skip::macros(vec)]
 #[cfg(test)]
 mod tests {
+    use arrow_buffer::i256;
     use std::sync::Arc;
 
     use super::*;
@@ -6644,4 +6652,88 @@ mod tests {
             BooleanArray::from(vec![Some(true), None, None, Some(true)])
         );
     }
+
+    #[test]
+    fn test_decimal128() {
+        let a = Decimal128Array::from_iter_values([1, 2, 4, 5]);
+        let b = Decimal128Array::from_iter_values([7, -3, 4, 3]);
+        let e = BooleanArray::from(vec![false, false, true, false]);
+        let r = eq(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let r = eq_dyn(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let e = BooleanArray::from(vec![true, false, false, false]);
+        let r = lt(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let r = lt_dyn(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let e = BooleanArray::from(vec![true, false, true, false]);
+        let r = lt_eq(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let r = lt_eq_dyn(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let e = BooleanArray::from(vec![false, true, false, true]);
+        let r = gt(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let r = gt_dyn(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let e = BooleanArray::from(vec![false, true, true, true]);
+        let r = gt_eq(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let r = gt_eq_dyn(&a, &b).unwrap();
+        assert_eq!(e, r);
+    }
+
+    #[test]
+    fn test_decimal256() {
+        let a = Decimal256Array::from_iter_values(
+            [1, 2, 4, 5].into_iter().map(i256::from_i128),
+        );
+        let b = Decimal256Array::from_iter_values(
+            [7, -3, 4, 3].into_iter().map(i256::from_i128),
+        );
+        let e = BooleanArray::from(vec![false, false, true, false]);
+        let r = eq(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let r = eq_dyn(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let e = BooleanArray::from(vec![true, false, false, false]);
+        let r = lt(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let r = lt_dyn(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let e = BooleanArray::from(vec![true, false, true, false]);
+        let r = lt_eq(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let r = lt_eq_dyn(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let e = BooleanArray::from(vec![false, true, false, true]);
+        let r = gt(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let r = gt_dyn(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let e = BooleanArray::from(vec![false, true, true, true]);
+        let r = gt_eq(&a, &b).unwrap();
+        assert_eq!(e, r);
+
+        let r = gt_eq_dyn(&a, &b).unwrap();
+        assert_eq!(e, r);
+    }
 }