You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/01/01 11:06:02 UTC
[arrow] branch master updated: ARROW-10990: [Rust] Refactor simd
comparison kernels to avoid out of bounds reads
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 2d28778 ARROW-10990: [Rust] Refactor simd comparison kernels to avoid out of bounds reads
2d28778 is described below
commit 2d28778fe4c7a647ffe45f27de7be1d2b2b9a86a
Author: Jörn Horstmann <gi...@jhorstmann.net>
AuthorDate: Fri Jan 1 06:05:12 2021 -0500
ARROW-10990: [Rust] Refactor simd comparison kernels to avoid out of bounds reads
- [x] Adjust tests so the input data is bigger than one vector lane
- [x] Remove `value_slice` function when ARROW-10989 gets merged
Closes #8975 from jhorstmann/ARROW-10990-compare-kernels-out-of-bounds
Authored-by: Jörn Horstmann <gi...@jhorstmann.net>
Signed-off-by: Andrew Lamb <an...@nerdnetworks.org>
---
rust/arrow/src/compute/kernels/comparison.rs | 492 ++++++++++++++++-----------
rust/arrow/src/datatypes.rs | 21 +-
2 files changed, 302 insertions(+), 211 deletions(-)
diff --git a/rust/arrow/src/compute/kernels/comparison.rs b/rust/arrow/src/compute/kernels/comparison.rs
index 012e4f2..a55c951 100644
--- a/rust/arrow/src/compute/kernels/comparison.rs
+++ b/rust/arrow/src/compute/kernels/comparison.rs
@@ -371,16 +371,18 @@ pub fn gt_eq_utf8_scalar(left: &StringArray, right: &str) -> Result<BooleanArray
/// Helper function to perform boolean lambda function on values from two arrays using
/// SIMD.
#[cfg(simd_x86)]
-fn simd_compare_op<T, F>(
+fn simd_compare_op<T, SIMD_OP, SCALAR_OP>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
- op: F,
+ simd_op: SIMD_OP,
+ scalar_op: SCALAR_OP,
) -> Result<BooleanArray>
where
T: ArrowNumericType,
- F: Fn(T::Simd, T::Simd) -> T::SimdMask,
+ SIMD_OP: Fn(T::Simd, T::Simd) -> T::SimdMask,
+ SCALAR_OP: Fn(T::Native, T::Native) -> bool,
{
- use std::mem;
+ use std::borrow::BorrowMut;
let len = left.len();
if len != right.len() {
@@ -393,34 +395,60 @@ where
let null_bit_buffer = combine_option_bitmap(left.data_ref(), right.data_ref(), len)?;
let lanes = T::lanes();
- let mut result = MutableBuffer::new(left.len() * mem::size_of::<bool>());
-
- let rem = len % lanes;
+ let buffer_size = bit_util::ceil(len, 8);
+ let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
- for i in (0..len - rem).step_by(lanes) {
- let simd_left = T::load(unsafe { left.value_slice(i, lanes) });
- let simd_right = T::load(unsafe { right.value_slice(i, lanes) });
- let simd_result = op(simd_left, simd_right);
- T::bitmask(&simd_result, |b| {
- result.extend_from_slice(b);
- });
- }
+ // this is currently the case for all our datatypes and allows us to always append full bytes
+ assert!(
+ lanes % 8 == 0,
+ "Number of vector lanes must be multiple of 8"
+ );
+ let mut left_chunks = left.values().chunks_exact(lanes);
+ let mut right_chunks = right.values().chunks_exact(lanes);
+
+ let result_remainder = left_chunks
+ .borrow_mut()
+ .zip(right_chunks.borrow_mut())
+ .fold(
+ result.typed_data_mut(),
+ |result_slice, (left_slice, right_slice)| {
+ let simd_left = T::load(left_slice);
+ let simd_right = T::load(right_slice);
+ let simd_result = simd_op(simd_left, simd_right);
+
+ let bitmask = T::mask_to_u64(&simd_result);
+ let bytes = bitmask.to_le_bytes();
+ &result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);
+
+ &mut result_slice[lanes / 8..]
+ },
+ );
- if rem > 0 {
- //Soundness
- // This is not sound because it can read past the end of PrimitiveArray buffer (lanes is always greater than rem), see ARROW-10990
- let simd_left = T::load(unsafe { left.value_slice(len - rem, lanes) });
- let simd_right = T::load(unsafe { right.value_slice(len - rem, lanes) });
- let simd_result = op(simd_left, simd_right);
- let rem_buffer_size = (rem as f32 / 8f32).ceil() as usize;
- T::bitmask(&simd_result, |b| {
- result.extend_from_slice(&b[0..rem_buffer_size]);
+ let left_remainder = left_chunks.remainder();
+ let right_remainder = right_chunks.remainder();
+
+ assert_eq!(left_remainder.len(), right_remainder.len());
+
+ let remainder_bitmask = left_remainder
+ .iter()
+ .zip(right_remainder.iter())
+ .enumerate()
+ .fold(0_u64, |mut mask, (i, (scalar_left, scalar_right))| {
+ let bit = if scalar_op(*scalar_left, *scalar_right) {
+ 1_u64
+ } else {
+ 0_u64
+ };
+ mask |= bit << i;
+ mask
});
- }
+ let remainder_mask_as_bytes =
+ &remainder_bitmask.to_le_bytes()[0..bit_util::ceil(left_remainder.len(), 8)];
+ result_remainder.copy_from_slice(remainder_mask_as_bytes);
let data = ArrayData::new(
DataType::Boolean,
- left.len(),
+ len,
None,
null_bit_buffer,
0,
@@ -433,48 +461,78 @@ where
/// Helper function to perform boolean lambda function on values from an array and a scalar value using
/// SIMD.
#[cfg(simd_x86)]
-fn simd_compare_op_scalar<T, F>(
+fn simd_compare_op_scalar<T, SIMD_OP, SCALAR_OP>(
left: &PrimitiveArray<T>,
right: T::Native,
- op: F,
+ simd_op: SIMD_OP,
+ scalar_op: SCALAR_OP,
) -> Result<BooleanArray>
where
T: ArrowNumericType,
- F: Fn(T::Simd, T::Simd) -> T::SimdMask,
+ SIMD_OP: Fn(T::Simd, T::Simd) -> T::SimdMask,
+ SCALAR_OP: Fn(T::Native, T::Native) -> bool,
{
- use std::mem;
+ use std::borrow::BorrowMut;
let len = left.len();
- let null_bit_buffer = left.data().null_buffer().cloned();
+
let lanes = T::lanes();
- let mut result = MutableBuffer::new(left.len() * mem::size_of::<bool>());
+ let buffer_size = bit_util::ceil(len, 8);
+ let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
+
+ // this is currently the case for all our datatypes and allows us to always append full bytes
+ assert!(
+ lanes % 8 == 0,
+ "Number of vector lanes must be multiple of 8"
+ );
+ let mut left_chunks = left.values().chunks_exact(lanes);
let simd_right = T::init(right);
- let rem = len % lanes;
+ let result_remainder = left_chunks.borrow_mut().fold(
+ result.typed_data_mut(),
+ |result_slice, left_slice| {
+ let simd_left = T::load(left_slice);
+ let simd_result = simd_op(simd_left, simd_right);
- for i in (0..len - rem).step_by(lanes) {
- let simd_left = T::load(unsafe { left.value_slice(i, lanes) });
- let simd_result = op(simd_left, simd_right);
- T::bitmask(&simd_result, |b| {
- result.extend_from_slice(b);
- });
- }
+ let bitmask = T::mask_to_u64(&simd_result);
+ let bytes = bitmask.to_le_bytes();
+ &result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);
- if rem > 0 {
- //Soundness
- // This is not sound because it can read past the end of PrimitiveArray buffer (lanes is always greater than rem), see ARROW-10990
- let simd_left = T::load(unsafe { left.value_slice(len - rem, lanes) });
- let simd_result = op(simd_left, simd_right);
- let rem_buffer_size = (rem as f32 / 8f32).ceil() as usize;
- T::bitmask(&simd_result, |b| {
- result.extend_from_slice(&b[0..rem_buffer_size]);
- });
- }
+ &mut result_slice[lanes / 8..]
+ },
+ );
+
+ let left_remainder = left_chunks.remainder();
+
+ let remainder_bitmask =
+ left_remainder
+ .iter()
+ .enumerate()
+ .fold(0_u64, |mut mask, (i, scalar_left)| {
+ let bit = if scalar_op(*scalar_left, right) {
+ 1_u64
+ } else {
+ 0_u64
+ };
+ mask |= bit << i;
+ mask
+ });
+ let remainder_mask_as_bytes =
+ &remainder_bitmask.to_le_bytes()[0..bit_util::ceil(left_remainder.len(), 8)];
+ result_remainder.copy_from_slice(remainder_mask_as_bytes);
+
+ let null_bit_buffer = left
+ .data_ref()
+ .null_buffer()
+ .map(|b| b.bit_slice(left.offset(), left.len()));
+
+ // null count is the same as in the input since the right side of the scalar comparison cannot be null
+ let null_count = left.null_count();
let data = ArrayData::new(
DataType::Boolean,
- left.len(),
- None,
+ len,
+ Some(null_count),
null_bit_buffer,
0,
vec![result.freeze()],
@@ -489,7 +547,7 @@ where
T: ArrowNumericType,
{
#[cfg(simd_x86)]
- return simd_compare_op(left, right, T::eq);
+ return simd_compare_op(left, right, T::eq, |a, b| a == b);
#[cfg(not(simd_x86))]
return compare_op!(left, right, |a, b| a == b);
}
@@ -500,7 +558,7 @@ where
T: ArrowNumericType,
{
#[cfg(simd_x86)]
- return simd_compare_op_scalar(left, right, T::eq);
+ return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b);
#[cfg(not(simd_x86))]
return compare_op_scalar!(left, right, |a, b| a == b);
}
@@ -511,7 +569,7 @@ where
T: ArrowNumericType,
{
#[cfg(simd_x86)]
- return simd_compare_op(left, right, T::ne);
+ return simd_compare_op(left, right, T::ne, |a, b| a != b);
#[cfg(not(simd_x86))]
return compare_op!(left, right, |a, b| a != b);
}
@@ -522,7 +580,7 @@ where
T: ArrowNumericType,
{
#[cfg(simd_x86)]
- return simd_compare_op_scalar(left, right, T::ne);
+ return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b);
#[cfg(not(simd_x86))]
return compare_op_scalar!(left, right, |a, b| a != b);
}
@@ -534,7 +592,7 @@ where
T: ArrowNumericType,
{
#[cfg(simd_x86)]
- return simd_compare_op(left, right, T::lt);
+ return simd_compare_op(left, right, T::lt, |a, b| a < b);
#[cfg(not(simd_x86))]
return compare_op!(left, right, |a, b| a < b);
}
@@ -546,7 +604,7 @@ where
T: ArrowNumericType,
{
#[cfg(simd_x86)]
- return simd_compare_op_scalar(left, right, T::lt);
+ return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b);
#[cfg(not(simd_x86))]
return compare_op_scalar!(left, right, |a, b| a < b);
}
@@ -561,7 +619,7 @@ where
T: ArrowNumericType,
{
#[cfg(simd_x86)]
- return simd_compare_op(left, right, T::le);
+ return simd_compare_op(left, right, T::le, |a, b| a <= b);
#[cfg(not(simd_x86))]
return compare_op!(left, right, |a, b| a <= b);
}
@@ -573,7 +631,7 @@ where
T: ArrowNumericType,
{
#[cfg(simd_x86)]
- return simd_compare_op_scalar(left, right, T::le);
+ return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b);
#[cfg(not(simd_x86))]
return compare_op_scalar!(left, right, |a, b| a <= b);
}
@@ -585,7 +643,7 @@ where
T: ArrowNumericType,
{
#[cfg(simd_x86)]
- return simd_compare_op(left, right, T::gt);
+ return simd_compare_op(left, right, T::gt, |a, b| a > b);
#[cfg(not(simd_x86))]
return compare_op!(left, right, |a, b| a > b);
}
@@ -597,7 +655,7 @@ where
T: ArrowNumericType,
{
#[cfg(simd_x86)]
- return simd_compare_op_scalar(left, right, T::gt);
+ return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b);
#[cfg(not(simd_x86))]
return compare_op_scalar!(left, right, |a, b| a > b);
}
@@ -612,7 +670,7 @@ where
T: ArrowNumericType,
{
#[cfg(simd_x86)]
- return simd_compare_op(left, right, T::ge);
+ return simd_compare_op(left, right, T::ge, |a, b| a >= b);
#[cfg(not(simd_x86))]
return compare_op!(left, right, |a, b| a >= b);
}
@@ -624,7 +682,7 @@ where
T: ArrowNumericType,
{
#[cfg(simd_x86)]
- return simd_compare_op_scalar(left, right, T::ge);
+ return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b);
#[cfg(not(simd_x86))]
return compare_op_scalar!(left, right, |a, b| a >= b);
}
@@ -752,33 +810,57 @@ fn new_all_set_buffer(len: usize) -> Buffer {
buffer.freeze()
}
+// disable wrapping inside literal vectors used for test data and assertions
+#[rustfmt::skip::macros(vec)]
#[cfg(test)]
mod tests {
use super::*;
use crate::datatypes::{Int8Type, ToByteSlice};
- use crate::{array::Int32Array, datatypes::Field};
+ use crate::{array::Int32Array, array::Int64Array, datatypes::Field};
+
+ /// Evaluate `KERNEL` with two vectors as inputs and assert against the expected output.
+ /// `A_VEC` and `B_VEC` can be of type `Vec<i64>` or `Vec<Option<i64>>`.
+ /// `EXPECTED` can be either `Vec<bool>` or `Vec<Option<bool>>`.
+ /// The main reason for this macro is that inputs and outputs align nicely after `cargo fmt`.
+ macro_rules! cmp_i64 {
+ ($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => {
+ let a = Int64Array::from($A_VEC);
+ let b = Int64Array::from($B_VEC);
+ let c = $KERNEL(&a, &b).unwrap();
+ assert_eq!(BooleanArray::from($EXPECTED), c);
+ };
+ }
+
+ /// Evaluate `KERNEL` with one vectors and one scalar as inputs and assert against the expected output.
+ /// `A_VEC` can be of type `Vec<i64>` or `Vec<Option<i64>>`.
+ /// `EXPECTED` can be either `Vec<bool>` or `Vec<Option<bool>>`.
+ /// The main reason for this macro is that inputs and outputs align nicely after `cargo fmt`.
+ macro_rules! cmp_i64_scalar {
+ ($KERNEL:ident, $A_VEC:expr, $B:literal, $EXPECTED:expr) => {
+ let a = Int64Array::from($A_VEC);
+ let c = $KERNEL(&a, $B).unwrap();
+ assert_eq!(BooleanArray::from($EXPECTED), c);
+ };
+ }
#[test]
fn test_primitive_array_eq() {
- let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
- let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
- let c = eq(&a, &b).unwrap();
- assert_eq!(false, c.value(0));
- assert_eq!(false, c.value(1));
- assert_eq!(true, c.value(2));
- assert_eq!(false, c.value(3));
- assert_eq!(false, c.value(4));
+ cmp_i64!(
+ eq,
+ vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
+ vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+ vec![false, false, true, false, false, false, false, true, false, false]
+ );
}
#[test]
fn test_primitive_array_eq_scalar() {
- let a = Int32Array::from(vec![6, 7, 8, 9, 10]);
- let c = eq_scalar(&a, 8).unwrap();
- assert_eq!(false, c.value(0));
- assert_eq!(false, c.value(1));
- assert_eq!(true, c.value(2));
- assert_eq!(false, c.value(3));
- assert_eq!(false, c.value(4));
+ cmp_i64_scalar!(
+ eq_scalar,
+ vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+ 8,
+ vec![false, false, true, false, false, false, false, true, false, false]
+ );
}
#[test]
@@ -797,193 +879,205 @@ mod tests {
#[test]
fn test_primitive_array_neq() {
- let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
- let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
- let c = neq(&a, &b).unwrap();
- assert_eq!(true, c.value(0));
- assert_eq!(true, c.value(1));
- assert_eq!(false, c.value(2));
- assert_eq!(true, c.value(3));
- assert_eq!(true, c.value(4));
+ cmp_i64!(
+ neq,
+ vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
+ vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+ vec![true, true, false, true, true, true, true, false, true, true]
+ );
}
#[test]
fn test_primitive_array_neq_scalar() {
- let a = Int32Array::from(vec![6, 7, 8, 9, 10]);
- let c = neq_scalar(&a, 8).unwrap();
- assert_eq!(true, c.value(0));
- assert_eq!(true, c.value(1));
- assert_eq!(false, c.value(2));
- assert_eq!(true, c.value(3));
- assert_eq!(true, c.value(4));
+ cmp_i64_scalar!(
+ neq_scalar,
+ vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+ 8,
+ vec![true, true, false, true, true, true, true, false, true, true]
+ );
}
#[test]
fn test_primitive_array_lt() {
- let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
- let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
- let c = lt(&a, &b).unwrap();
- assert_eq!(false, c.value(0));
- assert_eq!(false, c.value(1));
- assert_eq!(false, c.value(2));
- assert_eq!(true, c.value(3));
- assert_eq!(true, c.value(4));
+ cmp_i64!(
+ lt,
+ vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
+ vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+ vec![false, false, false, true, true, false, false, false, true, true]
+ );
}
#[test]
fn test_primitive_array_lt_scalar() {
- let a = Int32Array::from(vec![6, 7, 8, 9, 10]);
- let c = lt_scalar(&a, 8).unwrap();
- assert_eq!(true, c.value(0));
- assert_eq!(true, c.value(1));
- assert_eq!(false, c.value(2));
- assert_eq!(false, c.value(3));
- assert_eq!(false, c.value(4));
+ cmp_i64_scalar!(
+ lt_scalar,
+ vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+ 8,
+ vec![true, true, false, false, false, true, true, false, false, false]
+ );
}
#[test]
fn test_primitive_array_lt_nulls() {
- let a = Int32Array::from(vec![None, None, Some(1)]);
- let b = Int32Array::from(vec![None, Some(1), None]);
- let c = lt(&a, &b).unwrap();
- assert_eq!(false, c.value(0));
- assert_eq!(true, c.value(1));
- assert_eq!(false, c.value(2));
+ cmp_i64!(
+ lt,
+ vec![None, None, Some(1), Some(1), None, None, Some(2), Some(2),],
+ vec![None, Some(1), None, Some(1), None, Some(3), None, Some(3),],
+ vec![None, None, None, Some(false), None, None, None, Some(true)]
+ );
}
#[test]
fn test_primitive_array_lt_scalar_nulls() {
- let a = Int32Array::from(vec![None, Some(1), Some(2)]);
- let c = lt_scalar(&a, 2).unwrap();
- assert_eq!(true, c.value(0));
- assert_eq!(true, c.value(1));
- assert_eq!(false, c.value(2));
+ cmp_i64_scalar!(
+ lt_scalar,
+ vec![None, Some(1), Some(2), Some(3), None, Some(1), Some(2), Some(3), Some(2), None],
+ 2,
+ vec![None, Some(true), Some(false), Some(false), None, Some(true), Some(false), Some(false), Some(false), None]
+ );
}
#[test]
fn test_primitive_array_lt_eq() {
- let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
- let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
- let c = lt_eq(&a, &b).unwrap();
- assert_eq!(false, c.value(0));
- assert_eq!(false, c.value(1));
- assert_eq!(true, c.value(2));
- assert_eq!(true, c.value(3));
- assert_eq!(true, c.value(4));
+ cmp_i64!(
+ lt_eq,
+ vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
+ vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+ vec![false, false, true, true, true, false, false, true, true, true]
+ );
}
#[test]
fn test_primitive_array_lt_eq_scalar() {
- let a = Int32Array::from(vec![6, 7, 8, 9, 10]);
- let c = lt_eq_scalar(&a, 8).unwrap();
- assert_eq!(true, c.value(0));
- assert_eq!(true, c.value(1));
- assert_eq!(true, c.value(2));
- assert_eq!(false, c.value(3));
- assert_eq!(false, c.value(4));
+ cmp_i64_scalar!(
+ lt_eq_scalar,
+ vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+ 8,
+ vec![true, true, true, false, false, true, true, true, false, false]
+ );
}
#[test]
fn test_primitive_array_lt_eq_nulls() {
- let a = Int32Array::from(vec![None, None, Some(1)]);
- let b = Int32Array::from(vec![None, Some(1), None]);
- let c = lt_eq(&a, &b).unwrap();
- assert_eq!(true, c.value(0));
- assert_eq!(true, c.value(1));
- assert_eq!(false, c.value(2));
+ cmp_i64!(
+ lt_eq,
+ vec![None, None, Some(1), None, None, Some(1), None, None, Some(1)],
+ vec![None, Some(1), Some(0), None, Some(1), Some(2), None, None, Some(3)],
+ vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)]
+ );
}
#[test]
fn test_primitive_array_lt_eq_scalar_nulls() {
- let a = Int32Array::from(vec![None, Some(1), Some(2)]);
- let c = lt_eq_scalar(&a, 1).unwrap();
- assert_eq!(true, c.value(0));
- assert_eq!(true, c.value(1));
- assert_eq!(false, c.value(2));
+ cmp_i64_scalar!(
+ lt_eq_scalar,
+ vec![None, Some(1), Some(2), None, Some(1), Some(2), None, Some(1), Some(2)],
+ 1,
+ vec![None, Some(true), Some(false), None, Some(true), Some(false), None, Some(true), Some(false)]
+ );
}
#[test]
fn test_primitive_array_gt() {
- let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
- let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
- let c = gt(&a, &b).unwrap();
- assert_eq!(true, c.value(0));
- assert_eq!(true, c.value(1));
- assert_eq!(false, c.value(2));
- assert_eq!(false, c.value(3));
- assert_eq!(false, c.value(4));
+ cmp_i64!(
+ gt,
+ vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
+ vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+ vec![true, true, false, false, false, true, true, false, false, false]
+ );
}
#[test]
fn test_primitive_array_gt_scalar() {
- let a = Int32Array::from(vec![6, 7, 8, 9, 10]);
- let c = gt_scalar(&a, 8).unwrap();
- assert_eq!(false, c.value(0));
- assert_eq!(false, c.value(1));
- assert_eq!(false, c.value(2));
- assert_eq!(true, c.value(3));
- assert_eq!(true, c.value(4));
+ cmp_i64_scalar!(
+ gt_scalar,
+ vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+ 8,
+ vec![false, false, false, true, true, false, false, false, true, true]
+ );
}
#[test]
fn test_primitive_array_gt_nulls() {
- let a = Int32Array::from(vec![None, None, Some(1)]);
- let b = Int32Array::from(vec![None, Some(1), None]);
- let c = gt(&a, &b).unwrap();
- assert_eq!(false, c.value(0));
- assert_eq!(false, c.value(1));
- assert_eq!(true, c.value(2));
+ cmp_i64!(
+ gt,
+ vec![None, None, Some(1), None, None, Some(2), None, None, Some(3)],
+ vec![None, Some(1), Some(1), None, Some(1), Some(1), None, Some(1), Some(1)],
+ vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)]
+ );
}
#[test]
fn test_primitive_array_gt_scalar_nulls() {
- let a = Int32Array::from(vec![None, Some(1), Some(2)]);
- let c = gt_scalar(&a, 1).unwrap();
- assert_eq!(false, c.value(0));
- assert_eq!(false, c.value(1));
- assert_eq!(true, c.value(2));
+ cmp_i64_scalar!(
+ gt_scalar,
+ vec![None, Some(1), Some(2), None, Some(1), Some(2), None, Some(1), Some(2)],
+ 1,
+ vec![None, Some(false), Some(true), None, Some(false), Some(true), None, Some(false), Some(true)]
+ );
}
#[test]
fn test_primitive_array_gt_eq() {
- let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
- let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
- let c = gt_eq(&a, &b).unwrap();
- assert_eq!(true, c.value(0));
- assert_eq!(true, c.value(1));
- assert_eq!(true, c.value(2));
- assert_eq!(false, c.value(3));
- assert_eq!(false, c.value(4));
+ cmp_i64!(
+ gt_eq,
+ vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
+ vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+ vec![true, true, true, false, false, true, true, true, false, false]
+ );
}
#[test]
fn test_primitive_array_gt_eq_scalar() {
- let a = Int32Array::from(vec![6, 7, 8, 9, 10]);
- let c = gt_eq_scalar(&a, 8).unwrap();
- assert_eq!(false, c.value(0));
- assert_eq!(false, c.value(1));
- assert_eq!(true, c.value(2));
- assert_eq!(true, c.value(3));
- assert_eq!(true, c.value(4));
+ cmp_i64_scalar!(
+ gt_eq_scalar,
+ vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+ 8,
+ vec![false, false, true, true, true, false, false, true, true, true]
+ );
}
#[test]
fn test_primitive_array_gt_eq_nulls() {
- let a = Int32Array::from(vec![None, None, Some(1)]);
- let b = Int32Array::from(vec![None, Some(1), None]);
- let c = gt_eq(&a, &b).unwrap();
- assert_eq!(true, c.value(0));
- assert_eq!(false, c.value(1));
- assert_eq!(true, c.value(2));
+ cmp_i64!(
+ gt_eq,
+ vec![None, None, Some(1), None, Some(1), Some(2), None, None, Some(1)],
+ vec![None, Some(1), None, None, Some(1), Some(1), None, Some(2), Some(2)],
+ vec![None, None, None, None, Some(true), Some(true), None, None, Some(false)]
+ );
}
#[test]
fn test_primitive_array_gt_eq_scalar_nulls() {
- let a = Int32Array::from(vec![None, Some(1), Some(2)]);
- let c = gt_eq_scalar(&a, 1).unwrap();
- assert_eq!(false, c.value(0));
- assert_eq!(true, c.value(1));
- assert_eq!(true, c.value(2));
+ cmp_i64_scalar!(
+ gt_eq_scalar,
+ vec![None, Some(1), Some(2), None, Some(2), Some(3), None, Some(3), Some(4)],
+ 2,
+ vec![None, Some(false), Some(true), None, Some(true), Some(true), None, Some(true), Some(true)]
+ );
+ }
+
+ #[test]
+ fn test_primitive_array_compare_slice() {
+ let a: Int32Array = (0..100).map(Some).collect();
+ let a = a.slice(50, 50);
+ let a = a.as_any().downcast_ref::<Int32Array>().unwrap();
+ let b: Int32Array = (100..200).map(Some).collect();
+ let b = b.slice(50, 50);
+ let b = b.as_any().downcast_ref::<Int32Array>().unwrap();
+ let actual = lt(&a, &b).unwrap();
+ let expected: BooleanArray = (0..50).map(|_| Some(true)).collect();
+ assert_eq!(expected, actual);
+ }
+
+ #[test]
+ fn test_primitive_array_compare_scalar_slice() {
+ let a: Int32Array = (0..100).map(Some).collect();
+ let a = a.slice(50, 50);
+ let a = a.as_any().downcast_ref::<Int32Array>().unwrap();
+ let actual = lt_scalar(&a, 200).unwrap();
+ let expected: BooleanArray = (0..50).map(|_| Some(true)).collect();
+ assert_eq!(expected, actual);
}
#[test]
diff --git a/rust/arrow/src/datatypes.rs b/rust/arrow/src/datatypes.rs
index 125adc4..939cabf 100644
--- a/rust/arrow/src/datatypes.rs
+++ b/rust/arrow/src/datatypes.rs
@@ -542,14 +542,13 @@ where
/// The number of bits used corresponds to the number of lanes of this type
fn mask_from_u64(mask: u64) -> Self::SimdMask;
+ /// Creates a bitmask from the given SIMD mask.
+ /// Each bit corresponds to one vector lane, starting with the least-significant bit.
+ fn mask_to_u64(mask: &Self::SimdMask) -> u64;
+
/// Gets the value of a single lane in a SIMD mask
fn mask_get(mask: &Self::SimdMask, idx: usize) -> bool;
- /// Gets the bitmask for a SimdMask as a byte slice and passes it to the closure used as the action parameter
- fn bitmask<T>(mask: &Self::SimdMask, action: T)
- where
- T: FnMut(&[u8]);
-
/// Sets the value of a single lane of a SIMD mask
fn mask_set(mask: Self::SimdMask, idx: usize, value: bool) -> Self::SimdMask;
@@ -715,15 +714,13 @@ macro_rules! make_numeric_type {
}
#[inline]
- fn mask_get(mask: &Self::SimdMask, idx: usize) -> bool {
- unsafe { mask.extract_unchecked(idx) }
+ fn mask_to_u64(mask: &Self::SimdMask) -> u64 {
+ mask.bitmask() as u64
}
- fn bitmask<T>(mask: &Self::SimdMask, mut action: T)
- where
- T: FnMut(&[u8]),
- {
- action(mask.bitmask().to_byte_slice());
+ #[inline]
+ fn mask_get(mask: &Self::SimdMask, idx: usize) -> bool {
+ unsafe { mask.extract_unchecked(idx) }
}
#[inline]