You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/11/29 21:49:43 UTC
(arrow-rs) branch master updated: Support nested schema projection (#5148) (#5149)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 6d4b8bbad9 Support nested schema projection (#5148) (#5149)
6d4b8bbad9 is described below
commit 6d4b8bbad95c7e4fec0c4f1fb755ad7a1c542983
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Wed Nov 29 21:49:38 2023 +0000
Support nested schema projection (#5148) (#5149)
* Support nested schema projection
* Tweak doc
* Review feedback
---
arrow-schema/src/fields.rs | 232 ++++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 231 insertions(+), 1 deletion(-)
diff --git a/arrow-schema/src/fields.rs b/arrow-schema/src/fields.rs
index f90632455f..400f42c59c 100644
--- a/arrow-schema/src/fields.rs
+++ b/arrow-schema/src/fields.rs
@@ -15,10 +15,11 @@
// specific language governing permissions and limitations
// under the License.
-use crate::{ArrowError, Field, FieldRef, SchemaBuilder};
use std::ops::Deref;
use std::sync::Arc;
+use crate::{ArrowError, DataType, Field, FieldRef, SchemaBuilder};
+
/// A cheaply cloneable, owned slice of [`FieldRef`]
///
/// Similar to `Arc<Vec<FieldRef>>` or `Arc<[FieldRef]>`
@@ -99,6 +100,108 @@ impl Fields {
.all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b))
}
+ /// Returns a copy of this [`Fields`] containing only those [`FieldRef`] passing a predicate
+ ///
+ /// Performs a depth-first scan of [`Fields`] invoking `filter` for each [`FieldRef`]
+ /// containing no child [`FieldRef`], a leaf field, along with a count of the number
+ /// of such leaves encountered so far. Only [`FieldRef`] for which `filter`
+ /// returned `true` will be included in the result.
+ ///
+ /// This can therefore be used to select a subset of fields from nested types
+ /// such as [`DataType::Struct`] or [`DataType::List`].
+ ///
+ /// ```
+ /// # use arrow_schema::{DataType, Field, Fields};
+ /// let fields = Fields::from(vec![
+ /// Field::new("a", DataType::Int32, true), // Leaf 0
+ /// Field::new("b", DataType::Struct(Fields::from(vec![
+ /// Field::new("c", DataType::Float32, false), // Leaf 1
+ /// Field::new("d", DataType::Float64, false), // Leaf 2
+ /// Field::new("e", DataType::Struct(Fields::from(vec![
+ /// Field::new("f", DataType::Int32, false), // Leaf 3
+ /// Field::new("g", DataType::Float16, false), // Leaf 4
+ /// ])), true),
+ /// ])), false)
+ /// ]);
+ /// let filtered = fields.filter_leaves(|idx, _| [0, 2, 3, 4].contains(&idx));
+ /// let expected = Fields::from(vec![
+ /// Field::new("a", DataType::Int32, true),
+ /// Field::new("b", DataType::Struct(Fields::from(vec![
+ /// Field::new("d", DataType::Float64, false),
+ /// Field::new("e", DataType::Struct(Fields::from(vec![
+ /// Field::new("f", DataType::Int32, false),
+ /// Field::new("g", DataType::Float16, false),
+ /// ])), true),
+ /// ])), false)
+ /// ]);
+ /// assert_eq!(filtered, expected);
+ /// ```
+ pub fn filter_leaves<F: FnMut(usize, &FieldRef) -> bool>(&self, mut filter: F) -> Self {
+ fn filter_field<F: FnMut(&FieldRef) -> bool>(
+ f: &FieldRef,
+ filter: &mut F,
+ ) -> Option<FieldRef> {
+ use DataType::*;
+
+ let v = match f.data_type() {
+ Dictionary(_, v) => v.as_ref(), // Key must be integer
+ RunEndEncoded(_, v) => v.data_type(), // Run-ends must be integer
+ d => d,
+ };
+ let d = match v {
+ List(child) => List(filter_field(child, filter)?),
+ LargeList(child) => LargeList(filter_field(child, filter)?),
+ Map(child, ordered) => Map(filter_field(child, filter)?, *ordered),
+ FixedSizeList(child, size) => FixedSizeList(filter_field(child, filter)?, *size),
+ Struct(fields) => {
+ let filtered: Fields = fields
+ .iter()
+ .filter_map(|f| filter_field(f, filter))
+ .collect();
+
+ if filtered.is_empty() {
+ return None;
+ }
+
+ Struct(filtered)
+ }
+ Union(fields, mode) => {
+ let filtered: UnionFields = fields
+ .iter()
+ .filter_map(|(id, f)| Some((id, filter_field(f, filter)?)))
+ .collect();
+
+ if filtered.is_empty() {
+ return None;
+ }
+
+ Union(filtered, *mode)
+ }
+ _ => return filter(f).then(|| f.clone()),
+ };
+ let d = match f.data_type() {
+ Dictionary(k, _) => Dictionary(k.clone(), Box::new(d)),
+ RunEndEncoded(v, f) => {
+ RunEndEncoded(v.clone(), Arc::new(f.as_ref().clone().with_data_type(d)))
+ }
+ _ => d,
+ };
+ Some(Arc::new(f.as_ref().clone().with_data_type(d)))
+ }
+
+ let mut leaf_idx = 0;
+ let mut filter = |f: &FieldRef| {
+ let t = filter(leaf_idx, f);
+ leaf_idx += 1;
+ t
+ };
+
+ self.0
+ .iter()
+ .filter_map(|f| filter_field(f, &mut filter))
+ .collect()
+ }
+
/// Remove a field by index and return it.
///
/// # Panic
@@ -307,3 +410,130 @@ impl FromIterator<(i8, FieldRef)> for UnionFields {
Self(iter.into_iter().collect())
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::UnionMode;
+
+ #[test]
+ fn test_filter() {
+ let floats = Fields::from(vec![
+ Field::new("a", DataType::Float32, false),
+ Field::new("b", DataType::Float32, false),
+ ]);
+ let fields = Fields::from(vec![
+ Field::new("a", DataType::Int32, true),
+ Field::new("floats", DataType::Struct(floats.clone()), true),
+ Field::new("b", DataType::Int16, true),
+ Field::new(
+ "c",
+ DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
+ false,
+ ),
+ Field::new(
+ "d",
+ DataType::Dictionary(
+ Box::new(DataType::Int32),
+ Box::new(DataType::Struct(floats.clone())),
+ ),
+ false,
+ ),
+ Field::new_list(
+ "e",
+ Field::new("floats", DataType::Struct(floats.clone()), true),
+ true,
+ ),
+ Field::new(
+ "f",
+ DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3),
+ false,
+ ),
+ Field::new_map(
+ "g",
+ "entries",
+ Field::new("keys", DataType::LargeUtf8, false),
+ Field::new("values", DataType::Int32, true),
+ false,
+ false,
+ ),
+ Field::new(
+ "h",
+ DataType::Union(
+ UnionFields::new(
+ vec![1, 3],
+ vec![
+ Field::new("field1", DataType::UInt8, false),
+ Field::new("field3", DataType::Utf8, false),
+ ],
+ ),
+ UnionMode::Dense,
+ ),
+ true,
+ ),
+ Field::new(
+ "i",
+ DataType::RunEndEncoded(
+ Arc::new(Field::new("run_ends", DataType::Int32, false)),
+ Arc::new(Field::new("values", DataType::Struct(floats.clone()), true)),
+ ),
+ false,
+ ),
+ ]);
+
+ let floats_a = DataType::Struct(vec![floats[0].clone()].into());
+
+ let r = fields.filter_leaves(|idx, _| idx == 0 || idx == 1);
+ assert_eq!(r.len(), 2);
+ assert_eq!(r[0], fields[0]);
+ assert_eq!(r[1].data_type(), &floats_a);
+
+ let r = fields.filter_leaves(|_, f| f.name() == "a");
+ assert_eq!(r.len(), 5);
+ assert_eq!(r[0], fields[0]);
+ assert_eq!(r[1].data_type(), &floats_a);
+ assert_eq!(
+ r[2].data_type(),
+ &DataType::Dictionary(Box::new(DataType::Int32), Box::new(floats_a.clone()))
+ );
+ assert_eq!(
+ r[3].as_ref(),
+ &Field::new_list("e", Field::new("floats", floats_a.clone(), true), true)
+ );
+ assert_eq!(
+ r[4].as_ref(),
+ &Field::new(
+ "i",
+ DataType::RunEndEncoded(
+ Arc::new(Field::new("run_ends", DataType::Int32, false)),
+ Arc::new(Field::new("values", floats_a.clone(), true)),
+ ),
+ false,
+ )
+ );
+
+ let r = fields.filter_leaves(|_, f| f.name() == "floats");
+ assert_eq!(r.len(), 0);
+
+ let r = fields.filter_leaves(|idx, _| idx == 9);
+ assert_eq!(r.len(), 1);
+ assert_eq!(r[0], fields[6]);
+
+ let r = fields.filter_leaves(|idx, _| idx == 10 || idx == 11);
+ assert_eq!(r.len(), 1);
+ assert_eq!(r[0], fields[7]);
+
+ let union = DataType::Union(
+ UnionFields::new(vec![1], vec![Field::new("field1", DataType::UInt8, false)]),
+ UnionMode::Dense,
+ );
+
+ let r = fields.filter_leaves(|idx, _| idx == 12);
+ assert_eq!(r.len(), 1);
+ assert_eq!(r[0].data_type(), &union);
+
+ let r = fields.filter_leaves(|idx, _| idx == 14 || idx == 15);
+ assert_eq!(r.len(), 1);
+ assert_eq!(r[0], fields[9]);
+ }
+}