You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/12/04 21:41:47 UTC

(arrow-datafusion) branch main updated: Support named query parameters (#8384)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 37bbd66543 Support named query parameters (#8384)
37bbd66543 is described below

commit 37bbd665439f8227971a3657a01205544694bed1
Author: Asura7969 <14...@qq.com>
AuthorDate: Tue Dec 5 05:41:40 2023 +0800

    Support named query parameters (#8384)
    
    * Minor: Improve the document format of JoinHashMap
    
    * support named query parameters
    
    * cargo fmt
    
    * add `ParamValues` conversion
    
    * improve doc
---
 datafusion/common/src/lib.rs             |   2 +
 datafusion/common/src/param_value.rs     | 149 +++++++++++++++++++++++++++++++
 datafusion/core/src/dataframe/mod.rs     |  30 ++++++-
 datafusion/core/tests/sql/select.rs      |  47 ++++++++++
 datafusion/expr/src/expr.rs              |   2 +-
 datafusion/expr/src/logical_plan/plan.rs |  66 +++-----------
 datafusion/sql/src/expr/value.rs         |   7 +-
 datafusion/sql/tests/sql_integration.rs  |  27 +++---
 8 files changed, 261 insertions(+), 69 deletions(-)

diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs
index 90fb4a8814..6df89624fc 100644
--- a/datafusion/common/src/lib.rs
+++ b/datafusion/common/src/lib.rs
@@ -20,6 +20,7 @@ mod dfschema;
 mod error;
 mod functional_dependencies;
 mod join_type;
+mod param_value;
 #[cfg(feature = "pyarrow")]
 mod pyarrow;
 mod schema_reference;
@@ -59,6 +60,7 @@ pub use functional_dependencies::{
     Constraints, Dependency, FunctionalDependence, FunctionalDependencies,
 };
 pub use join_type::{JoinConstraint, JoinSide, JoinType};
+pub use param_value::ParamValues;
 pub use scalar::{ScalarType, ScalarValue};
 pub use schema_reference::{OwnedSchemaReference, SchemaReference};
 pub use stats::{ColumnStatistics, Statistics};
diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs
new file mode 100644
index 0000000000..253c312b66
--- /dev/null
+++ b/datafusion/common/src/param_value.rs
@@ -0,0 +1,149 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::error::{_internal_err, _plan_err};
+use crate::{DataFusionError, Result, ScalarValue};
+use arrow_schema::DataType;
+use std::collections::HashMap;
+
+/// The parameter value corresponding to the placeholder
+#[derive(Debug, Clone)]
+pub enum ParamValues {
+    /// for positional query parameters, like select * from test where a > $1 and b = $2
+    LIST(Vec<ScalarValue>),
+    /// for named query parameters, like select * from test where a > $foo and b = $goo
+    MAP(HashMap<String, ScalarValue>),
+}
+
+impl ParamValues {
+    /// Verify parameter list length and type
+    pub fn verify(&self, expect: &Vec<DataType>) -> Result<()> {
+        match self {
+            ParamValues::LIST(list) => {
+                // Verify if the number of params matches the number of values
+                if expect.len() != list.len() {
+                    return _plan_err!(
+                        "Expected {} parameters, got {}",
+                        expect.len(),
+                        list.len()
+                    );
+                }
+
+                // Verify if the types of the params matches the types of the values
+                let iter = expect.iter().zip(list.iter());
+                for (i, (param_type, value)) in iter.enumerate() {
+                    if *param_type != value.data_type() {
+                        return _plan_err!(
+                            "Expected parameter of type {:?}, got {:?} at index {}",
+                            param_type,
+                            value.data_type(),
+                            i
+                        );
+                    }
+                }
+                Ok(())
+            }
+            ParamValues::MAP(_) => {
+                // If it is a named query, variables can be reused,
+                // but the lengths are not necessarily equal
+                Ok(())
+            }
+        }
+    }
+
+    pub fn get_placeholders_with_values(
+        &self,
+        id: &String,
+        data_type: &Option<DataType>,
+    ) -> Result<ScalarValue> {
+        match self {
+            ParamValues::LIST(list) => {
+                if id.is_empty() || id == "$0" {
+                    return _plan_err!("Empty placeholder id");
+                }
+                // convert id (in format $1, $2, ..) to idx (0, 1, ..)
+                let idx = id[1..].parse::<usize>().map_err(|e| {
+                    DataFusionError::Internal(format!(
+                        "Failed to parse placeholder id: {e}"
+                    ))
+                })? - 1;
+                // value at the idx-th position in param_values should be the value for the placeholder
+                let value = list.get(idx).ok_or_else(|| {
+                    DataFusionError::Internal(format!(
+                        "No value found for placeholder with id {id}"
+                    ))
+                })?;
+                // check if the data type of the value matches the data type of the placeholder
+                if Some(value.data_type()) != *data_type {
+                    return _internal_err!(
+                        "Placeholder value type mismatch: expected {:?}, got {:?}",
+                        data_type,
+                        value.data_type()
+                    );
+                }
+                Ok(value.clone())
+            }
+            ParamValues::MAP(map) => {
+                // convert name (in format $a, $b, ..) to mapped values (a, b, ..)
+                let name = &id[1..];
+                // value at the name position in param_values should be the value for the placeholder
+                let value = map.get(name).ok_or_else(|| {
+                    DataFusionError::Internal(format!(
+                        "No value found for placeholder with name {id}"
+                    ))
+                })?;
+                // check if the data type of the value matches the data type of the placeholder
+                if Some(value.data_type()) != *data_type {
+                    return _internal_err!(
+                        "Placeholder value type mismatch: expected {:?}, got {:?}",
+                        data_type,
+                        value.data_type()
+                    );
+                }
+                Ok(value.clone())
+            }
+        }
+    }
+}
+
+impl From<Vec<ScalarValue>> for ParamValues {
+    fn from(value: Vec<ScalarValue>) -> Self {
+        Self::LIST(value)
+    }
+}
+
+impl<K> From<Vec<(K, ScalarValue)>> for ParamValues
+where
+    K: Into<String>,
+{
+    fn from(value: Vec<(K, ScalarValue)>) -> Self {
+        let value: HashMap<String, ScalarValue> =
+            value.into_iter().map(|(k, v)| (k.into(), v)).collect();
+        Self::MAP(value)
+    }
+}
+
+impl<K> From<HashMap<K, ScalarValue>> for ParamValues
+where
+    K: Into<String>,
+{
+    fn from(value: HashMap<K, ScalarValue>) -> Self {
+        let value: HashMap<String, ScalarValue> =
+            value.into_iter().map(|(k, v)| (k.into(), v)).collect();
+        Self::MAP(value)
+    }
+}
diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs
index 89e82fa952..52b5157b73 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -32,11 +32,12 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions;
 use datafusion_common::file_options::json_writer::JsonWriterOptions;
 use datafusion_common::parsers::CompressionTypeVariant;
 use datafusion_common::{
-    DataFusionError, FileType, FileTypeWriterOptions, SchemaError, UnnestOptions,
+    DataFusionError, FileType, FileTypeWriterOptions, ParamValues, SchemaError,
+    UnnestOptions,
 };
 use datafusion_expr::dml::CopyOptions;
 
-use datafusion_common::{Column, DFSchema, ScalarValue};
+use datafusion_common::{Column, DFSchema};
 use datafusion_expr::{
     avg, count, is_null, max, median, min, stddev, utils::COUNT_STAR_EXPANSION,
     TableProviderFilterPushDown, UNNAMED_TABLE,
@@ -1227,11 +1228,32 @@ impl DataFrame {
     ///  ],
     ///  &results
     /// );
+    /// // Note you can also provide named parameters
+    /// let results = ctx
+    ///   .sql("SELECT a FROM example WHERE b = $my_param")
+    ///   .await?
+    ///    // replace $my_param with value 2
+    ///    // Note you can also use a HashMap as well
+    ///   .with_param_values(vec![
+    ///       ("my_param", ScalarValue::from(2i64))
+    ///    ])?
+    ///   .collect()
+    ///   .await?;
+    /// assert_batches_eq!(
+    ///  &[
+    ///    "+---+",
+    ///    "| a |",
+    ///    "+---+",
+    ///    "| 1 |",
+    ///    "+---+",
+    ///  ],
+    ///  &results
+    /// );
     /// # Ok(())
     /// # }
     /// ```
-    pub fn with_param_values(self, param_values: Vec<ScalarValue>) -> Result<Self> {
-        let plan = self.plan.with_param_values(param_values)?;
+    pub fn with_param_values(self, query_values: impl Into<ParamValues>) -> Result<Self> {
+        let plan = self.plan.with_param_values(query_values)?;
         Ok(Self::new(self.session_state, plan))
     }
 
diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs
index 63f3e97930..cbdea9d729 100644
--- a/datafusion/core/tests/sql/select.rs
+++ b/datafusion/core/tests/sql/select.rs
@@ -525,6 +525,53 @@ async fn test_prepare_statement() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn test_named_query_parameters() -> Result<()> {
+    let tmp_dir = TempDir::new()?;
+    let partition_count = 4;
+    let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?;
+
+    // sql to statement then to logical plan with parameters
+    // c1 defined as UINT32, c2 defined as UInt64
+    let results = ctx
+        .sql("SELECT c1, c2 FROM test WHERE c1 > $coo AND c1 < $foo")
+        .await?
+        .with_param_values(vec![
+            ("foo", ScalarValue::UInt32(Some(3))),
+            ("coo", ScalarValue::UInt32(Some(0))),
+        ])?
+        .collect()
+        .await?;
+    let expected = vec![
+        "+----+----+",
+        "| c1 | c2 |",
+        "+----+----+",
+        "| 1  | 1  |",
+        "| 1  | 2  |",
+        "| 1  | 3  |",
+        "| 1  | 4  |",
+        "| 1  | 5  |",
+        "| 1  | 6  |",
+        "| 1  | 7  |",
+        "| 1  | 8  |",
+        "| 1  | 9  |",
+        "| 1  | 10 |",
+        "| 2  | 1  |",
+        "| 2  | 2  |",
+        "| 2  | 3  |",
+        "| 2  | 4  |",
+        "| 2  | 5  |",
+        "| 2  | 6  |",
+        "| 2  | 7  |",
+        "| 2  | 8  |",
+        "| 2  | 9  |",
+        "| 2  | 10 |",
+        "+----+----+",
+    ];
+    assert_batches_sorted_eq!(expected, &results);
+    Ok(())
+}
+
 #[tokio::test]
 async fn parallel_query_with_filter() -> Result<()> {
     let tmp_dir = TempDir::new()?;
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index ee9b0ad6f9..6fa400454d 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -671,7 +671,7 @@ impl InSubquery {
     }
 }
 
-/// Placeholder, representing bind parameter values such as `$1`.
+/// Placeholder, representing bind parameter values such as `$1` or `$name`.
 ///
 /// The type of these parameters is inferred using [`Expr::infer_placeholder_types`]
 /// or can be specified directly using `PREPARE` statements.
diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs
index 9bb47c7da0..fc8590294f 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -48,7 +48,7 @@ use datafusion_common::tree_node::{
 use datafusion_common::{
     aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints,
     DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies,
-    OwnedTableReference, Result, ScalarValue, UnnestOptions,
+    OwnedTableReference, ParamValues, Result, UnnestOptions,
 };
 // backwards compatibility
 pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan};
@@ -993,32 +993,12 @@ impl LogicalPlan {
     /// ```
     pub fn with_param_values(
         self,
-        param_values: Vec<ScalarValue>,
+        param_values: impl Into<ParamValues>,
     ) -> Result<LogicalPlan> {
+        let param_values = param_values.into();
         match self {
             LogicalPlan::Prepare(prepare_lp) => {
-                // Verify if the number of params matches the number of values
-                if prepare_lp.data_types.len() != param_values.len() {
-                    return plan_err!(
-                        "Expected {} parameters, got {}",
-                        prepare_lp.data_types.len(),
-                        param_values.len()
-                    );
-                }
-
-                // Verify if the types of the params matches the types of the values
-                let iter = prepare_lp.data_types.iter().zip(param_values.iter());
-                for (i, (param_type, value)) in iter.enumerate() {
-                    if *param_type != value.data_type() {
-                        return plan_err!(
-                            "Expected parameter of type {:?}, got {:?} at index {}",
-                            param_type,
-                            value.data_type(),
-                            i
-                        );
-                    }
-                }
-
+                param_values.verify(&prepare_lp.data_types)?;
                 let input_plan = prepare_lp.input;
                 input_plan.replace_params_with_values(&param_values)
             }
@@ -1182,7 +1162,7 @@ impl LogicalPlan {
     /// See [`Self::with_param_values`] for examples and usage
     pub fn replace_params_with_values(
         &self,
-        param_values: &[ScalarValue],
+        param_values: &ParamValues,
     ) -> Result<LogicalPlan> {
         let new_exprs = self
             .expressions()
@@ -1239,36 +1219,15 @@ impl LogicalPlan {
     /// corresponding values provided in the params_values
     fn replace_placeholders_with_values(
         expr: Expr,
-        param_values: &[ScalarValue],
+        param_values: &ParamValues,
     ) -> Result<Expr> {
         expr.transform(&|expr| {
             match &expr {
                 Expr::Placeholder(Placeholder { id, data_type }) => {
-                    if id.is_empty() || id == "$0" {
-                        return plan_err!("Empty placeholder id");
-                    }
-                    // convert id (in format $1, $2, ..) to idx (0, 1, ..)
-                    let idx = id[1..].parse::<usize>().map_err(|e| {
-                        DataFusionError::Internal(format!(
-                            "Failed to parse placeholder id: {e}"
-                        ))
-                    })? - 1;
-                    // value at the idx-th position in param_values should be the value for the placeholder
-                    let value = param_values.get(idx).ok_or_else(|| {
-                        DataFusionError::Internal(format!(
-                            "No value found for placeholder with id {id}"
-                        ))
-                    })?;
-                    // check if the data type of the value matches the data type of the placeholder
-                    if Some(value.data_type()) != *data_type {
-                        return internal_err!(
-                            "Placeholder value type mismatch: expected {:?}, got {:?}",
-                            data_type,
-                            value.data_type()
-                        );
-                    }
+                    let value =
+                        param_values.get_placeholders_with_values(id, data_type)?;
                     // Replace the placeholder with the value
-                    Ok(Transformed::Yes(Expr::Literal(value.clone())))
+                    Ok(Transformed::Yes(Expr::Literal(value)))
                 }
                 Expr::ScalarSubquery(qry) => {
                     let subquery =
@@ -2580,7 +2539,7 @@ mod tests {
     use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet};
     use arrow::datatypes::{DataType, Field, Schema};
     use datafusion_common::tree_node::TreeNodeVisitor;
-    use datafusion_common::{not_impl_err, DFSchema, TableReference};
+    use datafusion_common::{not_impl_err, DFSchema, ScalarValue, TableReference};
     use std::collections::HashMap;
 
     fn employee_schema() -> Schema {
@@ -3028,7 +2987,8 @@ digraph {
             .build()
             .unwrap();
 
-        plan.replace_params_with_values(&[42i32.into()])
+        let param_values = vec![ScalarValue::Int32(Some(42))];
+        plan.replace_params_with_values(&param_values.clone().into())
             .expect_err("unexpectedly succeeded to replace an invalid placeholder");
 
         // test $0 placeholder
@@ -3041,7 +3001,7 @@ digraph {
             .build()
             .unwrap();
 
-        plan.replace_params_with_values(&[42i32.into()])
+        plan.replace_params_with_values(&param_values.into())
             .expect_err("unexpectedly succeeded to replace an invalid placeholder");
     }
 
diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs
index a3f29da488..708f7c6001 100644
--- a/datafusion/sql/src/expr/value.rs
+++ b/datafusion/sql/src/expr/value.rs
@@ -108,7 +108,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
             }
             Ok(index) => index - 1,
             Err(_) => {
-                return plan_err!("Invalid placeholder, not a number: {param}");
+                return if param_data_types.is_empty() {
+                    Ok(Expr::Placeholder(Placeholder::new(param, None)))
+                } else {
+                    // when PREPARE Statement, param_data_types length is always 0
+                    plan_err!("Invalid placeholder, not a number: {param}")
+                };
             }
         };
         // Check if the placeholder is in the parameter list
diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs
index d5b06bcf81..83bdb954b1 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -22,11 +22,11 @@ use std::{sync::Arc, vec};
 use arrow_schema::*;
 use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect};
 
-use datafusion_common::plan_err;
 use datafusion_common::{
     assert_contains, config::ConfigOptions, DataFusionError, Result, ScalarValue,
     TableReference,
 };
+use datafusion_common::{plan_err, ParamValues};
 use datafusion_expr::{
     logical_plan::{LogicalPlan, Prepare},
     AggregateUDF, ScalarUDF, TableSource, WindowUDF,
@@ -471,6 +471,10 @@ Dml: op=[Insert Into] table=[test_decimal]
     "INSERT INTO person (id, first_name, last_name) VALUES ($2, $4, $6)",
     "Error during planning: Placeholder type could not be resolved"
 )]
+#[case::placeholder_type_unresolved(
+    "INSERT INTO person (id, first_name, last_name) VALUES ($id, $first_name, $last_name)",
+    "Error during planning: Can't parse placeholder: $id"
+)]
 #[test]
 fn test_insert_schema_errors(#[case] sql: &str, #[case] error: &str) {
     let err = logical_plan(sql).unwrap_err();
@@ -2674,7 +2678,7 @@ fn prepare_stmt_quick_test(
 
 fn prepare_stmt_replace_params_quick_test(
     plan: LogicalPlan,
-    param_values: Vec<ScalarValue>,
+    param_values: impl Into<ParamValues>,
     expected_plan: &str,
 ) -> LogicalPlan {
     // replace params
@@ -3726,7 +3730,7 @@ fn test_prepare_statement_to_plan_no_param() {
 
     ///////////////////
     // replace params with values
-    let param_values = vec![];
+    let param_values: Vec<ScalarValue> = vec![];
     let expected_plan = "Projection: person.id, person.age\
         \n  Filter: person.age = Int64(10)\
         \n    TableScan: person";
@@ -3740,7 +3744,7 @@ fn test_prepare_statement_to_plan_one_param_no_value_panic() {
     let sql = "PREPARE my_plan(INT) AS SELECT id, age  FROM person WHERE age = 10";
     let plan = logical_plan(sql).unwrap();
     // declare 1 param but provide 0
-    let param_values = vec![];
+    let param_values: Vec<ScalarValue> = vec![];
     assert_eq!(
         plan.with_param_values(param_values)
             .unwrap_err()
@@ -3853,7 +3857,7 @@ Projection: person.id, orders.order_id
     assert_eq!(actual_types, expected_types);
 
     // replace params with values
-    let param_values = vec![ScalarValue::Int32(Some(10))];
+    let param_values = vec![ScalarValue::Int32(Some(10))].into();
     let expected_plan = r#"
 Projection: person.id, orders.order_id
   Inner Join:  Filter: person.id = orders.customer_id AND person.age = Int32(10)
@@ -3885,7 +3889,7 @@ Projection: person.id, person.age
     assert_eq!(actual_types, expected_types);
 
     // replace params with values
-    let param_values = vec![ScalarValue::Int32(Some(10))];
+    let param_values = vec![ScalarValue::Int32(Some(10))].into();
     let expected_plan = r#"
 Projection: person.id, person.age
   Filter: person.age = Int32(10)
@@ -3919,7 +3923,8 @@ Projection: person.id, person.age
     assert_eq!(actual_types, expected_types);
 
     // replace params with values
-    let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))];
+    let param_values =
+        vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))].into();
     let expected_plan = r#"
 Projection: person.id, person.age
   Filter: person.age BETWEEN Int32(10) AND Int32(30)
@@ -3955,7 +3960,7 @@ Projection: person.id, person.age
     assert_eq!(actual_types, expected_types);
 
     // replace params with values
-    let param_values = vec![ScalarValue::UInt32(Some(10))];
+    let param_values = vec![ScalarValue::UInt32(Some(10))].into();
     let expected_plan = r#"
 Projection: person.id, person.age
   Filter: person.age = (<subquery>)
@@ -3995,7 +4000,8 @@ Dml: op=[Update] table=[person]
     assert_eq!(actual_types, expected_types);
 
     // replace params with values
-    let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))];
+    let param_values =
+        vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))].into();
     let expected_plan = r#"
 Dml: op=[Update] table=[person]
   Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀
@@ -4034,7 +4040,8 @@ fn test_prepare_statement_insert_infer() {
         ScalarValue::UInt32(Some(1)),
         ScalarValue::Utf8(Some("Alan".to_string())),
         ScalarValue::Utf8(Some("Turing".to_string())),
-    ];
+    ]
+    .into();
     let expected_plan = "Dml: op=[Insert Into] table=[person]\
                         \n  Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \
                                     CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \