You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2022/11/18 19:15:21 UTC
[spark] branch master updated: [SPARK-41173][SQL] Move `require()` out from the constructors of string expressions
This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b96ddce77aa [SPARK-41173][SQL] Move `require()` out from the constructors of string expressions
b96ddce77aa is described below
commit b96ddce77aa3f17eb0dea95083a9ac35d6077a94
Author: yangjie01 <ya...@baidu.com>
AuthorDate: Fri Nov 18 22:14:32 2022 +0300
[SPARK-41173][SQL] Move `require()` out from the constructors of string expressions
### What changes were proposed in this pull request?
This pr aims to move `require()` out from the constructors of string expressions, include `ConcatWs` and `FormatString`.
The args number checking logic moved into `checkInputDataTypes()`.
### Why are the changes needed?
Migration onto error classes unifies Spark SQL error messages.
### Does this PR introduce _any_ user-facing change?
Yes. The PR changes user-facing error messages.
### How was this patch tested?
Pass GitHub Actions
Closes #38705 from LuciferYang/SPARK-41173.
Authored-by: yangjie01 <ya...@baidu.com>
Signed-off-by: Max Gekk <ma...@gmail.com>
---
.../catalyst/expressions/stringExpressions.scala | 35 +++++++++++++++++++---
.../results/ansi/string-functions.sql.out | 34 +++++++++++++++++++--
.../sql-tests/results/string-functions.sql.out | 34 +++++++++++++++++++--
3 files changed, 95 insertions(+), 8 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 45bed3e2387..60b56f4fef7 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -67,8 +67,6 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
case class ConcatWs(children: Seq[Expression])
extends Expression with ImplicitCastInputTypes {
- require(children.nonEmpty, s"$prettyName requires at least one argument.")
-
override def prettyName: String = "concat_ws"
/** The 1st child (separator) is str, and rest are either str or array of str. */
@@ -82,6 +80,21 @@ case class ConcatWs(children: Seq[Expression])
override def nullable: Boolean = children.head.nullable
override def foldable: Boolean = children.forall(_.foldable)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.isEmpty) {
+ DataTypeMismatch(
+ errorSubClass = "WRONG_NUM_ARGS",
+ messageParameters = Map(
+ "functionName" -> toSQLId(prettyName),
+ "expectedNum" -> "> 0",
+ "actualNum" -> children.length.toString
+ )
+ )
+ } else {
+ super.checkInputDataTypes()
+ }
+ }
+
override def eval(input: InternalRow): Any = {
val flatInputs = children.flatMap { child =>
child.eval(input) match {
@@ -1662,8 +1675,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera
// scalastyle:on line.size.limit
case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes {
- require(children.nonEmpty, s"$prettyName() should take at least 1 argument")
- if (!SQLConf.get.getConf(SQLConf.ALLOW_ZERO_INDEX_IN_FORMAT_STRING)) {
+ if (children.nonEmpty && !SQLConf.get.getConf(SQLConf.ALLOW_ZERO_INDEX_IN_FORMAT_STRING)) {
checkArgumentIndexNotZero(children(0))
}
@@ -1675,6 +1687,21 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
override def inputTypes: Seq[AbstractDataType] =
StringType :: List.fill(children.size - 1)(AnyDataType)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.isEmpty) {
+ DataTypeMismatch(
+ errorSubClass = "WRONG_NUM_ARGS",
+ messageParameters = Map(
+ "functionName" -> toSQLId(prettyName),
+ "expectedNum" -> "> 0",
+ "actualNum" -> children.length.toString
+ )
+ )
+ } else {
+ super.checkInputDataTypes()
+ }
+ }
+
override def eval(input: InternalRow): Any = {
val pattern = children(0).eval(input)
if (pattern == null) {
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
index 5b82cfa957d..41f1922f8bd 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
@@ -5,7 +5,22 @@ select concat_ws()
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-requirement failed: concat_ws requires at least one argument.; line 1 pos 7
+{
+ "errorClass" : "DATATYPE_MISMATCH.WRONG_NUM_ARGS",
+ "messageParameters" : {
+ "actualNum" : "0",
+ "expectedNum" : "> 0",
+ "functionName" : "`concat_ws`",
+ "sqlExpr" : "\"concat_ws()\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 18,
+ "fragment" : "concat_ws()"
+ } ]
+}
-- !query
@@ -14,7 +29,22 @@ select format_string()
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-requirement failed: format_string() should take at least 1 argument; line 1 pos 7
+{
+ "errorClass" : "DATATYPE_MISMATCH.WRONG_NUM_ARGS",
+ "messageParameters" : {
+ "actualNum" : "0",
+ "expectedNum" : "> 0",
+ "functionName" : "`format_string`",
+ "sqlExpr" : "\"format_string()\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 22,
+ "fragment" : "format_string()"
+ } ]
+}
-- !query
diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
index 58a36b3299f..4bcb69ed773 100644
--- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
@@ -5,7 +5,22 @@ select concat_ws()
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-requirement failed: concat_ws requires at least one argument.; line 1 pos 7
+{
+ "errorClass" : "DATATYPE_MISMATCH.WRONG_NUM_ARGS",
+ "messageParameters" : {
+ "actualNum" : "0",
+ "expectedNum" : "> 0",
+ "functionName" : "`concat_ws`",
+ "sqlExpr" : "\"concat_ws()\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 18,
+ "fragment" : "concat_ws()"
+ } ]
+}
-- !query
@@ -14,7 +29,22 @@ select format_string()
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-requirement failed: format_string() should take at least 1 argument; line 1 pos 7
+{
+ "errorClass" : "DATATYPE_MISMATCH.WRONG_NUM_ARGS",
+ "messageParameters" : {
+ "actualNum" : "0",
+ "expectedNum" : "> 0",
+ "functionName" : "`format_string`",
+ "sqlExpr" : "\"format_string()\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 22,
+ "fragment" : "format_string()"
+ } ]
+}
-- !query
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org