You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/04/15 13:05:27 UTC

[arrow-rs] branch master updated: Read/write nested dictionary in ipc stream reader/writer (#1566)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new bb6535865 Read/write nested dictionary in ipc stream reader/writer (#1566)
bb6535865 is described below

commit bb65358653f11bc597e266c9f1ae565612dcb321
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Fri Apr 15 06:05:23 2022 -0700

    Read/write nested dictionary in ipc stream reader/writer (#1566)
    
    * Read dictionary inside dictionary
    
    * Fix clippy
---
 arrow/src/datatypes/field.rs | 13 ++++++++++-
 arrow/src/ipc/reader.rs      | 39 +++++++++++++++++++++++++++++++++
 arrow/src/ipc/writer.rs      | 52 +++++++++++++++++++++++++++++++++++++++-----
 3 files changed, 97 insertions(+), 7 deletions(-)

diff --git a/arrow/src/datatypes/field.rs b/arrow/src/datatypes/field.rs
index 2509edbd0..c841216a5 100644
--- a/arrow/src/datatypes/field.rs
+++ b/arrow/src/datatypes/field.rs
@@ -116,7 +116,15 @@ impl Field {
     /// Returns a (flattened) vector containing all fields contained within this field (including it self)
     pub(crate) fn fields(&self) -> Vec<&Field> {
         let mut collected_fields = vec![self];
-        match &self.data_type {
+        collected_fields.append(&mut self._fields(&self.data_type));
+
+        collected_fields
+    }
+
+    fn _fields<'a>(&'a self, dt: &'a DataType) -> Vec<&Field> {
+        let mut collected_fields = vec![];
+
+        match dt {
             DataType::Struct(fields) | DataType::Union(fields, _) => {
                 collected_fields.extend(fields.iter().flat_map(|f| f.fields()))
             }
@@ -124,6 +132,9 @@ impl Field {
             | DataType::LargeList(field)
             | DataType::FixedSizeList(field, _)
             | DataType::Map(field, _) => collected_fields.push(field),
+            DataType::Dictionary(_, value_field) => {
+                collected_fields.append(&mut self._fields(value_field.as_ref()))
+            }
             _ => (),
         }
 
diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs
index 1134ed425..fe10fe548 100644
--- a/arrow/src/ipc/reader.rs
+++ b/arrow/src/ipc/reader.rs
@@ -1019,6 +1019,7 @@ mod tests {
 
     use flate2::read::GzDecoder;
 
+    use crate::datatypes::Int8Type;
     use crate::{datatypes, util::integration_util::*};
 
     #[test]
@@ -1441,4 +1442,42 @@ mod tests {
         let output_batch = roundtrip_ipc_stream(&input_batch);
         assert_eq!(input_batch, output_batch);
     }
+
+    #[test]
+    fn test_roundtrip_stream_nested_dict_dict() {
+        let values = StringArray::from_iter_values(["a", "b", "c"]);
+        let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1]);
+        let dict_array = DictionaryArray::<Int8Type>::try_new(&keys, &values).unwrap();
+        let dict_data = dict_array.data();
+
+        let value_offsets = Buffer::from_slice_ref(&[0, 2, 4, 6]);
+
+        let list_data_type = DataType::List(Box::new(Field::new_dict(
+            "item",
+            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
+            false,
+            1,
+            false,
+        )));
+        let list_data = ArrayData::builder(list_data_type)
+            .len(3)
+            .add_buffer(value_offsets)
+            .add_child_data(dict_data.clone())
+            .build()
+            .unwrap();
+        let list_array = ListArray::from(list_data);
+
+        let dict_dict_array =
+            DictionaryArray::<Int8Type>::try_new(&keys, &list_array).unwrap();
+
+        let schema = Arc::new(Schema::new(vec![Field::new(
+            "f1",
+            dict_dict_array.data_type().clone(),
+            false,
+        )]));
+        let input_batch =
+            RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
+        let output_batch = roundtrip_ipc_stream(&input_batch);
+        assert_eq!(input_batch, output_batch);
+    }
 }
diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs
index 33d40ce36..a5b35f364 100644
--- a/arrow/src/ipc/writer.rs
+++ b/arrow/src/ipc/writer.rs
@@ -25,7 +25,9 @@ use std::io::{BufWriter, Write};
 
 use flatbuffers::FlatBufferBuilder;
 
-use crate::array::{as_struct_array, as_union_array, ArrayData, ArrayRef};
+use crate::array::{
+    as_list_array, as_struct_array, as_union_array, make_array, ArrayData, ArrayRef,
+};
 use crate::buffer::{Buffer, MutableBuffer};
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
@@ -137,15 +139,14 @@ impl IpcDataGenerator {
         }
     }
 
-    fn encode_dictionaries(
+    fn _encode_dictionaries(
         &self,
-        field: &Field,
         column: &ArrayRef,
         encoded_dictionaries: &mut Vec<EncodedData>,
         dictionary_tracker: &mut DictionaryTracker,
         write_options: &IpcWriteOptions,
     ) -> Result<()> {
-        // TODO: Handle other nested types (map, list, etc)
+        // TODO: Handle other nested types (map, etc)
         match column.data_type() {
             DataType::Struct(fields) => {
                 let s = as_struct_array(column);
@@ -159,6 +160,16 @@ impl IpcDataGenerator {
                     )?;
                 }
             }
+            DataType::List(field) => {
+                let list = as_list_array(column);
+                self.encode_dictionaries(
+                    field,
+                    &list.values(),
+                    encoded_dictionaries,
+                    dictionary_tracker,
+                    write_options,
+                )?;
+            }
             DataType::Union(fields, _) => {
                 let union = as_union_array(column);
                 for (field, ref column) in fields
@@ -175,6 +186,21 @@ impl IpcDataGenerator {
                     )?;
                 }
             }
+            _ => (),
+        }
+
+        Ok(())
+    }
+
+    fn encode_dictionaries(
+        &self,
+        field: &Field,
+        column: &ArrayRef,
+        encoded_dictionaries: &mut Vec<EncodedData>,
+        dictionary_tracker: &mut DictionaryTracker,
+        write_options: &IpcWriteOptions,
+    ) -> Result<()> {
+        match column.data_type() {
             DataType::Dictionary(_key_type, _value_type) => {
                 let dict_id = field
                     .dict_id()
@@ -182,6 +208,15 @@ impl IpcDataGenerator {
                 let dict_data = column.data();
                 let dict_values = &dict_data.child_data()[0];
 
+                let values = make_array(dict_data.child_data()[0].clone());
+
+                self._encode_dictionaries(
+                    &values,
+                    encoded_dictionaries,
+                    dictionary_tracker,
+                    write_options,
+                )?;
+
                 let emit = dictionary_tracker.insert(dict_id, column)?;
 
                 if emit {
@@ -192,7 +227,12 @@ impl IpcDataGenerator {
                     ));
                 }
             }
-            _ => (),
+            _ => self._encode_dictionaries(
+                column,
+                encoded_dictionaries,
+                dictionary_tracker,
+                write_options,
+            )?,
         }
 
         Ok(())
@@ -205,7 +245,7 @@ impl IpcDataGenerator {
         write_options: &IpcWriteOptions,
     ) -> Result<(Vec<EncodedData>, EncodedData)> {
         let schema = batch.schema();
-        let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len());
+        let mut encoded_dictionaries = Vec::with_capacity(schema.all_fields().len());
 
         for (i, field) in schema.fields().iter().enumerate() {
             let column = batch.column(i);