You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2023/06/22 06:40:49 UTC

[spark] branch master updated: [SPARK-44066][SQL] Support positional parameters in Scala/Java `sql()`

This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1b4048bf62d [SPARK-44066][SQL] Support positional parameters in Scala/Java `sql()`
1b4048bf62d is described below

commit 1b4048bf62dddae7d324c4b12aa409a1bd456dc5
Author: Max Gekk <ma...@gmail.com>
AuthorDate: Thu Jun 22 09:40:30 2023 +0300

    [SPARK-44066][SQL] Support positional parameters in Scala/Java `sql()`
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to extend SparkSession API and override the `sql` method by:
    ```scala
      def sql(sqlText: String, args: Array[_]): DataFrame
    ```
    which accepts an array of Java/Scala objects that can be converted to SQL literal expressions.
    
    And the first argument `sqlText` might have named parameters in the positions of constants like literal values. A value can be also a `Column` of literal expression, in that case it is taken as is.
    
    For example:
    ```scala
      spark.sql(
        sqlText = "SELECT * FROM tbl WHERE date > ? LIMIT ?",
        args = Array(LocalDate.of(2023, 6, 15), 100))
    ```
    The new `sql()` method parses the input SQL statement and replaces the positional parameters by the literal values.
    
    ### Why are the changes needed?
    1. To conform the SQL standard and JDBC/ODBC protocol.
    2. To improve user experience with Spark SQL via
        - Using Spark as remote service (microservice).
        - Write SQL code that will power reports, dashboards, charts and other data presentation solutions that need to account for criteria modifiable by users through an interface.
        - Build a generic integration layer based on the SQL API. The goal is to expose managed data to a wide application ecosystem with a microservice architecture. It is only natural in such a setup to ask for modular and reusable SQL code, that can be executed repeatedly with different parameter values.
    
    3. To achieve feature parity with other systems that support positional parameters.
    
    ### Does this PR introduce _any_ user-facing change?
    No, the changes extend the existing API.
    
    ### How was this patch tested?
    By running new tests:
    ```
    $ build/sbt "test:testOnly *AnalysisSuite"
    $ build/sbt "test:testOnly *PlanParserSuite"
    $ build/sbt "test:testOnly *ParametersSuite"
    ```
    and the affected test suites:
    ```
    $ build/sbt "sql/testOnly *QueryExecutionErrorsSuite"
    ```
    
    Closes #41568 from MaxGekk/parametrized-query-pos-param.
    
    Authored-by: Max Gekk <ma...@gmail.com>
    Signed-off-by: Max Gekk <ma...@gmail.com>
---
 .../CheckConnectJvmClientCompatibility.scala       |   2 +
 .../sql/connect/planner/SparkConnectPlanner.scala  |   4 +-
 .../spark/sql/catalyst/parser/SqlBaseLexer.g4      |   1 +
 .../spark/sql/catalyst/parser/SqlBaseParser.g4     |   5 +-
 .../spark/sql/catalyst/analysis/parameters.scala   |  95 ++++++--
 .../spark/sql/catalyst/parser/AstBuilder.scala     |  14 +-
 .../sql/catalyst/analysis/AnalysisSuite.scala      |  22 +-
 .../sql/catalyst/parser/PlanParserSuite.scala      |  25 +-
 .../scala/org/apache/spark/sql/SparkSession.scala  |  34 ++-
 .../apache/spark/sql/JavaSparkSessionSuite.java    |  28 +++
 .../org/apache/spark/sql/ParametersSuite.scala     | 265 +++++++++++++++++++--
 .../sql/errors/QueryExecutionErrorsSuite.scala     |  10 +-
 12 files changed, 448 insertions(+), 57 deletions(-)

diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 6b648fd152b..acc469672b4 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -227,6 +227,8 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataset"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.executeCommand"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.this"),
+      // TODO(SPARK-44068): Support positional parameters in Scala connect client
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sql"),
 
       // RuntimeConfig
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.RuntimeConfig.this"),
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 6ee252d1a58..856d0f06ba4 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -41,7 +41,7 @@ import org.apache.spark.ml.{functions => MLFunctions}
 import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, RelationalGroupedDataset, Row, SparkSession}
 import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
