You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/01/18 12:48:53 UTC

[arrow-rs] branch master updated: Improve concat kernel capacity estimation (#3546)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 56dfad0b2 Improve concat kernel capacity estimation (#3546)
56dfad0b2 is described below

commit 56dfad0b2a03bc14f398a2998a68da2bc02fb7d2
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Wed Jan 18 12:48:47 2023 +0000

    Improve concat kernel capacity estimation (#3546)
    
    * Improve concat kernel capacity estimation
    
    * Review feedback
    
    * Format
---
 arrow-select/src/concat.rs | 137 ++++++++++++++++++++++++++++++---------------
 1 file changed, 93 insertions(+), 44 deletions(-)

diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs
index 7e28f1695..cff8fd25b 100644
--- a/arrow-select/src/concat.rs
+++ b/arrow-select/src/concat.rs
@@ -30,24 +30,28 @@
 //! assert_eq!(arr.len(), 3);
 //! ```
 
+use arrow_array::types::*;
 use arrow_array::*;
+use arrow_buffer::ArrowNativeType;
 use arrow_data::transform::{Capacities, MutableArrayData};
-use arrow_data::ArrayData;
 use arrow_schema::{ArrowError, DataType, SchemaRef};
 
-fn compute_str_values_length<Offset: OffsetSizeTrait>(arrays: &[&ArrayData]) -> usize {
-    arrays
-        .iter()
-        .map(|&data| {
-            // get the length of the value buffer
-            let buf_len = data.buffers()[1].len();
-            // find the offset of the buffer
-            // this returns a slice of offsets, starting from the offset of the array
-            // so we can take the first value
-            let offset = data.buffer::<Offset>(0)[0];
-            buf_len - offset.to_usize().unwrap()
-        })
-        .sum()
+fn binary_capacity<T: ByteArrayType>(arrays: &[&dyn Array]) -> Capacities {
+    let mut item_capacity = 0;
+    let mut bytes_capacity = 0;
+    for array in arrays {
+        let a = array
+            .as_any()
+            .downcast_ref::<GenericByteArray<T>>()
+            .unwrap();
+
+        // Guaranteed to always have at least one element
+        let offsets = a.value_offsets();
+        bytes_capacity += offsets[offsets.len() - 1].as_usize() - offsets[0].as_usize();
+        item_capacity += a.len()
+    }
+
+    Capacities::Binary(item_capacity, Some(bytes_capacity))
 }
 
 /// Concatenate multiple [Array] of the same type into a single [ArrayRef].
@@ -61,43 +65,27 @@ pub fn concat(arrays: &[&dyn Array]) -> Result<ArrayRef, ArrowError> {
         return Ok(array.slice(0, array.len()));
     }
 
-    if arrays
-        .iter()
-        .any(|array| array.data_type() != arrays[0].data_type())
-    {
+    let d = arrays[0].data_type();
+    if arrays.iter().skip(1).any(|array| array.data_type() != d) {
         return Err(ArrowError::InvalidArgumentError(
             "It is not possible to concatenate arrays of different data types."
                 .to_string(),
         ));
     }
 
-    let lengths = arrays.iter().map(|array| array.len()).collect::<Vec<_>>();
-    let capacity = lengths.iter().sum();
-
-    let arrays = arrays.iter().map(|a| a.data()).collect::<Vec<_>>();
-
-    let mut mutable = match arrays[0].data_type() {
-        DataType::Utf8 => {
-            let str_values_size = compute_str_values_length::<i32>(&arrays);
-            MutableArrayData::with_capacities(
-                arrays,
-                false,
-                Capacities::Binary(capacity, Some(str_values_size)),
-            )
-        }
-        DataType::LargeUtf8 => {
-            let str_values_size = compute_str_values_length::<i64>(&arrays);
-            MutableArrayData::with_capacities(
-                arrays,
-                false,
-                Capacities::Binary(capacity, Some(str_values_size)),
-            )
-        }
-        _ => MutableArrayData::new(arrays, false, capacity),
+    let capacity = match d {
+        DataType::Utf8 => binary_capacity::<Utf8Type>(arrays),
+        DataType::LargeUtf8 => binary_capacity::<LargeUtf8Type>(arrays),
+        DataType::Binary => binary_capacity::<BinaryType>(arrays),
+        DataType::LargeBinary => binary_capacity::<LargeBinaryType>(arrays),
+        _ => Capacities::Array(arrays.iter().map(|a| a.len()).sum()),
     };
 
-    for (i, len) in lengths.iter().enumerate() {
-        mutable.extend(i, 0, *len)
+    let array_data = arrays.iter().map(|a| a.data()).collect::<Vec<_>>();
+    let mut mutable = MutableArrayData::with_capacities(array_data, false, capacity);
+
+    for (i, a) in arrays.iter().enumerate() {
+        mutable.extend(i, 0, a.len())
     }
 
     Ok(make_array(mutable.freeze()))
@@ -139,7 +127,6 @@ pub fn concat_batches<'a>(
 #[cfg(test)]
 mod tests {
     use super::*;
-    use arrow_array::types::*;
     use arrow_schema::{Field, Schema};
     use std::sync::Arc;
 
@@ -665,4 +652,66 @@ mod tests {
             "Invalid argument error: batches[1] schema is different with argument schema.",
         );
     }
+
+    #[test]
+    fn concat_capacity() {
+        let a = Int32Array::from_iter_values(0..100);
+        let b = Int32Array::from_iter_values(10..20);
+        let a = concat(&[&a, &b]).unwrap();
+        let data = a.data();
+        assert_eq!(data.buffers()[0].len(), 440);
+        assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64
+
+        let a = concat(&[&a.slice(10, 20), &b]).unwrap();
+        let data = a.data();
+        assert_eq!(data.buffers()[0].len(), 120);
+        assert_eq!(data.buffers()[0].capacity(), 128); // Nearest multiple of 64
+
+        let a = StringArray::from_iter_values(std::iter::repeat("foo").take(100));
+        let b = StringArray::from(vec!["bingo", "bongo", "lorem", ""]);
+
+        let a = concat(&[&a, &b]).unwrap();
+        let data = a.data();
+        // (100 + 4 + 1) * size_of<i32>()
+        assert_eq!(data.buffers()[0].len(), 420);
+        assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64
+
+        // len("foo") * 100 + len("bingo") + len("bongo") + len("lorem")
+        assert_eq!(data.buffers()[1].len(), 315);
+        assert_eq!(data.buffers()[1].capacity(), 320); // Nearest multiple of 64
+
+        let a = concat(&[&a.slice(10, 40), &b]).unwrap();
+        let data = a.data();
+        // (40 + 4 + 5) * size_of<i32>()
+        assert_eq!(data.buffers()[0].len(), 180);
+        assert_eq!(data.buffers()[0].capacity(), 192); // Nearest multiple of 64
+
+        // len("foo") * 40 + len("bingo") + len("bongo") + len("lorem")
+        assert_eq!(data.buffers()[1].len(), 135);
+        assert_eq!(data.buffers()[1].capacity(), 192); // Nearest multiple of 64
+
+        let a = LargeBinaryArray::from_iter_values(std::iter::repeat(b"foo").take(100));
+        let b =
+            LargeBinaryArray::from_iter_values(std::iter::repeat(b"cupcakes").take(10));
+
+        let a = concat(&[&a, &b]).unwrap();
+        let data = a.data();
+        // (100 + 10 + 1) * size_of<i64>()
+        assert_eq!(data.buffers()[0].len(), 888);
+        assert_eq!(data.buffers()[0].capacity(), 896); // Nearest multiple of 64
+
+        // len("foo") * 100 + len("cupcakes") * 10
+        assert_eq!(data.buffers()[1].len(), 380);
+        assert_eq!(data.buffers()[1].capacity(), 384); // Nearest multiple of 64
+
+        let a = concat(&[&a.slice(10, 40), &b]).unwrap();
+        let data = a.data();
+        // (40 + 10 + 1) * size_of<i64>()
+        assert_eq!(data.buffers()[0].len(), 408);
+        assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64
+
+        // len("foo") * 40 + len("cupcakes") * 10
+        assert_eq!(data.buffers()[1].len(), 200);
+        assert_eq!(data.buffers()[1].capacity(), 256); // Nearest multiple of 64
+    }
 }