You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/11/29 21:49:43 UTC

(arrow-rs) branch master updated: Support nested schema projection (#5148) (#5149)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 6d4b8bbad9 Support nested schema projection (#5148) (#5149)
6d4b8bbad9 is described below

commit 6d4b8bbad95c7e4fec0c4f1fb755ad7a1c542983
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Wed Nov 29 21:49:38 2023 +0000

    Support nested schema projection (#5148) (#5149)
    
    * Support nested schema projection
    
    * Tweak doc
    
    * Review feedback
---
 arrow-schema/src/fields.rs | 232 ++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 231 insertions(+), 1 deletion(-)

diff --git a/arrow-schema/src/fields.rs b/arrow-schema/src/fields.rs
index f90632455f..400f42c59c 100644
--- a/arrow-schema/src/fields.rs
+++ b/arrow-schema/src/fields.rs
@@ -15,10 +15,11 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::{ArrowError, Field, FieldRef, SchemaBuilder};
 use std::ops::Deref;
 use std::sync::Arc;
 
+use crate::{ArrowError, DataType, Field, FieldRef, SchemaBuilder};
+
 /// A cheaply cloneable, owned slice of [`FieldRef`]
 ///
 /// Similar to `Arc<Vec<FieldRef>>` or `Arc<[FieldRef]>`
@@ -99,6 +100,108 @@ impl Fields {
                 .all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b))
     }
 
+    /// Returns a copy of this [`Fields`] containing only those [`FieldRef`] passing a predicate
+    ///
+    /// Performs a depth-first scan of [`Fields`] invoking `filter` for each [`FieldRef`]
+    /// containing no child [`FieldRef`], a leaf field, along with a count of the number
+    /// of such leaves encountered so far. Only [`FieldRef`] for which `filter`
+    /// returned `true` will be included in the result.
+    ///
+    /// This can therefore be used to select a subset of fields from nested types
+    /// such as [`DataType::Struct`] or [`DataType::List`].
+    ///
+    /// ```
+    /// # use arrow_schema::{DataType, Field, Fields};
+    /// let fields = Fields::from(vec![
+    ///     Field::new("a", DataType::Int32, true), // Leaf 0
+    ///     Field::new("b", DataType::Struct(Fields::from(vec![
+    ///         Field::new("c", DataType::Float32, false), // Leaf 1
+    ///         Field::new("d", DataType::Float64, false), // Leaf 2
+    ///         Field::new("e", DataType::Struct(Fields::from(vec![
+    ///             Field::new("f", DataType::Int32, false),   // Leaf 3
+    ///             Field::new("g", DataType::Float16, false), // Leaf 4
+    ///         ])), true),
+    ///     ])), false)
+    /// ]);
+    /// let filtered = fields.filter_leaves(|idx, _| [0, 2, 3, 4].contains(&idx));
+    /// let expected = Fields::from(vec![
+    ///     Field::new("a", DataType::Int32, true),
+    ///     Field::new("b", DataType::Struct(Fields::from(vec![
+    ///         Field::new("d", DataType::Float64, false),
+    ///         Field::new("e", DataType::Struct(Fields::from(vec![
+    ///             Field::new("f", DataType::Int32, false),
+    ///             Field::new("g", DataType::Float16, false),
+    ///         ])), true),
+    ///     ])), false)
+    /// ]);
+    /// assert_eq!(filtered, expected);
+    /// ```
+    pub fn filter_leaves<F: FnMut(usize, &FieldRef) -> bool>(&self, mut filter: F) -> Self {
+        fn filter_field<F: FnMut(&FieldRef) -> bool>(
+            f: &FieldRef,
+            filter: &mut F,
+        ) -> Option<FieldRef> {
+            use DataType::*;
+
+            let v = match f.data_type() {
+                Dictionary(_, v) => v.as_ref(),       // Key must be integer
+                RunEndEncoded(_, v) => v.data_type(), // Run-ends must be integer
+                d => d,
+            };
+            let d = match v {
+                List(child) => List(filter_field(child, filter)?),
+                LargeList(child) => LargeList(filter_field(child, filter)?),
+                Map(child, ordered) => Map(filter_field(child, filter)?, *ordered),
+                FixedSizeList(child, size) => FixedSizeList(filter_field(child, filter)?, *size),
+                Struct(fields) => {
+                    let filtered: Fields = fields
+                        .iter()
+                        .filter_map(|f| filter_field(f, filter))
+                        .collect();
+
+                    if filtered.is_empty() {
+                        return None;
+                    }
+
+                    Struct(filtered)
+                }
+                Union(fields, mode) => {
+                    let filtered: UnionFields = fields
+                        .iter()
+                        .filter_map(|(id, f)| Some((id, filter_field(f, filter)?)))
+                        .collect();
+
+                    if filtered.is_empty() {
+                        return None;
+                    }
+
+                    Union(filtered, *mode)
+                }
+                _ => return filter(f).then(|| f.clone()),
+            };
+            let d = match f.data_type() {
+                Dictionary(k, _) => Dictionary(k.clone(), Box::new(d)),
+                RunEndEncoded(v, f) => {
+                    RunEndEncoded(v.clone(), Arc::new(f.as_ref().clone().with_data_type(d)))
+                }
+                _ => d,
+            };
+            Some(Arc::new(f.as_ref().clone().with_data_type(d)))
+        }
+
+        let mut leaf_idx = 0;
+        let mut filter = |f: &FieldRef| {
+            let t = filter(leaf_idx, f);
+            leaf_idx += 1;
+            t
+        };
+
+        self.0
+            .iter()
+            .filter_map(|f| filter_field(f, &mut filter))
+            .collect()
+    }
+
     /// Remove a field by index and return it.
     ///
     /// # Panic