@@ -253,7 +253,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
     val parser = session.sessionState.sqlParser
     val parsedPlan = parser.parsePlan(sql.getQuery)
     if (!args.isEmpty) {
-      ParameterizedQuery(parsedPlan, args.asScala.mapValues(transformLiteral).toMap)
+      NameParameterizedQuery(parsedPlan, args.asScala.mapValues(transformLiteral).toMap)
     } else {
       parsedPlan
     }
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
index ecd5f5912fd..6c9b3a71266 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
@@ -445,6 +445,7 @@ COLON: ':';
 ARROW: '->';
 HENT_START: '/*+';
 HENT_END: '*/';
+QUESTION: '?';
 
 STRING_LITERAL
     : '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 240310a426d..d1e672e9472 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -952,9 +952,10 @@ literalType
 
 constant
     : NULL                                                                                     #nullLiteral
-    | COLON identifier                                                                         #parameterLiteral
+    | QUESTION                                                                                 #posParameterLiteral
+    | COLON identifier                                                                         #namedParameterLiteral
     | interval                                                                                 #intervalLiteral
-    | literalType stringLit                                                                     #typeConstructor
+    | literalType stringLit                                                                    #typeConstructor
     | number                                                                                   #numericLiteral
     | booleanValue                                                                             #booleanLiteral
     | stringLit+                                                                               #stringLiteral
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
index a00f9cec92c..2e3cabce24a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
@@ -25,12 +25,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, PARAMETERIZED
 import org.apache.spark.sql.errors.QueryErrorsBase
 import org.apache.spark.sql.types.DataType
 
-/**
- * The expression represents a named parameter that should be replaced by a literal.
- *
- * @param name The identifier of the parameter without the marker.
- */
-case class Parameter(name: String) extends LeafExpression with Unevaluable {
+sealed trait Parameter extends LeafExpression with Unevaluable {
   override lazy val resolved: Boolean = false
 
   private def unboundError(methodName: String): Nothing = {
@@ -41,17 +36,56 @@ case class Parameter(name: String) extends LeafExpression with Unevaluable {
   override def nullable: Boolean = unboundError("nullable")
 
   final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETER)
+
+  def name: String
+}
+
+/**
+ * The expression represents a named parameter that should be replaced by a literal.
+ *
+ * @param name The identifier of the parameter without the marker.
+ */
+case class NamedParameter(name: String) extends Parameter
+
+/**
+ * The expression represents a positional parameter that should be replaced by a literal.
+ *
+ * @param pos An unique position of the parameter in a SQL query text.
+ */
+case class PosParameter(pos: Int) extends Parameter {
+  override def name: String = s"_$pos"
 }
 
 /**
  * The logical plan representing a parameterized query. It will be removed during analysis after
  * the parameters are bind.
  */
-case class ParameterizedQuery(child: LogicalPlan, args: Map[String, Expression])
-  extends UnresolvedUnaryNode {
+abstract class ParameterizedQuery(child: LogicalPlan) extends UnresolvedUnaryNode {
+  final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETERIZED_QUERY)
+}
 
+/**
+ * The logical plan representing a parameterized query with named parameters.
+ *
+ * @param child The parameterized logical plan.
+ * @param args The map of parameter names to its literal values.
+ */
+case class NameParameterizedQuery(child: LogicalPlan, args: Map[String, Expression])
+  extends ParameterizedQuery(child) {
+  assert(args.nonEmpty)
+  override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
+    copy(child = newChild)
+}
+
+/**
+ * The logical plan representing a parameterized query with positional parameters.
+ *
+ * @param child The parameterized logical plan.
+ * @param args The literal values of positional parameters.
+ */
+case class PosParameterizedQuery(child: LogicalPlan, args: Array[Expression])
+  extends ParameterizedQuery(child) {
   assert(args.nonEmpty)
-  final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETERIZED_QUERY)
   override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
     copy(child = newChild)
 }
@@ -61,6 +95,20 @@ case class ParameterizedQuery(child: LogicalPlan, args: Map[String, Expression])
  * user-specified arguments.
  */
 object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
+  private def checkArgs(args: Iterable[(String, Expression)]): Unit = {
+    args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) =>
+      expr.failAnalysis(
+        errorClass = "INVALID_SQL_ARG",
+        messageParameters = Map("name" -> name))
+    }
+  }
+
+  private def bind(p: LogicalPlan)(f: PartialFunction[Expression, Expression]): LogicalPlan = {
+    p.resolveExpressionsWithPruning(_.containsPattern(PARAMETER)) (f orElse {
+      case sub: SubqueryExpression => sub.withNewPlan(bind(sub.plan)(f))
+    })
+  }
+
   override def apply(plan: LogicalPlan): LogicalPlan = {
     if (plan.containsPattern(PARAMETERIZED_QUERY)) {
       // One unresolved plan can have at most one ParameterizedQuery.
@@ -71,23 +119,22 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
     plan.resolveOperatorsWithPruning(_.containsPattern(PARAMETERIZED_QUERY)) {
       // We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE
       // relations are not children of `UnresolvedWith`.
-      case p @ ParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) =>
-        args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) =>
-          expr.failAnalysis(
-            errorClass = "INVALID_SQL_ARG",
-            messageParameters = Map("name" -> name))
-        }
+      case p @ NameParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) =>
+        checkArgs(args)
+        bind(child) { case NamedParameter(name) if args.contains(name) => args(name) }
+
+      case p @ PosParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) =>
+        val indexedArgs = args.zipWithIndex
+        checkArgs(indexedArgs.map(arg => (s"_${arg._2}", arg._1)))
+
+        val positions = scala.collection.mutable.Set.empty[Int]
+        bind(child) { case p @ PosParameter(pos) => positions.add(pos); p }
+        val posToIndex = positions.toSeq.sorted.zipWithIndex.toMap
 
