You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/01/18 12:48:53 UTC
[arrow-rs] branch master updated: Improve concat kernel capacity estimation (#3546)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 56dfad0b2 Improve concat kernel capacity estimation (#3546)
56dfad0b2 is described below
commit 56dfad0b2a03bc14f398a2998a68da2bc02fb7d2
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Wed Jan 18 12:48:47 2023 +0000
Improve concat kernel capacity estimation (#3546)
* Improve concat kernel capacity estimation
* Review feedback
* Format
---
arrow-select/src/concat.rs | 137 ++++++++++++++++++++++++++++++---------------
1 file changed, 93 insertions(+), 44 deletions(-)
diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs
index 7e28f1695..cff8fd25b 100644
--- a/arrow-select/src/concat.rs
+++ b/arrow-select/src/concat.rs
@@ -30,24 +30,28 @@
//! assert_eq!(arr.len(), 3);
//! ```
+use arrow_array::types::*;
use arrow_array::*;
+use arrow_buffer::ArrowNativeType;
use arrow_data::transform::{Capacities, MutableArrayData};
-use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, SchemaRef};
-fn compute_str_values_length<Offset: OffsetSizeTrait>(arrays: &[&ArrayData]) -> usize {
- arrays
- .iter()
- .map(|&data| {
- // get the length of the value buffer
- let buf_len = data.buffers()[1].len();
- // find the offset of the buffer
- // this returns a slice of offsets, starting from the offset of the array
- // so we can take the first value
- let offset = data.buffer::<Offset>(0)[0];
- buf_len - offset.to_usize().unwrap()
- })
- .sum()
+fn binary_capacity<T: ByteArrayType>(arrays: &[&dyn Array]) -> Capacities {
+ let mut item_capacity = 0;
+ let mut bytes_capacity = 0;
+ for array in arrays {
+ let a = array
+ .as_any()
+ .downcast_ref::<GenericByteArray<T>>()
+ .unwrap();
+
+ // Guaranteed to always have at least one element
+ let offsets = a.value_offsets();
+ bytes_capacity += offsets[offsets.len() - 1].as_usize() - offsets[0].as_usize();
+ item_capacity += a.len()
+ }
+
+ Capacities::Binary(item_capacity, Some(bytes_capacity))
}
/// Concatenate multiple [Array] of the same type into a single [ArrayRef].
@@ -61,43 +65,27 @@ pub fn concat(arrays: &[&dyn Array]) -> Result<ArrayRef, ArrowError> {
return Ok(array.slice(0, array.len()));
}
- if arrays
- .iter()
- .any(|array| array.data_type() != arrays[0].data_type())
- {
+ let d = arrays[0].data_type();
+ if arrays.iter().skip(1).any(|array| array.data_type() != d) {
return Err(ArrowError::InvalidArgumentError(
"It is not possible to concatenate arrays of different data types."
.to_string(),
));
}
- let lengths = arrays.iter().map(|array| array.len()).collect::<Vec<_>>();
- let capacity = lengths.iter().sum();
-
- let arrays = arrays.iter().map(|a| a.data()).collect::<Vec<_>>();
-
- let mut mutable = match arrays[0].data_type() {
- DataType::Utf8 => {
- let str_values_size = compute_str_values_length::<i32>(&arrays);
- MutableArrayData::with_capacities(
- arrays,
- false,
- Capacities::Binary(capacity, Some(str_values_size)),
- )
- }
- DataType::LargeUtf8 => {
- let str_values_size = compute_str_values_length::<i64>(&arrays);
- MutableArrayData::with_capacities(
- arrays,
- false,
- Capacities::Binary(capacity, Some(str_values_size)),
- )
- }
- _ => MutableArrayData::new(arrays, false, capacity),
+ let capacity = match d {
+ DataType::Utf8 => binary_capacity::<Utf8Type>(arrays),
+ DataType::LargeUtf8 => binary_capacity::<LargeUtf8Type>(arrays),
+ DataType::Binary => binary_capacity::<BinaryType>(arrays),
+ DataType::LargeBinary => binary_capacity::<LargeBinaryType>(arrays),
+ _ => Capacities::Array(arrays.iter().map(|a| a.len()).sum()),
};
- for (i, len) in lengths.iter().enumerate() {
- mutable.extend(i, 0, *len)
+ let array_data = arrays.iter().map(|a| a.data()).collect::<Vec<_>>();
+ let mut mutable = MutableArrayData::with_capacities(array_data, false, capacity);
+
+ for (i, a) in arrays.iter().enumerate() {
+ mutable.extend(i, 0, a.len())
}
Ok(make_array(mutable.freeze()))
@@ -139,7 +127,6 @@ pub fn concat_batches<'a>(
#[cfg(test)]
mod tests {
use super::*;
- use arrow_array::types::*;
use arrow_schema::{Field, Schema};
use std::sync::Arc;
@@ -665,4 +652,66 @@ mod tests {
"Invalid argument error: batches[1] schema is different with argument schema.",
);
}
+
+ #[test]
+ fn concat_capacity() {
+ let a = Int32Array::from_iter_values(0..100);
+ let b = Int32Array::from_iter_values(10..20);
+ let a = concat(&[&a, &b]).unwrap();
+ let data = a.data();
+ assert_eq!(data.buffers()[0].len(), 440);
+ assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64
+
+ let a = concat(&[&a.slice(10, 20), &b]).unwrap();
+ let data = a.data();
+ assert_eq!(data.buffers()[0].len(), 120);
+ assert_eq!(data.buffers()[0].capacity(), 128); // Nearest multiple of 64
+
+ let a = StringArray::from_iter_values(std::iter::repeat("foo").take(100));
+ let b = StringArray::from(vec!["bingo", "bongo", "lorem", ""]);
+
+ let a = concat(&[&a, &b]).unwrap();
+ let data = a.data();
+ // (100 + 4 + 1) * size_of<i32>()
+ assert_eq!(data.buffers()[0].len(), 420);
+ assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64
+
+ // len("foo") * 100 + len("bingo") + len("bongo") + len("lorem")
+ assert_eq!(data.buffers()[1].len(), 315);
+ assert_eq!(data.buffers()[1].capacity(), 320); // Nearest multiple of 64
+
+ let a = concat(&[&a.slice(10, 40), &b]).unwrap();
+ let data = a.data();
+ // (40 + 4 + 5) * size_of<i32>()
+ assert_eq!(data.buffers()[0].len(), 180);
+ assert_eq!(data.buffers()[0].capacity(), 192); // Nearest multiple of 64
+
+ // len("foo") * 40 + len("bingo") + len("bongo") + len("lorem")
+ assert_eq!(data.buffers()[1].len(), 135);
+ assert_eq!(data.buffers()[1].capacity(), 192); // Nearest multiple of 64
+
+ let a = LargeBinaryArray::from_iter_values(std::iter::repeat(b"foo").take(100));
+ let b =
+ LargeBinaryArray::from_iter_values(std::iter::repeat(b"cupcakes").take(10));
+
+ let a = concat(&[&a, &b]).unwrap();
+ let data = a.data();
+ // (100 + 10 + 1) * size_of<i64>()
+ assert_eq!(data.buffers()[0].len(), 888);
+ assert_eq!(data.buffers()[0].capacity(), 896); // Nearest multiple of 64
+
+ // len("foo") * 100 + len("cupcakes") * 10
+ assert_eq!(data.buffers()[1].len(), 380);
+ assert_eq!(data.buffers()[1].capacity(), 384); // Nearest multiple of 64
+
+ let a = concat(&[&a.slice(10, 40), &b]).unwrap();
+ let data = a.data();
+ // (40 + 10 + 1) * size_of<i64>()
+ assert_eq!(data.buffers()[0].len(), 408);
+ assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64
+
+ // len("foo") * 40 + len("cupcakes") * 10
+ assert_eq!(data.buffers()[1].len(), 200);
+ assert_eq!(data.buffers()[1].capacity(), 256); // Nearest multiple of 64
+ }
}