You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/11/28 21:16:49 UTC

[arrow-rs] branch master updated: Add try_unary_mut (#3134)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 5d84746cf Add try_unary_mut (#3134)
5d84746cf is described below

commit 5d84746cfdfe3ae9a2678f10b4dbb2e9385dc479
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Mon Nov 28 13:16:43 2022 -0800

    Add try_unary_mut (#3134)
    
    * Add add_scalar_mut and add_scalar_checked_mut
    
    * Update slice related functions for completeness.
    
    * Change result type
    
    * Update API doc
    
    * Remove _mut arithmetic kernels
    
    * For review
---
 arrow-array/src/array/primitive_array.rs          | 38 +++++++++++++++++++++++
 arrow-array/src/builder/boolean_buffer_builder.rs |  5 +++
 arrow-array/src/builder/null_buffer_builder.rs    |  4 +++
 arrow-array/src/builder/primitive_builder.rs      | 18 +++++++++++
 arrow/src/compute/kernels/arithmetic.rs           | 23 ++++++++++++++
 arrow/src/compute/kernels/arity.rs                | 27 ++++++++++++++++
 6 files changed, 115 insertions(+)

diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs
index e3d14e79d..036ef0cdd 100644
--- a/arrow-array/src/array/primitive_array.rs
+++ b/arrow-array/src/array/primitive_array.rs
@@ -505,6 +505,44 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
         })
     }
 
+    /// Applies an unary and fallible function to all valid values in a mutable primitive array.
+    /// Mutable primitive array means that the buffer is not shared with other arrays.
+    /// As a result, this mutates the buffer directly without allocating new buffer.
+    ///
+    /// This is unlike [`Self::unary_mut`] which will apply an infallible function to all rows
+    /// regardless of validity, in many cases this will be significantly faster and should
+    /// be preferred if `op` is infallible.
+    ///
+    /// This returns an `Err` when the input array is shared buffer with other
+    /// array. In the case, returned `Err` wraps input array. If the function
+    /// encounters an error during applying on values. In the case, this returns an `Err` within
+    /// an `Ok` which wraps the actual error.
+    ///
+    /// Note: LLVM is currently unable to effectively vectorize fallible operations
+    pub fn try_unary_mut<F, E>(
+        self,
+        op: F,
+    ) -> Result<Result<PrimitiveArray<T>, E>, PrimitiveArray<T>>
+    where
+        F: Fn(T::Native) -> Result<T::Native, E>,
+    {
+        let len = self.len();
+        let null_count = self.null_count();
+        let mut builder = self.into_builder()?;
+
+        let (slice, null_buffer) = builder.slices_mut();
+
+        match try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| {
+            unsafe { *slice.get_unchecked_mut(idx) = op(*slice.get_unchecked(idx))? };
+            Ok::<_, E>(())
+        }) {
+            Ok(_) => {}
+            Err(err) => return Ok(Err(err)),
+        };
+
+        Ok(Ok(builder.finish()))
+    }
+
     /// Applies a unary and nullable function to all valid values in a primitive array
     ///
     /// This is unlike [`Self::unary`] which will apply an infallible function to all rows
diff --git a/arrow-array/src/builder/boolean_buffer_builder.rs b/arrow-array/src/builder/boolean_buffer_builder.rs
index 4f8638ee7..7d86f74f6 100644
--- a/arrow-array/src/builder/boolean_buffer_builder.rs
+++ b/arrow-array/src/builder/boolean_buffer_builder.rs
@@ -168,6 +168,11 @@ impl BooleanBufferBuilder {
         self.buffer.as_slice()
     }
 
+    /// Returns the packed bits
+    pub fn as_slice_mut(&mut self) -> &mut [u8] {
+        self.buffer.as_slice_mut()
+    }
+
     /// Creates a [`Buffer`]
     #[inline]
     pub fn finish(&mut self) -> Buffer {
diff --git a/arrow-array/src/builder/null_buffer_builder.rs b/arrow-array/src/builder/null_buffer_builder.rs
index b3c788fe5..0061f70c7 100644
--- a/arrow-array/src/builder/null_buffer_builder.rs
+++ b/arrow-array/src/builder/null_buffer_builder.rs
@@ -154,6 +154,10 @@ impl NullBufferBuilder {
             self.bitmap_builder = Some(b);
         }
     }
