You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/02/06 11:25:58 UTC

[arrow] branch master updated: ARROW-11436: [Rust] Improved from_iter for primitive arrays (-20-30% for cast)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 65f8026  ARROW-11436: [Rust] Improved from_iter for primitive arrays (-20-30% for cast)
65f8026 is described below

commit 65f8026d7260bcd3f21077e36206aaee4b8900e2
Author: Jorge C. Leitao <jo...@gmail.com>
AuthorDate: Sat Feb 6 06:24:59 2021 -0500

    ARROW-11436: [Rust] Improved from_iter for primitive arrays (-20-30% for cast)
    
    This PR refactors `PrimitiveArray::from_iter` to support non-sized iterators, as that is the expectation from `Rust::FromIter`. This makes `from_iter` slower.
    
    To compensate, this PR introduces a new method, `unsafe from_trusted_len_iter` to create an array from an iterator of `Option<T>` that has a trusted len. This is 20-30% faster than using the `from_iter` implemented in master.
    
    This PR uses `from_trusted_len_iter` to speed up casting of `primitive -> primitive` and `bool -> primitive` by 20-30%.
    
    Note that the added complexity (of having new functions and being unable to rely on `collect`) arises from two `unstable` features that we can't use:
    
    1. `unsafe trait TrustedLen`
    2. specialization
    
    if we had both features, we could implement `TrustedLen` for all our iterators, and then use specialization to offer specialized `FromIter` for them. Once these features are stabilized, we can simplify our API by allowing users to call `collect` on an iterator and let the compiler conclude that the iterator is `TrustedLen` and therefore should use the specialized implementation. :)
    
    ```
    Switched to branch 'iter_p'
       Compiling arrow v4.0.0-SNAPSHOT (/Users/jorgecarleitao/projects/arrow/rust/arrow)
        Finished bench [optimized] target(s) in 1m 11s
         Running /Users/jorgecarleitao/projects/arrow/rust/target/release/deps/cast_kernels-9a3b6d213a9f7a9a
    Gnuplot not found, using plotters backend
    cast int32 to int32 512 time:   [25.384 ns 25.415 ns 25.449 ns]
                            change: [-2.1332% -1.4846% -0.7930%] (p = 0.00 < 0.05)
                            Change within noise threshold.
    Found 11 outliers among 100 measurements (11.00%)
      3 (3.00%) high mild
      8 (8.00%) high severe
    
    cast int32 to uint32 512
                            time:   [2.1526 us 2.1576 us 2.1629 us]
                            change: [-24.823% -24.221% -23.610%] (p = 0.00 < 0.05)
                            Performance has improved.
    Found 9 outliers among 100 measurements (9.00%)
      1 (1.00%) low mild
      4 (4.00%) high mild
      4 (4.00%) high severe
    
    cast int32 to float32 512
                            time:   [2.4012 us 2.4083 us 2.4170 us]
                            change: [-20.381% -19.079% -17.860%] (p = 0.00 < 0.05)
                            Performance has improved.
    Found 7 outliers among 100 measurements (7.00%)
      1 (1.00%) low mild
      3 (3.00%) high mild
      3 (3.00%) high severe
    
    cast int32 to float64 512
                            time:   [2.4544 us 2.4608 us 2.4680 us]
                            change: [-24.689% -23.471% -22.441%] (p = 0.00 < 0.05)
                            Performance has improved.
    Found 8 outliers among 100 measurements (8.00%)
      4 (4.00%) high mild
      4 (4.00%) high severe
    
    cast int32 to int64 512 time:   [2.2424 us 2.2532 us 2.2663 us]
                            change: [-28.692% -28.008% -27.316%] (p = 0.00 < 0.05)
                            Performance has improved.
    Found 7 outliers among 100 measurements (7.00%)
      3 (3.00%) high mild
      4 (4.00%) high severe
    
    cast float32 to int32 512
                            time:   [2.4966 us 2.5063 us 2.5176 us]
                            change: [-26.820% -26.281% -25.755%] (p = 0.00 < 0.05)
                            Performance has improved.
    Found 6 outliers among 100 measurements (6.00%)
      1 (1.00%) low mild
      2 (2.00%) high mild
      3 (3.00%) high severe
    
    cast float64 to float32 512
                            time:   [2.4857 us 2.4954 us 2.5070 us]
                            change: [-22.439% -21.818% -21.222%] (p = 0.00 < 0.05)
                            Performance has improved.
    Found 6 outliers among 100 measurements (6.00%)
      1 (1.00%) low mild
      3 (3.00%) high mild
      2 (2.00%) high severe
    
    cast float64 to uint64 512
                            time:   [2.8313 us 2.8369 us 2.8427 us]
                            change: [-32.996% -32.605% -32.263%] (p = 0.00 < 0.05)
                            Performance has improved.
    Found 9 outliers among 100 measurements (9.00%)
      6 (6.00%) high mild
      3 (3.00%) high severe
    
    cast int64 to int32 512 time:   [2.2000 us 2.2073 us 2.2154 us]
                            change: [-32.106% -31.271% -30.265%] (p = 0.00 < 0.05)
                            Performance has improved.
    Found 11 outliers among 100 measurements (11.00%)
      1 (1.00%) low mild
      5 (5.00%) high mild
      5 (5.00%) high severe
    
    cast date64 to date32 512
                            time:   [1.0772 us 1.0815 us 1.0866 us]
                            change: [+1.3054% +3.6485% +6.8389%] (p = 0.01 < 0.05)
                            Performance has regressed.
    Found 14 outliers among 100 measurements (14.00%)
      2 (2.00%) high mild
      12 (12.00%) high severe
    
    cast date32 to date64 512
                            time:   [838.20 ns 840.35 ns 842.74 ns]
                            change: [-1.2203% -0.3511% +0.5828%] (p = 0.47 > 0.05)
                            No change in performance detected.
    Found 10 outliers among 100 measurements (10.00%)
      3 (3.00%) low mild
      2 (2.00%) high mild
      5 (5.00%) high severe
    
    cast time32s to time32ms 512
                            time:   [741.99 ns 745.21 ns 748.67 ns]
                            change: [-0.1386% +1.9748% +5.4043%] (p = 0.18 > 0.05)
                            No change in performance detected.
    Found 12 outliers among 100 measurements (12.00%)
      4 (4.00%) low mild
      2 (2.00%) high mild
      6 (6.00%) high severe
    
    cast time32s to time64us 512
                            time:   [4.2476 us 4.2596 us 4.2747 us]
                            change: [-19.580% -18.601% -17.667%] (p = 0.00 < 0.05)
                            Performance has improved.
    Found 11 outliers among 100 measurements (11.00%)
      2 (2.00%) low mild
      6 (6.00%) high mild
      3 (3.00%) high severe
    
    cast time64ns to time32s 512
                            time:   [4.9276 us 4.9371 us 4.9489 us]
                            change: [-0.2071% +0.6046% +1.5380%] (p = 0.17 > 0.05)
                            No change in performance detected.
    Found 8 outliers among 100 measurements (8.00%)
      1 (1.00%) low mild
      2 (2.00%) high mild
      5 (5.00%) high severe
    
    cast timestamp_ns to timestamp_s 512
                            time:   [25.938 ns 26.005 ns 26.079 ns]
                            change: [-2.0321% -1.2763% -0.4590%] (p = 0.00 < 0.05)
                            Change within noise threshold.
    Found 9 outliers among 100 measurements (9.00%)
      1 (1.00%) low mild
      3 (3.00%) high mild
      5 (5.00%) high severe
    
    cast timestamp_ms to timestamp_ns 512
                            time:   [2.0140 us 2.0187 us 2.0234 us]
                            change: [+0.4914% +1.5932% +2.6619%] (p = 0.00 < 0.05)
                            Change within noise threshold.
    Found 6 outliers among 100 measurements (6.00%)
      3 (3.00%) high mild
      3 (3.00%) high severe
    
    cast utf8 to f32        time:   [28.568 us 28.651 us 28.749 us]
                            change: [-5.4918% -4.8116% -4.1925%] (p = 0.00 < 0.05)
                            Performance has improved.
    Found 9 outliers among 100 measurements (9.00%)
      1 (1.00%) low mild
      5 (5.00%) high mild
      3 (3.00%) high severe
    
    cast i64 to string 512  time:   [50.182 us 50.270 us 50.366 us]
                            change: [-2.7854% -1.4471% +0.2627%] (p = 0.05 > 0.05)
                            No change in performance detected.
    Found 13 outliers among 100 measurements (13.00%)
      7 (7.00%) high mild
      6 (6.00%) high severe
    
    cast f32 to string 512  time:   [54.687 us 54.833 us 54.983 us]
                            change: [+2.6122% +3.3716% +4.2074%] (p = 0.00 < 0.05)
                            Performance has regressed.
    Found 5 outliers among 100 measurements (5.00%)
      1 (1.00%) high mild
      4 (4.00%) high severe
    
    cast timestamp_ms to i64 512
                            time:   [424.64 ns 426.27 ns 428.09 ns]
                            change: [+1.3146% +2.0623% +2.7867%] (p = 0.00 < 0.05)
                            Performance has regressed.
    Found 7 outliers among 100 measurements (7.00%)
      2 (2.00%) low mild
      2 (2.00%) high mild
      3 (3.00%) high severe
    
    cast utf8 to date32 512 time:   [45.104 us 45.202 us 45.312 us]
                            change: [-0.3430% +0.3110% +1.0109%] (p = 0.40 > 0.05)
                            No change in performance detected.
    Found 10 outliers among 100 measurements (10.00%)
      5 (5.00%) high mild
      5 (5.00%) high severe
    
    cast utf8 to date64 512 time:   [75.844 us 76.028 us 76.239 us]
                            change: [-10.720% -10.018% -9.3075%] (p = 0.00 < 0.05)
                            Performance has improved.
    Found 5 outliers among 100 measurements (5.00%)
      3 (3.00%) high mild
      2 (2.00%) high severe
    
    ```
    
    Closes #9370 from jorgecarleitao/iter_p
    
    Authored-by: Jorge C. Leitao <jo...@gmail.com>
    Signed-off-by: Andrew Lamb <an...@nerdnetworks.org>
