You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/01/01 11:06:02 UTC

[arrow] branch master updated: ARROW-10990: [Rust] Refactor simd comparison kernels to avoid out of bounds reads

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 2d28778  ARROW-10990: [Rust] Refactor simd comparison kernels to avoid out of bounds reads
2d28778 is described below

commit 2d28778fe4c7a647ffe45f27de7be1d2b2b9a86a
Author: Jörn Horstmann <gi...@jhorstmann.net>
AuthorDate: Fri Jan 1 06:05:12 2021 -0500

    ARROW-10990: [Rust] Refactor simd comparison kernels to avoid out of bounds reads
    
    - [x] Adjust tests so the input data is bigger than one vector lane
    - [x] Remove `value_slice` function when ARROW-10989 gets merged
    
    Closes #8975 from jhorstmann/ARROW-10990-compare-kernels-out-of-bounds
    
    Authored-by: Jörn Horstmann <gi...@jhorstmann.net>
    Signed-off-by: Andrew Lamb <an...@nerdnetworks.org>
---
 rust/arrow/src/compute/kernels/comparison.rs | 492 ++++++++++++++++-----------
 rust/arrow/src/datatypes.rs                  |  21 +-
 2 files changed, 302 insertions(+), 211 deletions(-)

diff --git a/rust/arrow/src/compute/kernels/comparison.rs b/rust/arrow/src/compute/kernels/comparison.rs
index 012e4f2..a55c951 100644
--- a/rust/arrow/src/compute/kernels/comparison.rs
+++ b/rust/arrow/src/compute/kernels/comparison.rs
@@ -371,16 +371,18 @@ pub fn gt_eq_utf8_scalar(left: &StringArray, right: &str) -> Result<BooleanArray
 /// Helper function to perform boolean lambda function on values from two arrays using
 /// SIMD.
 #[cfg(simd_x86)]
