You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/03/10 01:32:01 UTC

[spark] branch branch-3.4 updated: [SPARK-42702][SPARK-42623][SQL] Support parameterized query in subquery and CTE

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 80127286c5f [SPARK-42702][SPARK-42623][SQL] Support parameterized query in subquery and CTE
80127286c5f is described below

commit 80127286c5fd9cd472c868e0bf8ebcec4cf399dc
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Fri Mar 10 09:31:20 2023 +0800

    [SPARK-42702][SPARK-42623][SQL] Support parameterized query in subquery and CTE
    
    ### What changes were proposed in this pull request?
    
    This PR fixes a few issues of parameterized query:
    1. replace placeholders in CTE/subqueries
    2. don't replace placeholders in non-DML commands as it may store the original SQL text with placeholders and we can't resolve it later (e.g. CREATE VIEW).
    
    ### Why are the changes needed?
    
    make the parameterized query feature complete
    
    ### Does this PR introduce _any_ user-facing change?
    
    yes, bug fix
    
    ### How was this patch tested?
    
    new tests
    
    Closes #40333 from cloud-fan/parameter.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit a7807038d5be5e46634d5bf807dd12fa63546b33)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  |  10 +-
 core/src/main/resources/error/error-classes.json   |   5 +
 .../spark/sql/catalyst/analysis/Analyzer.scala     |   1 +
 .../sql/catalyst/analysis/CheckAnalysis.scala      |   2 +-
 .../spark/sql/catalyst/analysis/parameters.scala   | 112 +++++++++++++++++++++
 .../sql/catalyst/expressions/parameters.scala      |  64 ------------
 .../spark/sql/catalyst/trees/TreePatterns.scala    |   1 +
 .../sql/catalyst/analysis/AnalysisSuite.scala      |  24 +++--
 .../sql/catalyst/parser/PlanParserSuite.scala      |   2 +-
 .../scala/org/apache/spark/sql/SparkSession.scala  |  12 ++-
 .../org/apache/spark/sql/ParametersSuite.scala     | 100 +++++++++++++++++-
 11 files changed, 247 insertions(+), 86 deletions(-)

diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 5dd0a7ea309..24717e07b00 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -32,7 +32,7 @@ import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
 import org.apache.spark.connect.proto.Parse.ParseFormat
 import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
@@ -209,8 +209,12 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformSql(sql: proto.SQL): LogicalPlan = {
     val args = sql.getArgsMap.asScala.toMap
     val parser = session.sessionState.sqlParser
-    val parsedArgs = args.mapValues(parser.parseExpression).toMap
-    Parameter.bind(parser.parsePlan(sql.getQuery), parsedArgs)
+    val parsedPlan = parser.parsePlan(sql.getQuery)
+    if (args.nonEmpty) {
+      ParameterizedQuery(parsedPlan, args.mapValues(parser.parseExpression).toMap)
+    } else {
+      parsedPlan
+    }
   }
 
   private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan = {
diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json
index 2780c98bfc6..7d16365c677 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -1714,6 +1714,11 @@
           "Pandas user defined aggregate function in the PIVOT clause."
         ]
       },
