You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by xu...@apache.org on 2022/12/04 17:46:15 UTC

[arrow-datafusion] branch master updated: Refactor code for `insert` in sqllogictest (#4503)

This is an automated email from the ASF dual-hosted git repository.

xudong963 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 4f7b0038c Refactor code for `insert` in sqllogictest (#4503)
4f7b0038c is described below

commit 4f7b0038cb66c94f8a5ce91de7dfefb5409172c1
Author: xudong.w <wx...@gmail.com>
AuthorDate: Mon Dec 5 01:46:10 2022 +0800

    Refactor code for `insert` in sqllogictest (#4503)
    
    * Refactor code in insert for sqllogictest
    
    * refactor error handle
    
    * fix clippy
---
 datafusion/core/src/execution/context.rs           | 24 ++++++----------
 datafusion/core/tests/sqllogictests/src/error.rs   | 12 --------
 .../core/tests/sqllogictests/src/insert/mod.rs     | 33 ++++++++--------------
 datafusion/core/tests/sqllogictests/src/main.rs    | 12 +++++---
 4 files changed, 27 insertions(+), 54 deletions(-)

diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs
index 9115b2efb..a9eb42bab 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -940,22 +940,14 @@ impl SessionContext {
         table_ref: impl Into<TableReference<'a>>,
     ) -> Result<Arc<DataFrame>> {
         let table_ref = table_ref.into();
-        let schema = self.state.read().schema_for_ref(table_ref)?;
-        match schema.table(table_ref.table()) {
-            Some(ref provider) => {
-                let plan = LogicalPlanBuilder::scan(
-                    table_ref.table(),
-                    provider_as_source(Arc::clone(provider)),
-                    None,
-                )?
-                .build()?;
-                Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
-            }
-            _ => Err(DataFusionError::Plan(format!(
-                "No table named '{}'",
-                table_ref.table()
-            ))),
-        }
+        let provider = self.table_provider(table_ref)?;
+        let plan = LogicalPlanBuilder::scan(
+            table_ref.table(),
+            provider_as_source(Arc::clone(&provider)),
+            None,
+        )?
+        .build()?;
+        Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
     }
 
     /// Return a [`TabelProvider`] for the specified table.
diff --git a/datafusion/core/tests/sqllogictests/src/error.rs b/datafusion/core/tests/sqllogictests/src/error.rs
index 8ac482141..5324e8f88 100644
--- a/datafusion/core/tests/sqllogictests/src/error.rs
+++ b/datafusion/core/tests/sqllogictests/src/error.rs
@@ -32,12 +32,6 @@ pub enum DFSqlLogicTestError {
     DataFusion(DataFusionError),
     /// Error returned when SQL is syntactically incorrect.
     Sql(ParserError),
-    /// Error returned on a branch that we know it is possible
-    /// but to which we still have no implementation for.
-    /// Often, these errors are tracked in our issue tracker.
-    NotImplemented(String),
-    /// Error returned from DFSqlLogicTest inner
-    Internal(String),
 }
 
 impl From<TestError> for DFSqlLogicTestError {
@@ -70,12 +64,6 @@ impl Display for DFSqlLogicTestError {
                 write!(f, "DataFusion error: {}", error)
             }
             DFSqlLogicTestError::Sql(error) => write!(f, "SQL Parser error: {}", error),
-            DFSqlLogicTestError::NotImplemented(error) => {
-                write!(f, "This feature is not implemented yet: {}", error)
-            }
-            DFSqlLogicTestError::Internal(error) => {
-                write!(f, "Internal error: {}", error)
-            }
         }
     }
 }
diff --git a/datafusion/core/tests/sqllogictests/src/insert/mod.rs b/datafusion/core/tests/sqllogictests/src/insert/mod.rs
index 100fa1184..025015f5f 100644
--- a/datafusion/core/tests/sqllogictests/src/insert/mod.rs
+++ b/datafusion/core/tests/sqllogictests/src/insert/mod.rs
@@ -17,28 +17,26 @@
 
 mod util;
 
-use crate::error::{DFSqlLogicTestError, Result};
+use crate::error::Result;
 use crate::insert::util::LogicTestContextProvider;
 use datafusion::datasource::MemTable;
 use datafusion::prelude::SessionContext;
 use datafusion_common::{DFSchema, DataFusionError};
 use datafusion_expr::Expr as DFExpr;
-use datafusion_sql::parser::{DFParser, Statement};
 use datafusion_sql::planner::SqlToRel;
 use sqlparser::ast::{Expr, SetExpr, Statement as SQLStatement};
 use std::collections::HashMap;
 
-pub async fn insert(ctx: &SessionContext, sql: String) -> Result<String> {
+pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result<String> {
     // First, use sqlparser to get table name and insert values
-    let mut table_name = "".to_string();
-    let mut insert_values: Vec<Vec<Expr>> = vec![];
-    if let Statement::Statement(statement) = &DFParser::parse_sql(&sql)?[0] {
-        if let SQLStatement::Insert {
+    let table_name;
+    let insert_values: Vec<Vec<Expr>>;
+    match insert_stmt {
+        SQLStatement::Insert {
             table_name: name,
             source,
             ..
-        } = &**statement
-        {
+        } => {
             // Todo: check columns match table schema
             table_name = name.to_string();
             match &*source.body {
@@ -46,17 +44,12 @@ pub async fn insert(ctx: &SessionContext, sql: String) -> Result<String> {
                     insert_values = values.0.clone();
                 }
                 _ => {
-                    return Err(DFSqlLogicTestError::NotImplemented(
-                        "Only support insert values".to_string(),
-                    ));
+                    // Directly panic: make it easy to find the location of the error.
+                    panic!()
                 }
             }
         }
-    } else {
-        return Err(DFSqlLogicTestError::Internal(format!(
-            "{:?} not an insert statement",
-            sql
-        )));
+        _ => unreachable!(),
     }
 
     // Second, get table by table name
@@ -65,11 +58,7 @@ pub async fn insert(ctx: &SessionContext, sql: String) -> Result<String> {
     let table_batches = table_provider
         .as_any()
         .downcast_ref::<MemTable>()
-        .ok_or_else(|| {
-            DFSqlLogicTestError::NotImplemented(
-                "only support use memory table in logictest".to_string(),
-            )
-        })?
+        .unwrap()
         .get_batches();
 
     // Third, transfer insert values to `RecordBatch`
diff --git a/datafusion/core/tests/sqllogictests/src/main.rs b/datafusion/core/tests/sqllogictests/src/main.rs
index 813a180e8..5fa261b36 100644
--- a/datafusion/core/tests/sqllogictests/src/main.rs
+++ b/datafusion/core/tests/sqllogictests/src/main.rs
@@ -19,7 +19,9 @@ use async_trait::async_trait;
 use datafusion::arrow::csv::WriterBuilder;
 use datafusion::arrow::record_batch::RecordBatch;
 use datafusion::prelude::{SessionConfig, SessionContext};
+use datafusion_sql::parser::{DFParser, Statement};
 use normalize::normalize_batch;
+use sqlparser::ast::Statement as SQLStatement;
 use std::path::Path;
 use std::time::Duration;
 
@@ -144,10 +146,12 @@ fn format_batches(batches: Vec<RecordBatch>) -> Result<String> {
 async fn run_query(ctx: &SessionContext, sql: impl Into<String>) -> Result<String> {
     let sql = sql.into();
     // Check if the sql is `insert`
-    if sql.trim_start().to_lowercase().starts_with("insert") {
-        // Process the insert statement
-        insert(ctx, sql).await?;
-        return Ok("".to_string());
+    if let Ok(statements) = DFParser::parse_sql(&sql) {
+        if let Statement::Statement(statement) = &statements[0] {
+            if let SQLStatement::Insert { .. } = &**statement {
+                return insert(ctx, statement).await;
+            }
+        }
     }
     let df = ctx.sql(sql.as_str()).await.unwrap();
     let results: Vec<RecordBatch> = df.collect().await.unwrap();