-fn simd_compare_op<T, F>(
+fn simd_compare_op<T, SIMD_OP, SCALAR_OP>(
     left: &PrimitiveArray<T>,
     right: &PrimitiveArray<T>,
-    op: F,
+    simd_op: SIMD_OP,
+    scalar_op: SCALAR_OP,
 ) -> Result<BooleanArray>
 where
     T: ArrowNumericType,
-    F: Fn(T::Simd, T::Simd) -> T::SimdMask,
+    SIMD_OP: Fn(T::Simd, T::Simd) -> T::SimdMask,
+    SCALAR_OP: Fn(T::Native, T::Native) -> bool,
 {
-    use std::mem;
+    use std::borrow::BorrowMut;
 
     let len = left.len();
     if len != right.len() {
@@ -393,34 +395,60 @@ where
     let null_bit_buffer = combine_option_bitmap(left.data_ref(), right.data_ref(), len)?;
 
     let lanes = T::lanes();
-    let mut result = MutableBuffer::new(left.len() * mem::size_of::<bool>());
-
-    let rem = len % lanes;
+    let buffer_size = bit_util::ceil(len, 8);
+    let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
 
-    for i in (0..len - rem).step_by(lanes) {
-        let simd_left = T::load(unsafe { left.value_slice(i, lanes) });
-        let simd_right = T::load(unsafe { right.value_slice(i, lanes) });
-        let simd_result = op(simd_left, simd_right);
-        T::bitmask(&simd_result, |b| {
-            result.extend_from_slice(b);
-        });
-    }
+    // this is currently the case for all our datatypes and allows us to always append full bytes
+    assert!(
+        lanes % 8 == 0,
+        "Number of vector lanes must be multiple of 8"
+    );
+    let mut left_chunks = left.values().chunks_exact(lanes);
+    let mut right_chunks = right.values().chunks_exact(lanes);
+
+    let result_remainder = left_chunks
+        .borrow_mut()
+        .zip(right_chunks.borrow_mut())
+        .fold(
+            result.typed_data_mut(),
+            |result_slice, (left_slice, right_slice)| {
+                let simd_left = T::load(left_slice);
+                let simd_right = T::load(right_slice);
+                let simd_result = simd_op(simd_left, simd_right);
+
+                let bitmask = T::mask_to_u64(&simd_result);
+                let bytes = bitmask.to_le_bytes();
+                &result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);
+
+                &mut result_slice[lanes / 8..]
+            },
+        );
 
-    if rem > 0 {
-        //Soundness
-        //	This is not sound because it can read past the end of PrimitiveArray buffer (lanes is always greater than rem), see ARROW-10990
-        let simd_left = T::load(unsafe { left.value_slice(len - rem, lanes) });
-        let simd_right = T::load(unsafe { right.value_slice(len - rem, lanes) });
-        let simd_result = op(simd_left, simd_right);
-        let rem_buffer_size = (rem as f32 / 8f32).ceil() as usize;
-        T::bitmask(&simd_result, |b| {
-            result.extend_from_slice(&b[0..rem_buffer_size]);
+    let left_remainder = left_chunks.remainder();
+    let right_remainder = right_chunks.remainder();
+
+    assert_eq!(left_remainder.len(), right_remainder.len());
+
+    let remainder_bitmask = left_remainder
+        .iter()
+        .zip(right_remainder.iter())
+        .enumerate()
+        .fold(0_u64, |mut mask, (i, (scalar_left, scalar_right))| {
+            let bit = if scalar_op(*scalar_left, *scalar_right) {
+                1_u64
+            } else {
+                0_u64
+            };
+            mask |= bit << i;
+            mask
         });
-    }
+    let remainder_mask_as_bytes =
+        &remainder_bitmask.to_le_bytes()[0..bit_util::ceil(left_remainder.len(), 8)];
+    result_remainder.copy_from_slice(remainder_mask_as_bytes);
 
     let data = ArrayData::new(
         DataType::Boolean,
-        left.len(),
+        len,
         None,
         null_bit_buffer,
         0,
@@ -433,48 +461,78 @@ where
 /// Helper function to perform boolean lambda function on values from an array and a scalar value using
 /// SIMD.
 #[cfg(simd_x86)]
-fn simd_compare_op_scalar<T, F>(
+fn simd_compare_op_scalar<T, SIMD_OP, SCALAR_OP>(
     left: &PrimitiveArray<T>,
     right: T::Native,
-    op: F,
+    simd_op: SIMD_OP,
+    scalar_op: SCALAR_OP,
 ) -> Result<BooleanArray>
 where
     T: ArrowNumericType,
-    F: Fn(T::Simd, T::Simd) -> T::SimdMask,
+    SIMD_OP: Fn(T::Simd, T::Simd) -> T::SimdMask,
+    SCALAR_OP: Fn(T::Native, T::Native) -> bool,
 {
-    use std::mem;
+    use std::borrow::BorrowMut;
 
     let len = left.len();
-    let null_bit_buffer = left.data().null_buffer().cloned();
+
     let lanes = T::lanes();
-    let mut result = MutableBuffer::new(left.len() * mem::size_of::<bool>());
+    let buffer_size = bit_util::ceil(len, 8);
+    let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
+
+    // this is currently the case for all our datatypes and allows us to always append full bytes
+    assert!(
+        lanes % 8 == 0,
+        "Number of vector lanes must be multiple of 8"
+    );
+    let mut left_chunks = left.values().chunks_exact(lanes);
     let simd_right = T::init(right);
 
-    let rem = len % lanes;
+    let result_remainder = left_chunks.borrow_mut().fold(
+        result.typed_data_mut(),
+        |result_slice, left_slice| {
+            let simd_left = T::load(left_slice);
+            let simd_result = simd_op(simd_left, simd_right);
 
-    for i in (0..len - rem).step_by(lanes) {
-        let simd_left = T::load(unsafe { left.value_slice(i, lanes) });
-        let simd_result = op(simd_left, simd_right);
-        T::bitmask(&simd_result, |b| {
-            result.extend_from_slice(b);
-        });
-    }
+            let bitmask = T::mask_to_u64(&simd_result);
+            let bytes = bitmask.to_le_bytes();
+            &result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);
 
-    if rem > 0 {
-        //Soundness
-        //	This is not sound because it can read past the end of PrimitiveArray buffer (lanes is always greater than rem), see ARROW-10990
-        let simd_left = T::load(unsafe { left.value_slice(len - rem, lanes) });
-        let simd_result = op(simd_left, simd_right);
-        let rem_buffer_size = (rem as f32 / 8f32).ceil() as usize;
-        T::bitmask(&simd_result, |b| {
-            result.extend_from_slice(&b[0..rem_buffer_size]);
-        });
-    }
+            &mut result_slice[lanes / 8..]
+        },
+    );
+
+    let left_remainder = left_chunks.remainder();
+
+    let remainder_bitmask =
+        left_remainder
+            .iter()
+            .enumerate()
+            .fold(0_u64, |mut mask, (i, scalar_left)| {
+                let bit = if scalar_op(*scalar_left, right) {
+                    1_u64
+                } else {
+                    0_u64
+                };
+                mask |= bit << i;
+                mask
+            });
+    let remainder_mask_as_bytes =
+        &remainder_bitmask.to_le_bytes()[0..bit_util::ceil(left_remainder.len(), 8)];
+    result_remainder.copy_from_slice(remainder_mask_as_bytes);
+
+    let null_bit_buffer = left
+        .data_ref()
+        .null_buffer()
+        .map(|b| b.bit_slice(left.offset(), left.len()));
+
+    // null count is the same as in the input since the right side of the scalar comparison cannot be null
+    let null_count = left.null_count();
 
     let data = ArrayData::new(
         DataType::Boolean,
-        left.len(),
-        None,
+        len,
+        Some(null_count),
         null_bit_buffer,
         0,
         vec![result.freeze()],
@@ -489,7 +547,7 @@ where
     T: ArrowNumericType,
 {
     #[cfg(simd_x86)]
-    return simd_compare_op(left, right, T::eq);
+    return simd_compare_op(left, right, T::eq, |a, b| a == b);
     #[cfg(not(simd_x86))]
     return compare_op!(left, right, |a, b| a == b);
 }
@@ -500,7 +558,7 @@ where
     T: ArrowNumericType,
 {
     #[cfg(simd_x86)]
-    return simd_compare_op_scalar(left, right, T::eq);
+    return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b);
     #[cfg(not(simd_x86))]
     return compare_op_scalar!(left, right, |a, b| a == b);
 }
@@ -511,7 +569,7 @@ where
     T: ArrowNumericType,
 {
     #[cfg(simd_x86)]
-    return simd_compare_op(left, right, T::ne);
+    return simd_compare_op(left, right, T::ne, |a, b| a != b);
     #[cfg(not(simd_x86))]
     return compare_op!(left, right, |a, b| a != b);
 }
@@ -522,7 +580,7 @@ where
     T: ArrowNumericType,
 {
     #[cfg(simd_x86)]
-    return simd_compare_op_scalar(left, right, T::ne);
+    return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b);
     #[cfg(not(simd_x86))]
     return compare_op_scalar!(left, right, |a, b| a != b);
 }
@@ -534,7 +592,7 @@ where
     T: ArrowNumericType,
 {
     #[cfg(simd_x86)]
-    return simd_compare_op(left, right, T::lt);
+    return simd_compare_op(left, right, T::lt, |a, b| a < b);
     #[cfg(not(simd_x86))]
     return compare_op!(left, right, |a, b| a < b);
 }
@@ -546,7 +604,7 @@ where
     T: ArrowNumericType,
 {
     #[cfg(simd_x86)]
-    return simd_compare_op_scalar(left, right, T::lt);
+    return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b);
     #[cfg(not(simd_x86))]
     return compare_op_scalar!(left, right, |a, b| a < b);
 }
@@ -561,7 +619,7 @@ where
     T: ArrowNumericType,
 {
     #[cfg(simd_x86)]
-    return simd_compare_op(left, right, T::le);
+    return simd_compare_op(left, right, T::le, |a, b| a <= b);
     #[cfg(not(simd_x86))]
     return compare_op!(left, right, |a, b| a <= b);
 }
@@ -573,7 +631,7 @@ where
     T: ArrowNumericType,
 {
     #[cfg(simd_x86)]
-    return simd_compare_op_scalar(left, right, T::le);
+    return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b);
     #[cfg(not(simd_x86))]
     return compare_op_scalar!(left, right, |a, b| a <= b);
 }
@@ -585,7 +643,7 @@ where
     T: ArrowNumericType,
 {
     #[cfg(simd_x86)]
-    return simd_compare_op(left, right, T::gt);
+    return simd_compare_op(left, right, T::gt, |a, b| a > b);
     #[cfg(not(simd_x86))]
     return compare_op!(left, right, |a, b| a > b);
 }
@@ -597,7 +655,7 @@ where
     T: ArrowNumericType,
 {
     #[cfg(simd_x86)]
-    return simd_compare_op_scalar(left, right, T::gt);
+    return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b);
     #[cfg(not(simd_x86))]
     return compare_op_scalar!(left, right, |a, b| a > b);
 }
@@ -612,7 +670,7 @@ where
     T: ArrowNumericType,
 {
     #[cfg(simd_x86)]
-    return simd_compare_op(left, right, T::ge);
+    return simd_compare_op(left, right, T::ge, |a, b| a >= b);
     #[cfg(not(simd_x86))]
     return compare_op!(left, right, |a, b| a >= b);
 }
@@ -624,7 +682,7 @@ where
     T: ArrowNumericType,
 {
     #[cfg(simd_x86)]
-    return simd_compare_op_scalar(left, right, T::ge);
+    return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b);
     #[cfg(not(simd_x86))]
     return compare_op_scalar!(left, right, |a, b| a >= b);
 }
@@ -752,33 +810,57 @@ fn new_all_set_buffer(len: usize) -> Buffer {
     buffer.freeze()
 }
 
+// disable wrapping inside literal vectors used for test data and assertions
+#[rustfmt::skip::macros(vec)]
 #[cfg(test)]
 mod tests {
     use super::*;
     use crate::datatypes::{Int8Type, ToByteSlice};
-    use crate::{array::Int32Array, datatypes::Field};
+    use crate::{array::Int32Array, array::Int64Array, datatypes::Field};
+
+    /// Evaluate `KERNEL` with two vectors as inputs and assert against the expected output.
+    /// `A_VEC` and `B_VEC` can be of type `Vec<i64>` or `Vec<Option<i64>>`.
+    /// `EXPECTED` can be either `Vec<bool>` or `Vec<Option<bool>>`.
+    /// The main reason for this macro is that inputs and outputs align nicely after `cargo fmt`.
+    macro_rules! cmp_i64 {
+        ($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => {
+            let a = Int64Array::from($A_VEC);
+            let b = Int64Array::from($B_VEC);
+            let c = $KERNEL(&a, &b).unwrap();
+            assert_eq!(BooleanArray::from($EXPECTED), c);
+        };
+    }
+
+    /// Evaluate `KERNEL` with one vectors and one scalar as inputs and assert against the expected output.
+    /// `A_VEC` can be of type `Vec<i64>` or `Vec<Option<i64>>`.
+    /// `EXPECTED` can be either `Vec<bool>` or `Vec<Option<bool>>`.
+    /// The main reason for this macro is that inputs and outputs align nicely after `cargo fmt`.
+    macro_rules! cmp_i64_scalar {
+        ($KERNEL:ident, $A_VEC:expr, $B:literal, $EXPECTED:expr) => {
+            let a = Int64Array::from($A_VEC);
+            let c = $KERNEL(&a, $B).unwrap();
+            assert_eq!(BooleanArray::from($EXPECTED), c);
+        };
+    }
 
     #[test]
     fn test_primitive_array_eq() {
-        let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
-        let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
-        let c = eq(&a, &b).unwrap();
-        assert_eq!(false, c.value(0));
-        assert_eq!(false, c.value(1));
-        assert_eq!(true, c.value(2));
-        assert_eq!(false, c.value(3));
-        assert_eq!(false, c.value(4));
+        cmp_i64!(
+            eq,
+            vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
+            vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+            vec![false, false, true, false, false, false, false, true, false, false]
+        );
     }
 
     #[test]
     fn test_primitive_array_eq_scalar() {
-        let a = Int32Array::from(vec![6, 7, 8, 9, 10]);
-        let c = eq_scalar(&a, 8).unwrap();
-        assert_eq!(false, c.value(0));
-        assert_eq!(false, c.value(1));
-        assert_eq!(true, c.value(2));
-        assert_eq!(false, c.value(3));
-        assert_eq!(false, c.value(4));
+        cmp_i64_scalar!(
+            eq_scalar,
+            vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+            8,
+            vec![false, false, true, false, false, false, false, true, false, false]
+        );
     }
 
     #[test]
@@ -797,193 +879,205 @@ mod tests {
 
     #[test]
     fn test_primitive_array_neq() {
-        let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
-        let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
-        let c = neq(&a, &b).unwrap();
-        assert_eq!(true, c.value(0));
-        assert_eq!(true, c.value(1));
-        assert_eq!(false, c.value(2));
-        assert_eq!(true, c.value(3));
-        assert_eq!(true, c.value(4));
+        cmp_i64!(
+            neq,
+            vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
+            vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+            vec![true, true, false, true, true, true, true, false, true, true]
+        );
     }
 
     #[test]
     fn test_primitive_array_neq_scalar() {
-        let a = Int32Array::from(vec![6, 7, 8, 9, 10]);
-        let c = neq_scalar(&a, 8).unwrap();
-        assert_eq!(true, c.value(0));
-        assert_eq!(true, c.value(1));
-        assert_eq!(false, c.value(2));
-        assert_eq!(true, c.value(3));
-        assert_eq!(true, c.value(4));
+        cmp_i64_scalar!(
+            neq_scalar,
+            vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+            8,
+            vec![true, true, false, true, true, true, true, false, true, true]
+        );
     }
 
     #[test]
     fn test_primitive_array_lt() {
-        let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
-        let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
-        let c = lt(&a, &b).unwrap();
-        assert_eq!(false, c.value(0));
-        assert_eq!(false, c.value(1));
-        assert_eq!(false, c.value(2));
-        assert_eq!(true, c.value(3));
-        assert_eq!(true, c.value(4));
+        cmp_i64!(
+            lt,
+            vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
+            vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+            vec![false, false, false, true, true, false, false, false, true, true]
+        );
     }
 
     #[test]
     fn test_primitive_array_lt_scalar() {
-        let a = Int32Array::from(vec![6, 7, 8, 9, 10]);
-        let c = lt_scalar(&a, 8).unwrap();
-        assert_eq!(true, c.value(0));
-        assert_eq!(true, c.value(1));
-        assert_eq!(false, c.value(2));
-        assert_eq!(false, c.value(3));
-        assert_eq!(false, c.value(4));
+        cmp_i64_scalar!(
+            lt_scalar,
+            vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+            8,
+            vec![true, true, false, false, false, true, true, false, false, false]
+        );
     }
 
     #[test]
     fn test_primitive_array_lt_nulls() {
-        let a = Int32Array::from(vec![None, None, Some(1)]);
-        let b = Int32Array::from(vec![None, Some(1), None]);
-        let c = lt(&a, &b).unwrap();
-        assert_eq!(false, c.value(0));
-        assert_eq!(true, c.value(1));
-        assert_eq!(false, c.value(2));
+        cmp_i64!(
+            lt,
+            vec![None, None, Some(1), Some(1), None, None, Some(2), Some(2),],
+            vec![None, Some(1), None, Some(1), None, Some(3), None, Some(3),],
+            vec![None, None, None, Some(false), None, None, None, Some(true)]
+        );
     }
 
     #[test]
     fn test_primitive_array_lt_scalar_nulls() {
-        let a = Int32Array::from(vec![None, Some(1), Some(2)]);
-        let c = lt_scalar(&a, 2).unwrap();
-        assert_eq!(true, c.value(0));
-        assert_eq!(true, c.value(1));
-        assert_eq!(false, c.value(2));
+        cmp_i64_scalar!(
+            lt_scalar,
+            vec![None, Some(1), Some(2), Some(3), None, Some(1), Some(2), Some(3), Some(2), None],
+            2,
+            vec![None, Some(true), Some(false), Some(false), None, Some(true), Some(false), Some(false), Some(false), None]
+        );
     }
 
     #[test]
     fn test_primitive_array_lt_eq() {
-        let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
-        let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
-        let c = lt_eq(&a, &b).unwrap();
-        assert_eq!(false, c.value(0));
-        assert_eq!(false, c.value(1));
-        assert_eq!(true, c.value(2));
-        assert_eq!(true, c.value(3));
-        assert_eq!(true, c.value(4));
+        cmp_i64!(
+            lt_eq,
+            vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
+            vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+            vec![false, false, true, true, true, false, false, true, true, true]
+        );
     }
 
     #[test]
     fn test_primitive_array_lt_eq_scalar() {
-        let a = Int32Array::from(vec![6, 7, 8, 9, 10]);
-        let c = lt_eq_scalar(&a, 8).unwrap();
-        assert_eq!(true, c.value(0));
-        assert_eq!(true, c.value(1));
-        assert_eq!(true, c.value(2));
-        assert_eq!(false, c.value(3));
-        assert_eq!(false, c.value(4));
+        cmp_i64_scalar!(
+            lt_eq_scalar,
+            vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+            8,
+            vec![true, true, true, false, false, true, true, true, false, false]
+        );
     }
 
     #[test]
     fn test_primitive_array_lt_eq_nulls() {
-        let a = Int32Array::from(vec![None, None, Some(1)]);
-        let b = Int32Array::from(vec![None, Some(1), None]);
-        let c = lt_eq(&a, &b).unwrap();
-        assert_eq!(true, c.value(0));
-        assert_eq!(true, c.value(1));
-        assert_eq!(false, c.value(2));
+        cmp_i64!(
+            lt_eq,
+            vec![None, None, Some(1), None, None, Some(1), None, None, Some(1)],
+            vec![None, Some(1), Some(0), None, Some(1), Some(2), None, None, Some(3)],
+            vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)]
+        );
     }
 
     #[test]
     fn test_primitive_array_lt_eq_scalar_nulls() {
-        let a = Int32Array::from(vec![None, Some(1), Some(2)]);
-        let c = lt_eq_scalar(&a, 1).unwrap();
-        assert_eq!(true, c.value(0));
-        assert_eq!(true, c.value(1));
-        assert_eq!(false, c.value(2));
+        cmp_i64_scalar!(
+            lt_eq_scalar,
+            vec![None, Some(1), Some(2), None, Some(1), Some(2), None, Some(1), Some(2)],
+            1,
+            vec![None, Some(true), Some(false), None, Some(true), Some(false), None, Some(true), Some(false)]
+        );
     }
 
     #[test]
     fn test_primitive_array_gt() {
-        let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
-        let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
-        let c = gt(&a, &b).unwrap();
-        assert_eq!(true, c.value(0));
-        assert_eq!(true, c.value(1));
-        assert_eq!(false, c.value(2));
-        assert_eq!(false, c.value(3));
-        assert_eq!(false, c.value(4));
+        cmp_i64!(
+            gt,
+            vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
+            vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+            vec![true, true, false, false, false, true, true, false, false, false]
+        );
     }
 
     #[test]
     fn test_primitive_array_gt_scalar() {
-        let a = Int32Array::from(vec![6, 7, 8, 9, 10]);
-        let c = gt_scalar(&a, 8).unwrap();
-        assert_eq!(false, c.value(0));
-        assert_eq!(false, c.value(1));
-        assert_eq!(false, c.value(2));
-        assert_eq!(true, c.value(3));
-        assert_eq!(true, c.value(4));
+        cmp_i64_scalar!(
+            gt_scalar,
+            vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+            8,
+            vec![false, false, false, true, true, false, false, false, true, true]
+        );
     }
 
     #[test]
     fn test_primitive_array_gt_nulls() {
-        let a = Int32Array::from(vec![None, None, Some(1)]);
-        let b = Int32Array::from(vec![None, Some(1), None]);
-        let c = gt(&a, &b).unwrap();
-        assert_eq!(false, c.value(0));
-        assert_eq!(false, c.value(1));
-        assert_eq!(true, c.value(2));
+        cmp_i64!(
+            gt,
+            vec![None, None, Some(1), None, None, Some(2), None, None, Some(3)],
+            vec![None, Some(1), Some(1), None, Some(1), Some(1), None, Some(1), Some(1)],
+            vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)]
+        );
     }
 
     #[test]
     fn test_primitive_array_gt_scalar_nulls() {
-        let a = Int32Array::from(vec![None, Some(1), Some(2)]);
-        let c = gt_scalar(&a, 1).unwrap();
-        assert_eq!(false, c.value(0));
-        assert_eq!(false, c.value(1));
-        assert_eq!(true, c.value(2));
+        cmp_i64_scalar!(
+            gt_scalar,
+            vec![None, Some(1), Some(2), None, Some(1), Some(2), None, Some(1), Some(2)],
+            1,
+            vec![None, Some(false), Some(true), None, Some(false), Some(true), None, Some(false), Some(true)]
+        );
     }
 
     #[test]
     fn test_primitive_array_gt_eq() {
-        let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
-        let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
-        let c = gt_eq(&a, &b).unwrap();
-        assert_eq!(true, c.value(0));
-        assert_eq!(true, c.value(1));
-        assert_eq!(true, c.value(2));
-        assert_eq!(false, c.value(3));
-        assert_eq!(false, c.value(4));
+        cmp_i64!(
+            gt_eq,
+            vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
+            vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+            vec![true, true, true, false, false, true, true, true, false, false]
+        );
     }
 
     #[test]
     fn test_primitive_array_gt_eq_scalar() {
-        let a = Int32Array::from(vec![6, 7, 8, 9, 10]);
-        let c = gt_eq_scalar(&a, 8).unwrap();
-        assert_eq!(false, c.value(0));
-        assert_eq!(false, c.value(1));
-        assert_eq!(true, c.value(2));
-        assert_eq!(true, c.value(3));
-        assert_eq!(true, c.value(4));
+        cmp_i64_scalar!(
+            gt_eq_scalar,
+            vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
+            8,
+            vec![false, false, true, true, true, false, false, true, true, true]
+        );
     }
 
     #[test]
     fn test_primitive_array_gt_eq_nulls() {
-        let a = Int32Array::from(vec![None, None, Some(1)]);
-        let b = Int32Array::from(vec![None, Some(1), None]);
-        let c = gt_eq(&a, &b).unwrap();
-        assert_eq!(true, c.value(0));
-        assert_eq!(false, c.value(1));
-        assert_eq!(true, c.value(2));
+        cmp_i64!(
+            gt_eq,
+            vec![None, None, Some(1), None, Some(1), Some(2), None, None, Some(1)],
+            vec![None, Some(1), None, None, Some(1), Some(1), None, Some(2), Some(2)],
+            vec![None, None, None, None, Some(true), Some(true), None, None, Some(false)]
+        );
     }
 
     #[test]
     fn test_primitive_array_gt_eq_scalar_nulls() {
-        let a = Int32Array::from(vec![None, Some(1), Some(2)]);
-        let c = gt_eq_scalar(&a, 1).unwrap();
-        assert_eq!(false, c.value(0));
-        assert_eq!(true, c.value(1));
-        assert_eq!(true, c.value(2));
+        cmp_i64_scalar!(
+            gt_eq_scalar,
+            vec![None, Some(1), Some(2), None, Some(2), Some(3), None, Some(3), Some(4)],
+            2,
+            vec![None, Some(false), Some(true), None, Some(true), Some(true), None, Some(true), Some(true)]
+        );
+    }
+
+    #[test]
+    fn test_primitive_array_compare_slice() {
+        let a: Int32Array = (0..100).map(Some).collect();
+        let a = a.slice(50, 50);
+        let a = a.as_any().downcast_ref::<Int32Array>().unwrap();
+        let b: Int32Array = (100..200).map(Some).collect();
+        let b = b.slice(50, 50);
+        let b = b.as_any().downcast_ref::<Int32Array>().unwrap();
+        let actual = lt(&a, &b).unwrap();
+        let expected: BooleanArray = (0..50).map(|_| Some(true)).collect();
+        assert_eq!(expected, actual);
+    }
+
+    #[test]
+    fn test_primitive_array_compare_scalar_slice() {
+        let a: Int32Array = (0..100).map(Some).collect();
+        let a = a.slice(50, 50);
+        let a = a.as_any().downcast_ref::<Int32Array>().unwrap();
+        let actual = lt_scalar(&a, 200).unwrap();
+        let expected: BooleanArray = (0..50).map(|_| Some(true)).collect();
+        assert_eq!(expected, actual);
     }
 
     #[test]
