You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/12/04 21:41:47 UTC
(arrow-datafusion) branch main updated: Support named query parameters (#8384)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 37bbd66543 Support named query parameters (#8384)
37bbd66543 is described below
commit 37bbd665439f8227971a3657a01205544694bed1
Author: Asura7969 <14...@qq.com>
AuthorDate: Tue Dec 5 05:41:40 2023 +0800
Support named query parameters (#8384)
* Minor: Improve the document format of JoinHashMap
* support named query parameters
* cargo fmt
* add `ParamValues` conversion
* improve doc
---
datafusion/common/src/lib.rs | 2 +
datafusion/common/src/param_value.rs | 149 +++++++++++++++++++++++++++++++
datafusion/core/src/dataframe/mod.rs | 30 ++++++-
datafusion/core/tests/sql/select.rs | 47 ++++++++++
datafusion/expr/src/expr.rs | 2 +-
datafusion/expr/src/logical_plan/plan.rs | 66 +++-----------
datafusion/sql/src/expr/value.rs | 7 +-
datafusion/sql/tests/sql_integration.rs | 27 +++---
8 files changed, 261 insertions(+), 69 deletions(-)
diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs
index 90fb4a8814..6df89624fc 100644
--- a/datafusion/common/src/lib.rs
+++ b/datafusion/common/src/lib.rs
@@ -20,6 +20,7 @@ mod dfschema;
mod error;
mod functional_dependencies;
mod join_type;
+mod param_value;
#[cfg(feature = "pyarrow")]
mod pyarrow;
mod schema_reference;
@@ -59,6 +60,7 @@ pub use functional_dependencies::{
Constraints, Dependency, FunctionalDependence, FunctionalDependencies,
};
pub use join_type::{JoinConstraint, JoinSide, JoinType};
+pub use param_value::ParamValues;
pub use scalar::{ScalarType, ScalarValue};
pub use schema_reference::{OwnedSchemaReference, SchemaReference};
pub use stats::{ColumnStatistics, Statistics};
diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs
new file mode 100644
index 0000000000..253c312b66
--- /dev/null
+++ b/datafusion/common/src/param_value.rs
@@ -0,0 +1,149 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::error::{_internal_err, _plan_err};
+use crate::{DataFusionError, Result, ScalarValue};
+use arrow_schema::DataType;
+use std::collections::HashMap;
+
+/// The parameter value corresponding to the placeholder
+#[derive(Debug, Clone)]
+pub enum ParamValues {
+ /// for positional query parameters, like select * from test where a > $1 and b = $2
+ LIST(Vec<ScalarValue>),
+ /// for named query parameters, like select * from test where a > $foo and b = $goo
+ MAP(HashMap<String, ScalarValue>),
+}
+
+impl ParamValues {
+ /// Verify parameter list length and type
+ pub fn verify(&self, expect: &Vec<DataType>) -> Result<()> {
+ match self {
+ ParamValues::LIST(list) => {
+ // Verify if the number of params matches the number of values
+ if expect.len() != list.len() {
+ return _plan_err!(
+ "Expected {} parameters, got {}",
+ expect.len(),
+ list.len()
+ );
+ }
+
+ // Verify if the types of the params matches the types of the values
+ let iter = expect.iter().zip(list.iter());
+ for (i, (param_type, value)) in iter.enumerate() {
+ if *param_type != value.data_type() {
+ return _plan_err!(
+ "Expected parameter of type {:?}, got {:?} at index {}",
+ param_type,
+ value.data_type(),
+ i
+ );
+ }
+ }
+ Ok(())
+ }
+ ParamValues::MAP(_) => {
+ // If it is a named query, variables can be reused,
+ // but the lengths are not necessarily equal
+ Ok(())
+ }
+ }
+ }
+
+ pub fn get_placeholders_with_values(
+ &self,
+ id: &String,
+ data_type: &Option<DataType>,
+ ) -> Result<ScalarValue> {
+ match self {
+ ParamValues::LIST(list) => {
+ if id.is_empty() || id == "$0" {
+ return _plan_err!("Empty placeholder id");
+ }
+ // convert id (in format $1, $2, ..) to idx (0, 1, ..)
+ let idx = id[1..].parse::<usize>().map_err(|e| {
+ DataFusionError::Internal(format!(
+ "Failed to parse placeholder id: {e}"
+ ))
+ })? - 1;
+ // value at the idx-th position in param_values should be the value for the placeholder
+ let value = list.get(idx).ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "No value found for placeholder with id {id}"
+ ))
+ })?;
+ // check if the data type of the value matches the data type of the placeholder
+ if Some(value.data_type()) != *data_type {
+ return _internal_err!(
+ "Placeholder value type mismatch: expected {:?}, got {:?}",
+ data_type,
+ value.data_type()
+ );
+ }
+ Ok(value.clone())
+ }
+ ParamValues::MAP(map) => {
+ // convert name (in format $a, $b, ..) to mapped values (a, b, ..)
+ let name = &id[1..];
+ // value at the name position in param_values should be the value for the placeholder
+ let value = map.get(name).ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "No value found for placeholder with name {id}"
+ ))
+ })?;
+ // check if the data type of the value matches the data type of the placeholder
+ if Some(value.data_type()) != *data_type {
+ return _internal_err!(
+ "Placeholder value type mismatch: expected {:?}, got {:?}",
+ data_type,
+ value.data_type()
+ );
+ }
+ Ok(value.clone())
+ }
+ }
+ }
+}
+
+impl From<Vec<ScalarValue>> for ParamValues {
+ fn from(value: Vec<ScalarValue>) -> Self {
+ Self::LIST(value)
+ }
+}
+
+impl<K> From<Vec<(K, ScalarValue)>> for ParamValues
+where
+ K: Into<String>,
+{
+ fn from(value: Vec<(K, ScalarValue)>) -> Self {
+ let value: HashMap<String, ScalarValue> =
+ value.into_iter().map(|(k, v)| (k.into(), v)).collect();
+ Self::MAP(value)
+ }
+}
+
+impl<K> From<HashMap<K, ScalarValue>> for ParamValues
+where
+ K: Into<String>,
+{
+ fn from(value: HashMap<K, ScalarValue>) -> Self {
+ let value: HashMap<String, ScalarValue> =
+ value.into_iter().map(|(k, v)| (k.into(), v)).collect();
+ Self::MAP(value)
+ }
+}
diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs
index 89e82fa952..52b5157b73 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -32,11 +32,12 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions;
use datafusion_common::file_options::json_writer::JsonWriterOptions;
use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_common::{
- DataFusionError, FileType, FileTypeWriterOptions, SchemaError, UnnestOptions,
+ DataFusionError, FileType, FileTypeWriterOptions, ParamValues, SchemaError,
+ UnnestOptions,
};
use datafusion_expr::dml::CopyOptions;
-use datafusion_common::{Column, DFSchema, ScalarValue};
+use datafusion_common::{Column, DFSchema};
use datafusion_expr::{
avg, count, is_null, max, median, min, stddev, utils::COUNT_STAR_EXPANSION,
TableProviderFilterPushDown, UNNAMED_TABLE,
@@ -1227,11 +1228,32 @@ impl DataFrame {
/// ],
/// &results
/// );
+ /// // Note you can also provide named parameters
+ /// let results = ctx
+ /// .sql("SELECT a FROM example WHERE b = $my_param")
+ /// .await?
+ /// // replace $my_param with value 2
+ /// // Note you can also use a HashMap as well
+ /// .with_param_values(vec![
+ /// ("my_param", ScalarValue::from(2i64))
+ /// ])?
+ /// .collect()
+ /// .await?;
+ /// assert_batches_eq!(
+ /// &[
+ /// "+---+",
+ /// "| a |",
+ /// "+---+",
+ /// "| 1 |",
+ /// "+---+",
+ /// ],
+ /// &results
+ /// );
/// # Ok(())
/// # }
/// ```
- pub fn with_param_values(self, param_values: Vec<ScalarValue>) -> Result<Self> {
- let plan = self.plan.with_param_values(param_values)?;
+ pub fn with_param_values(self, query_values: impl Into<ParamValues>) -> Result<Self> {
+ let plan = self.plan.with_param_values(query_values)?;
Ok(Self::new(self.session_state, plan))
}
diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs
index 63f3e97930..cbdea9d729 100644
--- a/datafusion/core/tests/sql/select.rs
+++ b/datafusion/core/tests/sql/select.rs
@@ -525,6 +525,53 @@ async fn test_prepare_statement() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn test_named_query_parameters() -> Result<()> {
+ let tmp_dir = TempDir::new()?;
+ let partition_count = 4;
+ let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?;
+
+ // sql to statement then to logical plan with parameters
+ // c1 defined as UINT32, c2 defined as UInt64
+ let results = ctx
+ .sql("SELECT c1, c2 FROM test WHERE c1 > $coo AND c1 < $foo")
+ .await?
+ .with_param_values(vec![
+ ("foo", ScalarValue::UInt32(Some(3))),
+ ("coo", ScalarValue::UInt32(Some(0))),
+ ])?
+ .collect()
+ .await?;
+ let expected = vec![
+ "+----+----+",
+ "| c1 | c2 |",
+ "+----+----+",
+ "| 1 | 1 |",
+ "| 1 | 2 |",
+ "| 1 | 3 |",
+ "| 1 | 4 |",
+ "| 1 | 5 |",
+ "| 1 | 6 |",
+ "| 1 | 7 |",
+ "| 1 | 8 |",
+ "| 1 | 9 |",
+ "| 1 | 10 |",
+ "| 2 | 1 |",
+ "| 2 | 2 |",
+ "| 2 | 3 |",
+ "| 2 | 4 |",
+ "| 2 | 5 |",
+ "| 2 | 6 |",
+ "| 2 | 7 |",
+ "| 2 | 8 |",
+ "| 2 | 9 |",
+ "| 2 | 10 |",
+ "+----+----+",
+ ];
+ assert_batches_sorted_eq!(expected, &results);
+ Ok(())
+}
+
#[tokio::test]
async fn parallel_query_with_filter() -> Result<()> {
let tmp_dir = TempDir::new()?;
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index ee9b0ad6f9..6fa400454d 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -671,7 +671,7 @@ impl InSubquery {
}
}
-/// Placeholder, representing bind parameter values such as `$1`.
+/// Placeholder, representing bind parameter values such as `$1` or `$name`.
///
/// The type of these parameters is inferred using [`Expr::infer_placeholder_types`]
/// or can be specified directly using `PREPARE` statements.
diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs
index 9bb47c7da0..fc8590294f 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -48,7 +48,7 @@ use datafusion_common::tree_node::{
use datafusion_common::{
aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints,
DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies,
- OwnedTableReference, Result, ScalarValue, UnnestOptions,
+ OwnedTableReference, ParamValues, Result, UnnestOptions,
};
// backwards compatibility
pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan};
@@ -993,32 +993,12 @@ impl LogicalPlan {
/// ```
pub fn with_param_values(
self,
- param_values: Vec<ScalarValue>,
+ param_values: impl Into<ParamValues>,
) -> Result<LogicalPlan> {
+ let param_values = param_values.into();
match self {
LogicalPlan::Prepare(prepare_lp) => {
- // Verify if the number of params matches the number of values
- if prepare_lp.data_types.len() != param_values.len() {
- return plan_err!(
- "Expected {} parameters, got {}",
- prepare_lp.data_types.len(),
- param_values.len()
- );
- }
-
- // Verify if the types of the params matches the types of the values
- let iter = prepare_lp.data_types.iter().zip(param_values.iter());
- for (i, (param_type, value)) in iter.enumerate() {
- if *param_type != value.data_type() {
- return plan_err!(
- "Expected parameter of type {:?}, got {:?} at index {}",
- param_type,
- value.data_type(),
- i
- );
- }
- }
-
+ param_values.verify(&prepare_lp.data_types)?;
let input_plan = prepare_lp.input;
input_plan.replace_params_with_values(¶m_values)
}
@@ -1182,7 +1162,7 @@ impl LogicalPlan {
/// See [`Self::with_param_values`] for examples and usage
pub fn replace_params_with_values(
&self,
- param_values: &[ScalarValue],
+ param_values: &ParamValues,
) -> Result<LogicalPlan> {
let new_exprs = self
.expressions()
@@ -1239,36 +1219,15 @@ impl LogicalPlan {
/// corresponding values provided in the params_values
fn replace_placeholders_with_values(
expr: Expr,
- param_values: &[ScalarValue],
+ param_values: &ParamValues,
) -> Result<Expr> {
expr.transform(&|expr| {
match &expr {
Expr::Placeholder(Placeholder { id, data_type }) => {
- if id.is_empty() || id == "$0" {
- return plan_err!("Empty placeholder id");
- }
- // convert id (in format $1, $2, ..) to idx (0, 1, ..)
- let idx = id[1..].parse::<usize>().map_err(|e| {
- DataFusionError::Internal(format!(
- "Failed to parse placeholder id: {e}"
- ))
- })? - 1;
- // value at the idx-th position in param_values should be the value for the placeholder
- let value = param_values.get(idx).ok_or_else(|| {
- DataFusionError::Internal(format!(
- "No value found for placeholder with id {id}"
- ))
- })?;
- // check if the data type of the value matches the data type of the placeholder
- if Some(value.data_type()) != *data_type {
- return internal_err!(
- "Placeholder value type mismatch: expected {:?}, got {:?}",
- data_type,
- value.data_type()
- );
- }
+ let value =
+ param_values.get_placeholders_with_values(id, data_type)?;
// Replace the placeholder with the value
- Ok(Transformed::Yes(Expr::Literal(value.clone())))
+ Ok(Transformed::Yes(Expr::Literal(value)))
}
Expr::ScalarSubquery(qry) => {
let subquery =
@@ -2580,7 +2539,7 @@ mod tests {
use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::tree_node::TreeNodeVisitor;
- use datafusion_common::{not_impl_err, DFSchema, TableReference};
+ use datafusion_common::{not_impl_err, DFSchema, ScalarValue, TableReference};
use std::collections::HashMap;
fn employee_schema() -> Schema {
@@ -3028,7 +2987,8 @@ digraph {
.build()
.unwrap();
- plan.replace_params_with_values(&[42i32.into()])
+ let param_values = vec![ScalarValue::Int32(Some(42))];
+ plan.replace_params_with_values(¶m_values.clone().into())
.expect_err("unexpectedly succeeded to replace an invalid placeholder");
// test $0 placeholder
@@ -3041,7 +3001,7 @@ digraph {
.build()
.unwrap();
- plan.replace_params_with_values(&[42i32.into()])
+ plan.replace_params_with_values(¶m_values.into())
.expect_err("unexpectedly succeeded to replace an invalid placeholder");
}
diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs
index a3f29da488..708f7c6001 100644
--- a/datafusion/sql/src/expr/value.rs
+++ b/datafusion/sql/src/expr/value.rs
@@ -108,7 +108,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
Ok(index) => index - 1,
Err(_) => {
- return plan_err!("Invalid placeholder, not a number: {param}");
+ return if param_data_types.is_empty() {
+ Ok(Expr::Placeholder(Placeholder::new(param, None)))
+ } else {
+ // when PREPARE Statement, param_data_types length is always 0
+ plan_err!("Invalid placeholder, not a number: {param}")
+ };
}
};
// Check if the placeholder is in the parameter list
diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs
index d5b06bcf81..83bdb954b1 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -22,11 +22,11 @@ use std::{sync::Arc, vec};
use arrow_schema::*;
use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect};
-use datafusion_common::plan_err;
use datafusion_common::{
assert_contains, config::ConfigOptions, DataFusionError, Result, ScalarValue,
TableReference,
};
+use datafusion_common::{plan_err, ParamValues};
use datafusion_expr::{
logical_plan::{LogicalPlan, Prepare},
AggregateUDF, ScalarUDF, TableSource, WindowUDF,
@@ -471,6 +471,10 @@ Dml: op=[Insert Into] table=[test_decimal]
"INSERT INTO person (id, first_name, last_name) VALUES ($2, $4, $6)",
"Error during planning: Placeholder type could not be resolved"
)]
+#[case::placeholder_type_unresolved(
+ "INSERT INTO person (id, first_name, last_name) VALUES ($id, $first_name, $last_name)",
+ "Error during planning: Can't parse placeholder: $id"
+)]
#[test]
fn test_insert_schema_errors(#[case] sql: &str, #[case] error: &str) {
let err = logical_plan(sql).unwrap_err();
@@ -2674,7 +2678,7 @@ fn prepare_stmt_quick_test(
fn prepare_stmt_replace_params_quick_test(
plan: LogicalPlan,
- param_values: Vec<ScalarValue>,
+ param_values: impl Into<ParamValues>,
expected_plan: &str,
) -> LogicalPlan {
// replace params
@@ -3726,7 +3730,7 @@ fn test_prepare_statement_to_plan_no_param() {
///////////////////
// replace params with values
- let param_values = vec![];
+ let param_values: Vec<ScalarValue> = vec![];
let expected_plan = "Projection: person.id, person.age\
\n Filter: person.age = Int64(10)\
\n TableScan: person";
@@ -3740,7 +3744,7 @@ fn test_prepare_statement_to_plan_one_param_no_value_panic() {
let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10";
let plan = logical_plan(sql).unwrap();
// declare 1 param but provide 0
- let param_values = vec![];
+ let param_values: Vec<ScalarValue> = vec![];
assert_eq!(
plan.with_param_values(param_values)
.unwrap_err()
@@ -3853,7 +3857,7 @@ Projection: person.id, orders.order_id
assert_eq!(actual_types, expected_types);
// replace params with values
- let param_values = vec![ScalarValue::Int32(Some(10))];
+ let param_values = vec![ScalarValue::Int32(Some(10))].into();
let expected_plan = r#"
Projection: person.id, orders.order_id
Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10)
@@ -3885,7 +3889,7 @@ Projection: person.id, person.age
assert_eq!(actual_types, expected_types);
// replace params with values
- let param_values = vec![ScalarValue::Int32(Some(10))];
+ let param_values = vec![ScalarValue::Int32(Some(10))].into();
let expected_plan = r#"
Projection: person.id, person.age
Filter: person.age = Int32(10)
@@ -3919,7 +3923,8 @@ Projection: person.id, person.age
assert_eq!(actual_types, expected_types);
// replace params with values
- let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))];
+ let param_values =
+ vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))].into();
let expected_plan = r#"
Projection: person.id, person.age
Filter: person.age BETWEEN Int32(10) AND Int32(30)
@@ -3955,7 +3960,7 @@ Projection: person.id, person.age
assert_eq!(actual_types, expected_types);
// replace params with values
- let param_values = vec![ScalarValue::UInt32(Some(10))];
+ let param_values = vec![ScalarValue::UInt32(Some(10))].into();
let expected_plan = r#"
Projection: person.id, person.age
Filter: person.age = (<subquery>)
@@ -3995,7 +4000,8 @@ Dml: op=[Update] table=[person]
assert_eq!(actual_types, expected_types);
// replace params with values
- let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))];
+ let param_values =
+ vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))].into();
let expected_plan = r#"
Dml: op=[Update] table=[person]
Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀
@@ -4034,7 +4040,8 @@ fn test_prepare_statement_insert_infer() {
ScalarValue::UInt32(Some(1)),
ScalarValue::Utf8(Some("Alan".to_string())),
ScalarValue::Utf8(Some("Turing".to_string())),
- ];
+ ]
+ .into();
let expected_plan = "Dml: op=[Insert Into] table=[person]\
\n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \
CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \