You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by su...@apache.org on 2019/03/06 00:45:35 UTC

[arrow] branch master updated: ARROW-4749: [Rust] Return Result for RecordBatch::new()

This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new faeb309  ARROW-4749: [Rust] Return Result for RecordBatch::new()
faeb309 is described below

commit faeb3092c40a753147180065cf1e043426735ec7
Author: Neville Dipale <ne...@gmail.com>
AuthorDate: Tue Mar 5 16:45:24 2019 -0800

    ARROW-4749: [Rust] Return Result for RecordBatch::new()
    
    Adds more validation between schemas and columns,
    returning an error when record types mismatch the schema
    
    Author: Neville Dipale <ne...@gmail.com>
    Author: Andy Grove <an...@gmail.com>
    
    Closes #3800 from nevi-me/ARROW-4749 and squashes the following commits:
    
    586395db <Neville Dipale> RecordBatch::try -> RecordBatch::try_new
    fbdde0fc <Neville Dipale> fix csv writer tests
    d92aeb63 <Andy Grove> fix aggr schema
    5dd6bd01 <Neville Dipale> rebase against master, update record batch in in-memory source
    0c402f5f <Neville Dipale> ARROW-4749:  Return Result for RecordBatch::new()
---
 rust/arrow/benches/csv_writer.rs            |  2 +-
 rust/arrow/examples/dynamic_types.rs        | 10 ++--
 rust/arrow/src/csv/reader.rs                |  5 +-
 rust/arrow/src/csv/writer.rs                | 14 ++---
 rust/arrow/src/error.rs                     |  1 +
 rust/arrow/src/json/reader.rs               |  5 +-
 rust/arrow/src/record_batch.rs              | 79 +++++++++++++++++++++++------
 rust/datafusion/src/datasource/memory.rs    | 53 +++++++++++--------
 rust/datafusion/src/execution/aggregate.rs  | 12 +++--
 rust/datafusion/src/execution/context.rs    |  9 +++-
 rust/datafusion/src/execution/filter.rs     |  2 +-
 rust/datafusion/src/execution/limit.rs      |  2 +-
 rust/datafusion/src/execution/projection.rs |  2 +-
 13 files changed, 139 insertions(+), 57 deletions(-)

diff --git a/rust/arrow/benches/csv_writer.rs b/rust/arrow/benches/csv_writer.rs
index ec3bc5a..49b1eed 100644
--- a/rust/arrow/benches/csv_writer.rs
+++ b/rust/arrow/benches/csv_writer.rs
@@ -48,7 +48,7 @@ fn record_batches_to_csv() {
     let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
     let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);
 
-    let b = RecordBatch::new(
+    let b = RecordBatch::try_new(
         Arc::new(schema),
         vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
     );
diff --git a/rust/arrow/examples/dynamic_types.rs b/rust/arrow/examples/dynamic_types.rs
index 8e6bb5d..2f361f4 100644
--- a/rust/arrow/examples/dynamic_types.rs
+++ b/rust/arrow/examples/dynamic_types.rs
@@ -22,9 +22,10 @@ extern crate arrow;
 
 use arrow::array::*;
 use arrow::datatypes::*;
+use arrow::error::Result;
 use arrow::record_batch::*;
 
-fn main() {
+fn main() -> Result<()> {
     // define schema
     let schema = Schema::new(vec![
         Field::new("id", DataType::Int32, false),
@@ -58,9 +59,10 @@ fn main() {
     ]);
 
     // build a record batch
-    let batch = RecordBatch::new(Arc::new(schema), vec![Arc::new(id), Arc::new(nested)]);
+    let batch =
+        RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id), Arc::new(nested)])?;
 
-    process(&batch);
+    Ok(process(&batch))
 }
 
 /// Create a new batch by performing a projection of id, nested.c
@@ -88,7 +90,7 @@ fn process(batch: &RecordBatch) {
         Field::new("sum", DataType::Float64, false),
     ]);
 
