You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/12/22 16:34:47 UTC

[arrow-rs] branch active_release updated: Add Schema::project and RecordBatch::project functions (#1033) (#1077)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch active_release
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/active_release by this push:
     new e0abdb9  Add Schema::project and RecordBatch::project functions  (#1033) (#1077)
e0abdb9 is described below

commit e0abdb9e62772a2f853974e68e744246e7f47569
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Wed Dec 22 11:34:39 2021 -0500

    Add Schema::project and RecordBatch::project functions  (#1033) (#1077)
    
    * Allow Schema and RecordBatch to project schemas on specific columns returning a new schema with those columns only
    
    * Addressing PR updates and adding a test for out of range projection
    
    * switch to &[usize]
    
    * fix: clippy and fmt
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
    
    Co-authored-by: Stephen Carman <hn...@users.noreply.github.com>
---
 arrow/src/datatypes/schema.rs | 65 +++++++++++++++++++++++++++++++++++++++++++
 arrow/src/record_batch.rs     | 38 +++++++++++++++++++++++++
 2 files changed, 103 insertions(+)

diff --git a/arrow/src/datatypes/schema.rs b/arrow/src/datatypes/schema.rs
index 22b1ceb..8c30144 100644
--- a/arrow/src/datatypes/schema.rs
+++ b/arrow/src/datatypes/schema.rs
@@ -87,6 +87,24 @@ impl Schema {
         Self { fields, metadata }
     }
 
+    /// Returns a new schema with only the specified columns in the new schema
+    /// This carries metadata from the parent schema over as well
+    pub fn project(&self, indices: &[usize]) -> Result<Schema> {
+        let new_fields = indices
+            .iter()
+            .map(|i| {
+                self.fields.get(*i).cloned().ok_or_else(|| {
+                    ArrowError::SchemaError(format!(
+                        "project index {} out of bounds, max field {}",
+                        i,
+                        self.fields().len()
+                    ))
+                })
+            })
+            .collect::<Result<Vec<_>>>()?;
+        Ok(Self::new_with_metadata(new_fields, self.metadata.clone()))
+    }
+
     /// Merge schema into self if it is compatible. Struct fields will be merged recursively.
     ///
     /// Example:
@@ -369,4 +387,51 @@ mod tests {
 
         assert_eq!(schema, de_schema);
     }
+
+    #[test]
+    fn test_projection() {
+        let mut metadata = HashMap::new();
+        metadata.insert("meta".to_string(), "data".to_string());
+
+        let schema = Schema::new_with_metadata(
+            vec![
+                Field::new("name", DataType::Utf8, false),
+                Field::new("address", DataType::Utf8, false),
+                Field::new("priority", DataType::UInt8, false),
+            ],
+            metadata,
+        );
+
+        let projected: Schema = schema.project(&[0, 2]).unwrap();
+
+        assert_eq!(projected.fields().len(), 2);
+        assert_eq!(projected.fields()[0].name(), "name");
+        assert_eq!(projected.fields()[1].name(), "priority");
+        assert_eq!(projected.metadata.get("meta").unwrap(), "data")
+    }
+
+    #[test]
+    fn test_oob_projection() {
+        let mut metadata = HashMap::new();
+        metadata.insert("meta".to_string(), "data".to_string());
+
+        let schema = Schema::new_with_metadata(
+            vec![
+                Field::new("name", DataType::Utf8, false),
+                Field::new("address", DataType::Utf8, false),
+                Field::new("priority", DataType::UInt8, false),
+            ],
+            metadata,
+        );
+
+        let projected: Result<Schema> = schema.project(&[0, 3]);
+
+        assert!(projected.is_err());
+        if let Err(e) = projected {
+            assert_eq!(
+                e.to_string(),
+                "Schema error: project index 3 out of bounds, max field 3".to_string()
+            )
+        }
+    }
 }
diff --git a/arrow/src/record_batch.rs b/arrow/src/record_batch.rs
index b441f6c..9faba7d 100644
--- a/arrow/src/record_batch.rs
+++ b/arrow/src/record_batch.rs
@@ -175,6 +175,25 @@ impl RecordBatch {
         self.schema.clone()
     }
 
+    /// Projects the schema onto the specified columns
+    pub fn project(&self, indices: &[usize]) -> Result<RecordBatch> {
+        let projected_schema = self.schema.project(indices)?;
+        let batch_fields = indices
+            .iter()
+            .map(|f| {
+                self.columns.get(*f).cloned().ok_or_else(|| {
+                    ArrowError::SchemaError(format!(
+                        "project index {} out of bounds, max field {}",
+                        f,
+                        self.columns.len()
+                    ))
+                })
+            })
+            .collect::<Result<Vec<_>>>()?;
+
+        RecordBatch::try_new(SchemaRef::new(projected_schema), batch_fields)
+    }
+
     /// Returns the number of columns in the record batch.
     ///
     /// # Example
@@ -900,4 +919,23 @@ mod tests {
 
         assert_ne!(batch1, batch2);
     }
+
+    #[test]
+    fn project() {
+        let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
+        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
+        let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
+
+        let record_batch = RecordBatch::try_from_iter(vec![
+            ("a", a.clone()),
+            ("b", b.clone()),
+            ("c", c.clone()),
+        ])
+        .expect("valid conversion");
+
+        let expected = RecordBatch::try_from_iter(vec![("a", a), ("c", c)])
+            .expect("valid conversion");
+
+        assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
+    }
 }