You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2022/07/04 19:51:37 UTC
[arrow-rs] branch master updated: Add unary_cmp (#1991)
This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 932ffc5d1 Add unary_cmp (#1991)
932ffc5d1 is described below
commit 932ffc5d1360818c480de8dcf165772ad50d4359
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Mon Jul 4 12:51:33 2022 -0700
Add unary_cmp (#1991)
* Add unary_cmp
* Fix clippy
* Trigger Build
* Trigger Build
* Trigger Build
---
arrow/src/compute/kernels/comparison.rs | 69 +++++++++++++++++++++------------
1 file changed, 45 insertions(+), 24 deletions(-)
diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs
index 068b9dedf..0a6d60cea 100644
--- a/arrow/src/compute/kernels/comparison.rs
+++ b/arrow/src/compute/kernels/comparison.rs
@@ -134,7 +134,7 @@ macro_rules! compare_op_primitive {
}
macro_rules! compare_op_scalar {
- ($left:expr, $right:expr, $op:expr) => {{
+ ($left:expr, $op:expr) => {{
let null_bit_buffer = $left
.data()
.null_buffer()
@@ -143,7 +143,7 @@ macro_rules! compare_op_scalar {
// Safety:
// `i < $left.len()`
let comparison =
- (0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i), $right) });
+ (0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i)) });
// same as $left.len()
let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) };
@@ -777,7 +777,7 @@ pub fn eq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a, b| a == b)
+ compare_op_scalar!(left, |a| a == right)
}
#[inline]
@@ -870,22 +870,22 @@ pub fn eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray>
/// Perform `left < right` operation on [`BooleanArray`] and a scalar
pub fn lt_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a: bool, b: bool| !a & b)
+ compare_op_scalar!(left, |a: bool| !a & right)
}
/// Perform `left <= right` operation on [`BooleanArray`] and a scalar
pub fn lt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a, b| a <= b)
+ compare_op_scalar!(left, |a| a <= right)
}
/// Perform `left > right` operation on [`BooleanArray`] and a scalar
pub fn gt_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a: bool, b: bool| a & !b)
+ compare_op_scalar!(left, |a: bool| a & !right)
}
/// Perform `left >= right` operation on [`BooleanArray`] and a scalar
pub fn gt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a, b| a >= b)
+ compare_op_scalar!(left, |a| a >= right)
}
/// Perform `left != right` operation on [`BooleanArray`] and a scalar
@@ -906,7 +906,7 @@ pub fn eq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a, b| a == b)
+ compare_op_scalar!(left, |a| a == right)
}
/// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
@@ -922,7 +922,7 @@ pub fn neq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a, b| a != b)
+ compare_op_scalar!(left, |a| a != right)
}
/// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
@@ -938,7 +938,7 @@ pub fn lt_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a, b| a < b)
+ compare_op_scalar!(left, |a| a < right)
}
/// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
@@ -954,7 +954,7 @@ pub fn lt_eq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a, b| a <= b)
+ compare_op_scalar!(left, |a| a <= right)
}
/// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
@@ -970,7 +970,7 @@ pub fn gt_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a, b| a > b)
+ compare_op_scalar!(left, |a| a > right)
}
/// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
@@ -986,7 +986,7 @@ pub fn gt_eq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a, b| a >= b)
+ compare_op_scalar!(left, |a| a >= right)
}
/// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`].
@@ -1002,7 +1002,7 @@ pub fn neq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a, b| a != b)
+ compare_op_scalar!(left, |a| a != right)
}
/// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`].
@@ -1018,7 +1018,7 @@ pub fn lt_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a, b| a < b)
+ compare_op_scalar!(left, |a| a < right)
}
/// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`].
@@ -1034,7 +1034,7 @@ pub fn lt_eq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a, b| a <= b)
+ compare_op_scalar!(left, |a| a <= right)
}
/// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`].
@@ -1050,7 +1050,7 @@ pub fn gt_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a, b| a > b)
+ compare_op_scalar!(left, |a| a > right)
}
/// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`].
@@ -1066,7 +1066,7 @@ pub fn gt_eq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
- compare_op_scalar!(left, right, |a, b| a >= b)
+ compare_op_scalar!(left, |a| a >= right)
}
/// Calls $RIGHT.$TY() (e.g. `right.to_i128()`) with a nice error message.
@@ -2554,7 +2554,16 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar!(left, right, |a, b| a == b);
+ return compare_op_scalar!(left, |a| a == right);
+}
+
+/// Applies an unary and infallible comparison function to a primitive array.
+pub fn unary_cmp<T, F>(left: &PrimitiveArray<T>, op: F) -> Result<BooleanArray>
+where
+ T: ArrowNumericType,
+ F: Fn(T::Native) -> bool,
+{
+ return compare_op_scalar!(left, op);
}
/// Perform `left != right` operation on two [`PrimitiveArray`]s.
@@ -2576,7 +2585,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar!(left, right, |a, b| a != b);
+ return compare_op_scalar!(left, |a| a != right);
}
/// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
@@ -2600,7 +2609,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar!(left, right, |a, b| a < b);
+ return compare_op_scalar!(left, |a| a < right);
}
/// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
@@ -2627,7 +2636,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar!(left, right, |a, b| a <= b);
+ return compare_op_scalar!(left, |a| a <= right);
}
/// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
@@ -2651,7 +2660,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar!(left, right, |a, b| a > b);
+ return compare_op_scalar!(left, |a| a > right);
}
/// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
@@ -2678,7 +2687,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar!(left, right, |a, b| a >= b);
+ return compare_op_scalar!(left, |a| a >= right);
}
/// Checks if a [`GenericListArray`] contains a value in the [`PrimitiveArray`]
@@ -5047,4 +5056,16 @@ mod tests {
let result = gt_eq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));
}
+
+ #[test]
+ fn test_unary_cmp() {
+ let a = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]);
+ let values = vec![1_i32, 3];
+
+ let a_eq = unary_cmp(&a, |a| values.contains(&a)).unwrap();
+ assert_eq!(
+ a_eq,
+ BooleanArray::from(vec![Some(true), None, Some(false), Some(true)])
+ );
+ }
}