-    let _ = RecordBatch::new(
+    let _ = RecordBatch::try_new(
         Arc::new(projected_schema),
         vec![
             id.clone(), // NOTE: this is cloning the Arc not the array data
diff --git a/rust/arrow/src/csv/reader.rs b/rust/arrow/src/csv/reader.rs
index a511b93..85b2ccd 100644
--- a/rust/arrow/src/csv/reader.rs
+++ b/rust/arrow/src/csv/reader.rs
@@ -329,7 +329,10 @@ impl<R: Read> Reader<R> {
         let projected_schema = Arc::new(Schema::new(projected_fields));
 
         match arrays {
-            Ok(arr) => Ok(Some(RecordBatch::new(projected_schema, arr))),
+            Ok(arr) => match RecordBatch::try_new(projected_schema, arr) {
+                Ok(batch) => Ok(Some(batch)),
+                Err(e) => Err(e),
+            },
             Err(e) => Err(e),
         }
     }
diff --git a/rust/arrow/src/csv/writer.rs b/rust/arrow/src/csv/writer.rs
index bf1e582..945fb71 100644
--- a/rust/arrow/src/csv/writer.rs
+++ b/rust/arrow/src/csv/writer.rs
@@ -50,10 +50,10 @@
 //! let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
 //! let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);
 //!
-//! let batch = RecordBatch::new(
+//! let batch = RecordBatch::try_new(
 //!     Arc::new(schema),
 //!     vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
-//! );
+//! ).unwrap();
 //!
 //! let file = get_temp_file("out.csv", &[]);
 //!
@@ -287,10 +287,11 @@ mod tests {
         let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
         let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);
 
-        let batch = RecordBatch::new(
+        let batch = RecordBatch::try_new(
             Arc::new(schema),
             vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
-        );
+        )
+        .unwrap();
 
         let file = get_temp_file("columns.csv", &[]);
 
@@ -331,10 +332,11 @@ mod tests {
         let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
         let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);
 
-        let batch = RecordBatch::new(
+        let batch = RecordBatch::try_new(
             Arc::new(schema),
             vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
-        );
+        )
+        .unwrap();
 
         let file = get_temp_file("custom_options.csv", &[]);
 
diff --git a/rust/arrow/src/error.rs b/rust/arrow/src/error.rs
index 96ed944..2f758d4 100644
--- a/rust/arrow/src/error.rs
+++ b/rust/arrow/src/error.rs
@@ -30,6 +30,7 @@ pub enum ArrowError {
     CsvError(String),
     JsonError(String),
     IoError(String),
+    InvalidArgumentError(String),
 }
 
 impl From<::std::io::Error> for ArrowError {
diff --git a/rust/arrow/src/json/reader.rs b/rust/arrow/src/json/reader.rs
index 1495492..8bdbf89 100644
--- a/rust/arrow/src/json/reader.rs
+++ b/rust/arrow/src/json/reader.rs
@@ -487,7 +487,10 @@ impl<R: Read> Reader<R> {
         let projected_schema = Arc::new(Schema::new(projected_fields));
 
         match arrays {
-            Ok(arr) => Ok(Some(RecordBatch::new(projected_schema, arr))),
+            Ok(arr) => match RecordBatch::try_new(projected_schema, arr) {
+                Ok(batch) => Ok(Some(batch)),
+                Err(e) => Err(e),
+            },
             Err(e) => Err(e),
         }
     }
diff --git a/rust/arrow/src/record_batch.rs b/rust/arrow/src/record_batch.rs
index e3da628..62f93b8 100644
--- a/rust/arrow/src/record_batch.rs
+++ b/rust/arrow/src/record_batch.rs
@@ -25,6 +25,7 @@ use std::sync::Arc;
 
 use crate::array::*;
 use crate::datatypes::*;
+use crate::error::{ArrowError, Result};
 
 /// A batch of column-oriented data
 #[derive(Clone)]
@@ -34,36 +35,61 @@ pub struct RecordBatch {
 }
 
 impl RecordBatch {
-    pub fn new(schema: Arc<Schema>, columns: Vec<ArrayRef>) -> Self {
-        // assert that there are some columns
-        assert!(
-            columns.len() > 0,
-            "at least one column must be defined to create a record batch"
-        );
-        // assert that all columns have the same row count
+    /// Creates a `RecordBatch` from a schema and columns
+    ///
+    /// Expects the following:
+    ///  * the vec of columns to not be empty
+    ///  * the schema and column data types to have equal lengths and match
+    ///  * each array in columns to have the same length
+    pub fn try_new(schema: Arc<Schema>, columns: Vec<ArrayRef>) -> Result<Self> {
+        // check that there are some columns
+        if columns.is_empty() {
+            return Err(ArrowError::InvalidArgumentError(
+                "at least one column must be defined to create a record batch"
+                    .to_string(),
+            ));
+        }
+        // check that number of fields in schema match column length
+        if schema.fields().len() != columns.len() {
+            return Err(ArrowError::InvalidArgumentError(
+                "number of columns must match number of fields in schema".to_string(),
+            ));
+        }
+        // check that all columns have the same row count, and match the schema
         let len = columns[0].data().len();
-        for i in 1..columns.len() {
-            assert_eq!(
-                len,
-                columns[i].len(),
-                "all columns in a record batch must have the same length"
-            );
+        for i in 0..columns.len() {
+            if columns[i].len() != len {
+                return Err(ArrowError::InvalidArgumentError(
+                    "all columns in a record batch must have the same length".to_string(),
+                ));
+            }
+            if columns[i].data_type() != schema.field(i).data_type() {
+                return Err(ArrowError::InvalidArgumentError(format!(
+                    "column types must match schema types, expected {:?} but found {:?} at column index {}", 
+                    schema.field(i).data_type(),
+                    columns[i].data_type(),
+                    i)));
+            }
         }
-        RecordBatch { schema, columns }
+        Ok(RecordBatch { schema, columns })
     }
 
+    /// Returns the schema of the record batch
     pub fn schema(&self) -> &Arc<Schema> {
         &self.schema
     }
 
+    /// Number of columns in the record batch
     pub fn num_columns(&self) -> usize {
         self.columns.len()
     }
 
+    /// Number of rows in each column
     pub fn num_rows(&self) -> usize {
         self.columns[0].data().len()
     }
 
+    /// Get a reference to a column's array by index
     pub fn column(&self, i: usize) -> &ArrayRef {
         &self.columns[i]
     }
@@ -103,7 +129,8 @@ mod tests {
         let b = BinaryArray::from(array_data);
 
         let record_batch =
-            RecordBatch::new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
+            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])
+                .unwrap();
 
         assert_eq!(5, record_batch.num_rows());
         assert_eq!(2, record_batch.num_columns());
@@ -112,4 +139,26 @@ mod tests {
         assert_eq!(5, record_batch.column(0).data().len());
         assert_eq!(5, record_batch.column(1).data().len());
     }
+
+    #[test]
+    fn create_record_batch_schema_mismatch() {
+        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
+
+        let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
+
+        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]);
+        assert!(!batch.is_ok());
+    }
+
+    #[test]
+    fn create_record_batch_record_mismatch() {
+        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
+
+        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
+        let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
+
+        let batch =
+            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
+        assert!(!batch.is_ok());
+    }
 }
diff --git a/rust/datafusion/src/datasource/memory.rs b/rust/datafusion/src/datasource/memory.rs
index 3367393..5168ae9 100644
--- a/rust/datafusion/src/datasource/memory.rs
+++ b/rust/datafusion/src/datasource/memory.rs
@@ -102,20 +102,25 @@ impl Table for MemTable {
 
         let projected_schema = Arc::new(Schema::new(projected_columns?));
 
-        Ok(Rc::new(RefCell::new(MemBatchIterator {
-            schema: projected_schema.clone(),
-            index: 0,
-            batches: self
-                .batches
-                .iter()
-                .map(|batch| {
-                    RecordBatch::new(
-                        projected_schema.clone(),
-                        columns.iter().map(|i| batch.column(*i).clone()).collect(),
-                    )
-                })
-                .collect(),
-        })))
+        let batches = self
+            .batches
+            .iter()
+            .map(|batch| {
+                RecordBatch::try_new(
+                    projected_schema.clone(),
+                    columns.iter().map(|i| batch.column(*i).clone()).collect(),
+                )
+            })
+            .collect();
+
+        match batches {
+            Ok(batches) => Ok(Rc::new(RefCell::new(MemBatchIterator {
+                schema: projected_schema.clone(),
+                index: 0,
+                batches,
+            }))),
+            Err(e) => Err(ExecutionError::ArrowError(e)),
+        }
     }
 }
 
@@ -155,14 +160,15 @@ mod tests {
             Field::new("c", DataType::Int32, false),
         ]));
 
-        let batch = RecordBatch::new(
+        let batch = RecordBatch::try_new(
             schema.clone(),
             vec![
                 Arc::new(Int32Array::from(vec![1, 2, 3])),
                 Arc::new(Int32Array::from(vec![4, 5, 6])),
                 Arc::new(Int32Array::from(vec![7, 8, 9])),
             ],
-        );
+        )
+        .unwrap();
 
         let provider = MemTable::new(schema, vec![batch]).unwrap();
 
@@ -183,14 +189,15 @@ mod tests {
             Field::new("c", DataType::Int32, false),
         ]));
 
-        let batch = RecordBatch::new(
+        let batch = RecordBatch::try_new(
             schema.clone(),
             vec![
                 Arc::new(Int32Array::from(vec![1, 2, 3])),
                 Arc::new(Int32Array::from(vec![4, 5, 6])),
                 Arc::new(Int32Array::from(vec![7, 8, 9])),
             ],
-        );
+        )
+        .unwrap();
 
         let provider = MemTable::new(schema, vec![batch]).unwrap();
 
@@ -208,14 +215,15 @@ mod tests {
             Field::new("c", DataType::Int32, false),
         ]));
 
-        let batch = RecordBatch::new(
+        let batch = RecordBatch::try_new(
             schema.clone(),
             vec![
                 Arc::new(Int32Array::from(vec![1, 2, 3])),
                 Arc::new(Int32Array::from(vec![4, 5, 6])),
                 Arc::new(Int32Array::from(vec![7, 8, 9])),
             ],
-        );
+        )
+        .unwrap();
 
         let provider = MemTable::new(schema, vec![batch]).unwrap();
 
