You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2023/09/27 17:02:56 UTC
[spark] branch master updated: [SPARK-44838][SQL] raise_error improvement
This is an automated email from the ASF dual-hosted git repository.
gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9109d7037f4 [SPARK-44838][SQL] raise_error improvement
9109d7037f4 is described below
commit 9109d7037f44158e72d14019eb33f9c7b8838868
Author: srielau <se...@rielau.com>
AuthorDate: Wed Sep 27 10:02:44 2023 -0700
[SPARK-44838][SQL] raise_error improvement
### What changes were proposed in this pull request?
Extend the raise_error() function to a two-argument version:
raise_error(errorClassStr, errorParamMap)
This new form will accept any error class defined in error-classes.json and require Map<String, String> to provide values for the parameters in the error classes template.
Externally an error raised via raise_error() is indistinguishable from an error raised from within the Spark engine.
The single-parameter raise_error(str) will raise USER_RAISED_EXCEPTION (SQLSTATE P0001 - borrowed from PostgreSQL).
USER_RAISED_EXCEPTION text is: "<errorMessage>" which will be filled in with the str - value.
We will also provide `spark.sql.legacy.raiseErrorWithoutErrorClass` (default: false) to revert to the old behavior for the single-parameter version.
Naturally assert_true() will also return `USER_RAISED_EXCEPTION`.
#### Examples
```
SELECT raise_error('VIEW_NOT_FOUND', map('relationName', '`v1`');
[VIEW_NOT_FOUND] The view `v1` cannot be found. Verify the spelling ...
SELECT raise_error('Error!');
[USER_RAISED_EXCEPTION] Error!
SELECT assert_true(1 < 0);
[USER_RAISED_EXCEPTION] '(1 < 0)' is not true!
SELECT assert_true(1 < 0, 'bad!')
[USER_RAISED_EXCEPTION] bad!
```
### Why are the changes needed?
This change moves raise_error() and assert_true() to the new error frame work.
It greatly expands the ability of users to raise error messages which can be intercepted via SQLSTATE and/or error class.
### Does this PR introduce _any_ user-facing change?
Yes, the result of assert_true() changes and raise_error() gains a new signature.
### How was this patch tested?
Run existing QA and add new tests for assert_true and raise_error
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #42985 from srielau/SPARK-44838-raise_error.
Lead-authored-by: srielau <se...@rielau.com>
Co-authored-by: Serge Rielau <sr...@users.noreply.github.com>
Co-authored-by: Wenchen Fan <we...@databricks.com>
Signed-off-by: Gengliang Wang <ge...@apache.org>
---
.../src/main/resources/error/error-classes.json | 26 +++
.../org/apache/spark/ErrorClassesJSONReader.scala | 18 ++
.../org/apache/spark/SparkThrowableHelper.scala | 10 +-
.../scala/org/apache/spark/sql/functions.scala | 8 +
.../function_assert_true_with_message.explain | 2 +-
.../explain-results/function_raise_error.explain | 2 +-
.../org/apache/spark/SparkThrowableSuite.scala | 4 +-
docs/sql-error-conditions.md | 20 ++
python/pyspark/sql/tests/test_functions.py | 4 +-
.../spark/sql/catalyst/expressions/misc.scala | 71 ++++---
.../spark/sql/errors/QueryExecutionErrors.scala | 49 ++++-
.../org/apache/spark/sql/internal/SQLConf.scala | 14 ++
.../expressions/ExpressionEvalHelper.scala | 4 +
.../expressions/MiscExpressionsSuite.scala | 10 +-
.../catalyst/optimizer/ConstantFoldingSuite.scala | 2 +-
.../scala/org/apache/spark/sql/functions.scala | 8 +
.../sql-functions/sql-expression-schema.md | 2 +-
.../analyzer-results/misc-functions.sql.out | 86 +++++++-
.../resources/sql-tests/inputs/misc-functions.sql | 22 +++
.../sql-tests/results/misc-functions.sql.out | 220 +++++++++++++++++++--
.../apache/spark/sql/ColumnExpressionSuite.scala | 26 ++-
.../spark/sql/execution/ui/UISeleniumSuite.scala | 9 +-
.../sql/expressions/ExpressionInfoSuite.scala | 1 +
23 files changed, 551 insertions(+), 67 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json
index dd0190c3462..0882e387176 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -3502,6 +3502,26 @@
"3. set \"spark.sql.legacy.allowUntypedScalaUDF\" to \"true\" and use this API with caution."
]
},
+ "USER_RAISED_EXCEPTION" : {
+ "message" : [
+ "<errorMessage>"
+ ],
+ "sqlState" : "P0001"
+ },
+ "USER_RAISED_EXCEPTION_PARAMETER_MISMATCH" : {
+ "message" : [
+ "The `raise_error()` function was used to raise error class: <errorClass> which expects parameters: <expectedParms>.",
+ "The provided parameters <providedParms> do not match the expected parameters.",
+ "Please make sure to provide all expected parameters."
+ ],
+ "sqlState" : "P0001"
+ },
+ "USER_RAISED_EXCEPTION_UNKNOWN_ERROR_CLASS" : {
+ "message" : [
+ "The `raise_error()` function was used to raise an unknown error class: <errorClass>"
+ ],
+ "sqlState" : "P0001"
+ },
"VARIABLE_ALREADY_EXISTS" : {
"message" : [
"Cannot create the variable <variableName> because it already exists.",
@@ -6310,5 +6330,11 @@
"message" : [
"Failed to get block <blockId>, which is not a shuffle block"
]
+ },
+ "_LEGACY_ERROR_USER_RAISED_EXCEPTION" : {
+ "message" : [
+ "<errorMessage>"
+ ],
+ "sqlState" : "P0001"
}
}
diff --git a/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala b/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
index 3c1a9a27bb5..c59e1e376ea 100644
--- a/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
+++ b/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
@@ -57,6 +57,13 @@ class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) {
}
}
+ def getMessageParameters(errorClass: String): Seq[String] = {
+ val messageTemplate = getMessageTemplate(errorClass)
+ val pattern = "<([a-zA-Z0-9_-]+)>".r
+ val matches = pattern.findAllIn(messageTemplate).toSeq
+ matches.map(m => m.stripSuffix(">").stripPrefix("<"))
+ }
+
def getMessageTemplate(errorClass: String): String = {
val errorClasses = errorClass.split("\\.")
assert(errorClasses.length == 1 || errorClasses.length == 2)
@@ -85,6 +92,17 @@ class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) {
.flatMap(_.sqlState)
.orNull
}
+
+ def isValidErrorClass(errorClass: String): Boolean = {
+ val errorClasses = errorClass.split("\\.")
+ errorClasses match {
+ case Array(mainClass) => errorInfoMap.contains(mainClass)
+ case Array(mainClass, subClass) => errorInfoMap.get(mainClass).map { info =>
+ info.subClass.get.contains(subClass)
+ }.getOrElse(false)
+ case _ => false
+ }
+ }
}
private object ErrorClassesJsonReader {
diff --git a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala
index 5d58d66eec3..f56dcab2e48 100644
--- a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala
+++ b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala
@@ -52,7 +52,7 @@ private[spark] object SparkThrowableHelper {
context: String): String = {
val displayMessage = errorReader.getErrorMessage(errorClass, messageParameters)
val displayQueryContext = (if (context.isEmpty) "" else "\n") + context
- val prefix = if (errorClass.startsWith("_LEGACY_ERROR_TEMP_")) "" else s"[$errorClass] "
+ val prefix = if (errorClass.startsWith("_LEGACY_ERROR_")) "" else s"[$errorClass] "
s"$prefix$displayMessage$displayQueryContext"
}
@@ -60,6 +60,14 @@ private[spark] object SparkThrowableHelper {
errorReader.getSqlState(errorClass)
}
+ def isValidErrorClass(errorClass: String): Boolean = {
+ errorReader.isValidErrorClass(errorClass)
+ }
+
+ def getMessageParameters(errorClass: String): Seq[String] = {
+ errorReader.getMessageParameters(errorClass)
+ }
+
def isInternalError(errorClass: String): Boolean = {
errorClass.startsWith("INTERNAL_ERROR")
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 36f1aeb3a6f..5e02a5910b1 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3305,6 +3305,14 @@ object functions {
*/
def raise_error(c: Column): Column = Column.fn("raise_error", c)
+ /**
+ * Throws an exception with the provided error message.
+ *
+ * @group misc_funcs
+ * @since 4.0.0
+ */
+ def raise_error(c: Column, e: Column): Column = Column.fn("raise_error", c, e)
+
/**
* Returns the estimated number of unique values given the binary representation of a
* Datasketches HllSketch.
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_assert_true_with_message.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_assert_true_with_message.explain
index dfd0468941b..8ddf8c2f003 100644
--- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_assert_true_with_message.explain
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_assert_true_with_message.explain
@@ -1,2 +1,2 @@
-Project [if ((id#0L > cast(0 as bigint))) null else raise_error(id negative!, NullType) AS assert_true((id > 0), id negative!)#0]
+Project [if ((id#0L > cast(0 as bigint))) null else raise_error(USER_RAISED_EXCEPTION, map(errorMessage, id negative!), NullType) AS assert_true((id > 0), id negative!)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_raise_error.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_raise_error.explain
index c65063a35a1..5b9ce716a4f 100644
--- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_raise_error.explain
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_raise_error.explain
@@ -1,2 +1,2 @@
-Project [raise_error(kaboom, NullType) AS raise_error(kaboom)#0]
+Project [raise_error(USER_RAISED_EXCEPTION, map(errorMessage, kaboom), NullType) AS raise_error(USER_RAISED_EXCEPTION, map(errorMessage, kaboom))#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
index 5c9009bf8fa..8029275d838 100644
--- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
@@ -143,7 +143,7 @@ class SparkThrowableSuite extends SparkFunSuite {
test("Message format invariants") {
val messageFormats = errorReader.errorInfoMap
- .filterKeys(!_.startsWith("_LEGACY_ERROR_TEMP_"))
+ .filterKeys(!_.startsWith("_LEGACY_ERROR_"))
.filterKeys(!_.startsWith("INTERNAL_ERROR"))
.values.toSeq.flatMap { i => Seq(i.messageTemplate) }
checkCondition(messageFormats, s => s != null)
@@ -236,7 +236,7 @@ class SparkThrowableSuite extends SparkFunSuite {
orphans
}
- val sqlErrorParentDocContent = errors.toSeq.filter(!_._1.startsWith("_LEGACY_ERROR_TEMP_"))
+ val sqlErrorParentDocContent = errors.toSeq.filter(!_._1.startsWith("_LEGACY_ERROR"))
.sortBy(_._1).map(error => {
val name = error._1
val info = error._2
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index 660a72dca7d..fda10eceb97 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -2149,6 +2149,26 @@ You're using untyped Scala UDF, which does not have the input type information.
2. use Java UDF APIs, e.g. `udf(new UDF1[String, Integer] { override def call(s: String): Integer = s.length() }, IntegerType)`, if input types are all non primitive.
3. set "spark.sql.legacy.allowUntypedScalaUDF" to "true" and use this API with caution.
+### USER_RAISED_EXCEPTION
+
+SQLSTATE: P0001
+
+`<errorMessage>`
+
+### USER_RAISED_EXCEPTION_PARAMETER_MISMATCH
+
+SQLSTATE: P0001
+
+The `raise_error()` function was used to raise error class: `<errorClass>` which expects parameters: `<expectedParms>`.
+The provided parameters `<providedParms>` do not match the expected parameters.
+Please make sure to provide all expected parameters.
+
+### USER_RAISED_EXCEPTION_UNKNOWN_ERROR_CLASS
+
+SQLSTATE: P0001
+
+The `raise_error()` function was used to raise an unknown error class: `<errorClass>`
+
### VARIABLE_ALREADY_EXISTS
[SQLSTATE: 42723](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index 66de94a6b9b..9753ba5c532 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -1031,10 +1031,10 @@ class FunctionsTestsMixin:
[Row(val=None), Row(val=None), Row(val=None)],
)
- with self.assertRaisesRegex(tpe, "too big"):
+ with self.assertRaisesRegex(tpe, r"\[USER_RAISED_EXCEPTION\] too big"):
df.select(F.assert_true(df.id < 2, "too big")).toDF("val").collect()
- with self.assertRaisesRegex(tpe, "2000000"):
+ with self.assertRaisesRegex(tpe, r"\[USER_RAISED_EXCEPTION\] 2000000.0"):
df.select(F.assert_true(df.id < 2, df.id * 1e6)).toDF("val").collect()
with self.assertRaises(PySparkTypeError) as pe:
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 6a7f841c324..4a54ccf4a31 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -24,7 +24,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.trees.TreePattern.{CURRENT_LIKE, TreePattern}
-import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator
+import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator}
+import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -61,64 +62,88 @@ case class PrintToStderr(child: Expression) extends UnaryExpression {
/**
* Throw with the result of an expression (used for debugging).
*/
+// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_(expr) - Throws an exception with `expr`.",
+ usage = "_FUNC_( expr [, errorParams ]) - Throws a USER_RAISED_EXCEPTION with `expr` as message, or a defined error class in `expr` with a parameter map. A `null` errorParms is equivalent to an empty map.",
examples = """
Examples:
> SELECT _FUNC_('custom error message');
- java.lang.RuntimeException
- custom error message
+ [USER_RAISED_EXCEPTION] custom error message
+
+ > SELECT _FUNC_('VIEW_NOT_FOUND', Map('relationName' -> '`V1`'));
+ [VIEW_NOT_FOUND] The view `V1` cannot be found. ...
""",
since = "3.1.0",
group = "misc_funcs")
-case class RaiseError(child: Expression, dataType: DataType)
- extends UnaryExpression with ImplicitCastInputTypes {
+// scalastyle:on line.size.limit
+case class RaiseError(errorClass: Expression, errorParms: Expression, dataType: DataType)
+ extends BinaryExpression with ImplicitCastInputTypes {
+
+ def this(str: Expression) = {
+ this(Literal(
+ if (SQLConf.get.getConf(SQLConf.LEGACY_RAISE_ERROR_WITHOUT_ERROR_CLASS)) {
+ "_LEGACY_ERROR_USER_RAISED_EXCEPTION"
+ } else {
+ "USER_RAISED_EXCEPTION"
+ }),
+ CreateMap(Seq(Literal("errorMessage"), str)), NullType)
+ }
- def this(child: Expression) = this(child, NullType)
+ def this(errorClass: Expression, errorParms: Expression) = {
+ this(errorClass, errorParms, NullType)
+ }
override def foldable: Boolean = false
override def nullable: Boolean = true
- override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(StringType, MapType(StringType, StringType))
+
+ override def left: Expression = errorClass
+ override def right: Expression = errorParms
override def prettyName: String = "raise_error"
override def eval(input: InternalRow): Any = {
- val value = child.eval(input)
- if (value == null) {
- throw new RuntimeException()
- }
- throw new RuntimeException(value.toString)
+ val error = errorClass.eval(input).asInstanceOf[UTF8String]
+ val parms: MapData = errorParms.eval(input).asInstanceOf[MapData]
+ throw raiseError(error, parms)
}
// if (true) is to avoid codegen compilation exception that statement is unreachable
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val eval = child.genCode(ctx)
+ val error = errorClass.genCode(ctx)
+ val parms = errorParms.genCode(ctx)
ExprCode(
- code = code"""${eval.code}
+ code = code"""${error.code}
+ |${parms.code}
|if (true) {
- | if (${eval.isNull}) {
- | throw new RuntimeException();
- | }
- | throw new RuntimeException(${eval.value}.toString());
+ | throw QueryExecutionErrors.raiseError(
+ | ${error.value},
+ | ${parms.value});
|}""".stripMargin,
isNull = TrueLiteral,
value = JavaCode.defaultLiteral(dataType)
)
}
- override protected def withNewChildInternal(newChild: Expression): RaiseError =
- copy(child = newChild)
+ override protected def withNewChildrenInternal(
+ newLeft: Expression, newRight: Expression): RaiseError = {
+ copy(errorClass = newLeft, errorParms = newRight)
+ }
}
object RaiseError {
- def apply(child: Expression): RaiseError = new RaiseError(child)
+ def apply(str: Expression): RaiseError = new RaiseError(str)
+
+ def apply(errorClass: Expression, parms: Expression): RaiseError =
+ new RaiseError(errorClass, parms)
}
/**
* A function that throws an exception if 'condition' is not true.
*/
@ExpressionDescription(
- usage = "_FUNC_(expr) - Throws an exception if `expr` is not true.",
+ usage = "_FUNC_(expr [, message]) - Throws an exception if `expr` is not true.",
examples = """
Examples:
> SELECT _FUNC_(0 < 1);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 84472490128..f34674909bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -21,6 +21,7 @@ import java.io.{File, FileNotFoundException, IOException}
import java.lang.reflect.InvocationTargetException
import java.net.{URISyntaxException, URL}
import java.time.DateTimeException
+import java.util.Locale
import java.util.concurrent.TimeoutException
import com.fasterxml.jackson.core.{JsonParser, JsonToken}
@@ -41,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ValueInterval
import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext, TreeNode}
-import org.apache.spark.sql.catalyst.util.{sideBySide, BadRecordException, DateTimeUtils, FailFastMode}
+import org.apache.spark.sql.catalyst.util.{sideBySide, BadRecordException, DateTimeUtils, FailFastMode, MapData}
import org.apache.spark.sql.connector.catalog.{CatalogNotFoundException, Table, TableProvider}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.expressions.Transform
@@ -2728,4 +2729,50 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
errorClass = "UNSUPPORTED_FEATURE.PURGE_TABLE",
messageParameters = Map.empty)
}
+
+ def raiseError(
+ errorClass: UTF8String,
+ errorParms: MapData): RuntimeException = {
+ val errorClassStr = if (errorClass != null) {
+ errorClass.toString.toUpperCase(Locale.ROOT)
+ } else {
+ "null"
+ }
+ val errorParmsMap = if (errorParms != null) {
+ val errorParmsMutable = collection.mutable.Map[String, String]()
+ errorParms.foreach(StringType, StringType, { case (key, value) =>
+ errorParmsMutable += (key.toString ->
+ (if (value == null) { "null" } else { value.toString } ))
+ })
+ errorParmsMutable.toMap
+ } else {
+ Map.empty[String, String]
+ }
+
+ // Is the error class a known error class? If not raise an error
+ if (!SparkThrowableHelper.isValidErrorClass(errorClassStr)) {
+ new SparkRuntimeException(
+ errorClass = "USER_RAISED_EXCEPTION_UNKNOWN_ERROR_CLASS",
+ messageParameters = Map("errorClass" -> toSQLValue(errorClassStr)))
+ } else {
+ // Did the user provide all parameters? If not raise an error
+ val expectedParms = SparkThrowableHelper.getMessageParameters(errorClassStr).sorted
+ val providedParms = errorParmsMap.keys.toSeq.sorted
+ if (expectedParms != providedParms) {
+ new SparkRuntimeException(
+ errorClass = "USER_RAISED_EXCEPTION_PARAMETER_MISMATCH",
+ messageParameters = Map("errorClass" -> toSQLValue(errorClassStr),
+ "expectedParms" -> expectedParms.map { p => toSQLValue(p) }.mkString(","),
+ "providedParms" -> providedParms.map { p => toSQLValue(p) }.mkString(",")))
+ } else if (errorClassStr == "_LEGACY_ERROR_USER_RAISED_EXCEPTION") {
+ // Don't break old raise_error() if asked
+ new RuntimeException(errorParmsMap.head._2)
+ } else {
+ // All good, raise the error
+ new SparkRuntimeException(
+ errorClass = errorClassStr,
+ messageParameters = errorParmsMap)
+ }
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 18d30ac6ac5..0cad85e1296 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4448,6 +4448,17 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val LEGACY_RAISE_ERROR_WITHOUT_ERROR_CLASS =
+ buildConf("spark.sql.legacy.raiseErrorWithoutErrorClass")
+ .internal()
+ .doc("When set to true, restores the legacy behavior of `raise_error` and `assert_true` to " +
+ "not return the `[USER_RAISED_EXCEPTION]` prefix." +
+ "For example, `raise_error('error!')` returns `error!` instead of " +
+ "`[USER_RAISED_EXCEPTION] Error!`.")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
/**
* Holds information about keys that have been deprecated.
*
@@ -5317,6 +5328,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
getConf(SQLConf.LEGACY_NEGATIVE_INDEX_IN_ARRAY_INSERT)
}
+ def legacyRaiseErrorWithoutErrorClass: Boolean =
+ getConf(SQLConf.LEGACY_RAISE_ERROR_WITHOUT_ERROR_CLASS)
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 5be0cae4a22..7ddb92cbbde 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -211,6 +211,10 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB
}.getMessage
if (errMsg == null) {
if (expectedErrMsg != null) {
+ fail(s"Expected `$expectedErrMsg` but null error message found")
+ }
+ } else if (expectedErrMsg == null) {
+ if (errMsg != null) {
fail(s"Expected null error message, but `$errMsg` found")
}
} else if (!errMsg.contains(expectedErrMsg)) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
index d449de3defb..28da02a68f7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
@@ -37,7 +37,7 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkExceptionInExpression[RuntimeException](
RaiseError(Literal.create(null, StringType)),
EmptyRow,
- null
+ "[USER_RAISED_EXCEPTION] null"
)
// Expects a string
@@ -45,10 +45,10 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
- "paramIndex" -> "1",
- "requiredType" -> "\"STRING\"",
- "inputSql" -> "\"5\"",
- "inputType" -> "\"INT\""
+ "paramIndex" -> "2",
+ "requiredType" -> "\"MAP<STRING, STRING>\"",
+ "inputSql" -> "\"map(errorMessage, 5)\"",
+ "inputType" -> "\"MAP<STRING, INT>\""
)
)
)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index 90882da0cab..8734583d3c3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -275,7 +275,7 @@ class ConstantFoldingSuite extends PlanTest {
val originalQuery =
testRelation
.select($"a")
- .where(Size(CreateArray(Seq(AssertTrue(false)))) > 0)
+ .where(Size(CreateArray(Seq(rand(0)))) > 0)
val optimized = Optimize.execute(originalQuery.analyze)
comparePlans(optimized, originalQuery.analyze)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index f911e58f6c8..58a994a0ea2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3249,6 +3249,14 @@ object functions {
*/
def raise_error(c: Column): Column = Column.fn("raise_error", c)
+ /**
+ * Throws an exception with the provided error class and parameter map.
+ *
+ * @group misc_funcs
+ * @since 4.0.0
+ */
+ def raise_error(c: Column, e: Column): Column = Column.fn("raise_error", c, e)
+
/**
* Returns the estimated number of unique values given the binary representation
* of a Datasketches HllSketch.
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 4fd493d1a3c..89e840d1242 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -254,7 +254,7 @@
| org.apache.spark.sql.catalyst.expressions.RLike | regexp_like | SELECT regexp_like('%SystemDrive%\Users\John', '%SystemDrive%\\Users.*') | struct<REGEXP_LIKE(%SystemDrive%UsersJohn, %SystemDrive%\Users.*):boolean> |
| org.apache.spark.sql.catalyst.expressions.RLike | rlike | SELECT rlike('%SystemDrive%\Users\John', '%SystemDrive%\\Users.*') | struct<RLIKE(%SystemDrive%UsersJohn, %SystemDrive%\Users.*):boolean> |
| org.apache.spark.sql.catalyst.expressions.RPadExpressionBuilder | rpad | SELECT rpad('hi', 5, '??') | struct<rpad(hi, 5, ??):string> |
-| org.apache.spark.sql.catalyst.expressions.RaiseError | raise_error | SELECT raise_error('custom error message') | struct<raise_error(custom error message):void> |
+| org.apache.spark.sql.catalyst.expressions.RaiseError | raise_error | SELECT raise_error('custom error message') | struct<raise_error(USER_RAISED_EXCEPTION, map(errorMessage, custom error message)):void> |
| org.apache.spark.sql.catalyst.expressions.Rand | rand | SELECT rand() | struct<rand():double> |
| org.apache.spark.sql.catalyst.expressions.Rand | random | SELECT random() | struct<rand():double> |
| org.apache.spark.sql.catalyst.expressions.Randn | randn | SELECT randn() | struct<randn():double> |
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/misc-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/misc-functions.sql.out
index 042f2b64f9a..1fd6e55a4d2 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/misc-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/misc-functions.sql.out
@@ -102,14 +102,14 @@ CreateViewCommand `tbl_misc`, SELECT * FROM (VALUES (1), (8), (2)) AS T(v), fals
-- !query
SELECT raise_error('error message')
-- !query analysis
-Project [raise_error(error message, NullType) AS raise_error(error message)#x]
+Project [raise_error(USER_RAISED_EXCEPTION, map(errorMessage, error message), NullType) AS raise_error(USER_RAISED_EXCEPTION, map(errorMessage, error message))#x]
+- OneRowRelation
-- !query
SELECT if(v > 5, raise_error('too big: ' || v), v + 1) FROM tbl_misc
-- !query analysis
-Project [if ((v#x > 5)) cast(raise_error(concat(too big: , cast(v#x as string)), NullType) as int) else (v#x + 1) AS (IF((v > 5), raise_error(concat(too big: , v)), (v + 1)))#x]
+Project [if ((v#x > 5)) cast(raise_error(USER_RAISED_EXCEPTION, map(errorMessage, concat(too big: , cast(v#x as string))), NullType) as int) else (v#x + 1) AS (IF((v > 5), raise_error(USER_RAISED_EXCEPTION, map(errorMessage, concat(too big: , v))), (v + 1)))#x]
+- SubqueryAlias tbl_misc
+- View (`tbl_misc`, [v#x])
+- Project [cast(v#x as int) AS v#x]
@@ -117,3 +117,85 @@ Project [if ((v#x > 5)) cast(raise_error(concat(too big: , cast(v#x as string)),
+- SubqueryAlias T
+- Project [col1#x AS v#x]
+- LocalRelation [col1#x]
+
+
+-- !query
+SELECT raise_error('VIEW_NOT_FOUND', Map('relationName', '`v`'))
+-- !query analysis
+Project [raise_error(VIEW_NOT_FOUND, map(relationName, `v`), NullType) AS raise_error(VIEW_NOT_FOUND, map(relationName, `v`))#x]
++- OneRowRelation
+
+
+-- !query
+SELECT raise_error('VIEW_NOT_FOund', Map('relationName', '`v`'))
+-- !query analysis
+Project [raise_error(VIEW_NOT_FOund, map(relationName, `v`), NullType) AS raise_error(VIEW_NOT_FOund, map(relationName, `v`))#x]
++- OneRowRelation
+
+
+-- !query
+SELECT raise_error('VIEW_NOT_FOund', Map('relationNAME', '`v`'))
+-- !query analysis
+Project [raise_error(VIEW_NOT_FOund, map(relationNAME, `v`), NullType) AS raise_error(VIEW_NOT_FOund, map(relationNAME, `v`))#x]
++- OneRowRelation
+
+
+-- !query
+SELECT raise_error('VIEW_NOT_FOUND', Map())
+-- !query analysis
+Project [raise_error(VIEW_NOT_FOUND, cast(map() as map<string,string>), NullType) AS raise_error(VIEW_NOT_FOUND, map())#x]
++- OneRowRelation
+
+
+-- !query
+SELECT raise_error('VIEW_NOT_FOUND', Map('relationName', '`v`', 'totallymadeup', '5'))
+-- !query analysis
+Project [raise_error(VIEW_NOT_FOUND, map(relationName, `v`, totallymadeup, 5), NullType) AS raise_error(VIEW_NOT_FOUND, map(relationName, `v`, totallymadeup, 5))#x]
++- OneRowRelation
+
+
+-- !query
+SELECT raise_error('ALL_PARTITION_COLUMNS_NOT_ALLOWED', Map())
+-- !query analysis
+Project [raise_error(ALL_PARTITION_COLUMNS_NOT_ALLOWED, cast(map() as map<string,string>), NullType) AS raise_error(ALL_PARTITION_COLUMNS_NOT_ALLOWED, map())#x]
++- OneRowRelation
+
+
+-- !query
+SELECT raise_error('ALL_PARTITION_COLUMNS_NOT_ALLOWED', NULL)
+-- !query analysis
+Project [raise_error(ALL_PARTITION_COLUMNS_NOT_ALLOWED, cast(null as map<string,string>), NullType) AS raise_error(ALL_PARTITION_COLUMNS_NOT_ALLOWED, NULL)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT raise_error(NULL, NULL)
+-- !query analysis
+Project [raise_error(cast(null as string), cast(null as map<string,string>), NullType) AS raise_error(NULL, NULL)#x]
++- OneRowRelation
+
+
+-- !query
+SET spark.sql.legacy.raiseErrorWithoutErrorClass=true
+-- !query analysis
+SetCommand (spark.sql.legacy.raiseErrorWithoutErrorClass,Some(true))
+
+
+-- !query
+SELECT assert_true(false)
+-- !query analysis
+Project [assert_true(false, 'false' is not true!) AS assert_true(false, 'false' is not true!)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT raise_error('hello')
+-- !query analysis
+Project [raise_error(_LEGACY_ERROR_USER_RAISED_EXCEPTION, map(errorMessage, hello), NullType) AS raise_error(_LEGACY_ERROR_USER_RAISED_EXCEPTION, map(errorMessage, hello))#x]
++- OneRowRelation
+
+
+-- !query
+SET spark.sql.legacy.raiseErrorWithoutErrorClass=false
+-- !query analysis
+SetCommand (spark.sql.legacy.raiseErrorWithoutErrorClass,Some(false))
diff --git a/sql/core/src/test/resources/sql-tests/inputs/misc-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/misc-functions.sql
index 907ff33000d..12e6a36db77 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/misc-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/misc-functions.sql
@@ -20,3 +20,25 @@ SELECT assert_true(false, 'custom error message');
CREATE TEMPORARY VIEW tbl_misc AS SELECT * FROM (VALUES (1), (8), (2)) AS T(v);
SELECT raise_error('error message');
SELECT if(v > 5, raise_error('too big: ' || v), v + 1) FROM tbl_misc;
+
+SELECT raise_error('VIEW_NOT_FOUND', Map('relationName', '`v`'));
+-- Error class is case insensitive
+SELECT raise_error('VIEW_NOT_FOund', Map('relationName', '`v`'));
+-- parameters are case sensitive
+SELECT raise_error('VIEW_NOT_FOund', Map('relationNAME', '`v`'));
+-- Too few parameters
+SELECT raise_error('VIEW_NOT_FOUND', Map());
+-- Too many parameters
+SELECT raise_error('VIEW_NOT_FOUND', Map('relationName', '`v`', 'totallymadeup', '5'));
+
+-- Empty parameter list
+SELECT raise_error('ALL_PARTITION_COLUMNS_NOT_ALLOWED', Map());
+SELECT raise_error('ALL_PARTITION_COLUMNS_NOT_ALLOWED', NULL);
+
+SELECT raise_error(NULL, NULL);
+
+-- Check legacy config disables printing of [USER_RAISED_EXCEPTION]
+SET spark.sql.legacy.raiseErrorWithoutErrorClass=true;
+SELECT assert_true(false);
+SELECT raise_error('hello');
+SET spark.sql.legacy.raiseErrorWithoutErrorClass=false;
\ No newline at end of file
diff --git a/sql/core/src/test/resources/sql-tests/results/misc-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/misc-functions.sql.out
index 7e9bb2f7acd..be7252d8c88 100644
--- a/sql/core/src/test/resources/sql-tests/results/misc-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/misc-functions.sql.out
@@ -68,8 +68,14 @@ SELECT assert_true(false)
-- !query schema
struct<>
-- !query output
-java.lang.RuntimeException
-'false' is not true!
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "USER_RAISED_EXCEPTION",
+ "sqlState" : "P0001",
+ "messageParameters" : {
+ "errorMessage" : "'false' is not true!"
+ }
+}
-- !query
@@ -77,8 +83,14 @@ SELECT assert_true(boolean(0))
-- !query schema
struct<>
-- !query output
-java.lang.RuntimeException
-'cast(0 as boolean)' is not true!
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "USER_RAISED_EXCEPTION",
+ "sqlState" : "P0001",
+ "messageParameters" : {
+ "errorMessage" : "'cast(0 as boolean)' is not true!"
+ }
+}
-- !query
@@ -86,8 +98,14 @@ SELECT assert_true(null)
-- !query schema
struct<>
-- !query output
-java.lang.RuntimeException
-'null' is not true!
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "USER_RAISED_EXCEPTION",
+ "sqlState" : "P0001",
+ "messageParameters" : {
+ "errorMessage" : "'null' is not true!"
+ }
+}
-- !query
@@ -95,8 +113,14 @@ SELECT assert_true(boolean(null))
-- !query schema
struct<>
-- !query output
-java.lang.RuntimeException
-'cast(null as boolean)' is not true!
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "USER_RAISED_EXCEPTION",
+ "sqlState" : "P0001",
+ "messageParameters" : {
+ "errorMessage" : "'cast(null as boolean)' is not true!"
+ }
+}
-- !query
@@ -104,8 +128,14 @@ SELECT assert_true(false, 'custom error message')
-- !query schema
struct<>
-- !query output
-java.lang.RuntimeException
-custom error message
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "USER_RAISED_EXCEPTION",
+ "sqlState" : "P0001",
+ "messageParameters" : {
+ "errorMessage" : "custom error message"
+ }
+}
-- !query
@@ -121,8 +151,14 @@ SELECT raise_error('error message')
-- !query schema
struct<>
-- !query output
-java.lang.RuntimeException
-error message
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "USER_RAISED_EXCEPTION",
+ "sqlState" : "P0001",
+ "messageParameters" : {
+ "errorMessage" : "error message"
+ }
+}
-- !query
@@ -130,5 +166,163 @@ SELECT if(v > 5, raise_error('too big: ' || v), v + 1) FROM tbl_misc
-- !query schema
struct<>
-- !query output
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "USER_RAISED_EXCEPTION",
+ "sqlState" : "P0001",
+ "messageParameters" : {
+ "errorMessage" : "too big: 8"
+ }
+}
+
+
+-- !query
+SELECT raise_error('VIEW_NOT_FOUND', Map('relationName', '`v`'))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`v`"
+ }
+}
+
+
+-- !query
+SELECT raise_error('VIEW_NOT_FOund', Map('relationName', '`v`'))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`v`"
+ }
+}
+
+
+-- !query
+SELECT raise_error('VIEW_NOT_FOund', Map('relationNAME', '`v`'))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "USER_RAISED_EXCEPTION_PARAMETER_MISMATCH",
+ "sqlState" : "P0001",
+ "messageParameters" : {
+ "errorClass" : "'VIEW_NOT_FOUND'",
+ "expectedParms" : "'relationName'",
+ "providedParms" : "'relationNAME'"
+ }
+}
+
+
+-- !query
+SELECT raise_error('VIEW_NOT_FOUND', Map())
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "USER_RAISED_EXCEPTION_PARAMETER_MISMATCH",
+ "sqlState" : "P0001",
+ "messageParameters" : {
+ "errorClass" : "'VIEW_NOT_FOUND'",
+ "expectedParms" : "'relationName'",
+ "providedParms" : ""
+ }
+}
+
+
+-- !query
+SELECT raise_error('VIEW_NOT_FOUND', Map('relationName', '`v`', 'totallymadeup', '5'))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "USER_RAISED_EXCEPTION_PARAMETER_MISMATCH",
+ "sqlState" : "P0001",
+ "messageParameters" : {
+ "errorClass" : "'VIEW_NOT_FOUND'",
+ "expectedParms" : "'relationName'",
+ "providedParms" : "'relationName','totallymadeup'"
+ }
+}
+
+
+-- !query
+SELECT raise_error('ALL_PARTITION_COLUMNS_NOT_ALLOWED', Map())
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "ALL_PARTITION_COLUMNS_NOT_ALLOWED"
+}
+
+
+-- !query
+SELECT raise_error('ALL_PARTITION_COLUMNS_NOT_ALLOWED', NULL)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "ALL_PARTITION_COLUMNS_NOT_ALLOWED"
+}
+
+
+-- !query
+SELECT raise_error(NULL, NULL)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "USER_RAISED_EXCEPTION_UNKNOWN_ERROR_CLASS",
+ "sqlState" : "P0001",
+ "messageParameters" : {
+ "errorClass" : "'null'"
+ }
+}
+
+
+-- !query
+SET spark.sql.legacy.raiseErrorWithoutErrorClass=true
+-- !query schema
+struct<key:string,value:string>
+-- !query output
+spark.sql.legacy.raiseErrorWithoutErrorClass true
+
+
+-- !query
+SELECT assert_true(false)
+-- !query schema
+struct<>
+-- !query output
+java.lang.RuntimeException
+'false' is not true!
+
+
+-- !query
+SELECT raise_error('hello')
+-- !query schema
+struct<>
+-- !query output
java.lang.RuntimeException
-too big: 8
+hello
+
+
+-- !query
+SET spark.sql.legacy.raiseErrorWithoutErrorClass=false
+-- !query schema
+struct<key:string,value:string>
+-- !query output
+spark.sql.legacy.raiseErrorWithoutErrorClass false
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 0baded3323c..8a10050336c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -26,7 +26,7 @@ import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat}
import org.scalatest.matchers.should.Matchers._
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.sql.UpdateFieldsBenchmark._
import org.apache.spark.sql.catalyst.expressions.{InSet, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{outstandingTimezonesIds, outstandingZoneIds}
@@ -2545,8 +2545,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
val e1 = intercept[SparkException] {
booleanDf.select(assert_true($"cond", lit(null.asInstanceOf[String]))).collect()
}
- assert(e1.getCause.isInstanceOf[RuntimeException])
- assert(e1.getCause.getMessage == null)
+ checkError(e1.getCause.asInstanceOf[SparkThrowable],
+ errorClass = "USER_RAISED_EXCEPTION",
+ parameters = Map("errorMessage" -> "null"))
val nullDf = Seq(("first row", None), ("second row", Some(true))).toDF("n", "cond")
checkAnswer(
@@ -2556,8 +2557,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
val e2 = intercept[SparkException] {
nullDf.select(assert_true($"cond", $"n")).collect()
}
- assert(e2.getCause.isInstanceOf[RuntimeException])
- assert(e2.getCause.getMessage == "first row")
+ checkError(e2.getCause.asInstanceOf[SparkThrowable],
+ errorClass = "USER_RAISED_EXCEPTION",
+ parameters = Map("errorMessage" -> "first row"))
// assert_true(condition)
val intDf = Seq((0, 1)).toDF("a", "b")
@@ -2565,8 +2567,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
val e3 = intercept[SparkException] {
intDf.select(assert_true($"a" > $"b")).collect()
}
+
assert(e3.getCause.isInstanceOf[RuntimeException])
- assert(e3.getCause.getMessage.matches("'\\(a#\\d+ > b#\\d+\\)' is not true!"))
+ assert(e3.getCause.getMessage.matches(
+ "\\[USER_RAISED_EXCEPTION\\] '\\(a#\\d+ > b#\\d+\\)' is not true!"))
}
test("raise_error") {
@@ -2575,14 +2579,16 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
val e1 = intercept[SparkException] {
strDf.select(raise_error(lit(null.asInstanceOf[String]))).collect()
}
- assert(e1.getCause.isInstanceOf[RuntimeException])
- assert(e1.getCause.getMessage == null)
+ checkError(e1.getCause.asInstanceOf[SparkThrowable],
+ errorClass = "USER_RAISED_EXCEPTION",
+ parameters = Map("errorMessage" -> "null"))
val e2 = intercept[SparkException] {
strDf.select(raise_error($"a")).collect()
}
- assert(e2.getCause.isInstanceOf[RuntimeException])
- assert(e2.getCause.getMessage == "hello")
+ checkError(e2.getCause.asInstanceOf[SparkThrowable],
+ errorClass = "USER_RAISED_EXCEPTION",
+ parameters = Map("errorMessage" -> "hello"))
}
test("SPARK-34677: negate/add/subtract year-month and day-time intervals") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/UISeleniumSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/UISeleniumSuite.scala
index f80c456b4ba..04c3953ca40 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/UISeleniumSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/UISeleniumSuite.scala
@@ -27,7 +27,7 @@ import org.scalatest.concurrent.Eventually.eventually
import org.scalatest.concurrent.Futures.{interval, timeout}
import org.scalatestplus.selenium.WebBrowser
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.{SparkException, SparkFunSuite, SparkThrowable}
import org.apache.spark.sql.SparkSession
import org.apache.spark.ui.SparkUICssErrorHandler
@@ -120,9 +120,10 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser {
HTML40_EXTENDED_ESCAPE.keySet().asScala).mkString
val errorMsg = escapeJava(escape.mkString)
val e1 = intercept[SparkException](spark.sql(s"SELECT raise_error('$errorMsg')").collect())
- val e2 = e1.getCause
- assert(e2.isInstanceOf[RuntimeException])
- assert(e2.getMessage === escape)
+ val e2 = e1.getCause.asInstanceOf[SparkThrowable]
+ checkError(e2,
+ errorClass = "USER_RAISED_EXCEPTION",
+ parameters = Map("errorMessage" -> escape))
eventually(timeout(10.seconds), interval(100.milliseconds)) {
val summary = findErrorSummaryOnSQLUI()
assert(!summary.contains("&"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
index 262412f8cdb..1d522718116 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
@@ -194,6 +194,7 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
"org.apache.spark.sql.catalyst.expressions.SparkVersion",
// Throws an error
"org.apache.spark.sql.catalyst.expressions.RaiseError",
+ "org.apache.spark.sql.catalyst.expressions.AssertTrue",
classOf[CurrentUser].getName,
// The encrypt expression includes a random initialization vector to its encrypted result
classOf[AesEncrypt].getName)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org