You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2022/08/01 08:40:36 UTC

[spark] branch master updated: [SPARK-39923][SQL] Multiple query contexts in Spark exceptions

This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9ee2c753b98 [SPARK-39923][SQL] Multiple query contexts in Spark exceptions
9ee2c753b98 is described below

commit 9ee2c753b98b290fab9b2ec1f02d90c7c9441271
Author: Max Gekk <ma...@gmail.com>
AuthorDate: Mon Aug 1 13:40:22 2022 +0500

    [SPARK-39923][SQL] Multiple query contexts in Spark exceptions
    
    ### What changes were proposed in this pull request?
    1. Replace `Option[QueryContext]` by `Array[QueryContext]` in Spark exceptions like in `SparkRuntimeException`.
    2. Pass `SQLQueryContext` to `QueryExecutionErrors` functions instead of `Option[SQLQueryContext]`.
    3. Add the methods `getContextOrNull()` and `getContextOrNullCode()` to `SupportQueryContext` to get a SQL query context or `null` (if it is missed) of an expression.
    
    ### Why are the changes needed?
    1. The changes will allow to chain multiple error contexts in Spark's exception. For instance, if user's query refers a view v1, v1 refers another view v2, and v2 does a division. The error contexts will be: sql fragment of v2 that does division -> sql fragment of v1 that refers v2 -> sql fragment of your query that refers v1.
    2. Passing `SQLQueryContext` to `QueryExecutionErrors` directly simplifies codegen code because it allows to avoid construction of Scala objects like `scala.None`.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, this PR changes user-facing exceptions.
    
    ### How was this patch tested?
    By running the modified test suites:
    ```
    $ build/sbt "test:testOnly *DecimalExpressionSuite"
    ```
    and potentially affected tests:
    ```
    $ build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite"
    ```
    
    Closes #37343 from MaxGekk/array-as-query-context.
    
    Authored-by: Max Gekk <ma...@gmail.com>
    Signed-off-by: Max Gekk <ma...@gmail.com>
---
 .../scala/org/apache/spark/SparkException.scala    | 28 ++++-----
 .../spark/sql/catalyst/expressions/Cast.scala      | 67 +++++++++++-----------
 .../sql/catalyst/expressions/Expression.scala      | 10 ++++
 .../catalyst/expressions/aggregate/Average.scala   |  6 +-
 .../sql/catalyst/expressions/aggregate/Sum.scala   | 23 ++++----
 .../sql/catalyst/expressions/arithmetic.scala      | 48 +++++++++-------
 .../expressions/collectionOperations.scala         |  4 +-
 .../expressions/complexTypeExtractors.scala        |  8 +--
 .../catalyst/expressions/decimalExpressions.scala  | 32 ++++-------
 .../catalyst/expressions/intervalExpressions.scala | 16 +++---
 .../sql/catalyst/expressions/mathExpressions.scala |  2 +-
 .../catalyst/expressions/stringExpressions.scala   |  5 +-
 .../spark/sql/catalyst/util/DateTimeUtils.scala    | 10 ++--
 .../spark/sql/catalyst/util/IntervalUtils.scala    |  2 +-
 .../apache/spark/sql/catalyst/util/MathUtils.scala | 14 ++---
 .../spark/sql/catalyst/util/UTF8StringUtils.scala  | 10 ++--
 .../apache/spark/sql/errors/QueryErrorsBase.scala  |  9 ++-
 .../spark/sql/errors/QueryExecutionErrors.scala    | 54 ++++++++---------
 .../scala/org/apache/spark/sql/types/Decimal.scala |  4 +-
 .../expressions/DecimalExpressionSuite.scala       |  2 +-
 20 files changed, 182 insertions(+), 172 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala
index d6add48ffb1..6548a114d41 100644
--- a/core/src/main/scala/org/apache/spark/SparkException.scala
+++ b/core/src/main/scala/org/apache/spark/SparkException.scala
@@ -119,7 +119,7 @@ private[spark] class SparkArithmeticException(
     errorClass: String,
     errorSubClass: Option[String] = None,
     messageParameters: Array[String],
-    context: Option[QueryContext],
+    context: Array[QueryContext],
     summary: String)
   extends ArithmeticException(
     SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary))