+      "PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT" : {
+        "message" : [
+          "Parameter markers in unexpected statement: <statement>. Parameter markers must only be used in a query, or DML statement."
+        ]
+      },
       "PIVOT_AFTER_GROUP_BY" : {
         "message" : [
           "PIVOT clause following a GROUP BY clause. Consider pushing the GROUP BY into a subquery."
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d7cc34d6f15..e5d78b21f19 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -265,6 +265,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
       // at the beginning of analysis.
       OptimizeUpdateFields,
       CTESubstitution,
+      BindParameters,
       WindowsSubstitution,
       EliminateUnions,
       SubstituteUnresolvedOrdinals),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 77948735dbe..3af15d2465a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -336,7 +336,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
           case p: Parameter =>
             p.failAnalysis(
               errorClass = "UNBOUND_SQL_PARAMETER",
-              messageParameters = Map("name" -> toSQLId(p.name)))
+              messageParameters = Map("name" -> p.name))
 
           case _ =>
         })
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
new file mode 100644
index 00000000000..29c36300673
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, LeafExpression, Literal, SubqueryExpression, Unevaluable}
+import org.apache.spark.sql.catalyst.plans.logical.{Command, DeleteFromTable, InsertIntoStatement, LogicalPlan, MergeIntoTable, UnaryNode, UpdateTable}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH}
+import org.apache.spark.sql.errors.QueryErrorsBase
+import org.apache.spark.sql.types.DataType
+
+/**
+ * The expression represents a named parameter that should be replaced by a literal.
+ *
+ * @param name The identifier of the parameter without the marker.
+ */
+case class Parameter(name: String) extends LeafExpression with Unevaluable {
+  override lazy val resolved: Boolean = false
+
+  private def unboundError(methodName: String): Nothing = {
+    throw SparkException.internalError(
+      s"Cannot call `$methodName()` of the unbound parameter `$name`.")
+  }
+  override def dataType: DataType = unboundError("dataType")
+  override def nullable: Boolean = unboundError("nullable")
+
+  final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETER)
+}
+
+/**
+ * The logical plan representing a parameterized query. It will be removed during analysis after
+ * the parameters are bind.
+ */
+case class ParameterizedQuery(child: LogicalPlan, args: Map[String, Expression]) extends UnaryNode {
+  assert(args.nonEmpty)
+  override def output: Seq[Attribute] = Nil
+  override lazy val resolved = false
+  final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETERIZED_QUERY)
+  override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
+    copy(child = newChild)
+}
+
+/**
+ * Finds all named parameters in `ParameterizedQuery` and substitutes them by literals from the
+ * user-specified arguments.
+ */
+object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
+  override def apply(plan: LogicalPlan): LogicalPlan = {
+    if (plan.containsPattern(PARAMETERIZED_QUERY)) {
+      // One unresolved plan can have at most one ParameterizedQuery.
+      val parameterizedQueries = plan.collect { case p: ParameterizedQuery => p }
+      assert(parameterizedQueries.length == 1)
+    }
+
+    plan.resolveOperatorsWithPruning(_.containsPattern(PARAMETERIZED_QUERY)) {
+      // We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE
+      // relations are not children of `UnresolvedWith`.
+      case p @ ParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) =>
+        // Some commands may store the original SQL text, like CREATE VIEW, GENERATED COLUMN, etc.
+        // We can't store the original SQL text with parameters, as we don't store the arguments and
+        // are not able to resolve it after parsing it back. Since parameterized query is mostly
+        // used to avoid SQL injection for SELECT queries, we simply forbid non-DML commands here.
+        child match {
+          case _: InsertIntoStatement => // OK
+          case _: UpdateTable => // OK
+          case _: DeleteFromTable => // OK
+          case _: MergeIntoTable => // OK
+          case cmd: Command =>
+            child.failAnalysis(
+              errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT",
+              messageParameters = Map("statement" -> cmd.nodeName)
+            )
+          case _ => // OK
+        }
+
+        args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) =>
+          expr.failAnalysis(
+            errorClass = "INVALID_SQL_ARG",
+            messageParameters = Map("name" -> name))
+        }
+
+        def bind(p: LogicalPlan): LogicalPlan = {
+          p.resolveExpressionsWithPruning(_.containsPattern(PARAMETER)) {
+            case Parameter(name) if args.contains(name) =>
+              args(name)
+            case sub: SubqueryExpression => sub.withNewPlan(bind(sub.plan))
+          }
+        }
+        val res = bind(child)
+        res.copyTagsFrom(p)
+        res
+
+      case _ => plan
+    }
+  }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/parameters.scala
deleted file mode 100644
index fae2b9a1a9f..00000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/parameters.scala
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.analysis.AnalysisErrorAt
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, TreePattern}
-import org.apache.spark.sql.errors.QueryErrorsBase
-import org.apache.spark.sql.types.DataType
-
-/**
- * The expression represents a named parameter that should be replaced by a literal.
- *
- * @param name The identifier of the parameter without the marker.
- */
-case class Parameter(name: String) extends LeafExpression with Unevaluable {
-  override lazy val resolved: Boolean = false
-
-  private def unboundError(methodName: String): Nothing = {
-    throw SparkException.internalError(
-      s"Cannot call `$methodName()` of the unbound parameter `$name`.")
-  }
-  override def dataType: DataType = unboundError("dataType")
-  override def nullable: Boolean = unboundError("nullable")
-
-  final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETER)
-}
-
-
-/**
- * Finds all named parameters in the given plan and substitutes them by literals of `args` values.
- */
-object Parameter extends QueryErrorsBase {
-  def bind(plan: LogicalPlan, args: Map[String, Expression]): LogicalPlan = {
-    if (!args.isEmpty) {
-      args.filter(!_._2.isInstanceOf[Literal]).headOption.foreach { case (name, expr) =>
-        expr.failAnalysis(
-          errorClass = "INVALID_SQL_ARG",
-          messageParameters = Map("name" -> toSQLId(name)))
-      }
-      plan.transformAllExpressionsWithPruning(_.containsPattern(PARAMETER)) {
-        case Parameter(name) if args.contains(name) => args(name)
-      }
-    } else {
-      plan
-    }
-  }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 48db1a4408d..ce853b5773c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -73,6 +73,7 @@ object TreePattern extends Enumeration  {
   val OR: Value = Value
   val OUTER_REFERENCE: Value = Value
   val PARAMETER: Value = Value
+  val PARAMETERIZED_QUERY: Value = Value
   val PIVOT: Value = Value
   val PLAN_EXPRESSION: Value = Value
   val PYTHON_UDF: Value = Value
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 0f26d3a2dc9..54ea4086c9b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -1346,17 +1346,21 @@ class AnalysisSuite extends AnalysisTest with Matchers {
   }
 
   test("SPARK-41271: bind named parameters to literals") {
-    comparePlans(
-      Parameter.bind(
-        plan = parsePlan("SELECT * FROM a LIMIT :limitA"),
-        args = Map("limitA" -> Literal(10))),
-      parsePlan("SELECT * FROM a LIMIT 10"))
+    CTERelationDef.curId.set(0)
+    val actual1 = ParameterizedQuery(
+      child = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT :limitA"),
+      args = Map("limitA" -> Literal(10))).analyze
+    CTERelationDef.curId.set(0)
+    val expected1 = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT 10").analyze
+    comparePlans(actual1, expected1)
     // Ignore unused arguments
-    comparePlans(
-      Parameter.bind(
-        plan = parsePlan("SELECT c FROM a WHERE c < :param2"),
-        args = Map("param1" -> Literal(10), "param2" -> Literal(20))),
-      parsePlan("SELECT c FROM a WHERE c < 20"))
+    CTERelationDef.curId.set(0)
+    val actual2 = ParameterizedQuery(
+      child = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < :param2"),
+      args = Map("param1" -> Literal(10), "param2" -> Literal(20))).analyze
+    CTERelationDef.curId.set(0)
+    val expected2 = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < 20").analyze
+    comparePlans(actual2, expected2)
   }
 
   test("SPARK-41489: type of filter expression should be a bool") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 6fc83d8c782..3b5a2401335 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.parser
 
 import org.apache.spark.SparkThrowable
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases}
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Parameter, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.{PercentileCont, PercentileDisc}
 import org.apache.spark.sql.catalyst.plans._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index adbe593ac56..066e609a6d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -35,9 +35,9 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
 import org.apache.spark.sql.catalog.Catalog
 import org.apache.spark.sql.catalyst._
-import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.analysis.{ParameterizedQuery, UnresolvedRelation}
 import org.apache.spark.sql.catalyst.encoders._
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Parameter}
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
 import org.apache.spark.sql.catalyst.util.CharVarcharUtils
 import org.apache.spark.sql.connector.ExternalCommandRunner
