You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/07/25 15:15:12 UTC
[arrow-rs] branch master updated: Use ArrayAccessor in Comparison Kernels (#2157)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 4a47ab78b Use ArrayAccessor in Comparison Kernels (#2157)
4a47ab78b is described below
commit 4a47ab78b98b8b0c95d27bee3f468421757b85cd
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Mon Jul 25 08:15:06 2022 -0700
Use ArrayAccessor in Comparison Kernels (#2157)
* Use ArrayAccessor
* More
* Fix clippy
---
arrow/src/compute/kernels/comparison.rs | 521 ++++++++++----------------------
1 file changed, 162 insertions(+), 359 deletions(-)
diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs
index 5344e160c..7733ce67a 100644
--- a/arrow/src/compute/kernels/comparison.rs
+++ b/arrow/src/compute/kernels/comparison.rs
@@ -24,8 +24,7 @@
//!
use crate::array::*;
-use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer, MutableBuffer};
-use crate::compute::binary_boolean_kernel;
+use crate::buffer::{buffer_unary_not, Buffer, MutableBuffer};
use crate::compute::util::combine_option_bitmap;
use crate::datatypes::{
ArrowNativeType, ArrowNumericType, DataType, Date32Type, Date64Type, Float32Type,
@@ -37,171 +36,74 @@ use crate::datatypes::{
use crate::error::{ArrowError, Result};
use crate::util::bit_util;
use regex::{escape, Regex};
-use std::any::type_name;
use std::collections::HashMap;
-/// Helper function to perform boolean lambda function on values from two arrays, this
+/// Helper function to perform boolean lambda function on values from two array accessors, this
/// version does not attempt to use SIMD.
-macro_rules! compare_op {
- ($left: expr, $right:expr, $op:expr) => {{
- if $left.len() != $right.len() {
- return Err(ArrowError::ComputeError(
- "Cannot perform comparison operation on arrays of different length"
- .to_string(),
- ));
- }
-
- let null_bit_buffer =
- combine_option_bitmap(&[$left.data_ref(), $right.data_ref()], $left.len())?;
-
- // Safety:
- // `i < $left.len()` and $left.len() == $right.len()
- let comparison = (0..$left.len())
- .map(|i| unsafe { $op($left.value_unchecked(i), $right.value_unchecked(i)) });
- // same size as $left.len() and $right.len()
- let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) };
-
- let data = unsafe {
- ArrayData::new_unchecked(
- DataType::Boolean,
- $left.len(),
- None,
- null_bit_buffer,
- 0,
- vec![Buffer::from(buffer)],
- vec![],
- )
- };
- Ok(BooleanArray::from(data))
- }};
-}
+fn compare_op<T: ArrayAccessor, F>(left: T, right: T, op: F) -> Result<BooleanArray>
+where
+ F: Fn(T::Item, T::Item) -> bool,
+{
+ if left.len() != right.len() {
+ return Err(ArrowError::ComputeError(
+ "Cannot perform comparison operation on arrays of different length"
+ .to_string(),
+ ));
+ }
-macro_rules! compare_op_primitive {
- ($left: expr, $right:expr, $op:expr) => {{
- if $left.len() != $right.len() {
- return Err(ArrowError::ComputeError(
- "Cannot perform comparison operation on arrays of different length"
- .to_string(),
- ));
- }
+ let null_bit_buffer =
+ combine_option_bitmap(&[left.data_ref(), right.data_ref()], left.len())?;
- let null_bit_buffer =
- combine_option_bitmap(&[$left.data_ref(), $right.data_ref()], $left.len())?;
-
- let mut values = MutableBuffer::from_len_zeroed(($left.len() + 7) / 8);
- let lhs_chunks_iter = $left.values().chunks_exact(8);
- let lhs_remainder = lhs_chunks_iter.remainder();
- let rhs_chunks_iter = $right.values().chunks_exact(8);
- let rhs_remainder = rhs_chunks_iter.remainder();
- let chunks = $left.len() / 8;
-
- values[..chunks]
- .iter_mut()
- .zip(lhs_chunks_iter)
- .zip(rhs_chunks_iter)
- .for_each(|((byte, lhs), rhs)| {
- lhs.iter()
- .zip(rhs.iter())
- .enumerate()
- .for_each(|(i, (&lhs, &rhs))| {
- *byte |= if $op(lhs, rhs) { 1 << i } else { 0 };
- });
- });
+ // Safety:
+ // `i < $left.len()` and $left.len() == $right.len()
+ let comparison = (0..left.len())
+ .map(|i| unsafe { op(left.value_unchecked(i), right.value_unchecked(i)) });
+ // same size as $left.len() and $right.len()
+ let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) };
- if !lhs_remainder.is_empty() {
- let last = &mut values[chunks];
- lhs_remainder
- .iter()
- .zip(rhs_remainder.iter())
- .enumerate()
- .for_each(|(i, (&lhs, &rhs))| {
- *last |= if $op(lhs, rhs) { 1 << i } else { 0 };
- });
- };
- let data = unsafe {
- ArrayData::new_unchecked(
- DataType::Boolean,
- $left.len(),
- None,
- null_bit_buffer,
- 0,
- vec![Buffer::from(values)],
- vec![],
- )
- };
- Ok(BooleanArray::from(data))
- }};
+ let data = unsafe {
+ ArrayData::new_unchecked(
+ DataType::Boolean,
+ left.len(),
+ None,
+ null_bit_buffer,
+ 0,
+ vec![Buffer::from(buffer)],
+ vec![],
+ )
+ };
+ Ok(BooleanArray::from(data))
}
-macro_rules! compare_op_scalar {
- ($left:expr, $op:expr) => {{
- let null_bit_buffer = $left
- .data()
- .null_buffer()
- .map(|b| b.bit_slice($left.offset(), $left.len()));
-
- // Safety:
- // `i < $left.len()`
- let comparison =
- (0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i)) });
- // same as $left.len()
- let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) };
-
- let data = unsafe {
- ArrayData::new_unchecked(
- DataType::Boolean,
- $left.len(),
- None,
- null_bit_buffer,
- 0,
- vec![Buffer::from(buffer)],
- vec![],
- )
- };
- Ok(BooleanArray::from(data))
- }};
-}
+/// Helper function to perform boolean lambda function on values from array accessor, this
+/// version does not attempt to use SIMD.
+fn compare_op_scalar<T: ArrayAccessor, F>(left: T, op: F) -> Result<BooleanArray>
+where
+ F: Fn(T::Item) -> bool,
+{
+ let null_bit_buffer = left
+ .data()
+ .null_buffer()
+ .map(|b| b.bit_slice(left.offset(), left.len()));
-macro_rules! compare_op_scalar_primitive {
- ($left: expr, $right:expr, $op:expr) => {{
- let null_bit_buffer = $left
- .data()
- .null_buffer()
- .map(|b| b.bit_slice($left.offset(), $left.len()));
-
- let mut values = MutableBuffer::from_len_zeroed(($left.len() + 7) / 8);
- let lhs_chunks_iter = $left.values().chunks_exact(8);
- let lhs_remainder = lhs_chunks_iter.remainder();
- let chunks = $left.len() / 8;
-
- values[..chunks]
- .iter_mut()
- .zip(lhs_chunks_iter)
- .for_each(|(byte, chunk)| {
- chunk.iter().enumerate().for_each(|(i, &c_i)| {
- *byte |= if $op(c_i, $right) { 1 << i } else { 0 };
- });
- });
- if !lhs_remainder.is_empty() {
- let last = &mut values[chunks];
- lhs_remainder.iter().enumerate().for_each(|(i, &lhs)| {
- *last |= if $op(lhs, $right) { 1 << i } else { 0 };
- });
- };
+ // Safety:
+ // `i < $left.len()`
+ let comparison = (0..left.len()).map(|i| unsafe { op(left.value_unchecked(i)) });
+ // same as $left.len()
+ let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) };
- let data = unsafe {
- ArrayData::new_unchecked(
- DataType::Boolean,
- $left.len(),
- None,
- null_bit_buffer,
- 0,
- vec![Buffer::from(values)],
- vec![],
- )
- };
- Ok(BooleanArray::from(data))
- }};
+ let data = unsafe {
+ ArrayData::new_unchecked(
+ DataType::Boolean,
+ left.len(),
+ None,
+ null_bit_buffer,
+ 0,
+ vec![Buffer::from(buffer)],
+ vec![],
+ )
+ };
+ Ok(BooleanArray::from(data))
}
/// Evaluate `op(left, right)` for [`PrimitiveArray`]s using a specified
@@ -215,7 +117,7 @@ where
T: ArrowNumericType,
F: Fn(T::Native, T::Native) -> bool,
{
- compare_op_primitive!(left, right, op)
+ compare_op(left, right, op)
}
/// Evaluate `op(left, right)` for [`PrimitiveArray`] and scalar using
@@ -229,7 +131,7 @@ where
T: ArrowNumericType,
F: Fn(T::Native, T::Native) -> bool,
{
- compare_op_scalar_primitive!(left, right, op)
+ compare_op_scalar(left, |l| op(l, right))
}
fn is_like_pattern(c: char) -> bool {
@@ -769,7 +671,7 @@ pub fn eq_utf8<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,
) -> Result<BooleanArray> {
- compare_op!(left, right, |a, b| a == b)
+ compare_op(left, right, |a, b| a == b)
}
/// Perform `left == right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar.
@@ -777,66 +679,37 @@ pub fn eq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a| a == right)
-}
-
-#[inline]
-fn binary_boolean_op<F>(
- left: &BooleanArray,
- right: &BooleanArray,
- op: F,
-) -> Result<BooleanArray>
-where
- F: Copy + Fn(u64, u64) -> u64,
-{
- binary_boolean_kernel(
- left,
- right,
- |left: &Buffer,
- left_offset_in_bits: usize,
- right: &Buffer,
- right_offset_in_bits: usize,
- len_in_bits: usize| {
- bitwise_bin_op_helper(
- left,
- left_offset_in_bits,
- right,
- right_offset_in_bits,
- len_in_bits,
- op,
- )
- },
- )
+ compare_op_scalar(left, |a| a == right)
}
/// Perform `left == right` operation on [`BooleanArray`]
pub fn eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result<BooleanArray> {
- binary_boolean_op(left, right, |a, b| !(a ^ b))
+ compare_op(left, right, |a, b| !(a ^ b))
}
/// Perform `left != right` operation on [`BooleanArray`]
pub fn neq_bool(left: &BooleanArray, right: &BooleanArray) -> Result<BooleanArray> {
- binary_boolean_op(left, right, |a, b| (a ^ b))
+ compare_op(left, right, |a, b| (a ^ b))
}
/// Perform `left < right` operation on [`BooleanArray`]
pub fn lt_bool(left: &BooleanArray, right: &BooleanArray) -> Result<BooleanArray> {
- binary_boolean_op(left, right, |a, b| ((!a) & b))
+ compare_op(left, right, |a, b| ((!a) & b))
}
/// Perform `left <= right` operation on [`BooleanArray`]
pub fn lt_eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result<BooleanArray> {
- binary_boolean_op(left, right, |a, b| !(a & (!b)))
+ compare_op(left, right, |a, b| !(a & (!b)))
}
/// Perform `left > right` operation on [`BooleanArray`]
pub fn gt_bool(left: &BooleanArray, right: &BooleanArray) -> Result<BooleanArray> {
- binary_boolean_op(left, right, |a, b| (a & (!b)))
+ compare_op(left, right, |a, b| (a & (!b)))
}
/// Perform `left >= right` operation on [`BooleanArray`]
pub fn gt_eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result<BooleanArray> {
- binary_boolean_op(left, right, |a, b| !((!a) & b))
+ compare_op(left, right, |a, b| !((!a) & b))
}
/// Perform `left == right` operation on [`BooleanArray`] and a scalar
@@ -870,22 +743,22 @@ pub fn eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray>
/// Perform `left < right` operation on [`BooleanArray`] and a scalar
pub fn lt_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a: bool| !a & right)
+ compare_op_scalar(left, |a: bool| !a & right)
}
/// Perform `left <= right` operation on [`BooleanArray`] and a scalar
pub fn lt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a| a <= right)
+ compare_op_scalar(left, |a| a <= right)
}
/// Perform `left > right` operation on [`BooleanArray`] and a scalar
pub fn gt_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a: bool| a & !right)
+ compare_op_scalar(left, |a: bool| a & !right)
}
/// Perform `left >= right` operation on [`BooleanArray`] and a scalar
pub fn gt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a| a >= right)
+ compare_op_scalar(left, |a| a >= right)
}
/// Perform `left != right` operation on [`BooleanArray`] and a scalar
@@ -898,7 +771,7 @@ pub fn eq_binary<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &GenericBinaryArray<OffsetSize>,
) -> Result<BooleanArray> {
- compare_op!(left, right, |a, b| a == b)
+ compare_op(left, right, |a, b| a == b)
}
/// Perform `left == right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar
@@ -906,7 +779,7 @@ pub fn eq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a| a == right)
+ compare_op_scalar(left, |a| a == right)
}
/// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
@@ -914,7 +787,7 @@ pub fn neq_binary<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &GenericBinaryArray<OffsetSize>,
) -> Result<BooleanArray> {
- compare_op!(left, right, |a, b| a != b)
+ compare_op(left, right, |a, b| a != b)
}
/// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar.
@@ -922,7 +795,7 @@ pub fn neq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a| a != right)
+ compare_op_scalar(left, |a| a != right)
}
/// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
@@ -930,7 +803,7 @@ pub fn lt_binary<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &GenericBinaryArray<OffsetSize>,
) -> Result<BooleanArray> {
- compare_op!(left, right, |a, b| a < b)
+ compare_op(left, right, |a, b| a < b)
}
/// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar.
@@ -938,7 +811,7 @@ pub fn lt_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a| a < right)
+ compare_op_scalar(left, |a| a < right)
}
/// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
@@ -946,7 +819,7 @@ pub fn lt_eq_binary<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &GenericBinaryArray<OffsetSize>,
) -> Result<BooleanArray> {
- compare_op!(left, right, |a, b| a <= b)
+ compare_op(left, right, |a, b| a <= b)
}
/// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar.
@@ -954,7 +827,7 @@ pub fn lt_eq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a| a <= right)
+ compare_op_scalar(left, |a| a <= right)
}
/// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
@@ -962,7 +835,7 @@ pub fn gt_binary<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &GenericBinaryArray<OffsetSize>,
) -> Result<BooleanArray> {
- compare_op!(left, right, |a, b| a > b)
+ compare_op(left, right, |a, b| a > b)
}
/// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar.
@@ -970,7 +843,7 @@ pub fn gt_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a| a > right)
+ compare_op_scalar(left, |a| a > right)
}
/// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
@@ -978,7 +851,7 @@ pub fn gt_eq_binary<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &GenericBinaryArray<OffsetSize>,
) -> Result<BooleanArray> {
- compare_op!(left, right, |a, b| a >= b)
+ compare_op(left, right, |a, b| a >= b)
}
/// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar.
@@ -986,7 +859,7 @@ pub fn gt_eq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a| a >= right)
+ compare_op_scalar(left, |a| a >= right)
}
/// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`].
@@ -994,7 +867,7 @@ pub fn neq_utf8<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,
) -> Result<BooleanArray> {
- compare_op!(left, right, |a, b| a != b)
+ compare_op(left, right, |a, b| a != b)
}
/// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar.
@@ -1002,7 +875,7 @@ pub fn neq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a| a != right)
+ compare_op_scalar(left, |a| a != right)
}
/// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`].
@@ -1010,7 +883,7 @@ pub fn lt_utf8<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,
) -> Result<BooleanArray> {
- compare_op!(left, right, |a, b| a < b)
+ compare_op(left, right, |a, b| a < b)
}
/// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar.
@@ -1018,7 +891,7 @@ pub fn lt_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a| a < right)
+ compare_op_scalar(left, |a| a < right)
}
/// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`].
@@ -1026,7 +899,7 @@ pub fn lt_eq_utf8<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,
) -> Result<BooleanArray> {
- compare_op!(left, right, |a, b| a <= b)
+ compare_op(left, right, |a, b| a <= b)
}
/// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar.
@@ -1034,7 +907,7 @@ pub fn lt_eq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a| a <= right)
+ compare_op_scalar(left, |a| a <= right)
}
/// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`].
@@ -1042,7 +915,7 @@ pub fn gt_utf8<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,
) -> Result<BooleanArray> {
- compare_op!(left, right, |a, b| a > b)
+ compare_op(left, right, |a, b| a > b)
}
/// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar.
@@ -1050,7 +923,7 @@ pub fn gt_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a| a > right)
+ compare_op_scalar(left, |a| a > right)
}
/// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`].
@@ -1058,7 +931,7 @@ pub fn gt_eq_utf8<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,
) -> Result<BooleanArray> {
- compare_op!(left, right, |a, b| a >= b)
+ compare_op(left, right, |a, b| a >= b)
}
/// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar.
@@ -1066,7 +939,7 @@ pub fn gt_eq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
- compare_op_scalar!(left, |a| a >= right)
+ compare_op_scalar(left, |a| a >= right)
}
/// Calls $RIGHT.$TY() (e.g. `right.to_i128()`) with a nice error message.
@@ -1931,177 +1804,107 @@ where
Ok(BooleanArray::from(data))
}
-macro_rules! typed_cmp {
- ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident) => {{
- let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
- ArrowError::CastError(format!(
- "Left array cannot be cast to {}",
- type_name::<$T>()
- ))
- })?;
- let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
- ArrowError::CastError(format!(
- "Right array cannot be cast to {}",
- type_name::<$T>(),
- ))
- })?;
- $OP(left, right)
- }};
- ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
- let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
- ArrowError::CastError(format!(
- "Left array cannot be cast to {}",
- type_name::<$T>()
- ))
- })?;
- let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
- ArrowError::CastError(format!(
- "Right array cannot be cast to {}",
- type_name::<$T>(),
- ))
- })?;
- $OP::<$TT>(left, right)
- }};
+fn cmp_primitive_array<T: ArrowNumericType, F>(
+ left: &dyn Array,
+ right: &dyn Array,
+ op: F,
+) -> Result<BooleanArray>
+where
+ F: Fn(T::Native, T::Native) -> bool,
+{
+ let left_array = as_primitive_array::<T>(left);
+ let right_array = as_primitive_array::<T>(right);
+ compare_op(left_array, right_array, op)
}
macro_rules! typed_compares {
- ($LEFT: expr, $RIGHT: expr, $OP_BOOL: ident, $OP_PRIM: ident, $OP_STR: ident, $OP_BINARY: ident) => {{
+ ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr) => {{
match ($LEFT.data_type(), $RIGHT.data_type()) {
(DataType::Boolean, DataType::Boolean) => {
- typed_cmp!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+ compare_op(as_boolean_array($LEFT), as_boolean_array($RIGHT), $OP_BOOL)
}
(DataType::Int8, DataType::Int8) => {
- typed_cmp!($LEFT, $RIGHT, Int8Array, $OP_PRIM, Int8Type)
+ cmp_primitive_array::<Int8Type, _>($LEFT, $RIGHT, $OP)
}
(DataType::Int16, DataType::Int16) => {
- typed_cmp!($LEFT, $RIGHT, Int16Array, $OP_PRIM, Int16Type)
+ cmp_primitive_array::<Int16Type, _>($LEFT, $RIGHT, $OP)
}
(DataType::Int32, DataType::Int32) => {
- typed_cmp!($LEFT, $RIGHT, Int32Array, $OP_PRIM, Int32Type)
+ cmp_primitive_array::<Int32Type, _>($LEFT, $RIGHT, $OP)
}
(DataType::Int64, DataType::Int64) => {
- typed_cmp!($LEFT, $RIGHT, Int64Array, $OP_PRIM, Int64Type)
+ cmp_primitive_array::<Int64Type, _>($LEFT, $RIGHT, $OP)
}
(DataType::UInt8, DataType::UInt8) => {
- typed_cmp!($LEFT, $RIGHT, UInt8Array, $OP_PRIM, UInt8Type)
+ cmp_primitive_array::<UInt8Type, _>($LEFT, $RIGHT, $OP)
}
(DataType::UInt16, DataType::UInt16) => {
- typed_cmp!($LEFT, $RIGHT, UInt16Array, $OP_PRIM, UInt16Type)
+ cmp_primitive_array::<UInt16Type, _>($LEFT, $RIGHT, $OP)
}
(DataType::UInt32, DataType::UInt32) => {
- typed_cmp!($LEFT, $RIGHT, UInt32Array, $OP_PRIM, UInt32Type)
+ cmp_primitive_array::<UInt32Type, _>($LEFT, $RIGHT, $OP)
}
(DataType::UInt64, DataType::UInt64) => {
- typed_cmp!($LEFT, $RIGHT, UInt64Array, $OP_PRIM, UInt64Type)
+ cmp_primitive_array::<UInt64Type, _>($LEFT, $RIGHT, $OP)
}
(DataType::Float32, DataType::Float32) => {
- typed_cmp!($LEFT, $RIGHT, Float32Array, $OP_PRIM, Float32Type)
+ cmp_primitive_array::<Float32Type, _>($LEFT, $RIGHT, $OP)
}
(DataType::Float64, DataType::Float64) => {
- typed_cmp!($LEFT, $RIGHT, Float64Array, $OP_PRIM, Float64Type)
+ cmp_primitive_array::<Float64Type, _>($LEFT, $RIGHT, $OP)
}
(DataType::Utf8, DataType::Utf8) => {
- typed_cmp!($LEFT, $RIGHT, StringArray, $OP_STR, i32)
- }
- (DataType::LargeUtf8, DataType::LargeUtf8) => {
- typed_cmp!($LEFT, $RIGHT, LargeStringArray, $OP_STR, i64)
- }
- (DataType::Binary, DataType::Binary) => {
- typed_cmp!($LEFT, $RIGHT, BinaryArray, $OP_BINARY, i32)
- }
- (DataType::LargeBinary, DataType::LargeBinary) => {
- typed_cmp!($LEFT, $RIGHT, LargeBinaryArray, $OP_BINARY, i64)
+ compare_op(as_string_array($LEFT), as_string_array($RIGHT), $OP)
}
+ (DataType::LargeUtf8, DataType::LargeUtf8) => compare_op(
+ as_largestring_array($LEFT),
+ as_largestring_array($RIGHT),
+ $OP,
+ ),
+ (DataType::Binary, DataType::Binary) => compare_op(
+ as_generic_binary_array::<i32>($LEFT),
+ as_generic_binary_array::<i32>($RIGHT),
+ $OP,
+ ),
+ (DataType::LargeBinary, DataType::LargeBinary) => compare_op(
+ as_generic_binary_array::<i64>($LEFT),
+ as_generic_binary_array::<i64>($RIGHT),
+ $OP,
+ ),
(
DataType::Timestamp(TimeUnit::Nanosecond, _),
DataType::Timestamp(TimeUnit::Nanosecond, _),
- ) => {
- typed_cmp!(
- $LEFT,
- $RIGHT,
- TimestampNanosecondArray,
- $OP_PRIM,
- TimestampNanosecondType
- )
- }
+ ) => cmp_primitive_array::<TimestampNanosecondType, _>($LEFT, $RIGHT, $OP),
(
DataType::Timestamp(TimeUnit::Microsecond, _),
DataType::Timestamp(TimeUnit::Microsecond, _),
- ) => {
- typed_cmp!(
- $LEFT,
- $RIGHT,
- TimestampMicrosecondArray,
- $OP_PRIM,
- TimestampMicrosecondType
- )
- }
+ ) => cmp_primitive_array::<TimestampMicrosecondType, _>($LEFT, $RIGHT, $OP),
(
DataType::Timestamp(TimeUnit::Millisecond, _),
DataType::Timestamp(TimeUnit::Millisecond, _),
- ) => {
- typed_cmp!(
- $LEFT,
- $RIGHT,
- TimestampMillisecondArray,
- $OP_PRIM,
- TimestampMillisecondType
- )
- }
+ ) => cmp_primitive_array::<TimestampMillisecondType, _>($LEFT, $RIGHT, $OP),
(
DataType::Timestamp(TimeUnit::Second, _),
DataType::Timestamp(TimeUnit::Second, _),
- ) => {
- typed_cmp!(
- $LEFT,
- $RIGHT,
- TimestampSecondArray,
- $OP_PRIM,
- TimestampSecondType
- )
- }
+ ) => cmp_primitive_array::<TimestampSecondType, _>($LEFT, $RIGHT, $OP),
(DataType::Date32, DataType::Date32) => {
- typed_cmp!($LEFT, $RIGHT, Date32Array, $OP_PRIM, Date32Type)
+ cmp_primitive_array::<Date32Type, _>($LEFT, $RIGHT, $OP)
}
(DataType::Date64, DataType::Date64) => {
- typed_cmp!($LEFT, $RIGHT, Date64Array, $OP_PRIM, Date64Type)
+ cmp_primitive_array::<Date64Type, _>($LEFT, $RIGHT, $OP)
}
(
DataType::Interval(IntervalUnit::YearMonth),
DataType::Interval(IntervalUnit::YearMonth),
- ) => {
- typed_cmp!(
- $LEFT,
- $RIGHT,
- IntervalYearMonthArray,
- $OP_PRIM,
- IntervalYearMonthType
- )
- }
+ ) => cmp_primitive_array::<IntervalYearMonthType, _>($LEFT, $RIGHT, $OP),
(
DataType::Interval(IntervalUnit::DayTime),
DataType::Interval(IntervalUnit::DayTime),
- ) => {
- typed_cmp!(
- $LEFT,
- $RIGHT,
- IntervalDayTimeArray,
- $OP_PRIM,
- IntervalDayTimeType
- )
- }
+ ) => cmp_primitive_array::<IntervalDayTimeType, _>($LEFT, $RIGHT, $OP),
(
DataType::Interval(IntervalUnit::MonthDayNano),
DataType::Interval(IntervalUnit::MonthDayNano),
- ) => {
- typed_cmp!(
- $LEFT,
- $RIGHT,
- IntervalMonthDayNanoArray,
- $OP_PRIM,
- IntervalMonthDayNanoType
- )
- }
+ ) => cmp_primitive_array::<IntervalMonthDayNanoType, _>($LEFT, $RIGHT, $OP),
(t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
"Comparing arrays of type {} is not yet implemented",
t1
@@ -2410,7 +2213,7 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a == b, |a, b| a == b)
}
- _ => typed_compares!(left, right, eq_bool, eq, eq_utf8, eq_binary),
+ _ => typed_compares!(left, right, |a, b| !(a ^ b), |a, b| a == b),
}
}
@@ -2435,7 +2238,7 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a != b, |a, b| a != b)
}
- _ => typed_compares!(left, right, neq_bool, neq, neq_utf8, neq_binary),
+ _ => typed_compares!(left, right, |a, b| (a ^ b), |a, b| a != b),
}
}
@@ -2460,7 +2263,7 @@ pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a < b, |a, b| a < b)
}
- _ => typed_compares!(left, right, lt_bool, lt, lt_utf8, lt_binary),
+ _ => typed_compares!(left, right, |a, b| ((!a) & b), |a, b| a < b),
}
}
@@ -2484,7 +2287,7 @@ pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a <= b, |a, b| a <= b)
}
- _ => typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8, lt_eq_binary),
+ _ => typed_compares!(left, right, |a, b| !(a & (!b)), |a, b| a <= b),
}
}
@@ -2508,7 +2311,7 @@ pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a > b, |a, b| a > b)
}
- _ => typed_compares!(left, right, gt_bool, gt, gt_utf8, gt_binary),
+ _ => typed_compares!(left, right, |a, b| (a & (!b)), |a, b| a > b),
}
}
@@ -2531,7 +2334,7 @@ pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a >= b, |a, b| a >= b)
}
- _ => typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8, gt_eq_binary),
+ _ => typed_compares!(left, right, |a, b| !((!a) & b), |a, b| a >= b),
}
}
@@ -2543,7 +2346,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op(left, right, T::eq, |a, b| a == b);
#[cfg(not(feature = "simd"))]
- return compare_op!(left, right, |a, b| a == b);
+ return compare_op(left, right, |a, b| a == b);
}
/// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar value.
@@ -2554,7 +2357,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar!(left, |a| a == right);
+ return compare_op_scalar(left, |a| a == right);
}
/// Applies an unary and infallible comparison function to a primitive array.
@@ -2563,7 +2366,7 @@ where
T: ArrowNumericType,
F: Fn(T::Native) -> bool,
{
- return compare_op_scalar!(left, op);
+ compare_op_scalar(left, op)
}
/// Perform `left != right` operation on two [`PrimitiveArray`]s.
@@ -2574,7 +2377,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op(left, right, T::ne, |a, b| a != b);
#[cfg(not(feature = "simd"))]
- return compare_op!(left, right, |a, b| a != b);
+ return compare_op(left, right, |a, b| a != b);
}
/// Perform `left != right` operation on a [`PrimitiveArray`] and a scalar value.
@@ -2585,7 +2388,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar!(left, |a| a != right);
+ return compare_op_scalar(left, |a| a != right);
}
/// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
@@ -2597,7 +2400,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op(left, right, T::lt, |a, b| a < b);
#[cfg(not(feature = "simd"))]
- return compare_op!(left, right, |a, b| a < b);
+ return compare_op(left, right, |a, b| a < b);
}
/// Perform `left < right` operation on a [`PrimitiveArray`] and a scalar value.
@@ -2609,7 +2412,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar!(left, |a| a < right);
+ return compare_op_scalar(left, |a| a < right);
}
/// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
@@ -2624,7 +2427,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op(left, right, T::le, |a, b| a <= b);
#[cfg(not(feature = "simd"))]
- return compare_op!(left, right, |a, b| a <= b);
+ return compare_op(left, right, |a, b| a <= b);
}
/// Perform `left <= right` operation on a [`PrimitiveArray`] and a scalar value.
@@ -2636,7 +2439,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar!(left, |a| a <= right);
+ return compare_op_scalar(left, |a| a <= right);
}
/// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
@@ -2648,7 +2451,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op(left, right, T::gt, |a, b| a > b);
#[cfg(not(feature = "simd"))]
- return compare_op!(left, right, |a, b| a > b);
+ return compare_op(left, right, |a, b| a > b);
}
/// Perform `left > right` operation on a [`PrimitiveArray`] and a scalar value.
@@ -2660,7 +2463,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar!(left, |a| a > right);
+ return compare_op_scalar(left, |a| a > right);
}
/// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
@@ -2675,7 +2478,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op(left, right, T::ge, |a, b| a >= b);
#[cfg(not(feature = "simd"))]
- return compare_op!(left, right, |a, b| a >= b);
+ return compare_op(left, right, |a, b| a >= b);
}
/// Perform `left >= right` operation on a [`PrimitiveArray`] and a scalar value.
@@ -2687,7 +2490,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b);
#[cfg(not(feature = "simd"))]
- return compare_op_scalar!(left, |a| a >= right);
+ return compare_op_scalar(left, |a| a >= right);
}
/// Checks if a [`GenericListArray`] contains a value in the [`PrimitiveArray`]