You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/11/09 11:33:52 UTC

[arrow-rs] 01/01: feat(ipc): add support for deserializing messages with nested dictionary fields (#923)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch cherry_pick_e20d3faf
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git

commit 183c1ddff7e6c306f8f8bf1da5a43598c8aab8c1
Author: Helgi Kristvin Sigurbjarnarson <he...@lacework.net>
AuthorDate: Mon Nov 8 13:32:33 2021 -0800

    feat(ipc): add support for deserializing messages with nested dictionary fields (#923)
    
    * feat(ipc): read a message containing nested dictionary fields
    
    * Apply suggestions from code review
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
    
    * address lints
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
 arrow/src/datatypes/field.rs  | 88 +++++++++++++++++++++++++++++++++++++++++++
 arrow/src/datatypes/schema.rs |  9 ++++-
 arrow/src/ipc/reader.rs       | 38 +++++++++++++++++--
 3 files changed, 131 insertions(+), 4 deletions(-)

diff --git a/arrow/src/datatypes/field.rs b/arrow/src/datatypes/field.rs
index 497dbb3..4ed0661 100644
--- a/arrow/src/datatypes/field.rs
+++ b/arrow/src/datatypes/field.rs
@@ -107,6 +107,36 @@ impl Field {
         self.nullable
     }
 
+    /// Returns a (flattened) vector containing all fields contained within this field (including it self)
+    pub(crate) fn fields(&self) -> Vec<&Field> {
+        let mut collected_fields = vec![self];
+        match &self.data_type {
+            DataType::Struct(fields) | DataType::Union(fields) => {
+                collected_fields.extend(fields.iter().map(|f| f.fields()).flatten())
+            }
+            DataType::List(field)
+            | DataType::LargeList(field)
+            | DataType::FixedSizeList(field, _)
+            | DataType::Map(field, _) => collected_fields.push(field),
+            _ => (),
+        }
+
+        collected_fields
+    }
+
+    /// Returns a vector containing all (potentially nested) `Field` instances selected by the
+    /// dictionary ID they use
+    #[inline]
+    pub(crate) fn fields_with_dict_id(&self, id: i64) -> Vec<&Field> {
+        self.fields()
+            .into_iter()
+            .filter(|&field| {
+                matches!(field.data_type(), DataType::Dictionary(_, _))
+                    && field.dict_id == id
+            })
+            .collect()
+    }
+
     /// Returns the dictionary ID, if this is a dictionary type.
     #[inline]
     pub const fn dict_id(&self) -> Option<i64> {
@@ -572,3 +602,61 @@ impl std::fmt::Display for Field {
         write!(f, "{:?}", self)
     }
 }
+
+#[cfg(test)]
+mod test {
+    use super::{DataType, Field};
+
+    #[test]
+    fn test_fields_with_dict_id() {
+        let dict1 = Field::new_dict(
+            "dict1",
+            DataType::Dictionary(DataType::Utf8.into(), DataType::Int32.into()),
+            false,
+            10,
+            false,
+        );
+        let dict2 = Field::new_dict(
+            "dict2",
+            DataType::Dictionary(DataType::Int32.into(), DataType::Int8.into()),
+            false,
+            20,
+            false,
+        );
+
+        let field = Field::new(
+            "struct<dict1, list[struct<dict2, list[struct<dict1]>]>",
+            DataType::Struct(vec![
+                dict1.clone(),
+                Field::new(
+                    "list[struct<dict1, list[struct<dict2>]>]",
+                    DataType::List(Box::new(Field::new(
+                        "struct<dict1, list[struct<dict2>]>",
+                        DataType::Struct(vec![
+                            dict1.clone(),
+                            Field::new(
+                                "list[struct<dict2>]",
+                                DataType::List(Box::new(Field::new(
+                                    "struct<dict2>",
+                                    DataType::Struct(vec![dict2.clone()]),
+                                    false,
+                                ))),
+                                false,
+                            ),
+                        ]),
+                        false,
+                    ))),
+                    false,
+                ),
+            ]),
+            false,
+        );
+
+        for field in field.fields_with_dict_id(10) {
+            assert_eq!(dict1, *field);
+        }
+        for field in field.fields_with_dict_id(20) {
+            assert_eq!(dict2, *field);
+        }
+    }
+}
diff --git a/arrow/src/datatypes/schema.rs b/arrow/src/datatypes/schema.rs
index cfc0744..cc8ddbd 100644
--- a/arrow/src/datatypes/schema.rs
+++ b/arrow/src/datatypes/schema.rs
@@ -159,6 +159,12 @@ impl Schema {
         &self.fields
     }
 
+    /// Returns a vector with references to all fields (including nested fields)
+    #[inline]
+    pub(crate) fn all_fields(&self) -> Vec<&Field> {
+        self.fields.iter().map(|f| f.fields()).flatten().collect()
+    }
+
     /// Returns an immutable reference of a specific `Field` instance selected using an
     /// offset within the internal `fields` vector.
     pub fn field(&self, i: usize) -> &Field {
@@ -175,7 +181,8 @@ impl Schema {
     pub fn fields_with_dict_id(&self, dict_id: i64) -> Vec<&Field> {
         self.fields
             .iter()
-            .filter(|f| f.dict_id() == Some(dict_id))
+            .map(|f| f.fields_with_dict_id(dict_id))
+            .flatten()
             .collect()
     }
 
diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs
index e925e2a..5bc76d0 100644
--- a/arrow/src/ipc/reader.rs
+++ b/arrow/src/ipc/reader.rs
@@ -495,7 +495,7 @@ pub fn read_dictionary(
     // in the reader. Note that a dictionary batch may be shared between many fields.
     // We don't currently record the isOrdered field. This could be general
     // attributes of arrays.
-    for (i, field) in schema.fields().iter().enumerate() {
+    for (i, field) in schema.all_fields().iter().enumerate() {
         if field.dict_id() == Some(id) {
             // Add (possibly multiple) array refs to the dictionaries array.
             dictionaries_by_field[i] = Some(dictionary_values.clone());
@@ -582,7 +582,7 @@ impl<R: Read + Seek> FileReader<R> {
         let schema = ipc::convert::fb_to_schema(ipc_schema);
 
         // Create an array of optional dictionary value arrays, one per field.
-        let mut dictionaries_by_field = vec![None; schema.fields().len()];
+        let mut dictionaries_by_field = vec![None; schema.all_fields().len()];
         for block in footer.dictionaries().unwrap() {
             // read length from end of offset
             let mut message_size: [u8; 4] = [0; 4];
@@ -923,7 +923,7 @@ mod tests {
 
     use flate2::read::GzDecoder;
 
-    use crate::util::integration_util::*;
+    use crate::{datatypes, util::integration_util::*};
 
     #[test]
     fn read_generated_files_014() {
@@ -1149,6 +1149,38 @@ mod tests {
         })
     }
 
+    #[test]
+    fn test_roundtrip_nested_dict() {
+        let inner: DictionaryArray<datatypes::Int32Type> =
+            vec!["a", "b", "a"].into_iter().collect();
+
+        let array = Arc::new(inner) as ArrayRef;
+
+        let dctfield = Field::new("dict", array.data_type().clone(), false);
+
+        let s = StructArray::from(vec![(dctfield, array)]);
+        let struct_array = Arc::new(s) as ArrayRef;
+
+        let schema = Arc::new(Schema::new(vec![Field::new(
+            "struct",
+            struct_array.data_type().clone(),
+            false,
+        )]));
+
+        let batch = RecordBatch::try_new(schema.clone(), vec![struct_array]).unwrap();
+
+        let mut buf = Vec::new();
+        let mut writer = ipc::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
+        writer.write(&batch).unwrap();
+        writer.finish().unwrap();
+        drop(writer);
+
+        let reader = ipc::reader::FileReader::try_new(std::io::Cursor::new(buf)).unwrap();
+        let batch2: std::result::Result<Vec<_>, _> = reader.collect();
+
+        assert_eq!(batch, batch2.unwrap()[0]);
+    }
+
     /// Read gzipped JSON file
     fn read_gzip_json(version: &str, path: &str) -> ArrowJson {
         let testdata = crate::util::test_util::arrow_test_data();