You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/12/22 16:34:47 UTC
[arrow-rs] branch active_release updated: Add Schema::project and RecordBatch::project functions (#1033) (#1077)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch active_release
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/active_release by this push:
new e0abdb9 Add Schema::project and RecordBatch::project functions (#1033) (#1077)
e0abdb9 is described below
commit e0abdb9e62772a2f853974e68e744246e7f47569
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Wed Dec 22 11:34:39 2021 -0500
Add Schema::project and RecordBatch::project functions (#1033) (#1077)
* Allow Schema and RecordBatch to project schemas on specific columns returning a new schema with those columns only
* Addressing PR updates and adding a test for out of range projection
* switch to &[usize]
* fix: clippy and fmt
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
Co-authored-by: Stephen Carman <hn...@users.noreply.github.com>
---
arrow/src/datatypes/schema.rs | 65 +++++++++++++++++++++++++++++++++++++++++++
arrow/src/record_batch.rs | 38 +++++++++++++++++++++++++
2 files changed, 103 insertions(+)
diff --git a/arrow/src/datatypes/schema.rs b/arrow/src/datatypes/schema.rs
index 22b1ceb..8c30144 100644
--- a/arrow/src/datatypes/schema.rs
+++ b/arrow/src/datatypes/schema.rs
@@ -87,6 +87,24 @@ impl Schema {
Self { fields, metadata }
}
+ /// Returns a new schema with only the specified columns in the new schema
+ /// This carries metadata from the parent schema over as well
+ pub fn project(&self, indices: &[usize]) -> Result<Schema> {
+ let new_fields = indices
+ .iter()
+ .map(|i| {
+ self.fields.get(*i).cloned().ok_or_else(|| {
+ ArrowError::SchemaError(format!(
+ "project index {} out of bounds, max field {}",
+ i,
+ self.fields().len()
+ ))
+ })
+ })
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self::new_with_metadata(new_fields, self.metadata.clone()))
+ }
+
/// Merge schema into self if it is compatible. Struct fields will be merged recursively.
///
/// Example:
@@ -369,4 +387,51 @@ mod tests {
assert_eq!(schema, de_schema);
}
+
+ #[test]
+ fn test_projection() {
+ let mut metadata = HashMap::new();
+ metadata.insert("meta".to_string(), "data".to_string());
+
+ let schema = Schema::new_with_metadata(
+ vec![
+ Field::new("name", DataType::Utf8, false),
+ Field::new("address", DataType::Utf8, false),
+ Field::new("priority", DataType::UInt8, false),
+ ],
+ metadata,
+ );
+
+ let projected: Schema = schema.project(&[0, 2]).unwrap();
+
+ assert_eq!(projected.fields().len(), 2);
+ assert_eq!(projected.fields()[0].name(), "name");
+ assert_eq!(projected.fields()[1].name(), "priority");
+ assert_eq!(projected.metadata.get("meta").unwrap(), "data")
+ }
+
+ #[test]
+ fn test_oob_projection() {
+ let mut metadata = HashMap::new();
+ metadata.insert("meta".to_string(), "data".to_string());
+
+ let schema = Schema::new_with_metadata(
+ vec![
+ Field::new("name", DataType::Utf8, false),
+ Field::new("address", DataType::Utf8, false),
+ Field::new("priority", DataType::UInt8, false),
+ ],
+ metadata,
+ );
+
+ let projected: Result<Schema> = schema.project(&[0, 3]);
+
+ assert!(projected.is_err());
+ if let Err(e) = projected {
+ assert_eq!(
+ e.to_string(),
+ "Schema error: project index 3 out of bounds, max field 3".to_string()
+ )
+ }
+ }
}
diff --git a/arrow/src/record_batch.rs b/arrow/src/record_batch.rs
index b441f6c..9faba7d 100644
--- a/arrow/src/record_batch.rs
+++ b/arrow/src/record_batch.rs
@@ -175,6 +175,25 @@ impl RecordBatch {
self.schema.clone()
}
+ /// Projects the schema onto the specified columns
+ pub fn project(&self, indices: &[usize]) -> Result<RecordBatch> {
+ let projected_schema = self.schema.project(indices)?;
+ let batch_fields = indices
+ .iter()
+ .map(|f| {
+ self.columns.get(*f).cloned().ok_or_else(|| {
+ ArrowError::SchemaError(format!(
+ "project index {} out of bounds, max field {}",
+ f,
+ self.columns.len()
+ ))
+ })
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ RecordBatch::try_new(SchemaRef::new(projected_schema), batch_fields)
+ }
+
/// Returns the number of columns in the record batch.
///
/// # Example
@@ -900,4 +919,23 @@ mod tests {
assert_ne!(batch1, batch2);
}
+
+ #[test]
+ fn project() {
+ let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
+ let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
+ let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
+
+ let record_batch = RecordBatch::try_from_iter(vec![
+ ("a", a.clone()),
+ ("b", b.clone()),
+ ("c", c.clone()),
+ ])
+ .expect("valid conversion");
+
+ let expected = RecordBatch::try_from_iter(vec![("a", a), ("c", c)])
+ .expect("valid conversion");
+
+ assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
+ }
}