You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/11/30 22:56:23 UTC

[arrow-rs] branch master updated: Add binary_mut and try_binary_mut (#3144)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 961e114af Add binary_mut and try_binary_mut (#3144)
961e114af is described below

commit 961e114af0bd74d31dfcaa30e91f9929a6e6d719
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Wed Nov 30 14:56:18 2022 -0800

    Add binary_mut and try_binary_mut (#3144)
    
    * Add add_mut
    
    * Add try_binary_mut
    
    * Add test
    
    * Change result type
    
    * Remove _mut kernels
    
    * Fix clippy
---
 arrow/src/compute/kernels/arithmetic.rs |  31 ++++-
 arrow/src/compute/kernels/arity.rs      | 216 ++++++++++++++++++++++++++++++++
 2 files changed, 246 insertions(+), 1 deletion(-)

diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs
index f9deada53..c57e27095 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -1624,7 +1624,7 @@ where
 mod tests {
     use super::*;
     use crate::array::Int32Array;
-    use crate::compute::{try_unary_mut, unary_mut};
+    use crate::compute::{binary_mut, try_binary_mut, try_unary_mut, unary_mut};
     use crate::datatypes::{Date64Type, Int32Type, Int8Type};
     use arrow_buffer::i256;
     use chrono::NaiveDate;
@@ -3100,6 +3100,35 @@ mod tests {
         assert_eq!(result.null_count(), 13);
     }
 
+    #[test]
+    fn test_primitive_array_add_mut_by_binary_mut() {
+        let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
+        let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
+
+        let c = binary_mut(a, &b, |a, b| a.add_wrapping(b))
+            .unwrap()
+            .unwrap();
+        let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
+        assert_eq!(c, expected);
+    }
+
+    #[test]
+    fn test_primitive_add_mut_wrapping_overflow_by_try_binary_mut() {
+        let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
+        let b = Int32Array::from(vec![1, 1]);
+
+        let wrapped = binary_mut(a, &b, |a, b| a.add_wrapping(b))
+            .unwrap()
+            .unwrap();
+        let expected = Int32Array::from(vec![-2147483648, -2147483647]);
+        assert_eq!(expected, wrapped);
+
+        let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
+        let b = Int32Array::from(vec![1, 1]);
+        let overflow = try_binary_mut(a, &b, |a, b| a.add_checked(b));
+        let _ = overflow.unwrap().expect_err("overflow should be detected");
+    }
+
     #[test]
     fn test_primitive_add_scalar_by_unary_mut() {
         let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs
index 946d15e9e..d0f18cf58 100644
--- a/arrow/src/compute/kernels/arity.rs
+++ b/arrow/src/compute/kernels/arity.rs
@@ -232,6 +232,75 @@ where
     Ok(unsafe { build_primitive_array(len, buffer, null_count, null_buffer) })
 }
 
+/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, mutating
+/// the mutable [`PrimitiveArray`] `a`. If any index is null in either `a` or `b`, the
+/// corresponding index in the result will also be null.
+///
+/// Mutable primitive array means that the buffer is not shared with other arrays.
+/// As a result, this mutates the buffer directly without allocating new buffer.
+///
+/// Like [`unary`] the provided function is evaluated for every index, ignoring validity. This
+/// is beneficial when the cost of the operation is low compared to the cost of branching, and
+/// especially when the operation can be vectorised, however, requires `op` to be infallible
+/// for all possible values of its inputs
+///
+/// # Error
+///
+/// This function gives error if the arrays have different lengths.
+/// This function gives error of original [`PrimitiveArray`] `a` if it is not a mutable
+/// primitive array.
+pub fn binary_mut<T, F>(
+    a: PrimitiveArray<T>,
+    b: &PrimitiveArray<T>,
+    op: F,
+) -> std::result::Result<
+    std::result::Result<PrimitiveArray<T>, ArrowError>,
+    PrimitiveArray<T>,
+>
+where
+    T: ArrowPrimitiveType,
+    F: Fn(T::Native, T::Native) -> T::Native,
+{
+    if a.len() != b.len() {
+        return Ok(Err(ArrowError::ComputeError(
+            "Cannot perform binary operation on arrays of different length".to_string(),
+        )));
+    }
+
+    if a.is_empty() {
+        return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
+            &T::DATA_TYPE,
+        ))));
+    }
+
+    let len = a.len();
+
+    let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap();
+    let null_count = null_buffer
+        .as_ref()
+        .map(|x| len - x.count_set_bits_offset(0, len))
+        .unwrap_or_default();
+
+    let mut builder = a.into_builder()?;
+
+    builder
+        .values_slice_mut()
+        .iter_mut()
+        .zip(b.values())
+        .for_each(|(l, r)| *l = op(*l, *r));
+
+    let array_builder = builder
+        .finish()
+        .data()
+        .clone()
+        .into_builder()
+        .null_bit_buffer(null_buffer)
+        .null_count(null_count);
+
+    let array_data = unsafe { array_builder.build_unchecked() };
+    Ok(Ok(PrimitiveArray::<T>::from(array_data)))
+}
+
 /// Applies the provided fallible binary operation across `a` and `b`, returning any error,
 /// and collecting the results into a [`PrimitiveArray`]. If any index is null in either `a`
 /// or `b`, the corresponding index in the result will also be null
