You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/02/06 11:25:58 UTC
[arrow] branch master updated: ARROW-11436: [Rust] Improved
from_iter for primitive arrays (-20-30% for cast)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 65f8026 ARROW-11436: [Rust] Improved from_iter for primitive arrays (-20-30% for cast)
65f8026 is described below
commit 65f8026d7260bcd3f21077e36206aaee4b8900e2
Author: Jorge C. Leitao <jo...@gmail.com>
AuthorDate: Sat Feb 6 06:24:59 2021 -0500
ARROW-11436: [Rust] Improved from_iter for primitive arrays (-20-30% for cast)
This PR refactors `PrimitiveArray::from_iter` to support non-sized iterators, as that is the expectation from `Rust::FromIter`. This makes `from_iter` slower.
To compensate, this PR introduces a new method, `unsafe from_trusted_len_iter` to create an array from an iterator of `Option<T>` that has a trusted len. This is 20-30% faster than using the `from_iter` implemented in master.
This PR uses `from_trusted_len_iter` to speed up casting of `primitive -> primitive` and `bool -> primitive` by 20-30%.
Note that the added complexity (of having new functions and being unable to rely on `collect`) arises from two `unstable` features that we can't use:
1. `unsafe trait TrustedLen`
2. specialization
if we had both features, we could implement `TrustedLen` for all our iterators, and then use specialization to offer specialized `FromIter` for them. Once these features are stabilized, we can simplify our API by allowing users to call `collect` on an iterator and let the compiler conclude that the iterator is `TrustedLen` and therefore should use the specialized implementation. :)
```
Switched to branch 'iter_p'
Compiling arrow v4.0.0-SNAPSHOT (/Users/jorgecarleitao/projects/arrow/rust/arrow)
Finished bench [optimized] target(s) in 1m 11s
Running /Users/jorgecarleitao/projects/arrow/rust/target/release/deps/cast_kernels-9a3b6d213a9f7a9a
Gnuplot not found, using plotters backend
cast int32 to int32 512 time: [25.384 ns 25.415 ns 25.449 ns]
change: [-2.1332% -1.4846% -0.7930%] (p = 0.00 < 0.05)
Change within noise threshold.
Found 11 outliers among 100 measurements (11.00%)
3 (3.00%) high mild
8 (8.00%) high severe
cast int32 to uint32 512
time: [2.1526 us 2.1576 us 2.1629 us]
change: [-24.823% -24.221% -23.610%] (p = 0.00 < 0.05)
Performance has improved.
Found 9 outliers among 100 measurements (9.00%)
1 (1.00%) low mild
4 (4.00%) high mild
4 (4.00%) high severe
cast int32 to float32 512
time: [2.4012 us 2.4083 us 2.4170 us]
change: [-20.381% -19.079% -17.860%] (p = 0.00 < 0.05)
Performance has improved.
Found 7 outliers among 100 measurements (7.00%)
1 (1.00%) low mild
3 (3.00%) high mild
3 (3.00%) high severe
cast int32 to float64 512
time: [2.4544 us 2.4608 us 2.4680 us]
change: [-24.689% -23.471% -22.441%] (p = 0.00 < 0.05)
Performance has improved.
Found 8 outliers among 100 measurements (8.00%)
4 (4.00%) high mild
4 (4.00%) high severe
cast int32 to int64 512 time: [2.2424 us 2.2532 us 2.2663 us]
change: [-28.692% -28.008% -27.316%] (p = 0.00 < 0.05)
Performance has improved.
Found 7 outliers among 100 measurements (7.00%)
3 (3.00%) high mild
4 (4.00%) high severe
cast float32 to int32 512
time: [2.4966 us 2.5063 us 2.5176 us]
change: [-26.820% -26.281% -25.755%] (p = 0.00 < 0.05)
Performance has improved.
Found 6 outliers among 100 measurements (6.00%)
1 (1.00%) low mild
2 (2.00%) high mild
3 (3.00%) high severe
cast float64 to float32 512
time: [2.4857 us 2.4954 us 2.5070 us]
change: [-22.439% -21.818% -21.222%] (p = 0.00 < 0.05)
Performance has improved.
Found 6 outliers among 100 measurements (6.00%)
1 (1.00%) low mild
3 (3.00%) high mild
2 (2.00%) high severe
cast float64 to uint64 512
time: [2.8313 us 2.8369 us 2.8427 us]
change: [-32.996% -32.605% -32.263%] (p = 0.00 < 0.05)
Performance has improved.
Found 9 outliers among 100 measurements (9.00%)
6 (6.00%) high mild
3 (3.00%) high severe
cast int64 to int32 512 time: [2.2000 us 2.2073 us 2.2154 us]
change: [-32.106% -31.271% -30.265%] (p = 0.00 < 0.05)
Performance has improved.
Found 11 outliers among 100 measurements (11.00%)
1 (1.00%) low mild
5 (5.00%) high mild
5 (5.00%) high severe
cast date64 to date32 512
time: [1.0772 us 1.0815 us 1.0866 us]
change: [+1.3054% +3.6485% +6.8389%] (p = 0.01 < 0.05)
Performance has regressed.
Found 14 outliers among 100 measurements (14.00%)
2 (2.00%) high mild
12 (12.00%) high severe
cast date32 to date64 512
time: [838.20 ns 840.35 ns 842.74 ns]
change: [-1.2203% -0.3511% +0.5828%] (p = 0.47 > 0.05)
No change in performance detected.
Found 10 outliers among 100 measurements (10.00%)
3 (3.00%) low mild
2 (2.00%) high mild
5 (5.00%) high severe
cast time32s to time32ms 512
time: [741.99 ns 745.21 ns 748.67 ns]
change: [-0.1386% +1.9748% +5.4043%] (p = 0.18 > 0.05)
No change in performance detected.
Found 12 outliers among 100 measurements (12.00%)
4 (4.00%) low mild
2 (2.00%) high mild
6 (6.00%) high severe
cast time32s to time64us 512
time: [4.2476 us 4.2596 us 4.2747 us]
change: [-19.580% -18.601% -17.667%] (p = 0.00 < 0.05)
Performance has improved.
Found 11 outliers among 100 measurements (11.00%)
2 (2.00%) low mild
6 (6.00%) high mild
3 (3.00%) high severe
cast time64ns to time32s 512
time: [4.9276 us 4.9371 us 4.9489 us]
change: [-0.2071% +0.6046% +1.5380%] (p = 0.17 > 0.05)
No change in performance detected.
Found 8 outliers among 100 measurements (8.00%)
1 (1.00%) low mild
2 (2.00%) high mild
5 (5.00%) high severe
cast timestamp_ns to timestamp_s 512
time: [25.938 ns 26.005 ns 26.079 ns]
change: [-2.0321% -1.2763% -0.4590%] (p = 0.00 < 0.05)
Change within noise threshold.
Found 9 outliers among 100 measurements (9.00%)
1 (1.00%) low mild
3 (3.00%) high mild
5 (5.00%) high severe
cast timestamp_ms to timestamp_ns 512
time: [2.0140 us 2.0187 us 2.0234 us]
change: [+0.4914% +1.5932% +2.6619%] (p = 0.00 < 0.05)
Change within noise threshold.
Found 6 outliers among 100 measurements (6.00%)
3 (3.00%) high mild
3 (3.00%) high severe
cast utf8 to f32 time: [28.568 us 28.651 us 28.749 us]
change: [-5.4918% -4.8116% -4.1925%] (p = 0.00 < 0.05)
Performance has improved.
Found 9 outliers among 100 measurements (9.00%)
1 (1.00%) low mild
5 (5.00%) high mild
3 (3.00%) high severe
cast i64 to string 512 time: [50.182 us 50.270 us 50.366 us]
change: [-2.7854% -1.4471% +0.2627%] (p = 0.05 > 0.05)
No change in performance detected.
Found 13 outliers among 100 measurements (13.00%)
7 (7.00%) high mild
6 (6.00%) high severe
cast f32 to string 512 time: [54.687 us 54.833 us 54.983 us]
change: [+2.6122% +3.3716% +4.2074%] (p = 0.00 < 0.05)
Performance has regressed.
Found 5 outliers among 100 measurements (5.00%)
1 (1.00%) high mild
4 (4.00%) high severe
cast timestamp_ms to i64 512
time: [424.64 ns 426.27 ns 428.09 ns]
change: [+1.3146% +2.0623% +2.7867%] (p = 0.00 < 0.05)
Performance has regressed.
Found 7 outliers among 100 measurements (7.00%)
2 (2.00%) low mild
2 (2.00%) high mild
3 (3.00%) high severe
cast utf8 to date32 512 time: [45.104 us 45.202 us 45.312 us]
change: [-0.3430% +0.3110% +1.0109%] (p = 0.40 > 0.05)
No change in performance detected.
Found 10 outliers among 100 measurements (10.00%)
5 (5.00%) high mild
5 (5.00%) high severe
cast utf8 to date64 512 time: [75.844 us 76.028 us 76.239 us]
change: [-10.720% -10.018% -9.3075%] (p = 0.00 < 0.05)
Performance has improved.
Found 5 outliers among 100 measurements (5.00%)
3 (3.00%) high mild
2 (2.00%) high severe
```
Closes #9370 from jorgecarleitao/iter_p
Authored-by: Jorge C. Leitao <jo...@gmail.com>
Signed-off-by: Andrew Lamb <an...@nerdnetworks.org>
---
rust/arrow/src/array/array_primitive.rs | 67 ++++++++++++++++++---------
rust/arrow/src/array/builder.rs | 10 ++++
rust/arrow/src/array/iterator.rs | 1 +
rust/arrow/src/buffer.rs | 8 ++++
rust/arrow/src/compute/kernels/cast.rs | 57 +++++++++++++----------
rust/arrow/src/util/mod.rs | 3 ++
rust/arrow/src/util/trusted_len.rs | 82 +++++++++++++++++++++++++++++++++
7 files changed, 182 insertions(+), 46 deletions(-)
diff --git a/rust/arrow/src/array/array_primitive.rs b/rust/arrow/src/array/array_primitive.rs
index 2056b88..0bde7bc 100644
--- a/rust/arrow/src/array/array_primitive.rs
+++ b/rust/arrow/src/array/array_primitive.rs
@@ -28,9 +28,12 @@ use chrono::prelude::*;
use super::array::print_long_array;
use super::raw_pointer::RawPtrBox;
use super::*;
-use crate::buffer::{Buffer, MutableBuffer};
use crate::temporal_conversions;
use crate::util::bit_util;
+use crate::{
+ buffer::{Buffer, MutableBuffer},
+ util::trusted_len_unzip,
+};
/// Number of seconds in a day
const SECONDS_IN_DAY: i64 = 86_400;
@@ -267,41 +270,61 @@ impl<T: ArrowPrimitiveType, Ptr: Borrow<Option<<T as ArrowPrimitiveType>::Native
{
fn from_iter<I: IntoIterator<Item = Ptr>>(iter: I) -> Self {
let iter = iter.into_iter();
- let (_, data_len) = iter.size_hint();
- let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound.
+ let (lower, _) = iter.size_hint();
- let num_bytes = bit_util::ceil(data_len, 8);
- let mut null_buf = MutableBuffer::from_len_zeroed(num_bytes);
- let mut val_buf = MutableBuffer::new(
- data_len * mem::size_of::<<T as ArrowPrimitiveType>::Native>(),
- );
+ let mut null_buf = BooleanBufferBuilder::new(lower);
- let null_slice = null_buf.as_slice_mut();
- iter.enumerate().for_each(|(i, item)| {
- if let Some(a) = item.borrow() {
- bit_util::set_bit(null_slice, i);
- val_buf.push(*a);
- } else {
- // this ensures that null items on the buffer are not arbitrary.
- // This is important because falible operations can use null values (e.g. a vectorized "add")
- // which may panic (e.g. overflow if the number on the slots happen to be very large).
- val_buf.push(T::Native::default());
- }
- });
+ let buffer: Buffer = iter
+ .map(|item| {
+ if let Some(a) = item.borrow() {
+ null_buf.append(true);
+ *a
+ } else {
+ null_buf.append(false);
+ // this ensures that null items on the buffer are not arbitrary.
+ // This is important because falible operations can use null values (e.g. a vectorized "add")
+ // which may panic (e.g. overflow if the number on the slots happen to be very large).
+ T::Native::default()
+ }
+ })
+ .collect();
let data = ArrayData::new(
T::DATA_TYPE,
- data_len,
+ null_buf.len(),
None,
Some(null_buf.into()),
0,
- vec![val_buf.into()],
+ vec![buffer],
vec![],
);
PrimitiveArray::from(Arc::new(data))
}
}
+impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
+ /// Creates a [`PrimitiveArray`] from an iterator of trusted length.
+ /// # Safety
+ /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html).
+ /// I.e. that `size_hint().1` correctly reports its length.
+ #[inline]
+ pub unsafe fn from_trusted_len_iter<I, P>(iter: I) -> Self
+ where
+ P: std::borrow::Borrow<Option<<T as ArrowPrimitiveType>::Native>>,
+ I: IntoIterator<Item = P>,
+ {
+ let iterator = iter.into_iter();
+ let (_, upper) = iterator.size_hint();
+ let len = upper.expect("trusted_len_unzip requires an upper limit");
+
+ let (null, buffer) = trusted_len_unzip(iterator);
+
+ let data =
+ ArrayData::new(T::DATA_TYPE, len, None, Some(null), 0, vec![buffer], vec![]);
+ PrimitiveArray::from(Arc::new(data))
+ }
+}
+
// TODO: the macro is needed here because we'd get "conflicting implementations" error
// otherwise with both `From<Vec<T::Native>>` and `From<Vec<Option<T::Native>>>`.
// We should revisit this in future.
diff --git a/rust/arrow/src/array/builder.rs b/rust/arrow/src/array/builder.rs
index 61724da..ad5b39f 100644
--- a/rust/arrow/src/array/builder.rs
+++ b/rust/arrow/src/array/builder.rs
@@ -297,14 +297,17 @@ impl BooleanBufferBuilder {
Self { buffer, len: 0 }
}
+ #[inline]
pub fn len(&self) -> usize {
self.len
}
+ #[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
+ #[inline]
pub fn capacity(&self) -> usize {
self.buffer.capacity() * 8
}
@@ -372,6 +375,13 @@ impl BooleanBufferBuilder {
}
}
+impl From<BooleanBufferBuilder> for Buffer {
+ #[inline]
+ fn from(builder: BooleanBufferBuilder) -> Self {
+ builder.buffer.into()
+ }
+}
+
/// Trait for dealing with different array builders at runtime
pub trait ArrayBuilder: Any {
/// Returns the number of array slots in the builder
diff --git a/rust/arrow/src/array/iterator.rs b/rust/arrow/src/array/iterator.rs
index ff1a830..cd891ba 100644
--- a/rust/arrow/src/array/iterator.rs
+++ b/rust/arrow/src/array/iterator.rs
@@ -46,6 +46,7 @@ impl<'a, T: ArrowPrimitiveType> PrimitiveIter<'a, T> {
impl<'a, T: ArrowPrimitiveType> std::iter::Iterator for PrimitiveIter<'a, T> {
type Item = Option<T::Native>;
+ #[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.current == self.current_end {
None
diff --git a/rust/arrow/src/buffer.rs b/rust/arrow/src/buffer.rs
index 024db5e..1f3ef58 100644
--- a/rust/arrow/src/buffer.rs
+++ b/rust/arrow/src/buffer.rs
@@ -990,6 +990,14 @@ impl MutableBuffer {
pub fn extend_zeros(&mut self, additional: usize) {
self.resize(self.len + additional, 0);
}
+
+ /// # Safety
+ /// The caller must ensure that the buffer was properly initialized up to `len`.
+ #[inline]
+ pub(crate) unsafe fn set_len(&mut self, len: usize) {
+ assert!(len <= self.capacity());
+ self.len = len;
+ }
}
/// # Safety
diff --git a/rust/arrow/src/compute/kernels/cast.rs b/rust/arrow/src/compute/kernels/cast.rs
index c554225..3358e68 100644
--- a/rust/arrow/src/compute/kernels/cast.rs
+++ b/rust/arrow/src/compute/kernels/cast.rs
@@ -860,9 +860,12 @@ where
T::Native: num::NumCast,
R::Native: num::NumCast,
{
- from.iter()
- .map(|v| v.and_then(num::cast::cast::<T::Native, R::Native>))
- .collect()
+ let iter = from
+ .iter()
+ .map(|v| v.and_then(num::cast::cast::<T::Native, R::Native>));
+ // Soundness:
+ // The iterator is trustedLen because it comes from an `PrimitiveArray`.
+ unsafe { PrimitiveArray::<R>::from_trusted_len_iter(iter) }
}
/// Cast numeric types to Utf8
@@ -905,15 +908,18 @@ where
T: ArrowNumericType,
<T as ArrowPrimitiveType>::Native: lexical_core::FromLexical,
{
- (0..from.len())
- .map(|i| {
- if from.is_null(i) {
- None
- } else {
- lexical_core::parse(from.value(i).as_bytes()).ok()
- }
- })
- .collect()
+ let iter = (0..from.len()).map(|i| {
+ if from.is_null(i) {
+ None
+ } else {
+ lexical_core::parse(from.value(i).as_bytes()).ok()
+ }
+ });
+ // Benefit:
+ // 20% performance improvement
+ // Soundness:
+ // The iterator is trustedLen because it comes from an `StringArray`.
+ unsafe { PrimitiveArray::<T>::from_trusted_len_iter(iter) }
}
/// Cast numeric types to Boolean
@@ -968,18 +974,21 @@ where
T: ArrowNumericType,
T::Native: num::NumCast,
{
- (0..from.len())
- .map(|i| {
- if from.is_null(i) {
- None
- } else if from.value(i) {
- // a workaround to cast a primitive to T::Native, infallible
- num::cast::cast(1)
- } else {
- Some(T::default_value())
- }
- })
- .collect()
+ let iter = (0..from.len()).map(|i| {
+ if from.is_null(i) {
+ None
+ } else if from.value(i) {
+ // a workaround to cast a primitive to T::Native, infallible
+ num::cast::cast(1)
+ } else {
+ Some(T::default_value())
+ }
+ });
+ // Benefit:
+ // 20% performance improvement
+ // Soundness:
+ // The iterator is trustedLen because it comes from a Range
+ unsafe { PrimitiveArray::<T>::from_trusted_len_iter(iter) }
}
/// Attempts to cast an `ArrayDictionary` with index type K into
diff --git a/rust/arrow/src/util/mod.rs b/rust/arrow/src/util/mod.rs
index af9b458..c0b5a3e 100644
--- a/rust/arrow/src/util/mod.rs
+++ b/rust/arrow/src/util/mod.rs
@@ -25,3 +25,6 @@ pub mod pretty;
pub(crate) mod serialization;
pub mod string_writer;
pub mod test_util;
+
+mod trusted_len;
+pub(crate) use trusted_len::trusted_len_unzip;
diff --git a/rust/arrow/src/util/trusted_len.rs b/rust/arrow/src/util/trusted_len.rs
new file mode 100644
index 0000000..84a6623
--- /dev/null
+++ b/rust/arrow/src/util/trusted_len.rs
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::bit_util;
+use crate::{
+ buffer::{Buffer, MutableBuffer},
+ datatypes::ArrowNativeType,
+};
+
+/// Creates two [`Buffer`]s from an iterator of `Option`.
+/// The first buffer corresponds to a bitmap buffer, the second one
+/// corresponds to a values buffer.
+/// # Safety
+/// The caller must ensure that `iterator` is `TrustedLen`.
+#[inline]
+pub(crate) unsafe fn trusted_len_unzip<I, P, T>(iterator: I) -> (Buffer, Buffer)
+where
+ T: ArrowNativeType,
+ P: std::borrow::Borrow<Option<T>>,
+ I: Iterator<Item = P>,
+{
+ let (_, upper) = iterator.size_hint();
+ let upper = upper.expect("trusted_len_unzip requires an upper limit");
+ let len = upper * std::mem::size_of::<T>();
+
+ let mut null = MutableBuffer::from_len_zeroed(upper.saturating_add(7) / 8);
+ let mut buffer = MutableBuffer::new(len);
+
+ let dst_null = null.as_mut_ptr();
+ let mut dst = buffer.as_mut_ptr() as *mut T;
+ for (i, item) in iterator.enumerate() {
+ let item = item.borrow();
+ if let Some(item) = item {
+ std::ptr::write(dst, *item);
+ bit_util::set_bit_raw(dst_null, i);
+ } else {
+ std::ptr::write(dst, T::default());
+ }
+ dst = dst.add(1);
+ }
+ assert_eq!(
+ dst.offset_from(buffer.as_ptr() as *mut T) as usize,
+ upper,
+ "Trusted iterator length was not accurately reported"
+ );
+ buffer.set_len(len);
+ (null.into(), buffer.into())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn trusted_len_unzip_good() {
+ let vec = vec![Some(1u32), None];
+ let (null, buffer) = unsafe { trusted_len_unzip(vec.iter()) };
+ assert_eq!(null.as_slice(), &[0b00000001]);
+ assert_eq!(buffer.as_slice(), &[1u8, 0, 0, 0, 0, 0, 0, 0]);
+ }
+
+ #[test]
+ #[should_panic(expected = "trusted_len_unzip requires an upper limit")]
+ fn trusted_len_unzip_panic() {
+ let iter = std::iter::repeat(Some(4i32));
+ unsafe { trusted_len_unzip(iter) };
+ }
+}