You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2022/07/04 19:51:37 UTC

[arrow-rs] branch master updated: Add unary_cmp (#1991)

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 932ffc5d1 Add unary_cmp (#1991)
932ffc5d1 is described below

commit 932ffc5d1360818c480de8dcf165772ad50d4359
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Mon Jul 4 12:51:33 2022 -0700

    Add unary_cmp (#1991)
    
    * Add unary_cmp
    
    * Fix clippy
    
    * Trigger Build
    
    * Trigger Build
    
    * Trigger Build
---
 arrow/src/compute/kernels/comparison.rs | 69 +++++++++++++++++++++------------
 1 file changed, 45 insertions(+), 24 deletions(-)

diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs
index 068b9dedf..0a6d60cea 100644
--- a/arrow/src/compute/kernels/comparison.rs
+++ b/arrow/src/compute/kernels/comparison.rs
@@ -134,7 +134,7 @@ macro_rules! compare_op_primitive {
 }
 
 macro_rules! compare_op_scalar {
-    ($left:expr, $right:expr, $op:expr) => {{
+    ($left:expr, $op:expr) => {{
         let null_bit_buffer = $left
             .data()
             .null_buffer()
@@ -143,7 +143,7 @@ macro_rules! compare_op_scalar {
         // Safety:
         // `i < $left.len()`
         let comparison =
-            (0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i), $right) });
+            (0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i)) });
         // same as $left.len()
         let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) };
 
@@ -777,7 +777,7 @@ pub fn eq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
     left: &GenericStringArray<OffsetSize>,
     right: &str,
 ) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a, b| a == b)
+    compare_op_scalar!(left, |a| a == right)
 }
 
 #[inline]
@@ -870,22 +870,22 @@ pub fn eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray>
 
 /// Perform `left < right` operation on [`BooleanArray`] and a scalar
 pub fn lt_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a: bool, b: bool| !a & b)
+    compare_op_scalar!(left, |a: bool| !a & right)
 }
 
 /// Perform `left <= right` operation on [`BooleanArray`] and a scalar
 pub fn lt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a, b| a <= b)
+    compare_op_scalar!(left, |a| a <= right)
 }
 
 /// Perform `left > right` operation on [`BooleanArray`] and a scalar
 pub fn gt_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a: bool, b: bool| a & !b)
+    compare_op_scalar!(left, |a: bool| a & !right)
 }
 
 /// Perform `left >= right` operation on [`BooleanArray`] and a scalar
 pub fn gt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a, b| a >= b)
+    compare_op_scalar!(left, |a| a >= right)
 }
 
 /// Perform `left != right` operation on [`BooleanArray`] and a scalar
@@ -906,7 +906,7 @@ pub fn eq_binary_scalar<OffsetSize: OffsetSizeTrait>(
     left: &GenericBinaryArray<OffsetSize>,
     right: &[u8],
 ) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a, b| a == b)
+    compare_op_scalar!(left, |a| a == right)
 }
 
 /// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
@@ -922,7 +922,7 @@ pub fn neq_binary_scalar<OffsetSize: OffsetSizeTrait>(
     left: &GenericBinaryArray<OffsetSize>,
     right: &[u8],
 ) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a, b| a != b)
+    compare_op_scalar!(left, |a| a != right)
 }
 
 /// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
@@ -938,7 +938,7 @@ pub fn lt_binary_scalar<OffsetSize: OffsetSizeTrait>(
     left: &GenericBinaryArray<OffsetSize>,
     right: &[u8],
 ) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a, b| a < b)
+    compare_op_scalar!(left, |a| a < right)
 }
 
 /// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
@@ -954,7 +954,7 @@ pub fn lt_eq_binary_scalar<OffsetSize: OffsetSizeTrait>(
     left: &GenericBinaryArray<OffsetSize>,
     right: &[u8],
 ) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a, b| a <= b)
+    compare_op_scalar!(left, |a| a <= right)
 }
 
 /// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
@@ -970,7 +970,7 @@ pub fn gt_binary_scalar<OffsetSize: OffsetSizeTrait>(
     left: &GenericBinaryArray<OffsetSize>,
     right: &[u8],
 ) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a, b| a > b)
+    compare_op_scalar!(left, |a| a > right)
 }
 
 /// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
@@ -986,7 +986,7 @@ pub fn gt_eq_binary_scalar<OffsetSize: OffsetSizeTrait>(
     left: &GenericBinaryArray<OffsetSize>,
     right: &[u8],
 ) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a, b| a >= b)