@@ -128,7 +128,7 @@ private[spark] class SparkArithmeticException(
   override def getMessageParameters: Array[String] = messageParameters
   override def getErrorClass: String = errorClass
   override def getErrorSubClass: String = errorSubClass.orNull
-  override def getQueryContext: Array[QueryContext] = context.toArray
+  override def getQueryContext: Array[QueryContext] = context
 }
 
 /**
@@ -195,7 +195,7 @@ private[spark] class SparkDateTimeException(
     errorClass: String,
     errorSubClass: Option[String] = None,
     messageParameters: Array[String],
-    context: Option[QueryContext],
+    context: Array[QueryContext],
     summary: String)
   extends DateTimeException(
     SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary))
@@ -204,7 +204,7 @@ private[spark] class SparkDateTimeException(
   override def getMessageParameters: Array[String] = messageParameters
   override def getErrorClass: String = errorClass
   override def getErrorSubClass: String = errorSubClass.orNull
-  override def getQueryContext: Array[QueryContext] = context.toArray
+  override def getQueryContext: Array[QueryContext] = context
 }
 
 /**
@@ -244,7 +244,7 @@ private[spark] class SparkNumberFormatException(
     errorClass: String,
     errorSubClass: Option[String] = None,
     messageParameters: Array[String],
-    context: Option[QueryContext],
+    context: Array[QueryContext],
     summary: String)
   extends NumberFormatException(
     SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary))
@@ -253,7 +253,7 @@ private[spark] class SparkNumberFormatException(
   override def getMessageParameters: Array[String] = messageParameters
   override def getErrorClass: String = errorClass
   override def getErrorSubClass: String = errorSubClass.orNull
-  override def getQueryContext: Array[QueryContext] = context.toArray
+  override def getQueryContext: Array[QueryContext] = context
 }
 
 /**
@@ -323,7 +323,7 @@ private[spark] class SparkRuntimeException(
     errorSubClass: Option[String] = None,
     messageParameters: Array[String],
     cause: Throwable = null,
-    context: Option[QueryContext] = None,
+    context: Array[QueryContext] = Array.empty,
     summary: String = "")
   extends RuntimeException(
     SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary),
@@ -334,7 +334,7 @@ private[spark] class SparkRuntimeException(
            errorSubClass: String,
            messageParameters: Array[String],
            cause: Throwable,
-           context: Option[QueryContext])
+           context: Array[QueryContext])
     = this(errorClass = errorClass,
     errorSubClass = Some(errorSubClass),
     messageParameters = messageParameters,
@@ -348,12 +348,12 @@ private[spark] class SparkRuntimeException(
     errorSubClass = Some(errorSubClass),
     messageParameters = messageParameters,
     cause = null,
-    context = None)
+    context = Array.empty[QueryContext])
 
   override def getMessageParameters: Array[String] = messageParameters
   override def getErrorClass: String = errorClass
   override def getErrorSubClass: String = errorSubClass.orNull
-  override def getQueryContext: Array[QueryContext] = context.toArray
+  override def getQueryContext: Array[QueryContext] = context
 }
 
 /**
@@ -379,7 +379,7 @@ private[spark] class SparkArrayIndexOutOfBoundsException(
     errorClass: String,
     errorSubClass: Option[String] = None,
     messageParameters: Array[String],
-    context: Option[QueryContext],
+    context: Array[QueryContext],
     summary: String)
   extends ArrayIndexOutOfBoundsException(
     SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary))
@@ -388,7 +388,7 @@ private[spark] class SparkArrayIndexOutOfBoundsException(
   override def getMessageParameters: Array[String] = messageParameters
   override def getErrorClass: String = errorClass
   override def getErrorSubClass: String = errorSubClass.orNull
-  override def getQueryContext: Array[QueryContext] = context.toArray
+  override def getQueryContext: Array[QueryContext] = context
 }
 
 /**
@@ -420,7 +420,7 @@ private[spark] class SparkNoSuchElementException(
     errorClass: String,
     errorSubClass: Option[String] = None,
     messageParameters: Array[String],
-    context: Option[QueryContext],
+    context: Array[QueryContext],
     summary: String)
   extends NoSuchElementException(
     SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary))
@@ -429,7 +429,7 @@ private[spark] class SparkNoSuchElementException(
   override def getMessageParameters: Array[String] = messageParameters
   override def getErrorClass: String = errorClass
   override def getErrorSubClass: String = errorSubClass.orNull
-  override def getQueryContext: Array[QueryContext] = context.toArray
+  override def getQueryContext: Array[QueryContext] = context
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 0ba651b5650..f740ecd9dcb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -653,7 +653,7 @@ case class Cast(
           false
         } else {
           if (ansiEnabled) {
-            throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, queryContext)
+            throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, getContextOrNull())
           } else {
             null
           }
@@ -685,7 +685,7 @@ case class Cast(
     case StringType =>
       buildCast[UTF8String](_, utfs => {
         if (ansiEnabled) {
-          DateTimeUtils.stringToTimestampAnsi(utfs, zoneId, queryContext)
+          DateTimeUtils.stringToTimestampAnsi(utfs, zoneId, getContextOrNull())
         } else {
           DateTimeUtils.stringToTimestamp(utfs, zoneId).orNull
         }
@@ -710,14 +710,14 @@ case class Cast(
     // TimestampWritable.doubleToTimestamp
     case DoubleType =>
       if (ansiEnabled) {
-        buildCast[Double](_, d => doubleToTimestampAnsi(d, queryContext))
+        buildCast[Double](_, d => doubleToTimestampAnsi(d, getContextOrNull()))
       } else {
         buildCast[Double](_, d => doubleToTimestamp(d))
       }
     // TimestampWritable.floatToTimestamp
     case FloatType =>
       if (ansiEnabled) {
-        buildCast[Float](_, f => doubleToTimestampAnsi(f.toDouble, queryContext))
+        buildCast[Float](_, f => doubleToTimestampAnsi(f.toDouble, getContextOrNull()))
       } else {
         buildCast[Float](_, f => doubleToTimestamp(f.toDouble))
       }
@@ -727,7 +727,7 @@ case class Cast(
     case StringType =>
       buildCast[UTF8String](_, utfs => {
         if (ansiEnabled) {
-          DateTimeUtils.stringToTimestampWithoutTimeZoneAnsi(utfs, queryContext)
+          DateTimeUtils.stringToTimestampWithoutTimeZoneAnsi(utfs, getContextOrNull())
         } else {
           DateTimeUtils.stringToTimestampWithoutTimeZone(utfs).orNull
         }
@@ -760,7 +760,7 @@ case class Cast(
   private[this] def castToDate(from: DataType): Any => Any = from match {
     case StringType =>
       if (ansiEnabled) {
-        buildCast[UTF8String](_, s => DateTimeUtils.stringToDateAnsi(s, queryContext))
+        buildCast[UTF8String](_, s => DateTimeUtils.stringToDateAnsi(s, getContextOrNull()))
       } else {
         buildCast[UTF8String](_, s => DateTimeUtils.stringToDate(s).orNull)
       }
@@ -817,7 +817,7 @@ case class Cast(
   // LongConverter
   private[this] def castToLong(from: DataType): Any => Any = from match {
     case StringType if ansiEnabled =>
-      buildCast[UTF8String](_, v => UTF8StringUtils.toLongExact(v, queryContext))
+      buildCast[UTF8String](_, v => UTF8StringUtils.toLongExact(v, getContextOrNull()))
     case StringType =>
       val result = new LongWrapper()
       buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
@@ -840,7 +840,7 @@ case class Cast(
   // IntConverter
   private[this] def castToInt(from: DataType): Any => Any = from match {
     case StringType if ansiEnabled =>
-      buildCast[UTF8String](_, v => UTF8StringUtils.toIntExact(v, queryContext))
+      buildCast[UTF8String](_, v => UTF8StringUtils.toIntExact(v, getContextOrNull()))
     case StringType =>
       val result = new IntWrapper()
       buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
@@ -872,7 +872,7 @@ case class Cast(
   // ShortConverter
   private[this] def castToShort(from: DataType): Any => Any = from match {
     case StringType if ansiEnabled =>
-      buildCast[UTF8String](_, v => UTF8StringUtils.toShortExact(v, queryContext))
+      buildCast[UTF8String](_, v => UTF8StringUtils.toShortExact(v, getContextOrNull()))
     case StringType =>
       val result = new IntWrapper()
       buildCast[UTF8String](_, s => if (s.toShort(result)) {
@@ -919,7 +919,7 @@ case class Cast(
   // ByteConverter
   private[this] def castToByte(from: DataType): Any => Any = from match {
     case StringType if ansiEnabled =>
-      buildCast[UTF8String](_, v => UTF8StringUtils.toByteExact(v, queryContext))
+      buildCast[UTF8String](_, v => UTF8StringUtils.toByteExact(v, getContextOrNull()))
     case StringType =>
       val result = new IntWrapper()
       buildCast[UTF8String](_, s => if (s.toByte(result)) {
@@ -986,7 +986,7 @@ case class Cast(
         null
       } else {
         throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(
-          value, decimalType.precision, decimalType.scale, queryContext)
+          value, decimalType.precision, decimalType.scale, getContextOrNull())
       }
     }
   }
@@ -999,7 +999,7 @@ case class Cast(
   private[this] def toPrecision(
       value: Decimal,
       decimalType: DecimalType,
-      context: Option[SQLQueryContext]): Decimal =
+      context: SQLQueryContext): Decimal =
     value.toPrecision(
       decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, !ansiEnabled, context)
 
@@ -1012,17 +1012,17 @@ case class Cast(
       })
     case StringType if ansiEnabled =>
       buildCast[UTF8String](_,
-        s => changePrecision(Decimal.fromStringANSI(s, target, queryContext), target))
+        s => changePrecision(Decimal.fromStringANSI(s, target, getContextOrNull()), target))
     case BooleanType =>
       buildCast[Boolean](_,
-        b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target, queryContext))
+        b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target, getContextOrNull()))
     case DateType =>
       buildCast[Int](_, d => null) // date can't cast to decimal in Hive
     case TimestampType =>
       // Note that we lose precision here.
       buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target))
     case dt: DecimalType =>
-      b => toPrecision(b.asInstanceOf[Decimal], target, queryContext)
+      b => toPrecision(b.asInstanceOf[Decimal], target, getContextOrNull())
     case t: IntegralType =>
       b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target)
     case x: FractionalType =>
@@ -1055,7 +1055,7 @@ case class Cast(
             val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false)
             if(ansiEnabled && d == null) {
               throw QueryExecutionErrors.invalidInputInCastToNumberError(
-                DoubleType, s, queryContext)
+                DoubleType, s, getContextOrNull())
             } else {
               d
             }
@@ -1081,7 +1081,7 @@ case class Cast(
             val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
             if (ansiEnabled && f == null) {
               throw QueryExecutionErrors.invalidInputInCastToNumberError(
-                FloatType, s, queryContext)
+                FloatType, s, getContextOrNull())
             } else {
               f
             }
@@ -1196,10 +1196,6 @@ case class Cast(
     }
   }
 
-  def errorContextCode(codegenContext: CodegenContext): String = {
-    codegenContext.addReferenceObj("errCtx", queryContext)
-  }
-
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val eval = child.genCode(ctx)
     val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
@@ -1512,7 +1508,7 @@ case class Cast(
         val intOpt = ctx.freshVariable("intOpt", classOf[Option[Integer]])
         (c, evPrim, evNull) =>
           if (ansiEnabled) {
-            val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+            val errorContext = getContextOrNullCode(ctx)
             code"""
               $evPrim = $dateTimeUtilsCls.stringToDateAnsi($c, $errorContext);
             """
@@ -1556,12 +1552,13 @@ case class Cast(
          |$evPrim = $d;
        """.stripMargin
     } else {
+      val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow)
       val overflowCode = if (nullOnOverflow) {
         s"$evNull = true;"
       } else {
         s"""
            |throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(
-           |  $d, ${decimalType.precision}, ${decimalType.scale}, ${errorContextCode(ctx)});
+           |  $d, ${decimalType.precision}, ${decimalType.scale}, $errorContextCode);
          """.stripMargin
       }
       code"""
@@ -1602,7 +1599,7 @@ case class Cast(
               }
           """
       case StringType if ansiEnabled =>
-        val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+        val errorContext = getContextOrNullCode(ctx)
         val toType = ctx.addReferenceObj("toType", target)
         (c, evPrim, evNull) =>
           code"""
@@ -1679,7 +1676,7 @@ case class Cast(
       val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]])
       (c, evPrim, evNull) =>
         if (ansiEnabled) {
-          val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+          val errorContext = getContextOrNullCode(ctx)
           code"""
             $evPrim = $dateTimeUtilsCls.stringToTimestampAnsi($c, $zid, $errorContext);
            """
@@ -1718,7 +1715,7 @@ case class Cast(
     case DoubleType =>
       (c, evPrim, evNull) =>
         if (ansiEnabled) {
-          val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+          val errorContext = getContextOrNullCode(ctx)
           code"$evPrim = $dateTimeUtilsCls.doubleToTimestampAnsi($c, $errorContext);"
         } else {
           code"""
@@ -1732,7 +1729,7 @@ case class Cast(
     case FloatType =>
       (c, evPrim, evNull) =>
         if (ansiEnabled) {
-          val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+          val errorContext = getContextOrNullCode(ctx)
           code"$evPrim = $dateTimeUtilsCls.doubleToTimestampAnsi((double)$c, $errorContext);"
         } else {
           code"""
@@ -1752,7 +1749,7 @@ case class Cast(
       val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]])
       (c, evPrim, evNull) =>
         if (ansiEnabled) {
-          val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+          val errorContext = getContextOrNullCode(ctx)
           code"""
             $evPrim = $dateTimeUtilsCls.stringToTimestampWithoutTimeZoneAnsi($c, $errorContext);
            """
@@ -1869,7 +1866,7 @@ case class Cast(
       val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}"
       (c, evPrim, evNull) =>
         val castFailureCode = if (ansiEnabled) {
-          val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+          val errorContext = getContextOrNullCode(ctx)
           s"throw QueryExecutionErrors.invalidInputSyntaxForBooleanError($c, $errorContext);"
         } else {
           s"$evNull = true;"
@@ -2004,7 +2001,7 @@ case class Cast(
   private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
     case StringType if ansiEnabled =>
       val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
-      val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+      val errorContext = getContextOrNullCode(ctx)
       (c, evPrim, evNull) => code"$evPrim = $stringUtils.toByteExact($c, $errorContext);"
     case StringType =>
       val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
@@ -2041,7 +2038,7 @@ case class Cast(
       ctx: CodegenContext): CastFunction = from match {
     case StringType if ansiEnabled =>
       val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
-      val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+      val errorContext = getContextOrNullCode(ctx)
       (c, evPrim, evNull) => code"$evPrim = $stringUtils.toShortExact($c, $errorContext);"
     case StringType =>
       val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
@@ -2076,7 +2073,7 @@ case class Cast(
   private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
     case StringType if ansiEnabled =>
       val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
-      val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+      val errorContext = getContextOrNullCode(ctx)
       (c, evPrim, evNull) => code"$evPrim = $stringUtils.toIntExact($c, $errorContext);"
     case StringType =>
       val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
@@ -2111,7 +2108,7 @@ case class Cast(
   private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
     case StringType if ansiEnabled =>
       val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
-      val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+      val errorContext = getContextOrNullCode(ctx)
       (c, evPrim, evNull) => code"$evPrim = $stringUtils.toLongExact($c, $errorContext);"
     case StringType =>
       val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper])
@@ -2148,7 +2145,7 @@ case class Cast(
         val floatStr = ctx.freshVariable("floatStr", StringType)
         (c, evPrim, evNull) =>
           val handleNull = if (ansiEnabled) {
-            val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+            val errorContext = getContextOrNullCode(ctx)
             s"throw QueryExecutionErrors.invalidInputInCastToNumberError(" +
               s"org.apache.spark.sql.types.FloatType$$.MODULE$$,$c, $errorContext);"
           } else {
@@ -2186,7 +2183,7 @@ case class Cast(
         val doubleStr = ctx.freshVariable("doubleStr", StringType)
         (c, evPrim, evNull) =>
           val handleNull = if (ansiEnabled) {
-            val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+            val errorContext = getContextOrNullCode(ctx)
             s"throw QueryExecutionErrors.invalidInputInCastToNumberError(" +
               s"org.apache.spark.sql.types.DoubleType$$.MODULE$$, $c, $errorContext);"
           } else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index d623357b9da..261d9a0cb63 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -597,6 +597,16 @@ trait SupportQueryContext extends Expression with Serializable {
 
   def initQueryContext(): Option[SQLQueryContext]
 
+  def getContextOrNull(): SQLQueryContext = queryContext.getOrElse(null)
+
+  def getContextOrNullCode(ctx: CodegenContext, withErrorContext: Boolean = true): String = {
+    if (withErrorContext && queryContext.isDefined) {
+      ctx.addReferenceObj("errCtx", queryContext.get)
+    } else {
+      "null"
+    }
+  }
+
   // Note: Even though query contexts are serialized to executors, it will be regenerated from an
   //       empty "Origin" during rule transforms since "Origin"s are not serialized to executors
   //       for better performance. Thus, we need to copy the original query context during
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index b749dfdaea1..36ffcd8f764 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -86,7 +86,7 @@ abstract class AverageBase
 
   // If all input are nulls, count will be 0 and we will get null after the division.
   // We can't directly use `/` as it throws an exception under ansi mode.
-  protected def getEvaluateExpression(context: Option[SQLQueryContext]) = child.dataType match {
+  protected def getEvaluateExpression(context: SQLQueryContext = null) = child.dataType match {
     case _: DecimalType =>
       If(EqualTo(count, Literal(0L)),
         Literal(null, resultType),
@@ -141,7 +141,7 @@ case class Average(
 
   override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions
 
-  override lazy val evaluateExpression: Expression = getEvaluateExpression(queryContext)
+  override lazy val evaluateExpression: Expression = getEvaluateExpression(getContextOrNull())
 
   override def initQueryContext(): Option[SQLQueryContext] = if (useAnsiAdd) {
     Some(origin.context)
@@ -206,7 +206,7 @@ case class TryAverage(child: Expression) extends AverageBase {
   }
 
   override lazy val evaluateExpression: Expression = {
-    addTryEvalIfNeeded(getEvaluateExpression(None))
+    addTryEvalIfNeeded(getEvaluateExpression())
   }
 
   override protected def withNewChildInternal(newChild: Expression): Expression =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 9230bd9bf44..e8492c0e5dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -148,14 +148,15 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
    * So now, if ansi is enabled, then throw exception, if not then return null.
    * If sum is not null, then return the sum.
    */
