You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2022/08/01 08:40:36 UTC
[spark] branch master updated: [SPARK-39923][SQL] Multiple query contexts in Spark exceptions
This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9ee2c753b98 [SPARK-39923][SQL] Multiple query contexts in Spark exceptions
9ee2c753b98 is described below
commit 9ee2c753b98b290fab9b2ec1f02d90c7c9441271
Author: Max Gekk <ma...@gmail.com>
AuthorDate: Mon Aug 1 13:40:22 2022 +0500
[SPARK-39923][SQL] Multiple query contexts in Spark exceptions
### What changes were proposed in this pull request?
1. Replace `Option[QueryContext]` by `Array[QueryContext]` in Spark exceptions like in `SparkRuntimeException`.
2. Pass `SQLQueryContext` to `QueryExecutionErrors` functions instead of `Option[SQLQueryContext]`.
3. Add the methods `getContextOrNull()` and `getContextOrNullCode()` to `SupportQueryContext` to get a SQL query context or `null` (if it is missed) of an expression.
### Why are the changes needed?
1. The changes will allow to chain multiple error contexts in Spark's exception. For instance, if user's query refers a view v1, v1 refers another view v2, and v2 does a division. The error contexts will be: sql fragment of v2 that does division -> sql fragment of v1 that refers v2 -> sql fragment of your query that refers v1.
2. Passing `SQLQueryContext` to `QueryExecutionErrors` directly simplifies codegen code because it allows to avoid construction of Scala objects like `scala.None`.
### Does this PR introduce _any_ user-facing change?
Yes, this PR changes user-facing exceptions.
### How was this patch tested?
By running the modified test suites:
```
$ build/sbt "test:testOnly *DecimalExpressionSuite"
```
and potentially affected tests:
```
$ build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite"
```
Closes #37343 from MaxGekk/array-as-query-context.
Authored-by: Max Gekk <ma...@gmail.com>
Signed-off-by: Max Gekk <ma...@gmail.com>
---
.../scala/org/apache/spark/SparkException.scala | 28 ++++-----
.../spark/sql/catalyst/expressions/Cast.scala | 67 +++++++++++-----------
.../sql/catalyst/expressions/Expression.scala | 10 ++++
.../catalyst/expressions/aggregate/Average.scala | 6 +-
.../sql/catalyst/expressions/aggregate/Sum.scala | 23 ++++----
.../sql/catalyst/expressions/arithmetic.scala | 48 +++++++++-------
.../expressions/collectionOperations.scala | 4 +-
.../expressions/complexTypeExtractors.scala | 8 +--
.../catalyst/expressions/decimalExpressions.scala | 32 ++++-------
.../catalyst/expressions/intervalExpressions.scala | 16 +++---
.../sql/catalyst/expressions/mathExpressions.scala | 2 +-
.../catalyst/expressions/stringExpressions.scala | 5 +-
.../spark/sql/catalyst/util/DateTimeUtils.scala | 10 ++--
.../spark/sql/catalyst/util/IntervalUtils.scala | 2 +-
.../apache/spark/sql/catalyst/util/MathUtils.scala | 14 ++---
.../spark/sql/catalyst/util/UTF8StringUtils.scala | 10 ++--
.../apache/spark/sql/errors/QueryErrorsBase.scala | 9 ++-
.../spark/sql/errors/QueryExecutionErrors.scala | 54 ++++++++---------
.../scala/org/apache/spark/sql/types/Decimal.scala | 4 +-
.../expressions/DecimalExpressionSuite.scala | 2 +-
20 files changed, 182 insertions(+), 172 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala
index d6add48ffb1..6548a114d41 100644
--- a/core/src/main/scala/org/apache/spark/SparkException.scala
+++ b/core/src/main/scala/org/apache/spark/SparkException.scala
@@ -119,7 +119,7 @@ private[spark] class SparkArithmeticException(
errorClass: String,
errorSubClass: Option[String] = None,
messageParameters: Array[String],
- context: Option[QueryContext],
+ context: Array[QueryContext],
summary: String)
extends ArithmeticException(
SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary))
@@ -128,7 +128,7 @@ private[spark] class SparkArithmeticException(
override def getMessageParameters: Array[String] = messageParameters
override def getErrorClass: String = errorClass
override def getErrorSubClass: String = errorSubClass.orNull
- override def getQueryContext: Array[QueryContext] = context.toArray
+ override def getQueryContext: Array[QueryContext] = context
}
/**
@@ -195,7 +195,7 @@ private[spark] class SparkDateTimeException(
errorClass: String,
errorSubClass: Option[String] = None,
messageParameters: Array[String],
- context: Option[QueryContext],
+ context: Array[QueryContext],
summary: String)
extends DateTimeException(
SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary))
@@ -204,7 +204,7 @@ private[spark] class SparkDateTimeException(
override def getMessageParameters: Array[String] = messageParameters
override def getErrorClass: String = errorClass
override def getErrorSubClass: String = errorSubClass.orNull
- override def getQueryContext: Array[QueryContext] = context.toArray
+ override def getQueryContext: Array[QueryContext] = context
}
/**
@@ -244,7 +244,7 @@ private[spark] class SparkNumberFormatException(
errorClass: String,
errorSubClass: Option[String] = None,
messageParameters: Array[String],
- context: Option[QueryContext],
+ context: Array[QueryContext],
summary: String)
extends NumberFormatException(
SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary))
@@ -253,7 +253,7 @@ private[spark] class SparkNumberFormatException(
override def getMessageParameters: Array[String] = messageParameters
override def getErrorClass: String = errorClass
override def getErrorSubClass: String = errorSubClass.orNull
- override def getQueryContext: Array[QueryContext] = context.toArray
+ override def getQueryContext: Array[QueryContext] = context
}
/**
@@ -323,7 +323,7 @@ private[spark] class SparkRuntimeException(
errorSubClass: Option[String] = None,
messageParameters: Array[String],
cause: Throwable = null,
- context: Option[QueryContext] = None,
+ context: Array[QueryContext] = Array.empty,
summary: String = "")
extends RuntimeException(
SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary),
@@ -334,7 +334,7 @@ private[spark] class SparkRuntimeException(
errorSubClass: String,
messageParameters: Array[String],
cause: Throwable,
- context: Option[QueryContext])
+ context: Array[QueryContext])
= this(errorClass = errorClass,
errorSubClass = Some(errorSubClass),
messageParameters = messageParameters,
@@ -348,12 +348,12 @@ private[spark] class SparkRuntimeException(
errorSubClass = Some(errorSubClass),
messageParameters = messageParameters,
cause = null,
- context = None)
+ context = Array.empty[QueryContext])
override def getMessageParameters: Array[String] = messageParameters
override def getErrorClass: String = errorClass
override def getErrorSubClass: String = errorSubClass.orNull
- override def getQueryContext: Array[QueryContext] = context.toArray
+ override def getQueryContext: Array[QueryContext] = context
}
/**
@@ -379,7 +379,7 @@ private[spark] class SparkArrayIndexOutOfBoundsException(
errorClass: String,
errorSubClass: Option[String] = None,
messageParameters: Array[String],
- context: Option[QueryContext],
+ context: Array[QueryContext],
summary: String)
extends ArrayIndexOutOfBoundsException(
SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary))
@@ -388,7 +388,7 @@ private[spark] class SparkArrayIndexOutOfBoundsException(
override def getMessageParameters: Array[String] = messageParameters
override def getErrorClass: String = errorClass
override def getErrorSubClass: String = errorSubClass.orNull
- override def getQueryContext: Array[QueryContext] = context.toArray
+ override def getQueryContext: Array[QueryContext] = context
}
/**
@@ -420,7 +420,7 @@ private[spark] class SparkNoSuchElementException(
errorClass: String,
errorSubClass: Option[String] = None,
messageParameters: Array[String],
- context: Option[QueryContext],
+ context: Array[QueryContext],
summary: String)
extends NoSuchElementException(
SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary))
@@ -429,7 +429,7 @@ private[spark] class SparkNoSuchElementException(
override def getMessageParameters: Array[String] = messageParameters
override def getErrorClass: String = errorClass
override def getErrorSubClass: String = errorSubClass.orNull
- override def getQueryContext: Array[QueryContext] = context.toArray
+ override def getQueryContext: Array[QueryContext] = context
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 0ba651b5650..f740ecd9dcb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -653,7 +653,7 @@ case class Cast(
false
} else {
if (ansiEnabled) {
- throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, queryContext)
+ throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, getContextOrNull())
} else {
null
}
@@ -685,7 +685,7 @@ case class Cast(
case StringType =>
buildCast[UTF8String](_, utfs => {
if (ansiEnabled) {
- DateTimeUtils.stringToTimestampAnsi(utfs, zoneId, queryContext)
+ DateTimeUtils.stringToTimestampAnsi(utfs, zoneId, getContextOrNull())
} else {
DateTimeUtils.stringToTimestamp(utfs, zoneId).orNull
}
@@ -710,14 +710,14 @@ case class Cast(
// TimestampWritable.doubleToTimestamp
case DoubleType =>
if (ansiEnabled) {
- buildCast[Double](_, d => doubleToTimestampAnsi(d, queryContext))
+ buildCast[Double](_, d => doubleToTimestampAnsi(d, getContextOrNull()))
} else {
buildCast[Double](_, d => doubleToTimestamp(d))
}
// TimestampWritable.floatToTimestamp
case FloatType =>
if (ansiEnabled) {
- buildCast[Float](_, f => doubleToTimestampAnsi(f.toDouble, queryContext))
+ buildCast[Float](_, f => doubleToTimestampAnsi(f.toDouble, getContextOrNull()))
} else {
buildCast[Float](_, f => doubleToTimestamp(f.toDouble))
}
@@ -727,7 +727,7 @@ case class Cast(
case StringType =>
buildCast[UTF8String](_, utfs => {
if (ansiEnabled) {
- DateTimeUtils.stringToTimestampWithoutTimeZoneAnsi(utfs, queryContext)
+ DateTimeUtils.stringToTimestampWithoutTimeZoneAnsi(utfs, getContextOrNull())
} else {
DateTimeUtils.stringToTimestampWithoutTimeZone(utfs).orNull
}
@@ -760,7 +760,7 @@ case class Cast(
private[this] def castToDate(from: DataType): Any => Any = from match {
case StringType =>
if (ansiEnabled) {
- buildCast[UTF8String](_, s => DateTimeUtils.stringToDateAnsi(s, queryContext))
+ buildCast[UTF8String](_, s => DateTimeUtils.stringToDateAnsi(s, getContextOrNull()))
} else {
buildCast[UTF8String](_, s => DateTimeUtils.stringToDate(s).orNull)
}
@@ -817,7 +817,7 @@ case class Cast(
// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
- buildCast[UTF8String](_, v => UTF8StringUtils.toLongExact(v, queryContext))
+ buildCast[UTF8String](_, v => UTF8StringUtils.toLongExact(v, getContextOrNull()))
case StringType =>
val result = new LongWrapper()
buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
@@ -840,7 +840,7 @@ case class Cast(
// IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
- buildCast[UTF8String](_, v => UTF8StringUtils.toIntExact(v, queryContext))
+ buildCast[UTF8String](_, v => UTF8StringUtils.toIntExact(v, getContextOrNull()))
case StringType =>
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
@@ -872,7 +872,7 @@ case class Cast(
// ShortConverter
private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
- buildCast[UTF8String](_, v => UTF8StringUtils.toShortExact(v, queryContext))
+ buildCast[UTF8String](_, v => UTF8StringUtils.toShortExact(v, getContextOrNull()))
case StringType =>
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toShort(result)) {
@@ -919,7 +919,7 @@ case class Cast(
// ByteConverter
private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
- buildCast[UTF8String](_, v => UTF8StringUtils.toByteExact(v, queryContext))
+ buildCast[UTF8String](_, v => UTF8StringUtils.toByteExact(v, getContextOrNull()))
case StringType =>
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toByte(result)) {
@@ -986,7 +986,7 @@ case class Cast(
null
} else {
throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(
- value, decimalType.precision, decimalType.scale, queryContext)
+ value, decimalType.precision, decimalType.scale, getContextOrNull())
}
}
}
@@ -999,7 +999,7 @@ case class Cast(
private[this] def toPrecision(
value: Decimal,
decimalType: DecimalType,
- context: Option[SQLQueryContext]): Decimal =
+ context: SQLQueryContext): Decimal =
value.toPrecision(
decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, !ansiEnabled, context)
@@ -1012,17 +1012,17 @@ case class Cast(
})
case StringType if ansiEnabled =>
buildCast[UTF8String](_,
- s => changePrecision(Decimal.fromStringANSI(s, target, queryContext), target))
+ s => changePrecision(Decimal.fromStringANSI(s, target, getContextOrNull()), target))
case BooleanType =>
buildCast[Boolean](_,
- b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target, queryContext))
+ b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target, getContextOrNull()))
case DateType =>
buildCast[Int](_, d => null) // date can't cast to decimal in Hive
case TimestampType =>
// Note that we lose precision here.
buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target))
case dt: DecimalType =>
- b => toPrecision(b.asInstanceOf[Decimal], target, queryContext)
+ b => toPrecision(b.asInstanceOf[Decimal], target, getContextOrNull())
case t: IntegralType =>
b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target)
case x: FractionalType =>
@@ -1055,7 +1055,7 @@ case class Cast(
val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false)
if(ansiEnabled && d == null) {
throw QueryExecutionErrors.invalidInputInCastToNumberError(
- DoubleType, s, queryContext)
+ DoubleType, s, getContextOrNull())
} else {
d
}
@@ -1081,7 +1081,7 @@ case class Cast(
val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
if (ansiEnabled && f == null) {
throw QueryExecutionErrors.invalidInputInCastToNumberError(
- FloatType, s, queryContext)
+ FloatType, s, getContextOrNull())
} else {
f
}
@@ -1196,10 +1196,6 @@ case class Cast(
}
}
- def errorContextCode(codegenContext: CodegenContext): String = {
- codegenContext.addReferenceObj("errCtx", queryContext)
- }
-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
@@ -1512,7 +1508,7 @@ case class Cast(
val intOpt = ctx.freshVariable("intOpt", classOf[Option[Integer]])
(c, evPrim, evNull) =>
if (ansiEnabled) {
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
code"""
$evPrim = $dateTimeUtilsCls.stringToDateAnsi($c, $errorContext);
"""
@@ -1556,12 +1552,13 @@ case class Cast(
|$evPrim = $d;
""".stripMargin
} else {
+ val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow)
val overflowCode = if (nullOnOverflow) {
s"$evNull = true;"
} else {
s"""
|throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(
- | $d, ${decimalType.precision}, ${decimalType.scale}, ${errorContextCode(ctx)});
+ | $d, ${decimalType.precision}, ${decimalType.scale}, $errorContextCode);
""".stripMargin
}
code"""
@@ -1602,7 +1599,7 @@ case class Cast(
}
"""
case StringType if ansiEnabled =>
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
val toType = ctx.addReferenceObj("toType", target)
(c, evPrim, evNull) =>
code"""
@@ -1679,7 +1676,7 @@ case class Cast(
val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]])
(c, evPrim, evNull) =>
if (ansiEnabled) {
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
code"""
$evPrim = $dateTimeUtilsCls.stringToTimestampAnsi($c, $zid, $errorContext);
"""
@@ -1718,7 +1715,7 @@ case class Cast(
case DoubleType =>
(c, evPrim, evNull) =>
if (ansiEnabled) {
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
code"$evPrim = $dateTimeUtilsCls.doubleToTimestampAnsi($c, $errorContext);"
} else {
code"""
@@ -1732,7 +1729,7 @@ case class Cast(
case FloatType =>
(c, evPrim, evNull) =>
if (ansiEnabled) {
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
code"$evPrim = $dateTimeUtilsCls.doubleToTimestampAnsi((double)$c, $errorContext);"
} else {
code"""
@@ -1752,7 +1749,7 @@ case class Cast(
val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]])
(c, evPrim, evNull) =>
if (ansiEnabled) {
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
code"""
$evPrim = $dateTimeUtilsCls.stringToTimestampWithoutTimeZoneAnsi($c, $errorContext);
"""
@@ -1869,7 +1866,7 @@ case class Cast(
val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}"
(c, evPrim, evNull) =>
val castFailureCode = if (ansiEnabled) {
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
s"throw QueryExecutionErrors.invalidInputSyntaxForBooleanError($c, $errorContext);"
} else {
s"$evNull = true;"
@@ -2004,7 +2001,7 @@ case class Cast(
private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType if ansiEnabled =>
val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
(c, evPrim, evNull) => code"$evPrim = $stringUtils.toByteExact($c, $errorContext);"
case StringType =>
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
@@ -2041,7 +2038,7 @@ case class Cast(
ctx: CodegenContext): CastFunction = from match {
case StringType if ansiEnabled =>
val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
(c, evPrim, evNull) => code"$evPrim = $stringUtils.toShortExact($c, $errorContext);"
case StringType =>
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
@@ -2076,7 +2073,7 @@ case class Cast(
private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType if ansiEnabled =>
val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
(c, evPrim, evNull) => code"$evPrim = $stringUtils.toIntExact($c, $errorContext);"
case StringType =>
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
@@ -2111,7 +2108,7 @@ case class Cast(
private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType if ansiEnabled =>
val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
(c, evPrim, evNull) => code"$evPrim = $stringUtils.toLongExact($c, $errorContext);"
case StringType =>
val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper])
@@ -2148,7 +2145,7 @@ case class Cast(
val floatStr = ctx.freshVariable("floatStr", StringType)
(c, evPrim, evNull) =>
val handleNull = if (ansiEnabled) {
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
s"throw QueryExecutionErrors.invalidInputInCastToNumberError(" +
s"org.apache.spark.sql.types.FloatType$$.MODULE$$,$c, $errorContext);"
} else {
@@ -2186,7 +2183,7 @@ case class Cast(
val doubleStr = ctx.freshVariable("doubleStr", StringType)
(c, evPrim, evNull) =>
val handleNull = if (ansiEnabled) {
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
s"throw QueryExecutionErrors.invalidInputInCastToNumberError(" +
s"org.apache.spark.sql.types.DoubleType$$.MODULE$$, $c, $errorContext);"
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index d623357b9da..261d9a0cb63 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -597,6 +597,16 @@ trait SupportQueryContext extends Expression with Serializable {
def initQueryContext(): Option[SQLQueryContext]
+ def getContextOrNull(): SQLQueryContext = queryContext.getOrElse(null)
+
+ def getContextOrNullCode(ctx: CodegenContext, withErrorContext: Boolean = true): String = {
+ if (withErrorContext && queryContext.isDefined) {
+ ctx.addReferenceObj("errCtx", queryContext.get)
+ } else {
+ "null"
+ }
+ }
+
// Note: Even though query contexts are serialized to executors, it will be regenerated from an
// empty "Origin" during rule transforms since "Origin"s are not serialized to executors
// for better performance. Thus, we need to copy the original query context during
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index b749dfdaea1..36ffcd8f764 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -86,7 +86,7 @@ abstract class AverageBase
// If all input are nulls, count will be 0 and we will get null after the division.
// We can't directly use `/` as it throws an exception under ansi mode.
- protected def getEvaluateExpression(context: Option[SQLQueryContext]) = child.dataType match {
+ protected def getEvaluateExpression(context: SQLQueryContext = null) = child.dataType match {
case _: DecimalType =>
If(EqualTo(count, Literal(0L)),
Literal(null, resultType),
@@ -141,7 +141,7 @@ case class Average(
override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions
- override lazy val evaluateExpression: Expression = getEvaluateExpression(queryContext)
+ override lazy val evaluateExpression: Expression = getEvaluateExpression(getContextOrNull())
override def initQueryContext(): Option[SQLQueryContext] = if (useAnsiAdd) {
Some(origin.context)
@@ -206,7 +206,7 @@ case class TryAverage(child: Expression) extends AverageBase {
}
override lazy val evaluateExpression: Expression = {
- addTryEvalIfNeeded(getEvaluateExpression(None))
+ addTryEvalIfNeeded(getEvaluateExpression())
}
override protected def withNewChildInternal(newChild: Expression): Expression =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 9230bd9bf44..e8492c0e5dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -148,14 +148,15 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
* So now, if ansi is enabled, then throw exception, if not then return null.
* If sum is not null, then return the sum.
*/
- protected def getEvaluateExpression(
- context: Option[SQLQueryContext]): Expression = resultType match {
- case d: DecimalType =>
- val checkOverflowInSum = CheckOverflowInSum(sum, d, !useAnsiAdd, context)
- If(isEmpty, Literal.create(null, resultType), checkOverflowInSum)
- case _ if shouldTrackIsEmpty =>
- If(isEmpty, Literal.create(null, resultType), sum)
- case _ => sum
+ protected def getEvaluateExpression(context: SQLQueryContext = null): Expression = {
+ resultType match {
+ case d: DecimalType =>
+ val checkOverflowInSum = CheckOverflowInSum(sum, d, !useAnsiAdd, context)
+ If(isEmpty, Literal.create(null, resultType), checkOverflowInSum)
+ case _ if shouldTrackIsEmpty =>
+ If(isEmpty, Literal.create(null, resultType), sum)
+ case _ => sum
+ }
}
// The flag `useAnsiAdd` won't be shown in the `toString` or `toAggString` methods
@@ -192,7 +193,7 @@ case class Sum(
override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions
- override lazy val evaluateExpression: Expression = getEvaluateExpression(queryContext)
+ override lazy val evaluateExpression: Expression = getEvaluateExpression(getContextOrNull())
override def initQueryContext(): Option[SQLQueryContext] = if (useAnsiAdd) {
Some(origin.context)
@@ -255,9 +256,9 @@ case class TrySum(child: Expression) extends SumBase(child) {
override lazy val evaluateExpression: Expression =
if (useAnsiAdd) {
- TryEval(getEvaluateExpression(None))
+ TryEval(getEvaluateExpression())
} else {
- getEvaluateExpression(None)
+ getEvaluateExpression()
}
override protected def withNewChildInternal(newChild: Expression): Expression =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 7bbe5d15b91..86e6e6d7323 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -265,7 +265,7 @@ abstract class BinaryArithmetic extends BinaryOperator
}
protected def checkDecimalOverflow(value: Decimal, precision: Int, scale: Int): Decimal = {
- value.toPrecision(precision, scale, Decimal.ROUND_HALF_UP, !failOnError, queryContext)
+ value.toPrecision(precision, scale, Decimal.ROUND_HALF_UP, !failOnError, getContextOrNull())
}
/** Name of the function for this expression on a [[Decimal]] type. */
@@ -285,11 +285,7 @@ abstract class BinaryArithmetic extends BinaryOperator
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case DecimalType.Fixed(precision, scale) =>
- val errorContextCode = if (failOnError) {
- ctx.addReferenceObj("errCtx", queryContext)
- } else {
- "scala.None$.MODULE$"
- }
+ val errorContextCode = getContextOrNullCode(ctx, failOnError)
val updateIsNull = if (failOnError) {
""
} else {
@@ -334,7 +330,7 @@ abstract class BinaryArithmetic extends BinaryOperator
})
case IntegerType | LongType if failOnError && exactMathMethod.isDefined =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
s"""
|${ev.value} = $mathUtils.${exactMathMethod.get}($eval1, $eval2, $errorContext);
@@ -414,9 +410,9 @@ case class Add(
case _: YearMonthIntervalType =>
MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
case _: IntegerType if failOnError =>
- MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], queryContext)
+ MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], getContextOrNull())
case _: LongType if failOnError =>
- MathUtils.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], queryContext)
+ MathUtils.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], getContextOrNull())
case _ => numeric.plus(input1, input2)
}
@@ -483,9 +479,15 @@ case class Subtract(
case _: YearMonthIntervalType =>
MathUtils.subtractExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
case _: IntegerType if failOnError =>
- MathUtils.subtractExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], queryContext)
+ MathUtils.subtractExact(
+ input1.asInstanceOf[Int],
+ input2.asInstanceOf[Int],
+ getContextOrNull())
case _: LongType if failOnError =>
- MathUtils.subtractExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], queryContext)
+ MathUtils.subtractExact(
+ input1.asInstanceOf[Long],
+ input2.asInstanceOf[Long],
+ getContextOrNull())
case _ => numeric.minus(input1, input2)
}
@@ -539,9 +541,15 @@ case class Multiply(
case DecimalType.Fixed(precision, scale) =>
checkDecimalOverflow(numeric.times(input1, input2).asInstanceOf[Decimal], precision, scale)
case _: IntegerType if failOnError =>
- MathUtils.multiplyExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], queryContext)
+ MathUtils.multiplyExact(
+ input1.asInstanceOf[Int],
+ input2.asInstanceOf[Int],
+ getContextOrNull())
case _: LongType if failOnError =>
- MathUtils.multiplyExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], queryContext)
+ MathUtils.multiplyExact(
+ input1.asInstanceOf[Long],
+ input2.asInstanceOf[Long],
+ getContextOrNull())
case _ => numeric.times(input1, input2)
}
@@ -578,10 +586,10 @@ trait DivModLike extends BinaryArithmetic {
} else {
if (isZero(input2)) {
// when we reach here, failOnError must be true.
- throw QueryExecutionErrors.divideByZeroError(queryContext)
+ throw QueryExecutionErrors.divideByZeroError(getContextOrNull())
}
if (checkDivideOverflow && input1 == Long.MinValue && input2 == -1) {
- throw QueryExecutionErrors.overflowInIntegralDivideError(queryContext)
+ throw QueryExecutionErrors.overflowInIntegralDivideError(getContextOrNull())
}
evalOperation(input1, input2)
}
@@ -603,11 +611,7 @@ trait DivModLike extends BinaryArithmetic {
s"${eval2.value} == 0"
}
val javaType = CodeGenerator.javaType(dataType)
- val errorContextCode = if (failOnError) {
- ctx.addReferenceObj("errCtx", queryContext)
- } else {
- "scala.None$.MODULE$"
- }
+ val errorContextCode = getContextOrNullCode(ctx, failOnError)
val operation = super.dataType match {
case DecimalType.Fixed(precision, scale) =>
val decimalValue = ctx.freshName("decimalValue")
@@ -962,7 +966,7 @@ case class Pmod(
} else {
if (isZero(input2)) {
// when we reach here, failOnError must bet true.
- throw QueryExecutionErrors.divideByZeroError(queryContext)
+ throw QueryExecutionErrors.divideByZeroError(getContextOrNull())
}
pmodFunc(input1, input2)
}
@@ -979,7 +983,7 @@ case class Pmod(
}
val remainder = ctx.freshName("remainder")
val javaType = CodeGenerator.javaType(dataType)
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
val result = dataType match {
case DecimalType.Fixed(precision, scale) =>
val decimalAdd = "$plus"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 098b3a88084..ae23775b62d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -2174,7 +2174,7 @@ case class ElementAt(
if (array.numElements() < math.abs(index)) {
if (failOnError) {
throw QueryExecutionErrors.invalidElementAtIndexError(
- index, array.numElements(), queryContext)
+ index, array.numElements(), getContextOrNull())
} else {
defaultValueOutOfBound match {
case Some(value) => value.eval()
@@ -2216,7 +2216,7 @@ case class ElementAt(
}
val indexOutOfBoundBranch = if (failOnError) {
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
// scalastyle:off line.size.limit
s"throw QueryExecutionErrors.invalidElementAtIndexError($index, $eval1.numElements(), $errorContext);"
// scalastyle:on line.size.limit
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index fedfcfb978f..b6cbb1d0005 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -268,7 +268,7 @@ case class GetArrayItem(
if (index >= baseValue.numElements() || index < 0) {
if (failOnError) {
throw QueryExecutionErrors.invalidArrayIndexError(
- index, baseValue.numElements, queryContext)
+ index, baseValue.numElements, getContextOrNull())
} else {
null
}
@@ -292,7 +292,7 @@ case class GetArrayItem(
}
val indexOutOfBoundBranch = if (failOnError) {
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
// scalastyle:off line.size.limit
s"throw QueryExecutionErrors.invalidArrayIndexError($index, $eval1.numElements(), $errorContext);"
// scalastyle:on line.size.limit
@@ -380,7 +380,7 @@ trait GetMapValueUtil
if (!found) {
if (failOnError) {
- throw QueryExecutionErrors.mapKeyNotExistError(ordinal, keyType, queryContext)
+ throw QueryExecutionErrors.mapKeyNotExistError(ordinal, keyType, getContextOrNull())
} else {
null
}
@@ -413,7 +413,7 @@ trait GetMapValueUtil
}
val keyJavaType = CodeGenerator.javaType(keyType)
- lazy val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ lazy val errorContext = getContextOrNullCode(ctx)
val keyDt = ctx.addReferenceObj("keyType", keyType, keyType.getClass.getName)
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val keyNotFoundBranch = if (failOnError) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index e672fffda19..37e3dd5ea89 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -123,14 +123,10 @@ case class CheckOverflow(
dataType.scale,
Decimal.ROUND_HALF_UP,
nullOnOverflow,
- queryContext)
+ getContextOrNull())
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val errorContextCode = if (nullOnOverflow) {
- "scala.None$.MODULE$"
- } else {
- ctx.addReferenceObj("errCtx", queryContext)
- }
+ val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow)
nullSafeCodeGen(ctx, ev, eval => {
// scalastyle:off line.size.limit
s"""
@@ -161,7 +157,7 @@ case class CheckOverflowInSum(
child: Expression,
dataType: DecimalType,
nullOnOverflow: Boolean,
- context: Option[SQLQueryContext] = None) extends UnaryExpression {
+ context: SQLQueryContext) extends UnaryExpression with SupportQueryContext {
override def nullable: Boolean = true
@@ -182,11 +178,7 @@ case class CheckOverflowInSum(
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
- val errorContextCode = if (nullOnOverflow) {
- "scala.None$.MODULE$"
- } else {
- ctx.addReferenceObj("errCtx", context)
- }
+ val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow)
val nullHandling = if (nullOnOverflow) {
""
} else {
@@ -216,6 +208,8 @@ case class CheckOverflowInSum(
override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum =
copy(child = newChild)
+
+ override def initQueryContext(): Option[SQLQueryContext] = Option(context)
}
/**
@@ -261,12 +255,12 @@ case class DecimalDivideWithOverflowCheck(
left: Expression,
right: Expression,
override val dataType: DecimalType,
- context: Option[SQLQueryContext],
+ context: SQLQueryContext,
nullOnOverflow: Boolean)
extends BinaryExpression with ExpectsInputTypes with SupportQueryContext {
override def nullable: Boolean = nullOnOverflow
override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, DecimalType)
- override def initQueryContext(): Option[SQLQueryContext] = context
+ override def initQueryContext(): Option[SQLQueryContext] = Option(context)
def decimalMethod: String = "$div"
override def eval(input: InternalRow): Any = {
@@ -275,22 +269,18 @@ case class DecimalDivideWithOverflowCheck(
if (nullOnOverflow) {
null
} else {
- throw QueryExecutionErrors.overflowInSumOfDecimalError(queryContext)
+ throw QueryExecutionErrors.overflowInSumOfDecimalError(getContextOrNull())
}
} else {
val value2 = right.eval(input)
dataType.fractional.asInstanceOf[Fractional[Any]].div(value1, value2).asInstanceOf[Decimal]
.toPrecision(dataType.precision, dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow,
- queryContext)
+ getContextOrNull())
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val errorContextCode = if (nullOnOverflow) {
- "scala.None$.MODULE$"
- } else {
- ctx.addReferenceObj("errCtx", queryContext)
- }
+ val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow)
val nullHandling = if (nullOnOverflow) {
""
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
index 17a2714c611..f7ec82de11b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
@@ -604,7 +604,7 @@ trait IntervalDivide {
minValue: Any,
num: Expression,
numValue: Any,
- context: Option[SQLQueryContext]): Unit = {
+ context: SQLQueryContext): Unit = {
if (value == minValue && num.dataType.isInstanceOf[IntegralType]) {
if (numValue.asInstanceOf[Number].longValue() == -1) {
throw QueryExecutionErrors.overflowInIntegralDivideError(context)
@@ -615,7 +615,7 @@ trait IntervalDivide {
def divideByZeroCheck(
dataType: DataType,
num: Any,
- context: Option[SQLQueryContext]): Unit = dataType match {
+ context: SQLQueryContext): Unit = dataType match {
case _: DecimalType =>
if (num.asInstanceOf[Decimal].isZero) {
throw QueryExecutionErrors.intervalDividedByZeroError(context)
@@ -665,13 +665,13 @@ case class DivideYMInterval(
override def nullSafeEval(interval: Any, num: Any): Any = {
checkDivideOverflow(
- interval.asInstanceOf[Int], Int.MinValue, right, num, Some(origin.context))
- divideByZeroCheck(right.dataType, num, Some(origin.context))
+ interval.asInstanceOf[Int], Int.MinValue, right, num, origin.context)
+ divideByZeroCheck(right.dataType, num, origin.context)
evalFunc(interval.asInstanceOf[Int], num)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val errorContext = ctx.addReferenceObj("errCtx", Some(origin.context))
+ val errorContext = ctx.addReferenceObj("errCtx", origin.context)
right.dataType match {
case t: IntegralType =>
val math = t match {
@@ -743,13 +743,13 @@ case class DivideDTInterval(
override def nullSafeEval(interval: Any, num: Any): Any = {
checkDivideOverflow(
- interval.asInstanceOf[Long], Long.MinValue, right, num, Some(origin.context))
- divideByZeroCheck(right.dataType, num, Some(origin.context))
+ interval.asInstanceOf[Long], Long.MinValue, right, num, origin.context)
+ divideByZeroCheck(right.dataType, num, origin.context)
evalFunc(interval.asInstanceOf[Long], num)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val errorContext = ctx.addReferenceObj("errCtx", Some(origin.context))
+ val errorContext = ctx.addReferenceObj("errCtx", origin.context)
right.dataType match {
case _: IntegralType =>
val math = classOf[LongMath].getName
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 55ff36e9863..dfbc041b259 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -1520,7 +1520,7 @@ abstract class RoundBase(child: Expression, scale: Expression,
if (_scale >= 0) {
s"""
${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s,
- Decimal.$modeStr(), true, scala.None$$.MODULE$$);
+ Decimal.$modeStr(), true, null);
${ev.isNull} = ${ev.value} == null;"""
} else {
s"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 815eb8977b6..d4504c36e4e 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -296,7 +296,8 @@ case class Elt(
val index = indexObj.asInstanceOf[Int]
if (index <= 0 || index > inputExprs.length) {
if (failOnError) {
- throw QueryExecutionErrors.invalidArrayIndexError(index, inputExprs.length, queryContext)
+ throw QueryExecutionErrors.invalidArrayIndexError(
+ index, inputExprs.length, getContextOrNull())
} else {
null
}
@@ -348,7 +349,7 @@ case class Elt(
}.mkString)
val indexOutOfBoundBranch = if (failOnError) {
- val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+ val errorContext = getContextOrNullCode(ctx)
// scalastyle:off line.size.limit
s"""
|if (!$indexMatched) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index 172c2e54034..af0666a98fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -468,14 +468,14 @@ object DateTimeUtils {
def stringToTimestampAnsi(
s: UTF8String,
timeZoneId: ZoneId,
- context: Option[SQLQueryContext] = None): Long = {
+ context: SQLQueryContext = null): Long = {
stringToTimestamp(s, timeZoneId).getOrElse {
throw QueryExecutionErrors.invalidInputInCastToDatetimeError(
s, StringType, TimestampType, context)
}
}
- def doubleToTimestampAnsi(d: Double, context: Option[SQLQueryContext]): Long = {
+ def doubleToTimestampAnsi(d: Double, context: SQLQueryContext): Long = {
if (d.isNaN || d.isInfinite) {
throw QueryExecutionErrors.invalidInputInCastToDatetimeError(
d, DoubleType, TimestampType, context)
@@ -527,7 +527,7 @@ object DateTimeUtils {
def stringToTimestampWithoutTimeZoneAnsi(
s: UTF8String,
- context: Option[SQLQueryContext]): Long = {
+ context: SQLQueryContext): Long = {
stringToTimestampWithoutTimeZone(s, true).getOrElse {
throw QueryExecutionErrors.invalidInputInCastToDatetimeError(
s, StringType, TimestampNTZType, context)
@@ -646,7 +646,9 @@ object DateTimeUtils {
}
}
- def stringToDateAnsi(s: UTF8String, context: Option[SQLQueryContext] = None): Int = {
+ def stringToDateAnsi(
+ s: UTF8String,
+ context: SQLQueryContext = null): Int = {
stringToDate(s).getOrElse {
throw QueryExecutionErrors.invalidInputInCastToDatetimeError(
s, StringType, DateType, context)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
index b4695062c08..f2c4236ad7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
@@ -733,7 +733,7 @@ object IntervalUtils {
* @throws ArithmeticException if the result overflows any field value or divided by zero
*/
def divideExact(interval: CalendarInterval, num: Double): CalendarInterval = {
- if (num == 0) throw QueryExecutionErrors.intervalDividedByZeroError(None)
+ if (num == 0) throw QueryExecutionErrors.intervalDividedByZeroError(null)
fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
index 6cb3616d4e7..e79e483076d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
@@ -27,37 +27,37 @@ object MathUtils {
def addExact(a: Int, b: Int): Int = withOverflow(Math.addExact(a, b))
- def addExact(a: Int, b: Int, context: Option[SQLQueryContext]): Int = {
+ def addExact(a: Int, b: Int, context: SQLQueryContext): Int = {
withOverflow(Math.addExact(a, b), hint = "try_add", context)
}
def addExact(a: Long, b: Long): Long = withOverflow(Math.addExact(a, b))
- def addExact(a: Long, b: Long, context: Option[SQLQueryContext]): Long = {
+ def addExact(a: Long, b: Long, context: SQLQueryContext): Long = {
withOverflow(Math.addExact(a, b), hint = "try_add", context)
}
def subtractExact(a: Int, b: Int): Int = withOverflow(Math.subtractExact(a, b))
- def subtractExact(a: Int, b: Int, context: Option[SQLQueryContext]): Int = {
+ def subtractExact(a: Int, b: Int, context: SQLQueryContext): Int = {
withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context)
}
def subtractExact(a: Long, b: Long): Long = withOverflow(Math.subtractExact(a, b))
- def subtractExact(a: Long, b: Long, context: Option[SQLQueryContext]): Long = {
+ def subtractExact(a: Long, b: Long, context: SQLQueryContext): Long = {
withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context)
}
def multiplyExact(a: Int, b: Int): Int = withOverflow(Math.multiplyExact(a, b))
- def multiplyExact(a: Int, b: Int, context: Option[SQLQueryContext]): Int = {
+ def multiplyExact(a: Int, b: Int, context: SQLQueryContext): Int = {
withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context)
}
def multiplyExact(a: Long, b: Long): Long = withOverflow(Math.multiplyExact(a, b))
- def multiplyExact(a: Long, b: Long, context: Option[SQLQueryContext]): Long = {
+ def multiplyExact(a: Long, b: Long, context: SQLQueryContext): Long = {
withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context)
}
@@ -78,7 +78,7 @@ object MathUtils {
private def withOverflow[A](
f: => A,
hint: String = "",
- context: Option[SQLQueryContext] = None): A = {
+ context: SQLQueryContext = null): A = {
try {
f
} catch {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala
index 503c0e181ca..f7800469c35 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala
@@ -27,21 +27,21 @@ import org.apache.spark.unsafe.types.UTF8String
*/
object UTF8StringUtils {
- def toLongExact(s: UTF8String, context: Option[SQLQueryContext]): Long =
+ def toLongExact(s: UTF8String, context: SQLQueryContext): Long =
withException(s.toLongExact, context, LongType, s)
- def toIntExact(s: UTF8String, context: Option[SQLQueryContext]): Int =
+ def toIntExact(s: UTF8String, context: SQLQueryContext): Int =
withException(s.toIntExact, context, IntegerType, s)
- def toShortExact(s: UTF8String, context: Option[SQLQueryContext]): Short =
+ def toShortExact(s: UTF8String, context: SQLQueryContext): Short =
withException(s.toShortExact, context, ShortType, s)
- def toByteExact(s: UTF8String, context: Option[SQLQueryContext]): Byte =
+ def toByteExact(s: UTF8String, context: SQLQueryContext): Byte =
withException(s.toByteExact, context, ByteType, s)
private def withException[A](
f: => A,
- context: Option[SQLQueryContext],
+ context: SQLQueryContext,
to: DataType,
s: UTF8String): A = {
try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala
index 9617f7d4b0f..4785073f80b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.errors
import java.util.Locale
+import org.apache.spark.QueryContext
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.catalyst.trees.SQLQueryContext
@@ -97,7 +98,11 @@ private[sql] trait QueryErrorsBase {
quoteByDefault(toPrettySQL(e))
}
- def getSummary(context: Option[SQLQueryContext]): String = {
- context.map(_.summary).getOrElse("")
+ def getSummary(sqlContext: SQLQueryContext): String = {
+ if (sqlContext == null) "" else sqlContext.summary
+ }
+
+ def getQueryContext(sqlContext: SQLQueryContext): Array[QueryContext] = {
+ if (sqlContext == null) Array.empty else Array(sqlContext.asInstanceOf[QueryContext])
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index bad95afa139..3644e7c0df8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -89,7 +89,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
toSQLType(from),
toSQLType(to),
toSQLConf(SQLConf.ANSI_ENABLED.key)),
- context = None,
+ context = Array.empty,
summary = "")
}
@@ -103,7 +103,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
toSQLType(from),
toSQLType(to),
toSQLId(columnName)),
- context = None,
+ context = Array.empty,
summary = ""
)
}
@@ -112,7 +112,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
value: Decimal,
decimalPrecision: Int,
decimalScale: Int,
- context: Option[SQLQueryContext] = None): ArithmeticException = {
+ context: SQLQueryContext = null): ArithmeticException = {
new SparkArithmeticException(
errorClass = "CANNOT_CHANGE_DECIMAL_PRECISION",
messageParameters = Array(
@@ -120,7 +120,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
decimalPrecision.toString,
decimalScale.toString,
toSQLConf(SQLConf.ANSI_ENABLED.key)),
- context = context,
+ context = getQueryContext(context),
summary = getSummary(context))
}
@@ -128,7 +128,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
value: Any,
from: DataType,
to: DataType,
- context: Option[SQLQueryContext]): Throwable = {
+ context: SQLQueryContext): Throwable = {
new SparkDateTimeException(
errorClass = "CAST_INVALID_INPUT",
messageParameters = Array(
@@ -136,13 +136,13 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
toSQLType(from),
toSQLType(to),
toSQLConf(SQLConf.ANSI_ENABLED.key)),
- context = context,
+ context = getQueryContext(context),
summary = getSummary(context))
}
def invalidInputSyntaxForBooleanError(
s: UTF8String,
- context: Option[SQLQueryContext]): SparkRuntimeException = {
+ context: SQLQueryContext): SparkRuntimeException = {
new SparkRuntimeException(
errorClass = "CAST_INVALID_INPUT",
messageParameters = Array(
@@ -150,14 +150,14 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
toSQLType(StringType),
toSQLType(BooleanType),
toSQLConf(SQLConf.ANSI_ENABLED.key)),
- context = context,
+ context = getQueryContext(context),
summary = getSummary(context))
}
def invalidInputInCastToNumberError(
to: DataType,
s: UTF8String,
- context: Option[SQLQueryContext]): SparkNumberFormatException = {
+ context: SQLQueryContext): SparkNumberFormatException = {
new SparkNumberFormatException(
errorClass = "CAST_INVALID_INPUT",
messageParameters = Array(
@@ -165,7 +165,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
toSQLType(StringType),
toSQLType(to),
toSQLConf(SQLConf.ANSI_ENABLED.key)),
- context = context,
+ context = getQueryContext(context),
summary = getSummary(context))
}
@@ -196,40 +196,40 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
messageParameters = Array(funcCls, inputTypes, outputType), e)
}
- def divideByZeroError(context: Option[SQLQueryContext]): ArithmeticException = {
+ def divideByZeroError(context: SQLQueryContext): ArithmeticException = {
new SparkArithmeticException(
errorClass = "DIVIDE_BY_ZERO",
messageParameters = Array(toSQLConf(SQLConf.ANSI_ENABLED.key)),
- context = context,
+ context = getQueryContext(context),
summary = getSummary(context))
}
- def intervalDividedByZeroError(context: Option[SQLQueryContext]): ArithmeticException = {
+ def intervalDividedByZeroError(context: SQLQueryContext): ArithmeticException = {
new SparkArithmeticException(
errorClass = "INTERVAL_DIVIDED_BY_ZERO",
messageParameters = Array.empty,
- context = context,
+ context = getQueryContext(context),
summary = getSummary(context))
}
def invalidArrayIndexError(
index: Int,
numElements: Int,
- context: Option[SQLQueryContext]): ArrayIndexOutOfBoundsException = {
+ context: SQLQueryContext): ArrayIndexOutOfBoundsException = {
new SparkArrayIndexOutOfBoundsException(
errorClass = "INVALID_ARRAY_INDEX",
messageParameters = Array(
toSQLValue(index, IntegerType),
toSQLValue(numElements, IntegerType),
toSQLConf(SQLConf.ANSI_ENABLED.key)),
- context = context,
+ context = getQueryContext(context),
summary = getSummary(context))
}
def invalidElementAtIndexError(
index: Int,
numElements: Int,
- context: Option[SQLQueryContext]): ArrayIndexOutOfBoundsException = {
+ context: SQLQueryContext): ArrayIndexOutOfBoundsException = {
new SparkArrayIndexOutOfBoundsException(
errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT",
messageParameters =
@@ -237,20 +237,20 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
toSQLValue(index, IntegerType),
toSQLValue(numElements, IntegerType),
toSQLConf(SQLConf.ANSI_ENABLED.key)),
- context = context,
+ context = getQueryContext(context),
summary = getSummary(context))
}
def mapKeyNotExistError(
key: Any,
dataType: DataType,
- context: Option[SQLQueryContext]): NoSuchElementException = {
+ context: SQLQueryContext): NoSuchElementException = {
new SparkNoSuchElementException(
errorClass = "MAP_KEY_DOES_NOT_EXIST",
messageParameters = Array(
toSQLValue(key, dataType),
toSQLConf(SQLConf.ANSI_ENABLED.key)),
- context = context,
+ context = getQueryContext(context),
summary = getSummary(context))
}
@@ -259,7 +259,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
errorClass = "INVALID_FRACTION_OF_SECOND",
errorSubClass = None,
Array(toSQLConf(SQLConf.ANSI_ENABLED.key)),
- context = None,
+ context = Array.empty,
summary = "")
}
@@ -268,7 +268,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
errorClass = "CANNOT_PARSE_TIMESTAMP",
errorSubClass = None,
Array(e.getMessage, toSQLConf(SQLConf.ANSI_ENABLED.key)),
- context = None,
+ context = Array.empty,
summary = "")
}
@@ -294,11 +294,11 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
ansiIllegalArgumentError(e.getMessage)
}
- def overflowInSumOfDecimalError(context: Option[SQLQueryContext]): ArithmeticException = {
+ def overflowInSumOfDecimalError(context: SQLQueryContext): ArithmeticException = {
arithmeticOverflowError("Overflow in sum of decimals", context = context)
}
- def overflowInIntegralDivideError(context: Option[SQLQueryContext]): ArithmeticException = {
+ def overflowInIntegralDivideError(context: SQLQueryContext): ArithmeticException = {
arithmeticOverflowError("Overflow in integral divide", "try_divide", context)
}
@@ -514,14 +514,14 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
def arithmeticOverflowError(
message: String,
hint: String = "",
- context: Option[SQLQueryContext] = None): ArithmeticException = {
+ context: SQLQueryContext = null): ArithmeticException = {
val alternative = if (hint.nonEmpty) {
s" Use '$hint' to tolerate overflow and return NULL instead."
} else ""
new SparkArithmeticException(
errorClass = "ARITHMETIC_OVERFLOW",
messageParameters = Array(message, alternative, SQLConf.ANSI_ENABLED.key),
- context = context,
+ context = getQueryContext(context),
summary = getSummary(context))
}
@@ -2061,7 +2061,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
messageParameters = Array(
s"add ${toSQLValue(amount, IntegerType)} $unit to " +
s"${toSQLValue(DateTimeUtils.microsToInstant(micros), TimestampType)}"),
- context = None,
+ context = Array.empty,
summary = "")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 00172f69fda..aa683a06a8e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -367,7 +367,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
scale: Int,
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP,
nullOnOverflow: Boolean = true,
- context: Option[SQLQueryContext] = None): Decimal = {
+ context: SQLQueryContext = null): Decimal = {
val copy = clone()
if (copy.changePrecision(precision, scale, roundMode)) {
copy
@@ -632,7 +632,7 @@ object Decimal {
def fromStringANSI(
str: UTF8String,
to: DecimalType = DecimalType.USER_DEFAULT,
- context: Option[SQLQueryContext] = None): Decimal = {
+ context: SQLQueryContext = null): Decimal = {
try {
val bigDecimal = stringToJavaBigDecimal(str)
// We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
index d96ca4b87f0..513a62dc7f0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
@@ -91,7 +91,7 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkExceptionInExpression[ArithmeticException](expr1, query)
val expr2 = CheckOverflowInSum(
- Literal(d), DecimalType(4, 3), false, context = Some(origin.context))
+ Literal(d), DecimalType(4, 3), false, context = origin.context)
checkExceptionInExpression[ArithmeticException](expr2, query)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org