You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/06/27 20:43:51 UTC
[arrow-datafusion] branch main updated: Fix inserting into a table with non-nullable columns (#6722)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new c165f48df6 Fix inserting into a table with non-nullable columns (#6722)
c165f48df6 is described below
commit c165f48df69f42614bb25c19e35559283ec55eb9
Author: Jonah Gao <jo...@gmail.com>
AuthorDate: Wed Jun 28 04:43:46 2023 +0800
Fix inserting into a table with non-nullable columns (#6722)
* Fix inserting into a table with non-nullable columns
* Implement equivalent_names_and_types method for Schema
* Simplify StreamExt's intra-document link
* Improve check_batch
---
datafusion/common/src/dfschema.rs | 95 ++++++++++++++++-----
datafusion/common/src/lib.rs | 2 +-
datafusion/core/src/datasource/file_format/csv.rs | 3 +-
datafusion/core/src/datasource/listing/table.rs | 6 +-
datafusion/core/src/datasource/memory.rs | 5 +-
datafusion/core/src/physical_plan/insert.rs | 99 +++++++++++++++++++---
.../core/tests/sqllogictests/test_files/insert.slt | 30 +++++++
7 files changed, 203 insertions(+), 37 deletions(-)
diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs
index 292c19886b..c490852c6e 100644
--- a/datafusion/common/src/dfschema.rs
+++ b/datafusion/common/src/dfschema.rs
@@ -729,6 +729,34 @@ impl From<Field> for DFField {
}
}
+/// DataFusion-specific extensions to [`Schema`].
+pub trait SchemaExt {
+ /// This is a specialized version of Eq that ignores differences
+ /// in nullability and metadata.
+ ///
+ /// It works the same as [`DFSchema::equivalent_names_and_types`].
+ fn equivalent_names_and_types(&self, other: &Self) -> bool;
+}
+
+impl SchemaExt for Schema {
+ fn equivalent_names_and_types(&self, other: &Self) -> bool {
+ if self.fields().len() != other.fields().len() {
+ return false;
+ }
+
+ self.fields()
+ .iter()
+ .zip(other.fields().iter())
+ .all(|(f1, f2)| {
+ f1.name() == f2.name()
+ && DFSchema::datatype_is_semantically_equal(
+ f1.data_type(),
+ f2.data_type(),
+ )
+ })
+ }
+}
+
#[cfg(test)]
mod tests {
use crate::assert_contains;
@@ -995,7 +1023,8 @@ mod tests {
TestCase {
fields1: vec![&field1_i16_t],
fields2: vec![&field1_i16_t],
- expected: true,
+ expected_dfschema: true,
+ expected_arrow: true,
}
.run();
@@ -1003,7 +1032,8 @@ mod tests {
TestCase {
fields1: vec![&field1_i16_t_meta],
fields2: vec![&field1_i16_t],
- expected: true,
+ expected_dfschema: true,
+ expected_arrow: true,
}
.run();
@@ -1011,7 +1041,8 @@ mod tests {
TestCase {
fields1: vec![&field1_i16_t],
fields2: vec![&field2_i16_t],
- expected: false,
+ expected_dfschema: false,
+ expected_arrow: false,
}
.run();
@@ -1019,7 +1050,8 @@ mod tests {
TestCase {
fields1: vec![&field1_i16_t],
fields2: vec![&field1_i32_t],
- expected: false,
+ expected_dfschema: false,
+ expected_arrow: false,
}
.run();
@@ -1027,7 +1059,8 @@ mod tests {
TestCase {
fields1: vec![&field1_i16_t],
fields2: vec![&field1_i16_f],
- expected: true,
+ expected_dfschema: true,
+ expected_arrow: true,
}
.run();
@@ -1035,7 +1068,8 @@ mod tests {
TestCase {
fields1: vec![&field1_i16_t],
fields2: vec![&field1_i16_t_qualified],
- expected: false,
+ expected_dfschema: false,
+ expected_arrow: true,
}
.run();
@@ -1043,7 +1077,8 @@ mod tests {
TestCase {
fields1: vec![&field2_i16_t, &field1_i16_t],
fields2: vec![&field2_i16_t, &field3_i16_t],
- expected: false,
+ expected_dfschema: false,
+ expected_arrow: false,
}
.run();
@@ -1051,7 +1086,8 @@ mod tests {
TestCase {
fields1: vec![&field1_i16_t, &field2_i16_t],
fields2: vec![&field1_i16_t],
- expected: false,
+ expected_dfschema: false,
+ expected_arrow: false,
}
.run();
@@ -1059,7 +1095,8 @@ mod tests {
TestCase {
fields1: vec![&field_dict_t],
fields2: vec![&field_dict_t],
- expected: true,
+ expected_dfschema: true,
+ expected_arrow: true,
}
.run();
@@ -1067,7 +1104,8 @@ mod tests {
TestCase {
fields1: vec![&field_dict_t],
fields2: vec![&field_dict_f],
- expected: true,
+ expected_dfschema: true,
+ expected_arrow: true,
}
.run();
@@ -1075,7 +1113,8 @@ mod tests {
TestCase {
fields1: vec![&field_dict_t],
fields2: vec![&field1_i16_t],
- expected: false,
+ expected_dfschema: false,
+ expected_arrow: false,
}
.run();
@@ -1083,7 +1122,8 @@ mod tests {
TestCase {
fields1: vec![&list_t],
fields2: vec![&list_f],
- expected: true,
+ expected_dfschema: true,
+ expected_arrow: true,
}
.run();
@@ -1091,7 +1131,8 @@ mod tests {
TestCase {
fields1: vec![&list_t],
fields2: vec![&list_f_name],
- expected: false,
+ expected_dfschema: false,
+ expected_arrow: false,
}
.run();
@@ -1099,7 +1140,8 @@ mod tests {
TestCase {
fields1: vec![&struct_t],
fields2: vec![&struct_f],
- expected: true,
+ expected_dfschema: true,
+ expected_arrow: true,
}
.run();
@@ -1107,7 +1149,8 @@ mod tests {
TestCase {
fields1: vec![&struct_t],
fields2: vec![&struct_f_meta],
- expected: true,
+ expected_dfschema: true,
+ expected_arrow: true,
}
.run();
@@ -1115,7 +1158,8 @@ mod tests {
TestCase {
fields1: vec![&struct_t],
fields2: vec![&struct_f_type],
- expected: false,
+ expected_dfschema: false,
+ expected_arrow: false,
}
.run();
@@ -1123,7 +1167,8 @@ mod tests {
struct TestCase<'a> {
fields1: Vec<&'a DFField>,
fields2: Vec<&'a DFField>,
- expected: bool,
+ expected_dfschema: bool,
+ expected_arrow: bool,
}
impl<'a> TestCase<'a> {
@@ -1133,13 +1178,25 @@ mod tests {
let schema2 = to_df_schema(self.fields2);
assert_eq!(
schema1.equivalent_names_and_types(&schema2),
- self.expected,
+ self.expected_dfschema,
"Comparison did not match expected: {}\n\n\
schema1:\n\n{:#?}\n\nschema2:\n\n{:#?}",
- self.expected,
+ self.expected_dfschema,
schema1,
schema2
);
+
+ let arrow_schema1 = Schema::from(schema1);
+ let arrow_schema2 = Schema::from(schema2);
+ assert_eq!(
+ arrow_schema1.equivalent_names_and_types(&arrow_schema2),
+ self.expected_arrow,
+ "Comparison did not match expected: {}\n\n\
+ arrow schema1:\n\n{:#?}\n\n arrow schema2:\n\n{:#?}",
+ self.expected_arrow,
+ arrow_schema1,
+ arrow_schema2
+ );
}
}
diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs
index e941e443df..63b4024579 100644
--- a/datafusion/common/src/lib.rs
+++ b/datafusion/common/src/lib.rs
@@ -36,7 +36,7 @@ pub mod tree_node;
pub mod utils;
pub use column::Column;
-pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema};
+pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, SchemaExt, ToDFSchema};
pub use error::{
field_not_found, unqualified_field_not_found, DataFusionError, Result, SchemaError,
SharedResult,
diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs
index dd6d3cd7f7..c34c82a2c4 100644
--- a/datafusion/core/src/datasource/file_format/csv.rs
+++ b/datafusion/core/src/datasource/file_format/csv.rs
@@ -238,6 +238,7 @@ impl FileFormat for CsvFormat {
_state: &SessionState,
conf: FileSinkConfig,
) -> Result<Arc<dyn ExecutionPlan>> {
+ let sink_schema = conf.output_schema().clone();
let sink = Arc::new(CsvSink::new(
conf,
self.has_header,
@@ -245,7 +246,7 @@ impl FileFormat for CsvFormat {
self.file_compression_type.clone(),
));
- Ok(Arc::new(InsertExec::new(input, sink)) as _)
+ Ok(Arc::new(InsertExec::new(input, sink, sink_schema)) as _)
}
}
diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs
index 9beb88144d..c27dcaf391 100644
--- a/datafusion/core/src/datasource/listing/table.rs
+++ b/datafusion/core/src/datasource/listing/table.rs
@@ -25,7 +25,7 @@ use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef};
use arrow_schema::Schema;
use async_trait::async_trait;
use dashmap::DashMap;
-use datafusion_common::ToDFSchema;
+use datafusion_common::{SchemaExt, ToDFSchema};
use datafusion_expr::expr::Sort;
use datafusion_optimizer::utils::conjunction;
use datafusion_physical_expr::{create_physical_expr, LexOrdering, PhysicalSortExpr};
@@ -776,7 +776,7 @@ impl TableProvider for ListingTable {
input: Arc<dyn ExecutionPlan>,
) -> Result<Arc<dyn ExecutionPlan>> {
// Check that the schema of the plan matches the schema of this table.
- if !input.schema().eq(&self.schema()) {
+ if !self.schema().equivalent_names_and_types(&input.schema()) {
return Err(DataFusionError::Plan(
// Return an error if schema of the input query does not match with the table schema.
"Inserting query must have the same schema with the table.".to_string(),
@@ -816,7 +816,7 @@ impl TableProvider for ListingTable {
let config = FileSinkConfig {
object_store_url: self.table_paths()[0].object_store(),
file_groups,
- output_schema: input.schema(),
+ output_schema: self.schema(),
table_partition_cols: self.options.table_partition_cols.clone(),
writer_mode: crate::datasource::file_format::FileWriterMode::Append,
};
diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs
index b7cb013eba..784aa2aff2 100644
--- a/datafusion/core/src/datasource/memory.rs
+++ b/datafusion/core/src/datasource/memory.rs
@@ -26,6 +26,7 @@ use std::sync::Arc;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
+use datafusion_common::SchemaExt;
use datafusion_execution::TaskContext;
use tokio::sync::RwLock;
@@ -189,13 +190,13 @@ impl TableProvider for MemTable {
) -> Result<Arc<dyn ExecutionPlan>> {
// Create a physical plan from the logical plan.
// Check that the schema of the plan matches the schema of this table.
- if !input.schema().eq(&self.schema) {
+ if !self.schema().equivalent_names_and_types(&input.schema()) {
return Err(DataFusionError::Plan(
"Inserting query must have the same schema with the table.".to_string(),
));
}
let sink = Arc::new(MemSink::new(self.batches.clone()));
- Ok(Arc::new(InsertExec::new(input, sink)))
+ Ok(Arc::new(InsertExec::new(input, sink, self.schema.clone())))
}
}
diff --git a/datafusion/core/src/physical_plan/insert.rs b/datafusion/core/src/physical_plan/insert.rs
index 4742e1617e..15f77914d4 100644
--- a/datafusion/core/src/physical_plan/insert.rs
+++ b/datafusion/core/src/physical_plan/insert.rs
@@ -68,25 +68,72 @@ pub trait DataSink: DisplayAs + Debug + Send + Sync {
pub struct InsertExec {
/// Input plan that produces the record batches to be written.
input: Arc<dyn ExecutionPlan>,
- /// Sink to whic to write
+ /// Sink to which to write
sink: Arc<dyn DataSink>,
- /// Schema describing the structure of the data.
- schema: SchemaRef,
+ /// Schema of the sink for validating the input data
+ sink_schema: SchemaRef,
+ /// Schema describing the structure of the output data.
+ count_schema: SchemaRef,
}
impl fmt::Debug for InsertExec {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- write!(f, "InsertExec schema: {:?}", self.schema)
+ write!(f, "InsertExec schema: {:?}", self.count_schema)
}
}
impl InsertExec {
/// Create a plan to write to `sink`
- pub fn new(input: Arc<dyn ExecutionPlan>, sink: Arc<dyn DataSink>) -> Self {
+ pub fn new(
+ input: Arc<dyn ExecutionPlan>,
+ sink: Arc<dyn DataSink>,
+ sink_schema: SchemaRef,
+ ) -> Self {
Self {
input,
sink,
- schema: make_count_schema(),
+ sink_schema,
+ count_schema: make_count_schema(),
+ }
+ }
+
+ fn make_input_stream(
+ &self,
+ partition: usize,
+ context: Arc<TaskContext>,
+ ) -> Result<SendableRecordBatchStream> {
+ let input_stream = self.input.execute(partition, context)?;
+
+ debug_assert_eq!(
+ self.sink_schema.fields().len(),
+ self.input.schema().fields().len()
+ );
+
+ // Find input columns that may violate the not null constraint.
+ let risky_columns: Vec<_> = self
+ .sink_schema
+ .fields()
+ .iter()
+ .zip(self.input.schema().fields().iter())
+ .enumerate()
+ .filter_map(|(i, (sink_field, input_field))| {
+ if !sink_field.is_nullable() && input_field.is_nullable() {
+ Some(i)
+ } else {
+ None
+ }
+ })
+ .collect();
+
+ if risky_columns.is_empty() {
+ Ok(input_stream)
+ } else {
+ // Check not null constraint on the input stream
+ Ok(Box::pin(RecordBatchStreamAdapter::new(
+ self.sink_schema.clone(),
+ input_stream
+ .map(move |batch| check_not_null_contraits(batch?, &risky_columns)),
+ )))
}
}
}
@@ -99,7 +146,7 @@ impl ExecutionPlan for InsertExec {
/// Get the schema for this execution plan
fn schema(&self) -> SchemaRef {
- self.schema.clone()
+ self.count_schema.clone()
}
fn output_partitioning(&self) -> Partitioning {
@@ -142,7 +189,8 @@ impl ExecutionPlan for InsertExec {
Ok(Arc::new(Self {
input: children[0].clone(),
sink: self.sink.clone(),
- schema: self.schema.clone(),
+ sink_schema: self.sink_schema.clone(),
+ count_schema: self.count_schema.clone(),
}))
}
@@ -168,8 +216,9 @@ impl ExecutionPlan for InsertExec {
)));
}
- let data = self.input.execute(0, context.clone())?;
- let schema = self.schema.clone();
+ let data = self.make_input_stream(0, context.clone())?;
+
+ let count_schema = self.count_schema.clone();
let sink = self.sink.clone();
let stream = futures::stream::once(async move {
@@ -177,7 +226,10 @@ impl ExecutionPlan for InsertExec {
})
.boxed();
- Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
+ Ok(Box::pin(RecordBatchStreamAdapter::new(
+ count_schema,
+ stream,
+ )))
}
fn fmt_as(
@@ -221,3 +273,28 @@ fn make_count_schema() -> SchemaRef {
false,
)]))
}
+
+fn check_not_null_contraits(
+ batch: RecordBatch,
+ column_indices: &Vec<usize>,
+) -> Result<RecordBatch> {
+ for i in column_indices {
+ let index = *i;
+ if batch.num_columns() <= index {
+ return Err(DataFusionError::Execution(format!(
+ "Invalid batch column count {} expected > {}",
+ batch.num_columns(),
+ index
+ )));
+ }
+
+ if batch.column(index).null_count() > 0 {
+ return Err(DataFusionError::Execution(format!(
+ "Invalid batch column at '{}' has null but schema specifies non-nullable",
+ index
+ )));
+ }
+ }
+
+ Ok(batch)
+}
diff --git a/datafusion/core/tests/sqllogictests/test_files/insert.slt b/datafusion/core/tests/sqllogictests/test_files/insert.slt
index 2a04eaed32..c710859a7b 100644
--- a/datafusion/core/tests/sqllogictests/test_files/insert.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/insert.slt
@@ -270,3 +270,33 @@ select * from table_without_values;
statement ok
drop table table_without_values;
+
+
+# test insert with non-nullable column
+statement ok
+CREATE TABLE table_without_values(field1 BIGINT NOT NULL, field2 BIGINT NULL);
+
+query II
+insert into table_without_values values(1, 100);
+----
+1
+
+query II
+insert into table_without_values values(2, NULL);
+----
+1
+
+statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable
+insert into table_without_values values(NULL, 300);
+
+statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable
+insert into table_without_values values(3, 300), (NULL, 400);
+
+query II rowsort
+select * from table_without_values;
+----
+1 100
+2 NULL
+
+statement ok
+drop table table_without_values;
\ No newline at end of file