You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2022/10/06 17:45:28 UTC
[arrow-rs] branch master updated: Add NaN handling in dyn scalar comparison kernels (#2830)
This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new c93ce3956 Add NaN handling in dyn scalar comparison kernels (#2830)
c93ce3956 is described below
commit c93ce39567b73c56f11d4e731de102c532e3654d
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Thu Oct 6 10:45:23 2022 -0700
Add NaN handling in dyn scalar comparison kernels (#2830)
* Add NaN handling in dyn scalar comparison kernels
* Use trait
* Fix clippy
* Prepare for simd and non-simd tests
* Restore flight protocal files.
* Use ArrowNativeTypeOp as type bound
* Remove num::ToPrimitive addition
---
arrow/benches/comparison_kernels.rs | 74 ++++++---------
arrow/src/compute/kernels/comparison.rs | 153 ++++++++++++++++++++++++++++----
arrow/src/datatypes/native.rs | 60 ++++++++++++-
3 files changed, 217 insertions(+), 70 deletions(-)
diff --git a/arrow/benches/comparison_kernels.rs b/arrow/benches/comparison_kernels.rs
index 4ad139b87..e2afa99fb 100644
--- a/arrow/benches/comparison_kernels.rs
+++ b/arrow/benches/comparison_kernels.rs
@@ -33,13 +33,6 @@ where
eq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap();
}
-fn bench_eq_scalar<T>(arr_a: &PrimitiveArray<T>, value_b: T::Native)
-where
- T: ArrowNumericType,
-{
- eq_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)).unwrap();
-}
-
fn bench_neq<T>(arr_a: &PrimitiveArray<T>, arr_b: &PrimitiveArray<T>)
where
T: ArrowNumericType,
@@ -47,13 +40,6 @@ where
neq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap();
}
-fn bench_neq_scalar<T>(arr_a: &PrimitiveArray<T>, value_b: T::Native)
-where
- T: ArrowNumericType,
-{
- neq_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)).unwrap();
-}
-
fn bench_lt<T>(arr_a: &PrimitiveArray<T>, arr_b: &PrimitiveArray<T>)
where
T: ArrowNumericType,
@@ -61,13 +47,6 @@ where
lt(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap();
}
-fn bench_lt_scalar<T>(arr_a: &PrimitiveArray<T>, value_b: T::Native)
-where
- T: ArrowNumericType,
-{
- lt_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)).unwrap();
-}
-
fn bench_lt_eq<T>(arr_a: &PrimitiveArray<T>, arr_b: &PrimitiveArray<T>)
where
T: ArrowNumericType,
@@ -75,13 +54,6 @@ where
lt_eq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap();
}
-fn bench_lt_eq_scalar<T>(arr_a: &PrimitiveArray<T>, value_b: T::Native)
-where
- T: ArrowNumericType,
-{
- lt_eq_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)).unwrap();
-}
-
fn bench_gt<T>(arr_a: &PrimitiveArray<T>, arr_b: &PrimitiveArray<T>)
where
T: ArrowNumericType,
@@ -89,13 +61,6 @@ where
gt(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap();
}
-fn bench_gt_scalar<T>(arr_a: &PrimitiveArray<T>, value_b: T::Native)
-where
- T: ArrowNumericType,
-{
- gt_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)).unwrap();
-}
-
fn bench_gt_eq<T>(arr_a: &PrimitiveArray<T>, arr_b: &PrimitiveArray<T>)
where
T: ArrowNumericType,
@@ -103,13 +68,6 @@ where
gt_eq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap();
}
-fn bench_gt_eq_scalar<T>(arr_a: &PrimitiveArray<T>, value_b: T::Native)
-where
- T: ArrowNumericType,
-{
- gt_eq_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)).unwrap();
-}
-
fn bench_like_utf8_scalar(arr_a: &StringArray, value_b: &str) {
like_utf8_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)).unwrap();
}
@@ -164,39 +122,57 @@ fn add_benchmark(c: &mut Criterion) {
c.bench_function("eq Float32", |b| b.iter(|| bench_eq(&arr_a, &arr_b)));
c.bench_function("eq scalar Float32", |b| {
- b.iter(|| bench_eq_scalar(&arr_a, 1.0))
+ b.iter(|| {
+ eq_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap()
+ })
});
c.bench_function("neq Float32", |b| b.iter(|| bench_neq(&arr_a, &arr_b)));
c.bench_function("neq scalar Float32", |b| {
- b.iter(|| bench_neq_scalar(&arr_a, 1.0))
+ b.iter(|| {
+ neq_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap()
+ })
});
c.bench_function("lt Float32", |b| b.iter(|| bench_lt(&arr_a, &arr_b)));
c.bench_function("lt scalar Float32", |b| {
- b.iter(|| bench_lt_scalar(&arr_a, 1.0))
+ b.iter(|| {
+ lt_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap()
+ })
});
c.bench_function("lt_eq Float32", |b| b.iter(|| bench_lt_eq(&arr_a, &arr_b)));
c.bench_function("lt_eq scalar Float32", |b| {
- b.iter(|| bench_lt_eq_scalar(&arr_a, 1.0))
+ b.iter(|| {
+ lt_eq_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap()
+ })
});
c.bench_function("gt Float32", |b| b.iter(|| bench_gt(&arr_a, &arr_b)));
c.bench_function("gt scalar Float32", |b| {
- b.iter(|| bench_gt_scalar(&arr_a, 1.0))
+ b.iter(|| {
+ gt_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap()
+ })
});
c.bench_function("gt_eq Float32", |b| b.iter(|| bench_gt_eq(&arr_a, &arr_b)));
c.bench_function("gt_eq scalar Float32", |b| {
- b.iter(|| bench_gt_eq_scalar(&arr_a, 1.0))
+ b.iter(|| {
+ gt_eq_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap()
+ })
});
c.bench_function("eq MonthDayNano", |b| {
b.iter(|| bench_eq(&arr_month_day_nano_a, &arr_month_day_nano_b))
});
c.bench_function("eq scalar MonthDayNano", |b| {
- b.iter(|| bench_eq_scalar(&arr_month_day_nano_a, 123))
+ b.iter(|| {
+ eq_scalar(
+ criterion::black_box(&arr_month_day_nano_a),
+ criterion::black_box(123),
+ )
+ .unwrap()
+ })
});
c.bench_function("like_utf8 scalar equals", |b| {
diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs
index 1ea433150..d1d1e470e 100644
--- a/arrow/src/compute/kernels/comparison.rs
+++ b/arrow/src/compute/kernels/comparison.rs
@@ -27,12 +27,13 @@ use crate::array::*;
use crate::buffer::{buffer_unary_not, Buffer, MutableBuffer};
use crate::compute::util::combine_option_bitmap;
use crate::datatypes::{
- ArrowNativeType, ArrowNumericType, DataType, Date32Type, Date64Type, Float32Type,
- Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType,
- IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Time32MillisecondType,
- Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit,
- TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
- TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+ native_op::ArrowNativeTypeOp, ArrowNativeType, ArrowNumericType, DataType,
+ Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
+ Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
+ IntervalYearMonthType, Time32MillisecondType, Time32SecondType,
+ Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
+ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type,
+ UInt32Type, UInt64Type, UInt8Type,
};
#[allow(unused_imports)]
use crate::downcast_dictionary_array;
@@ -1328,7 +1329,12 @@ macro_rules! dyn_compare_utf8_scalar {
}
/// Perform `left == right` operation on an array and a numeric scalar
-/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values
+/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values.
+///
+/// If `simd` feature flag is not enabled:
+/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
+/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
where
T: num::ToPrimitive + std::fmt::Debug,
@@ -1342,7 +1348,12 @@ where
}
/// Perform `left < right` operation on an array and a numeric scalar
-/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values
+/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values.
+///
+/// If `simd` feature flag is not enabled:
+/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
+/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
pub fn lt_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
where
T: num::ToPrimitive + std::fmt::Debug,
@@ -1356,7 +1367,12 @@ where
}
/// Perform `left <= right` operation on an array and a numeric scalar
-/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values
+/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values.
+///
+/// If `simd` feature flag is not enabled:
+/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
+/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
pub fn lt_eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
where
T: num::ToPrimitive + std::fmt::Debug,
@@ -1370,7 +1386,12 @@ where
}
/// Perform `left > right` operation on an array and a numeric scalar
-/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values
+/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values.
+///
+/// If `simd` feature flag is not enabled:
+/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
+/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
pub fn gt_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
where
T: num::ToPrimitive + std::fmt::Debug,
@@ -1384,7 +1405,12 @@ where
}
/// Perform `left >= right` operation on an array and a numeric scalar
-/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values
+/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values.
+///
+/// If `simd` feature flag is not enabled:
+/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
+/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
pub fn gt_eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
where
T: num::ToPrimitive + std::fmt::Debug,
@@ -1398,7 +1424,12 @@ where
}
/// Perform `left != right` operation on an array and a numeric scalar
-/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values
+/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values.
+///
+/// If `simd` feature flag is not enabled:
+/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
+/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
pub fn neq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
where
T: num::ToPrimitive + std::fmt::Debug,
@@ -3016,14 +3047,20 @@ where
}
/// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar value.
+///
+/// If `simd` feature flag is not enabled:
+/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
+/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
pub fn eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
T: ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar(left, |a| a == right);
+ return compare_op_scalar(left, |a| a.is_eq(right));
}
/// Applies an unary and infallible comparison function to a primitive array.
@@ -3047,14 +3084,20 @@ where
}
/// Perform `left != right` operation on a [`PrimitiveArray`] and a scalar value.
+///
+/// If `simd` feature flag is not enabled:
+/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
+/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
pub fn neq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
T: ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar(left, |a| a != right);
+ return compare_op_scalar(left, |a| a.is_ne(right));
}
/// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
@@ -3071,14 +3114,20 @@ where
/// Perform `left < right` operation on a [`PrimitiveArray`] and a scalar value.
/// Null values are less than non-null values.
+///
+/// If `simd` feature flag is not enabled:
+/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
+/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
pub fn lt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
T: ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar(left, |a| a < right);
+ return compare_op_scalar(left, |a| a.is_lt(right));
}
/// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
@@ -3098,14 +3147,20 @@ where
/// Perform `left <= right` operation on a [`PrimitiveArray`] and a scalar value.
/// Null values are less than non-null values.
+///
+/// If `simd` feature flag is not enabled:
+/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
+/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
pub fn lt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
T: ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar(left, |a| a <= right);
+ return compare_op_scalar(left, |a| a.is_le(right));
}
/// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
@@ -3122,14 +3177,20 @@ where
/// Perform `left > right` operation on a [`PrimitiveArray`] and a scalar value.
/// Non-null values are greater than null values.
+///
+/// If `simd` feature flag is not enabled:
+/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
+/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
pub fn gt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
T: ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar(left, |a| a > right);
+ return compare_op_scalar(left, |a| a.is_gt(right));
}
/// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
@@ -3149,14 +3210,20 @@ where
/// Perform `left >= right` operation on a [`PrimitiveArray`] and a scalar value.
/// Non-null values are greater than null values.
+///
+/// If `simd` feature flag is not enabled:
+/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
+/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
pub fn gt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
T: ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar(left, |a| a >= right);
+ return compare_op_scalar(left, |a| a.is_ge(right));
}
/// Checks if a [`GenericListArray`] contains a value in the [`PrimitiveArray`]
@@ -5848,28 +5915,48 @@ mod tests {
.into_iter()
.map(Some)
.collect();
+ #[cfg(feature = "simd")]
let expected = BooleanArray::from(
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
);
+ #[cfg(not(feature = "simd"))]
+ let expected = BooleanArray::from(
+ vec![Some(true), Some(false), Some(false), Some(false), Some(false)],
+ );
assert_eq!(eq_dyn_scalar(&array, f32::NAN).unwrap(), expected);
+ #[cfg(feature = "simd")]
let expected = BooleanArray::from(
vec![Some(true), Some(true), Some(true), Some(true), Some(true)],
);
+ #[cfg(not(feature = "simd"))]
+ let expected = BooleanArray::from(
+ vec![Some(false), Some(true), Some(true), Some(true), Some(true)],
+ );
assert_eq!(neq_dyn_scalar(&array, f32::NAN).unwrap(), expected);
let array: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 10.0]
.into_iter()
.map(Some)
.collect();
+ #[cfg(feature = "simd")]
let expected = BooleanArray::from(
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
);
+ #[cfg(not(feature = "simd"))]
+ let expected = BooleanArray::from(
+ vec![Some(true), Some(false), Some(false), Some(false), Some(false)],
+ );
assert_eq!(eq_dyn_scalar(&array, f64::NAN).unwrap(), expected);
+ #[cfg(feature = "simd")]
let expected = BooleanArray::from(
vec![Some(true), Some(true), Some(true), Some(true), Some(true)],
);
+ #[cfg(not(feature = "simd"))]
+ let expected = BooleanArray::from(
+ vec![Some(false), Some(true), Some(true), Some(true), Some(true)],
+ );
assert_eq!(neq_dyn_scalar(&array, f64::NAN).unwrap(), expected);
}
@@ -5879,28 +5966,48 @@ mod tests {
.into_iter()
.map(Some)
.collect();
+ #[cfg(feature = "simd")]
let expected = BooleanArray::from(
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
);
+ #[cfg(not(feature = "simd"))]
+ let expected = BooleanArray::from(
+ vec![Some(false), Some(true), Some(true), Some(true), Some(true)],
+ );
assert_eq!(lt_dyn_scalar(&array, f32::NAN).unwrap(), expected);
+ #[cfg(feature = "simd")]
let expected = BooleanArray::from(
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
);
+ #[cfg(not(feature = "simd"))]
+ let expected = BooleanArray::from(
+ vec![Some(true), Some(true), Some(true), Some(true), Some(true)],
+ );
assert_eq!(lt_eq_dyn_scalar(&array, f32::NAN).unwrap(), expected);
let array: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 10.0]
.into_iter()
.map(Some)
.collect();
+ #[cfg(feature = "simd")]
let expected = BooleanArray::from(
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
);
+ #[cfg(not(feature = "simd"))]
+ let expected = BooleanArray::from(
+ vec![Some(false), Some(true), Some(true), Some(true), Some(true)],
+ );
assert_eq!(lt_dyn_scalar(&array, f64::NAN).unwrap(), expected);
+ #[cfg(feature = "simd")]
let expected = BooleanArray::from(
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
);
+ #[cfg(not(feature = "simd"))]
+ let expected = BooleanArray::from(
+ vec![Some(true), Some(true), Some(true), Some(true), Some(true)],
+ );
assert_eq!(lt_eq_dyn_scalar(&array, f64::NAN).unwrap(), expected);
}
@@ -5915,9 +6022,14 @@ mod tests {
);
assert_eq!(gt_dyn_scalar(&array, f32::NAN).unwrap(), expected);
+ #[cfg(feature = "simd")]
let expected = BooleanArray::from(
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
);
+ #[cfg(not(feature = "simd"))]
+ let expected = BooleanArray::from(
+ vec![Some(true), Some(false), Some(false), Some(false), Some(false)],
+ );
assert_eq!(gt_eq_dyn_scalar(&array, f32::NAN).unwrap(), expected);
let array: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 10.0]
@@ -5929,9 +6041,14 @@ mod tests {
);
assert_eq!(gt_dyn_scalar(&array, f64::NAN).unwrap(), expected);
+ #[cfg(feature = "simd")]
let expected = BooleanArray::from(
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
);
+ #[cfg(not(feature = "simd"))]
+ let expected = BooleanArray::from(
+ vec![Some(true), Some(false), Some(false), Some(false), Some(false)],
+ );
assert_eq!(gt_eq_dyn_scalar(&array, f64::NAN).unwrap(), expected);
}
diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs
index 654b93950..374d0b950 100644
--- a/arrow/src/datatypes/native.rs
+++ b/arrow/src/datatypes/native.rs
@@ -94,6 +94,30 @@ pub(crate) mod native_op {
fn mod_wrapping(self, rhs: Self) -> Self {
self % rhs
}
+
+ fn is_eq(self, rhs: Self) -> bool {
+ self == rhs
+ }
+
+ fn is_ne(self, rhs: Self) -> bool {
+ self != rhs
+ }
+
+ fn is_lt(self, rhs: Self) -> bool {
+ self < rhs
+ }
+
+ fn is_le(self, rhs: Self) -> bool {
+ self <= rhs
+ }
+
+ fn is_gt(self, rhs: Self) -> bool {
+ self > rhs
+ }
+
+ fn is_ge(self, rhs: Self) -> bool {
+ self >= rhs
+ }
}
}
@@ -186,6 +210,36 @@ native_type_op!(u16);
native_type_op!(u32);
native_type_op!(u64);
-impl native_op::ArrowNativeTypeOp for f16 {}
-impl native_op::ArrowNativeTypeOp for f32 {}
-impl native_op::ArrowNativeTypeOp for f64 {}
+macro_rules! native_type_float_op {
+ ($t:tt) => {
+ impl native_op::ArrowNativeTypeOp for $t {
+ fn is_eq(self, rhs: Self) -> bool {
+ self.total_cmp(&rhs).is_eq()
+ }
+
+ fn is_ne(self, rhs: Self) -> bool {
+ self.total_cmp(&rhs).is_ne()
+ }
+
+ fn is_lt(self, rhs: Self) -> bool {
+ self.total_cmp(&rhs).is_lt()
+ }
+
+ fn is_le(self, rhs: Self) -> bool {
+ self.total_cmp(&rhs).is_le()
+ }
+
+ fn is_gt(self, rhs: Self) -> bool {
+ self.total_cmp(&rhs).is_gt()
+ }
+
+ fn is_ge(self, rhs: Self) -> bool {
+ self.total_cmp(&rhs).is_ge()
+ }
+ }
+ };
+}
+
+native_type_float_op!(f16);
+native_type_float_op!(f32);
+native_type_float_op!(f64);