You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/05 16:03:27 UTC

[arrow-datafusion] branch master updated: Remove interior mutability of `MemTable` (#4514)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 0fca6d5de Remove interior mutability of `MemTable` (#4514)
0fca6d5de is described below

commit 0fca6d5de60c40d422816cd656919eaf6bcb8442
Author: xudong.w <wx...@gmail.com>
AuthorDate: Tue Dec 6 00:03:20 2022 +0800

    Remove interior mutability of `MemTable` (#4514)
    
    * Remove interior mutability of MemTable
    
    * remove insert_batches
---
 datafusion/core/src/datasource/memory.rs           | 13 +++-------
 datafusion/core/tests/sqllogictests/src/error.rs   | 10 ++++++++
 .../core/tests/sqllogictests/src/insert/mod.rs     | 28 ++++++++++++----------
 3 files changed, 28 insertions(+), 23 deletions(-)

diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs
index a80c4b94d..632ef8d28 100644
--- a/datafusion/core/src/datasource/memory.rs
+++ b/datafusion/core/src/datasource/memory.rs
@@ -26,7 +26,6 @@ use std::sync::Arc;
 use arrow::datatypes::SchemaRef;
 use arrow::record_batch::RecordBatch;
 use async_trait::async_trait;
-use parking_lot::RwLock;
 
 use crate::datasource::{TableProvider, TableType};
 use crate::error::{DataFusionError, Result};
@@ -41,7 +40,7 @@ use crate::physical_plan::{repartition::RepartitionExec, Partitioning};
 #[derive(Debug)]
 pub struct MemTable {
     schema: SchemaRef,
-    batches: Arc<RwLock<Vec<Vec<RecordBatch>>>>,
+    batches: Vec<Vec<RecordBatch>>,
 }
 
 impl MemTable {
@@ -54,7 +53,7 @@ impl MemTable {
         {
             Ok(Self {
                 schema,
-                batches: Arc::new(RwLock::new(partitions)),
+                batches: partitions,
             })
         } else {
             Err(DataFusionError::Plan(
@@ -118,11 +117,6 @@ impl MemTable {
         }
         MemTable::try_new(schema.clone(), data)
     }
-
-    /// Get record batches in MemTable
-    pub fn get_batches(&self) -> Arc<RwLock<Vec<Vec<RecordBatch>>>> {
-        self.batches.clone()
-    }
 }
 
 #[async_trait]
@@ -146,9 +140,8 @@ impl TableProvider for MemTable {
         _filters: &[Expr],
         _limit: Option<usize>,
     ) -> Result<Arc<dyn ExecutionPlan>> {
-        let batches = self.batches.read();
         Ok(Arc::new(MemoryExec::try_new(
-            &(*batches).clone(),
+            &self.batches.clone(),
             self.schema(),
             projection.cloned(),
         )?))
diff --git a/datafusion/core/tests/sqllogictests/src/error.rs b/datafusion/core/tests/sqllogictests/src/error.rs
index 5324e8f88..0b073870d 100644
--- a/datafusion/core/tests/sqllogictests/src/error.rs
+++ b/datafusion/core/tests/sqllogictests/src/error.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use arrow::error::ArrowError;
 use datafusion_common::DataFusionError;
 use sqllogictest::TestError;
 use sqlparser::parser::ParserError;
@@ -32,6 +33,8 @@ pub enum DFSqlLogicTestError {
     DataFusion(DataFusionError),
     /// Error returned when SQL is syntactically incorrect.
     Sql(ParserError),
+    /// Error from arrow-rs
+    Arrow(ArrowError),
 }
 
 impl From<TestError> for DFSqlLogicTestError {
@@ -52,6 +55,12 @@ impl From<ParserError> for DFSqlLogicTestError {
     }
 }
 
+impl From<ArrowError> for DFSqlLogicTestError {
+    fn from(value: ArrowError) -> Self {
+        DFSqlLogicTestError::Arrow(value)
+    }
+}
+
 impl Display for DFSqlLogicTestError {
     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
         match self {
@@ -64,6 +73,7 @@ impl Display for DFSqlLogicTestError {
                 write!(f, "DataFusion error: {}", error)
             }
             DFSqlLogicTestError::Sql(error) => write!(f, "SQL Parser error: {}", error),
+            DFSqlLogicTestError::Arrow(error) => write!(f, "Arrow error: {}", error),
         }
     }
 }
diff --git a/datafusion/core/tests/sqllogictests/src/insert/mod.rs b/datafusion/core/tests/sqllogictests/src/insert/mod.rs
index 025015f5f..a8f24a051 100644
--- a/datafusion/core/tests/sqllogictests/src/insert/mod.rs
+++ b/datafusion/core/tests/sqllogictests/src/insert/mod.rs
@@ -19,6 +19,7 @@ mod util;
 
 use crate::error::Result;
 use crate::insert::util::LogicTestContextProvider;
+use arrow::record_batch::RecordBatch;
 use datafusion::datasource::MemTable;
 use datafusion::prelude::SessionContext;
 use datafusion_common::{DFSchema, DataFusionError};
@@ -26,6 +27,7 @@ use datafusion_expr::Expr as DFExpr;
 use datafusion_sql::planner::SqlToRel;
 use sqlparser::ast::{Expr, SetExpr, Statement as SQLStatement};
 use std::collections::HashMap;
+use std::sync::Arc;
 
 pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result<String> {
     // First, use sqlparser to get table name and insert values
@@ -52,19 +54,14 @@ pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result<
         _ => unreachable!(),
     }
 
-    // Second, get table by table name
-    // Here we assume table must be in memory table.
-    let table_provider = ctx.table_provider(table_name.as_str())?;
-    let table_batches = table_provider
-        .as_any()
-        .downcast_ref::<MemTable>()
-        .unwrap()
-        .get_batches();
+    // Second, get batches in table and destroy the old table
+    let mut origin_batches = ctx.table(table_name.as_str())?.collect().await?;
+    let schema = ctx.table_provider(table_name.as_str())?.schema();
+    ctx.deregister_table(table_name.as_str())?;
 
     // Third, transfer insert values to `RecordBatch`
     // Attention: schema info can be ignored. (insert values don't contain schema info)
     let sql_to_rel = SqlToRel::new(&LogicTestContextProvider {});
-    let mut insert_batches = Vec::with_capacity(insert_values.len());
     for row in insert_values.into_iter() {
         let logical_exprs = row
             .into_iter()
@@ -74,12 +71,17 @@ pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result<
             .collect::<std::result::Result<Vec<DFExpr>, DataFusionError>>()?;
         // Directly use `select` to get `RecordBatch`
         let dataframe = ctx.read_empty()?;
-        insert_batches.push(dataframe.select(logical_exprs)?.collect().await?)
+        origin_batches.extend(dataframe.select(logical_exprs)?.collect().await?)
     }
 
-    // Final, append the `RecordBatch` to memtable's batches
-    let mut table_batches = table_batches.write();
-    table_batches.extend(insert_batches);
+    // Replace new batches schema to old schema
+    for batch in origin_batches.iter_mut() {
+        *batch = RecordBatch::try_new(schema.clone(), batch.columns().to_vec())?;
+    }
+
+    // Final, create new memtable with same schema.
+    let new_provider = MemTable::try_new(schema, vec![origin_batches])?;
+    ctx.register_table(table_name.as_str(), Arc::new(new_provider))?;
 
     Ok("".to_string())
 }