@@ -307,3 +410,130 @@ impl FromIterator<(i8, FieldRef)> for UnionFields {
         Self(iter.into_iter().collect())
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::UnionMode;
+
+    #[test]
+    fn test_filter() {
+        let floats = Fields::from(vec![
+            Field::new("a", DataType::Float32, false),
+            Field::new("b", DataType::Float32, false),
+        ]);
+        let fields = Fields::from(vec![
+            Field::new("a", DataType::Int32, true),
+            Field::new("floats", DataType::Struct(floats.clone()), true),
+            Field::new("b", DataType::Int16, true),
+            Field::new(
+                "c",
+                DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
+                false,
+            ),
+            Field::new(
+                "d",
+                DataType::Dictionary(
+                    Box::new(DataType::Int32),
+                    Box::new(DataType::Struct(floats.clone())),
+                ),
+                false,
+            ),
+            Field::new_list(
+                "e",
+                Field::new("floats", DataType::Struct(floats.clone()), true),
+                true,
+            ),
+            Field::new(
+                "f",
+                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3),
+                false,
+            ),
+            Field::new_map(
+                "g",
+                "entries",
+                Field::new("keys", DataType::LargeUtf8, false),
+                Field::new("values", DataType::Int32, true),
+                false,
+                false,
+            ),
+            Field::new(
+                "h",
+                DataType::Union(
+                    UnionFields::new(
+                        vec![1, 3],
+                        vec![
+                            Field::new("field1", DataType::UInt8, false),
+                            Field::new("field3", DataType::Utf8, false),
+                        ],
+                    ),
+                    UnionMode::Dense,
+                ),
+                true,
+            ),
+            Field::new(
+                "i",
+                DataType::RunEndEncoded(
+                    Arc::new(Field::new("run_ends", DataType::Int32, false)),
+                    Arc::new(Field::new("values", DataType::Struct(floats.clone()), true)),
+                ),
+                false,
+            ),
+        ]);
+
+        let floats_a = DataType::Struct(vec![floats[0].clone()].into());
+
+        let r = fields.filter_leaves(|idx, _| idx == 0 || idx == 1);
+        assert_eq!(r.len(), 2);
+        assert_eq!(r[0], fields[0]);
+        assert_eq!(r[1].data_type(), &floats_a);
+
+        let r = fields.filter_leaves(|_, f| f.name() == "a");
+        assert_eq!(r.len(), 5);
+        assert_eq!(r[0], fields[0]);
+        assert_eq!(r[1].data_type(), &floats_a);
+        assert_eq!(
+            r[2].data_type(),
+            &DataType::Dictionary(Box::new(DataType::Int32), Box::new(floats_a.clone()))
+        );
+        assert_eq!(
+            r[3].as_ref(),
+            &Field::new_list("e", Field::new("floats", floats_a.clone(), true), true)
+        );
+        assert_eq!(
+            r[4].as_ref(),
+            &Field::new(
+                "i",
+                DataType::RunEndEncoded(
+                    Arc::new(Field::new("run_ends", DataType::Int32, false)),
+                    Arc::new(Field::new("values", floats_a.clone(), true)),
+                ),
+                false,
+            )
+        );
+
+        let r = fields.filter_leaves(|_, f| f.name() == "floats");
+        assert_eq!(r.len(), 0);
+
+        let r = fields.filter_leaves(|idx, _| idx == 9);
+        assert_eq!(r.len(), 1);
+        assert_eq!(r[0], fields[6]);
+
+        let r = fields.filter_leaves(|idx, _| idx == 10 || idx == 11);
+        assert_eq!(r.len(), 1);
+        assert_eq!(r[0], fields[7]);
+
+        let union = DataType::Union(
+            UnionFields::new(vec![1], vec![Field::new("field1", DataType::UInt8, false)]),
+            UnionMode::Dense,
+        );
+
+        let r = fields.filter_leaves(|idx, _| idx == 12);
+        assert_eq!(r.len(), 1);
+        assert_eq!(r[0].data_type(), &union);
+
+        let r = fields.filter_leaves(|idx, _| idx == 14 || idx == 15);
+        assert_eq!(r.len(), 1);
+        assert_eq!(r[0], fields[9]);
+    }
+}