You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/03/10 01:32:01 UTC
[spark] branch branch-3.4 updated: [SPARK-42702][SPARK-42623][SQL] Support parameterized query in subquery and CTE
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 80127286c5f [SPARK-42702][SPARK-42623][SQL] Support parameterized query in subquery and CTE
80127286c5f is described below
commit 80127286c5fd9cd472c868e0bf8ebcec4cf399dc
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Fri Mar 10 09:31:20 2023 +0800
[SPARK-42702][SPARK-42623][SQL] Support parameterized query in subquery and CTE
### What changes were proposed in this pull request?
This PR fixes a few issues of parameterized query:
1. replace placeholders in CTE/subqueries
2. don't replace placeholders in non-DML commands as it may store the original SQL text with placeholders and we can't resolve it later (e.g. CREATE VIEW).
### Why are the changes needed?
make the parameterized query feature complete
### Does this PR introduce _any_ user-facing change?
yes, bug fix
### How was this patch tested?
new tests
Closes #40333 from cloud-fan/parameter.
Authored-by: Wenchen Fan <we...@databricks.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
(cherry picked from commit a7807038d5be5e46634d5bf807dd12fa63546b33)
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../sql/connect/planner/SparkConnectPlanner.scala | 10 +-
core/src/main/resources/error/error-classes.json | 5 +
.../spark/sql/catalyst/analysis/Analyzer.scala | 1 +
.../sql/catalyst/analysis/CheckAnalysis.scala | 2 +-
.../spark/sql/catalyst/analysis/parameters.scala | 112 +++++++++++++++++++++
.../sql/catalyst/expressions/parameters.scala | 64 ------------
.../spark/sql/catalyst/trees/TreePatterns.scala | 1 +
.../sql/catalyst/analysis/AnalysisSuite.scala | 24 +++--
.../sql/catalyst/parser/PlanParserSuite.scala | 2 +-
.../scala/org/apache/spark/sql/SparkSession.scala | 12 ++-
.../org/apache/spark/sql/ParametersSuite.scala | 100 +++++++++++++++++-
11 files changed, 247 insertions(+), 86 deletions(-)
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 5dd0a7ea309..24717e07b00 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -32,7 +32,7 @@ import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
@@ -209,8 +209,12 @@ class SparkConnectPlanner(val session: SparkSession) {
private def transformSql(sql: proto.SQL): LogicalPlan = {
val args = sql.getArgsMap.asScala.toMap
val parser = session.sessionState.sqlParser
- val parsedArgs = args.mapValues(parser.parseExpression).toMap
- Parameter.bind(parser.parsePlan(sql.getQuery), parsedArgs)
+ val parsedPlan = parser.parsePlan(sql.getQuery)
+ if (args.nonEmpty) {
+ ParameterizedQuery(parsedPlan, args.mapValues(parser.parseExpression).toMap)
+ } else {
+ parsedPlan
+ }
}
private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan = {
diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json
index 2780c98bfc6..7d16365c677 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -1714,6 +1714,11 @@
"Pandas user defined aggregate function in the PIVOT clause."
]
},
+ "PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT" : {
+ "message" : [
+ "Parameter markers in unexpected statement: <statement>. Parameter markers must only be used in a query, or DML statement."
+ ]
+ },
"PIVOT_AFTER_GROUP_BY" : {
"message" : [
"PIVOT clause following a GROUP BY clause. Consider pushing the GROUP BY into a subquery."
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d7cc34d6f15..e5d78b21f19 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -265,6 +265,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// at the beginning of analysis.
OptimizeUpdateFields,
CTESubstitution,
+ BindParameters,
WindowsSubstitution,
EliminateUnions,
SubstituteUnresolvedOrdinals),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 77948735dbe..3af15d2465a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -336,7 +336,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
case p: Parameter =>
p.failAnalysis(
errorClass = "UNBOUND_SQL_PARAMETER",
- messageParameters = Map("name" -> toSQLId(p.name)))
+ messageParameters = Map("name" -> p.name))
case _ =>
})
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
new file mode 100644
index 00000000000..29c36300673
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, LeafExpression, Literal, SubqueryExpression, Unevaluable}
+import org.apache.spark.sql.catalyst.plans.logical.{Command, DeleteFromTable, InsertIntoStatement, LogicalPlan, MergeIntoTable, UnaryNode, UpdateTable}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH}
+import org.apache.spark.sql.errors.QueryErrorsBase
+import org.apache.spark.sql.types.DataType
+
+/**
+ * The expression represents a named parameter that should be replaced by a literal.
+ *
+ * @param name The identifier of the parameter without the marker.
+ */
+case class Parameter(name: String) extends LeafExpression with Unevaluable {
+ override lazy val resolved: Boolean = false
+
+ private def unboundError(methodName: String): Nothing = {
+ throw SparkException.internalError(
+ s"Cannot call `$methodName()` of the unbound parameter `$name`.")
+ }
+ override def dataType: DataType = unboundError("dataType")
+ override def nullable: Boolean = unboundError("nullable")
+
+ final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETER)
+}
+
+/**
+ * The logical plan representing a parameterized query. It will be removed during analysis after
+ * the parameters are bind.
+ */
+case class ParameterizedQuery(child: LogicalPlan, args: Map[String, Expression]) extends UnaryNode {
+ assert(args.nonEmpty)
+ override def output: Seq[Attribute] = Nil
+ override lazy val resolved = false
+ final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETERIZED_QUERY)
+ override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
+ copy(child = newChild)
+}
+
+/**
+ * Finds all named parameters in `ParameterizedQuery` and substitutes them by literals from the
+ * user-specified arguments.
+ */
+object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ if (plan.containsPattern(PARAMETERIZED_QUERY)) {
+ // One unresolved plan can have at most one ParameterizedQuery.
+ val parameterizedQueries = plan.collect { case p: ParameterizedQuery => p }
+ assert(parameterizedQueries.length == 1)
+ }
+
+ plan.resolveOperatorsWithPruning(_.containsPattern(PARAMETERIZED_QUERY)) {
+ // We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE
+ // relations are not children of `UnresolvedWith`.
+ case p @ ParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) =>
+ // Some commands may store the original SQL text, like CREATE VIEW, GENERATED COLUMN, etc.
+ // We can't store the original SQL text with parameters, as we don't store the arguments and
+ // are not able to resolve it after parsing it back. Since parameterized query is mostly
+ // used to avoid SQL injection for SELECT queries, we simply forbid non-DML commands here.
+ child match {
+ case _: InsertIntoStatement => // OK
+ case _: UpdateTable => // OK
+ case _: DeleteFromTable => // OK
+ case _: MergeIntoTable => // OK
+ case cmd: Command =>
+ child.failAnalysis(
+ errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT",
+ messageParameters = Map("statement" -> cmd.nodeName)
+ )
+ case _ => // OK
+ }
+
+ args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) =>
+ expr.failAnalysis(
+ errorClass = "INVALID_SQL_ARG",
+ messageParameters = Map("name" -> name))
+ }
+
+ def bind(p: LogicalPlan): LogicalPlan = {
+ p.resolveExpressionsWithPruning(_.containsPattern(PARAMETER)) {
+ case Parameter(name) if args.contains(name) =>
+ args(name)
+ case sub: SubqueryExpression => sub.withNewPlan(bind(sub.plan))
+ }
+ }
+ val res = bind(child)
+ res.copyTagsFrom(p)
+ res
+
+ case _ => plan
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/parameters.scala
deleted file mode 100644
index fae2b9a1a9f..00000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/parameters.scala
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.analysis.AnalysisErrorAt
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, TreePattern}
-import org.apache.spark.sql.errors.QueryErrorsBase
-import org.apache.spark.sql.types.DataType
-
-/**
- * The expression represents a named parameter that should be replaced by a literal.
- *
- * @param name The identifier of the parameter without the marker.
- */
-case class Parameter(name: String) extends LeafExpression with Unevaluable {
- override lazy val resolved: Boolean = false
-
- private def unboundError(methodName: String): Nothing = {
- throw SparkException.internalError(
- s"Cannot call `$methodName()` of the unbound parameter `$name`.")
- }
- override def dataType: DataType = unboundError("dataType")
- override def nullable: Boolean = unboundError("nullable")
-
- final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETER)
-}
-
-
-/**
- * Finds all named parameters in the given plan and substitutes them by literals of `args` values.
- */
-object Parameter extends QueryErrorsBase {
- def bind(plan: LogicalPlan, args: Map[String, Expression]): LogicalPlan = {
- if (!args.isEmpty) {
- args.filter(!_._2.isInstanceOf[Literal]).headOption.foreach { case (name, expr) =>
- expr.failAnalysis(
- errorClass = "INVALID_SQL_ARG",
- messageParameters = Map("name" -> toSQLId(name)))
- }
- plan.transformAllExpressionsWithPruning(_.containsPattern(PARAMETER)) {
- case Parameter(name) if args.contains(name) => args(name)
- }
- } else {
- plan
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 48db1a4408d..ce853b5773c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -73,6 +73,7 @@ object TreePattern extends Enumeration {
val OR: Value = Value
val OUTER_REFERENCE: Value = Value
val PARAMETER: Value = Value
+ val PARAMETERIZED_QUERY: Value = Value
val PIVOT: Value = Value
val PLAN_EXPRESSION: Value = Value
val PYTHON_UDF: Value = Value
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 0f26d3a2dc9..54ea4086c9b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -1346,17 +1346,21 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}
test("SPARK-41271: bind named parameters to literals") {
- comparePlans(
- Parameter.bind(
- plan = parsePlan("SELECT * FROM a LIMIT :limitA"),
- args = Map("limitA" -> Literal(10))),
- parsePlan("SELECT * FROM a LIMIT 10"))
+ CTERelationDef.curId.set(0)
+ val actual1 = ParameterizedQuery(
+ child = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT :limitA"),
+ args = Map("limitA" -> Literal(10))).analyze
+ CTERelationDef.curId.set(0)
+ val expected1 = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT 10").analyze
+ comparePlans(actual1, expected1)
// Ignore unused arguments
- comparePlans(
- Parameter.bind(
- plan = parsePlan("SELECT c FROM a WHERE c < :param2"),
- args = Map("param1" -> Literal(10), "param2" -> Literal(20))),
- parsePlan("SELECT c FROM a WHERE c < 20"))
+ CTERelationDef.curId.set(0)
+ val actual2 = ParameterizedQuery(
+ child = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < :param2"),
+ args = Map("param1" -> Literal(10), "param2" -> Literal(20))).analyze
+ CTERelationDef.curId.set(0)
+ val expected2 = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < 20").analyze
+ comparePlans(actual2, expected2)
}
test("SPARK-41489: type of filter expression should be a bool") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 6fc83d8c782..3b5a2401335 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.parser
import org.apache.spark.SparkThrowable
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases}
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Parameter, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{PercentileCont, PercentileDisc}
import org.apache.spark.sql.catalyst.plans._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index adbe593ac56..066e609a6d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -35,9 +35,9 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst._
-import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.analysis.{ParameterizedQuery, UnresolvedRelation}
import org.apache.spark.sql.catalyst.encoders._
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Parameter}
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.ExternalCommandRunner
@@ -623,8 +623,12 @@ class SparkSession private(
val tracker = new QueryPlanningTracker
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
val parser = sessionState.sqlParser
- val parsedArgs = args.mapValues(parser.parseExpression).toMap
- Parameter.bind(parser.parsePlan(sqlText), parsedArgs)
+ val parsedPlan = parser.parsePlan(sqlText)
+ if (args.nonEmpty) {
+ ParameterizedQuery(parsedPlan, args.mapValues(parser.parseExpression).toMap)
+ } else {
+ parsedPlan
+ }
}
Dataset.ofRows(self, plan, tracker)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
index 668a1e4ad7d..e6e5eb9fac4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
@@ -38,13 +38,107 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
Row(true))
}
+ test("parameter binding is case sensitive") {
+ checkAnswer(
+ spark.sql("SELECT :p, :P", Map("p" -> "1", "P" -> "2")),
+ Row(1, 2)
+ )
+
+ checkError(
+ exception = intercept[AnalysisException] {
+ spark.sql("select :P", Map("p" -> "1"))
+ },
+ errorClass = "UNBOUND_SQL_PARAMETER",
+ parameters = Map("name" -> "P"),
+ context = ExpectedContext(
+ fragment = ":P",
+ start = 7,
+ stop = 8))
+ }
+
+ test("parameters in CTE") {
+ val sqlText =
+ """
+ |WITH w1 AS (SELECT :p1 AS p)
+ |SELECT p + :p2 FROM w1
+ |""".stripMargin
+ val args = Map("p1" -> "1", "p2" -> "2")
+ checkAnswer(
+ spark.sql(sqlText, args),
+ Row(3))
+ }
+
+ test("parameters in nested CTE") {
+ val sqlText =
+ """
+ |WITH w1 AS
+ | (WITH w2 AS (SELECT :p1 AS p) SELECT p + :p2 AS p2 FROM w2)
+ |SELECT p2 + :p3 FROM w1
+ |""".stripMargin
+ val args = Map("p1" -> "1", "p2" -> "2", "p3" -> "3")
+ checkAnswer(
+ spark.sql(sqlText, args),
+ Row(6))
+ }
+
+ test("parameters in subquery expression") {
+ val sqlText = "SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2"
+ val args = Map("p1" -> "1", "p2" -> "2")
+ checkAnswer(
+ spark.sql(sqlText, args),
+ Row(12))
+ }
+
+ test("parameters in nested subquery expression") {
+ val sqlText = "SELECT (SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2) + :p3"
+ val args = Map("p1" -> "1", "p2" -> "2", "p3" -> "3")
+ checkAnswer(
+ spark.sql(sqlText, args),
+ Row(15))
+ }
+
+ test("parameters in subquery expression inside CTE") {
+ val sqlText =
+ """
+ |WITH w1 AS (SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2 AS p)
+ |SELECT p + :p3 FROM w1
+ |""".stripMargin
+ val args = Map("p1" -> "1", "p2" -> "2", "p3" -> "3")
+ checkAnswer(
+ spark.sql(sqlText, args),
+ Row(15))
+ }
+
+ test("parameters in INSERT") {
+ withTable("t") {
+ sql("CREATE TABLE t (col INT) USING json")
+ spark.sql("INSERT INTO t SELECT :p", Map("p" -> "1"))
+ checkAnswer(spark.table("t"), Row(1))
+ }
+ }
+
+ test("parameters not allowed in DDL commands") {
+ val sqlText = "CREATE VIEW v AS SELECT :p AS p"
+ val args = Map("p" -> "1")
+ checkError(
+ exception = intercept[AnalysisException] {
+ spark.sql(sqlText, args)
+ },
+ errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT",
+ parameters = Map("statement" -> "CreateView"),
+ context = ExpectedContext(
+ fragment = "CREATE VIEW v AS SELECT :p AS p",
+ start = 0,
+ stop = sqlText.length - 1))
+ }
+
test("non-substituted parameters") {
checkError(
exception = intercept[AnalysisException] {
spark.sql("select :abc, :def", Map("abc" -> "1"))
},
errorClass = "UNBOUND_SQL_PARAMETER",
- parameters = Map("name" -> "`def`"),
+ parameters = Map("name" -> "def"),
context = ExpectedContext(
fragment = ":def",
start = 13,
@@ -54,7 +148,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
sql("select :abc").collect()
},
errorClass = "UNBOUND_SQL_PARAMETER",
- parameters = Map("name" -> "`abc`"),
+ parameters = Map("name" -> "abc"),
context = ExpectedContext(
fragment = ":abc",
start = 7,
@@ -68,7 +162,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
spark.sql("SELECT :param1 FROM VALUES (1) AS t(col1)", Map("param1" -> arg))
},
errorClass = "INVALID_SQL_ARG",
- parameters = Map("name" -> "`param1`"),
+ parameters = Map("name" -> "param1"),
context = ExpectedContext(
fragment = arg,
start = 0,
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org