@@ -289,6 +358,83 @@ where
     }
 }
 
+/// Applies the provided fallible binary operation across `a` and `b` by mutating the mutable
+/// [`PrimitiveArray`] `a` with the results, returning any error. If any index is null in
+/// either `a` or `b`, the corresponding index in the result will also be null
+///
+/// Like [`try_unary`] the function is only evaluated for non-null indices
+///
+/// Mutable primitive array means that the buffer is not shared with other arrays.
+/// As a result, this mutates the buffer directly without allocating new buffer.
+///
+/// # Error
+///
+/// Return an error if the arrays have different lengths or
+/// the operation is under erroneous.
+/// This function gives error of original [`PrimitiveArray`] `a` if it is not a mutable
+/// primitive array.
+pub fn try_binary_mut<T, F>(
+    a: PrimitiveArray<T>,
+    b: &PrimitiveArray<T>,
+    op: F,
+) -> std::result::Result<
+    std::result::Result<PrimitiveArray<T>, ArrowError>,
+    PrimitiveArray<T>,
+>
+where
+    T: ArrowPrimitiveType,
+    F: Fn(T::Native, T::Native) -> Result<T::Native>,
+{
+    if a.len() != b.len() {
+        return Ok(Err(ArrowError::ComputeError(
+            "Cannot perform binary operation on arrays of different length".to_string(),
+        )));
+    }
+    let len = a.len();
+
+    if a.is_empty() {
+        return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
+            &T::DATA_TYPE,
+        ))));
+    }
+
+    if a.null_count() == 0 && b.null_count() == 0 {
+        try_binary_no_nulls_mut(len, a, b, op)
+    } else {
+        let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap();
+        let null_count = null_buffer
+            .as_ref()
+            .map(|x| len - x.count_set_bits_offset(0, len))
+            .unwrap_or_default();
+
+        let mut builder = a.into_builder()?;
+
+        let slice = builder.values_slice_mut();
+
+        match try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| {
+            unsafe {
+                *slice.get_unchecked_mut(idx) =
+                    op(*slice.get_unchecked(idx), b.value_unchecked(idx))?
+            };
+            Ok::<_, ArrowError>(())
+        }) {
+            Ok(_) => {}
+            Err(err) => return Ok(Err(err)),
+        };
+
+        let array_builder = builder
+            .finish()
+            .data()
+            .clone()
+            .into_builder()
+            .null_bit_buffer(null_buffer)
+            .null_count(null_count);
+
+        let array_data = unsafe { array_builder.build_unchecked() };
+        Ok(Ok(PrimitiveArray::<T>::from(array_data)))
+    }
+}
+
 /// This intentional inline(never) attribute helps LLVM optimize the loop.
 #[inline(never)]
 fn try_binary_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
@@ -310,6 +456,35 @@ where
     Ok(unsafe { build_primitive_array(len, buffer.into(), 0, None) })
 }
 
+/// This intentional inline(never) attribute helps LLVM optimize the loop.
+#[inline(never)]
+fn try_binary_no_nulls_mut<T, F>(
+    len: usize,
+    a: PrimitiveArray<T>,
+    b: &PrimitiveArray<T>,
+    op: F,
+) -> std::result::Result<
+    std::result::Result<PrimitiveArray<T>, ArrowError>,
+    PrimitiveArray<T>,
+>
+where
+    T: ArrowPrimitiveType,
+    F: Fn(T::Native, T::Native) -> Result<T::Native>,
+{
+    let mut builder = a.into_builder()?;
+    let slice = builder.values_slice_mut();
+
+    for idx in 0..len {
+        unsafe {
+            match op(*slice.get_unchecked(idx), b.value_unchecked(idx)) {
+                Ok(value) => *slice.get_unchecked_mut(idx) = value,
+                Err(err) => return Ok(Err(err)),
+            };
+        };
+    }
+    Ok(Ok(builder.finish()))
+}
+
 #[inline(never)]
 fn try_binary_opt_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
     len: usize,
@@ -385,6 +560,7 @@ mod tests {
     use super::*;
     use crate::array::{as_primitive_array, Float64Array, PrimitiveDictionaryBuilder};
     use crate::datatypes::{Float64Type, Int32Type, Int8Type};
+    use arrow_array::Int32Array;
 
     #[test]
     fn test_unary_f64_slice() {
@@ -444,4 +620,44 @@ mod tests {
             &expected
         );
     }
+
+    #[test]
+    fn test_binary_mut() {
+        let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
+        let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
+        let c = binary_mut(a, &b, |l, r| l + r).unwrap().unwrap();
+
+        let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
+        assert_eq!(c, expected);
+    }
+
+    #[test]
+    fn test_try_binary_mut() {
+        let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
+        let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
+        let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();
+
+        let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
+        assert_eq!(c, expected);
+
+        let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
+        let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
+        let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();
+        let expected = Int32Array::from(vec![16, 16, 12, 12, 6]);
+        assert_eq!(c, expected);
+
+        let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
+        let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
+        let _ = try_binary_mut(a, &b, |l, r| {
+            if l == 1 {
+                Err(ArrowError::InvalidArgumentError(
+                    "got error".parse().unwrap(),
+                ))
+            } else {
+                Ok(l + r)
+            }
+        })
+        .unwrap()
+        .expect_err("should got error");
+    }
 }