-  protected def getEvaluateExpression(
-      context: Option[SQLQueryContext]): Expression = resultType match {
-    case d: DecimalType =>
-      val checkOverflowInSum = CheckOverflowInSum(sum, d, !useAnsiAdd, context)
-      If(isEmpty, Literal.create(null, resultType), checkOverflowInSum)
-    case _ if shouldTrackIsEmpty =>
-      If(isEmpty, Literal.create(null, resultType), sum)
-    case _ => sum
+  protected def getEvaluateExpression(context: SQLQueryContext = null): Expression = {
+    resultType match {
+      case d: DecimalType =>
+        val checkOverflowInSum = CheckOverflowInSum(sum, d, !useAnsiAdd, context)
+        If(isEmpty, Literal.create(null, resultType), checkOverflowInSum)
+      case _ if shouldTrackIsEmpty =>
+        If(isEmpty, Literal.create(null, resultType), sum)
+      case _ => sum
+    }
   }
 
   // The flag `useAnsiAdd` won't be shown in the `toString` or `toAggString` methods
@@ -192,7 +193,7 @@ case class Sum(
 
   override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions
 
-  override lazy val evaluateExpression: Expression = getEvaluateExpression(queryContext)
+  override lazy val evaluateExpression: Expression = getEvaluateExpression(getContextOrNull())
 
   override def initQueryContext(): Option[SQLQueryContext] = if (useAnsiAdd) {
     Some(origin.context)
@@ -255,9 +256,9 @@ case class TrySum(child: Expression) extends SumBase(child) {
 
   override lazy val evaluateExpression: Expression =
     if (useAnsiAdd) {
-      TryEval(getEvaluateExpression(None))
+      TryEval(getEvaluateExpression())
     } else {
-      getEvaluateExpression(None)
+      getEvaluateExpression()
     }
 
   override protected def withNewChildInternal(newChild: Expression): Expression =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 7bbe5d15b91..86e6e6d7323 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -265,7 +265,7 @@ abstract class BinaryArithmetic extends BinaryOperator
   }
 
   protected def checkDecimalOverflow(value: Decimal, precision: Int, scale: Int): Decimal = {
-    value.toPrecision(precision, scale, Decimal.ROUND_HALF_UP, !failOnError, queryContext)
+    value.toPrecision(precision, scale, Decimal.ROUND_HALF_UP, !failOnError, getContextOrNull())
   }
 
   /** Name of the function for this expression on a [[Decimal]] type. */
@@ -285,11 +285,7 @@ abstract class BinaryArithmetic extends BinaryOperator
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
     case DecimalType.Fixed(precision, scale) =>
-      val errorContextCode = if (failOnError) {
-        ctx.addReferenceObj("errCtx", queryContext)
-      } else {
-        "scala.None$.MODULE$"
-      }
+      val errorContextCode = getContextOrNullCode(ctx, failOnError)
       val updateIsNull = if (failOnError) {
         ""
       } else {
@@ -334,7 +330,7 @@ abstract class BinaryArithmetic extends BinaryOperator
       })
     case IntegerType | LongType if failOnError && exactMathMethod.isDefined =>
       nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
-        val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+        val errorContext = getContextOrNullCode(ctx)
         val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
         s"""
            |${ev.value} = $mathUtils.${exactMathMethod.get}($eval1, $eval2, $errorContext);
@@ -414,9 +410,9 @@ case class Add(
     case _: YearMonthIntervalType =>
       MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
     case _: IntegerType if failOnError =>
-      MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], queryContext)
+      MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], getContextOrNull())
     case _: LongType if failOnError =>
-      MathUtils.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], queryContext)
+      MathUtils.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], getContextOrNull())
     case _ => numeric.plus(input1, input2)
   }
 
@@ -483,9 +479,15 @@ case class Subtract(
     case _: YearMonthIntervalType =>
       MathUtils.subtractExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
     case _: IntegerType if failOnError =>
-      MathUtils.subtractExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], queryContext)
+      MathUtils.subtractExact(
+        input1.asInstanceOf[Int],
+        input2.asInstanceOf[Int],
+        getContextOrNull())
     case _: LongType if failOnError =>
-      MathUtils.subtractExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], queryContext)
+      MathUtils.subtractExact(
+        input1.asInstanceOf[Long],
+        input2.asInstanceOf[Long],
+        getContextOrNull())
     case _ => numeric.minus(input1, input2)
   }
 
@@ -539,9 +541,15 @@ case class Multiply(
     case DecimalType.Fixed(precision, scale) =>
       checkDecimalOverflow(numeric.times(input1, input2).asInstanceOf[Decimal], precision, scale)
     case _: IntegerType if failOnError =>
-      MathUtils.multiplyExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], queryContext)
+      MathUtils.multiplyExact(
+        input1.asInstanceOf[Int],
+        input2.asInstanceOf[Int],
+        getContextOrNull())
     case _: LongType if failOnError =>
-      MathUtils.multiplyExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], queryContext)
+      MathUtils.multiplyExact(
+        input1.asInstanceOf[Long],
+        input2.asInstanceOf[Long],
+        getContextOrNull())
     case _ => numeric.times(input1, input2)
   }
 
@@ -578,10 +586,10 @@ trait DivModLike extends BinaryArithmetic {
       } else {
         if (isZero(input2)) {
           // when we reach here, failOnError must be true.
-          throw QueryExecutionErrors.divideByZeroError(queryContext)
+          throw QueryExecutionErrors.divideByZeroError(getContextOrNull())
         }
         if (checkDivideOverflow && input1 == Long.MinValue && input2 == -1) {
-          throw QueryExecutionErrors.overflowInIntegralDivideError(queryContext)
+          throw QueryExecutionErrors.overflowInIntegralDivideError(getContextOrNull())
         }
         evalOperation(input1, input2)
       }
@@ -603,11 +611,7 @@ trait DivModLike extends BinaryArithmetic {
       s"${eval2.value} == 0"
     }
     val javaType = CodeGenerator.javaType(dataType)
-    val errorContextCode = if (failOnError) {
-      ctx.addReferenceObj("errCtx", queryContext)
-    } else {
-      "scala.None$.MODULE$"
-    }
+    val errorContextCode = getContextOrNullCode(ctx, failOnError)
     val operation = super.dataType match {
       case DecimalType.Fixed(precision, scale) =>
         val decimalValue = ctx.freshName("decimalValue")
@@ -962,7 +966,7 @@ case class Pmod(
       } else {
         if (isZero(input2)) {
           // when we reach here, failOnError must bet true.
-          throw QueryExecutionErrors.divideByZeroError(queryContext)
+          throw QueryExecutionErrors.divideByZeroError(getContextOrNull())
         }
         pmodFunc(input1, input2)
       }
@@ -979,7 +983,7 @@ case class Pmod(
     }
     val remainder = ctx.freshName("remainder")
     val javaType = CodeGenerator.javaType(dataType)
-    val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+    val errorContext = getContextOrNullCode(ctx)
     val result = dataType match {
       case DecimalType.Fixed(precision, scale) =>
         val decimalAdd = "$plus"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 098b3a88084..ae23775b62d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -2174,7 +2174,7 @@ case class ElementAt(
         if (array.numElements() < math.abs(index)) {
           if (failOnError) {
             throw QueryExecutionErrors.invalidElementAtIndexError(
-              index, array.numElements(), queryContext)
+              index, array.numElements(), getContextOrNull())
           } else {
             defaultValueOutOfBound match {
               case Some(value) => value.eval()
@@ -2216,7 +2216,7 @@ case class ElementAt(
           }
 
           val indexOutOfBoundBranch = if (failOnError) {
-            val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+            val errorContext = getContextOrNullCode(ctx)
             // scalastyle:off line.size.limit
             s"throw QueryExecutionErrors.invalidElementAtIndexError($index, $eval1.numElements(), $errorContext);"
             // scalastyle:on line.size.limit
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index fedfcfb978f..b6cbb1d0005 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -268,7 +268,7 @@ case class GetArrayItem(
     if (index >= baseValue.numElements() || index < 0) {
       if (failOnError) {
         throw QueryExecutionErrors.invalidArrayIndexError(
-          index, baseValue.numElements, queryContext)
+          index, baseValue.numElements, getContextOrNull())
       } else {
         null
       }
@@ -292,7 +292,7 @@ case class GetArrayItem(
       }
 
       val indexOutOfBoundBranch = if (failOnError) {
-        val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+        val errorContext = getContextOrNullCode(ctx)
         // scalastyle:off line.size.limit
         s"throw QueryExecutionErrors.invalidArrayIndexError($index, $eval1.numElements(), $errorContext);"
         // scalastyle:on line.size.limit
@@ -380,7 +380,7 @@ trait GetMapValueUtil
 
     if (!found) {
       if (failOnError) {
-        throw QueryExecutionErrors.mapKeyNotExistError(ordinal, keyType, queryContext)
+        throw QueryExecutionErrors.mapKeyNotExistError(ordinal, keyType, getContextOrNull())
       } else {
         null
       }
@@ -413,7 +413,7 @@ trait GetMapValueUtil
     }
 
     val keyJavaType = CodeGenerator.javaType(keyType)
-    lazy val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+    lazy val errorContext = getContextOrNullCode(ctx)
     val keyDt = ctx.addReferenceObj("keyType", keyType, keyType.getClass.getName)
     nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
       val keyNotFoundBranch = if (failOnError) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index e672fffda19..37e3dd5ea89 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -123,14 +123,10 @@ case class CheckOverflow(
       dataType.scale,
       Decimal.ROUND_HALF_UP,
       nullOnOverflow,
-      queryContext)
+      getContextOrNull())
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val errorContextCode = if (nullOnOverflow) {
-      "scala.None$.MODULE$"
-    } else {
-      ctx.addReferenceObj("errCtx", queryContext)
-    }
+    val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow)
     nullSafeCodeGen(ctx, ev, eval => {
       // scalastyle:off line.size.limit
       s"""
@@ -161,7 +157,7 @@ case class CheckOverflowInSum(
     child: Expression,
     dataType: DecimalType,
     nullOnOverflow: Boolean,
-    context: Option[SQLQueryContext] = None) extends UnaryExpression {
+    context: SQLQueryContext) extends UnaryExpression with SupportQueryContext {
 
   override def nullable: Boolean = true
 
@@ -182,11 +178,7 @@ case class CheckOverflowInSum(
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val childGen = child.genCode(ctx)
-    val errorContextCode = if (nullOnOverflow) {
-      "scala.None$.MODULE$"
-    } else {
-      ctx.addReferenceObj("errCtx", context)
-    }
+    val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow)
     val nullHandling = if (nullOnOverflow) {
       ""
     } else {
@@ -216,6 +208,8 @@ case class CheckOverflowInSum(
 
   override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum =
     copy(child = newChild)
+
+  override def initQueryContext(): Option[SQLQueryContext] = Option(context)
 }
 
 /**
@@ -261,12 +255,12 @@ case class DecimalDivideWithOverflowCheck(
     left: Expression,
     right: Expression,
     override val dataType: DecimalType,
-    context: Option[SQLQueryContext],
+    context: SQLQueryContext,
     nullOnOverflow: Boolean)
   extends BinaryExpression with ExpectsInputTypes with SupportQueryContext {
   override def nullable: Boolean = nullOnOverflow
   override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, DecimalType)
-  override def initQueryContext(): Option[SQLQueryContext] = context
+  override def initQueryContext(): Option[SQLQueryContext] = Option(context)
   def decimalMethod: String = "$div"
 
   override def eval(input: InternalRow): Any = {
@@ -275,22 +269,18 @@ case class DecimalDivideWithOverflowCheck(
       if (nullOnOverflow)  {
         null
       } else {
-        throw QueryExecutionErrors.overflowInSumOfDecimalError(queryContext)
+        throw QueryExecutionErrors.overflowInSumOfDecimalError(getContextOrNull())
       }
     } else {
       val value2 = right.eval(input)
       dataType.fractional.asInstanceOf[Fractional[Any]].div(value1, value2).asInstanceOf[Decimal]
         .toPrecision(dataType.precision, dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow,
-          queryContext)
+          getContextOrNull())
     }
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val errorContextCode = if (nullOnOverflow) {
-      "scala.None$.MODULE$"
-    } else {
-      ctx.addReferenceObj("errCtx", queryContext)
-    }
+    val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow)
     val nullHandling = if (nullOnOverflow) {
       ""
     } else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
index 17a2714c611..f7ec82de11b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
@@ -604,7 +604,7 @@ trait IntervalDivide {
       minValue: Any,
       num: Expression,
       numValue: Any,
-      context: Option[SQLQueryContext]): Unit = {
+      context: SQLQueryContext): Unit = {
     if (value == minValue && num.dataType.isInstanceOf[IntegralType]) {
       if (numValue.asInstanceOf[Number].longValue() == -1) {
         throw QueryExecutionErrors.overflowInIntegralDivideError(context)
@@ -615,7 +615,7 @@ trait IntervalDivide {
   def divideByZeroCheck(
       dataType: DataType,
       num: Any,
-      context: Option[SQLQueryContext]): Unit = dataType match {
+      context: SQLQueryContext): Unit = dataType match {
     case _: DecimalType =>
       if (num.asInstanceOf[Decimal].isZero) {
         throw QueryExecutionErrors.intervalDividedByZeroError(context)
@@ -665,13 +665,13 @@ case class DivideYMInterval(
 
   override def nullSafeEval(interval: Any, num: Any): Any = {
     checkDivideOverflow(
-      interval.asInstanceOf[Int], Int.MinValue, right, num, Some(origin.context))
-    divideByZeroCheck(right.dataType, num, Some(origin.context))
+      interval.asInstanceOf[Int], Int.MinValue, right, num, origin.context)
+    divideByZeroCheck(right.dataType, num, origin.context)
     evalFunc(interval.asInstanceOf[Int], num)
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val errorContext = ctx.addReferenceObj("errCtx", Some(origin.context))
+    val errorContext = ctx.addReferenceObj("errCtx", origin.context)
     right.dataType match {
       case t: IntegralType =>
         val math = t match {
@@ -743,13 +743,13 @@ case class DivideDTInterval(
 
   override def nullSafeEval(interval: Any, num: Any): Any = {
     checkDivideOverflow(
-      interval.asInstanceOf[Long], Long.MinValue, right, num, Some(origin.context))
-    divideByZeroCheck(right.dataType, num, Some(origin.context))
+      interval.asInstanceOf[Long], Long.MinValue, right, num, origin.context)
+    divideByZeroCheck(right.dataType, num, origin.context)
     evalFunc(interval.asInstanceOf[Long], num)
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val errorContext = ctx.addReferenceObj("errCtx", Some(origin.context))
+    val errorContext = ctx.addReferenceObj("errCtx", origin.context)
     right.dataType match {
       case _: IntegralType =>
         val math = classOf[LongMath].getName
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 55ff36e9863..dfbc041b259 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -1520,7 +1520,7 @@ abstract class RoundBase(child: Expression, scale: Expression,
         if (_scale >= 0) {
           s"""
             ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s,
-            Decimal.$modeStr(), true, scala.None$$.MODULE$$);
+            Decimal.$modeStr(), true, null);
             ${ev.isNull} = ${ev.value} == null;"""
        } else {
           s"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 815eb8977b6..d4504c36e4e 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -296,7 +296,8 @@ case class Elt(
       val index = indexObj.asInstanceOf[Int]
       if (index <= 0 || index > inputExprs.length) {
         if (failOnError) {
-          throw QueryExecutionErrors.invalidArrayIndexError(index, inputExprs.length, queryContext)
+          throw QueryExecutionErrors.invalidArrayIndexError(
+            index, inputExprs.length, getContextOrNull())
         } else {
           null
         }
@@ -348,7 +349,7 @@ case class Elt(
       }.mkString)
 
     val indexOutOfBoundBranch = if (failOnError) {
-      val errorContext = ctx.addReferenceObj("errCtx", queryContext)
+      val errorContext = getContextOrNullCode(ctx)
       // scalastyle:off line.size.limit
       s"""
          |if (!$indexMatched) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index 172c2e54034..af0666a98fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -468,14 +468,14 @@ object DateTimeUtils {
   def stringToTimestampAnsi(
       s: UTF8String,
       timeZoneId: ZoneId,
-      context: Option[SQLQueryContext] = None): Long = {
+      context: SQLQueryContext = null): Long = {
     stringToTimestamp(s, timeZoneId).getOrElse {
       throw QueryExecutionErrors.invalidInputInCastToDatetimeError(
         s, StringType, TimestampType, context)
     }
   }
 
-  def doubleToTimestampAnsi(d: Double, context: Option[SQLQueryContext]): Long = {
+  def doubleToTimestampAnsi(d: Double, context: SQLQueryContext): Long = {
     if (d.isNaN || d.isInfinite) {
       throw QueryExecutionErrors.invalidInputInCastToDatetimeError(
         d, DoubleType, TimestampType, context)
@@ -527,7 +527,7 @@ object DateTimeUtils {
 
   def stringToTimestampWithoutTimeZoneAnsi(
       s: UTF8String,
-      context: Option[SQLQueryContext]): Long = {
+      context: SQLQueryContext): Long = {
     stringToTimestampWithoutTimeZone(s, true).getOrElse {
       throw QueryExecutionErrors.invalidInputInCastToDatetimeError(
         s, StringType, TimestampNTZType, context)
@@ -646,7 +646,9 @@ object DateTimeUtils {
     }
   }
 
-  def stringToDateAnsi(s: UTF8String, context: Option[SQLQueryContext] = None): Int = {
+  def stringToDateAnsi(
+      s: UTF8String,
+      context: SQLQueryContext = null): Int = {
     stringToDate(s).getOrElse {
       throw QueryExecutionErrors.invalidInputInCastToDatetimeError(
         s, StringType, DateType, context)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
index b4695062c08..f2c4236ad7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
@@ -733,7 +733,7 @@ object IntervalUtils {
    * @throws ArithmeticException if the result overflows any field value or divided by zero
    */
   def divideExact(interval: CalendarInterval, num: Double): CalendarInterval = {
-    if (num == 0) throw QueryExecutionErrors.intervalDividedByZeroError(None)
+    if (num == 0) throw QueryExecutionErrors.intervalDividedByZeroError(null)
     fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num)
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
index 6cb3616d4e7..e79e483076d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
@@ -27,37 +27,37 @@ object MathUtils {
 
   def addExact(a: Int, b: Int): Int = withOverflow(Math.addExact(a, b))
 
-  def addExact(a: Int, b: Int, context: Option[SQLQueryContext]): Int = {
+  def addExact(a: Int, b: Int, context: SQLQueryContext): Int = {
     withOverflow(Math.addExact(a, b), hint = "try_add", context)
   }
 
   def addExact(a: Long, b: Long): Long = withOverflow(Math.addExact(a, b))
 
-  def addExact(a: Long, b: Long, context: Option[SQLQueryContext]): Long = {
+  def addExact(a: Long, b: Long, context: SQLQueryContext): Long = {
     withOverflow(Math.addExact(a, b), hint = "try_add", context)
   }
 
   def subtractExact(a: Int, b: Int): Int = withOverflow(Math.subtractExact(a, b))
 
-  def subtractExact(a: Int, b: Int, context: Option[SQLQueryContext]): Int = {
+  def subtractExact(a: Int, b: Int, context: SQLQueryContext): Int = {
     withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context)
   }
 
   def subtractExact(a: Long, b: Long): Long = withOverflow(Math.subtractExact(a, b))
 
-  def subtractExact(a: Long, b: Long, context: Option[SQLQueryContext]): Long = {
+  def subtractExact(a: Long, b: Long, context: SQLQueryContext): Long = {
     withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context)
   }
 
   def multiplyExact(a: Int, b: Int): Int = withOverflow(Math.multiplyExact(a, b))
 
-  def multiplyExact(a: Int, b: Int, context: Option[SQLQueryContext]): Int = {
+  def multiplyExact(a: Int, b: Int, context: SQLQueryContext): Int = {
     withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context)
   }
 
   def multiplyExact(a: Long, b: Long): Long = withOverflow(Math.multiplyExact(a, b))
 
-  def multiplyExact(a: Long, b: Long, context: Option[SQLQueryContext]): Long = {
+  def multiplyExact(a: Long, b: Long, context: SQLQueryContext): Long = {
     withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context)
   }
 
@@ -78,7 +78,7 @@ object MathUtils {
   private def withOverflow[A](
       f: => A,
       hint: String = "",
-      context: Option[SQLQueryContext] = None): A = {
+      context: SQLQueryContext = null): A = {
     try {
       f
     } catch {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala
index 503c0e181ca..f7800469c35 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala
@@ -27,21 +27,21 @@ import org.apache.spark.unsafe.types.UTF8String
  */
 object UTF8StringUtils {
 
-  def toLongExact(s: UTF8String, context: Option[SQLQueryContext]): Long =
+  def toLongExact(s: UTF8String, context: SQLQueryContext): Long =
     withException(s.toLongExact, context, LongType, s)
 
-  def toIntExact(s: UTF8String, context: Option[SQLQueryContext]): Int =
+  def toIntExact(s: UTF8String, context: SQLQueryContext): Int =
     withException(s.toIntExact, context, IntegerType, s)
 
-  def toShortExact(s: UTF8String, context: Option[SQLQueryContext]): Short =
+  def toShortExact(s: UTF8String, context: SQLQueryContext): Short =
     withException(s.toShortExact, context, ShortType, s)
 
-  def toByteExact(s: UTF8String, context: Option[SQLQueryContext]): Byte =
+  def toByteExact(s: UTF8String, context: SQLQueryContext): Byte =
     withException(s.toByteExact, context, ByteType, s)
 
   private def withException[A](
       f: => A,
-      context: Option[SQLQueryContext],
+      context: SQLQueryContext,
       to: DataType,
       s: UTF8String): A = {
     try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala
index 9617f7d4b0f..4785073f80b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.errors
 
 import java.util.Locale
 
+import org.apache.spark.QueryContext
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
 import org.apache.spark.sql.catalyst.trees.SQLQueryContext
@@ -97,7 +98,11 @@ private[sql] trait QueryErrorsBase {
     quoteByDefault(toPrettySQL(e))
   }
 
-  def getSummary(context: Option[SQLQueryContext]): String = {
-    context.map(_.summary).getOrElse("")
+  def getSummary(sqlContext: SQLQueryContext): String = {
+    if (sqlContext == null) "" else sqlContext.summary
+  }
+
+  def getQueryContext(sqlContext: SQLQueryContext): Array[QueryContext] = {
+    if (sqlContext == null) Array.empty else Array(sqlContext.asInstanceOf[QueryContext])
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index bad95afa139..3644e7c0df8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -89,7 +89,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
         toSQLType(from),
         toSQLType(to),
         toSQLConf(SQLConf.ANSI_ENABLED.key)),
-      context = None,
+      context = Array.empty,
       summary = "")
   }
 
@@ -103,7 +103,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
         toSQLType(from),
         toSQLType(to),
         toSQLId(columnName)),
-      context = None,
+      context = Array.empty,
       summary = ""
     )
   }
@@ -112,7 +112,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
       value: Decimal,
       decimalPrecision: Int,
       decimalScale: Int,
-      context: Option[SQLQueryContext] = None): ArithmeticException = {
+      context: SQLQueryContext = null): ArithmeticException = {
     new SparkArithmeticException(
       errorClass = "CANNOT_CHANGE_DECIMAL_PRECISION",
       messageParameters = Array(
@@ -120,7 +120,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
         decimalPrecision.toString,
         decimalScale.toString,
         toSQLConf(SQLConf.ANSI_ENABLED.key)),
-      context = context,
+      context = getQueryContext(context),
       summary = getSummary(context))
   }
 
@@ -128,7 +128,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
       value: Any,
       from: DataType,
       to: DataType,
-      context: Option[SQLQueryContext]): Throwable = {
+      context: SQLQueryContext): Throwable = {
     new SparkDateTimeException(
       errorClass = "CAST_INVALID_INPUT",
       messageParameters = Array(
@@ -136,13 +136,13 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
         toSQLType(from),
         toSQLType(to),
         toSQLConf(SQLConf.ANSI_ENABLED.key)),
-      context = context,
+      context = getQueryContext(context),
       summary = getSummary(context))
   }
 
   def invalidInputSyntaxForBooleanError(
       s: UTF8String,
-      context: Option[SQLQueryContext]): SparkRuntimeException = {
+      context: SQLQueryContext): SparkRuntimeException = {
     new SparkRuntimeException(
       errorClass = "CAST_INVALID_INPUT",
       messageParameters = Array(
@@ -150,14 +150,14 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
         toSQLType(StringType),
         toSQLType(BooleanType),
         toSQLConf(SQLConf.ANSI_ENABLED.key)),
-      context = context,
+      context = getQueryContext(context),
       summary = getSummary(context))
   }
 
   def invalidInputInCastToNumberError(
       to: DataType,
       s: UTF8String,
-      context: Option[SQLQueryContext]): SparkNumberFormatException = {
+      context: SQLQueryContext): SparkNumberFormatException = {
     new SparkNumberFormatException(
       errorClass = "CAST_INVALID_INPUT",
       messageParameters = Array(
@@ -165,7 +165,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
         toSQLType(StringType),
         toSQLType(to),
         toSQLConf(SQLConf.ANSI_ENABLED.key)),
-      context = context,
+      context = getQueryContext(context),
       summary = getSummary(context))
   }
 
@@ -196,40 +196,40 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
       messageParameters = Array(funcCls, inputTypes, outputType), e)
   }
 
-  def divideByZeroError(context: Option[SQLQueryContext]): ArithmeticException = {
+  def divideByZeroError(context: SQLQueryContext): ArithmeticException = {
     new SparkArithmeticException(
       errorClass = "DIVIDE_BY_ZERO",
       messageParameters = Array(toSQLConf(SQLConf.ANSI_ENABLED.key)),
-      context = context,
+      context = getQueryContext(context),
       summary = getSummary(context))
   }
 
-  def intervalDividedByZeroError(context: Option[SQLQueryContext]): ArithmeticException = {
+  def intervalDividedByZeroError(context: SQLQueryContext): ArithmeticException = {
     new SparkArithmeticException(
       errorClass = "INTERVAL_DIVIDED_BY_ZERO",
       messageParameters = Array.empty,
-      context = context,
+      context = getQueryContext(context),
       summary = getSummary(context))
   }
 
   def invalidArrayIndexError(
       index: Int,
       numElements: Int,
-      context: Option[SQLQueryContext]): ArrayIndexOutOfBoundsException = {
+      context: SQLQueryContext): ArrayIndexOutOfBoundsException = {
     new SparkArrayIndexOutOfBoundsException(
       errorClass = "INVALID_ARRAY_INDEX",
       messageParameters = Array(
         toSQLValue(index, IntegerType),
         toSQLValue(numElements, IntegerType),
         toSQLConf(SQLConf.ANSI_ENABLED.key)),
-      context = context,
+      context = getQueryContext(context),
       summary = getSummary(context))
   }
 
   def invalidElementAtIndexError(
       index: Int,
       numElements: Int,
-      context: Option[SQLQueryContext]): ArrayIndexOutOfBoundsException = {
+      context: SQLQueryContext): ArrayIndexOutOfBoundsException = {
     new SparkArrayIndexOutOfBoundsException(
       errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT",
       messageParameters =
@@ -237,20 +237,20 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
           toSQLValue(index, IntegerType),
           toSQLValue(numElements, IntegerType),
           toSQLConf(SQLConf.ANSI_ENABLED.key)),
-      context = context,
+      context = getQueryContext(context),
       summary = getSummary(context))
   }
 
   def mapKeyNotExistError(
       key: Any,
       dataType: DataType,
-      context: Option[SQLQueryContext]): NoSuchElementException = {
+      context: SQLQueryContext): NoSuchElementException = {
     new SparkNoSuchElementException(
       errorClass = "MAP_KEY_DOES_NOT_EXIST",
       messageParameters = Array(
         toSQLValue(key, dataType),
         toSQLConf(SQLConf.ANSI_ENABLED.key)),
-      context = context,
+      context = getQueryContext(context),
       summary = getSummary(context))
   }
 
@@ -259,7 +259,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
       errorClass = "INVALID_FRACTION_OF_SECOND",
       errorSubClass = None,
       Array(toSQLConf(SQLConf.ANSI_ENABLED.key)),
-      context = None,
+      context = Array.empty,
       summary = "")
   }
 
@@ -268,7 +268,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
       errorClass = "CANNOT_PARSE_TIMESTAMP",
       errorSubClass = None,
       Array(e.getMessage, toSQLConf(SQLConf.ANSI_ENABLED.key)),
-      context = None,
+      context = Array.empty,
       summary = "")
   }
 
@@ -294,11 +294,11 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
     ansiIllegalArgumentError(e.getMessage)
   }
 