@@ -623,8 +623,12 @@ class SparkSession private(
     val tracker = new QueryPlanningTracker
     val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
       val parser = sessionState.sqlParser
-      val parsedArgs = args.mapValues(parser.parseExpression).toMap
-      Parameter.bind(parser.parsePlan(sqlText), parsedArgs)
+      val parsedPlan = parser.parsePlan(sqlText)
+      if (args.nonEmpty) {
+        ParameterizedQuery(parsedPlan, args.mapValues(parser.parseExpression).toMap)
+      } else {
+        parsedPlan
+      }
     }
     Dataset.ofRows(self, plan, tracker)
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
index 668a1e4ad7d..e6e5eb9fac4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
@@ -38,13 +38,107 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
       Row(true))
   }
 
+  test("parameter binding is case sensitive") {
+    checkAnswer(
+      spark.sql("SELECT :p, :P", Map("p" -> "1", "P" -> "2")),
+      Row(1, 2)
+    )
+
+    checkError(
+      exception = intercept[AnalysisException] {
+        spark.sql("select :P", Map("p" -> "1"))
+      },
+      errorClass = "UNBOUND_SQL_PARAMETER",
+      parameters = Map("name" -> "P"),
+      context = ExpectedContext(
+        fragment = ":P",
+        start = 7,
+        stop = 8))
+  }
+
+  test("parameters in CTE") {
+    val sqlText =
+      """
+        |WITH w1 AS (SELECT :p1 AS p)
+        |SELECT p + :p2 FROM w1
+        |""".stripMargin
+    val args = Map("p1" -> "1", "p2" -> "2")
+    checkAnswer(
+      spark.sql(sqlText, args),
+      Row(3))
+  }
+
+  test("parameters in nested CTE") {
+    val sqlText =
+      """
+        |WITH w1 AS
+        |  (WITH w2 AS (SELECT :p1 AS p) SELECT p + :p2 AS p2 FROM w2)
+        |SELECT p2 + :p3 FROM w1
+        |""".stripMargin
+    val args = Map("p1" -> "1", "p2" -> "2", "p3" -> "3")
+    checkAnswer(
+      spark.sql(sqlText, args),
+      Row(6))
+  }
+
+  test("parameters in subquery expression") {
+    val sqlText = "SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2"
+    val args = Map("p1" -> "1", "p2" -> "2")
+    checkAnswer(
+      spark.sql(sqlText, args),
+      Row(12))
+  }
+
+  test("parameters in nested subquery expression") {
+    val sqlText = "SELECT (SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2) + :p3"
+    val args = Map("p1" -> "1", "p2" -> "2", "p3" -> "3")
+    checkAnswer(
+      spark.sql(sqlText, args),
+      Row(15))
+  }
+
+  test("parameters in subquery expression inside CTE") {
+    val sqlText =
+      """
+        |WITH w1 AS (SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2 AS p)
+        |SELECT p + :p3 FROM w1
+        |""".stripMargin
+    val args = Map("p1" -> "1", "p2" -> "2", "p3" -> "3")
+    checkAnswer(
+      spark.sql(sqlText, args),
+      Row(15))
+  }
+
+  test("parameters in INSERT") {
+    withTable("t") {
+      sql("CREATE TABLE t (col INT) USING json")
+      spark.sql("INSERT INTO t SELECT :p", Map("p" -> "1"))
+      checkAnswer(spark.table("t"), Row(1))
+    }
+  }
+
+  test("parameters not allowed in DDL commands") {
+    val sqlText = "CREATE VIEW v AS SELECT :p AS p"
+    val args = Map("p" -> "1")
+    checkError(
+      exception = intercept[AnalysisException] {
+        spark.sql(sqlText, args)
+      },
+      errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT",
+      parameters = Map("statement" -> "CreateView"),
+      context = ExpectedContext(
+        fragment = "CREATE VIEW v AS SELECT :p AS p",
+        start = 0,
+        stop = sqlText.length - 1))
+  }
+
   test("non-substituted parameters") {
     checkError(
       exception = intercept[AnalysisException] {
         spark.sql("select :abc, :def", Map("abc" -> "1"))
       },
       errorClass = "UNBOUND_SQL_PARAMETER",
-      parameters = Map("name" -> "`def`"),
+      parameters = Map("name" -> "def"),
       context = ExpectedContext(
         fragment = ":def",
         start = 13,
@@ -54,7 +148,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
         sql("select :abc").collect()
       },
       errorClass = "UNBOUND_SQL_PARAMETER",
-      parameters = Map("name" -> "`abc`"),
+      parameters = Map("name" -> "abc"),
       context = ExpectedContext(
         fragment = ":abc",
         start = 7,
@@ -68,7 +162,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
           spark.sql("SELECT :param1 FROM VALUES (1) AS t(col1)", Map("param1" -> arg))
         },
         errorClass = "INVALID_SQL_ARG",
-        parameters = Map("name" -> "`param1`"),
+        parameters = Map("name" -> "param1"),
         context = ExpectedContext(
           fragment = arg,
           start = 0,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org