@@ -243,14 +251,15 @@ mod tests {
             Field::new("c", DataType::Int32, false),
         ]));
 
-        let batch = RecordBatch::new(
+        let batch = RecordBatch::try_new(
             schema1.clone(),
             vec![
                 Arc::new(Int32Array::from(vec![1, 2, 3])),
                 Arc::new(Int32Array::from(vec![4, 5, 6])),
                 Arc::new(Int32Array::from(vec![7, 8, 9])),
             ],
-        );
+        )
+        .unwrap();
 
         match MemTable::new(schema2, vec![batch]) {
             Err(ExecutionError::General(e)) => assert_eq!(
diff --git a/rust/datafusion/src/execution/aggregate.rs b/rust/datafusion/src/execution/aggregate.rs
index 87a9f11..f9eb3e2 100644
--- a/rust/datafusion/src/execution/aggregate.rs
+++ b/rust/datafusion/src/execution/aggregate.rs
@@ -800,7 +800,10 @@ impl AggregateRelation {
             }
         }
 
-        Ok(Some(RecordBatch::new(self.schema.clone(), result_columns)))
+        Ok(Some(RecordBatch::try_new(
+            self.schema.clone(),
+            result_columns,
+        )?))
     }
 
     fn with_group_by(&mut self) -> Result<Option<RecordBatch>> {
@@ -1008,7 +1011,10 @@ impl AggregateRelation {
             result_arrays.push(array?);
         }
 
-        Ok(Some(RecordBatch::new(self.schema.clone(), result_arrays)))
+        Ok(Some(RecordBatch::try_new(
+            self.schema.clone(),
+            result_arrays,
+        )?))
     }
 }
 
