You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2023/06/22 06:40:49 UTC
[spark] branch master updated: [SPARK-44066][SQL] Support positional parameters in Scala/Java `sql()`
This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1b4048bf62d [SPARK-44066][SQL] Support positional parameters in Scala/Java `sql()`
1b4048bf62d is described below
commit 1b4048bf62dddae7d324c4b12aa409a1bd456dc5
Author: Max Gekk <ma...@gmail.com>
AuthorDate: Thu Jun 22 09:40:30 2023 +0300
[SPARK-44066][SQL] Support positional parameters in Scala/Java `sql()`
### What changes were proposed in this pull request?
In the PR, I propose to extend SparkSession API and override the `sql` method by:
```scala
def sql(sqlText: String, args: Array[_]): DataFrame
```
which accepts an array of Java/Scala objects that can be converted to SQL literal expressions.
And the first argument `sqlText` might have named parameters in the positions of constants like literal values. A value can be also a `Column` of literal expression, in that case it is taken as is.
For example:
```scala
spark.sql(
sqlText = "SELECT * FROM tbl WHERE date > ? LIMIT ?",
args = Array(LocalDate.of(2023, 6, 15), 100))
```
The new `sql()` method parses the input SQL statement and replaces the positional parameters by the literal values.
### Why are the changes needed?
1. To conform the SQL standard and JDBC/ODBC protocol.
2. To improve user experience with Spark SQL via
- Using Spark as remote service (microservice).
- Write SQL code that will power reports, dashboards, charts and other data presentation solutions that need to account for criteria modifiable by users through an interface.
- Build a generic integration layer based on the SQL API. The goal is to expose managed data to a wide application ecosystem with a microservice architecture. It is only natural in such a setup to ask for modular and reusable SQL code, that can be executed repeatedly with different parameter values.
3. To achieve feature parity with other systems that support positional parameters.
### Does this PR introduce _any_ user-facing change?
No, the changes extend the existing API.
### How was this patch tested?
By running new tests:
```
$ build/sbt "test:testOnly *AnalysisSuite"
$ build/sbt "test:testOnly *PlanParserSuite"
$ build/sbt "test:testOnly *ParametersSuite"
```
and the affected test suites:
```
$ build/sbt "sql/testOnly *QueryExecutionErrorsSuite"
```
Closes #41568 from MaxGekk/parametrized-query-pos-param.
Authored-by: Max Gekk <ma...@gmail.com>
Signed-off-by: Max Gekk <ma...@gmail.com>
---
.../CheckConnectJvmClientCompatibility.scala | 2 +
.../sql/connect/planner/SparkConnectPlanner.scala | 4 +-
.../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 1 +
.../spark/sql/catalyst/parser/SqlBaseParser.g4 | 5 +-
.../spark/sql/catalyst/analysis/parameters.scala | 95 ++++++--
.../spark/sql/catalyst/parser/AstBuilder.scala | 14 +-
.../sql/catalyst/analysis/AnalysisSuite.scala | 22 +-
.../sql/catalyst/parser/PlanParserSuite.scala | 25 +-
.../scala/org/apache/spark/sql/SparkSession.scala | 34 ++-
.../apache/spark/sql/JavaSparkSessionSuite.java | 28 +++
.../org/apache/spark/sql/ParametersSuite.scala | 265 +++++++++++++++++++--
.../sql/errors/QueryExecutionErrorsSuite.scala | 10 +-
12 files changed, 448 insertions(+), 57 deletions(-)
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 6b648fd152b..acc469672b4 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -227,6 +227,8 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataset"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.executeCommand"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.this"),
+ // TODO(SPARK-44068): Support positional parameters in Scala connect client
+ ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sql"),
// RuntimeConfig
ProblemFilters.exclude[Problem]("org.apache.spark.sql.RuntimeConfig.this"),
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 6ee252d1a58..856d0f06ba4 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -41,7 +41,7 @@ import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, RelationalGroupedDataset, Row, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
@@ -253,7 +253,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
val parser = session.sessionState.sqlParser
val parsedPlan = parser.parsePlan(sql.getQuery)
if (!args.isEmpty) {
- ParameterizedQuery(parsedPlan, args.asScala.mapValues(transformLiteral).toMap)
+ NameParameterizedQuery(parsedPlan, args.asScala.mapValues(transformLiteral).toMap)
} else {
parsedPlan
}
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
index ecd5f5912fd..6c9b3a71266 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
@@ -445,6 +445,7 @@ COLON: ':';
ARROW: '->';
HENT_START: '/*+';
HENT_END: '*/';
+QUESTION: '?';
STRING_LITERAL
: '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 240310a426d..d1e672e9472 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -952,9 +952,10 @@ literalType
constant
: NULL #nullLiteral
- | COLON identifier #parameterLiteral
+ | QUESTION #posParameterLiteral
+ | COLON identifier #namedParameterLiteral
| interval #intervalLiteral
- | literalType stringLit #typeConstructor
+ | literalType stringLit #typeConstructor
| number #numericLiteral
| booleanValue #booleanLiteral
| stringLit+ #stringLiteral
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
index a00f9cec92c..2e3cabce24a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
@@ -25,12 +25,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, PARAMETERIZED
import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.types.DataType
-/**
- * The expression represents a named parameter that should be replaced by a literal.
- *
- * @param name The identifier of the parameter without the marker.
- */
-case class Parameter(name: String) extends LeafExpression with Unevaluable {
+sealed trait Parameter extends LeafExpression with Unevaluable {
override lazy val resolved: Boolean = false
private def unboundError(methodName: String): Nothing = {
@@ -41,17 +36,56 @@ case class Parameter(name: String) extends LeafExpression with Unevaluable {
override def nullable: Boolean = unboundError("nullable")
final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETER)
+
+ def name: String
+}
+
+/**
+ * The expression represents a named parameter that should be replaced by a literal.
+ *
+ * @param name The identifier of the parameter without the marker.
+ */
+case class NamedParameter(name: String) extends Parameter
+
+/**
+ * The expression represents a positional parameter that should be replaced by a literal.
+ *
+ * @param pos An unique position of the parameter in a SQL query text.
+ */
+case class PosParameter(pos: Int) extends Parameter {
+ override def name: String = s"_$pos"
}
/**
* The logical plan representing a parameterized query. It will be removed during analysis after
* the parameters are bind.
*/
-case class ParameterizedQuery(child: LogicalPlan, args: Map[String, Expression])
- extends UnresolvedUnaryNode {
+abstract class ParameterizedQuery(child: LogicalPlan) extends UnresolvedUnaryNode {
+ final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETERIZED_QUERY)
+}
+/**
+ * The logical plan representing a parameterized query with named parameters.
+ *
+ * @param child The parameterized logical plan.
+ * @param args The map of parameter names to its literal values.
+ */
+case class NameParameterizedQuery(child: LogicalPlan, args: Map[String, Expression])
+ extends ParameterizedQuery(child) {
+ assert(args.nonEmpty)
+ override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
+ copy(child = newChild)
+}
+
+/**
+ * The logical plan representing a parameterized query with positional parameters.
+ *
+ * @param child The parameterized logical plan.
+ * @param args The literal values of positional parameters.
+ */
+case class PosParameterizedQuery(child: LogicalPlan, args: Array[Expression])
+ extends ParameterizedQuery(child) {
assert(args.nonEmpty)
- final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETERIZED_QUERY)
override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
copy(child = newChild)
}
@@ -61,6 +95,20 @@ case class ParameterizedQuery(child: LogicalPlan, args: Map[String, Expression])
* user-specified arguments.
*/
object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
+ private def checkArgs(args: Iterable[(String, Expression)]): Unit = {
+ args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) =>
+ expr.failAnalysis(
+ errorClass = "INVALID_SQL_ARG",
+ messageParameters = Map("name" -> name))
+ }
+ }
+
+ private def bind(p: LogicalPlan)(f: PartialFunction[Expression, Expression]): LogicalPlan = {
+ p.resolveExpressionsWithPruning(_.containsPattern(PARAMETER)) (f orElse {
+ case sub: SubqueryExpression => sub.withNewPlan(bind(sub.plan)(f))
+ })
+ }
+
override def apply(plan: LogicalPlan): LogicalPlan = {
if (plan.containsPattern(PARAMETERIZED_QUERY)) {
// One unresolved plan can have at most one ParameterizedQuery.
@@ -71,23 +119,22 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
plan.resolveOperatorsWithPruning(_.containsPattern(PARAMETERIZED_QUERY)) {
// We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE
// relations are not children of `UnresolvedWith`.
- case p @ ParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) =>
- args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) =>
- expr.failAnalysis(
- errorClass = "INVALID_SQL_ARG",
- messageParameters = Map("name" -> name))
- }
+ case p @ NameParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) =>
+ checkArgs(args)
+ bind(child) { case NamedParameter(name) if args.contains(name) => args(name) }
+
+ case p @ PosParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) =>
+ val indexedArgs = args.zipWithIndex
+ checkArgs(indexedArgs.map(arg => (s"_${arg._2}", arg._1)))
+
+ val positions = scala.collection.mutable.Set.empty[Int]
+ bind(child) { case p @ PosParameter(pos) => positions.add(pos); p }
+ val posToIndex = positions.toSeq.sorted.zipWithIndex.toMap
- def bind(p: LogicalPlan): LogicalPlan = {
- p.resolveExpressionsWithPruning(_.containsPattern(PARAMETER)) {
- case Parameter(name) if args.contains(name) =>
- args(name)
- case sub: SubqueryExpression => sub.withNewPlan(bind(sub.plan))
- }
+ bind(child) {
+ case PosParameter(pos) if posToIndex.contains(pos) && args.size > posToIndex(pos) =>
+ args(posToIndex(pos))
}
- val res = bind(child)
- res.copyTagsFrom(p)
- res
case _ => plan
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 07721424a86..ca62de12e7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -5113,7 +5113,17 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
/**
* Create a named parameter which represents a literal with a non-bound value and unknown type.
* */
- override def visitParameterLiteral(ctx: ParameterLiteralContext): Expression = withOrigin(ctx) {
- Parameter(ctx.identifier().getText)
+ override def visitNamedParameterLiteral(
+ ctx: NamedParameterLiteralContext): Expression = withOrigin(ctx) {
+ NamedParameter(ctx.identifier().getText)
+ }
+
+ /**
+ * Create a positional parameter which represents a literal
+ * with a non-bound value and unknown type.
+ * */
+ override def visitPosParameterLiteral(
+ ctx: PosParameterLiteralContext): Expression = withOrigin(ctx) {
+ PosParameter(ctx.QUESTION().getSymbol.getStartIndex)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 1e844e22bec..dae42453f0d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -1370,7 +1370,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
test("SPARK-41271: bind named parameters to literals") {
CTERelationDef.curId.set(0)
- val actual1 = ParameterizedQuery(
+ val actual1 = NameParameterizedQuery(
child = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT :limitA"),
args = Map("limitA" -> Literal(10))).analyze
CTERelationDef.curId.set(0)
@@ -1378,7 +1378,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
comparePlans(actual1, expected1)
// Ignore unused arguments
CTERelationDef.curId.set(0)
- val actual2 = ParameterizedQuery(
+ val actual2 = NameParameterizedQuery(
child = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < :param2"),
args = Map("param1" -> Literal(10), "param2" -> Literal(20))).analyze
CTERelationDef.curId.set(0)
@@ -1386,6 +1386,24 @@ class AnalysisSuite extends AnalysisTest with Matchers {
comparePlans(actual2, expected2)
}
+ test("SPARK-44066: bind positional parameters to literals") {
+ CTERelationDef.curId.set(0)
+ val actual1 = PosParameterizedQuery(
+ child = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT ?"),
+ args = Array(Literal(10))).analyze
+ CTERelationDef.curId.set(0)
+ val expected1 = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT 10").analyze
+ comparePlans(actual1, expected1)
+ // Ignore unused arguments
+ CTERelationDef.curId.set(0)
+ val actual2 = PosParameterizedQuery(
+ child = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < ?"),
+ args = Array(Literal(20), Literal(10))).analyze
+ CTERelationDef.curId.set(0)
+ val expected2 = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < 20").analyze
+ comparePlans(actual2, expected2)
+ }
+
test("SPARK-41489: type of filter expression should be a bool") {
assertAnalysisErrorClass(parsePlan(
s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 5a28ef847dc..ded8aaf7430 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.parser
import org.apache.spark.SparkThrowable
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Parameter, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases}
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedParameter, PosParameter, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{PercentileCont, PercentileDisc}
import org.apache.spark.sql.catalyst.plans._
@@ -1630,18 +1630,18 @@ class PlanParserSuite extends AnalysisTest {
test("SPARK-41271: parsing of named parameters") {
comparePlans(
parsePlan("SELECT :param_1"),
- Project(UnresolvedAlias(Parameter("param_1"), None) :: Nil, OneRowRelation()))
+ Project(UnresolvedAlias(NamedParameter("param_1"), None) :: Nil, OneRowRelation()))
comparePlans(
parsePlan("SELECT abs(:1Abc)"),
Project(UnresolvedAlias(
UnresolvedFunction(
"abs" :: Nil,
- Parameter("1Abc") :: Nil,
+ NamedParameter("1Abc") :: Nil,
isDistinct = false), None) :: Nil,
OneRowRelation()))
comparePlans(
parsePlan("SELECT * FROM a LIMIT :limitA"),
- table("a").select(star()).limit(Parameter("limitA")))
+ table("a").select(star()).limit(NamedParameter("limitA")))
// Invalid empty name and invalid symbol in a name
checkError(
exception = parseException(s"SELECT :-"),
@@ -1661,4 +1661,21 @@ class PlanParserSuite extends AnalysisTest {
Seq(Literal("abc")) :: Nil).as("tbl").select($"interval")
)
}
+
+ test("SPARK-44066: parsing of positional parameters") {
+ comparePlans(
+ parsePlan("SELECT ?"),
+ Project(UnresolvedAlias(PosParameter(7), None) :: Nil, OneRowRelation()))
+ comparePlans(
+ parsePlan("SELECT abs(?)"),
+ Project(UnresolvedAlias(
+ UnresolvedFunction(
+ "abs" :: Nil,
+ PosParameter(11) :: Nil,
+ isDistinct = false), None) :: Nil,
+ OneRowRelation()))
+ comparePlans(
+ parsePlan("SELECT * FROM a LIMIT ?"),
+ table("a").select(star()).limit(PosParameter(22)))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 642006fb8dc..2a1c2474bc6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -35,7 +35,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst._
-import org.apache.spark.sql.catalyst.analysis.{ParameterizedQuery, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation}
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
@@ -609,6 +609,36 @@ class SparkSession private(
| Everything else |
* ----------------- */
+ /**
+ * Executes a SQL query substituting positional parameters by the given arguments,
+ * returning the result as a `DataFrame`.
+ * This API eagerly runs DDL/DML commands, but not for SELECT queries.
+ *
+ * @param sqlText A SQL statement with positional parameters to execute.
+ * @param args An array of Java/Scala objects that can be converted to
+ * SQL literal expressions. See
+ * <a href="https://spark.apache.org/docs/latest/sql-ref-datatypes.html">
+ * Supported Data Types</a> for supported value types in Scala/Java.
+ * For example, 1, "Steven", LocalDate.of(2023, 4, 2).
+ * A value can be also a `Column` of literal expression, in that case
+ * it is taken as is.
+ *
+ * @since 3.5.0
+ */
+ @Experimental
+ def sql(sqlText: String, args: Array[_]): DataFrame = withActive {
+ val tracker = new QueryPlanningTracker
+ val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
+ val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
+ if (args.nonEmpty) {
+ PosParameterizedQuery(parsedPlan, args.map(lit(_).expr))
+ } else {
+ parsedPlan
+ }
+ }
+ Dataset.ofRows(self, plan, tracker)
+ }
+
/**
* Executes a SQL query substituting named parameters by the given arguments,
* returning the result as a `DataFrame`.
@@ -632,7 +662,7 @@ class SparkSession private(
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
if (args.nonEmpty) {
- ParameterizedQuery(parsedPlan, args.mapValues(lit(_).expr).toMap)
+ NameParameterizedQuery(parsedPlan, args.mapValues(lit(_).expr).toMap)
} else {
parsedPlan
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaSparkSessionSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaSparkSessionSuite.java
index b1df377936d..0d6d773d930 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaSparkSessionSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaSparkSessionSuite.java
@@ -18,11 +18,13 @@
package test.org.apache.spark.sql;
import org.apache.spark.sql.*;
+import org.apache.spark.sql.test.TestSparkSession;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
public class JavaSparkSessionSuite {
@@ -54,4 +56,30 @@ public class JavaSparkSessionSuite {
Assert.assertEquals(spark.conf().get(e.getKey()), e.getValue().toString());
}
}
+
+ @Test
+ public void testPositionalParameters() {
+ spark = new TestSparkSession();
+
+ int[] emptyArgs = {};
+ List<Row> collected1 = spark.sql("select 'abc'", emptyArgs).collectAsList();
+ Assert.assertEquals("abc", collected1.get(0).getString(0));
+
+ Object[] singleArg = new String[] { "abc" };
+ List<Row> collected2 = spark.sql("select ?", singleArg).collectAsList();
+ Assert.assertEquals("abc", collected2.get(0).getString(0));
+
+ int[] args = new int[] { 1, 2, 3 };
+ List<Row> collected3 = spark.sql("select ?, ?, ?", args).collectAsList();
+ Row r0 = collected3.get(0);
+ Assert.assertEquals(1, r0.getInt(0));
+ Assert.assertEquals(2, r0.getInt(1));
+ Assert.assertEquals(3, r0.getInt(2));
+
+ Object[] mixedArgs = new Object[] { 1, "abc" };
+ List<Row> collected4 = spark.sql("select ?, ?", mixedArgs).collectAsList();
+ Row r1 = collected4.get(0);
+ Assert.assertEquals(1, r1.getInt(0));
+ Assert.assertEquals("abc", r1.getString(1));
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
index 985d0373c4f..725956e259b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.test.SharedSparkSession
class ParametersSuite extends QueryTest with SharedSparkSession {
- test("bind parameters") {
+ test("bind named parameters") {
val sqlText =
"""
|SELECT id, id % :div as c0
@@ -42,6 +42,23 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
Row(true))
}
+ test("bind positional parameters") {
+ val sqlText =
+ """
+ |SELECT id, id % ? as c0
+ |FROM VALUES (0), (1), (2), (3), (4), (5), (6), (7), (8), (9) AS t(id)
+ |WHERE id < ?
+ |""".stripMargin
+ val args = Array(3, 4L)
+ checkAnswer(
+ spark.sql(sqlText, args),
+ Row(0, 0) :: Row(1, 1) :: Row(2, 2) :: Row(3, 0) :: Nil)
+
+ checkAnswer(
+ spark.sql("""SELECT contains('Spark \'SQL\'', ?)""", Array("SQL")),
+ Row(true))
+ }
+
test("parameter binding is case sensitive") {
checkAnswer(
spark.sql("SELECT :p, :P", Map("p" -> 1, "P" -> 2)),
@@ -60,7 +77,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
stop = 8))
}
- test("parameters in CTE") {
+ test("named parameters in CTE") {
val sqlText =
"""
|WITH w1 AS (SELECT :p1 AS p)
@@ -72,7 +89,19 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
Row(3))
}
- test("parameters in nested CTE") {
+ test("positional parameters in CTE") {
+ val sqlText =
+ """
+ |WITH w1 AS (SELECT ? AS p)
+ |SELECT p + ? FROM w1
+ |""".stripMargin
+ val args = Array(1, 2)
+ checkAnswer(
+ spark.sql(sqlText, args),
+ Row(3))
+ }
+
+ test("named parameters in nested CTE") {
val sqlText =
"""
|WITH w1 AS
@@ -85,7 +114,20 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
Row(6))
}
- test("parameters in subquery expression") {
+ test("positional parameters in nested CTE") {
+ val sqlText =
+ """
+ |WITH w1 AS
+ | (WITH w2 AS (SELECT ? AS p) SELECT p + ? AS p2 FROM w2)
+ |SELECT p2 + ? FROM w1
+ |""".stripMargin
+ val args = Array(1, 2, 3)
+ checkAnswer(
+ spark.sql(sqlText, args),
+ Row(6))
+ }
+
+ test("named parameters in subquery expression") {
val sqlText = "SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2"
val args = Map("p1" -> 1, "p2" -> 2)
checkAnswer(
@@ -93,7 +135,15 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
Row(12))
}
- test("parameters in nested subquery expression") {
+ test("positional parameters in subquery expression") {
+ val sqlText = "SELECT (SELECT max(id) + ? FROM range(10)) + ?"
+ val args = Array(1, 2)
+ checkAnswer(
+ spark.sql(sqlText, args),
+ Row(12))
+ }
+
+ test("named parameters in nested subquery expression") {
val sqlText = "SELECT (SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2) + :p3"
val args = Map("p1" -> 1, "p2" -> 2, "p3" -> 3)
checkAnswer(
@@ -101,7 +151,15 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
Row(15))
}
- test("parameters in subquery expression inside CTE") {
+ test("positional parameters in nested subquery expression") {
+ val sqlText = "SELECT (SELECT (SELECT max(id) + ? FROM range(10)) + ?) + ?"
+ val args = Array(1, 2, 3)
+ checkAnswer(
+ spark.sql(sqlText, args),
+ Row(15))
+ }
+
+ test("named parameters in subquery expression inside CTE") {
val sqlText =
"""
|WITH w1 AS (SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2 AS p)
@@ -113,7 +171,19 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
Row(15))
}
- test("parameter in identifier clause") {
+ test("positional parameters in subquery expression inside CTE") {
+ val sqlText =
+ """
+ |WITH w1 AS (SELECT (SELECT max(id) + ? FROM range(10)) + ? AS p)
+ |SELECT p + ? FROM w1
+ |""".stripMargin
+ val args = Array(1, 2, 3)
+ checkAnswer(
+ spark.sql(sqlText, args),
+ Row(15))
+ }
+
+ test("named parameter in identifier clause") {
val sqlText =
"SELECT IDENTIFIER('T.' || :p1 || '1') FROM VALUES(1) T(c1)"
val args = Map("p1" -> "c")
@@ -122,7 +192,16 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
Row(1))
}
- test("parameter in identifier clause in DDL and utility commands") {
+ test("positional parameter in identifier clause") {
+ val sqlText =
+ "SELECT IDENTIFIER('T.' || ? || '1') FROM VALUES(1) T(c1)"
+ val args = Array("c")
+ checkAnswer(
+ spark.sql(sqlText, args),
+ Row(1))
+ }
+
+ test("named parameter in identifier clause in DDL and utility commands") {
spark.sql("CREATE VIEW IDENTIFIER(:p1)(c1) AS SELECT 1", args = Map("p1" -> "v"))
spark.sql("ALTER VIEW IDENTIFIER(:p1) AS SELECT 2 AS c1", args = Map("p1" -> "v"))
checkAnswer(
@@ -131,7 +210,16 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
spark.sql("DROP VIEW IDENTIFIER(:p1)", args = Map("p1" -> "v"))
}
- test("parameters in INSERT") {
+ test("positional parameter in identifier clause in DDL and utility commands") {
+ spark.sql("CREATE VIEW IDENTIFIER(?)(c1) AS SELECT 1", args = Array("v"))
+ spark.sql("ALTER VIEW IDENTIFIER(?) AS SELECT 2 AS c1", args = Array("v"))
+ checkAnswer(
+ spark.sql("SHOW COLUMNS FROM IDENTIFIER(?)", args = Array("v")),
+ Row("c1"))
+ spark.sql("DROP VIEW IDENTIFIER(?)", args = Array("v"))
+ }
+
+ test("named parameters in INSERT") {
withTable("t") {
sql("CREATE TABLE t (col INT) USING json")
spark.sql("INSERT INTO t SELECT :p", Map("p" -> 1))
@@ -139,7 +227,15 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
}
}
- test("parameters not allowed in view body ") {
+ test("positional parameters in INSERT") {
+ withTable("t") {
+ sql("CREATE TABLE t (col INT) USING json")
+ spark.sql("INSERT INTO t SELECT ?", Array(1))
+ checkAnswer(spark.table("t"), Row(1))
+ }
+ }
+
+ test("named parameters not allowed in view body ") {
val sqlText = "CREATE VIEW v AS SELECT :p AS p"
val args = Map("p" -> 1)
checkError(
@@ -154,7 +250,22 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
stop = sqlText.length - 1))
}
- test("parameters not allowed in view body - WITH and scalar subquery") {
+ test("positional parameters not allowed in view body ") {
+ val sqlText = "CREATE VIEW v AS SELECT ? AS p"
+ val args = Array(1)
+ checkError(
+ exception = intercept[AnalysisException] {
+ spark.sql(sqlText, args)
+ },
+ errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT",
+ parameters = Map("statement" -> "CREATE VIEW body"),
+ context = ExpectedContext(
+ fragment = sqlText,
+ start = 0,
+ stop = sqlText.length - 1))
+ }
+
+ test("named parameters not allowed in view body - WITH and scalar subquery") {
val sqlText = "CREATE VIEW v AS WITH cte(a) AS (SELECT (SELECT :p) AS a) SELECT a FROM cte"
val args = Map("p" -> 1)
checkError(
@@ -169,7 +280,22 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
stop = sqlText.length - 1))
}
- test("parameters not allowed in view body - nested WITH and EXIST") {
+ test("positional parameters not allowed in view body - WITH and scalar subquery") {
+ val sqlText = "CREATE VIEW v AS WITH cte(a) AS (SELECT (SELECT ?) AS a) SELECT a FROM cte"
+ val args = Array(1)
+ checkError(
+ exception = intercept[AnalysisException] {
+ spark.sql(sqlText, args)
+ },
+ errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT",
+ parameters = Map("statement" -> "CREATE VIEW body"),
+ context = ExpectedContext(
+ fragment = sqlText,
+ start = 0,
+ stop = sqlText.length - 1))
+ }
+
+ test("named parameters not allowed in view body - nested WITH and EXIST") {
val sqlText =
"""CREATE VIEW v AS
|SELECT a as a
@@ -188,7 +314,26 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
stop = sqlText.length - 1))
}
- test("non-substituted parameters") {
+ test("positional parameters not allowed in view body - nested WITH and EXIST") {
+ val sqlText =
+ """CREATE VIEW v AS
+ |SELECT a as a
+ |FROM (WITH cte(a) AS (SELECT CASE WHEN EXISTS(SELECT ?) THEN 1 END AS a)
+ |SELECT a FROM cte)""".stripMargin
+ val args = Array(1)
+ checkError(
+ exception = intercept[AnalysisException] {
+ spark.sql(sqlText, args)
+ },
+ errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT",
+ parameters = Map("statement" -> "CREATE VIEW body"),
+ context = ExpectedContext(
+ fragment = sqlText,
+ start = 0,
+ stop = sqlText.length - 1))
+ }
+
+ test("non-substituted named parameters") {
checkError(
exception = intercept[AnalysisException] {
spark.sql("select :abc, :def", Map("abc" -> 1))
@@ -211,7 +356,30 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
stop = 10))
}
- test("literal argument of `sql()`") {
+ test("non-substituted positional parameters") {
+ checkError(
+ exception = intercept[AnalysisException] {
+ spark.sql("select ?, ?", Array(1))
+ },
+ errorClass = "UNBOUND_SQL_PARAMETER",
+ parameters = Map("name" -> "_10"),
+ context = ExpectedContext(
+ fragment = "?",
+ start = 10,
+ stop = 10))
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("select ?").collect()
+ },
+ errorClass = "UNBOUND_SQL_PARAMETER",
+ parameters = Map("name" -> "_7"),
+ context = ExpectedContext(
+ fragment = "?",
+ start = 7,
+ stop = 7))
+ }
+
+ test("literal argument of named parameter in `sql()`") {
val sqlText =
"""SELECT s FROM VALUES ('Jeff /*__*/ Green'), ('E\'Twaun Moore'), ('Vander Blue') AS t(s)
|WHERE s = :player_name""".stripMargin
@@ -249,4 +417,73 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
.toInstant) :: Nil)
}
}
+
+ test("literal argument of positional parameter in `sql()`") {
+ val sqlText =
+ """SELECT s FROM VALUES ('Jeff /*__*/ Green'), ('E\'Twaun Moore'), ('Vander Blue') AS t(s)
+ |WHERE s = ?""".stripMargin
+ checkAnswer(
+ spark.sql(sqlText, args = Array(lit("E'Twaun Moore"))),
+ Row("E'Twaun Moore") :: Nil)
+ checkAnswer(
+ spark.sql(sqlText, args = Array(lit("Vander Blue--comment"))),
+ Nil)
+ checkAnswer(
+ spark.sql(sqlText, args = Array(lit("Jeff /*__*/ Green"))),
+ Row("Jeff /*__*/ Green") :: Nil)
+
+ withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
+ checkAnswer(
+ spark.sql(
+ sqlText = """
+ |SELECT d
+ |FROM VALUES (DATE'1970-01-01'), (DATE'2023-12-31') AS t(d)
+ |WHERE d < ?
+ |""".stripMargin,
+ args = Array(lit(LocalDate.of(2023, 4, 1)))),
+ Row(LocalDate.of(1970, 1, 1)) :: Nil)
+ checkAnswer(
+ spark.sql(
+ sqlText = """
+ |SELECT d
+ |FROM VALUES (TIMESTAMP_LTZ'1970-01-01 01:02:03 Europe/Amsterdam'),
+ | (TIMESTAMP_LTZ'2023-12-31 04:05:06 America/Los_Angeles') AS t(d)
+ |WHERE d < ?
+ |""".stripMargin,
+ args = Array(lit(Instant.parse("2023-04-01T00:00:00Z")))),
+ Row(LocalDateTime.of(1970, 1, 1, 1, 2, 3)
+ .atZone(ZoneId.of("Europe/Amsterdam"))
+ .toInstant) :: Nil)
+ }
+ }
+
+ test("unused positional arguments") {
+ checkAnswer(
+ spark.sql("SELECT ?, ?", Array(1, "abc", 3.14f)),
+ Row(1, "abc"))
+ }
+
+ test("mixing of positional and named parameters") {
+ checkError(
+ exception = intercept[AnalysisException] {
+ spark.sql("select :param1, ?", Map("param1" -> 1))
+ },
+ errorClass = "UNBOUND_SQL_PARAMETER",
+ parameters = Map("name" -> "_16"),
+ context = ExpectedContext(
+ fragment = "?",
+ start = 16,
+ stop = 16))
+
+ checkError(
+ exception = intercept[AnalysisException] {
+ spark.sql("select :param1, ?", Array(1))
+ },
+ errorClass = "UNBOUND_SQL_PARAMETER",
+ parameters = Map("name" -> "param1"),
+ context = ExpectedContext(
+ fragment = ":param1",
+ start = 7,
+ stop = 13))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index 61b3610e64e..8f47b06d855 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -29,7 +29,7 @@ import org.mockito.Mockito.{mock, spy, when}
import org.apache.spark._
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.analysis.{Parameter, UnresolvedGenerator}
+import org.apache.spark.sql.catalyst.analysis.{NamedParameter, UnresolvedGenerator}
import org.apache.spark.sql.catalyst.expressions.{Grouping, Literal, RowNumber}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
@@ -866,26 +866,26 @@ class QueryExecutionErrorsSuite
test("INTERNAL_ERROR: Calling eval on Unevaluable expression") {
val e = intercept[SparkException] {
- Parameter("foo").eval()
+ NamedParameter("foo").eval()
}
checkError(
exception = e,
errorClass = "INTERNAL_ERROR",
- parameters = Map("message" -> "Cannot evaluate expression: parameter(foo)"),
+ parameters = Map("message" -> "Cannot evaluate expression: namedparameter(foo)"),
sqlState = "XX000")
}
test("INTERNAL_ERROR: Calling doGenCode on unresolved") {
val e = intercept[SparkException] {
val ctx = new CodegenContext
- Grouping(Parameter("foo")).genCode(ctx)
+ Grouping(NamedParameter("foo")).genCode(ctx)
}
checkError(
exception = e,
errorClass = "INTERNAL_ERROR",
parameters = Map(
"message" -> ("Cannot generate code for expression: " +
- "grouping(parameter(foo))")),
+ "grouping(namedparameter(foo))")),
sqlState = "XX000")
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org