You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/06/27 20:43:51 UTC

[arrow-datafusion] branch main updated: Fix inserting into a table with non-nullable columns (#6722)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new c165f48df6 Fix inserting into a table with non-nullable columns (#6722)
c165f48df6 is described below

commit c165f48df69f42614bb25c19e35559283ec55eb9
Author: Jonah Gao <jo...@gmail.com>
AuthorDate: Wed Jun 28 04:43:46 2023 +0800

    Fix inserting into a table with non-nullable columns (#6722)
    
    * Fix inserting into a table with non-nullable columns
    
    * Implement equivalent_names_and_types method for Schema
    
    * Simplify StreamExt's intra-document link
    
    * Improve check_batch
---
 datafusion/common/src/dfschema.rs                  | 95 ++++++++++++++++-----
 datafusion/common/src/lib.rs                       |  2 +-
 datafusion/core/src/datasource/file_format/csv.rs  |  3 +-
 datafusion/core/src/datasource/listing/table.rs    |  6 +-
 datafusion/core/src/datasource/memory.rs           |  5 +-
 datafusion/core/src/physical_plan/insert.rs        | 99 +++++++++++++++++++---
 .../core/tests/sqllogictests/test_files/insert.slt | 30 +++++++
 7 files changed, 203 insertions(+), 37 deletions(-)

diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs
index 292c19886b..c490852c6e 100644
--- a/datafusion/common/src/dfschema.rs
+++ b/datafusion/common/src/dfschema.rs
@@ -729,6 +729,34 @@ impl From<Field> for DFField {
     }
 }
 
+/// DataFusion-specific extensions to [`Schema`].
+pub trait SchemaExt {
+    /// This is a specialized version of Eq that ignores differences
+    /// in nullability and metadata.
+    ///
+    /// It works the same as [`DFSchema::equivalent_names_and_types`].
+    fn equivalent_names_and_types(&self, other: &Self) -> bool;
+}
+
+impl SchemaExt for Schema {
+    fn equivalent_names_and_types(&self, other: &Self) -> bool {
+        if self.fields().len() != other.fields().len() {
+            return false;
+        }
+
+        self.fields()
+            .iter()
+            .zip(other.fields().iter())
+            .all(|(f1, f2)| {
+                f1.name() == f2.name()
+                    && DFSchema::datatype_is_semantically_equal(
+                        f1.data_type(),
+                        f2.data_type(),
+                    )
+            })
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use crate::assert_contains;
@@ -995,7 +1023,8 @@ mod tests {
         TestCase {
             fields1: vec![&field1_i16_t],
             fields2: vec![&field1_i16_t],
-            expected: true,
+            expected_dfschema: true,
+            expected_arrow: true,
         }
         .run();
 
@@ -1003,7 +1032,8 @@ mod tests {
         TestCase {
             fields1: vec![&field1_i16_t_meta],
             fields2: vec![&field1_i16_t],
-            expected: true,
+            expected_dfschema: true,
+            expected_arrow: true,
         }
         .run();
 
@@ -1011,7 +1041,8 @@ mod tests {
         TestCase {
             fields1: vec![&field1_i16_t],
             fields2: vec![&field2_i16_t],
-            expected: false,
+            expected_dfschema: false,
+            expected_arrow: false,
         }
         .run();
 
@@ -1019,7 +1050,8 @@ mod tests {
         TestCase {
             fields1: vec![&field1_i16_t],
             fields2: vec![&field1_i32_t],
-            expected: false,
+            expected_dfschema: false,
+            expected_arrow: false,
         }
         .run();
 
@@ -1027,7 +1059,8 @@ mod tests {
         TestCase {
             fields1: vec![&field1_i16_t],
             fields2: vec![&field1_i16_f],
-            expected: true,
+            expected_dfschema: true,
+            expected_arrow: true,
         }
         .run();
 
@@ -1035,7 +1068,8 @@ mod tests {
         TestCase {
             fields1: vec![&field1_i16_t],
             fields2: vec![&field1_i16_t_qualified],
-            expected: false,
+            expected_dfschema: false,
+            expected_arrow: true,
         }
         .run();
 
@@ -1043,7 +1077,8 @@ mod tests {
         TestCase {
             fields1: vec![&field2_i16_t, &field1_i16_t],
             fields2: vec![&field2_i16_t, &field3_i16_t],
-            expected: false,
+            expected_dfschema: false,
+            expected_arrow: false,
         }
         .run();
 
@@ -1051,7 +1086,8 @@ mod tests {
         TestCase {
             fields1: vec![&field1_i16_t, &field2_i16_t],
             fields2: vec![&field1_i16_t],
-            expected: false,
+            expected_dfschema: false,
+            expected_arrow: false,
         }
         .run();
 
@@ -1059,7 +1095,8 @@ mod tests {
         TestCase {
             fields1: vec![&field_dict_t],
             fields2: vec![&field_dict_t],
-            expected: true,
+            expected_dfschema: true,
+            expected_arrow: true,
         }
         .run();
 
@@ -1067,7 +1104,8 @@ mod tests {
         TestCase {
             fields1: vec![&field_dict_t],
             fields2: vec![&field_dict_f],
-            expected: true,
+            expected_dfschema: true,
+            expected_arrow: true,
         }
         .run();
 
@@ -1075,7 +1113,8 @@ mod tests {
         TestCase {
             fields1: vec![&field_dict_t],
             fields2: vec![&field1_i16_t],
-            expected: false,
+            expected_dfschema: false,
+            expected_arrow: false,
         }
         .run();
 
@@ -1083,7 +1122,8 @@ mod tests {
         TestCase {
             fields1: vec![&list_t],
             fields2: vec![&list_f],
-            expected: true,
+            expected_dfschema: true,
+            expected_arrow: true,
         }
         .run();
 
@@ -1091,7 +1131,8 @@ mod tests {
         TestCase {
             fields1: vec![&list_t],
             fields2: vec![&list_f_name],
-            expected: false,
+            expected_dfschema: false,
+            expected_arrow: false,
         }
         .run();
 
@@ -1099,7 +1140,8 @@ mod tests {
         TestCase {
             fields1: vec![&struct_t],
             fields2: vec![&struct_f],
-            expected: true,
+            expected_dfschema: true,
+            expected_arrow: true,
         }
         .run();
 
@@ -1107,7 +1149,8 @@ mod tests {
         TestCase {
             fields1: vec![&struct_t],
             fields2: vec![&struct_f_meta],
-            expected: true,
+            expected_dfschema: true,
+            expected_arrow: true,
         }
         .run();
 
@@ -1115,7 +1158,8 @@ mod tests {
         TestCase {
             fields1: vec![&struct_t],
             fields2: vec![&struct_f_type],
-            expected: false,
+            expected_dfschema: false,
+            expected_arrow: false,
         }
         .run();
 
@@ -1123,7 +1167,8 @@ mod tests {
         struct TestCase<'a> {
             fields1: Vec<&'a DFField>,
             fields2: Vec<&'a DFField>,
-            expected: bool,
+            expected_dfschema: bool,
+            expected_arrow: bool,
         }
 
         impl<'a> TestCase<'a> {
@@ -1133,13 +1178,25 @@ mod tests {
                 let schema2 = to_df_schema(self.fields2);
                 assert_eq!(
                     schema1.equivalent_names_and_types(&schema2),
-                    self.expected,
+                    self.expected_dfschema,
                     "Comparison did not match expected: {}\n\n\
                      schema1:\n\n{:#?}\n\nschema2:\n\n{:#?}",
-                    self.expected,
+                    self.expected_dfschema,
                     schema1,
                     schema2
                 );
+
+                let arrow_schema1 = Schema::from(schema1);
+                let arrow_schema2 = Schema::from(schema2);
+                assert_eq!(
+                    arrow_schema1.equivalent_names_and_types(&arrow_schema2),
+                    self.expected_arrow,
+                    "Comparison did not match expected: {}\n\n\
+                     arrow schema1:\n\n{:#?}\n\n arrow schema2:\n\n{:#?}",
+                    self.expected_arrow,
+                    arrow_schema1,
+                    arrow_schema2
+                );
             }
         }
 
diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs
index e941e443df..63b4024579 100644
--- a/datafusion/common/src/lib.rs
+++ b/datafusion/common/src/lib.rs
@@ -36,7 +36,7 @@ pub mod tree_node;
 pub mod utils;
 
 pub use column::Column;
-pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema};
+pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, SchemaExt, ToDFSchema};
 pub use error::{
     field_not_found, unqualified_field_not_found, DataFusionError, Result, SchemaError,
     SharedResult,
diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs
index dd6d3cd7f7..c34c82a2c4 100644
--- a/datafusion/core/src/datasource/file_format/csv.rs
+++ b/datafusion/core/src/datasource/file_format/csv.rs
@@ -238,6 +238,7 @@ impl FileFormat for CsvFormat {
         _state: &SessionState,
         conf: FileSinkConfig,
     ) -> Result<Arc<dyn ExecutionPlan>> {
+        let sink_schema = conf.output_schema().clone();
         let sink = Arc::new(CsvSink::new(
             conf,
             self.has_header,
@@ -245,7 +246,7 @@ impl FileFormat for CsvFormat {
             self.file_compression_type.clone(),
         ));
 
-        Ok(Arc::new(InsertExec::new(input, sink)) as _)
+        Ok(Arc::new(InsertExec::new(input, sink, sink_schema)) as _)
     }
 }
 
diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs
index 9beb88144d..c27dcaf391 100644
--- a/datafusion/core/src/datasource/listing/table.rs
+++ b/datafusion/core/src/datasource/listing/table.rs
@@ -25,7 +25,7 @@ use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef};
 use arrow_schema::Schema;
 use async_trait::async_trait;
 use dashmap::DashMap;
-use datafusion_common::ToDFSchema;
+use datafusion_common::{SchemaExt, ToDFSchema};
 use datafusion_expr::expr::Sort;
 use datafusion_optimizer::utils::conjunction;
 use datafusion_physical_expr::{create_physical_expr, LexOrdering, PhysicalSortExpr};
@@ -776,7 +776,7 @@ impl TableProvider for ListingTable {
         input: Arc<dyn ExecutionPlan>,
     ) -> Result<Arc<dyn ExecutionPlan>> {
         // Check that the schema of the plan matches the schema of this table.
-        if !input.schema().eq(&self.schema()) {
+        if !self.schema().equivalent_names_and_types(&input.schema()) {
             return Err(DataFusionError::Plan(
                 // Return an error if schema of the input query does not match with the table schema.
                 "Inserting query must have the same schema with the table.".to_string(),
@@ -816,7 +816,7 @@ impl TableProvider for ListingTable {
         let config = FileSinkConfig {
             object_store_url: self.table_paths()[0].object_store(),
             file_groups,
-            output_schema: input.schema(),
+            output_schema: self.schema(),
             table_partition_cols: self.options.table_partition_cols.clone(),
             writer_mode: crate::datasource::file_format::FileWriterMode::Append,
         };
diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs
index b7cb013eba..784aa2aff2 100644
--- a/datafusion/core/src/datasource/memory.rs
+++ b/datafusion/core/src/datasource/memory.rs
@@ -26,6 +26,7 @@ use std::sync::Arc;
 use arrow::datatypes::SchemaRef;
 use arrow::record_batch::RecordBatch;
 use async_trait::async_trait;
+use datafusion_common::SchemaExt;
 use datafusion_execution::TaskContext;
 use tokio::sync::RwLock;
 
@@ -189,13 +190,13 @@ impl TableProvider for MemTable {
     ) -> Result<Arc<dyn ExecutionPlan>> {
         // Create a physical plan from the logical plan.
         // Check that the schema of the plan matches the schema of this table.
-        if !input.schema().eq(&self.schema) {
+        if !self.schema().equivalent_names_and_types(&input.schema()) {
             return Err(DataFusionError::Plan(
                 "Inserting query must have the same schema with the table.".to_string(),
             ));
         }
         let sink = Arc::new(MemSink::new(self.batches.clone()));
-        Ok(Arc::new(InsertExec::new(input, sink)))
+        Ok(Arc::new(InsertExec::new(input, sink, self.schema.clone())))
     }
 }
 
diff --git a/datafusion/core/src/physical_plan/insert.rs b/datafusion/core/src/physical_plan/insert.rs
index 4742e1617e..15f77914d4 100644
--- a/datafusion/core/src/physical_plan/insert.rs
+++ b/datafusion/core/src/physical_plan/insert.rs
@@ -68,25 +68,72 @@ pub trait DataSink: DisplayAs + Debug + Send + Sync {
 pub struct InsertExec {
     /// Input plan that produces the record batches to be written.
     input: Arc<dyn ExecutionPlan>,
-    /// Sink to whic to write
+    /// Sink to which to write
     sink: Arc<dyn DataSink>,
-    /// Schema describing the structure of the data.
-    schema: SchemaRef,
+    /// Schema of the sink for validating the input data
+    sink_schema: SchemaRef,
+    /// Schema describing the structure of the output data.
+    count_schema: SchemaRef,
 }
 
 impl fmt::Debug for InsertExec {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
-        write!(f, "InsertExec schema: {:?}", self.schema)
+        write!(f, "InsertExec schema: {:?}", self.count_schema)
     }
 }
 
 impl InsertExec {
     /// Create a plan to write to `sink`
-    pub fn new(input: Arc<dyn ExecutionPlan>, sink: Arc<dyn DataSink>) -> Self {
+    pub fn new(
+        input: Arc<dyn ExecutionPlan>,
+        sink: Arc<dyn DataSink>,
+        sink_schema: SchemaRef,
+    ) -> Self {
         Self {
             input,
             sink,
-            schema: make_count_schema(),
+            sink_schema,
+            count_schema: make_count_schema(),
+        }
+    }
+
+    fn make_input_stream(
+        &self,
+        partition: usize,
+        context: Arc<TaskContext>,
+    ) -> Result<SendableRecordBatchStream> {
+        let input_stream = self.input.execute(partition, context)?;
+
+        debug_assert_eq!(
+            self.sink_schema.fields().len(),
+            self.input.schema().fields().len()
+        );
+
+        // Find input columns that may violate the not null constraint.
+        let risky_columns: Vec<_> = self
+            .sink_schema
+            .fields()
+            .iter()
+            .zip(self.input.schema().fields().iter())
+            .enumerate()
+            .filter_map(|(i, (sink_field, input_field))| {
+                if !sink_field.is_nullable() && input_field.is_nullable() {
+                    Some(i)
+                } else {
+                    None
+                }
+            })
+            .collect();
+
+        if risky_columns.is_empty() {
+            Ok(input_stream)
+        } else {
+            // Check not null constraint on the input stream
+            Ok(Box::pin(RecordBatchStreamAdapter::new(
+                self.sink_schema.clone(),
+                input_stream
+                    .map(move |batch| check_not_null_contraits(batch?, &risky_columns)),
+            )))
         }
     }
 }
@@ -99,7 +146,7 @@ impl ExecutionPlan for InsertExec {
 
     /// Get the schema for this execution plan
     fn schema(&self) -> SchemaRef {
-        self.schema.clone()
+        self.count_schema.clone()
     }
 
     fn output_partitioning(&self) -> Partitioning {
@@ -142,7 +189,8 @@ impl ExecutionPlan for InsertExec {
         Ok(Arc::new(Self {
             input: children[0].clone(),
             sink: self.sink.clone(),
-            schema: self.schema.clone(),
+            sink_schema: self.sink_schema.clone(),
+            count_schema: self.count_schema.clone(),
         }))
     }
 
@@ -168,8 +216,9 @@ impl ExecutionPlan for InsertExec {
             )));
         }
 
-        let data = self.input.execute(0, context.clone())?;
-        let schema = self.schema.clone();
+        let data = self.make_input_stream(0, context.clone())?;
+
+        let count_schema = self.count_schema.clone();
         let sink = self.sink.clone();
 
         let stream = futures::stream::once(async move {
@@ -177,7 +226,10 @@ impl ExecutionPlan for InsertExec {
         })
         .boxed();
 
-        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
+        Ok(Box::pin(RecordBatchStreamAdapter::new(
+            count_schema,
+            stream,
+        )))
     }
 
     fn fmt_as(
@@ -221,3 +273,28 @@ fn make_count_schema() -> SchemaRef {
         false,
     )]))
 }
+
+fn check_not_null_contraits(
+    batch: RecordBatch,
+    column_indices: &Vec<usize>,
+) -> Result<RecordBatch> {
+    for i in column_indices {
+        let index = *i;
+        if batch.num_columns() <= index {
+            return Err(DataFusionError::Execution(format!(
+                "Invalid batch column count {} expected > {}",
+                batch.num_columns(),
+                index
+            )));
+        }
+
+        if batch.column(index).null_count() > 0 {
+            return Err(DataFusionError::Execution(format!(
+                "Invalid batch column at '{}' has null but schema specifies non-nullable",
+                index
+            )));
+        }
+    }
+
+    Ok(batch)
+}
diff --git a/datafusion/core/tests/sqllogictests/test_files/insert.slt b/datafusion/core/tests/sqllogictests/test_files/insert.slt
index 2a04eaed32..c710859a7b 100644
--- a/datafusion/core/tests/sqllogictests/test_files/insert.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/insert.slt
@@ -270,3 +270,33 @@ select * from table_without_values;
 
 statement ok
 drop table table_without_values;
+
+
+# test insert with non-nullable column
+statement ok
+CREATE TABLE table_without_values(field1 BIGINT NOT NULL, field2 BIGINT NULL);
+
+query II
+insert into table_without_values values(1, 100);
+----
+1
+
+query II
+insert into table_without_values values(2, NULL);
+----
+1
+
+statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable
+insert into table_without_values values(NULL, 300);
+
+statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable
+insert into table_without_values values(3, 300), (NULL, 400);
+
+query II rowsort
+select * from table_without_values;
+----
+1 100
+2 NULL
+
+statement ok
+drop table table_without_values;
\ No newline at end of file