You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ak...@apache.org on 2023/04/28 08:26:45 UTC
[arrow-datafusion] branch main updated: MemoryExec INSERT INTO refactor to use ExecutionPlan (#6049)
This is an automated email from the ASF dual-hosted git repository.
akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new d5a8c934fa MemoryExec INSERT INTO refactor to use ExecutionPlan (#6049)
d5a8c934fa is described below
commit d5a8c934fa818e52cc021ad250cb1044e29ff9df
Author: Metehan Yıldırım <10...@users.noreply.github.com>
AuthorDate: Fri Apr 28 11:26:38 2023 +0300
MemoryExec INSERT INTO refactor to use ExecutionPlan (#6049)
* MemoryExec insert into refactor
* Merge leftovers
* Set target partition
* Comment and formatting improvements
* Comments on state.
* Letfover comments
* After merge corrections
* Correction after merge
---------
Co-authored-by: Mehmet Ozan Kabak <oz...@gmail.com>
---
datafusion/core/src/datasource/datasource.rs | 4 +-
datafusion/core/src/datasource/memory.rs | 283 ++++++------
datafusion/core/src/execution/context.rs | 21 +-
datafusion/core/src/physical_plan/memory.rs | 646 ++++++++++++++++++++++++++-
datafusion/core/src/physical_plan/planner.rs | 19 +-
5 files changed, 800 insertions(+), 173 deletions(-)
diff --git a/datafusion/core/src/datasource/datasource.rs b/datafusion/core/src/datasource/datasource.rs
index 8db075a30a..4560b3820c 100644
--- a/datafusion/core/src/datasource/datasource.rs
+++ b/datafusion/core/src/datasource/datasource.rs
@@ -102,8 +102,8 @@ pub trait TableProvider: Sync + Send {
async fn insert_into(
&self,
_state: &SessionState,
- _input: &LogicalPlan,
- ) -> Result<()> {
+ _input: Arc<dyn ExecutionPlan>,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
let msg = "Insertion not implemented for this table".to_owned();
Err(DataFusionError::NotImplemented(msg))
}
diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs
index ca083aebe3..f41f8cb1bd 100644
--- a/datafusion/core/src/datasource/memory.rs
+++ b/datafusion/core/src/datasource/memory.rs
@@ -24,20 +24,22 @@ use std::sync::Arc;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
-use datafusion_expr::LogicalPlan;
use tokio::sync::RwLock;
use crate::datasource::{TableProvider, TableType};
use crate::error::{DataFusionError, Result};
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
-use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
+use crate::physical_plan::common;
use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::memory::MemoryExec;
+use crate::physical_plan::memory::MemoryWriteExec;
use crate::physical_plan::ExecutionPlan;
-use crate::physical_plan::{collect_partitioned, common};
use crate::physical_plan::{repartition::RepartitionExec, Partitioning};
+/// Type alias for partition data
+pub type PartitionData = Arc<RwLock<Vec<RecordBatch>>>;
+
/// In-memory data source for presenting a `Vec<RecordBatch>` as a
/// data source that can be queried by DataFusion. This allows data to
/// be pre-loaded into memory and then repeatedly queried without
@@ -45,7 +47,7 @@ use crate::physical_plan::{repartition::RepartitionExec, Partitioning};
#[derive(Debug)]
pub struct MemTable {
schema: SchemaRef,
- batches: Arc<RwLock<Vec<Vec<RecordBatch>>>>,
+ pub(crate) batches: Vec<PartitionData>,
}
impl MemTable {
@@ -58,7 +60,10 @@ impl MemTable {
{
Ok(Self {
schema,
- batches: Arc::new(RwLock::new(partitions)),
+ batches: partitions
+ .into_iter()
+ .map(|e| Arc::new(RwLock::new(e)))
+ .collect::<Vec<_>>(),
})
} else {
Err(DataFusionError::Plan(
@@ -147,71 +152,62 @@ impl TableProvider for MemTable {
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
- let batches = &self.batches.read().await;
- Ok(Arc::new(MemoryExec::try_new(
- batches,
+ let mut partitions = vec![];
+ for arc_inner_vec in self.batches.iter() {
+ let inner_vec = arc_inner_vec.read().await;
+ partitions.push(inner_vec.clone())
+ }
+ Ok(Arc::new(MemoryExec::try_new_owned_data(
+ partitions,
self.schema(),
projection.cloned(),
)?))
}
- /// Inserts the execution results of a given [LogicalPlan] into this [MemTable].
- /// The `LogicalPlan` must have the same schema as this `MemTable`.
+ /// Inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`].
+ /// The [`ExecutionPlan`] must have the same schema as this [`MemTable`].
///
/// # Arguments
///
- /// * `state` - The [SessionState] containing the context for executing the plan.
- /// * `input` - The [LogicalPlan] to execute and insert.
+ /// * `state` - The [`SessionState`] containing the context for executing the plan.
+ /// * `input` - The [`ExecutionPlan`] to execute and insert.
///
/// # Returns
///
/// * A `Result` indicating success or failure.
- async fn insert_into(&self, state: &SessionState, input: &LogicalPlan) -> Result<()> {
+ async fn insert_into(
+ &self,
+ _state: &SessionState,
+ input: Arc<dyn ExecutionPlan>,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
// Create a physical plan from the logical plan.
- let plan = state.create_physical_plan(input).await?;
-
// Check that the schema of the plan matches the schema of this table.
- if !plan.schema().eq(&self.schema) {
+ if !input.schema().eq(&self.schema) {
return Err(DataFusionError::Plan(
"Inserting query must have the same schema with the table.".to_string(),
));
}
- // Get the number of partitions in the plan and the table.
- let plan_partition_count = plan.output_partitioning().partition_count();
- let table_partition_count = self.batches.read().await.len();
+ if self.batches.is_empty() {
+ return Err(DataFusionError::Plan(
+ "The table must have partitions.".to_string(),
+ ));
+ }
- // Adjust the plan as necessary to match the number of partitions in the table.
- let plan: Arc<dyn ExecutionPlan> = if plan_partition_count
- == table_partition_count
- || table_partition_count == 0
- {
- plan
- } else if table_partition_count == 1 {
- // If the table has only one partition, coalesce the partitions in the plan.
- Arc::new(CoalescePartitionsExec::new(plan))
- } else {
- // Otherwise, repartition the plan using a round-robin partitioning scheme.
+ let input = if self.batches.len() > 1 {
Arc::new(RepartitionExec::try_new(
- plan,
- Partitioning::RoundRobinBatch(table_partition_count),
+ input,
+ Partitioning::RoundRobinBatch(self.batches.len()),
)?)
- };
-
- let results = collect_partitioned(plan, state.task_ctx()).await?;
-
- // Write the results into the table.
- let mut all_batches = self.batches.write().await;
-
- if all_batches.is_empty() {
- *all_batches = results
} else {
- for (batches, result) in all_batches.iter_mut().zip(results.into_iter()) {
- batches.extend(result);
- }
- }
+ input
+ };
- Ok(())
+ Ok(Arc::new(MemoryWriteExec::try_new(
+ input,
+ self.batches.clone(),
+ self.schema.clone(),
+ )?))
}
}
@@ -220,6 +216,7 @@ mod tests {
use super::*;
use crate::datasource::provider_as_source;
use crate::from_slice::FromSlice;
+ use crate::physical_plan::collect;
use crate::prelude::SessionContext;
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
@@ -455,21 +452,48 @@ mod tests {
Ok(())
}
- fn create_mem_table_scan(
+ async fn experiment(
schema: SchemaRef,
- data: Vec<Vec<RecordBatch>>,
- ) -> Result<Arc<LogicalPlan>> {
- // Convert the table into a provider so that it can be used in a query
- let provider = provider_as_source(Arc::new(MemTable::try_new(schema, data)?));
- // Create a table scan logical plan to read from the table
- Ok(Arc::new(
- LogicalPlanBuilder::scan("source", provider, None)?.build()?,
- ))
- }
-
- fn create_initial_ctx() -> Result<(SessionContext, SchemaRef, RecordBatch)> {
+ initial_data: Vec<Vec<RecordBatch>>,
+ inserted_data: Vec<Vec<RecordBatch>>,
+ ) -> Result<Vec<Vec<RecordBatch>>> {
// Create a new session context
let session_ctx = SessionContext::new();
+ // Create and register the initial table with the provided schema and data
+ let initial_table = Arc::new(MemTable::try_new(schema.clone(), initial_data)?);
+ session_ctx.register_table("t", initial_table.clone())?;
+ // Create and register the source table with the provided schema and inserted data
+ let source_table = Arc::new(MemTable::try_new(schema.clone(), inserted_data)?);
+ session_ctx.register_table("source", source_table.clone())?;
+ // Convert the source table into a provider so that it can be used in a query
+ let source = provider_as_source(source_table);
+ // Create a table scan logical plan to read from the source table
+ let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?;
+ // Create an insert plan to insert the source data into the initial table
+ let insert_into_table =
+ LogicalPlanBuilder::insert_into(scan_plan, "t", &schema)?.build()?;
+ // Create a physical plan from the insert plan
+ let plan = session_ctx
+ .state()
+ .create_physical_plan(&insert_into_table)
+ .await?;
+
+ // Execute the physical plan and collect the results
+ let res = collect(plan, session_ctx.task_ctx()).await?;
+ // Ensure the result is empty after the insert operation
+ assert!(res.is_empty());
+ // Read the data from the initial table and store it in a vector of partitions
+ let mut partitions = vec![];
+ for partition in initial_table.batches.iter() {
+ let part = partition.read().await.clone();
+ partitions.push(part);
+ }
+ Ok(partitions)
+ }
+
+ // Test inserting a single batch of data into a single partition
+ #[tokio::test]
+ async fn test_insert_into_single_partition() -> Result<()> {
// Create a new schema with one field called "a" of type Int32
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
@@ -478,111 +502,84 @@ mod tests {
schema.clone(),
vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
)?;
- Ok((session_ctx, schema, batch))
- }
-
- #[tokio::test]
- async fn test_insert_into_single_partition() -> Result<()> {
- let (session_ctx, schema, batch) = create_initial_ctx()?;
- let initial_table = Arc::new(MemTable::try_new(
- schema.clone(),
- vec![vec![batch.clone()]],
- )?);
- // Create a table scan logical plan to read from the table
- let single_partition_table_scan =
- create_mem_table_scan(schema.clone(), vec![vec![batch.clone()]])?;
- // Insert the data from the provider into the table
- initial_table
- .insert_into(&session_ctx.state(), &single_partition_table_scan)
- .await?;
+ // Run the experiment and obtain the resulting data in the table
+ let resulting_data_in_table =
+ experiment(schema, vec![vec![batch.clone()]], vec![vec![batch.clone()]])
+ .await?;
// Ensure that the table now contains two batches of data in the same partition
- assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2);
-
- // Create a new provider with 2 partitions
- let multi_partition_table_scan = create_mem_table_scan(
- schema.clone(),
- vec![vec![batch.clone()], vec![batch]],
- )?;
-
- // Insert the data from the provider into the table. We expect coalescing partitions.
- initial_table
- .insert_into(&session_ctx.state(), &multi_partition_table_scan)
- .await?;
- // Ensure that the table now contains 4 batches of data with only 1 partition
- assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4);
- assert_eq!(initial_table.batches.read().await.len(), 1);
+ assert_eq!(resulting_data_in_table[0].len(), 2);
Ok(())
}
+ // Test inserting multiple batches of data into a single partition
#[tokio::test]
- async fn test_insert_into_multiple_partition() -> Result<()> {
- let (session_ctx, schema, batch) = create_initial_ctx()?;
- // create a memory table with two partitions, each having one batch with the same data
- let initial_table = Arc::new(MemTable::try_new(
- schema.clone(),
- vec![vec![batch.clone()], vec![batch.clone()]],
- )?);
+ async fn test_insert_into_single_partition_with_multi_partition() -> Result<()> {
+ // Create a new schema with one field called "a" of type Int32
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
- // scan a data source provider from a memory table with a single partition
- let single_partition_table_scan = create_mem_table_scan(
+ // Create a new batch of data to insert into the table
+ let batch = RecordBatch::try_new(
schema.clone(),
- vec![vec![batch.clone(), batch.clone()]],
+ vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
)?;
-
- // insert the data from the 1 partition data source provider into the initial table
- initial_table
- .insert_into(&session_ctx.state(), &single_partition_table_scan)
- .await?;
-
- // We expect round robin repartition here, each partition gets 1 batch.
- assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2);
- assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 2);
-
- // scan a data source provider from a memory table with 2 partition
- let multi_partition_table_scan = create_mem_table_scan(
- schema.clone(),
+ // Run the experiment and obtain the resulting data in the table
+ let resulting_data_in_table = experiment(
+ schema,
+ vec![vec![batch.clone()]],
vec![vec![batch.clone()], vec![batch]],
- )?;
- // We expect one-to-one partition mapping.
- initial_table
- .insert_into(&session_ctx.state(), &multi_partition_table_scan)
- .await?;
- // Ensure that the table now contains 3 batches of data with 2 partitions.
- assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 3);
- assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 3);
+ )
+ .await?;
+ // Ensure that the table now contains three batches of data in the same partition
+ assert_eq!(resulting_data_in_table[0].len(), 3);
Ok(())
}
+ // Test inserting multiple batches of data into multiple partitions
#[tokio::test]
- async fn test_insert_into_empty_table() -> Result<()> {
- let (session_ctx, schema, batch) = create_initial_ctx()?;
- // create empty memory table
- let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![])?);
+ async fn test_insert_into_multi_partition_with_multi_partition() -> Result<()> {
+ // Create a new schema with one field called "a" of type Int32
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
- // scan a data source provider from a memory table with a single partition
- let single_partition_table_scan = create_mem_table_scan(
+ // Create a new batch of data to insert into the table
+ let batch = RecordBatch::try_new(
schema.clone(),
- vec![vec![batch.clone(), batch.clone()]],
+ vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
)?;
+ // Run the experiment and obtain the resulting data in the table
+ let resulting_data_in_table = experiment(
+ schema,
+ vec![vec![batch.clone()], vec![batch.clone()]],
+ vec![
+ vec![batch.clone(), batch.clone()],
+ vec![batch.clone(), batch],
+ ],
+ )
+ .await?;
+ // Ensure that each partition in the table now contains three batches of data
+ assert_eq!(resulting_data_in_table[0].len(), 3);
+ assert_eq!(resulting_data_in_table[1].len(), 3);
+ Ok(())
+ }
- // insert the data from the 1 partition data source provider into the initial table
- initial_table
- .insert_into(&session_ctx.state(), &single_partition_table_scan)
- .await?;
-
- assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2);
+ #[tokio::test]
+ async fn test_insert_from_empty_table() -> Result<()> {
+ // Create a new schema with one field called "a" of type Int32
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
- // scan a data source provider from a memory table with 2 partition
- let single_partition_table_scan = create_mem_table_scan(
+ // Create a new batch of data to insert into the table
+ let batch = RecordBatch::try_new(
schema.clone(),
- vec![vec![batch.clone()], vec![batch]],
+ vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
)?;
- // We expect coalesce partitions here.
- initial_table
- .insert_into(&session_ctx.state(), &single_partition_table_scan)
- .await?;
- // Ensure that the table now contains 3 batches of data with 2 partitions.
- assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4);
+ // Run the experiment and obtain the resulting data in the table
+ let resulting_data_in_table = experiment(
+ schema,
+ vec![vec![batch.clone(), batch.clone()]],
+ vec![vec![]],
+ )
+ .await?;
+ // Ensure that the table now contains two batches of data in the same partition
+ assert_eq!(resulting_data_in_table[0].len(), 2);
Ok(())
}
}
diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs
index bb6d58fb90..dce8a1c424 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -33,7 +33,7 @@ use crate::{
};
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
- DescribeTable, DmlStatement, StringifiedPlan, WriteOp,
+ DescribeTable, StringifiedPlan,
};
pub use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::var_provider::is_system_variables;
@@ -369,23 +369,6 @@ impl SessionContext {
/// Execute the [`LogicalPlan`], return a [`DataFrame`]
pub async fn execute_logical_plan(&self, plan: LogicalPlan) -> Result<DataFrame> {
match plan {
- LogicalPlan::Dml(DmlStatement {
- table_name,
- op: WriteOp::Insert,
- input,
- ..
- }) => {
- if self.table_exist(&table_name)? {
- let name = table_name.table();
- let provider = self.table_provider(name).await?;
- provider.insert_into(&self.state(), &input).await?;
- } else {
- return Err(DataFusionError::Execution(format!(
- "Table '{table_name}' does not exist"
- )));
- }
- self.return_empty_dataframe()
- }
LogicalPlan::Ddl(ddl) => match ddl {
DdlStatement::CreateExternalTable(cmd) => {
self.create_external_table(&cmd).await
@@ -1475,7 +1458,7 @@ impl SessionState {
.resolve(&catalog.default_catalog, &catalog.default_schema)
}
- fn schema_for_ref<'a>(
+ pub(crate) fn schema_for_ref<'a>(
&'a self,
table_ref: impl Into<TableReference<'a>>,
) -> Result<Arc<dyn SchemaProvider>> {
diff --git a/datafusion/core/src/physical_plan/memory.rs b/datafusion/core/src/physical_plan/memory.rs
index f0cd48fa4f..12a37c65c8 100644
--- a/datafusion/core/src/physical_plan/memory.rs
+++ b/datafusion/core/src/physical_plan/memory.rs
@@ -17,11 +17,6 @@
//! Execution plan for reading in-memory batches of data
-use core::fmt;
-use std::any::Any;
-use std::sync::Arc;
-use std::task::{Context, Poll};
-
use super::expressions::PhysicalSortExpr;
use super::{
common, project_schema, DisplayFormatType, ExecutionPlan, Partitioning,
@@ -30,10 +25,20 @@ use super::{
use crate::error::Result;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
+use core::fmt;
+use futures::FutureExt;
+use futures::StreamExt;
+use std::any::Any;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+use crate::datasource::memory::PartitionData;
use crate::execution::context::TaskContext;
+use crate::physical_plan::Distribution;
use datafusion_common::DataFusionError;
-use futures::Stream;
+use futures::{ready, Stream};
+use std::mem;
+use tokio::sync::{OwnedRwLockWriteGuard, RwLock};
/// Execution plan for reading in-memory batches of data
pub struct MemoryExec {
@@ -150,6 +155,23 @@ impl MemoryExec {
})
}
+ /// Create a new execution plan for reading in-memory record batches
+ /// The provided `schema` should not have the projection applied.
+ pub fn try_new_owned_data(
+ partitions: Vec<Vec<RecordBatch>>,
+ schema: SchemaRef,
+ projection: Option<Vec<usize>>,
+ ) -> Result<Self> {
+ let projected_schema = project_schema(&schema, projection.as_ref())?;
+ Ok(Self {
+ partitions,
+ schema,
+ projected_schema,
+ projection,
+ sort_information: None,
+ })
+ }
+
/// Set sort information
pub fn with_sort_information(
mut self,
@@ -223,15 +245,365 @@ impl RecordBatchStream for MemoryStream {
}
}
+/// Execution plan for writing record batches to an in-memory table.
+pub struct MemoryWriteExec {
+ /// Input plan that produces the record batches to be written.
+ input: Arc<dyn ExecutionPlan>,
+ /// Reference to the MemTable's partition data.
+ batches: Vec<PartitionData>,
+ /// Schema describing the structure of the data.
+ schema: SchemaRef,
+}
+
+impl fmt::Debug for MemoryWriteExec {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "schema: {:?}", self.schema)
+ }
+}
+
+impl ExecutionPlan for MemoryWriteExec {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ /// Get the schema for this execution plan
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+
+ fn output_partitioning(&self) -> Partitioning {
+ Partitioning::UnknownPartitioning(
+ self.input.output_partitioning().partition_count(),
+ )
+ }
+
+ fn benefits_from_input_partitioning(&self) -> bool {
+ false
+ }
+
+ fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
+ self.input.output_ordering()
+ }
+
+ fn required_input_distribution(&self) -> Vec<Distribution> {
+ // If the partition count of the MemTable is one, we want to require SinglePartition
+ // since it would induce better plans in plan optimizer.
+ if self.batches.len() == 1 {
+ vec![Distribution::SinglePartition]
+ } else {
+ vec![Distribution::UnspecifiedDistribution]
+ }
+ }
+
+ fn maintains_input_order(&self) -> Vec<bool> {
+ // In theory, if MemTable partition count equals the input plans output partition count,
+ // the Execution plan can preserve the order inside the partitions.
+ vec![self.batches.len() == self.input.output_partitioning().partition_count()]
+ }
+
+ fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+ vec![self.input.clone()]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ children: Vec<Arc<dyn ExecutionPlan>>,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ Ok(Arc::new(MemoryWriteExec::try_new(
+ children[0].clone(),
+ self.batches.clone(),
+ self.schema.clone(),
+ )?))
+ }
+
+ /// Execute the plan and return a stream of record batches for the specified partition.
+ /// Depending on the number of input partitions and MemTable partitions, it will choose
+ /// either a less lock acquiring or a locked implementation.
+ fn execute(
+ &self,
+ partition: usize,
+ context: Arc<TaskContext>,
+ ) -> Result<SendableRecordBatchStream> {
+ let batch_count = self.batches.len();
+ let data = self.input.execute(partition, context)?;
+ if batch_count >= self.input.output_partitioning().partition_count() {
+ // If the number of input partitions matches the number of MemTable partitions,
+ // use a lightweight implementation that doesn't utilize as many locks.
+ let table_partition = self.batches[partition].clone();
+ Ok(Box::pin(MemorySinkOneToOneStream::try_new(
+ table_partition,
+ data,
+ self.schema.clone(),
+ )?))
+ } else {
+ // Otherwise, use the locked implementation.
+ let table_partition = self.batches[partition % batch_count].clone();
+ Ok(Box::pin(MemorySinkStream::try_new(
+ table_partition,
+ data,
+ self.schema.clone(),
+ )?))
+ }
+ }
+
+ fn fmt_as(
+ &self,
+ t: DisplayFormatType,
+ f: &mut std::fmt::Formatter,
+ ) -> std::fmt::Result {
+ match t {
+ DisplayFormatType::Default => {
+ write!(
+ f,
+ "MemoryWriteExec: partitions={}, input_partition={}",
+ self.batches.len(),
+ self.input.output_partitioning().partition_count()
+ )
+ }
+ }
+ }
+
+ fn statistics(&self) -> Statistics {
+ Statistics::default()
+ }
+}
+
+impl MemoryWriteExec {
+ /// Create a new execution plan for reading in-memory record batches
+ /// The provided `schema` should not have the projection applied.
+ pub fn try_new(
+ plan: Arc<dyn ExecutionPlan>,
+ batches: Vec<Arc<RwLock<Vec<RecordBatch>>>>,
+ schema: SchemaRef,
+ ) -> Result<Self> {
+ Ok(Self {
+ input: plan,
+ batches,
+ schema,
+ })
+ }
+}
+
+/// This object encodes the different states of the [`MemorySinkStream`] when
+/// processing record batches.
+enum MemorySinkStreamState {
+ /// The stream is pulling data from the input.
+ Pull,
+ /// The stream is writing data to the table partition.
+ Write { maybe_batch: Option<RecordBatch> },
+}
+
+/// A stream that saves record batches in memory-backed storage.
+/// Can work even when multiple input partitions map to the same table
+/// partition, achieves buffer exclusivity by locking before writing.
+struct MemorySinkStream {
+ /// Stream of record batches to be inserted into the memory table.
+ data: SendableRecordBatchStream,
+ /// Memory table partition that stores the record batches.
+ table_partition: PartitionData,
+ /// Schema representing the structure of the data.
+ schema: SchemaRef,
+ /// State of the iterator when processing multiple polls.
+ state: MemorySinkStreamState,
+}
+
+impl MemorySinkStream {
+ /// Create a new `MemorySinkStream` with the provided parameters.
+ pub fn try_new(
+ table_partition: PartitionData,
+ data: SendableRecordBatchStream,
+ schema: SchemaRef,
+ ) -> Result<Self> {
+ Ok(Self {
+ table_partition,
+ data,
+ schema,
+ state: MemorySinkStreamState::Pull,
+ })
+ }
+
+ /// Implementation of the `poll_next` method. Continuously polls the record
+ /// batch stream, switching between the Pull and Write states. In case of
+ /// an error, returns the error immediately.
+ fn poll_next_impl(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<Result<RecordBatch>>> {
+ loop {
+ match &mut self.state {
+ MemorySinkStreamState::Pull => {
+ // Pull data from the input stream.
+ if let Some(result) = ready!(self.data.as_mut().poll_next(cx)) {
+ match result {
+ Ok(batch) => {
+ // Switch to the Write state with the received batch.
+ self.state = MemorySinkStreamState::Write {
+ maybe_batch: Some(batch),
+ }
+ }
+ Err(e) => return Poll::Ready(Some(Err(e))), // Return the error immediately.
+ }
+ } else {
+ return Poll::Ready(None); // If the input stream is exhausted, return None.
+ }
+ }
+ MemorySinkStreamState::Write { maybe_batch } => {
+ // Acquire a write lock on the table partition.
+ let mut partition =
+ ready!(self.table_partition.write().boxed().poll_unpin(cx));
+ if let Some(b) = mem::take(maybe_batch) {
+ partition.push(b); // Insert the batch into the table partition.
+ }
+ self.state = MemorySinkStreamState::Pull; // Switch back to the Pull state.
+ }
+ }
+ }
+ }
+}
+
+impl Stream for MemorySinkStream {
+ type Item = Result<RecordBatch>;
+
+ fn poll_next(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ self.poll_next_impl(cx)
+ }
+}
+
+impl RecordBatchStream for MemorySinkStream {
+ /// Get the schema
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+}
+
+/// This object encodes the different states of the [`MemorySinkOneToOneStream`]
+/// when processing record batches.
+enum MemorySinkOneToOneStreamState {
+ /// The `Acquire` variant represents the state where the [`MemorySinkOneToOneStream`]
+ /// is waiting to acquire the write lock on the shared partition to store the record batches.
+ Acquire,
+
+ /// The `Pull` variant represents the state where the [`MemorySinkOneToOneStream`] has
+ /// acquired the write lock on the shared partition and can pull record batches from
+ /// the input stream to store in the partition.
+ Pull {
+ /// The `partition` field contains an [`OwnedRwLockWriteGuard`] which wraps the
+ /// shared partition, providing exclusive write access to the underlying `Vec<RecordBatch>`.
+ partition: OwnedRwLockWriteGuard<Vec<RecordBatch>>,
+ },
+}
+
+/// A stream that saves record batches in memory-backed storage.
+/// Assumes that every table partition has at most one corresponding input
+/// partition, so it locks the table partition only once.
+struct MemorySinkOneToOneStream {
+ /// Stream of record batches to be inserted into the memory table.
+ data: SendableRecordBatchStream,
+ /// Memory table partition that stores the record batches.
+ table_partition: PartitionData,
+ /// Schema representing the structure of the data.
+ schema: SchemaRef,
+ /// State of the iterator when processing multiple polls.
+ state: MemorySinkOneToOneStreamState,
+}
+
+impl MemorySinkOneToOneStream {
+ /// Create a new `MemorySinkOneToOneStream` with the provided parameters.
+ pub fn try_new(
+ table_partition: Arc<RwLock<Vec<RecordBatch>>>,
+ data: SendableRecordBatchStream,
+ schema: SchemaRef,
+ ) -> Result<Self> {
+ Ok(Self {
+ table_partition,
+ data,
+ schema,
+ state: MemorySinkOneToOneStreamState::Acquire,
+ })
+ }
+
+ /// Implementation of the `poll_next` method. Continuously polls the record
+ /// batch stream and pushes batches to their corresponding table partition,
+ /// which are lock-acquired only once. In case of an error, returns the
+ /// error immediately.
+ fn poll_next_impl(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<Result<RecordBatch>>> {
+ loop {
+ match &mut self.state {
+ MemorySinkOneToOneStreamState::Acquire => {
+ // Acquire a write lock on the table partition.
+ self.state = MemorySinkOneToOneStreamState::Pull {
+ partition: ready!(self
+ .table_partition
+ .clone()
+ .write_owned()
+ .boxed()
+ .poll_unpin(cx)),
+ };
+ }
+ MemorySinkOneToOneStreamState::Pull { partition } => {
+ // Iterate over the batches in the input data stream.
+ while let Some(result) = ready!(self.data.poll_next_unpin(cx)) {
+ match result {
+ Ok(batch) => {
+ partition.push(batch);
+ } // Insert the batch into the table partition.
+ Err(e) => return Poll::Ready(Some(Err(e))), // Return the error immediately.
+ }
+ }
+ // If the input stream is exhausted, return None to indicate the end of the stream.
+ return Poll::Ready(None);
+ }
+ }
+ }
+ }
+}
+
+impl Stream for MemorySinkOneToOneStream {
+ type Item = Result<RecordBatch>;
+
+ /// Poll the stream for the next item.
+ fn poll_next(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ self.poll_next_impl(cx)
+ }
+}
+
+impl RecordBatchStream for MemorySinkOneToOneStream {
+ /// Get the schema of the record batches in the stream.
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
+ use crate::datasource::streaming::PartitionStream;
+ use crate::datasource::{MemTable, TableProvider};
use crate::from_slice::FromSlice;
+ use crate::physical_plan::stream::RecordBatchStreamAdapter;
+ use crate::physical_plan::streaming::StreamingTableExec;
use crate::physical_plan::ColumnStatistics;
- use crate::prelude::SessionContext;
+ use crate::physical_plan::{collect, displayable, SendableRecordBatchStream};
+ use crate::prelude::{CsvReadOptions, SessionContext};
+ use crate::test_util;
use arrow::array::Int32Array;
- use arrow::datatypes::{DataType, Field, Schema};
+ use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+ use arrow::record_batch::RecordBatch;
+ use datafusion_common::Result;
+ use datafusion_execution::config::SessionConfig;
+ use datafusion_execution::TaskContext;
use futures::StreamExt;
+ use std::sync::Arc;
fn mock_data() -> Result<(SchemaRef, RecordBatch)> {
let schema = Arc::new(Schema::new(vec![
@@ -340,4 +712,262 @@ mod tests {
Ok(())
}
+
+ #[tokio::test]
+ async fn test_insert_into() -> Result<()> {
+ // Create session context
+ let config = SessionConfig::new().with_target_partitions(8);
+ let ctx = SessionContext::with_config(config);
+ let testdata = test_util::arrow_test_data();
+ let schema = test_util::aggr_test_schema();
+ ctx.register_csv(
+ "aggregate_test_100",
+ &format!("{testdata}/csv/aggregate_test_100.csv"),
+ CsvReadOptions::new().schema(&schema),
+ )
+ .await?;
+ ctx.sql(
+ "CREATE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL)",
+ )
+ .await?;
+
+ let sql = "INSERT INTO table_without_values SELECT
+ SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),
+ COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
+ FROM aggregate_test_100
+ ORDER by c1
+ ";
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ let expected = {
+ vec![
+ "MemoryWriteExec: partitions=1, input_partition=1",
+ " ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2]",
+ " SortPreservingMergeExec: [c1@2 ASC NULLS LAST]",
+ " ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(UInt8(1)), c1@0 as c1]",
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Ro [...]
+ " SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST]",
+ " CoalesceBatchesExec: target_batch_size=8192",
+ " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 8), input_partitions=8",
+ " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_insert_into_as_select_multi_partitioned() -> Result<()> {
+ // Create session context
+ let config = SessionConfig::new().with_target_partitions(8);
+ let ctx = SessionContext::with_config(config);
+ let testdata = test_util::arrow_test_data();
+ let schema = test_util::aggr_test_schema();
+ ctx.register_csv(
+ "aggregate_test_100",
+ &format!("{testdata}/csv/aggregate_test_100.csv"),
+ CsvReadOptions::new().schema(&schema),
+ )
+ .await?;
+ ctx.sql(
+ "CREATE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL)",
+ )
+ .await?;
+
+ let sql = "INSERT INTO table_without_values SELECT
+ SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1,
+ COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2
+ FROM aggregate_test_100";
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ let expected = {
+ vec![
+ "MemoryWriteExec: partitions=1, input_partition=1",
+ " CoalescePartitionsExec",
+ " ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2]",
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows [...]
+ " SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST]",
+ " CoalesceBatchesExec: target_batch_size=8192",
+ " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 8), input_partitions=8",
+ " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+ Ok(())
+ }
+
+ // TODO: The generated plan is suboptimal since SortExec is in global state.
+ #[tokio::test]
+ async fn test_insert_into_as_select_single_partition() -> Result<()> {
+ // Create session context
+ let config = SessionConfig::new().with_target_partitions(8);
+ let ctx = SessionContext::with_config(config);
+ let testdata = test_util::arrow_test_data();
+ let schema = test_util::aggr_test_schema();
+ ctx.register_csv(
+ "aggregate_test_100",
+ &format!("{testdata}/csv/aggregate_test_100.csv"),
+ CsvReadOptions::new().schema(&schema),
+ )
+ .await?;
+ ctx.sql("CREATE TABLE table_without_values AS SELECT
+ SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1,
+ COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2
+ FROM aggregate_test_100")
+ .await?;
+
+ let sql = "INSERT INTO table_without_values SELECT
+ SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1,
+ COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2
+ FROM aggregate_test_100
+ ORDER BY c1";
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ let expected = {
+ vec![
+ "MemoryWriteExec: partitions=8, input_partition=8",
+ " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1",
+ " ProjectionExec: expr=[a1@0 as a1, a2@1 as a2]",
+ " SortPreservingMergeExec: [c1@2 ASC NULLS LAST]",
+ " ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as a1, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as a2, c1@0 as c1]",
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: [...]
+ " SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST]",
+ " CoalesceBatchesExec: target_batch_size=8192",
+ " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 8), input_partitions=8",
+ " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+ Ok(())
+ }
+
+ // DummyPartition is a simple implementation of the PartitionStream trait.
+ // It produces a stream of record batches with a fixed schema and the same content.
+ struct DummyPartition {
+ schema: SchemaRef,
+ batch: RecordBatch,
+ num_batches: usize,
+ }
+
+ impl PartitionStream for DummyPartition {
+ // Return a reference to the schema of this partition.
+ fn schema(&self) -> &SchemaRef {
+ &self.schema
+ }
+
+ // Execute the partition stream, producing a stream of record batches.
+ fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
+ let batches = itertools::repeat_n(self.batch.clone(), self.num_batches);
+ Box::pin(RecordBatchStreamAdapter::new(
+ self.schema.clone(),
+ futures::stream::iter(batches).map(Ok),
+ ))
+ }
+ }
+
+ // Test the less-lock mode by inserting a large number of batches into a table.
+ #[tokio::test]
+ async fn test_one_to_one_mode() -> Result<()> {
+ let num_batches = 10000;
+ // Create a new session context
+ let session_ctx = SessionContext::new();
+ // Create a new schema with one field called "a" of type Int32
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
+
+ // Create a new batch of data to insert into the table
+ let batch = RecordBatch::try_new(
+ schema.clone(),
+ vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
+ )?;
+ let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?);
+
+ let single_partition = Arc::new(DummyPartition {
+ schema: schema.clone(),
+ batch,
+ num_batches,
+ });
+ let input = Arc::new(StreamingTableExec::try_new(
+ schema.clone(),
+ vec![single_partition],
+ None,
+ false,
+ )?);
+ let plan = initial_table
+ .insert_into(&session_ctx.state(), input)
+ .await?;
+ let res = collect(plan, session_ctx.task_ctx()).await?;
+ assert!(res.is_empty());
+ // Ensure that the table now contains two batches of data in the same partition
+ assert_eq!(initial_table.batches[0].read().await.len(), num_batches);
+ Ok(())
+ }
+
+ // Test the locked mode by inserting a large number of batches into a table. It tests
+ // where the table partition count is not equal to the input's output partition count.
+ #[tokio::test]
+ async fn test_locked_mode() -> Result<()> {
+ let num_batches = 10000;
+ // Create a new session context
+ let session_ctx = SessionContext::new();
+ // Create a new schema with one field called "a" of type Int32
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
+
+ // Create a new batch of data to insert into the table
+ let batch = RecordBatch::try_new(
+ schema.clone(),
+ vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
+ )?;
+ let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?);
+
+ let single_partition = Arc::new(DummyPartition {
+ schema: schema.clone(),
+ batch,
+ num_batches,
+ });
+ let input = Arc::new(StreamingTableExec::try_new(
+ schema.clone(),
+ vec![
+ single_partition.clone(),
+ single_partition.clone(),
+ single_partition,
+ ],
+ None,
+ false,
+ )?);
+ let plan = initial_table
+ .insert_into(&session_ctx.state(), input)
+ .await?;
+ let res = collect(plan, session_ctx.task_ctx()).await?;
+ assert!(res.is_empty());
+ // Ensure that the table now contains two batches of data in the same partition
+ assert_eq!(initial_table.batches[0].read().await.len(), num_batches * 3);
+ Ok(())
+ }
}
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 782dcf1335..7f68d5a39a 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -67,7 +67,7 @@ use datafusion_expr::expr::{
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
-use datafusion_expr::{logical_plan, StringifiedPlan};
+use datafusion_expr::{logical_plan, DmlStatement, StringifiedPlan, WriteOp};
use datafusion_expr::{WindowFrame, WindowFrameBound};
use datafusion_optimizer::utils::unalias;
use datafusion_physical_expr::expressions::Literal;
@@ -489,6 +489,23 @@ impl DefaultPhysicalPlanner {
let unaliased: Vec<Expr> = filters.into_iter().map(unalias).collect();
source.scan(session_state, projection.as_ref(), &unaliased, *fetch).await
}
+ LogicalPlan::Dml(DmlStatement {
+ table_name,
+ op: WriteOp::Insert,
+ input,
+ ..
+ }) => {
+ let name = table_name.table();
+ let schema = session_state.schema_for_ref(table_name)?;
+ if let Some(provider) = schema.table(name).await {
+ let input_exec = self.create_initial_plan(input, session_state).await?;
+ provider.insert_into(session_state, input_exec).await
+ } else {
+ return Err(DataFusionError::Execution(format!(
+ "Table '{table_name}' does not exist"
+ )));
+ }
+ }
LogicalPlan::Values(Values {
values,
schema,