You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/07/14 21:16:03 UTC

[arrow-rs] branch master updated: Truncate IPC record batch (#2040)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 86543a4b8 Truncate IPC record batch (#2040)
86543a4b8 is described below

commit 86543a4b835f898308bd28a33d0da9d2ea7a9b35
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Thu Jul 14 14:15:57 2022 -0700

    Truncate IPC record batch (#2040)
    
    * Truncate IPC record batch
    
    * Fix clippy
    
    * Fix clippy
    
    * Check deserilized record batch
    
    * For review
    
    * Truncate DictionaryArray. Fix null buffer truncation. Add null value test.
    
    * Use BufferBuilder
    
    * Add option to IpcWriteOptions
    
    * For review
    
    * Remove truncate option. Revise test for review.
    
    * Trigger Build
---
 arrow/src/array/mod.rs          |   1 +
 arrow/src/datatypes/datatype.rs |  15 ++
 arrow/src/ipc/writer.rs         | 293 ++++++++++++++++++++++++++++++++++++++--
 3 files changed, 301 insertions(+), 8 deletions(-)

diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs
index 60703c53b..2f025f11c 100644
--- a/arrow/src/array/mod.rs
+++ b/arrow/src/array/mod.rs
@@ -176,6 +176,7 @@ pub(crate) use self::data::layout;
 pub use self::data::ArrayData;
 pub use self::data::ArrayDataBuilder;
 pub use self::data::ArrayDataRef;
+pub(crate) use self::data::BufferSpec;
 
 pub use self::array_binary::BinaryArray;
 pub use self::array_binary::FixedSizeBinaryArray;
diff --git a/arrow/src/datatypes/datatype.rs b/arrow/src/datatypes/datatype.rs
index f1c468926..d65915bd7 100644
--- a/arrow/src/datatypes/datatype.rs
+++ b/arrow/src/datatypes/datatype.rs
@@ -721,6 +721,21 @@ impl DataType {
         )
     }
 
+    /// Returns true if this type is temporal: (Date*, Time*, Duration, or Interval).
+    pub fn is_temporal(t: &DataType) -> bool {
+        use DataType::*;
+        matches!(
+            t,
+            Date32
+                | Date64
+                | Timestamp(_, _)
+                | Time32(_)
+                | Time64(_)
+                | Duration(_)
+                | Interval(_)
+        )
+    }
+
     /// Returns true if this type is valid as a dictionary key
     /// (e.g. [`super::ArrowDictionaryKeyType`]
     pub fn is_dictionary_key_type(t: &DataType) -> bool {
diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs
index 9551c4f17..ed713e586 100644
--- a/arrow/src/ipc/writer.rs
+++ b/arrow/src/ipc/writer.rs
@@ -20,6 +20,7 @@
 //! The `FileWriter` and `StreamWriter` have similar interfaces,
 //! however the `FileWriter` expects a reader that supports `Seek`ing
 
+use std::cmp::min;
 use std::collections::HashMap;
 use std::io::{BufWriter, Write};
 
@@ -27,7 +28,9 @@ use flatbuffers::FlatBufferBuilder;
 
 use crate::array::{
     as_large_list_array, as_list_array, as_map_array, as_struct_array, as_union_array,
-    make_array, Array, ArrayData, ArrayRef, FixedSizeListArray,
+    layout, make_array, Array, ArrayData, ArrayRef, BinaryArray, BufferBuilder,
+    BufferSpec, FixedSizeListArray, GenericBinaryArray, GenericStringArray,
+    LargeBinaryArray, LargeStringArray, OffsetSizeTrait, StringArray,
 };
 use crate::buffer::{Buffer, MutableBuffer};
 use crate::datatypes::*;
@@ -861,6 +864,106 @@ fn has_validity_bitmap(data_type: &DataType, write_options: &IpcWriteOptions) ->
     }
 }
 
