You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ne...@apache.org on 2022/06/05 09:00:48 UTC

[arrow-rs] branch master updated: Read and skip validity buffer of UnionType Array for V4 ipc message (#1789)

This is an automated email from the ASF dual-hosted git repository.

nevime pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 73d552a7c Read and skip validity buffer of UnionType Array for V4 ipc message (#1789)
73d552a7c is described below

commit 73d552a7cc794d0e3eaa3e5333e5bc1c98deeb45
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Sun Jun 5 02:00:44 2022 -0700

    Read and skip validity buffer of UnionType Array for V4 ipc message (#1789)
    
    * Read valididy buffer for V4 ipc message
    
    * Add unit test
    
    * Fix clippy
---
 arrow-flight/src/utils.rs                          |  1 +
 arrow/src/ipc/reader.rs                            | 31 ++++++++++++--
 arrow/src/ipc/writer.rs                            | 48 ++++++++++++++++++++++
 .../flight_client_scenarios/integration_test.rs    |  1 +
 .../flight_server_scenarios/integration_test.rs    | 10 ++++-
 5 files changed, 86 insertions(+), 5 deletions(-)

diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs
index 77526917f..dda3fc7fe 100644
--- a/arrow-flight/src/utils.rs
+++ b/arrow-flight/src/utils.rs
@@ -71,6 +71,7 @@ pub fn flight_data_to_arrow_batch(
                 schema,
                 dictionaries_by_id,
                 None,
+                &message.version(),
             )
         })?
 }
diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs
index 03a960c4c..868098327 100644
--- a/arrow/src/ipc/reader.rs
+++ b/arrow/src/ipc/reader.rs
@@ -52,6 +52,7 @@ fn read_buffer(buf: &ipc::Buffer, a_data: &[u8]) -> Buffer {
 ///     - check if the bit width of non-64-bit numbers is 64, and
 ///     - read the buffer as 64-bit (signed integer or float), and
 ///     - cast the 64-bit array to the appropriate data type
+#[allow(clippy::too_many_arguments)]
 fn create_array(
     nodes: &[ipc::FieldNode],
     field: &Field,
@@ -60,6 +61,7 @@ fn create_array(
     dictionaries_by_id: &HashMap<i64, ArrayRef>,
     mut node_index: usize,
     mut buffer_index: usize,
+    metadata: &ipc::MetadataVersion,
 ) -> Result<(ArrayRef, usize, usize)> {
     use DataType::*;
     let data_type = field.data_type();
@@ -106,6 +108,7 @@ fn create_array(
                 dictionaries_by_id,
                 node_index,
                 buffer_index,
+                metadata,
             )?;
             node_index = triple.1;
             buffer_index = triple.2;
@@ -128,6 +131,7 @@ fn create_array(
                 dictionaries_by_id,
                 node_index,
                 buffer_index,
+                metadata,
             )?;
             node_index = triple.1;
             buffer_index = triple.2;
@@ -153,6 +157,7 @@ fn create_array(
                     dictionaries_by_id,
                     node_index,
                     buffer_index,
+                    metadata,
                 )?;
                 node_index = triple.1;
                 buffer_index = triple.2;
@@ -201,6 +206,13 @@ fn create_array(
 
             let len = union_node.length() as usize;
 
+            // In V4, union types has validity bitmap
+            // In V5 and later, union types have no validity bitmap
+            if metadata < &ipc::MetadataVersion::V5 {
+                read_buffer(&buffers[buffer_index], data);
+                buffer_index += 1;
+            }
+
             let type_ids: Buffer =
                 read_buffer(&buffers[buffer_index], data)[..len].into();
 
@@ -226,6 +238,7 @@ fn create_array(
                     dictionaries_by_id,
                     node_index,
                     buffer_index,
+                    metadata,
                 )?;
 
                 node_index = triple.1;
@@ -582,6 +595,7 @@ pub fn read_record_batch(
     schema: SchemaRef,
     dictionaries_by_id: &HashMap<i64, ArrayRef>,
     projection: Option<&[usize]>,
+    metadata: &ipc::MetadataVersion,
 ) -> Result<RecordBatch> {
     let buffers = batch.buffers().ok_or_else(|| {
         ArrowError::IoError("Unable to get buffers from IPC RecordBatch".to_string())
@@ -607,6 +621,7 @@ pub fn read_record_batch(
                     dictionaries_by_id,
                     node_index,
                     buffer_index,
+                    metadata,
                 )?;
                 node_index = triple.1;
                 buffer_index = triple.2;
@@ -640,6 +655,7 @@ pub fn read_record_batch(
                 dictionaries_by_id,
                 node_index,
                 buffer_index,
+                metadata,
             )?;
             node_index = triple.1;
             buffer_index = triple.2;
@@ -656,6 +672,7 @@ pub fn read_dictionary(
     batch: ipc::DictionaryBatch,
     schema: &Schema,
     dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
+    metadata: &ipc::MetadataVersion,
 ) -> Result<()> {
     if batch.isDelta() {
         return Err(ArrowError::IoError(
@@ -686,6 +703,7 @@ pub fn read_dictionary(
                 Arc::new(schema),
                 dictionaries_by_id,
                 None,
+                metadata,
             )?;
             Some(record_batch.column(0).clone())
         }
@@ -816,7 +834,13 @@ impl<R: Read + Seek> FileReader<R> {
                         ))?;
                         reader.read_exact(&mut buf)?;
 
-                        read_dictionary(&buf, batch, &schema, &mut dictionaries_by_id)?;
+                        read_dictionary(
+                            &buf,
+                            batch,
+                            &schema,
+                            &mut dictionaries_by_id,
+                            &message.version(),
+                        )?;
                     }
                     t => {
                         return Err(ArrowError::IoError(format!(
@@ -925,6 +949,7 @@ impl<R: Read + Seek> FileReader<R> {
                     self.schema(),
                     &self.dictionaries_by_id,
                     self.projection.as_ref().map(|x| x.0.as_ref()),
+                    &message.version()
 
                 ).map(Some)
             }
@@ -1099,7 +1124,7 @@ impl<R: Read> StreamReader<R> {
                 let mut buf = vec![0; message.bodyLength() as usize];
                 self.reader.read_exact(&mut buf)?;
 
-                read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref())).map(Some)
+                read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref()), &message.version()).map(Some)
             }
             ipc::MessageHeader::DictionaryBatch => {
                 let batch = message.header_as_dictionary_batch().ok_or_else(|| {
@@ -1112,7 +1137,7 @@ impl<R: Read> StreamReader<R> {
                 self.reader.read_exact(&mut buf)?;
 
                 read_dictionary(
-                    &buf, batch, &self.schema, &mut self.dictionaries_by_id
+                    &buf, batch, &self.schema, &mut self.dictionaries_by_id, &message.version()
                 )?;
 
                 // read the next message until we encounter a RecordBatch
diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs
index c42c0fd97..70e07acae 100644
--- a/arrow/src/ipc/writer.rs
+++ b/arrow/src/ipc/writer.rs
@@ -1385,4 +1385,52 @@ mod tests {
         // Dictionary with id 2 should have been written to the dict tracker
         assert!(dict_tracker.written.contains_key(&2));
     }
+
+    #[test]
+    fn read_union_017() {
+        let testdata = crate::util::test_util::arrow_test_data();
+        let version = "0.17.1";
+        let data_file = File::open(format!(
+            "{}/arrow-ipc-stream/integration/0.17.1/generated_union.stream",
+            testdata,
+        ))
+        .unwrap();
+
+        let reader = StreamReader::try_new(data_file, None).unwrap();
+
+        // read and rewrite the stream to a temp location
+        {
+            let file = File::create(format!(
+                "target/debug/testdata/{}-generated_union.stream",
+                version
+            ))
+            .unwrap();
+            let mut writer = StreamWriter::try_new(file, &reader.schema()).unwrap();
+            reader.for_each(|batch| {
+                writer.write(&batch.unwrap()).unwrap();
+            });
+            writer.finish().unwrap();
+        }
+
+        // Compare original file and rewrote file
+        let file = File::open(format!(
+            "target/debug/testdata/{}-generated_union.stream",
+            version
+        ))
+        .unwrap();
+        let rewrite_reader = StreamReader::try_new(file, None).unwrap();
+
+        let data_file = File::open(format!(
+            "{}/arrow-ipc-stream/integration/0.17.1/generated_union.stream",
+            testdata,
+        ))
+        .unwrap();
+        let reader = StreamReader::try_new(data_file, None).unwrap();
+
+        reader.into_iter().zip(rewrite_reader.into_iter()).for_each(
+            |(batch1, batch2)| {
+                assert_eq!(batch1.unwrap(), batch2.unwrap());
+            },
+        );
+    }
 }
diff --git a/integration-testing/src/flight_client_scenarios/integration_test.rs b/integration-testing/src/flight_client_scenarios/integration_test.rs
index 4158a7352..62fe2b85d 100644
--- a/integration-testing/src/flight_client_scenarios/integration_test.rs
+++ b/integration-testing/src/flight_client_scenarios/integration_test.rs
@@ -270,6 +270,7 @@ async fn receive_batch_flight_data(
                 .expect("Error parsing dictionary"),
             &schema,
             dictionaries_by_id,
+            &message.version(),
         )
         .expect("Error reading dictionary");
 
diff --git a/integration-testing/src/flight_server_scenarios/integration_test.rs b/integration-testing/src/flight_server_scenarios/integration_test.rs
index 52086aade..7ad3d18eb 100644
--- a/integration-testing/src/flight_server_scenarios/integration_test.rs
+++ b/integration-testing/src/flight_server_scenarios/integration_test.rs
@@ -296,6 +296,7 @@ async fn record_batch_from_message(
         schema_ref,
         dictionaries_by_id,
         None,
+        &message.version(),
     );
 
     arrow_batch_result.map_err(|e| {
@@ -313,8 +314,13 @@ async fn dictionary_from_message(
         Status::internal("Could not parse message header as dictionary batch")
     })?;
 
-    let dictionary_batch_result =
-        reader::read_dictionary(data_body, ipc_batch, &schema_ref, dictionaries_by_id);
+    let dictionary_batch_result = reader::read_dictionary(
+        data_body,
+        ipc_batch,
+        &schema_ref,
+        dictionaries_by_id,
+        &message.version(),
+    );
     dictionary_batch_result.map_err(|e| {
         Status::internal(format!("Could not convert to Dictionary: {:?}", e))
     })