diff --git a/rust/arrow/src/datatypes.rs b/rust/arrow/src/datatypes.rs
index 125adc4..939cabf 100644
--- a/rust/arrow/src/datatypes.rs
+++ b/rust/arrow/src/datatypes.rs
@@ -542,14 +542,13 @@ where
     /// The number of bits used corresponds to the number of lanes of this type
     fn mask_from_u64(mask: u64) -> Self::SimdMask;
 
+    /// Creates a bitmask from the given SIMD mask.
+    /// Each bit corresponds to one vector lane, starting with the least-significant bit.
+    fn mask_to_u64(mask: &Self::SimdMask) -> u64;
+
     /// Gets the value of a single lane in a SIMD mask
     fn mask_get(mask: &Self::SimdMask, idx: usize) -> bool;
 
-    /// Gets the bitmask for a SimdMask as a byte slice and passes it to the closure used as the action parameter
-    fn bitmask<T>(mask: &Self::SimdMask, action: T)
-    where
-        T: FnMut(&[u8]);
-
     /// Sets the value of a single lane of a SIMD mask
     fn mask_set(mask: Self::SimdMask, idx: usize, value: bool) -> Self::SimdMask;
 
@@ -715,15 +714,13 @@ macro_rules! make_numeric_type {
             }
 
             #[inline]
-            fn mask_get(mask: &Self::SimdMask, idx: usize) -> bool {
-                unsafe { mask.extract_unchecked(idx) }
+            fn mask_to_u64(mask: &Self::SimdMask) -> u64 {
+                mask.bitmask() as u64
             }
 
-            fn bitmask<T>(mask: &Self::SimdMask, mut action: T)
-            where
-                T: FnMut(&[u8]),
-            {
-                action(mask.bitmask().to_byte_slice());
+            #[inline]
+            fn mask_get(mask: &Self::SimdMask, idx: usize) -> bool {
+                unsafe { mask.extract_unchecked(idx) }
             }
 
             #[inline]