@@ -1136,7 +1142,7 @@ mod tests {
         .unwrap();
 
         let aggr_schema = Arc::new(Schema::new(vec![
-            Field::new("c2", DataType::Int32, false),
+            Field::new("c2", DataType::UInt32, false),
             Field::new("min", DataType::Float64, false),
             Field::new("max", DataType::Float64, false),
             Field::new("sum", DataType::Float64, false),
diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs
index b26b310..8d4a984 100644
--- a/rust/datafusion/src/execution/context.rs
+++ b/rust/datafusion/src/execution/context.rs
@@ -187,8 +187,15 @@ impl ExecutionContext {
                     .collect();
                 let compiled_aggr_expr = compiled_aggr_expr_result?;
 
+                let mut output_fields: Vec<Field> = vec![];
+                for expr in group_expr {
+                    output_fields.push(expr_to_field(expr, input_schema.as_ref()));
+                }
+                for expr in aggr_expr {
+                    output_fields.push(expr_to_field(expr, input_schema.as_ref()));
+                }
                 let rel = AggregateRelation::new(
-                    Arc::new(Schema::empty()), //(expr_to_field(&compiled_group_expr, &input_schema))),
+                    Arc::new(Schema::new(output_fields)),
                     input_rel,
                     compiled_group_expr,
                     compiled_aggr_expr,
diff --git a/rust/datafusion/src/execution/filter.rs b/rust/datafusion/src/execution/filter.rs
index c4d1cec..3467926 100644
--- a/rust/datafusion/src/execution/filter.rs
+++ b/rust/datafusion/src/execution/filter.rs
@@ -71,7 +71,7 @@ impl Relation for FilterRelation {
                             .collect();
 
                         let filtered_batch: RecordBatch =
-                            RecordBatch::new(self.schema.clone(), filtered_columns?);
+                            RecordBatch::try_new(self.schema.clone(), filtered_columns?)?;
 
                         Ok(Some(filtered_batch))
                     }
diff --git a/rust/datafusion/src/execution/limit.rs b/rust/datafusion/src/execution/limit.rs
index bfd8706..c58e4fd 100644
--- a/rust/datafusion/src/execution/limit.rs
+++ b/rust/datafusion/src/execution/limit.rs
@@ -66,7 +66,7 @@ impl Relation for LimitRelation {
                         .collect();
 
                     let limited_batch: RecordBatch =
-                        RecordBatch::new(self.schema.clone(), limited_columns?);
+                        RecordBatch::try_new(self.schema.clone(), limited_columns?)?;
                     self.num_consumed_rows += capacity;
 
                     Ok(Some(limited_batch))
diff --git a/rust/datafusion/src/execution/projection.rs b/rust/datafusion/src/execution/projection.rs
index bcf3ebb..a02213a 100644
--- a/rust/datafusion/src/execution/projection.rs
+++ b/rust/datafusion/src/execution/projection.rs
@@ -64,7 +64,7 @@ impl Relation for ProjectRelation {
                 );
 
                 let projected_batch: RecordBatch =
-                    RecordBatch::new(Arc::new(schema), projected_columns?);
+                    RecordBatch::try_new(Arc::new(schema), projected_columns?)?;
 
                 Ok(Some(projected_batch))
             }