-  def overflowInSumOfDecimalError(context: Option[SQLQueryContext]): ArithmeticException = {
+  def overflowInSumOfDecimalError(context: SQLQueryContext): ArithmeticException = {
     arithmeticOverflowError("Overflow in sum of decimals", context = context)
   }
 
-  def overflowInIntegralDivideError(context: Option[SQLQueryContext]): ArithmeticException = {
+  def overflowInIntegralDivideError(context: SQLQueryContext): ArithmeticException = {
     arithmeticOverflowError("Overflow in integral divide", "try_divide", context)
   }
 
@@ -514,14 +514,14 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
   def arithmeticOverflowError(
       message: String,
       hint: String = "",
-      context: Option[SQLQueryContext] = None): ArithmeticException = {
+      context: SQLQueryContext = null): ArithmeticException = {
     val alternative = if (hint.nonEmpty) {
       s" Use '$hint' to tolerate overflow and return NULL instead."
     } else ""
     new SparkArithmeticException(
       errorClass = "ARITHMETIC_OVERFLOW",
       messageParameters = Array(message, alternative, SQLConf.ANSI_ENABLED.key),
-      context = context,
+      context = getQueryContext(context),
       summary = getSummary(context))
   }
 
@@ -2061,7 +2061,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
       messageParameters = Array(
         s"add ${toSQLValue(amount, IntegerType)} $unit to " +
         s"${toSQLValue(DateTimeUtils.microsToInstant(micros), TimestampType)}"),
-      context = None,
+      context = Array.empty,
       summary = "")
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 00172f69fda..aa683a06a8e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -367,7 +367,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
       scale: Int,
       roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP,
       nullOnOverflow: Boolean = true,
-      context: Option[SQLQueryContext] = None): Decimal = {
+      context: SQLQueryContext = null): Decimal = {
     val copy = clone()
     if (copy.changePrecision(precision, scale, roundMode)) {
       copy
@@ -632,7 +632,7 @@ object Decimal {
   def fromStringANSI(
       str: UTF8String,
       to: DecimalType = DecimalType.USER_DEFAULT,
-      context: Option[SQLQueryContext] = None): Decimal = {
+      context: SQLQueryContext = null): Decimal = {
     try {
       val bigDecimal = stringToJavaBigDecimal(str)
       // We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
index d96ca4b87f0..513a62dc7f0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
@@ -91,7 +91,7 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkExceptionInExpression[ArithmeticException](expr1, query)
 
     val expr2 = CheckOverflowInSum(
-      Literal(d), DecimalType(4, 3), false, context = Some(origin.context))
+      Literal(d), DecimalType(4, 3), false, context = origin.context)
     checkExceptionInExpression[ArithmeticException](expr2, query)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org