You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/11/28 21:16:49 UTC
[arrow-rs] branch master updated: Add try_unary_mut (#3134)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 5d84746cf Add try_unary_mut (#3134)
5d84746cf is described below
commit 5d84746cfdfe3ae9a2678f10b4dbb2e9385dc479
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Mon Nov 28 13:16:43 2022 -0800
Add try_unary_mut (#3134)
* Add add_scalar_mut and add_scalar_checked_mut
* Update slice related functions for completeness.
* Change result type
* Update API doc
* Remove _mut arithmetic kernels
* For review
---
arrow-array/src/array/primitive_array.rs | 38 +++++++++++++++++++++++
arrow-array/src/builder/boolean_buffer_builder.rs | 5 +++
arrow-array/src/builder/null_buffer_builder.rs | 4 +++
arrow-array/src/builder/primitive_builder.rs | 18 +++++++++++
arrow/src/compute/kernels/arithmetic.rs | 23 ++++++++++++++
arrow/src/compute/kernels/arity.rs | 27 ++++++++++++++++
6 files changed, 115 insertions(+)
diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs
index e3d14e79d..036ef0cdd 100644
--- a/arrow-array/src/array/primitive_array.rs
+++ b/arrow-array/src/array/primitive_array.rs
@@ -505,6 +505,44 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
})
}
+ /// Applies an unary and fallible function to all valid values in a mutable primitive array.
+ /// Mutable primitive array means that the buffer is not shared with other arrays.
+ /// As a result, this mutates the buffer directly without allocating new buffer.
+ ///
+ /// This is unlike [`Self::unary_mut`] which will apply an infallible function to all rows
+ /// regardless of validity, in many cases this will be significantly faster and should
+ /// be preferred if `op` is infallible.
+ ///
+ /// This returns an `Err` when the input array is shared buffer with other
+ /// array. In the case, returned `Err` wraps input array. If the function
+ /// encounters an error during applying on values. In the case, this returns an `Err` within
+ /// an `Ok` which wraps the actual error.
+ ///
+ /// Note: LLVM is currently unable to effectively vectorize fallible operations
+ pub fn try_unary_mut<F, E>(
+ self,
+ op: F,
+ ) -> Result<Result<PrimitiveArray<T>, E>, PrimitiveArray<T>>
+ where
+ F: Fn(T::Native) -> Result<T::Native, E>,
+ {
+ let len = self.len();
+ let null_count = self.null_count();
+ let mut builder = self.into_builder()?;
+
+ let (slice, null_buffer) = builder.slices_mut();
+
+ match try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| {
+ unsafe { *slice.get_unchecked_mut(idx) = op(*slice.get_unchecked(idx))? };
+ Ok::<_, E>(())
+ }) {
+ Ok(_) => {}
+ Err(err) => return Ok(Err(err)),
+ };
+
+ Ok(Ok(builder.finish()))
+ }
+
/// Applies a unary and nullable function to all valid values in a primitive array
///
/// This is unlike [`Self::unary`] which will apply an infallible function to all rows
diff --git a/arrow-array/src/builder/boolean_buffer_builder.rs b/arrow-array/src/builder/boolean_buffer_builder.rs
index 4f8638ee7..7d86f74f6 100644
--- a/arrow-array/src/builder/boolean_buffer_builder.rs
+++ b/arrow-array/src/builder/boolean_buffer_builder.rs
@@ -168,6 +168,11 @@ impl BooleanBufferBuilder {
self.buffer.as_slice()
}
+ /// Returns the packed bits
+ pub fn as_slice_mut(&mut self) -> &mut [u8] {
+ self.buffer.as_slice_mut()
+ }
+
/// Creates a [`Buffer`]
#[inline]
pub fn finish(&mut self) -> Buffer {
diff --git a/arrow-array/src/builder/null_buffer_builder.rs b/arrow-array/src/builder/null_buffer_builder.rs
index b3c788fe5..0061f70c7 100644
--- a/arrow-array/src/builder/null_buffer_builder.rs
+++ b/arrow-array/src/builder/null_buffer_builder.rs
@@ -154,6 +154,10 @@ impl NullBufferBuilder {
self.bitmap_builder = Some(b);
}
}
+
+ pub fn as_slice_mut(&mut self) -> Option<&mut [u8]> {
+ self.bitmap_builder.as_mut().map(|b| b.as_slice_mut())
+ }
}
impl NullBufferBuilder {
diff --git a/arrow-array/src/builder/primitive_builder.rs b/arrow-array/src/builder/primitive_builder.rs
index ef420dcbc..fa1dc3ad1 100644
--- a/arrow-array/src/builder/primitive_builder.rs
+++ b/arrow-array/src/builder/primitive_builder.rs
@@ -285,6 +285,24 @@ impl<T: ArrowPrimitiveType> PrimitiveBuilder<T> {
pub fn values_slice_mut(&mut self) -> &mut [T::Native] {
self.values_builder.as_slice_mut()
}
+
+ /// Returns the current values buffer as a slice
+ pub fn validity_slice(&self) -> Option<&[u8]> {
+ self.null_buffer_builder.as_slice()
+ }
+
+ /// Returns the current values buffer as a mutable slice
+ pub fn validity_slice_mut(&mut self) -> Option<&mut [u8]> {
+ self.null_buffer_builder.as_slice_mut()
+ }
+
+ /// Returns the current values buffer and null buffer as a slice
+ pub fn slices_mut(&mut self) -> (&mut [T::Native], Option<&mut [u8]>) {
+ (
+ self.values_builder.as_slice_mut(),
+ self.null_buffer_builder.as_slice_mut(),
+ )
+ }
}
#[cfg(test)]
diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs
index a99a90204..f9deada53 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -1624,6 +1624,7 @@ where
mod tests {
use super::*;
use crate::array::Int32Array;
+ use crate::compute::{try_unary_mut, unary_mut};
use crate::datatypes::{Date64Type, Int32Type, Int8Type};
use arrow_buffer::i256;
use chrono::NaiveDate;
@@ -3098,4 +3099,26 @@ mod tests {
assert_eq!(result.len(), 13);
assert_eq!(result.null_count(), 13);
}
+
+ #[test]
+ fn test_primitive_add_scalar_by_unary_mut() {
+ let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
+ let b = 3;
+ let c = unary_mut(a, |value| value.add_wrapping(b)).unwrap();
+ let expected = Int32Array::from(vec![18, 17, 12, 11, 4]);
+ assert_eq!(c, expected);
+ }
+
+ #[test]
+ fn test_primitive_add_scalar_overflow_by_try_unary_mut() {
+ let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
+
+ let wrapped = unary_mut(a, |value| value.add_wrapping(1)).unwrap();
+ let expected = Int32Array::from(vec![-2147483648, -2147483647]);
+ assert_eq!(expected, wrapped);
+
+ let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
+ let overflow = try_unary_mut(a, |value| value.add_checked(1));
+ let _ = overflow.unwrap().expect_err("overflow should be detected");
+ }
}
diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs
index c99d2b727..946d15e9e 100644
--- a/arrow/src/compute/kernels/arity.rs
+++ b/arrow/src/compute/kernels/arity.rs
@@ -58,6 +58,18 @@ where
array.unary(op)
}
+/// See [`PrimitiveArray::unary_mut`]
+pub fn unary_mut<I, F>(
+ array: PrimitiveArray<I>,
+ op: F,
+) -> std::result::Result<PrimitiveArray<I>, PrimitiveArray<I>>
+where
+ I: ArrowPrimitiveType,
+ F: Fn(I::Native) -> I::Native,
+{
+ array.unary_mut(op)
+}
+
/// See [`PrimitiveArray::try_unary`]
pub fn try_unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> Result<PrimitiveArray<O>>
where
@@ -68,6 +80,21 @@ where
array.try_unary(op)
}
+/// See [`PrimitiveArray::try_unary_mut`]
+pub fn try_unary_mut<I, F>(
+ array: PrimitiveArray<I>,
+ op: F,
+) -> std::result::Result<
+ std::result::Result<PrimitiveArray<I>, ArrowError>,
+ PrimitiveArray<I>,
+>
+where
+ I: ArrowPrimitiveType,
+ F: Fn(I::Native) -> Result<I::Native>,
+{
+ array.try_unary_mut(op)
+}
+
/// A helper function that applies an infallible unary function to a dictionary array with primitive value type.
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
where