You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by av...@apache.org on 2023/01/23 17:20:27 UTC

[arrow-datafusion] branch master updated: Infer values for inserts (#4977)

This is an automated email from the ASF dual-hosted git repository.

avantgardner pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 5d4038a84 Infer values for inserts (#4977)
5d4038a84 is described below

commit 5d4038a8463a575328bedbc22b32456f5dcd562c
Author: Brent Gardner <bg...@squarelabs.net>
AuthorDate: Mon Jan 23 10:20:21 2023 -0700

    Infer values for inserts (#4977)
    
    * Infer values for updates
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
 datafusion/sql/src/statement.rs          | 33 ++++++++++++++-
 datafusion/sql/tests/integration_test.rs | 69 ++++++++++++++++++++++++++++++++
 2 files changed, 100 insertions(+), 2 deletions(-)

diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs
index 616619a1b..5b7949d2d 100644
--- a/datafusion/sql/src/statement.rs
+++ b/datafusion/sql/src/statement.rs
@@ -40,7 +40,7 @@ use datafusion_expr::{
 };
 use sqlparser::ast;
 use sqlparser::ast::{
-    Assignment, Expr as SQLExpr, Expr, Ident, ObjectName, ObjectType, Query,
+    Assignment, Expr as SQLExpr, Expr, Ident, ObjectName, ObjectType, Query, SetExpr,
     ShowCreateObject, ShowStatementFilter, Statement, TableFactor, TableWithJoins,
     UnaryOperator, Value,
 };
@@ -762,8 +762,37 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         let arrow_schema = (*provider.schema()).clone();
         let table_schema = Arc::new(DFSchema::try_from(arrow_schema)?);
 
+        // infer types for Values clause... other types should be resolvable the regular way
+        let mut prepare_param_data_types = BTreeMap::new();
+        if let SetExpr::Values(ast::Values { rows, .. }) = (*source.body).clone() {
+            for row in rows.iter() {
+                for (idx, val) in row.iter().enumerate() {
+                    if let ast::Expr::Value(Value::Placeholder(name)) = val {
+                        let name =
+                            name.replace('$', "").parse::<usize>().map_err(|_| {
+                                DataFusionError::Plan(format!(
+                                    "Can't parse placeholder: {name}"
+                                ))
+                            })? - 1;
+                        let col = columns.get(idx).ok_or_else(|| {
+                            DataFusionError::Plan(format!(
+                                "Placeholder ${} refers to a non existent column",
+                                idx + 1
+                            ))
+                        })?;
+                        let field =
+                            table_schema.field_with_name(None, col.value.as_str())?;
+                        let dt = field.field().data_type().clone();
+                        let _ = prepare_param_data_types.insert(name, dt);
+                    }
+                }
+            }
+        }
+        let prepare_param_data_types = prepare_param_data_types.into_values().collect();
+
         // Projection
-        let mut planner_context = PlannerContext::new();
+        let mut planner_context =
+            PlannerContext::new_with_prepare_param_data_types(prepare_param_data_types);
         let source = self.query_to_plan(*source, &mut planner_context)?;
         if columns.len() != source.schema().fields().len() {
             Err(DataFusionError::Plan(
diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/integration_test.rs
index e93ec8712..c771a3ec5 100644
--- a/datafusion/sql/tests/integration_test.rs
+++ b/datafusion/sql/tests/integration_test.rs
@@ -3390,6 +3390,75 @@ Dml: op=[Update] table=[person]
     prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
 }
 
+#[test]
+fn test_prepare_statement_insert_infer() {
+    let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3)";
+
+    let expected_plan = r#"
+Dml: op=[Insert] table=[person]
+  Projection: column1 AS id, column2 AS first_name, column3 AS last_name
+    Values: ($1, $2, $3)
+        "#
+    .trim();
+
+    let expected_dt = "[Int32]";
+    let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+    let actual_types = plan.get_parameter_types().unwrap();
+    let expected_types = HashMap::from([
+        ("$1".to_string(), Some(DataType::UInt32)),
+        ("$2".to_string(), Some(DataType::Utf8)),
+        ("$3".to_string(), Some(DataType::Utf8)),
+    ]);
+    assert_eq!(actual_types, expected_types);
+
+    // replace params with values
+    let param_values = vec![
+        ScalarValue::UInt32(Some(1)),
+        ScalarValue::Utf8(Some("Alan".to_string())),
+        ScalarValue::Utf8(Some("Turing".to_string())),
+    ];
+    let expected_plan = r#"
+Dml: op=[Insert] table=[person]
+  Projection: column1 AS id, column2 AS first_name, column3 AS last_name
+    Values: (UInt32(1), Utf8("Alan"), Utf8("Turing"))
+        "#
+    .trim();
+    let plan = plan.replace_params_with_values(&param_values).unwrap();
+
+    prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
+}
+
+#[test]
+#[should_panic(expected = "Placeholder $4 refers to a non existent column")]
+fn test_prepare_statement_insert_infer_gt() {
+    let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3, $4)";
+
+    let expected_plan = r#""#.trim();
+    let expected_dt = "[Int32]";
+    let _ = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+}
+
+#[test]
+#[should_panic(expected = "value: Plan(\"Column count doesn't match insert query!\")")]
+fn test_prepare_statement_insert_infer_lt() {
+    let sql = "insert into person (id, first_name, last_name) values ($1, $2)";
+
+    let expected_plan = r#""#.trim();
+    let expected_dt = "[Int32]";
+    let _ = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+}
+
+#[test]
+#[should_panic(expected = "value: Plan(\"Placeholder type could not be resolved\")")]
+fn test_prepare_statement_insert_infer_gap() {
+    let sql = "insert into person (id, first_name, last_name) values ($2, $4, $6)";
+
+    let expected_plan = r#""#.trim();
+    let expected_dt = "[Int32]";
+    let _ = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+}
+
 #[test]
 fn test_prepare_statement_to_plan_one_param() {
     let sql = "PREPARE my_plan(INT) AS SELECT id, age  FROM person WHERE age = $1";