You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/01/10 12:55:13 UTC

[arrow-rs] branch master updated: Fix IPCWriter for Sliced BooleanArray (#3498)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new cada9ba33 Fix IPCWriter for Sliced BooleanArray (#3498)
cada9ba33 is described below

commit cada9ba33803a48a3145ab333fe1cf6410999d89
Author: Marco Neumann <ma...@crepererum.net>
AuthorDate: Tue Jan 10 13:55:07 2023 +0100

    Fix IPCWriter for Sliced BooleanArray (#3498)
    
    * fix: bool IPC
    
    Fixes #3496.
    
    * refactor: simplify code
    
    * refactor: `assert!` -> `assert_eq!`
---
 arrow-ipc/src/writer.rs | 73 ++++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 72 insertions(+), 1 deletion(-)

diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs
index ed5e53a95..d7cc83aab 100644
--- a/arrow-ipc/src/writer.rs
+++ b/arrow-ipc/src/writer.rs
@@ -1202,7 +1202,7 @@ fn write_array_data(
         )
     {
         // Truncate values
-        assert!(array_data.buffers().len() == 1);
+        assert_eq!(array_data.buffers().len(), 1);
 
         let buffer = &array_data.buffers()[0];
         let layout = layout(data_type);
@@ -1231,6 +1231,14 @@ fn write_array_data(
                 compression_codec,
             )?;
         }
+    } else if matches!(data_type, DataType::Boolean) {
+        // Bools are special because the payload (= 1 bit) is smaller than the physical container elements (= bytes).
+        // The array data may not start at the physical boundary of the underlying buffer, so we need to shift bits around.
+        assert_eq!(array_data.buffers().len(), 1);
+
+        let buffer = &array_data.buffers()[0];
+        let buffer = buffer.bit_slice(array_data.offset(), array_data.len());
+        offset = write_buffer(&buffer, buffers, arrow_data, offset, compression_codec)?;
     } else {
         for buffer in array_data.buffers() {
             offset =
@@ -1312,6 +1320,7 @@ fn pad_to_8(len: u32) -> usize {
 mod tests {
     use super::*;
 
+    use std::io::Cursor;
     use std::io::Seek;
     use std::sync::Arc;
 
@@ -1926,4 +1935,66 @@ mod tests {
             read_array.iter().collect::<Vec<_>>()
         );
     }
+
+    #[test]
+    fn encode_bools_slice() {
+        // Test case for https://github.com/apache/arrow-rs/issues/3496
+        assert_bool_roundtrip([true, false], 1, 1);
+
+        // slice somewhere in the middle
+        assert_bool_roundtrip(
+            [
+                true, false, true, true, false, false, true, true, true, false, false,
+                false, true, true, true, true, false, false, false, false, true, true,
+                true, true, true, false, false, false, false, false,
+            ],
+            13,
+            17,
+        );
+
+        // start at byte boundary, end in the middle
+        assert_bool_roundtrip(
+            [
+                true, false, true, true, false, false, true, true, true, false, false,
+                false,
+            ],
+            8,
+            2,
+        );
+
+        // start and stop and byte boundary
+        assert_bool_roundtrip(
+            [
+                true, false, true, true, false, false, true, true, true, false, false,
+                false, true, true, true, true, true, false, false, false, false, false,
+            ],
+            8,
+            8,
+        );
+    }
+
+    fn assert_bool_roundtrip<const N: usize>(
+        bools: [bool; N],
+        offset: usize,
+        length: usize,
+    ) {
+        let val_bool_field = Field::new("val", DataType::Boolean, false);
+
+        let schema = Arc::new(Schema::new(vec![val_bool_field]));
+
+        let bools = BooleanArray::from(bools.to_vec());
+
+        let batch =
+            RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(bools)]).unwrap();
+        let batch = batch.slice(offset, length);
+
+        let mut writer = StreamWriter::try_new(Vec::<u8>::new(), &schema).unwrap();
+        writer.write(&batch).unwrap();
+        writer.finish().unwrap();
+        let data = writer.into_inner().unwrap();
+
+        let mut reader = StreamReader::try_new(Cursor::new(data), None).unwrap();
+        let batch2 = reader.next().unwrap().unwrap();
+        assert_eq!(batch, batch2);
+    }
 }