You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/01/01 10:48:33 UTC

[arrow] branch master updated: ARROW-10656: [Rust] Allow schema validation to ignore field names and only check data types on new batch

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 118f462  ARROW-10656: [Rust] Allow schema validation to ignore field names and only check data types on new batch
118f462 is described below

commit 118f4622934409178cce97881c752474840571e4
Author: Neville Dipale <ne...@gmail.com>
AuthorDate: Fri Jan 1 05:47:30 2021 -0500

    ARROW-10656: [Rust] Allow schema validation to ignore field names and only check data types on new batch
    
    This adds the option to create a new record batch with less strict validation for list field names.
    The default behaviour is preserved.
    
    Closes #8988 from nevi-me/ARROW-10656
    
    Authored-by: Neville Dipale <ne...@gmail.com>
    Signed-off-by: Andrew Lamb <an...@nerdnetworks.org>
---
 rust/arrow/src/datatypes.rs    | 25 +++++++++++++
 rust/arrow/src/record_batch.rs | 81 +++++++++++++++++++++++++++++++++++-------
 2 files changed, 93 insertions(+), 13 deletions(-)

diff --git a/rust/arrow/src/datatypes.rs b/rust/arrow/src/datatypes.rs
index d2cf47e..125adc4 100644
--- a/rust/arrow/src/datatypes.rs
+++ b/rust/arrow/src/datatypes.rs
@@ -1246,6 +1246,31 @@ impl DataType {
                 | Float64
         )
     }
+
+    /// Compares the datatype with another, ignoring nested field names
+    /// and metadata
+    pub(crate) fn equals_datatype(&self, other: &DataType) -> bool {
+        match (&self, other) {
+            (DataType::List(a), DataType::List(b))
+            | (DataType::LargeList(a), DataType::LargeList(b)) => {
+                a.is_nullable() == b.is_nullable()
+                    && a.data_type().equals_datatype(b.data_type())
+            }
+            (DataType::FixedSizeList(a, a_size), DataType::FixedSizeList(b, b_size)) => {
+                a_size == b_size
+                    && a.is_nullable() == b.is_nullable()
+                    && a.data_type().equals_datatype(b.data_type())
+            }
+            (DataType::Struct(a), DataType::Struct(b)) => {
+                a.len() == b.len()
+                    && a.iter().zip(b).all(|(a, b)| {
+                        a.is_nullable() == b.is_nullable()
+                            && a.data_type().equals_datatype(b.data_type())
+                    })
+            }
+            _ => self == other,
+        }
+    }
 }
 
 impl Field {
diff --git a/rust/arrow/src/record_batch.rs b/rust/arrow/src/record_batch.rs
index b4aa97d..14731b6 100644
--- a/rust/arrow/src/record_batch.rs
+++ b/rust/arrow/src/record_batch.rs
@@ -75,6 +75,25 @@ impl RecordBatch {
     /// # }
     /// ```
     pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self> {
+        let options = RecordBatchOptions::default();
+        Self::validate_new_batch(&schema, columns.as_slice(), &options)?;
+        Ok(RecordBatch { schema, columns })
+    }
+
+    pub fn try_new_with_options(
+        schema: SchemaRef,
+        columns: Vec<ArrayRef>,
+        options: &RecordBatchOptions,
+    ) -> Result<Self> {
+        Self::validate_new_batch(&schema, columns.as_slice(), options)?;
+        Ok(RecordBatch { schema, columns })
+    }
+
+    fn validate_new_batch(
+        schema: &SchemaRef,
+        columns: &[ArrayRef],
+        options: &RecordBatchOptions,
+    ) -> Result<()> {
         // check that there are some columns
         if columns.is_empty() {
             return Err(ArrowError::InvalidArgumentError(
@@ -93,22 +112,45 @@ impl RecordBatch {
         // check that all columns have the same row count, and match the schema
         let len = columns[0].data().len();
 
-        for (i, column) in columns.iter().enumerate() {
-            if column.len() != len {
-                return Err(ArrowError::InvalidArgumentError(
-                    "all columns in a record batch must have the same length".to_string(),
-                ));
+        // This is a bit repetitive, but it is better to check the condition outside the loop
+        if options.match_field_names {
+            for (i, column) in columns.iter().enumerate() {
+                if column.len() != len {
+                    return Err(ArrowError::InvalidArgumentError(
+                        "all columns in a record batch must have the same length"
+                            .to_string(),
+                    ));
+                }
+                if column.data_type() != schema.field(i).data_type() {
+                    return Err(ArrowError::InvalidArgumentError(format!(
+                        "column types must match schema types, expected {:?} but found {:?} at column index {}",
+                        schema.field(i).data_type(),
+                        column.data_type(),
+                        i)));
+                }
             }
-            // list types can have different names, but we only need the data types to be the same
-            if column.data_type() != schema.field(i).data_type() {
-                return Err(ArrowError::InvalidArgumentError(format!(
-                    "column types must match schema types, expected {:?} but found {:?} at column index {}",
-                    schema.field(i).data_type(),
-                    column.data_type(),
-                    i)));
+        } else {
+            for (i, column) in columns.iter().enumerate() {
+                if column.len() != len {
+                    return Err(ArrowError::InvalidArgumentError(
+                        "all columns in a record batch must have the same length"
+                            .to_string(),
+                    ));
+                }
+                if !column
+                    .data_type()
+                    .equals_datatype(schema.field(i).data_type())
+                {
+                    return Err(ArrowError::InvalidArgumentError(format!(
+                        "column types must match schema types, expected {:?} but found {:?} at column index {}",
+                        schema.field(i).data_type(),
+                        column.data_type(),
+                        i)));
+                }
             }
         }
-        Ok(RecordBatch { schema, columns })
+
+        Ok(())
     }
 
     /// Returns the [`Schema`](crate::datatypes::Schema) of the record batch.
@@ -187,6 +229,19 @@ impl RecordBatch {
     }
 }
 
+#[derive(Debug)]
+pub struct RecordBatchOptions {
+    pub match_field_names: bool,
+}
+
+impl Default for RecordBatchOptions {
+    fn default() -> Self {
+        Self {
+            match_field_names: true,
+        }
+    }
+}
+
 impl From<&StructArray> for RecordBatch {
     /// Create a record batch from struct array.
     ///