+/// Whether to truncate the buffer
+#[inline]
+fn buffer_need_truncate(
+    array_offset: usize,
+    buffer: &Buffer,
+    spec: &BufferSpec,
+    min_length: usize,
+) -> bool {
+    spec != &BufferSpec::AlwaysNull && (array_offset != 0 || min_length < buffer.len())
+}
+
+/// Returns byte width for a buffer spec. Only for `BufferSpec::FixedWidth`.
+#[inline]
+fn get_buffer_element_width(spec: &BufferSpec) -> usize {
+    match spec {
+        BufferSpec::FixedWidth { byte_width } => *byte_width,
+        _ => 0,
+    }
+}
+
+/// Returns the number of total bytes in base binary arrays.
+fn get_binary_buffer_len(array_data: &ArrayData) -> usize {
+    if array_data.is_empty() {
+        return 0;
+    }
+    match array_data.data_type() {
+        DataType::Binary => {
+            let array: BinaryArray = array_data.clone().into();
+            let offsets = array.value_offsets();
+            (offsets[array_data.len()] - offsets[0]) as usize
+        }
+        DataType::LargeBinary => {
+            let array: LargeBinaryArray = array_data.clone().into();
+            let offsets = array.value_offsets();
+            (offsets[array_data.len()] - offsets[0]) as usize
+        }
+        DataType::Utf8 => {
+            let array: StringArray = array_data.clone().into();
+            let offsets = array.value_offsets();
+            (offsets[array_data.len()] - offsets[0]) as usize
+        }
+        DataType::LargeUtf8 => {
+            let array: LargeStringArray = array_data.clone().into();
+            let offsets = array.value_offsets();
+            (offsets[array_data.len()] - offsets[0]) as usize
+        }
+        _ => unreachable!(),
+    }
+}
+
+/// Rebase value offsets for given ArrayData to zero-based.
+fn get_zero_based_value_offsets<OffsetSize: OffsetSizeTrait>(
+    array_data: &ArrayData,
+) -> Buffer {
+    match array_data.data_type() {
+        DataType::Binary | DataType::LargeBinary => {
+            let array: GenericBinaryArray<OffsetSize> = array_data.clone().into();
+            let offsets = array.value_offsets();
+            let start_offset = offsets[0];
+
+            let mut builder = BufferBuilder::<OffsetSize>::new(array_data.len() + 1);
+            for x in offsets {
+                builder.append(*x - start_offset);
+            }
+
+            builder.finish()
+        }
+        DataType::Utf8 | DataType::LargeUtf8 => {
+            let array: GenericStringArray<OffsetSize> = array_data.clone().into();
+            let offsets = array.value_offsets();
+            let start_offset = offsets[0];
+
+            let mut builder = BufferBuilder::<OffsetSize>::new(array_data.len() + 1);
+            for x in offsets {
+                builder.append(*x - start_offset);
+            }
+
+            builder.finish()
+        }
+        _ => unreachable!(),
+    }
+}
+
+/// Returns the start offset of base binary array.
+fn get_buffer_offset<OffsetSize: OffsetSizeTrait>(array_data: &ArrayData) -> OffsetSize {
+    match array_data.data_type() {
+        DataType::Binary | DataType::LargeBinary => {
+            let array: GenericBinaryArray<OffsetSize> = array_data.clone().into();
+            let offsets = array.value_offsets();
+            offsets[0]
+        }
+        DataType::Utf8 | DataType::LargeUtf8 => {
+            let array: GenericStringArray<OffsetSize> = array_data.clone().into();
+            let offsets = array.value_offsets();
+            offsets[0]
+        }
+        _ => unreachable!(),
+    }
+}
+
 /// Write array data to a vector of bytes
 #[allow(clippy::too_many_arguments)]
 fn write_array_data(
@@ -891,15 +994,80 @@ fn write_array_data(
                 let buffer = buffer.with_bitset(num_bytes, true);
                 buffer.into()
             }
-            Some(buffer) => buffer.clone(),
+            Some(buffer) => buffer.bit_slice(array_data.offset(), array_data.len()),
         };
 
-        offset = write_buffer(&null_buffer, buffers, arrow_data, offset);
+        offset = write_buffer(null_buffer.as_slice(), buffers, arrow_data, offset);
     }
 