+
+    pub fn as_slice_mut(&mut self) -> Option<&mut [u8]> {
+        self.bitmap_builder.as_mut().map(|b| b.as_slice_mut())
+    }
 }
 
 impl NullBufferBuilder {
diff --git a/arrow-array/src/builder/primitive_builder.rs b/arrow-array/src/builder/primitive_builder.rs
index ef420dcbc..fa1dc3ad1 100644
--- a/arrow-array/src/builder/primitive_builder.rs
+++ b/arrow-array/src/builder/primitive_builder.rs
@@ -285,6 +285,24 @@ impl<T: ArrowPrimitiveType> PrimitiveBuilder<T> {
     pub fn values_slice_mut(&mut self) -> &mut [T::Native] {
         self.values_builder.as_slice_mut()
     }
+
+    /// Returns the current values buffer as a slice
+    pub fn validity_slice(&self) -> Option<&[u8]> {
+        self.null_buffer_builder.as_slice()
+    }
+
+    /// Returns the current values buffer as a mutable slice
+    pub fn validity_slice_mut(&mut self) -> Option<&mut [u8]> {
+        self.null_buffer_builder.as_slice_mut()
+    }
+
+    /// Returns the current values buffer and null buffer as a slice
+    pub fn slices_mut(&mut self) -> (&mut [T::Native], Option<&mut [u8]>) {
+        (
+            self.values_builder.as_slice_mut(),
+            self.null_buffer_builder.as_slice_mut(),
+        )
+    }
 }
 
 #[cfg(test)]
diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs
index a99a90204..f9deada53 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -1624,6 +1624,7 @@ where
 mod tests {
     use super::*;
     use crate::array::Int32Array;
+    use crate::compute::{try_unary_mut, unary_mut};
     use crate::datatypes::{Date64Type, Int32Type, Int8Type};
     use arrow_buffer::i256;
     use chrono::NaiveDate;
@@ -3098,4 +3099,26 @@ mod tests {
         assert_eq!(result.len(), 13);
         assert_eq!(result.null_count(), 13);
     }
+
+    #[test]
+    fn test_primitive_add_scalar_by_unary_mut() {
+        let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
+        let b = 3;
+        let c = unary_mut(a, |value| value.add_wrapping(b)).unwrap();
+        let expected = Int32Array::from(vec![18, 17, 12, 11, 4]);
+        assert_eq!(c, expected);
+    }
+
+    #[test]
+    fn test_primitive_add_scalar_overflow_by_try_unary_mut() {
+        let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
+
+        let wrapped = unary_mut(a, |value| value.add_wrapping(1)).unwrap();
+        let expected = Int32Array::from(vec![-2147483648, -2147483647]);
+        assert_eq!(expected, wrapped);
+
+        let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
+        let overflow = try_unary_mut(a, |value| value.add_checked(1));
+        let _ = overflow.unwrap().expect_err("overflow should be detected");
+    }
 }
diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs
index c99d2b727..946d15e9e 100644
--- a/arrow/src/compute/kernels/arity.rs
+++ b/arrow/src/compute/kernels/arity.rs
@@ -58,6 +58,18 @@ where
     array.unary(op)
 }
 
+/// See [`PrimitiveArray::unary_mut`]
+pub fn unary_mut<I, F>(
+    array: PrimitiveArray<I>,
+    op: F,
+) -> std::result::Result<PrimitiveArray<I>, PrimitiveArray<I>>
+where
+    I: ArrowPrimitiveType,
+    F: Fn(I::Native) -> I::Native,
+{
+    array.unary_mut(op)
+}
+
 /// See [`PrimitiveArray::try_unary`]
 pub fn try_unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> Result<PrimitiveArray<O>>
 where
@@ -68,6 +80,21 @@ where
     array.try_unary(op)
 }
 
+/// See [`PrimitiveArray::try_unary_mut`]
+pub fn try_unary_mut<I, F>(
+    array: PrimitiveArray<I>,
+    op: F,
+) -> std::result::Result<
+    std::result::Result<PrimitiveArray<I>, ArrowError>,
+    PrimitiveArray<I>,
+>
+where
+    I: ArrowPrimitiveType,
+    F: Fn(I::Native) -> Result<I::Native>,
+{
+    array.try_unary_mut(op)
+}
+
 /// A helper function that applies an infallible unary function to a dictionary array with primitive value type.
 fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
 where