---
 rust/arrow/src/array/array_primitive.rs | 67 ++++++++++++++++++---------
 rust/arrow/src/array/builder.rs         | 10 ++++
 rust/arrow/src/array/iterator.rs        |  1 +
 rust/arrow/src/buffer.rs                |  8 ++++
 rust/arrow/src/compute/kernels/cast.rs  | 57 +++++++++++++----------
 rust/arrow/src/util/mod.rs              |  3 ++
 rust/arrow/src/util/trusted_len.rs      | 82 +++++++++++++++++++++++++++++++++
 7 files changed, 182 insertions(+), 46 deletions(-)

diff --git a/rust/arrow/src/array/array_primitive.rs b/rust/arrow/src/array/array_primitive.rs
index 2056b88..0bde7bc 100644
--- a/rust/arrow/src/array/array_primitive.rs
+++ b/rust/arrow/src/array/array_primitive.rs
@@ -28,9 +28,12 @@ use chrono::prelude::*;
 use super::array::print_long_array;
 use super::raw_pointer::RawPtrBox;
 use super::*;
-use crate::buffer::{Buffer, MutableBuffer};
 use crate::temporal_conversions;
 use crate::util::bit_util;
+use crate::{
+    buffer::{Buffer, MutableBuffer},
+    util::trusted_len_unzip,
+};
 
 /// Number of seconds in a day
 const SECONDS_IN_DAY: i64 = 86_400;
