You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by su...@apache.org on 2019/03/06 00:45:35 UTC
[arrow] branch master updated: ARROW-4749: [Rust] Return Result for
RecordBatch::new()
This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new faeb309 ARROW-4749: [Rust] Return Result for RecordBatch::new()
faeb309 is described below
commit faeb3092c40a753147180065cf1e043426735ec7
Author: Neville Dipale <ne...@gmail.com>
AuthorDate: Tue Mar 5 16:45:24 2019 -0800
ARROW-4749: [Rust] Return Result for RecordBatch::new()
Adds more validation between schemas and columns,
returning an error when record types mismatch the schema
Author: Neville Dipale <ne...@gmail.com>
Author: Andy Grove <an...@gmail.com>
Closes #3800 from nevi-me/ARROW-4749 and squashes the following commits:
586395db <Neville Dipale> RecordBatch::try -> RecordBatch::try_new
fbdde0fc <Neville Dipale> fix csv writer tests
d92aeb63 <Andy Grove> fix aggr schema
5dd6bd01 <Neville Dipale> rebase against master, update record batch in in-memory source
0c402f5f <Neville Dipale> ARROW-4749: Return Result for RecordBatch::new()
---
rust/arrow/benches/csv_writer.rs | 2 +-
rust/arrow/examples/dynamic_types.rs | 10 ++--
rust/arrow/src/csv/reader.rs | 5 +-
rust/arrow/src/csv/writer.rs | 14 ++---
rust/arrow/src/error.rs | 1 +
rust/arrow/src/json/reader.rs | 5 +-
rust/arrow/src/record_batch.rs | 79 +++++++++++++++++++++++------
rust/datafusion/src/datasource/memory.rs | 53 +++++++++++--------
rust/datafusion/src/execution/aggregate.rs | 12 +++--
rust/datafusion/src/execution/context.rs | 9 +++-
rust/datafusion/src/execution/filter.rs | 2 +-
rust/datafusion/src/execution/limit.rs | 2 +-
rust/datafusion/src/execution/projection.rs | 2 +-
13 files changed, 139 insertions(+), 57 deletions(-)
diff --git a/rust/arrow/benches/csv_writer.rs b/rust/arrow/benches/csv_writer.rs
index ec3bc5a..49b1eed 100644
--- a/rust/arrow/benches/csv_writer.rs
+++ b/rust/arrow/benches/csv_writer.rs
@@ -48,7 +48,7 @@ fn record_batches_to_csv() {
let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);
- let b = RecordBatch::new(
+ let b = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
);
diff --git a/rust/arrow/examples/dynamic_types.rs b/rust/arrow/examples/dynamic_types.rs
index 8e6bb5d..2f361f4 100644
--- a/rust/arrow/examples/dynamic_types.rs
+++ b/rust/arrow/examples/dynamic_types.rs
@@ -22,9 +22,10 @@ extern crate arrow;
use arrow::array::*;
use arrow::datatypes::*;
+use arrow::error::Result;
use arrow::record_batch::*;
-fn main() {
+fn main() -> Result<()> {
// define schema
let schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
@@ -58,9 +59,10 @@ fn main() {
]);
// build a record batch
- let batch = RecordBatch::new(Arc::new(schema), vec![Arc::new(id), Arc::new(nested)]);
+ let batch =
+ RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id), Arc::new(nested)])?;
- process(&batch);
+ Ok(process(&batch))
}
/// Create a new batch by performing a projection of id, nested.c
@@ -88,7 +90,7 @@ fn process(batch: &RecordBatch) {
Field::new("sum", DataType::Float64, false),
]);
- let _ = RecordBatch::new(
+ let _ = RecordBatch::try_new(
Arc::new(projected_schema),
vec![
id.clone(), // NOTE: this is cloning the Arc not the array data
diff --git a/rust/arrow/src/csv/reader.rs b/rust/arrow/src/csv/reader.rs
index a511b93..85b2ccd 100644
--- a/rust/arrow/src/csv/reader.rs
+++ b/rust/arrow/src/csv/reader.rs
@@ -329,7 +329,10 @@ impl<R: Read> Reader<R> {
let projected_schema = Arc::new(Schema::new(projected_fields));
match arrays {
- Ok(arr) => Ok(Some(RecordBatch::new(projected_schema, arr))),
+ Ok(arr) => match RecordBatch::try_new(projected_schema, arr) {
+ Ok(batch) => Ok(Some(batch)),
+ Err(e) => Err(e),
+ },
Err(e) => Err(e),
}
}
diff --git a/rust/arrow/src/csv/writer.rs b/rust/arrow/src/csv/writer.rs
index bf1e582..945fb71 100644
--- a/rust/arrow/src/csv/writer.rs
+++ b/rust/arrow/src/csv/writer.rs
@@ -50,10 +50,10 @@
//! let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
//! let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);
//!
-//! let batch = RecordBatch::new(
+//! let batch = RecordBatch::try_new(
//! Arc::new(schema),
//! vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
-//! );
+//! ).unwrap();
//!
//! let file = get_temp_file("out.csv", &[]);
//!
@@ -287,10 +287,11 @@ mod tests {
let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);
- let batch = RecordBatch::new(
+ let batch = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
- );
+ )
+ .unwrap();
let file = get_temp_file("columns.csv", &[]);
@@ -331,10 +332,11 @@ mod tests {
let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);
- let batch = RecordBatch::new(
+ let batch = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
- );
+ )
+ .unwrap();
let file = get_temp_file("custom_options.csv", &[]);
diff --git a/rust/arrow/src/error.rs b/rust/arrow/src/error.rs
index 96ed944..2f758d4 100644
--- a/rust/arrow/src/error.rs
+++ b/rust/arrow/src/error.rs
@@ -30,6 +30,7 @@ pub enum ArrowError {
CsvError(String),
JsonError(String),
IoError(String),
+ InvalidArgumentError(String),
}
impl From<::std::io::Error> for ArrowError {
diff --git a/rust/arrow/src/json/reader.rs b/rust/arrow/src/json/reader.rs
index 1495492..8bdbf89 100644
--- a/rust/arrow/src/json/reader.rs
+++ b/rust/arrow/src/json/reader.rs
@@ -487,7 +487,10 @@ impl<R: Read> Reader<R> {
let projected_schema = Arc::new(Schema::new(projected_fields));
match arrays {
- Ok(arr) => Ok(Some(RecordBatch::new(projected_schema, arr))),
+ Ok(arr) => match RecordBatch::try_new(projected_schema, arr) {
+ Ok(batch) => Ok(Some(batch)),
+ Err(e) => Err(e),
+ },
Err(e) => Err(e),
}
}
diff --git a/rust/arrow/src/record_batch.rs b/rust/arrow/src/record_batch.rs
index e3da628..62f93b8 100644
--- a/rust/arrow/src/record_batch.rs
+++ b/rust/arrow/src/record_batch.rs
@@ -25,6 +25,7 @@ use std::sync::Arc;
use crate::array::*;
use crate::datatypes::*;
+use crate::error::{ArrowError, Result};
/// A batch of column-oriented data
#[derive(Clone)]
@@ -34,36 +35,61 @@ pub struct RecordBatch {
}
impl RecordBatch {
- pub fn new(schema: Arc<Schema>, columns: Vec<ArrayRef>) -> Self {
- // assert that there are some columns
- assert!(
- columns.len() > 0,
- "at least one column must be defined to create a record batch"
- );
- // assert that all columns have the same row count
+ /// Creates a `RecordBatch` from a schema and columns
+ ///
+ /// Expects the following:
+ /// * the vec of columns to not be empty
+ /// * the schema and column data types to have equal lengths and match
+ /// * each array in columns to have the same length
+ pub fn try_new(schema: Arc<Schema>, columns: Vec<ArrayRef>) -> Result<Self> {
+ // check that there are some columns
+ if columns.is_empty() {
+ return Err(ArrowError::InvalidArgumentError(
+ "at least one column must be defined to create a record batch"
+ .to_string(),
+ ));
+ }
+ // check that number of fields in schema match column length
+ if schema.fields().len() != columns.len() {
+ return Err(ArrowError::InvalidArgumentError(
+ "number of columns must match number of fields in schema".to_string(),
+ ));
+ }
+ // check that all columns have the same row count, and match the schema
let len = columns[0].data().len();
- for i in 1..columns.len() {
- assert_eq!(
- len,
- columns[i].len(),
- "all columns in a record batch must have the same length"
- );
+ for i in 0..columns.len() {
+ if columns[i].len() != len {
+ return Err(ArrowError::InvalidArgumentError(
+ "all columns in a record batch must have the same length".to_string(),
+ ));
+ }
+ if columns[i].data_type() != schema.field(i).data_type() {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "column types must match schema types, expected {:?} but found {:?} at column index {}",
+ schema.field(i).data_type(),
+ columns[i].data_type(),
+ i)));
+ }
}
- RecordBatch { schema, columns }
+ Ok(RecordBatch { schema, columns })
}
+ /// Returns the schema of the record batch
pub fn schema(&self) -> &Arc<Schema> {
&self.schema
}
+ /// Number of columns in the record batch
pub fn num_columns(&self) -> usize {
self.columns.len()
}
+ /// Number of rows in each column
pub fn num_rows(&self) -> usize {
self.columns[0].data().len()
}
+ /// Get a reference to a column's array by index
pub fn column(&self, i: usize) -> &ArrayRef {
&self.columns[i]
}
@@ -103,7 +129,8 @@ mod tests {
let b = BinaryArray::from(array_data);
let record_batch =
- RecordBatch::new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
+ RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])
+ .unwrap();
assert_eq!(5, record_batch.num_rows());
assert_eq!(2, record_batch.num_columns());
@@ -112,4 +139,26 @@ mod tests {
assert_eq!(5, record_batch.column(0).data().len());
assert_eq!(5, record_batch.column(1).data().len());
}
+
+ #[test]
+ fn create_record_batch_schema_mismatch() {
+ let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
+
+ let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
+
+ let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]);
+ assert!(!batch.is_ok());
+ }
+
+ #[test]
+ fn create_record_batch_record_mismatch() {
+ let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
+
+ let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
+ let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
+
+ let batch =
+ RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
+ assert!(!batch.is_ok());
+ }
}
diff --git a/rust/datafusion/src/datasource/memory.rs b/rust/datafusion/src/datasource/memory.rs
index 3367393..5168ae9 100644
--- a/rust/datafusion/src/datasource/memory.rs
+++ b/rust/datafusion/src/datasource/memory.rs
@@ -102,20 +102,25 @@ impl Table for MemTable {
let projected_schema = Arc::new(Schema::new(projected_columns?));
- Ok(Rc::new(RefCell::new(MemBatchIterator {
- schema: projected_schema.clone(),
- index: 0,
- batches: self
- .batches
- .iter()
- .map(|batch| {
- RecordBatch::new(
- projected_schema.clone(),
- columns.iter().map(|i| batch.column(*i).clone()).collect(),
- )
- })
- .collect(),
- })))
+ let batches = self
+ .batches
+ .iter()
+ .map(|batch| {
+ RecordBatch::try_new(
+ projected_schema.clone(),
+ columns.iter().map(|i| batch.column(*i).clone()).collect(),
+ )
+ })
+ .collect();
+
+ match batches {
+ Ok(batches) => Ok(Rc::new(RefCell::new(MemBatchIterator {
+ schema: projected_schema.clone(),
+ index: 0,
+ batches,
+ }))),
+ Err(e) => Err(ExecutionError::ArrowError(e)),
+ }
}
}
@@ -155,14 +160,15 @@ mod tests {
Field::new("c", DataType::Int32, false),
]));
- let batch = RecordBatch::new(
+ let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
- );
+ )
+ .unwrap();
let provider = MemTable::new(schema, vec![batch]).unwrap();
@@ -183,14 +189,15 @@ mod tests {
Field::new("c", DataType::Int32, false),
]));
- let batch = RecordBatch::new(
+ let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
- );
+ )
+ .unwrap();
let provider = MemTable::new(schema, vec![batch]).unwrap();
@@ -208,14 +215,15 @@ mod tests {
Field::new("c", DataType::Int32, false),
]));
- let batch = RecordBatch::new(
+ let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
- );
+ )
+ .unwrap();
let provider = MemTable::new(schema, vec![batch]).unwrap();
@@ -243,14 +251,15 @@ mod tests {
Field::new("c", DataType::Int32, false),
]));
- let batch = RecordBatch::new(
+ let batch = RecordBatch::try_new(
schema1.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
- );
+ )
+ .unwrap();
match MemTable::new(schema2, vec![batch]) {
Err(ExecutionError::General(e)) => assert_eq!(
diff --git a/rust/datafusion/src/execution/aggregate.rs b/rust/datafusion/src/execution/aggregate.rs
index 87a9f11..f9eb3e2 100644
--- a/rust/datafusion/src/execution/aggregate.rs
+++ b/rust/datafusion/src/execution/aggregate.rs
@@ -800,7 +800,10 @@ impl AggregateRelation {
}
}
- Ok(Some(RecordBatch::new(self.schema.clone(), result_columns)))
+ Ok(Some(RecordBatch::try_new(
+ self.schema.clone(),
+ result_columns,
+ )?))
}
fn with_group_by(&mut self) -> Result<Option<RecordBatch>> {
@@ -1008,7 +1011,10 @@ impl AggregateRelation {
result_arrays.push(array?);
}
- Ok(Some(RecordBatch::new(self.schema.clone(), result_arrays)))
+ Ok(Some(RecordBatch::try_new(
+ self.schema.clone(),
+ result_arrays,
+ )?))
}
}
@@ -1136,7 +1142,7 @@ mod tests {
.unwrap();
let aggr_schema = Arc::new(Schema::new(vec![
- Field::new("c2", DataType::Int32, false),
+ Field::new("c2", DataType::UInt32, false),
Field::new("min", DataType::Float64, false),
Field::new("max", DataType::Float64, false),
Field::new("sum", DataType::Float64, false),
diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs
index b26b310..8d4a984 100644
--- a/rust/datafusion/src/execution/context.rs
+++ b/rust/datafusion/src/execution/context.rs
@@ -187,8 +187,15 @@ impl ExecutionContext {
.collect();
let compiled_aggr_expr = compiled_aggr_expr_result?;
+ let mut output_fields: Vec<Field> = vec![];
+ for expr in group_expr {
+ output_fields.push(expr_to_field(expr, input_schema.as_ref()));
+ }
+ for expr in aggr_expr {
+ output_fields.push(expr_to_field(expr, input_schema.as_ref()));
+ }
let rel = AggregateRelation::new(
- Arc::new(Schema::empty()), //(expr_to_field(&compiled_group_expr, &input_schema))),
+ Arc::new(Schema::new(output_fields)),
input_rel,
compiled_group_expr,
compiled_aggr_expr,
diff --git a/rust/datafusion/src/execution/filter.rs b/rust/datafusion/src/execution/filter.rs
index c4d1cec..3467926 100644
--- a/rust/datafusion/src/execution/filter.rs
+++ b/rust/datafusion/src/execution/filter.rs
@@ -71,7 +71,7 @@ impl Relation for FilterRelation {
.collect();
let filtered_batch: RecordBatch =
- RecordBatch::new(self.schema.clone(), filtered_columns?);
+ RecordBatch::try_new(self.schema.clone(), filtered_columns?)?;
Ok(Some(filtered_batch))
}
diff --git a/rust/datafusion/src/execution/limit.rs b/rust/datafusion/src/execution/limit.rs
index bfd8706..c58e4fd 100644
--- a/rust/datafusion/src/execution/limit.rs
+++ b/rust/datafusion/src/execution/limit.rs
@@ -66,7 +66,7 @@ impl Relation for LimitRelation {
.collect();
let limited_batch: RecordBatch =
- RecordBatch::new(self.schema.clone(), limited_columns?);
+ RecordBatch::try_new(self.schema.clone(), limited_columns?)?;
self.num_consumed_rows += capacity;
Ok(Some(limited_batch))
diff --git a/rust/datafusion/src/execution/projection.rs b/rust/datafusion/src/execution/projection.rs
index bcf3ebb..a02213a 100644
--- a/rust/datafusion/src/execution/projection.rs
+++ b/rust/datafusion/src/execution/projection.rs
@@ -64,7 +64,7 @@ impl Relation for ProjectRelation {
);
let projected_batch: RecordBatch =
- RecordBatch::new(Arc::new(schema), projected_columns?);
+ RecordBatch::try_new(Arc::new(schema), projected_columns?)?;
Ok(Some(projected_batch))
}