You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/11/30 22:56:23 UTC
[arrow-rs] branch master updated: Add binary_mut and try_binary_mut (#3144)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 961e114af Add binary_mut and try_binary_mut (#3144)
961e114af is described below
commit 961e114af0bd74d31dfcaa30e91f9929a6e6d719
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Wed Nov 30 14:56:18 2022 -0800
Add binary_mut and try_binary_mut (#3144)
* Add add_mut
* Add try_binary_mut
* Add test
* Change result type
* Remove _mut kernels
* Fix clippy
---
arrow/src/compute/kernels/arithmetic.rs | 31 ++++-
arrow/src/compute/kernels/arity.rs | 216 ++++++++++++++++++++++++++++++++
2 files changed, 246 insertions(+), 1 deletion(-)
diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs
index f9deada53..c57e27095 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -1624,7 +1624,7 @@ where
mod tests {
use super::*;
use crate::array::Int32Array;
- use crate::compute::{try_unary_mut, unary_mut};
+ use crate::compute::{binary_mut, try_binary_mut, try_unary_mut, unary_mut};
use crate::datatypes::{Date64Type, Int32Type, Int8Type};
use arrow_buffer::i256;
use chrono::NaiveDate;
@@ -3100,6 +3100,35 @@ mod tests {
assert_eq!(result.null_count(), 13);
}
+ #[test]
+ fn test_primitive_array_add_mut_by_binary_mut() {
+ let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
+ let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
+
+ let c = binary_mut(a, &b, |a, b| a.add_wrapping(b))
+ .unwrap()
+ .unwrap();
+ let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
+ assert_eq!(c, expected);
+ }
+
+ #[test]
+ fn test_primitive_add_mut_wrapping_overflow_by_try_binary_mut() {
+ let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
+ let b = Int32Array::from(vec![1, 1]);
+
+ let wrapped = binary_mut(a, &b, |a, b| a.add_wrapping(b))
+ .unwrap()
+ .unwrap();
+ let expected = Int32Array::from(vec![-2147483648, -2147483647]);
+ assert_eq!(expected, wrapped);
+
+ let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
+ let b = Int32Array::from(vec![1, 1]);
+ let overflow = try_binary_mut(a, &b, |a, b| a.add_checked(b));
+ let _ = overflow.unwrap().expect_err("overflow should be detected");
+ }
+
#[test]
fn test_primitive_add_scalar_by_unary_mut() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs
index 946d15e9e..d0f18cf58 100644
--- a/arrow/src/compute/kernels/arity.rs
+++ b/arrow/src/compute/kernels/arity.rs
@@ -232,6 +232,75 @@ where
Ok(unsafe { build_primitive_array(len, buffer, null_count, null_buffer) })
}
+/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, mutating
+/// the mutable [`PrimitiveArray`] `a`. If any index is null in either `a` or `b`, the
+/// corresponding index in the result will also be null.
+///
+/// Mutable primitive array means that the buffer is not shared with other arrays.
+/// As a result, this mutates the buffer directly without allocating new buffer.
+///
+/// Like [`unary`] the provided function is evaluated for every index, ignoring validity. This
+/// is beneficial when the cost of the operation is low compared to the cost of branching, and
+/// especially when the operation can be vectorised, however, requires `op` to be infallible
+/// for all possible values of its inputs
+///
+/// # Error
+///
+/// This function gives error if the arrays have different lengths.
+/// This function gives error of original [`PrimitiveArray`] `a` if it is not a mutable
+/// primitive array.
+pub fn binary_mut<T, F>(
+ a: PrimitiveArray<T>,
+ b: &PrimitiveArray<T>,
+ op: F,
+) -> std::result::Result<
+ std::result::Result<PrimitiveArray<T>, ArrowError>,
+ PrimitiveArray<T>,
+>
+where
+ T: ArrowPrimitiveType,
+ F: Fn(T::Native, T::Native) -> T::Native,
+{
+ if a.len() != b.len() {
+ return Ok(Err(ArrowError::ComputeError(
+ "Cannot perform binary operation on arrays of different length".to_string(),
+ )));
+ }
+
+ if a.is_empty() {
+ return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
+ &T::DATA_TYPE,
+ ))));
+ }
+
+ let len = a.len();
+
+ let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap();
+ let null_count = null_buffer
+ .as_ref()
+ .map(|x| len - x.count_set_bits_offset(0, len))
+ .unwrap_or_default();
+
+ let mut builder = a.into_builder()?;
+
+ builder
+ .values_slice_mut()
+ .iter_mut()
+ .zip(b.values())
+ .for_each(|(l, r)| *l = op(*l, *r));
+
+ let array_builder = builder
+ .finish()
+ .data()
+ .clone()
+ .into_builder()
+ .null_bit_buffer(null_buffer)
+ .null_count(null_count);
+
+ let array_data = unsafe { array_builder.build_unchecked() };
+ Ok(Ok(PrimitiveArray::<T>::from(array_data)))
+}
+
/// Applies the provided fallible binary operation across `a` and `b`, returning any error,
/// and collecting the results into a [`PrimitiveArray`]. If any index is null in either `a`
/// or `b`, the corresponding index in the result will also be null
@@ -289,6 +358,83 @@ where
}
}
+/// Applies the provided fallible binary operation across `a` and `b` by mutating the mutable
+/// [`PrimitiveArray`] `a` with the results, returning any error. If any index is null in
+/// either `a` or `b`, the corresponding index in the result will also be null
+///
+/// Like [`try_unary`] the function is only evaluated for non-null indices
+///
+/// Mutable primitive array means that the buffer is not shared with other arrays.
+/// As a result, this mutates the buffer directly without allocating new buffer.
+///
+/// # Error
+///
+/// Return an error if the arrays have different lengths or
+/// the operation is under erroneous.
+/// This function gives error of original [`PrimitiveArray`] `a` if it is not a mutable
+/// primitive array.
+pub fn try_binary_mut<T, F>(
+ a: PrimitiveArray<T>,
+ b: &PrimitiveArray<T>,
+ op: F,
+) -> std::result::Result<
+ std::result::Result<PrimitiveArray<T>, ArrowError>,
+ PrimitiveArray<T>,
+>
+where
+ T: ArrowPrimitiveType,
+ F: Fn(T::Native, T::Native) -> Result<T::Native>,
+{
+ if a.len() != b.len() {
+ return Ok(Err(ArrowError::ComputeError(
+ "Cannot perform binary operation on arrays of different length".to_string(),
+ )));
+ }
+ let len = a.len();
+
+ if a.is_empty() {
+ return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
+ &T::DATA_TYPE,
+ ))));
+ }
+
+ if a.null_count() == 0 && b.null_count() == 0 {
+ try_binary_no_nulls_mut(len, a, b, op)
+ } else {
+ let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap();
+ let null_count = null_buffer
+ .as_ref()
+ .map(|x| len - x.count_set_bits_offset(0, len))
+ .unwrap_or_default();
+
+ let mut builder = a.into_builder()?;
+
+ let slice = builder.values_slice_mut();
+
+ match try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| {
+ unsafe {
+ *slice.get_unchecked_mut(idx) =
+ op(*slice.get_unchecked(idx), b.value_unchecked(idx))?
+ };
+ Ok::<_, ArrowError>(())
+ }) {
+ Ok(_) => {}
+ Err(err) => return Ok(Err(err)),
+ };
+
+ let array_builder = builder
+ .finish()
+ .data()
+ .clone()
+ .into_builder()
+ .null_bit_buffer(null_buffer)
+ .null_count(null_count);
+
+ let array_data = unsafe { array_builder.build_unchecked() };
+ Ok(Ok(PrimitiveArray::<T>::from(array_data)))
+ }
+}
+
/// This intentional inline(never) attribute helps LLVM optimize the loop.
#[inline(never)]
fn try_binary_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
@@ -310,6 +456,35 @@ where
Ok(unsafe { build_primitive_array(len, buffer.into(), 0, None) })
}
+/// This intentional inline(never) attribute helps LLVM optimize the loop.
+#[inline(never)]
+fn try_binary_no_nulls_mut<T, F>(
+ len: usize,
+ a: PrimitiveArray<T>,
+ b: &PrimitiveArray<T>,
+ op: F,
+) -> std::result::Result<
+ std::result::Result<PrimitiveArray<T>, ArrowError>,
+ PrimitiveArray<T>,
+>
+where
+ T: ArrowPrimitiveType,
+ F: Fn(T::Native, T::Native) -> Result<T::Native>,
+{
+ let mut builder = a.into_builder()?;
+ let slice = builder.values_slice_mut();
+
+ for idx in 0..len {
+ unsafe {
+ match op(*slice.get_unchecked(idx), b.value_unchecked(idx)) {
+ Ok(value) => *slice.get_unchecked_mut(idx) = value,
+ Err(err) => return Ok(Err(err)),
+ };
+ };
+ }
+ Ok(Ok(builder.finish()))
+}
+
#[inline(never)]
fn try_binary_opt_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
len: usize,
@@ -385,6 +560,7 @@ mod tests {
use super::*;
use crate::array::{as_primitive_array, Float64Array, PrimitiveDictionaryBuilder};
use crate::datatypes::{Float64Type, Int32Type, Int8Type};
+ use arrow_array::Int32Array;
#[test]
fn test_unary_f64_slice() {
@@ -444,4 +620,44 @@ mod tests {
&expected
);
}
+
+ #[test]
+ fn test_binary_mut() {
+ let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
+ let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
+ let c = binary_mut(a, &b, |l, r| l + r).unwrap().unwrap();
+
+ let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
+ assert_eq!(c, expected);
+ }
+
+ #[test]
+ fn test_try_binary_mut() {
+ let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
+ let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
+ let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();
+
+ let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
+ assert_eq!(c, expected);
+
+ let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
+ let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
+ let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();
+ let expected = Int32Array::from(vec![16, 16, 12, 12, 6]);
+ assert_eq!(c, expected);
+
+ let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
+ let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
+ let _ = try_binary_mut(a, &b, |l, r| {
+ if l == 1 {
+ Err(ArrowError::InvalidArgumentError(
+ "got error".parse().unwrap(),
+ ))
+ } else {
+ Ok(l + r)
+ }
+ })
+ .unwrap()
+ .expect_err("should got error");
+ }
}