You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2022/04/12 12:39:53 UTC

[spark] branch master updated: [SPARK-38589][SQL] New SQL function: try_avg

This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a7f0adb2dd8 [SPARK-38589][SQL] New SQL function: try_avg
a7f0adb2dd8 is described below

commit a7f0adb2dd8449af6f9e9b5a25f11b5dcf5868f1
Author: Gengliang Wang <ge...@apache.org>
AuthorDate: Tue Apr 12 20:39:08 2022 +0800

    [SPARK-38589][SQL] New SQL function: try_avg
    
    ### What changes were proposed in this pull request?
    
    Add a new SQL function: try_avg. It is identical to the function `avg`, except that it returns NULL result instead of throwing an exception on decimal/interval value overflow.
    Note it is also different from `avg` when ANSI mode is off on interval overflows
    | Function         | avg                                | try_avg      |
    |------------------|------------------------------------|-------------|
    | year-month interval overflow | Error                       | Return NULL |
    | day-time interval overflow | Error | Return NULL |
    
    ### Why are the changes needed?
    
    * Users can manage to finish queries without interruptions in ANSI mode.
    * Users can get NULLs instead of runtime errors if interval overflow occurs when ANSI mode is off. For example
    ```
    > SELECT avg(col) FROM VALUES (interval '2147483647 months'),(interval '1 months') AS tab(col)
    java.lang.ArithmeticException: integer overflow.
    
    > SELECT try_avg(col) FROM VALUES (interval '2147483647 months'),(interval '1 months') AS tab(col)
    NULL
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, adding a new SQL function: try_avg. It is identical to the function `avg`, except that it returns NULL result instead of throwing an exception on decimal/interval value overflow.
    
    ### How was this patch tested?
    
    UT
    
    Closes #35896 from gengliangwang/tryAvg.
    
    Lead-authored-by: Gengliang Wang <ge...@apache.org>
    Co-authored-by: Gengliang Wang <lt...@gmail.com>
    Signed-off-by: Gengliang Wang <ge...@apache.org>
---
 docs/sql-ref-ansi-compliance.md                    |   3 +-
 .../sql/catalyst/analysis/FunctionRegistry.scala   |   1 +
 .../catalyst/expressions/aggregate/Average.scala   | 125 +++++++++++++++++----
 .../sql/catalyst/expressions/aggregate/Sum.scala   |  35 +++---
 .../sql-functions/sql-expression-schema.md         |   5 +-
 .../resources/sql-tests/inputs/try_aggregates.sql  |  14 +++
 .../sql-tests/results/ansi/try_aggregates.sql.out  |  82 +++++++++++++-
 .../sql-tests/results/try_aggregates.sql.out       |  82 +++++++++++++-
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala |  12 ++
 9 files changed, 313 insertions(+), 46 deletions(-)

diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md
index 0f7f29cde7f..66161a112b1 100644
--- a/docs/sql-ref-ansi-compliance.md
+++ b/docs/sql-ref-ansi-compliance.md
@@ -316,7 +316,8 @@ When ANSI mode is on, it throws exceptions for invalid operations. You can use t
   - `try_subtract`: identical to the add operator `-`, except that it returns `NULL` result instead of throwing an exception on integral value overflow.
   - `try_multiply`: identical to the add operator `*`, except that it returns `NULL` result instead of throwing an exception on integral value overflow.
   - `try_divide`: identical to the division operator `/`, except that it returns `NULL` result instead of throwing an exception on dividing 0.
-  - `try_sum`: identical to the function `sum`, except that it returns `NULL` result instead of throwing an exception on integral/decimal value overflow.
+  - `try_sum`: identical to the function `sum`, except that it returns `NULL` result instead of throwing an exception on integral/decimal/interval value overflow.
+  - `try_avg`: identical to the function `avg`, except that it returns `NULL` result instead of throwing an exception on decimal/interval value overflow.
   - `try_element_at`: identical to the function `element_at`, except that it returns `NULL` result instead of throwing an exception on array's index out of bound or map's key not found.
 
 ### SQL Keywords (optional, disabled by default)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 1824fb68f76..80374f769a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -453,6 +453,7 @@ object FunctionRegistry {
     expression[TrySubtract]("try_subtract"),
     expression[TryMultiply]("try_multiply"),
     expression[TryElementAt]("try_element_at"),
+    expression[TryAverage]("try_avg"),
     expression[TrySum]("try_sum"),
     expression[TryToBinary]("try_to_binary"),
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 533f7f20b25..14914576091 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -26,25 +26,13 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
-@ExpressionDescription(
-  usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.",
-  examples = """
-    Examples:
-      > SELECT _FUNC_(col) FROM VALUES (1), (2), (3) AS tab(col);
-       2.0
-      > SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col);
-       1.5
-  """,
-  group = "agg_funcs",
-  since = "1.0.0")
-case class Average(
-    child: Expression,
-    failOnError: Boolean = SQLConf.get.ansiEnabled)
+abstract class AverageBase
   extends DeclarativeAggregate
   with ImplicitCastInputTypes
   with UnaryLike[Expression] {
 
-  def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled)
+  // Whether to use ANSI add or not during the execution.
+  def useAnsiAdd: Boolean
 
   override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")
 
@@ -61,7 +49,7 @@ case class Average(
 
   final override val nodePatterns: Seq[TreePattern] = Seq(AVERAGE)
 
-  private lazy val resultType = child.dataType match {
+  protected lazy val resultType = child.dataType match {
     case DecimalType.Fixed(p, s) =>
       DecimalType.bounded(p + 4, s + 4)
     case _: YearMonthIntervalType => YearMonthIntervalType()
@@ -86,18 +74,18 @@ case class Average(
     /* count = */ Literal(0L)
   )
 
-  override lazy val mergeExpressions = Seq(
-    /* sum = */ sum.left + sum.right,
+  protected def getMergeExpressions = Seq(
+    /* sum = */ Add(sum.left, sum.right, useAnsiAdd),
     /* count = */ count.left + count.right
   )
 
   // If all input are nulls, count will be 0 and we will get null after the division.
   // We can't directly use `/` as it throws an exception under ansi mode.
-  override lazy val evaluateExpression = child.dataType match {
+  protected def getEvaluateExpression = child.dataType match {
     case _: DecimalType =>
       DecimalPrecision.decimalAndDecimal()(
         Divide(
-          CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !failOnError),
+          CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !useAnsiAdd),
           count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
     case _: YearMonthIntervalType =>
       If(EqualTo(count, Literal(0L)),
@@ -109,17 +97,106 @@ case class Average(
       Divide(sum.cast(resultType), count.cast(resultType), failOnError = false)
   }
 
-  override lazy val updateExpressions: Seq[Expression] = Seq(
+  protected def getUpdateExpressions: Seq[Expression] = Seq(
     /* sum = */
     Add(
       sum,
-      coalesce(child.cast(sumDataType), Literal.default(sumDataType))),
+      coalesce(child.cast(sumDataType), Literal.default(sumDataType)),
+      failOnError = useAnsiAdd),
     /* count = */ If(child.isNull, count, count + 1L)
   )
 
+  // The flag `useAnsiAdd` won't be shown in the `toString` or `toAggString` methods
+  override def flatArguments: Iterator[Any] = Iterator(child)
+}
+
+@ExpressionDescription(
+  usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(col) FROM VALUES (1), (2), (3) AS tab(col);
+       2.0
+      > SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col);
+       1.5
+  """,
+  group = "agg_funcs",
+  since = "1.0.0")
+case class Average(
+    child: Expression,
+    useAnsiAdd: Boolean = SQLConf.get.ansiEnabled) extends AverageBase {
+  def this(child: Expression) = this(child, useAnsiAdd = SQLConf.get.ansiEnabled)
+
   override protected def withNewChildInternal(newChild: Expression): Average =
     copy(child = newChild)
 
-  // The flag `failOnError` won't be shown in the `toString` or `toAggString` methods
-  override def flatArguments: Iterator[Any] = Iterator(child)
+  override lazy val updateExpressions: Seq[Expression] = getUpdateExpressions
+
+  override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions
+
+  override lazy val evaluateExpression: Expression = getEvaluateExpression
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "_FUNC_(expr) - Returns the mean calculated from values of a group and the result is null on overflow.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(col) FROM VALUES (1), (2), (3) AS tab(col);
+       2.0
+      > SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col);
+       1.5
+      > SELECT _FUNC_(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col);
+       NULL
+  """,
+  group = "agg_funcs",
+  since = "3.3.0")
+// scalastyle:on line.size.limit
+case class TryAverage(child: Expression) extends AverageBase {
+  override def useAnsiAdd: Boolean = resultType match {
+    // Double type won't fail, thus we can always use non-Ansi Add.
+    // For decimal type, it returns NULL on overflow. It behaves the same as TrySum when
+    // `failOnError` is false.
+    case _: DoubleType | _: DecimalType => false
+    case _ => true
+  }
+
+  private def addTryEvalIfNeeded(expression: Expression): Expression = {
+    if (useAnsiAdd) {
+      TryEval(expression)
+    } else {
+      expression
+    }
+  }
+
+  override lazy val updateExpressions: Seq[Expression] = {
+    val expressions = getUpdateExpressions
+    addTryEvalIfNeeded(expressions.head) +: expressions.tail
+  }
+
+  override lazy val mergeExpressions: Seq[Expression] = {
+    val expressions = getMergeExpressions
+    if (useAnsiAdd) {
+      val bufferOverflow = sum.left.isNull && count.left > 0L
+      val inputOverflow = sum.right.isNull && count.right > 0L
+      Seq(
+        If(
+          bufferOverflow || inputOverflow,
+          Literal.create(null, resultType),
+          // If both the buffer and the input do not overflow, just add them, as they can't be
+          // null.
+          TryEval(Add(KnownNotNull(sum.left), KnownNotNull(sum.right), useAnsiAdd))),
+          expressions(1))
+    } else {
+      expressions
+    }
+  }
+
+  override lazy val evaluateExpression: Expression = {
+    addTryEvalIfNeeded(getEvaluateExpression)
+  }
+
+  override protected def withNewChildInternal(newChild: Expression): Expression =
+    copy(child = newChild)
+
+  override def prettyName: String = "try_avg"
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index fd27edfc8fc..f2c6925b837 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -30,7 +30,8 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
   with ImplicitCastInputTypes
   with UnaryLike[Expression] {
 
-  def failOnError: Boolean
+  // Whether to use ANSI add or not during the execution.
+  def useAnsiAdd: Boolean
 
   protected def shouldTrackIsEmpty: Boolean
 
@@ -81,9 +82,9 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
     // null if overflow happens under non-ansi mode.
     val sumExpr = if (child.nullable) {
       If(child.isNull, sum,
-        Add(sum, KnownNotNull(child).cast(resultType), failOnError = failOnError))
+        Add(sum, KnownNotNull(child).cast(resultType), failOnError = useAnsiAdd))
     } else {
-      Add(sum, child.cast(resultType), failOnError = failOnError)
+      Add(sum, child.cast(resultType), failOnError = useAnsiAdd)
     }
     // The buffer becomes non-empty after seeing the first not-null input.
     val isEmptyExpr = if (child.nullable) {
@@ -98,10 +99,10 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
     // in case the input is nullable. The `sum` can only be null if there is no value, as
     // non-decimal type can produce overflowed value under non-ansi mode.
     if (child.nullable) {
-      Seq(coalesce(Add(coalesce(sum, zero), child.cast(resultType), failOnError = failOnError),
+      Seq(coalesce(Add(coalesce(sum, zero), child.cast(resultType), failOnError = useAnsiAdd),
         sum))
     } else {
-      Seq(Add(coalesce(sum, zero), child.cast(resultType), failOnError = failOnError))
+      Seq(Add(coalesce(sum, zero), child.cast(resultType), failOnError = useAnsiAdd))
     }
   }
 
@@ -127,11 +128,11 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
         // If both the buffer and the input do not overflow, just add them, as they can't be
         // null. See the comments inside `updateExpressions`: `sum` can only be null if
         // overflow happens.
-        Add(KnownNotNull(sum.left), KnownNotNull(sum.right), failOnError)),
+        Add(KnownNotNull(sum.left), KnownNotNull(sum.right), useAnsiAdd)),
       isEmpty.left && isEmpty.right)
   } else {
     Seq(coalesce(
-      Add(coalesce(sum.left, zero), sum.right, failOnError = failOnError),
+      Add(coalesce(sum.left, zero), sum.right, failOnError = useAnsiAdd),
       sum.left))
   }
 
@@ -145,13 +146,13 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
   protected def getEvaluateExpression: Expression = resultType match {
     case d: DecimalType =>
       If(isEmpty, Literal.create(null, resultType),
-        CheckOverflowInSum(sum, d, !failOnError))
+        CheckOverflowInSum(sum, d, !useAnsiAdd))
     case _ if shouldTrackIsEmpty =>
       If(isEmpty, Literal.create(null, resultType), sum)
     case _ => sum
   }
 
-  // The flag `failOnError` won't be shown in the `toString` or `toAggString` methods
+  // The flag `useAnsiAdd` won't be shown in the `toString` or `toAggString` methods
   override def flatArguments: Iterator[Any] = Iterator(child)
 }
 
@@ -170,9 +171,9 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
   since = "1.0.0")
 case class Sum(
     child: Expression,
-    failOnError: Boolean = SQLConf.get.ansiEnabled)
+    useAnsiAdd: Boolean = SQLConf.get.ansiEnabled)
   extends SumBase(child) {
-  def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled)
+  def this(child: Expression) = this(child, useAnsiAdd = SQLConf.get.ansiEnabled)
 
   override def shouldTrackIsEmpty: Boolean = resultType match {
     case _: DecimalType => true
@@ -207,10 +208,10 @@ case class Sum(
 // scalastyle:on line.size.limit
 case class TrySum(child: Expression) extends SumBase(child) {
 
-  override def failOnError: Boolean = dataType match {
-    // Double type won't fail, thus the failOnError is always false
+  override def useAnsiAdd: Boolean = dataType match {
+    // Double type won't fail, thus useAnsiAdd is always false
     // For decimal type, it returns NULL on overflow. It behaves the same as TrySum when
-    // `failOnError` is false.
+    // `useAnsiAdd` is false.
     case _: DoubleType | _: DecimalType => false
     case _ => true
   }
@@ -224,7 +225,7 @@ case class TrySum(child: Expression) extends SumBase(child) {
   }
 
   override lazy val updateExpressions: Seq[Expression] =
-    if (failOnError) {
+    if (useAnsiAdd) {
       val expressions = getUpdateExpressions
       // If the length of updateExpressions is larger than 1, the tail expressions are for
       // tracking whether the input is empty, which doesn't need `TryEval` execution.
@@ -234,14 +235,14 @@ case class TrySum(child: Expression) extends SumBase(child) {
     }
 
   override lazy val mergeExpressions: Seq[Expression] =
-    if (failOnError) {
+    if (useAnsiAdd) {
       getMergeExpressions.map(TryEval)
     } else {
       getMergeExpressions
     }
 
   override lazy val evaluateExpression: Expression =
-    if (failOnError) {
+    if (useAnsiAdd) {
       TryEval(getEvaluateExpression)
     } else {
       getEvaluateExpression
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 14902b08549..9f8faf517a4 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -1,6 +1,6 @@
 <!-- Automatically generated by ExpressionsSchemaSuite -->
 ## Summary
-  - Number of queries: 387
+  - Number of queries: 388
   - Number of expressions that missing example: 12
   - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint
 ## Schema of Built-in Functions
@@ -380,6 +380,7 @@
 | org.apache.spark.sql.catalyst.expressions.aggregate.StddevSamp | stddev | SELECT stddev(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<stddev(col):double> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.StddevSamp | stddev_samp | SELECT stddev_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<stddev_samp(col):double> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.Sum | sum | SELECT sum(col) FROM VALUES (5), (10), (15) AS tab(col) | struct<sum(col):bigint> |
+| org.apache.spark.sql.catalyst.expressions.aggregate.TryAverage | try_avg | SELECT try_avg(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<try_avg(col):double> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.TrySum | try_sum | SELECT try_sum(col) FROM VALUES (5), (10), (15) AS tab(col) | struct<try_sum(col):bigint> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.VariancePop | var_pop | SELECT var_pop(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<var_pop(col):double> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | var_samp | SELECT var_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<var_samp(col):double> |
@@ -392,4 +393,4 @@
 | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>','a/b/text()') | struct<xpath(<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>, a/b/text()):array<string>> |
 | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_long(<a><b>1</b><b>2</b></a>, sum(a/b)):bigint> |
 | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_short(<a><b>1</b><b>2</b></a>, sum(a/b)):smallint> |
-| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |
+| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |
\ No newline at end of file
diff --git a/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql b/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql
index ffa8eefe828..cdd2e632319 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql
@@ -11,3 +11,17 @@ SELECT try_sum(col) FROM VALUES (interval '1 months'), (interval '1 months') AS
 SELECT try_sum(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col);
 SELECT try_sum(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') AS tab(col);
 SELECT try_sum(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') AS tab(col);
+
+-- try_avg
+SELECT try_avg(col) FROM VALUES (5), (10), (15) AS tab(col);
+SELECT try_avg(col) FROM VALUES (5.0), (10.0), (15.0) AS tab(col);
+SELECT try_avg(col) FROM VALUES (NULL), (10), (15) AS tab(col);
+SELECT try_avg(col) FROM VALUES (NULL), (NULL) AS tab(col);
+SELECT try_avg(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col);
+-- test overflow in Decimal(38, 0)
+SELECT try_avg(col) FROM VALUES (98765432109876543210987654321098765432BD), (98765432109876543210987654321098765432BD) AS tab(col);
+
+SELECT try_avg(col) FROM VALUES (interval '1 months'), (interval '1 months') AS tab(col);
+SELECT try_avg(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col);
+SELECT try_avg(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') AS tab(col);
+SELECT try_avg(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') AS tab(col);
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out
index 7ae217ad758..724553f6bd1 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 10
+-- Number of queries: 20
 
 
 -- !query
@@ -80,3 +80,83 @@ SELECT try_sum(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS')
 struct<try_sum(col):interval day>
 -- !query output
 NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (5), (10), (15) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+10.0
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (5.0), (10.0), (15.0) AS tab(col)
+-- !query schema
+struct<try_avg(col):decimal(7,5)>
+-- !query output
+10.00000
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (NULL), (10), (15) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+12.5
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (NULL), (NULL) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+4.6116860184273879E18
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (98765432109876543210987654321098765432BD), (98765432109876543210987654321098765432BD) AS tab(col)
+-- !query schema
+struct<try_avg(col):decimal(38,4)>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '1 months'), (interval '1 months') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval year to month>
+-- !query output
+0-1
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval year to month>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval day to second>
+-- !query output
+0 00:00:01.000000000
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval day to second>
+-- !query output
+NULL
diff --git a/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out b/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out
index 7ae217ad758..724553f6bd1 100644
--- a/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 10
+-- Number of queries: 20
 
 
 -- !query
@@ -80,3 +80,83 @@ SELECT try_sum(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS')
 struct<try_sum(col):interval day>
 -- !query output
 NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (5), (10), (15) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+10.0
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (5.0), (10.0), (15.0) AS tab(col)
+-- !query schema
+struct<try_avg(col):decimal(7,5)>
+-- !query output
+10.00000
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (NULL), (10), (15) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+12.5
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (NULL), (NULL) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+4.6116860184273879E18
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (98765432109876543210987654321098765432BD), (98765432109876543210987654321098765432BD) AS tab(col)
+-- !query schema
+struct<try_avg(col):decimal(38,4)>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '1 months'), (interval '1 months') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval year to month>
+-- !query output
+0-1
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval year to month>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval day to second>
+-- !query output
+0 00:00:01.000000000
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval day to second>
+-- !query output
+NULL
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 81067eef401..0b00659f73b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -4331,6 +4331,18 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
       checkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_sum(v)"), Row(null))
     }
   }
+
+  test("SPARK-38589: try_avg should return null if overflow happens before merging") {
+    val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
+      .map(Period.ofMonths)
+      .toDF("v")
+    val dayTimeDf = Seq(106751991L, 106751991L, 2L)
+      .map(Duration.ofDays)
+      .toDF("v")
+    Seq(yearMonthDf, dayTimeDf).foreach { df =>
+      checkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_avg(v)"), Row(null))
+    }
+  }
 }
 
 case class Foo(bar: Option[String])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org