You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2022/04/12 12:39:53 UTC
[spark] branch master updated: [SPARK-38589][SQL] New SQL function: try_avg
This is an automated email from the ASF dual-hosted git repository.
gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a7f0adb2dd8 [SPARK-38589][SQL] New SQL function: try_avg
a7f0adb2dd8 is described below
commit a7f0adb2dd8449af6f9e9b5a25f11b5dcf5868f1
Author: Gengliang Wang <ge...@apache.org>
AuthorDate: Tue Apr 12 20:39:08 2022 +0800
[SPARK-38589][SQL] New SQL function: try_avg
### What changes were proposed in this pull request?
Add a new SQL function: try_avg. It is identical to the function `avg`, except that it returns NULL result instead of throwing an exception on decimal/interval value overflow.
Note it is also different from `avg` when ANSI mode is off on interval overflows
| Function | avg | try_avg |
|------------------|------------------------------------|-------------|
| year-month interval overflow | Error | Return NULL |
| day-time interval overflow | Error | Return NULL |
### Why are the changes needed?
* Users can manage to finish queries without interruptions in ANSI mode.
* Users can get NULLs instead of runtime errors if interval overflow occurs when ANSI mode is off. For example
```
> SELECT avg(col) FROM VALUES (interval '2147483647 months'),(interval '1 months') AS tab(col)
java.lang.ArithmeticException: integer overflow.
> SELECT try_avg(col) FROM VALUES (interval '2147483647 months'),(interval '1 months') AS tab(col)
NULL
```
### Does this PR introduce _any_ user-facing change?
Yes, adding a new SQL function: try_avg. It is identical to the function `avg`, except that it returns NULL result instead of throwing an exception on decimal/interval value overflow.
### How was this patch tested?
UT
Closes #35896 from gengliangwang/tryAvg.
Lead-authored-by: Gengliang Wang <ge...@apache.org>
Co-authored-by: Gengliang Wang <lt...@gmail.com>
Signed-off-by: Gengliang Wang <ge...@apache.org>
---
docs/sql-ref-ansi-compliance.md | 3 +-
.../sql/catalyst/analysis/FunctionRegistry.scala | 1 +
.../catalyst/expressions/aggregate/Average.scala | 125 +++++++++++++++++----
.../sql/catalyst/expressions/aggregate/Sum.scala | 35 +++---
.../sql-functions/sql-expression-schema.md | 5 +-
.../resources/sql-tests/inputs/try_aggregates.sql | 14 +++
.../sql-tests/results/ansi/try_aggregates.sql.out | 82 +++++++++++++-
.../sql-tests/results/try_aggregates.sql.out | 82 +++++++++++++-
.../scala/org/apache/spark/sql/SQLQuerySuite.scala | 12 ++
9 files changed, 313 insertions(+), 46 deletions(-)
diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md
index 0f7f29cde7f..66161a112b1 100644
--- a/docs/sql-ref-ansi-compliance.md
+++ b/docs/sql-ref-ansi-compliance.md
@@ -316,7 +316,8 @@ When ANSI mode is on, it throws exceptions for invalid operations. You can use t
- `try_subtract`: identical to the add operator `-`, except that it returns `NULL` result instead of throwing an exception on integral value overflow.
- `try_multiply`: identical to the add operator `*`, except that it returns `NULL` result instead of throwing an exception on integral value overflow.
- `try_divide`: identical to the division operator `/`, except that it returns `NULL` result instead of throwing an exception on dividing 0.
- - `try_sum`: identical to the function `sum`, except that it returns `NULL` result instead of throwing an exception on integral/decimal value overflow.
+ - `try_sum`: identical to the function `sum`, except that it returns `NULL` result instead of throwing an exception on integral/decimal/interval value overflow.
+ - `try_avg`: identical to the function `avg`, except that it returns `NULL` result instead of throwing an exception on decimal/interval value overflow.
- `try_element_at`: identical to the function `element_at`, except that it returns `NULL` result instead of throwing an exception on array's index out of bound or map's key not found.
### SQL Keywords (optional, disabled by default)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 1824fb68f76..80374f769a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -453,6 +453,7 @@ object FunctionRegistry {
expression[TrySubtract]("try_subtract"),
expression[TryMultiply]("try_multiply"),
expression[TryElementAt]("try_element_at"),
+ expression[TryAverage]("try_avg"),
expression[TrySum]("try_sum"),
expression[TryToBinary]("try_to_binary"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 533f7f20b25..14914576091 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -26,25 +26,13 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.",
- examples = """
- Examples:
- > SELECT _FUNC_(col) FROM VALUES (1), (2), (3) AS tab(col);
- 2.0
- > SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col);
- 1.5
- """,
- group = "agg_funcs",
- since = "1.0.0")
-case class Average(
- child: Expression,
- failOnError: Boolean = SQLConf.get.ansiEnabled)
+abstract class AverageBase
extends DeclarativeAggregate
with ImplicitCastInputTypes
with UnaryLike[Expression] {
- def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled)
+ // Whether to use ANSI add or not during the execution.
+ def useAnsiAdd: Boolean
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")
@@ -61,7 +49,7 @@ case class Average(
final override val nodePatterns: Seq[TreePattern] = Seq(AVERAGE)
- private lazy val resultType = child.dataType match {
+ protected lazy val resultType = child.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
case _: YearMonthIntervalType => YearMonthIntervalType()
@@ -86,18 +74,18 @@ case class Average(
/* count = */ Literal(0L)
)
- override lazy val mergeExpressions = Seq(
- /* sum = */ sum.left + sum.right,
+ protected def getMergeExpressions = Seq(
+ /* sum = */ Add(sum.left, sum.right, useAnsiAdd),
/* count = */ count.left + count.right
)
// If all input are nulls, count will be 0 and we will get null after the division.
// We can't directly use `/` as it throws an exception under ansi mode.
- override lazy val evaluateExpression = child.dataType match {
+ protected def getEvaluateExpression = child.dataType match {
case _: DecimalType =>
DecimalPrecision.decimalAndDecimal()(
Divide(
- CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !failOnError),
+ CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !useAnsiAdd),
count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
case _: YearMonthIntervalType =>
If(EqualTo(count, Literal(0L)),
@@ -109,17 +97,106 @@ case class Average(
Divide(sum.cast(resultType), count.cast(resultType), failOnError = false)
}
- override lazy val updateExpressions: Seq[Expression] = Seq(
+ protected def getUpdateExpressions: Seq[Expression] = Seq(
/* sum = */
Add(
sum,
- coalesce(child.cast(sumDataType), Literal.default(sumDataType))),
+ coalesce(child.cast(sumDataType), Literal.default(sumDataType)),
+ failOnError = useAnsiAdd),
/* count = */ If(child.isNull, count, count + 1L)
)
+ // The flag `useAnsiAdd` won't be shown in the `toString` or `toAggString` methods
+ override def flatArguments: Iterator[Any] = Iterator(child)
+}
+
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(col) FROM VALUES (1), (2), (3) AS tab(col);
+ 2.0
+ > SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col);
+ 1.5
+ """,
+ group = "agg_funcs",
+ since = "1.0.0")
+case class Average(
+ child: Expression,
+ useAnsiAdd: Boolean = SQLConf.get.ansiEnabled) extends AverageBase {
+ def this(child: Expression) = this(child, useAnsiAdd = SQLConf.get.ansiEnabled)
+
override protected def withNewChildInternal(newChild: Expression): Average =
copy(child = newChild)
- // The flag `failOnError` won't be shown in the `toString` or `toAggString` methods
- override def flatArguments: Iterator[Any] = Iterator(child)
+ override lazy val updateExpressions: Seq[Expression] = getUpdateExpressions
+
+ override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions
+
+ override lazy val evaluateExpression: Expression = getEvaluateExpression
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Returns the mean calculated from values of a group and the result is null on overflow.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(col) FROM VALUES (1), (2), (3) AS tab(col);
+ 2.0
+ > SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col);
+ 1.5
+ > SELECT _FUNC_(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col);
+ NULL
+ """,
+ group = "agg_funcs",
+ since = "3.3.0")
+// scalastyle:on line.size.limit
+case class TryAverage(child: Expression) extends AverageBase {
+ override def useAnsiAdd: Boolean = resultType match {
+ // Double type won't fail, thus we can always use non-Ansi Add.
+ // For decimal type, it returns NULL on overflow. It behaves the same as TrySum when
+ // `failOnError` is false.
+ case _: DoubleType | _: DecimalType => false
+ case _ => true
+ }
+
+ private def addTryEvalIfNeeded(expression: Expression): Expression = {
+ if (useAnsiAdd) {
+ TryEval(expression)
+ } else {
+ expression
+ }
+ }
+
+ override lazy val updateExpressions: Seq[Expression] = {
+ val expressions = getUpdateExpressions
+ addTryEvalIfNeeded(expressions.head) +: expressions.tail
+ }
+
+ override lazy val mergeExpressions: Seq[Expression] = {
+ val expressions = getMergeExpressions
+ if (useAnsiAdd) {
+ val bufferOverflow = sum.left.isNull && count.left > 0L
+ val inputOverflow = sum.right.isNull && count.right > 0L
+ Seq(
+ If(
+ bufferOverflow || inputOverflow,
+ Literal.create(null, resultType),
+ // If both the buffer and the input do not overflow, just add them, as they can't be
+ // null.
+ TryEval(Add(KnownNotNull(sum.left), KnownNotNull(sum.right), useAnsiAdd))),
+ expressions(1))
+ } else {
+ expressions
+ }
+ }
+
+ override lazy val evaluateExpression: Expression = {
+ addTryEvalIfNeeded(getEvaluateExpression)
+ }
+
+ override protected def withNewChildInternal(newChild: Expression): Expression =
+ copy(child = newChild)
+
+ override def prettyName: String = "try_avg"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index fd27edfc8fc..f2c6925b837 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -30,7 +30,8 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
with ImplicitCastInputTypes
with UnaryLike[Expression] {
- def failOnError: Boolean
+ // Whether to use ANSI add or not during the execution.
+ def useAnsiAdd: Boolean
protected def shouldTrackIsEmpty: Boolean
@@ -81,9 +82,9 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
// null if overflow happens under non-ansi mode.
val sumExpr = if (child.nullable) {
If(child.isNull, sum,
- Add(sum, KnownNotNull(child).cast(resultType), failOnError = failOnError))
+ Add(sum, KnownNotNull(child).cast(resultType), failOnError = useAnsiAdd))
} else {
- Add(sum, child.cast(resultType), failOnError = failOnError)
+ Add(sum, child.cast(resultType), failOnError = useAnsiAdd)
}
// The buffer becomes non-empty after seeing the first not-null input.
val isEmptyExpr = if (child.nullable) {
@@ -98,10 +99,10 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
// in case the input is nullable. The `sum` can only be null if there is no value, as
// non-decimal type can produce overflowed value under non-ansi mode.
if (child.nullable) {
- Seq(coalesce(Add(coalesce(sum, zero), child.cast(resultType), failOnError = failOnError),
+ Seq(coalesce(Add(coalesce(sum, zero), child.cast(resultType), failOnError = useAnsiAdd),
sum))
} else {
- Seq(Add(coalesce(sum, zero), child.cast(resultType), failOnError = failOnError))
+ Seq(Add(coalesce(sum, zero), child.cast(resultType), failOnError = useAnsiAdd))
}
}
@@ -127,11 +128,11 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
// If both the buffer and the input do not overflow, just add them, as they can't be
// null. See the comments inside `updateExpressions`: `sum` can only be null if
// overflow happens.
- Add(KnownNotNull(sum.left), KnownNotNull(sum.right), failOnError)),
+ Add(KnownNotNull(sum.left), KnownNotNull(sum.right), useAnsiAdd)),
isEmpty.left && isEmpty.right)
} else {
Seq(coalesce(
- Add(coalesce(sum.left, zero), sum.right, failOnError = failOnError),
+ Add(coalesce(sum.left, zero), sum.right, failOnError = useAnsiAdd),
sum.left))
}
@@ -145,13 +146,13 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
protected def getEvaluateExpression: Expression = resultType match {
case d: DecimalType =>
If(isEmpty, Literal.create(null, resultType),
- CheckOverflowInSum(sum, d, !failOnError))
+ CheckOverflowInSum(sum, d, !useAnsiAdd))
case _ if shouldTrackIsEmpty =>
If(isEmpty, Literal.create(null, resultType), sum)
case _ => sum
}
- // The flag `failOnError` won't be shown in the `toString` or `toAggString` methods
+ // The flag `useAnsiAdd` won't be shown in the `toString` or `toAggString` methods
override def flatArguments: Iterator[Any] = Iterator(child)
}
@@ -170,9 +171,9 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
since = "1.0.0")
case class Sum(
child: Expression,
- failOnError: Boolean = SQLConf.get.ansiEnabled)
+ useAnsiAdd: Boolean = SQLConf.get.ansiEnabled)
extends SumBase(child) {
- def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled)
+ def this(child: Expression) = this(child, useAnsiAdd = SQLConf.get.ansiEnabled)
override def shouldTrackIsEmpty: Boolean = resultType match {
case _: DecimalType => true
@@ -207,10 +208,10 @@ case class Sum(
// scalastyle:on line.size.limit
case class TrySum(child: Expression) extends SumBase(child) {
- override def failOnError: Boolean = dataType match {
- // Double type won't fail, thus the failOnError is always false
+ override def useAnsiAdd: Boolean = dataType match {
+ // Double type won't fail, thus useAnsiAdd is always false
// For decimal type, it returns NULL on overflow. It behaves the same as TrySum when
- // `failOnError` is false.
+ // `useAnsiAdd` is false.
case _: DoubleType | _: DecimalType => false
case _ => true
}
@@ -224,7 +225,7 @@ case class TrySum(child: Expression) extends SumBase(child) {
}
override lazy val updateExpressions: Seq[Expression] =
- if (failOnError) {
+ if (useAnsiAdd) {
val expressions = getUpdateExpressions
// If the length of updateExpressions is larger than 1, the tail expressions are for
// tracking whether the input is empty, which doesn't need `TryEval` execution.
@@ -234,14 +235,14 @@ case class TrySum(child: Expression) extends SumBase(child) {
}
override lazy val mergeExpressions: Seq[Expression] =
- if (failOnError) {
+ if (useAnsiAdd) {
getMergeExpressions.map(TryEval)
} else {
getMergeExpressions
}
override lazy val evaluateExpression: Expression =
- if (failOnError) {
+ if (useAnsiAdd) {
TryEval(getEvaluateExpression)
} else {
getEvaluateExpression
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 14902b08549..9f8faf517a4 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -1,6 +1,6 @@
<!-- Automatically generated by ExpressionsSchemaSuite -->
## Summary
- - Number of queries: 387
+ - Number of queries: 388
- Number of expressions that missing example: 12
- Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint
## Schema of Built-in Functions
@@ -380,6 +380,7 @@
| org.apache.spark.sql.catalyst.expressions.aggregate.StddevSamp | stddev | SELECT stddev(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<stddev(col):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.StddevSamp | stddev_samp | SELECT stddev_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<stddev_samp(col):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Sum | sum | SELECT sum(col) FROM VALUES (5), (10), (15) AS tab(col) | struct<sum(col):bigint> |
+| org.apache.spark.sql.catalyst.expressions.aggregate.TryAverage | try_avg | SELECT try_avg(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<try_avg(col):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.TrySum | try_sum | SELECT try_sum(col) FROM VALUES (5), (10), (15) AS tab(col) | struct<try_sum(col):bigint> |
| org.apache.spark.sql.catalyst.expressions.aggregate.VariancePop | var_pop | SELECT var_pop(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<var_pop(col):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | var_samp | SELECT var_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<var_samp(col):double> |
@@ -392,4 +393,4 @@
| org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>','a/b/text()') | struct<xpath(<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>, a/b/text()):array<string>> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_long(<a><b>1</b><b>2</b></a>, sum(a/b)):bigint> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_short(<a><b>1</b><b>2</b></a>, sum(a/b)):smallint> |
-| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |
+| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |
\ No newline at end of file
diff --git a/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql b/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql
index ffa8eefe828..cdd2e632319 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql
@@ -11,3 +11,17 @@ SELECT try_sum(col) FROM VALUES (interval '1 months'), (interval '1 months') AS
SELECT try_sum(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col);
SELECT try_sum(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') AS tab(col);
SELECT try_sum(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') AS tab(col);
+
+-- try_avg
+SELECT try_avg(col) FROM VALUES (5), (10), (15) AS tab(col);
+SELECT try_avg(col) FROM VALUES (5.0), (10.0), (15.0) AS tab(col);
+SELECT try_avg(col) FROM VALUES (NULL), (10), (15) AS tab(col);
+SELECT try_avg(col) FROM VALUES (NULL), (NULL) AS tab(col);
+SELECT try_avg(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col);
+-- test overflow in Decimal(38, 0)
+SELECT try_avg(col) FROM VALUES (98765432109876543210987654321098765432BD), (98765432109876543210987654321098765432BD) AS tab(col);
+
+SELECT try_avg(col) FROM VALUES (interval '1 months'), (interval '1 months') AS tab(col);
+SELECT try_avg(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col);
+SELECT try_avg(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') AS tab(col);
+SELECT try_avg(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') AS tab(col);
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out
index 7ae217ad758..724553f6bd1 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 10
+-- Number of queries: 20
-- !query
@@ -80,3 +80,83 @@ SELECT try_sum(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS')
struct<try_sum(col):interval day>
-- !query output
NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (5), (10), (15) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+10.0
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (5.0), (10.0), (15.0) AS tab(col)
+-- !query schema
+struct<try_avg(col):decimal(7,5)>
+-- !query output
+10.00000
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (NULL), (10), (15) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+12.5
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (NULL), (NULL) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+4.6116860184273879E18
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (98765432109876543210987654321098765432BD), (98765432109876543210987654321098765432BD) AS tab(col)
+-- !query schema
+struct<try_avg(col):decimal(38,4)>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '1 months'), (interval '1 months') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval year to month>
+-- !query output
+0-1
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval year to month>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval day to second>
+-- !query output
+0 00:00:01.000000000
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval day to second>
+-- !query output
+NULL
diff --git a/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out b/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out
index 7ae217ad758..724553f6bd1 100644
--- a/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 10
+-- Number of queries: 20
-- !query
@@ -80,3 +80,83 @@ SELECT try_sum(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS')
struct<try_sum(col):interval day>
-- !query output
NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (5), (10), (15) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+10.0
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (5.0), (10.0), (15.0) AS tab(col)
+-- !query schema
+struct<try_avg(col):decimal(7,5)>
+-- !query output
+10.00000
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (NULL), (10), (15) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+12.5
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (NULL), (NULL) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col)
+-- !query schema
+struct<try_avg(col):double>
+-- !query output
+4.6116860184273879E18
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (98765432109876543210987654321098765432BD), (98765432109876543210987654321098765432BD) AS tab(col)
+-- !query schema
+struct<try_avg(col):decimal(38,4)>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '1 months'), (interval '1 months') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval year to month>
+-- !query output
+0-1
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval year to month>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval day to second>
+-- !query output
+0 00:00:01.000000000
+
+
+-- !query
+SELECT try_avg(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') AS tab(col)
+-- !query schema
+struct<try_avg(col):interval day to second>
+-- !query output
+NULL
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 81067eef401..0b00659f73b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -4331,6 +4331,18 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
checkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_sum(v)"), Row(null))
}
}
+
+ test("SPARK-38589: try_avg should return null if overflow happens before merging") {
+ val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
+ .map(Period.ofMonths)
+ .toDF("v")
+ val dayTimeDf = Seq(106751991L, 106751991L, 2L)
+ .map(Duration.ofDays)
+ .toDF("v")
+ Seq(yearMonthDf, dayTimeDf).foreach { df =>
+ checkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_avg(v)"), Row(null))
+ }
+ }
}
case class Foo(bar: Option[String])
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org