-        def bind(p: LogicalPlan): LogicalPlan = {
-          p.resolveExpressionsWithPruning(_.containsPattern(PARAMETER)) {
-            case Parameter(name) if args.contains(name) =>
-              args(name)
-            case sub: SubqueryExpression => sub.withNewPlan(bind(sub.plan))
-          }
+        bind(child) {
+          case PosParameter(pos) if posToIndex.contains(pos) && args.size > posToIndex(pos) =>
+            args(posToIndex(pos))
         }
-        val res = bind(child)
-        res.copyTagsFrom(p)
-        res
 
       case _ => plan
     }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 07721424a86..ca62de12e7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -5113,7 +5113,17 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
   /**
    * Create a named parameter which represents a literal with a non-bound value and unknown type.
    * */
-  override def visitParameterLiteral(ctx: ParameterLiteralContext): Expression = withOrigin(ctx) {
-    Parameter(ctx.identifier().getText)
+  override def visitNamedParameterLiteral(
+      ctx: NamedParameterLiteralContext): Expression = withOrigin(ctx) {
+    NamedParameter(ctx.identifier().getText)
+  }
+
+  /**
+   * Create a positional parameter which represents a literal
+   * with a non-bound value and unknown type.
+   * */
+  override def visitPosParameterLiteral(
+      ctx: PosParameterLiteralContext): Expression = withOrigin(ctx) {
+    PosParameter(ctx.QUESTION().getSymbol.getStartIndex)
   }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 1e844e22bec..dae42453f0d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -1370,7 +1370,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
 
   test("SPARK-41271: bind named parameters to literals") {
     CTERelationDef.curId.set(0)
-    val actual1 = ParameterizedQuery(
+    val actual1 = NameParameterizedQuery(
       child = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT :limitA"),
       args = Map("limitA" -> Literal(10))).analyze
     CTERelationDef.curId.set(0)
@@ -1378,7 +1378,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
     comparePlans(actual1, expected1)
     // Ignore unused arguments
     CTERelationDef.curId.set(0)
-    val actual2 = ParameterizedQuery(
+    val actual2 = NameParameterizedQuery(
       child = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < :param2"),
       args = Map("param1" -> Literal(10), "param2" -> Literal(20))).analyze
     CTERelationDef.curId.set(0)
@@ -1386,6 +1386,24 @@ class AnalysisSuite extends AnalysisTest with Matchers {
     comparePlans(actual2, expected2)
   }
 
+  test("SPARK-44066: bind positional parameters to literals") {
+    CTERelationDef.curId.set(0)
+    val actual1 = PosParameterizedQuery(
+      child = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT ?"),
+      args = Array(Literal(10))).analyze
+    CTERelationDef.curId.set(0)
+    val expected1 = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT 10").analyze
+    comparePlans(actual1, expected1)
+    // Ignore unused arguments
+    CTERelationDef.curId.set(0)
+    val actual2 = PosParameterizedQuery(
+      child = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < ?"),
+      args = Array(Literal(20), Literal(10))).analyze
+    CTERelationDef.curId.set(0)
+    val expected2 = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < 20").analyze
+    comparePlans(actual2, expected2)
+  }
+
   test("SPARK-41489: type of filter expression should be a bool") {
     assertAnalysisErrorClass(parsePlan(
       s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 5a28ef847dc..ded8aaf7430 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.parser
 
 import org.apache.spark.SparkThrowable
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Parameter, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases}
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedParameter, PosParameter, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.{PercentileCont, PercentileDisc}
 import org.apache.spark.sql.catalyst.plans._
@@ -1630,18 +1630,18 @@ class PlanParserSuite extends AnalysisTest {
   test("SPARK-41271: parsing of named parameters") {
     comparePlans(
       parsePlan("SELECT :param_1"),
-      Project(UnresolvedAlias(Parameter("param_1"), None) :: Nil, OneRowRelation()))
+      Project(UnresolvedAlias(NamedParameter("param_1"), None) :: Nil, OneRowRelation()))
     comparePlans(
       parsePlan("SELECT abs(:1Abc)"),
       Project(UnresolvedAlias(
         UnresolvedFunction(
           "abs" :: Nil,
-          Parameter("1Abc") :: Nil,
+          NamedParameter("1Abc") :: Nil,
           isDistinct = false), None) :: Nil,
         OneRowRelation()))
     comparePlans(
       parsePlan("SELECT * FROM a LIMIT :limitA"),
-      table("a").select(star()).limit(Parameter("limitA")))
+      table("a").select(star()).limit(NamedParameter("limitA")))
     // Invalid empty name and invalid symbol in a name
     checkError(
       exception = parseException(s"SELECT :-"),
@@ -1661,4 +1661,21 @@ class PlanParserSuite extends AnalysisTest {
         Seq(Literal("abc")) :: Nil).as("tbl").select($"interval")
     )
   }
+
+  test("SPARK-44066: parsing of positional parameters") {
+    comparePlans(
+      parsePlan("SELECT ?"),
+      Project(UnresolvedAlias(PosParameter(7), None) :: Nil, OneRowRelation()))
+    comparePlans(
+      parsePlan("SELECT abs(?)"),
+      Project(UnresolvedAlias(
+        UnresolvedFunction(
+          "abs" :: Nil,
+          PosParameter(11) :: Nil,
+          isDistinct = false), None) :: Nil,
+        OneRowRelation()))
+    comparePlans(
+      parsePlan("SELECT * FROM a LIMIT ?"),
+      table("a").select(star()).limit(PosParameter(22)))
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 642006fb8dc..2a1c2474bc6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -35,7 +35,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
 import org.apache.spark.sql.catalog.Catalog
 import org.apache.spark.sql.catalyst._
-import org.apache.spark.sql.catalyst.analysis.{ParameterizedQuery, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation}
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions.AttributeReference
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
@@ -609,6 +609,36 @@ class SparkSession private(
    |  Everything else  |
    * ----------------- */
 
+  /**
+   * Executes a SQL query substituting positional parameters by the given arguments,
+   * returning the result as a `DataFrame`.
+   * This API eagerly runs DDL/DML commands, but not for SELECT queries.
+   *
+   * @param sqlText A SQL statement with positional parameters to execute.
+   * @param args An array of Java/Scala objects that can be converted to
+   *             SQL literal expressions. See
+   *             <a href="https://spark.apache.org/docs/latest/sql-ref-datatypes.html">
+   *             Supported Data Types</a> for supported value types in Scala/Java.
+   *             For example, 1, "Steven", LocalDate.of(2023, 4, 2).
+   *             A value can be also a `Column` of literal expression, in that case
+   *             it is taken as is.
+   *
+   * @since 3.5.0
+   */
+  @Experimental
+  def sql(sqlText: String, args: Array[_]): DataFrame = withActive {
+    val tracker = new QueryPlanningTracker
+    val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
+      val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
+      if (args.nonEmpty) {
+        PosParameterizedQuery(parsedPlan, args.map(lit(_).expr))
+      } else {
+        parsedPlan
+      }
+    }
+    Dataset.ofRows(self, plan, tracker)
+  }
+
   /**
    * Executes a SQL query substituting named parameters by the given arguments,
    * returning the result as a `DataFrame`.
@@ -632,7 +662,7 @@ class SparkSession private(
     val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
       val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
       if (args.nonEmpty) {
-        ParameterizedQuery(parsedPlan, args.mapValues(lit(_).expr).toMap)
+        NameParameterizedQuery(parsedPlan, args.mapValues(lit(_).expr).toMap)
       } else {
         parsedPlan
       }
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaSparkSessionSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaSparkSessionSuite.java
index b1df377936d..0d6d773d930 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaSparkSessionSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaSparkSessionSuite.java
@@ -18,11 +18,13 @@
 package test.org.apache.spark.sql;
 
 import org.apache.spark.sql.*;
+import org.apache.spark.sql.test.TestSparkSession;
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Test;
 
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 public class JavaSparkSessionSuite {
@@ -54,4 +56,30 @@ public class JavaSparkSessionSuite {
       Assert.assertEquals(spark.conf().get(e.getKey()), e.getValue().toString());
     }
   }
+
+  @Test
+  public void testPositionalParameters() {
+    spark = new TestSparkSession();
+
+    int[] emptyArgs = {};
+    List<Row> collected1 = spark.sql("select 'abc'", emptyArgs).collectAsList();
+    Assert.assertEquals("abc", collected1.get(0).getString(0));
+
+    Object[] singleArg = new String[] { "abc" };
+    List<Row> collected2 = spark.sql("select ?", singleArg).collectAsList();
+    Assert.assertEquals("abc", collected2.get(0).getString(0));
+
+    int[] args = new int[] { 1, 2, 3 };
+    List<Row> collected3 = spark.sql("select ?, ?, ?", args).collectAsList();
+    Row r0 = collected3.get(0);
+    Assert.assertEquals(1, r0.getInt(0));
+    Assert.assertEquals(2, r0.getInt(1));
+    Assert.assertEquals(3, r0.getInt(2));
+
+    Object[] mixedArgs = new Object[] { 1, "abc" };
+    List<Row> collected4 = spark.sql("select ?, ?", mixedArgs).collectAsList();
+    Row r1 = collected4.get(0);
+    Assert.assertEquals(1, r1.getInt(0));
+    Assert.assertEquals("abc", r1.getString(1));
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
index 985d0373c4f..725956e259b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.test.SharedSparkSession
 
 class ParametersSuite extends QueryTest with SharedSparkSession {
 
-  test("bind parameters") {
+  test("bind named parameters") {
     val sqlText =
       """
         |SELECT id, id % :div as c0
@@ -42,6 +42,23 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
       Row(true))
   }
 
+  test("bind positional parameters") {
+    val sqlText =
+      """
+        |SELECT id, id % ? as c0
+        |FROM VALUES (0), (1), (2), (3), (4), (5), (6), (7), (8), (9) AS t(id)
+        |WHERE id < ?
+        |""".stripMargin
+    val args = Array(3, 4L)
+    checkAnswer(
+      spark.sql(sqlText, args),
+      Row(0, 0) :: Row(1, 1) :: Row(2, 2) :: Row(3, 0) :: Nil)
+
+    checkAnswer(
+      spark.sql("""SELECT contains('Spark \'SQL\'', ?)""", Array("SQL")),
+      Row(true))
+  }
+
   test("parameter binding is case sensitive") {
     checkAnswer(
       spark.sql("SELECT :p, :P", Map("p" -> 1, "P" -> 2)),
@@ -60,7 +77,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
         stop = 8))
   }
 
-  test("parameters in CTE") {
+  test("named parameters in CTE") {
     val sqlText =
       """
         |WITH w1 AS (SELECT :p1 AS p)
@@ -72,7 +89,19 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
       Row(3))
   }
 
-  test("parameters in nested CTE") {
+  test("positional parameters in CTE") {
+    val sqlText =
+      """
+        |WITH w1 AS (SELECT ? AS p)
+        |SELECT p + ? FROM w1
+        |""".stripMargin
+    val args = Array(1, 2)
+    checkAnswer(
+      spark.sql(sqlText, args),
+      Row(3))
+  }
+
+  test("named parameters in nested CTE") {
     val sqlText =
       """
         |WITH w1 AS
@@ -85,7 +114,20 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
       Row(6))
   }
 
-  test("parameters in subquery expression") {
+  test("positional parameters in nested CTE") {
+    val sqlText =
+      """
+        |WITH w1 AS
+        |  (WITH w2 AS (SELECT ? AS p) SELECT p + ? AS p2 FROM w2)
+        |SELECT p2 + ? FROM w1
+        |""".stripMargin
+    val args = Array(1, 2, 3)
+    checkAnswer(
+      spark.sql(sqlText, args),
+      Row(6))
+  }
+
+  test("named parameters in subquery expression") {
     val sqlText = "SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2"
     val args = Map("p1" -> 1, "p2" -> 2)
     checkAnswer(
@@ -93,7 +135,15 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
       Row(12))
   }
 
-  test("parameters in nested subquery expression") {
+  test("positional parameters in subquery expression") {
+    val sqlText = "SELECT (SELECT max(id) + ? FROM range(10)) + ?"
+    val args = Array(1, 2)
+    checkAnswer(
+      spark.sql(sqlText, args),
+      Row(12))
+  }
+
+  test("named parameters in nested subquery expression") {
     val sqlText = "SELECT (SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2) + :p3"
     val args = Map("p1" -> 1, "p2" -> 2, "p3" -> 3)
     checkAnswer(
@@ -101,7 +151,15 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
       Row(15))
   }
 
-  test("parameters in subquery expression inside CTE") {
+  test("positional parameters in nested subquery expression") {
+    val sqlText = "SELECT (SELECT (SELECT max(id) + ? FROM range(10)) + ?) + ?"
+    val args = Array(1, 2, 3)
+    checkAnswer(
+      spark.sql(sqlText, args),
+      Row(15))
+  }
+
+  test("named parameters in subquery expression inside CTE") {
     val sqlText =
       """
         |WITH w1 AS (SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2 AS p)
@@ -113,7 +171,19 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
       Row(15))
   }
 
-  test("parameter in identifier clause") {
+  test("positional parameters in subquery expression inside CTE") {
+    val sqlText =
+      """
+        |WITH w1 AS (SELECT (SELECT max(id) + ? FROM range(10)) + ? AS p)
+        |SELECT p + ? FROM w1
+        |""".stripMargin
+    val args = Array(1, 2, 3)
+    checkAnswer(
+      spark.sql(sqlText, args),
+      Row(15))
+  }
+
+  test("named parameter in identifier clause") {
     val sqlText =
       "SELECT IDENTIFIER('T.' || :p1 || '1') FROM VALUES(1) T(c1)"
     val args = Map("p1" -> "c")
@@ -122,7 +192,16 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
       Row(1))
   }
 
-  test("parameter in identifier clause in DDL and utility commands") {
+  test("positional parameter in identifier clause") {
+    val sqlText =
+      "SELECT IDENTIFIER('T.' || ? || '1') FROM VALUES(1) T(c1)"
+    val args = Array("c")
+    checkAnswer(
+      spark.sql(sqlText, args),
+      Row(1))
+  }
+
+  test("named parameter in identifier clause in DDL and utility commands") {
     spark.sql("CREATE VIEW IDENTIFIER(:p1)(c1) AS SELECT 1", args = Map("p1" -> "v"))
     spark.sql("ALTER VIEW IDENTIFIER(:p1) AS SELECT 2 AS c1", args = Map("p1" -> "v"))
     checkAnswer(
@@ -131,7 +210,16 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
     spark.sql("DROP VIEW IDENTIFIER(:p1)", args = Map("p1" -> "v"))
   }
 
-  test("parameters in INSERT") {
+  test("positional parameter in identifier clause in DDL and utility commands") {
+    spark.sql("CREATE VIEW IDENTIFIER(?)(c1) AS SELECT 1", args = Array("v"))
+    spark.sql("ALTER VIEW IDENTIFIER(?) AS SELECT 2 AS c1", args = Array("v"))
+    checkAnswer(
+      spark.sql("SHOW COLUMNS FROM IDENTIFIER(?)", args = Array("v")),
+      Row("c1"))
+    spark.sql("DROP VIEW IDENTIFIER(?)", args = Array("v"))
+  }
+
+  test("named parameters in INSERT") {
     withTable("t") {
       sql("CREATE TABLE t (col INT) USING json")
       spark.sql("INSERT INTO t SELECT :p", Map("p" -> 1))
@@ -139,7 +227,15 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
     }
   }
 
-  test("parameters not allowed in view body ") {
+  test("positional parameters in INSERT") {
+    withTable("t") {
+      sql("CREATE TABLE t (col INT) USING json")
+      spark.sql("INSERT INTO t SELECT ?", Array(1))
+      checkAnswer(spark.table("t"), Row(1))
+    }
+  }
+
+  test("named parameters not allowed in view body ") {
     val sqlText = "CREATE VIEW v AS SELECT :p AS p"
     val args = Map("p" -> 1)
     checkError(
@@ -154,7 +250,22 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
         stop = sqlText.length - 1))
   }
 
-  test("parameters not allowed in view body - WITH and scalar subquery") {
+  test("positional parameters not allowed in view body ") {
+    val sqlText = "CREATE VIEW v AS SELECT ? AS p"
+    val args = Array(1)
+    checkError(
+      exception = intercept[AnalysisException] {
+        spark.sql(sqlText, args)
+      },
+      errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT",
+      parameters = Map("statement" -> "CREATE VIEW body"),
+      context = ExpectedContext(
+        fragment = sqlText,
+        start = 0,
+        stop = sqlText.length - 1))
+  }
+
+  test("named parameters not allowed in view body - WITH and scalar subquery") {
     val sqlText = "CREATE VIEW v AS WITH cte(a) AS (SELECT (SELECT :p) AS a)  SELECT a FROM cte"
     val args = Map("p" -> 1)
     checkError(
@@ -169,7 +280,22 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
         stop = sqlText.length - 1))
   }
 
-  test("parameters not allowed in view body - nested WITH and EXIST") {
+  test("positional parameters not allowed in view body - WITH and scalar subquery") {
+    val sqlText = "CREATE VIEW v AS WITH cte(a) AS (SELECT (SELECT ?) AS a)  SELECT a FROM cte"
+    val args = Array(1)
+    checkError(
+      exception = intercept[AnalysisException] {
+        spark.sql(sqlText, args)
+      },
+      errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT",
+      parameters = Map("statement" -> "CREATE VIEW body"),
+      context = ExpectedContext(
+        fragment = sqlText,
+        start = 0,
+        stop = sqlText.length - 1))
+  }
+
+  test("named parameters not allowed in view body - nested WITH and EXIST") {
     val sqlText =
       """CREATE VIEW v AS
         |SELECT a as a
@@ -188,7 +314,26 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
         stop = sqlText.length - 1))
   }
 
-  test("non-substituted parameters") {
+  test("positional parameters not allowed in view body - nested WITH and EXIST") {
+    val sqlText =
+      """CREATE VIEW v AS
+        |SELECT a as a
+        |FROM (WITH cte(a) AS (SELECT CASE WHEN EXISTS(SELECT ?) THEN 1 END AS a)
+        |SELECT a FROM cte)""".stripMargin
+    val args = Array(1)
+    checkError(
+      exception = intercept[AnalysisException] {
+        spark.sql(sqlText, args)
+      },
+      errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT",
+      parameters = Map("statement" -> "CREATE VIEW body"),
+      context = ExpectedContext(
+        fragment = sqlText,
+        start = 0,
+        stop = sqlText.length - 1))
+  }
+
+  test("non-substituted named parameters") {
     checkError(
       exception = intercept[AnalysisException] {
         spark.sql("select :abc, :def", Map("abc" -> 1))
@@ -211,7 +356,30 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
         stop = 10))
   }
 
-  test("literal argument of `sql()`") {
+  test("non-substituted positional parameters") {
+    checkError(
+      exception = intercept[AnalysisException] {
+        spark.sql("select ?, ?", Array(1))
+      },
+      errorClass = "UNBOUND_SQL_PARAMETER",
+      parameters = Map("name" -> "_10"),
+      context = ExpectedContext(
+        fragment = "?",
+        start = 10,
+        stop = 10))
+    checkError(
+      exception = intercept[AnalysisException] {
+        sql("select ?").collect()
+      },
+      errorClass = "UNBOUND_SQL_PARAMETER",
+      parameters = Map("name" -> "_7"),
+      context = ExpectedContext(
+        fragment = "?",
+        start = 7,
+        stop = 7))
+  }
+
+  test("literal argument of named parameter in `sql()`") {
     val sqlText =
       """SELECT s FROM VALUES ('Jeff /*__*/ Green'), ('E\'Twaun Moore'), ('Vander Blue') AS t(s)
         |WHERE s = :player_name""".stripMargin
@@ -249,4 +417,73 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
           .toInstant) :: Nil)
     }
   }
+
+  test("literal argument of positional parameter in `sql()`") {
+    val sqlText =
+      """SELECT s FROM VALUES ('Jeff /*__*/ Green'), ('E\'Twaun Moore'), ('Vander Blue') AS t(s)
+        |WHERE s = ?""".stripMargin
+    checkAnswer(
+      spark.sql(sqlText, args = Array(lit("E'Twaun Moore"))),
+      Row("E'Twaun Moore") :: Nil)
+    checkAnswer(
+      spark.sql(sqlText, args = Array(lit("Vander Blue--comment"))),
+      Nil)
+    checkAnswer(
+      spark.sql(sqlText, args = Array(lit("Jeff /*__*/ Green"))),
+      Row("Jeff /*__*/ Green") :: Nil)
+
+    withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
+      checkAnswer(
+        spark.sql(
+          sqlText = """
+                      |SELECT d
+                      |FROM VALUES (DATE'1970-01-01'), (DATE'2023-12-31') AS t(d)
+                      |WHERE d < ?
+                      |""".stripMargin,
+          args = Array(lit(LocalDate.of(2023, 4, 1)))),
+        Row(LocalDate.of(1970, 1, 1)) :: Nil)
+      checkAnswer(
+        spark.sql(
+          sqlText = """
+                      |SELECT d
+                      |FROM VALUES (TIMESTAMP_LTZ'1970-01-01 01:02:03 Europe/Amsterdam'),
+                      |            (TIMESTAMP_LTZ'2023-12-31 04:05:06 America/Los_Angeles') AS t(d)
+                      |WHERE d < ?
+                      |""".stripMargin,
+          args = Array(lit(Instant.parse("2023-04-01T00:00:00Z")))),
+        Row(LocalDateTime.of(1970, 1, 1, 1, 2, 3)
+          .atZone(ZoneId.of("Europe/Amsterdam"))
+          .toInstant) :: Nil)
+    }
+  }
+
+  test("unused positional arguments") {
+    checkAnswer(
+      spark.sql("SELECT ?, ?", Array(1, "abc", 3.14f)),
+      Row(1, "abc"))
+  }
+
+  test("mixing of positional and named parameters") {
+    checkError(
+      exception = intercept[AnalysisException] {
+        spark.sql("select :param1, ?", Map("param1" -> 1))
+      },
+      errorClass = "UNBOUND_SQL_PARAMETER",
+      parameters = Map("name" -> "_16"),
+      context = ExpectedContext(
+        fragment = "?",
+        start = 16,
+        stop = 16))
+
+    checkError(
+      exception = intercept[AnalysisException] {
+        spark.sql("select :param1, ?", Array(1))
+      },
+      errorClass = "UNBOUND_SQL_PARAMETER",
+      parameters = Map("name" -> "param1"),
+      context = ExpectedContext(
+        fragment = ":param1",
+        start = 7,
+        stop = 13))
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index 61b3610e64e..8f47b06d855 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -29,7 +29,7 @@ import org.mockito.Mockito.{mock, spy, when}
 import org.apache.spark._
 import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row, SaveMode}
 import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.analysis.{Parameter, UnresolvedGenerator}
+import org.apache.spark.sql.catalyst.analysis.{NamedParameter, UnresolvedGenerator}
 import org.apache.spark.sql.catalyst.expressions.{Grouping, Literal, RowNumber}
 import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
@@ -866,26 +866,26 @@ class QueryExecutionErrorsSuite
 
   test("INTERNAL_ERROR: Calling eval on Unevaluable expression") {
     val e = intercept[SparkException] {
-      Parameter("foo").eval()
+      NamedParameter("foo").eval()
     }
     checkError(
       exception = e,
       errorClass = "INTERNAL_ERROR",
-      parameters = Map("message" -> "Cannot evaluate expression: parameter(foo)"),
+      parameters = Map("message" -> "Cannot evaluate expression: namedparameter(foo)"),
       sqlState = "XX000")
   }
 
   test("INTERNAL_ERROR: Calling doGenCode on unresolved") {
     val e = intercept[SparkException] {
       val ctx = new CodegenContext
-      Grouping(Parameter("foo")).genCode(ctx)
+      Grouping(NamedParameter("foo")).genCode(ctx)
     }
     checkError(
       exception = e,
       errorClass = "INTERNAL_ERROR",
       parameters = Map(
         "message" -> ("Cannot generate code for expression: " +
-          "grouping(parameter(foo))")),
+          "grouping(namedparameter(foo))")),
       sqlState = "XX000")
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org