-    array_data.buffers().iter().for_each(|buffer| {
-        offset = write_buffer(buffer, buffers, arrow_data, offset);
-    });
+    let data_type = array_data.data_type();
+    if matches!(
+        data_type,
+        DataType::Binary | DataType::LargeBinary | DataType::Utf8 | DataType::LargeUtf8
+    ) {
+        let total_bytes = get_binary_buffer_len(array_data);
+        let value_buffer = &array_data.buffers()[1];
+        if buffer_need_truncate(
+            array_data.offset(),
+            value_buffer,
+            &BufferSpec::VariableWidth,
+            total_bytes,
+        ) {
+            // Rebase offsets and truncate values
+            let (new_offsets, byte_offset) =
+                if matches!(data_type, DataType::Binary | DataType::Utf8) {
+                    (
+                        get_zero_based_value_offsets::<i32>(array_data),
+                        get_buffer_offset::<i32>(array_data) as usize,
+                    )
+                } else {
+                    (
+                        get_zero_based_value_offsets::<i64>(array_data),
+                        get_buffer_offset::<i64>(array_data) as usize,
+                    )
+                };
+
+            offset = write_buffer(new_offsets.as_slice(), buffers, arrow_data, offset);
+
+            let buffer_length = min(total_bytes, value_buffer.len() - byte_offset);
+            let buffer_slice =
+                &value_buffer.as_slice()[byte_offset..(byte_offset + buffer_length)];
+            offset = write_buffer(buffer_slice, buffers, arrow_data, offset);
+        } else {
+            array_data.buffers().iter().for_each(|buffer| {
+                offset = write_buffer(buffer.as_slice(), buffers, arrow_data, offset);
+            });
+        }
+    } else if DataType::is_numeric(data_type)
+        || DataType::is_temporal(data_type)
+        || matches!(
+            array_data.data_type(),
+            DataType::FixedSizeBinary(_) | DataType::Dictionary(_, _)
+        )
+    {
+        // Truncate values
+        assert!(array_data.buffers().len() == 1);
+
+        let buffer = &array_data.buffers()[0];
+        let layout = layout(data_type);
+        let spec = &layout.buffers[0];
+
+        let byte_width = get_buffer_element_width(spec);
+        let min_length = array_data.len() * byte_width;
+        if buffer_need_truncate(array_data.offset(), buffer, spec, min_length) {
+            let byte_offset = array_data.offset() * byte_width;
+            let buffer_length = min(min_length, buffer.len() - byte_offset);
+            let buffer_slice =
+                &buffer.as_slice()[byte_offset..(byte_offset + buffer_length)];
+            offset = write_buffer(buffer_slice, buffers, arrow_data, offset);
+        } else {
+            offset = write_buffer(buffer.as_slice(), buffers, arrow_data, offset);
+        }
+    } else {
+        array_data.buffers().iter().for_each(|buffer| {
+            offset = write_buffer(buffer, buffers, arrow_data, offset);
+        });
+    }
 
     if !matches!(array_data.data_type(), DataType::Dictionary(_, _)) {
         // recursively write out nested structures
@@ -923,7 +1091,7 @@ fn write_array_data(
 
 /// Write a buffer to a vector of bytes, and add its ipc::Buffer to a vector
 fn write_buffer(
-    buffer: &Buffer,
+    buffer: &[u8],
     buffers: &mut Vec<ipc::Buffer>,
     arrow_data: &mut Vec<u8>,
     offset: i64,
@@ -933,7 +1101,7 @@ fn write_buffer(
     let total_len: i64 = (len + pad_len) as i64;
     // assert_eq!(len % 8, 0, "Buffer width not a multiple of 8 bytes");
     buffers.push(ipc::Buffer::new(offset, total_len));
-    arrow_data.extend_from_slice(buffer.as_slice());
+    arrow_data.extend_from_slice(buffer);
     arrow_data.extend_from_slice(&vec![0u8; pad_len][..]);
     offset + total_len
 }
@@ -1507,4 +1675,113 @@ mod tests {
             IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap(),
         );
     }
+
+    fn serialize(record: &RecordBatch) -> Vec<u8> {
+        let buffer: Vec<u8> = Vec::new();
+        let mut stream_writer = StreamWriter::try_new(buffer, &record.schema()).unwrap();
+        stream_writer.write(record).unwrap();
+        stream_writer.finish().unwrap();
+        stream_writer.into_inner().unwrap()
+    }
+
+    fn deserialize(bytes: Vec<u8>) -> RecordBatch {
+        let mut stream_reader =
+            ipc::reader::StreamReader::try_new(std::io::Cursor::new(bytes), None)
+                .unwrap();
+        stream_reader.next().unwrap().unwrap()
+    }
+
+    #[test]
+    fn truncate_ipc_record_batch() {
+        fn create_batch(rows: usize) -> RecordBatch {
+            let schema = Schema::new(vec![
+                Field::new("a", DataType::Int32, false),
+                Field::new("b", DataType::Utf8, false),
+            ]);
+
+            let a = Int32Array::from_iter_values(0..rows as i32);
+            let b = StringArray::from_iter_values((0..rows).map(|i| i.to_string()));
+
+            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])
+                .unwrap()
+        }
+
+        let big_record_batch = create_batch(65536);
+
+        let length = 5;
+        let small_record_batch = create_batch(length);
+
+        let offset = 2;
+        let record_batch_slice = big_record_batch.slice(offset, length);
+        assert!(
+            serialize(&big_record_batch).len() > serialize(&small_record_batch).len()
+        );
+        assert_eq!(
+            serialize(&small_record_batch).len(),
+            serialize(&record_batch_slice).len()
+        );
+
+        assert_eq!(
+            deserialize(serialize(&record_batch_slice)),
+            record_batch_slice
+        );
+    }
+
+    #[test]
+    fn truncate_ipc_record_batch_with_nulls() {
+        fn create_batch() -> RecordBatch {
+            let schema = Schema::new(vec![
+                Field::new("a", DataType::Int32, true),
+                Field::new("b", DataType::Utf8, true),
+            ]);
+
+            let a = Int32Array::from(vec![Some(1), None, Some(1), None, Some(1)]);
+            let b = StringArray::from(vec![None, Some("a"), Some("a"), None, Some("a")]);
+
+            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])
+                .unwrap()
+        }
+
+        let record_batch = create_batch();
+        let record_batch_slice = record_batch.slice(1, 2);
+        let deserialized_batch = deserialize(serialize(&record_batch_slice));
+
+        assert!(serialize(&record_batch).len() > serialize(&record_batch_slice).len());
+
+        assert!(deserialized_batch.column(0).is_null(0));
+        assert!(deserialized_batch.column(0).is_valid(1));
+        assert!(deserialized_batch.column(1).is_valid(0));
+        assert!(deserialized_batch.column(1).is_valid(1));
+
+        assert_eq!(record_batch_slice, deserialized_batch);
+    }
+
+    #[test]
+    fn truncate_ipc_dictionary_array() {
+        fn create_batch() -> RecordBatch {
+            let values: StringArray = [Some("foo"), Some("bar"), Some("baz")]
+                .into_iter()
+                .collect();
+            let keys: Int32Array =
+                [Some(0), Some(2), None, Some(1)].into_iter().collect();
+
+            let array = DictionaryArray::<Int32Type>::try_new(&keys, &values).unwrap();
+
+            let schema =
+                Schema::new(vec![Field::new("dict", array.data_type().clone(), true)]);
+
+            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap()
+        }
+
+        let record_batch = create_batch();
+        let record_batch_slice = record_batch.slice(1, 2);
+        let deserialized_batch = deserialize(serialize(&record_batch_slice));
+
+        assert!(serialize(&record_batch).len() > serialize(&record_batch_slice).len());
+
+        assert!(deserialized_batch.column(0).is_valid(0));
+        assert!(deserialized_batch.column(0).is_null(1));
+
+        assert_eq!(record_batch_slice, deserialized_batch);
+    }
 }