You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ak...@apache.org on 2023/04/28 08:26:45 UTC

[arrow-datafusion] branch main updated: MemoryExec INSERT INTO refactor to use ExecutionPlan (#6049)

This is an automated email from the ASF dual-hosted git repository.

akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new d5a8c934fa MemoryExec INSERT INTO refactor to use ExecutionPlan (#6049)
d5a8c934fa is described below

commit d5a8c934fa818e52cc021ad250cb1044e29ff9df
Author: Metehan Yıldırım <10...@users.noreply.github.com>
AuthorDate: Fri Apr 28 11:26:38 2023 +0300

    MemoryExec INSERT INTO refactor to use ExecutionPlan (#6049)
    
    * MemoryExec insert into refactor
    
    * Merge leftovers
    
    * Set target partition
    
    * Comment and formatting improvements
    
    * Comments on state.
    
    * Letfover comments
    
    * After merge corrections
    
    * Correction after merge
    
    ---------
    
    Co-authored-by: Mehmet Ozan Kabak <oz...@gmail.com>
---
 datafusion/core/src/datasource/datasource.rs |   4 +-
 datafusion/core/src/datasource/memory.rs     | 283 ++++++------
 datafusion/core/src/execution/context.rs     |  21 +-
 datafusion/core/src/physical_plan/memory.rs  | 646 ++++++++++++++++++++++++++-
 datafusion/core/src/physical_plan/planner.rs |  19 +-
 5 files changed, 800 insertions(+), 173 deletions(-)

diff --git a/datafusion/core/src/datasource/datasource.rs b/datafusion/core/src/datasource/datasource.rs
index 8db075a30a..4560b3820c 100644
--- a/datafusion/core/src/datasource/datasource.rs
+++ b/datafusion/core/src/datasource/datasource.rs
@@ -102,8 +102,8 @@ pub trait TableProvider: Sync + Send {
     async fn insert_into(
         &self,
         _state: &SessionState,
-        _input: &LogicalPlan,
-    ) -> Result<()> {
+        _input: Arc<dyn ExecutionPlan>,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
         let msg = "Insertion not implemented for this table".to_owned();
         Err(DataFusionError::NotImplemented(msg))
     }
diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs
index ca083aebe3..f41f8cb1bd 100644
--- a/datafusion/core/src/datasource/memory.rs
+++ b/datafusion/core/src/datasource/memory.rs
@@ -24,20 +24,22 @@ use std::sync::Arc;
 use arrow::datatypes::SchemaRef;
 use arrow::record_batch::RecordBatch;
 use async_trait::async_trait;
-use datafusion_expr::LogicalPlan;
 use tokio::sync::RwLock;
 
 use crate::datasource::{TableProvider, TableType};
 use crate::error::{DataFusionError, Result};
 use crate::execution::context::SessionState;
 use crate::logical_expr::Expr;
-use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
+use crate::physical_plan::common;
 use crate::physical_plan::common::AbortOnDropSingle;
 use crate::physical_plan::memory::MemoryExec;
+use crate::physical_plan::memory::MemoryWriteExec;
 use crate::physical_plan::ExecutionPlan;
-use crate::physical_plan::{collect_partitioned, common};
 use crate::physical_plan::{repartition::RepartitionExec, Partitioning};
 
+/// Type alias for partition data
+pub type PartitionData = Arc<RwLock<Vec<RecordBatch>>>;
+
 /// In-memory data source for presenting a `Vec<RecordBatch>` as a
 /// data source that can be queried by DataFusion. This allows data to
 /// be pre-loaded into memory and then repeatedly queried without
@@ -45,7 +47,7 @@ use crate::physical_plan::{repartition::RepartitionExec, Partitioning};
 #[derive(Debug)]
 pub struct MemTable {
     schema: SchemaRef,
-    batches: Arc<RwLock<Vec<Vec<RecordBatch>>>>,
+    pub(crate) batches: Vec<PartitionData>,
 }
 
 impl MemTable {
@@ -58,7 +60,10 @@ impl MemTable {
         {
             Ok(Self {
                 schema,
-                batches: Arc::new(RwLock::new(partitions)),
+                batches: partitions
+                    .into_iter()
+                    .map(|e| Arc::new(RwLock::new(e)))
+                    .collect::<Vec<_>>(),
             })
         } else {
             Err(DataFusionError::Plan(
@@ -147,71 +152,62 @@ impl TableProvider for MemTable {
         _filters: &[Expr],
         _limit: Option<usize>,
     ) -> Result<Arc<dyn ExecutionPlan>> {
-        let batches = &self.batches.read().await;
-        Ok(Arc::new(MemoryExec::try_new(
-            batches,
+        let mut partitions = vec![];
+        for arc_inner_vec in self.batches.iter() {
+            let inner_vec = arc_inner_vec.read().await;
+            partitions.push(inner_vec.clone())
+        }
+        Ok(Arc::new(MemoryExec::try_new_owned_data(
+            partitions,
             self.schema(),
             projection.cloned(),
         )?))
     }
 
-    /// Inserts the execution results of a given [LogicalPlan] into this [MemTable].
-    /// The `LogicalPlan` must have the same schema as this `MemTable`.
+    /// Inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`].
+    /// The [`ExecutionPlan`] must have the same schema as this [`MemTable`].
     ///
     /// # Arguments
     ///
-    /// * `state` - The [SessionState] containing the context for executing the plan.
-    /// * `input` - The [LogicalPlan] to execute and insert.
+    /// * `state` - The [`SessionState`] containing the context for executing the plan.
+    /// * `input` - The [`ExecutionPlan`] to execute and insert.
     ///
     /// # Returns
     ///
     /// * A `Result` indicating success or failure.
-    async fn insert_into(&self, state: &SessionState, input: &LogicalPlan) -> Result<()> {
+    async fn insert_into(
+        &self,
+        _state: &SessionState,
+        input: Arc<dyn ExecutionPlan>,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
         // Create a physical plan from the logical plan.
-        let plan = state.create_physical_plan(input).await?;
-
         // Check that the schema of the plan matches the schema of this table.
-        if !plan.schema().eq(&self.schema) {
+        if !input.schema().eq(&self.schema) {
             return Err(DataFusionError::Plan(
                 "Inserting query must have the same schema with the table.".to_string(),
             ));
         }
 
-        // Get the number of partitions in the plan and the table.
-        let plan_partition_count = plan.output_partitioning().partition_count();
-        let table_partition_count = self.batches.read().await.len();
+        if self.batches.is_empty() {
+            return Err(DataFusionError::Plan(
+                "The table must have partitions.".to_string(),
+            ));
+        }
 
-        // Adjust the plan as necessary to match the number of partitions in the table.
-        let plan: Arc<dyn ExecutionPlan> = if plan_partition_count
-            == table_partition_count
-            || table_partition_count == 0
-        {
-            plan
-        } else if table_partition_count == 1 {
-            // If the table has only one partition, coalesce the partitions in the plan.
-            Arc::new(CoalescePartitionsExec::new(plan))
-        } else {
-            // Otherwise, repartition the plan using a round-robin partitioning scheme.
+        let input = if self.batches.len() > 1 {
             Arc::new(RepartitionExec::try_new(
-                plan,
-                Partitioning::RoundRobinBatch(table_partition_count),
+                input,
+                Partitioning::RoundRobinBatch(self.batches.len()),
             )?)
-        };
-
-        let results = collect_partitioned(plan, state.task_ctx()).await?;
-
-        // Write the results into the table.
-        let mut all_batches = self.batches.write().await;
-
-        if all_batches.is_empty() {
-            *all_batches = results
         } else {
-            for (batches, result) in all_batches.iter_mut().zip(results.into_iter()) {
-                batches.extend(result);
-            }
-        }
+            input
+        };
 
-        Ok(())
+        Ok(Arc::new(MemoryWriteExec::try_new(
+            input,
+            self.batches.clone(),
+            self.schema.clone(),
+        )?))
     }
 }
 
@@ -220,6 +216,7 @@ mod tests {
     use super::*;
     use crate::datasource::provider_as_source;
     use crate::from_slice::FromSlice;
+    use crate::physical_plan::collect;
     use crate::prelude::SessionContext;
     use arrow::array::Int32Array;
     use arrow::datatypes::{DataType, Field, Schema};
@@ -455,21 +452,48 @@ mod tests {
         Ok(())
     }
 
-    fn create_mem_table_scan(
+    async fn experiment(
         schema: SchemaRef,
-        data: Vec<Vec<RecordBatch>>,
-    ) -> Result<Arc<LogicalPlan>> {
-        // Convert the table into a provider so that it can be used in a query
-        let provider = provider_as_source(Arc::new(MemTable::try_new(schema, data)?));
-        // Create a table scan logical plan to read from the table
-        Ok(Arc::new(
-            LogicalPlanBuilder::scan("source", provider, None)?.build()?,
-        ))
-    }
-
-    fn create_initial_ctx() -> Result<(SessionContext, SchemaRef, RecordBatch)> {
+        initial_data: Vec<Vec<RecordBatch>>,
+        inserted_data: Vec<Vec<RecordBatch>>,
+    ) -> Result<Vec<Vec<RecordBatch>>> {
         // Create a new session context
         let session_ctx = SessionContext::new();
+        // Create and register the initial table with the provided schema and data
+        let initial_table = Arc::new(MemTable::try_new(schema.clone(), initial_data)?);
+        session_ctx.register_table("t", initial_table.clone())?;
+        // Create and register the source table with the provided schema and inserted data
+        let source_table = Arc::new(MemTable::try_new(schema.clone(), inserted_data)?);
+        session_ctx.register_table("source", source_table.clone())?;
+        // Convert the source table into a provider so that it can be used in a query
+        let source = provider_as_source(source_table);
+        // Create a table scan logical plan to read from the source table
+        let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?;
+        // Create an insert plan to insert the source data into the initial table
+        let insert_into_table =
+            LogicalPlanBuilder::insert_into(scan_plan, "t", &schema)?.build()?;
+        // Create a physical plan from the insert plan
+        let plan = session_ctx
+            .state()
+            .create_physical_plan(&insert_into_table)
+            .await?;
+
+        // Execute the physical plan and collect the results
+        let res = collect(plan, session_ctx.task_ctx()).await?;
+        // Ensure the result is empty after the insert operation
+        assert!(res.is_empty());
+        // Read the data from the initial table and store it in a vector of partitions
+        let mut partitions = vec![];
+        for partition in initial_table.batches.iter() {
+            let part = partition.read().await.clone();
+            partitions.push(part);
+        }
+        Ok(partitions)
+    }
+
+    // Test inserting a single batch of data into a single partition
+    #[tokio::test]
+    async fn test_insert_into_single_partition() -> Result<()> {
         // Create a new schema with one field called "a" of type Int32
         let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
 
@@ -478,111 +502,84 @@ mod tests {
             schema.clone(),
             vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
         )?;
-        Ok((session_ctx, schema, batch))
-    }
-
-    #[tokio::test]
-    async fn test_insert_into_single_partition() -> Result<()> {
-        let (session_ctx, schema, batch) = create_initial_ctx()?;
-        let initial_table = Arc::new(MemTable::try_new(
-            schema.clone(),
-            vec![vec![batch.clone()]],
-        )?);
-        // Create a table scan logical plan to read from the table
-        let single_partition_table_scan =
-            create_mem_table_scan(schema.clone(), vec![vec![batch.clone()]])?;
-        // Insert the data from the provider into the table
-        initial_table
-            .insert_into(&session_ctx.state(), &single_partition_table_scan)
-            .await?;
+        // Run the experiment and obtain the resulting data in the table
+        let resulting_data_in_table =
+            experiment(schema, vec![vec![batch.clone()]], vec![vec![batch.clone()]])
+                .await?;
         // Ensure that the table now contains two batches of data in the same partition
-        assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2);
-
-        // Create a new provider with 2 partitions
-        let multi_partition_table_scan = create_mem_table_scan(
-            schema.clone(),
-            vec![vec![batch.clone()], vec![batch]],
-        )?;
-
-        // Insert the data from the provider into the table. We expect coalescing partitions.
-        initial_table
-            .insert_into(&session_ctx.state(), &multi_partition_table_scan)
-            .await?;
-        // Ensure that the table now contains 4 batches of data with only 1 partition
-        assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4);
-        assert_eq!(initial_table.batches.read().await.len(), 1);
+        assert_eq!(resulting_data_in_table[0].len(), 2);
         Ok(())
     }
 
+    // Test inserting multiple batches of data into a single partition
     #[tokio::test]
-    async fn test_insert_into_multiple_partition() -> Result<()> {
-        let (session_ctx, schema, batch) = create_initial_ctx()?;
-        // create a memory table with two partitions, each having one batch with the same data
-        let initial_table = Arc::new(MemTable::try_new(
-            schema.clone(),
-            vec![vec![batch.clone()], vec![batch.clone()]],
-        )?);
+    async fn test_insert_into_single_partition_with_multi_partition() -> Result<()> {
+        // Create a new schema with one field called "a" of type Int32
+        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
 
-        // scan a data source provider from a memory table with a single partition
-        let single_partition_table_scan = create_mem_table_scan(
+        // Create a new batch of data to insert into the table
+        let batch = RecordBatch::try_new(
             schema.clone(),
-            vec![vec![batch.clone(), batch.clone()]],
+            vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
         )?;
-
-        // insert the data from the 1 partition data source provider into the initial table
-        initial_table
-            .insert_into(&session_ctx.state(), &single_partition_table_scan)
-            .await?;
-
-        // We expect round robin repartition here, each partition gets 1 batch.
-        assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2);
-        assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 2);
-
-        // scan a data source provider from a memory table with 2 partition
-        let multi_partition_table_scan = create_mem_table_scan(
-            schema.clone(),
+        // Run the experiment and obtain the resulting data in the table
+        let resulting_data_in_table = experiment(
+            schema,
+            vec![vec![batch.clone()]],
             vec![vec![batch.clone()], vec![batch]],
-        )?;
-        // We expect one-to-one partition mapping.
-        initial_table
-            .insert_into(&session_ctx.state(), &multi_partition_table_scan)
-            .await?;
-        // Ensure that the table now contains 3 batches of data with 2 partitions.
-        assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 3);
-        assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 3);
+        )
+        .await?;
+        // Ensure that the table now contains three batches of data in the same partition
+        assert_eq!(resulting_data_in_table[0].len(), 3);
         Ok(())
     }
 
+    // Test inserting multiple batches of data into multiple partitions
     #[tokio::test]
-    async fn test_insert_into_empty_table() -> Result<()> {
-        let (session_ctx, schema, batch) = create_initial_ctx()?;
-        // create empty memory table
-        let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![])?);
+    async fn test_insert_into_multi_partition_with_multi_partition() -> Result<()> {
+        // Create a new schema with one field called "a" of type Int32
+        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
 
-        // scan a data source provider from a memory table with a single partition
-        let single_partition_table_scan = create_mem_table_scan(
+        // Create a new batch of data to insert into the table
+        let batch = RecordBatch::try_new(
             schema.clone(),
-            vec![vec![batch.clone(), batch.clone()]],
+            vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
         )?;
+        // Run the experiment and obtain the resulting data in the table
+        let resulting_data_in_table = experiment(
+            schema,
+            vec![vec![batch.clone()], vec![batch.clone()]],
+            vec![
+                vec![batch.clone(), batch.clone()],
+                vec![batch.clone(), batch],
+            ],
+        )
+        .await?;
+        // Ensure that each partition in the table now contains three batches of data
+        assert_eq!(resulting_data_in_table[0].len(), 3);
+        assert_eq!(resulting_data_in_table[1].len(), 3);
+        Ok(())
+    }
 
-        // insert the data from the 1 partition data source provider into the initial table
-        initial_table
-            .insert_into(&session_ctx.state(), &single_partition_table_scan)
-            .await?;
-
-        assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2);
+    #[tokio::test]
+    async fn test_insert_from_empty_table() -> Result<()> {
+        // Create a new schema with one field called "a" of type Int32
+        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
 
-        // scan a data source provider from a memory table with 2 partition
-        let single_partition_table_scan = create_mem_table_scan(
+        // Create a new batch of data to insert into the table
+        let batch = RecordBatch::try_new(
             schema.clone(),
-            vec![vec![batch.clone()], vec![batch]],
+            vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
         )?;
-        // We expect coalesce partitions here.
-        initial_table
-            .insert_into(&session_ctx.state(), &single_partition_table_scan)
-            .await?;
-        // Ensure that the table now contains 3 batches of data with 2 partitions.
-        assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4);
+        // Run the experiment and obtain the resulting data in the table
+        let resulting_data_in_table = experiment(
+            schema,
+            vec![vec![batch.clone(), batch.clone()]],
+            vec![vec![]],
+        )
+        .await?;
+        // Ensure that the table now contains two batches of data in the same partition
+        assert_eq!(resulting_data_in_table[0].len(), 2);
         Ok(())
     }
 }
diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs
index bb6d58fb90..dce8a1c424 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -33,7 +33,7 @@ use crate::{
 };
 use datafusion_expr::{
     logical_plan::{DdlStatement, Statement},
-    DescribeTable, DmlStatement, StringifiedPlan, WriteOp,
+    DescribeTable, StringifiedPlan,
 };
 pub use datafusion_physical_expr::execution_props::ExecutionProps;
 use datafusion_physical_expr::var_provider::is_system_variables;
@@ -369,23 +369,6 @@ impl SessionContext {
     /// Execute the [`LogicalPlan`], return a [`DataFrame`]
     pub async fn execute_logical_plan(&self, plan: LogicalPlan) -> Result<DataFrame> {
         match plan {
-            LogicalPlan::Dml(DmlStatement {
-                table_name,
-                op: WriteOp::Insert,
-                input,
-                ..
-            }) => {
-                if self.table_exist(&table_name)? {
-                    let name = table_name.table();
-                    let provider = self.table_provider(name).await?;
-                    provider.insert_into(&self.state(), &input).await?;
-                } else {
-                    return Err(DataFusionError::Execution(format!(
-                        "Table '{table_name}' does not exist"
-                    )));
-                }
-                self.return_empty_dataframe()
-            }
             LogicalPlan::Ddl(ddl) => match ddl {
                 DdlStatement::CreateExternalTable(cmd) => {
                     self.create_external_table(&cmd).await
@@ -1475,7 +1458,7 @@ impl SessionState {
             .resolve(&catalog.default_catalog, &catalog.default_schema)
     }
 
-    fn schema_for_ref<'a>(
+    pub(crate) fn schema_for_ref<'a>(
         &'a self,
         table_ref: impl Into<TableReference<'a>>,
     ) -> Result<Arc<dyn SchemaProvider>> {
diff --git a/datafusion/core/src/physical_plan/memory.rs b/datafusion/core/src/physical_plan/memory.rs
index f0cd48fa4f..12a37c65c8 100644
--- a/datafusion/core/src/physical_plan/memory.rs
+++ b/datafusion/core/src/physical_plan/memory.rs
@@ -17,11 +17,6 @@
 
 //! Execution plan for reading in-memory batches of data
 
-use core::fmt;
-use std::any::Any;
-use std::sync::Arc;
-use std::task::{Context, Poll};
-
 use super::expressions::PhysicalSortExpr;
 use super::{
     common, project_schema, DisplayFormatType, ExecutionPlan, Partitioning,
@@ -30,10 +25,20 @@ use super::{
 use crate::error::Result;
 use arrow::datatypes::SchemaRef;
 use arrow::record_batch::RecordBatch;
+use core::fmt;
+use futures::FutureExt;
+use futures::StreamExt;
+use std::any::Any;
+use std::sync::Arc;
+use std::task::{Context, Poll};
 
+use crate::datasource::memory::PartitionData;
 use crate::execution::context::TaskContext;
+use crate::physical_plan::Distribution;
 use datafusion_common::DataFusionError;
-use futures::Stream;
+use futures::{ready, Stream};
+use std::mem;
+use tokio::sync::{OwnedRwLockWriteGuard, RwLock};
 
 /// Execution plan for reading in-memory batches of data
 pub struct MemoryExec {
@@ -150,6 +155,23 @@ impl MemoryExec {
         })
     }
 
+    /// Create a new execution plan for reading in-memory record batches
+    /// The provided `schema` should not have the projection applied.
+    pub fn try_new_owned_data(
+        partitions: Vec<Vec<RecordBatch>>,
+        schema: SchemaRef,
+        projection: Option<Vec<usize>>,
+    ) -> Result<Self> {
+        let projected_schema = project_schema(&schema, projection.as_ref())?;
+        Ok(Self {
+            partitions,
+            schema,
+            projected_schema,
+            projection,
+            sort_information: None,
+        })
+    }
+
     /// Set sort information
     pub fn with_sort_information(
         mut self,
@@ -223,15 +245,365 @@ impl RecordBatchStream for MemoryStream {
     }
 }
 
+/// Execution plan for writing record batches to an in-memory table.
+pub struct MemoryWriteExec {
+    /// Input plan that produces the record batches to be written.
+    input: Arc<dyn ExecutionPlan>,
+    /// Reference to the MemTable's partition data.
+    batches: Vec<PartitionData>,
+    /// Schema describing the structure of the data.
+    schema: SchemaRef,
+}
+
+impl fmt::Debug for MemoryWriteExec {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(f, "schema: {:?}", self.schema)
+    }
+}
+
+impl ExecutionPlan for MemoryWriteExec {
+    /// Return a reference to Any that can be used for downcasting
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    /// Get the schema for this execution plan
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+
+    fn output_partitioning(&self) -> Partitioning {
+        Partitioning::UnknownPartitioning(
+            self.input.output_partitioning().partition_count(),
+        )
+    }
+
+    fn benefits_from_input_partitioning(&self) -> bool {
+        false
+    }
+
+    fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
+        self.input.output_ordering()
+    }
+
+    fn required_input_distribution(&self) -> Vec<Distribution> {
+        // If the partition count of the MemTable is one, we want to require SinglePartition
+        // since it would induce better plans in plan optimizer.
+        if self.batches.len() == 1 {
+            vec![Distribution::SinglePartition]
+        } else {
+            vec![Distribution::UnspecifiedDistribution]
+        }
+    }
+
+    fn maintains_input_order(&self) -> Vec<bool> {
+        // In theory, if MemTable partition count equals the input plans output partition count,
+        // the Execution plan can preserve the order inside the partitions.
+        vec![self.batches.len() == self.input.output_partitioning().partition_count()]
+    }
+
+    fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+        vec![self.input.clone()]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        children: Vec<Arc<dyn ExecutionPlan>>,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        Ok(Arc::new(MemoryWriteExec::try_new(
+            children[0].clone(),
+            self.batches.clone(),
+            self.schema.clone(),
+        )?))
+    }
+
+    /// Execute the plan and return a stream of record batches for the specified partition.
+    /// Depending on the number of input partitions and MemTable partitions, it will choose
+    /// either a less lock acquiring or a locked implementation.
+    fn execute(
+        &self,
+        partition: usize,
+        context: Arc<TaskContext>,
+    ) -> Result<SendableRecordBatchStream> {
+        let batch_count = self.batches.len();
+        let data = self.input.execute(partition, context)?;
+        if batch_count >= self.input.output_partitioning().partition_count() {
+            // If the number of input partitions matches the number of MemTable partitions,
+            // use a lightweight implementation that doesn't utilize as many locks.
+            let table_partition = self.batches[partition].clone();
+            Ok(Box::pin(MemorySinkOneToOneStream::try_new(
+                table_partition,
+                data,
+                self.schema.clone(),
+            )?))
+        } else {
+            // Otherwise, use the locked implementation.
+            let table_partition = self.batches[partition % batch_count].clone();
+            Ok(Box::pin(MemorySinkStream::try_new(
+                table_partition,
+                data,
+                self.schema.clone(),
+            )?))
+        }
+    }
+
+    fn fmt_as(
+        &self,
+        t: DisplayFormatType,
+        f: &mut std::fmt::Formatter,
+    ) -> std::fmt::Result {
+        match t {
+            DisplayFormatType::Default => {
+                write!(
+                    f,
+                    "MemoryWriteExec: partitions={}, input_partition={}",
+                    self.batches.len(),
+                    self.input.output_partitioning().partition_count()
+                )
+            }
+        }
+    }
+
+    fn statistics(&self) -> Statistics {
+        Statistics::default()
+    }
+}
+
+impl MemoryWriteExec {
+    /// Create a new execution plan for reading in-memory record batches
+    /// The provided `schema` should not have the projection applied.
+    pub fn try_new(
+        plan: Arc<dyn ExecutionPlan>,
+        batches: Vec<Arc<RwLock<Vec<RecordBatch>>>>,
+        schema: SchemaRef,
+    ) -> Result<Self> {
+        Ok(Self {
+            input: plan,
+            batches,
+            schema,
+        })
+    }
+}
+
+/// This object encodes the different states of the [`MemorySinkStream`] when
+/// processing record batches.
+enum MemorySinkStreamState {
+    /// The stream is pulling data from the input.
+    Pull,
+    /// The stream is writing data to the table partition.
+    Write { maybe_batch: Option<RecordBatch> },
+}
+
+/// A stream that saves record batches in memory-backed storage.
+/// Can work even when multiple input partitions map to the same table
+/// partition, achieves buffer exclusivity by locking before writing.
+struct MemorySinkStream {
+    /// Stream of record batches to be inserted into the memory table.
+    data: SendableRecordBatchStream,
+    /// Memory table partition that stores the record batches.
+    table_partition: PartitionData,
+    /// Schema representing the structure of the data.
+    schema: SchemaRef,
+    /// State of the iterator when processing multiple polls.
+    state: MemorySinkStreamState,
+}
+
+impl MemorySinkStream {
+    /// Create a new `MemorySinkStream` with the provided parameters.
+    pub fn try_new(
+        table_partition: PartitionData,
+        data: SendableRecordBatchStream,
+        schema: SchemaRef,
+    ) -> Result<Self> {
+        Ok(Self {
+            table_partition,
+            data,
+            schema,
+            state: MemorySinkStreamState::Pull,
+        })
+    }
+
+    /// Implementation of the `poll_next` method. Continuously polls the record
+    /// batch stream, switching between the Pull and Write states. In case of
+    /// an error, returns the error immediately.
+    fn poll_next_impl(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Option<Result<RecordBatch>>> {
+        loop {
+            match &mut self.state {
+                MemorySinkStreamState::Pull => {
+                    // Pull data from the input stream.
+                    if let Some(result) = ready!(self.data.as_mut().poll_next(cx)) {
+                        match result {
+                            Ok(batch) => {
+                                // Switch to the Write state with the received batch.
+                                self.state = MemorySinkStreamState::Write {
+                                    maybe_batch: Some(batch),
+                                }
+                            }
+                            Err(e) => return Poll::Ready(Some(Err(e))), // Return the error immediately.
+                        }
+                    } else {
+                        return Poll::Ready(None); // If the input stream is exhausted, return None.
+                    }
+                }
+                MemorySinkStreamState::Write { maybe_batch } => {
+                    // Acquire a write lock on the table partition.
+                    let mut partition =
+                        ready!(self.table_partition.write().boxed().poll_unpin(cx));
+                    if let Some(b) = mem::take(maybe_batch) {
+                        partition.push(b); // Insert the batch into the table partition.
+                    }
+                    self.state = MemorySinkStreamState::Pull; // Switch back to the Pull state.
+                }
+            }
+        }
+    }
+}
+
+impl Stream for MemorySinkStream {
+    type Item = Result<RecordBatch>;
+
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        self.poll_next_impl(cx)
+    }
+}
+
+impl RecordBatchStream for MemorySinkStream {
+    /// Get the schema
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+}
+
+/// This object encodes the different states of the [`MemorySinkOneToOneStream`]
+/// when processing record batches.
+enum MemorySinkOneToOneStreamState {
+    /// The `Acquire` variant represents the state where the [`MemorySinkOneToOneStream`]
+    /// is waiting to acquire the write lock on the shared partition to store the record batches.
+    Acquire,
+
+    /// The `Pull` variant represents the state where the [`MemorySinkOneToOneStream`] has
+    /// acquired the write lock on the shared partition and can pull record batches from
+    /// the input stream to store in the partition.
+    Pull {
+        /// The `partition` field contains an [`OwnedRwLockWriteGuard`] which wraps the
+        /// shared partition, providing exclusive write access to the underlying `Vec<RecordBatch>`.
+        partition: OwnedRwLockWriteGuard<Vec<RecordBatch>>,
+    },
+}
+
+/// A stream that saves record batches in memory-backed storage.
+/// Assumes that every table partition has at most one corresponding input
+/// partition, so it locks the table partition only once.
+struct MemorySinkOneToOneStream {
+    /// Stream of record batches to be inserted into the memory table.
+    data: SendableRecordBatchStream,
+    /// Memory table partition that stores the record batches.
+    table_partition: PartitionData,
+    /// Schema representing the structure of the data.
+    schema: SchemaRef,
+    /// State of the iterator when processing multiple polls.
+    state: MemorySinkOneToOneStreamState,
+}
+
+impl MemorySinkOneToOneStream {
+    /// Create a new `MemorySinkOneToOneStream` with the provided parameters.
+    pub fn try_new(
+        table_partition: Arc<RwLock<Vec<RecordBatch>>>,
+        data: SendableRecordBatchStream,
+        schema: SchemaRef,
+    ) -> Result<Self> {
+        Ok(Self {
+            table_partition,
+            data,
+            schema,
+            state: MemorySinkOneToOneStreamState::Acquire,
+        })
+    }
+
+    /// Implementation of the `poll_next` method. Continuously polls the record
+    /// batch stream and pushes batches to their corresponding table partition,
+    /// which are lock-acquired only once. In case of an error, returns the
+    /// error immediately.
+    fn poll_next_impl(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Option<Result<RecordBatch>>> {
+        loop {
+            match &mut self.state {
+                MemorySinkOneToOneStreamState::Acquire => {
+                    // Acquire a write lock on the table partition.
+                    self.state = MemorySinkOneToOneStreamState::Pull {
+                        partition: ready!(self
+                            .table_partition
+                            .clone()
+                            .write_owned()
+                            .boxed()
+                            .poll_unpin(cx)),
+                    };
+                }
+                MemorySinkOneToOneStreamState::Pull { partition } => {
+                    // Iterate over the batches in the input data stream.
+                    while let Some(result) = ready!(self.data.poll_next_unpin(cx)) {
+                        match result {
+                            Ok(batch) => {
+                                partition.push(batch);
+                            } // Insert the batch into the table partition.
+                            Err(e) => return Poll::Ready(Some(Err(e))), // Return the error immediately.
+                        }
+                    }
+                    // If the input stream is exhausted, return None to indicate the end of the stream.
+                    return Poll::Ready(None);
+                }
+            }
+        }
+    }
+}
+
+impl Stream for MemorySinkOneToOneStream {
+    type Item = Result<RecordBatch>;
+
+    /// Poll the stream for the next item.
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        self.poll_next_impl(cx)
+    }
+}
+
+impl RecordBatchStream for MemorySinkOneToOneStream {
+    /// Get the schema of the record batches in the stream.
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::datasource::streaming::PartitionStream;
+    use crate::datasource::{MemTable, TableProvider};
     use crate::from_slice::FromSlice;
+    use crate::physical_plan::stream::RecordBatchStreamAdapter;
+    use crate::physical_plan::streaming::StreamingTableExec;
     use crate::physical_plan::ColumnStatistics;
-    use crate::prelude::SessionContext;
+    use crate::physical_plan::{collect, displayable, SendableRecordBatchStream};
+    use crate::prelude::{CsvReadOptions, SessionContext};
+    use crate::test_util;
     use arrow::array::Int32Array;
-    use arrow::datatypes::{DataType, Field, Schema};
+    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+    use arrow::record_batch::RecordBatch;
+    use datafusion_common::Result;
+    use datafusion_execution::config::SessionConfig;
+    use datafusion_execution::TaskContext;
     use futures::StreamExt;
+    use std::sync::Arc;
 
     fn mock_data() -> Result<(SchemaRef, RecordBatch)> {
         let schema = Arc::new(Schema::new(vec![
@@ -340,4 +712,262 @@ mod tests {
 
         Ok(())
     }
+
+    #[tokio::test]
+    async fn test_insert_into() -> Result<()> {
+        // Create session context
+        let config = SessionConfig::new().with_target_partitions(8);
+        let ctx = SessionContext::with_config(config);
+        let testdata = test_util::arrow_test_data();
+        let schema = test_util::aggr_test_schema();
+        ctx.register_csv(
+            "aggregate_test_100",
+            &format!("{testdata}/csv/aggregate_test_100.csv"),
+            CsvReadOptions::new().schema(&schema),
+        )
+        .await?;
+        ctx.sql(
+            "CREATE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL)",
+        )
+        .await?;
+
+        let sql = "INSERT INTO table_without_values SELECT
+                SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),
+                COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
+                FROM aggregate_test_100
+                ORDER by c1
+            ";
+        let msg = format!("Creating logical plan for '{sql}'");
+        let dataframe = ctx.sql(sql).await.expect(&msg);
+        let physical_plan = dataframe.create_physical_plan().await?;
+        let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+        let expected = {
+            vec![
+                "MemoryWriteExec: partitions=1, input_partition=1",
+                "  ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2]",
+                "    SortPreservingMergeExec: [c1@2 ASC NULLS LAST]",
+                "      ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(UInt8(1)), c1@0 as c1]",
+                "        BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Ro [...]
+                "          SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST]",
+                "            CoalesceBatchesExec: target_batch_size=8192",
+                "              RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 8), input_partitions=8",
+                "                RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1",
+            ]
+        };
+
+        let actual: Vec<&str> = formatted.trim().lines().collect();
+        let actual_len = actual.len();
+        let actual_trim_last = &actual[..actual_len - 1];
+        assert_eq!(
+            expected, actual_trim_last,
+            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+        );
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_insert_into_as_select_multi_partitioned() -> Result<()> {
+        // Create session context
+        let config = SessionConfig::new().with_target_partitions(8);
+        let ctx = SessionContext::with_config(config);
+        let testdata = test_util::arrow_test_data();
+        let schema = test_util::aggr_test_schema();
+        ctx.register_csv(
+            "aggregate_test_100",
+            &format!("{testdata}/csv/aggregate_test_100.csv"),
+            CsvReadOptions::new().schema(&schema),
+        )
+        .await?;
+        ctx.sql(
+            "CREATE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL)",
+        )
+        .await?;
+
+        let sql = "INSERT INTO table_without_values SELECT
+                SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1,
+                COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2
+                FROM aggregate_test_100";
+        let msg = format!("Creating logical plan for '{sql}'");
+        let dataframe = ctx.sql(sql).await.expect(&msg);
+        let physical_plan = dataframe.create_physical_plan().await?;
+        let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+        let expected = {
+            vec![
+                "MemoryWriteExec: partitions=1, input_partition=1",
+                "  CoalescePartitionsExec",
+                "    ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2]",
+                "      BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows [...]
+                "        SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST]",
+                "          CoalesceBatchesExec: target_batch_size=8192",
+                "            RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 8), input_partitions=8",
+                "              RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1",
+            ]
+        };
+
+        let actual: Vec<&str> = formatted.trim().lines().collect();
+        let actual_len = actual.len();
+        let actual_trim_last = &actual[..actual_len - 1];
+        assert_eq!(
+            expected, actual_trim_last,
+            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+        );
+        Ok(())
+    }
+
+    // TODO: The generated plan is suboptimal since SortExec is in global state.
+    #[tokio::test]
+    async fn test_insert_into_as_select_single_partition() -> Result<()> {
+        // Create session context
+        let config = SessionConfig::new().with_target_partitions(8);
+        let ctx = SessionContext::with_config(config);
+        let testdata = test_util::arrow_test_data();
+        let schema = test_util::aggr_test_schema();
+        ctx.register_csv(
+            "aggregate_test_100",
+            &format!("{testdata}/csv/aggregate_test_100.csv"),
+            CsvReadOptions::new().schema(&schema),
+        )
+        .await?;
+        ctx.sql("CREATE TABLE table_without_values AS SELECT
+                SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1,
+                COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2
+                FROM aggregate_test_100")
+            .await?;
+
+        let sql = "INSERT INTO table_without_values SELECT
+                SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1,
+                COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2
+                FROM aggregate_test_100
+                ORDER BY c1";
+        let msg = format!("Creating logical plan for '{sql}'");
+        let dataframe = ctx.sql(sql).await.expect(&msg);
+        let physical_plan = dataframe.create_physical_plan().await?;
+        let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+        let expected = {
+            vec![
+                "MemoryWriteExec: partitions=8, input_partition=8",
+                "  RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1",
+                "    ProjectionExec: expr=[a1@0 as a1, a2@1 as a2]",
+                "      SortPreservingMergeExec: [c1@2 ASC NULLS LAST]",
+                "        ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as a1, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as a2, c1@0 as c1]",
+                "          BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units:  [...]
+                "            SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST]",
+                "              CoalesceBatchesExec: target_batch_size=8192",
+                "                RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 8), input_partitions=8",
+                "                  RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1",
+            ]
+        };
+
+        let actual: Vec<&str> = formatted.trim().lines().collect();
+        let actual_len = actual.len();
+        let actual_trim_last = &actual[..actual_len - 1];
+        assert_eq!(
+            expected, actual_trim_last,
+            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+        );
+        Ok(())
+    }
+
+    // DummyPartition is a simple implementation of the PartitionStream trait.
+    // It produces a stream of record batches with a fixed schema and the same content.
+    struct DummyPartition {
+        schema: SchemaRef,
+        batch: RecordBatch,
+        num_batches: usize,
+    }
+
+    impl PartitionStream for DummyPartition {
+        // Return a reference to the schema of this partition.
+        fn schema(&self) -> &SchemaRef {
+            &self.schema
+        }
+
+        // Execute the partition stream, producing a stream of record batches.
+        fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
+            let batches = itertools::repeat_n(self.batch.clone(), self.num_batches);
+            Box::pin(RecordBatchStreamAdapter::new(
+                self.schema.clone(),
+                futures::stream::iter(batches).map(Ok),
+            ))
+        }
+    }
+
+    // Test the less-lock mode by inserting a large number of batches into a table.
+    #[tokio::test]
+    async fn test_one_to_one_mode() -> Result<()> {
+        let num_batches = 10000;
+        // Create a new session context
+        let session_ctx = SessionContext::new();
+        // Create a new schema with one field called "a" of type Int32
+        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
+
+        // Create a new batch of data to insert into the table
+        let batch = RecordBatch::try_new(
+            schema.clone(),
+            vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
+        )?;
+        let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?);
+
+        let single_partition = Arc::new(DummyPartition {
+            schema: schema.clone(),
+            batch,
+            num_batches,
+        });
+        let input = Arc::new(StreamingTableExec::try_new(
+            schema.clone(),
+            vec![single_partition],
+            None,
+            false,
+        )?);
+        let plan = initial_table
+            .insert_into(&session_ctx.state(), input)
+            .await?;
+        let res = collect(plan, session_ctx.task_ctx()).await?;
+        assert!(res.is_empty());
+        // Ensure that the table now contains two batches of data in the same partition
+        assert_eq!(initial_table.batches[0].read().await.len(), num_batches);
+        Ok(())
+    }
+
+    // Test the locked mode by inserting a large number of batches into a table. It tests
+    // where the table partition count is not equal to the input's output partition count.
+    #[tokio::test]
+    async fn test_locked_mode() -> Result<()> {
+        let num_batches = 10000;
+        // Create a new session context
+        let session_ctx = SessionContext::new();
+        // Create a new schema with one field called "a" of type Int32
+        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
+
+        // Create a new batch of data to insert into the table
+        let batch = RecordBatch::try_new(
+            schema.clone(),
+            vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
+        )?;
+        let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?);
+
+        let single_partition = Arc::new(DummyPartition {
+            schema: schema.clone(),
+            batch,
+            num_batches,
+        });
+        let input = Arc::new(StreamingTableExec::try_new(
+            schema.clone(),
+            vec![
+                single_partition.clone(),
+                single_partition.clone(),
+                single_partition,
+            ],
+            None,
+            false,
+        )?);
+        let plan = initial_table
+            .insert_into(&session_ctx.state(), input)
+            .await?;
+        let res = collect(plan, session_ctx.task_ctx()).await?;
+        assert!(res.is_empty());
+        // Ensure that the table now contains two batches of data in the same partition
+        assert_eq!(initial_table.batches[0].read().await.len(), num_batches * 3);
+        Ok(())
+    }
 }
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 782dcf1335..7f68d5a39a 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -67,7 +67,7 @@ use datafusion_expr::expr::{
 };
 use datafusion_expr::expr_rewriter::unnormalize_cols;
 use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
-use datafusion_expr::{logical_plan, StringifiedPlan};
+use datafusion_expr::{logical_plan, DmlStatement, StringifiedPlan, WriteOp};
 use datafusion_expr::{WindowFrame, WindowFrameBound};
 use datafusion_optimizer::utils::unalias;
 use datafusion_physical_expr::expressions::Literal;
@@ -489,6 +489,23 @@ impl DefaultPhysicalPlanner {
                     let unaliased: Vec<Expr> = filters.into_iter().map(unalias).collect();
                     source.scan(session_state, projection.as_ref(), &unaliased, *fetch).await
                 }
+                LogicalPlan::Dml(DmlStatement {
+                    table_name,
+                    op: WriteOp::Insert,
+                    input,
+                    ..
+                }) => {
+                    let name = table_name.table();
+                    let schema = session_state.schema_for_ref(table_name)?;
+                    if let Some(provider) = schema.table(name).await {
+                        let input_exec = self.create_initial_plan(input, session_state).await?;
+                        provider.insert_into(session_state, input_exec).await
+                    } else {
+                        return Err(DataFusionError::Execution(format!(
+                            "Table '{table_name}' does not exist"
+                        )));
+                    }
+                }
                 LogicalPlan::Values(Values {
                     values,
                     schema,