You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2022/11/18 19:15:21 UTC

[spark] branch master updated: [SPARK-41173][SQL] Move `require()` out from the constructors of string expressions

This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b96ddce77aa [SPARK-41173][SQL] Move `require()` out from the constructors of string expressions
b96ddce77aa is described below

commit b96ddce77aa3f17eb0dea95083a9ac35d6077a94
Author: yangjie01 <ya...@baidu.com>
AuthorDate: Fri Nov 18 22:14:32 2022 +0300

    [SPARK-41173][SQL] Move `require()` out from the constructors of string expressions
    
    ### What changes were proposed in this pull request?
    This pr aims to move `require()` out from the constructors of string expressions, include  `ConcatWs` and `FormatString`.
    The args number checking logic moved into `checkInputDataTypes()`.
    
    ### Why are the changes needed?
    Migration onto error classes unifies Spark SQL error messages.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. The PR changes user-facing error messages.
    
    ### How was this patch tested?
    Pass GitHub Actions
    
    Closes #38705 from LuciferYang/SPARK-41173.
    
    Authored-by: yangjie01 <ya...@baidu.com>
    Signed-off-by: Max Gekk <ma...@gmail.com>
---
 .../catalyst/expressions/stringExpressions.scala   | 35 +++++++++++++++++++---
 .../results/ansi/string-functions.sql.out          | 34 +++++++++++++++++++--
 .../sql-tests/results/string-functions.sql.out     | 34 +++++++++++++++++++--
 3 files changed, 95 insertions(+), 8 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 45bed3e2387..60b56f4fef7 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -67,8 +67,6 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
 case class ConcatWs(children: Seq[Expression])
   extends Expression with ImplicitCastInputTypes {
 
-  require(children.nonEmpty, s"$prettyName requires at least one argument.")
-
   override def prettyName: String = "concat_ws"
 
   /** The 1st child (separator) is str, and rest are either str or array of str. */
@@ -82,6 +80,21 @@ case class ConcatWs(children: Seq[Expression])
   override def nullable: Boolean = children.head.nullable
   override def foldable: Boolean = children.forall(_.foldable)
 
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (children.isEmpty) {
+      DataTypeMismatch(
+        errorSubClass = "WRONG_NUM_ARGS",
+        messageParameters = Map(
+          "functionName" -> toSQLId(prettyName),
+          "expectedNum" -> "> 0",
+          "actualNum" -> children.length.toString
+        )
+      )
+    } else {
+      super.checkInputDataTypes()
+    }
+  }
+
   override def eval(input: InternalRow): Any = {
     val flatInputs = children.flatMap { child =>
       child.eval(input) match {
@@ -1662,8 +1675,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera
 // scalastyle:on line.size.limit
 case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes {
 
-  require(children.nonEmpty, s"$prettyName() should take at least 1 argument")
-  if (!SQLConf.get.getConf(SQLConf.ALLOW_ZERO_INDEX_IN_FORMAT_STRING)) {
+  if (children.nonEmpty && !SQLConf.get.getConf(SQLConf.ALLOW_ZERO_INDEX_IN_FORMAT_STRING)) {
     checkArgumentIndexNotZero(children(0))
   }
 
@@ -1675,6 +1687,21 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
   override def inputTypes: Seq[AbstractDataType] =
     StringType :: List.fill(children.size - 1)(AnyDataType)
 
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (children.isEmpty) {
+      DataTypeMismatch(
+        errorSubClass = "WRONG_NUM_ARGS",
+        messageParameters = Map(
+          "functionName" -> toSQLId(prettyName),
+          "expectedNum" -> "> 0",
+          "actualNum" -> children.length.toString
+        )
+      )
+    } else {
+      super.checkInputDataTypes()
+    }
+  }
+
   override def eval(input: InternalRow): Any = {
     val pattern = children(0).eval(input)
     if (pattern == null) {
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
index 5b82cfa957d..41f1922f8bd 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
@@ -5,7 +5,22 @@ select concat_ws()
 struct<>
 -- !query output
 org.apache.spark.sql.AnalysisException
-requirement failed: concat_ws requires at least one argument.; line 1 pos 7
+{
+  "errorClass" : "DATATYPE_MISMATCH.WRONG_NUM_ARGS",
+  "messageParameters" : {
+    "actualNum" : "0",
+    "expectedNum" : "> 0",
+    "functionName" : "`concat_ws`",
+    "sqlExpr" : "\"concat_ws()\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 18,
+    "fragment" : "concat_ws()"
+  } ]
+}
 
 
 -- !query
@@ -14,7 +29,22 @@ select format_string()
 struct<>
 -- !query output
 org.apache.spark.sql.AnalysisException
-requirement failed: format_string() should take at least 1 argument; line 1 pos 7
+{
+  "errorClass" : "DATATYPE_MISMATCH.WRONG_NUM_ARGS",
+  "messageParameters" : {
+    "actualNum" : "0",
+    "expectedNum" : "> 0",
+    "functionName" : "`format_string`",
+    "sqlExpr" : "\"format_string()\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 22,
+    "fragment" : "format_string()"
+  } ]
+}
 
 
 -- !query
diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
index 58a36b3299f..4bcb69ed773 100644
--- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
@@ -5,7 +5,22 @@ select concat_ws()
 struct<>
 -- !query output
 org.apache.spark.sql.AnalysisException
-requirement failed: concat_ws requires at least one argument.; line 1 pos 7
+{
+  "errorClass" : "DATATYPE_MISMATCH.WRONG_NUM_ARGS",
+  "messageParameters" : {
+    "actualNum" : "0",
+    "expectedNum" : "> 0",
+    "functionName" : "`concat_ws`",
+    "sqlExpr" : "\"concat_ws()\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 18,
+    "fragment" : "concat_ws()"
+  } ]
+}
 
 
 -- !query
@@ -14,7 +29,22 @@ select format_string()
 struct<>
 -- !query output
 org.apache.spark.sql.AnalysisException
-requirement failed: format_string() should take at least 1 argument; line 1 pos 7
+{
+  "errorClass" : "DATATYPE_MISMATCH.WRONG_NUM_ARGS",
+  "messageParameters" : {
+    "actualNum" : "0",
+    "expectedNum" : "> 0",
+    "functionName" : "`format_string`",
+    "sqlExpr" : "\"format_string()\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 22,
+    "fragment" : "format_string()"
+  } ]
+}
 
 
 -- !query


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org