+    compare_op_scalar!(left, |a| a >= right)
 }
 
 /// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`].
@@ -1002,7 +1002,7 @@ pub fn neq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
     left: &GenericStringArray<OffsetSize>,
     right: &str,
 ) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a, b| a != b)
+    compare_op_scalar!(left, |a| a != right)
 }
 
 /// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`].
@@ -1018,7 +1018,7 @@ pub fn lt_utf8_scalar<OffsetSize: OffsetSizeTrait>(
     left: &GenericStringArray<OffsetSize>,
     right: &str,
 ) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a, b| a < b)
+    compare_op_scalar!(left, |a| a < right)
 }
 
 /// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`].
@@ -1034,7 +1034,7 @@ pub fn lt_eq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
     left: &GenericStringArray<OffsetSize>,
     right: &str,
 ) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a, b| a <= b)
+    compare_op_scalar!(left, |a| a <= right)
 }
 
 /// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`].
@@ -1050,7 +1050,7 @@ pub fn gt_utf8_scalar<OffsetSize: OffsetSizeTrait>(
     left: &GenericStringArray<OffsetSize>,
     right: &str,
 ) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a, b| a > b)
+    compare_op_scalar!(left, |a| a > right)
 }
 
 /// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`].
@@ -1066,7 +1066,7 @@ pub fn gt_eq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
     left: &GenericStringArray<OffsetSize>,
     right: &str,
 ) -> Result<BooleanArray> {
-    compare_op_scalar!(left, right, |a, b| a >= b)
+    compare_op_scalar!(left, |a| a >= right)
 }
 
 /// Calls $RIGHT.$TY() (e.g. `right.to_i128()`) with a nice error message.
@@ -2554,7 +2554,16 @@ where
     #[cfg(feature = "simd")]
     return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b);
     #[cfg(not(feature = "simd"))]
-    return compare_op_scalar!(left, right, |a, b| a == b);
+    return compare_op_scalar!(left, |a| a == right);
+}
+
+/// Applies an unary and infallible comparison function to a primitive array.
+pub fn unary_cmp<T, F>(left: &PrimitiveArray<T>, op: F) -> Result<BooleanArray>
+where
+    T: ArrowNumericType,
+    F: Fn(T::Native) -> bool,
+{
+    return compare_op_scalar!(left, op);
 }
 
 /// Perform `left != right` operation on two [`PrimitiveArray`]s.
@@ -2576,7 +2585,7 @@ where
     #[cfg(feature = "simd")]
     return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b);
     #[cfg(not(feature = "simd"))]
-    return compare_op_scalar!(left, right, |a, b| a != b);
+    return compare_op_scalar!(left, |a| a != right);
 }
 
 /// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
@@ -2600,7 +2609,7 @@ where
     #[cfg(feature = "simd")]
     return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b);
     #[cfg(not(feature = "simd"))]
-    return compare_op_scalar!(left, right, |a, b| a < b);
+    return compare_op_scalar!(left, |a| a < right);
 }
 
 /// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
@@ -2627,7 +2636,7 @@ where
     #[cfg(feature = "simd")]
     return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b);
     #[cfg(not(feature = "simd"))]
-    return compare_op_scalar!(left, right, |a, b| a <= b);
+    return compare_op_scalar!(left, |a| a <= right);
 }
 
 /// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
@@ -2651,7 +2660,7 @@ where
     #[cfg(feature = "simd")]
     return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b);
     #[cfg(not(feature = "simd"))]
-    return compare_op_scalar!(left, right, |a, b| a > b);
+    return compare_op_scalar!(left, |a| a > right);
 }
 
 /// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
@@ -2678,7 +2687,7 @@ where
     #[cfg(feature = "simd")]
     return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b);
     #[cfg(not(feature = "simd"))]
-    return compare_op_scalar!(left, right, |a, b| a >= b);
+    return compare_op_scalar!(left, |a| a >= right);
 }
 
 /// Checks if a [`GenericListArray`] contains a value in the [`PrimitiveArray`]
@@ -5047,4 +5056,16 @@ mod tests {
         let result = gt_eq_dyn(&dict_array1, &dict_array2);
         assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));
     }
+
+    #[test]
+    fn test_unary_cmp() {
+        let a = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]);
+        let values = vec![1_i32, 3];
+
+        let a_eq = unary_cmp(&a, |a| values.contains(&a)).unwrap();
+        assert_eq!(
+            a_eq,
+            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)])
+        );
+    }
 }