@@ -267,41 +270,61 @@ impl<T: ArrowPrimitiveType, Ptr: Borrow<Option<<T as ArrowPrimitiveType>::Native
 {
     fn from_iter<I: IntoIterator<Item = Ptr>>(iter: I) -> Self {
         let iter = iter.into_iter();
-        let (_, data_len) = iter.size_hint();
-        let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound.
+        let (lower, _) = iter.size_hint();
 
-        let num_bytes = bit_util::ceil(data_len, 8);
-        let mut null_buf = MutableBuffer::from_len_zeroed(num_bytes);
-        let mut val_buf = MutableBuffer::new(
-            data_len * mem::size_of::<<T as ArrowPrimitiveType>::Native>(),
-        );
+        let mut null_buf = BooleanBufferBuilder::new(lower);
 
-        let null_slice = null_buf.as_slice_mut();
-        iter.enumerate().for_each(|(i, item)| {
-            if let Some(a) = item.borrow() {
-                bit_util::set_bit(null_slice, i);
-                val_buf.push(*a);
-            } else {
-                // this ensures that null items on the buffer are not arbitrary.
-                // This is important because falible operations can use null values (e.g. a vectorized "add")
-                // which may panic (e.g. overflow if the number on the slots happen to be very large).
-                val_buf.push(T::Native::default());
-            }
-        });
+        let buffer: Buffer = iter
+            .map(|item| {
+                if let Some(a) = item.borrow() {
+                    null_buf.append(true);
+                    *a
+                } else {
+                    null_buf.append(false);
+                    // this ensures that null items on the buffer are not arbitrary.
+                    // This is important because falible operations can use null values (e.g. a vectorized "add")
+                    // which may panic (e.g. overflow if the number on the slots happen to be very large).
+                    T::Native::default()
+                }
+            })
+            .collect();
 
         let data = ArrayData::new(
             T::DATA_TYPE,
-            data_len,
+            null_buf.len(),
             None,
             Some(null_buf.into()),
             0,
-            vec![val_buf.into()],
+            vec![buffer],
             vec![],
         );
         PrimitiveArray::from(Arc::new(data))
     }
 }
 
+impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
+    /// Creates a [`PrimitiveArray`] from an iterator of trusted length.
+    /// # Safety
+    /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html).
+    /// I.e. that `size_hint().1` correctly reports its length.
+    #[inline]
+    pub unsafe fn from_trusted_len_iter<I, P>(iter: I) -> Self
+    where
+        P: std::borrow::Borrow<Option<<T as ArrowPrimitiveType>::Native>>,
+        I: IntoIterator<Item = P>,
+    {
+        let iterator = iter.into_iter();
+        let (_, upper) = iterator.size_hint();
+        let len = upper.expect("trusted_len_unzip requires an upper limit");
+
+        let (null, buffer) = trusted_len_unzip(iterator);
+
+        let data =
+            ArrayData::new(T::DATA_TYPE, len, None, Some(null), 0, vec![buffer], vec![]);
+        PrimitiveArray::from(Arc::new(data))
+    }
+}
+
 // TODO: the macro is needed here because we'd get "conflicting implementations" error
 // otherwise with both `From<Vec<T::Native>>` and `From<Vec<Option<T::Native>>>`.
 // We should revisit this in future.
diff --git a/rust/arrow/src/array/builder.rs b/rust/arrow/src/array/builder.rs
index 61724da..ad5b39f 100644
--- a/rust/arrow/src/array/builder.rs
+++ b/rust/arrow/src/array/builder.rs
@@ -297,14 +297,17 @@ impl BooleanBufferBuilder {
         Self { buffer, len: 0 }
     }
 
+    #[inline]
     pub fn len(&self) -> usize {
         self.len
     }
 
+    #[inline]
     pub fn is_empty(&self) -> bool {
         self.len == 0
     }
 
+    #[inline]
     pub fn capacity(&self) -> usize {
         self.buffer.capacity() * 8
     }
@@ -372,6 +375,13 @@ impl BooleanBufferBuilder {
     }
 }
 
+impl From<BooleanBufferBuilder> for Buffer {
+    #[inline]
+    fn from(builder: BooleanBufferBuilder) -> Self {
+        builder.buffer.into()
+    }
+}
+
 /// Trait for dealing with different array builders at runtime
 pub trait ArrayBuilder: Any {
     /// Returns the number of array slots in the builder
diff --git a/rust/arrow/src/array/iterator.rs b/rust/arrow/src/array/iterator.rs
index ff1a830..cd891ba 100644
--- a/rust/arrow/src/array/iterator.rs
+++ b/rust/arrow/src/array/iterator.rs
@@ -46,6 +46,7 @@ impl<'a, T: ArrowPrimitiveType> PrimitiveIter<'a, T> {
 impl<'a, T: ArrowPrimitiveType> std::iter::Iterator for PrimitiveIter<'a, T> {
     type Item = Option<T::Native>;
 
+    #[inline]
     fn next(&mut self) -> Option<Self::Item> {
         if self.current == self.current_end {
             None
diff --git a/rust/arrow/src/buffer.rs b/rust/arrow/src/buffer.rs
index 024db5e..1f3ef58 100644
--- a/rust/arrow/src/buffer.rs
+++ b/rust/arrow/src/buffer.rs
@@ -990,6 +990,14 @@ impl MutableBuffer {
     pub fn extend_zeros(&mut self, additional: usize) {
         self.resize(self.len + additional, 0);
     }
+
+    /// # Safety
+    /// The caller must ensure that the buffer was properly initialized up to `len`.
+    #[inline]
+    pub(crate) unsafe fn set_len(&mut self, len: usize) {
+        assert!(len <= self.capacity());
+        self.len = len;
+    }
 }
 
 /// # Safety
diff --git a/rust/arrow/src/compute/kernels/cast.rs b/rust/arrow/src/compute/kernels/cast.rs
index c554225..3358e68 100644
--- a/rust/arrow/src/compute/kernels/cast.rs
+++ b/rust/arrow/src/compute/kernels/cast.rs
@@ -860,9 +860,12 @@ where
     T::Native: num::NumCast,
     R::Native: num::NumCast,
 {
-    from.iter()
-        .map(|v| v.and_then(num::cast::cast::<T::Native, R::Native>))
-        .collect()
+    let iter = from
+        .iter()
+        .map(|v| v.and_then(num::cast::cast::<T::Native, R::Native>));
+    // Soundness:
+    //  The iterator is trustedLen because it comes from an `PrimitiveArray`.
+    unsafe { PrimitiveArray::<R>::from_trusted_len_iter(iter) }
 }
 
 /// Cast numeric types to Utf8
@@ -905,15 +908,18 @@ where
     T: ArrowNumericType,
     <T as ArrowPrimitiveType>::Native: lexical_core::FromLexical,
 {
-    (0..from.len())
-        .map(|i| {
-            if from.is_null(i) {
-                None
-            } else {
-                lexical_core::parse(from.value(i).as_bytes()).ok()
-            }
-        })
-        .collect()
+    let iter = (0..from.len()).map(|i| {
+        if from.is_null(i) {
+            None
+        } else {
+            lexical_core::parse(from.value(i).as_bytes()).ok()
+        }
+    });
+    // Benefit:
+    //     20% performance improvement
+    // Soundness:
+    //     The iterator is trustedLen because it comes from an `StringArray`.
+    unsafe { PrimitiveArray::<T>::from_trusted_len_iter(iter) }
 }
 
 /// Cast numeric types to Boolean
@@ -968,18 +974,21 @@ where
     T: ArrowNumericType,
     T::Native: num::NumCast,
 {
-    (0..from.len())
-        .map(|i| {
-            if from.is_null(i) {
-                None
-            } else if from.value(i) {
-                // a workaround to cast a primitive to T::Native, infallible
-                num::cast::cast(1)
-            } else {
-                Some(T::default_value())
-            }
-        })
-        .collect()
+    let iter = (0..from.len()).map(|i| {
+        if from.is_null(i) {
+            None
+        } else if from.value(i) {
+            // a workaround to cast a primitive to T::Native, infallible
+            num::cast::cast(1)
+        } else {
+            Some(T::default_value())
+        }
+    });
+    // Benefit:
+    //     20% performance improvement
+    // Soundness:
+    //     The iterator is trustedLen because it comes from a Range
+    unsafe { PrimitiveArray::<T>::from_trusted_len_iter(iter) }
 }
 
 /// Attempts to cast an `ArrayDictionary` with index type K into
diff --git a/rust/arrow/src/util/mod.rs b/rust/arrow/src/util/mod.rs
index af9b458..c0b5a3e 100644
--- a/rust/arrow/src/util/mod.rs
+++ b/rust/arrow/src/util/mod.rs
@@ -25,3 +25,6 @@ pub mod pretty;
 pub(crate) mod serialization;
 pub mod string_writer;
 pub mod test_util;
+
+mod trusted_len;
+pub(crate) use trusted_len::trusted_len_unzip;
diff --git a/rust/arrow/src/util/trusted_len.rs b/rust/arrow/src/util/trusted_len.rs
new file mode 100644
index 0000000..84a6623
--- /dev/null
+++ b/rust/arrow/src/util/trusted_len.rs
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::bit_util;
+use crate::{
+    buffer::{Buffer, MutableBuffer},
+    datatypes::ArrowNativeType,
+};
+
+/// Creates two [`Buffer`]s from an iterator of `Option`.
+/// The first buffer corresponds to a bitmap buffer, the second one
+/// corresponds to a values buffer.
+/// # Safety
+/// The caller must ensure that `iterator` is `TrustedLen`.
+#[inline]
+pub(crate) unsafe fn trusted_len_unzip<I, P, T>(iterator: I) -> (Buffer, Buffer)
+where
+    T: ArrowNativeType,
+    P: std::borrow::Borrow<Option<T>>,
+    I: Iterator<Item = P>,
+{
+    let (_, upper) = iterator.size_hint();
+    let upper = upper.expect("trusted_len_unzip requires an upper limit");
+    let len = upper * std::mem::size_of::<T>();
+
+    let mut null = MutableBuffer::from_len_zeroed(upper.saturating_add(7) / 8);
+    let mut buffer = MutableBuffer::new(len);
+
+    let dst_null = null.as_mut_ptr();
+    let mut dst = buffer.as_mut_ptr() as *mut T;
+    for (i, item) in iterator.enumerate() {
+        let item = item.borrow();
+        if let Some(item) = item {
+            std::ptr::write(dst, *item);
+            bit_util::set_bit_raw(dst_null, i);
+        } else {
+            std::ptr::write(dst, T::default());
+        }
+        dst = dst.add(1);
+    }
+    assert_eq!(
+        dst.offset_from(buffer.as_ptr() as *mut T) as usize,
+        upper,
+        "Trusted iterator length was not accurately reported"
+    );
+    buffer.set_len(len);
+    (null.into(), buffer.into())
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn trusted_len_unzip_good() {
+        let vec = vec![Some(1u32), None];
+        let (null, buffer) = unsafe { trusted_len_unzip(vec.iter()) };
+        assert_eq!(null.as_slice(), &[0b00000001]);
+        assert_eq!(buffer.as_slice(), &[1u8, 0, 0, 0, 0, 0, 0, 0]);
+    }
+
+    #[test]
+    #[should_panic(expected = "trusted_len_unzip requires an upper limit")]
+    fn trusted_len_unzip_panic() {
+        let iter = std::iter::repeat(Some(4i32));
+        unsafe { trusted_len_unzip(iter) };
+    }
+}