You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/06/10 08:57:39 UTC

[spark] branch master updated: [SPARK-39419][SQL] Fix ArraySort to throw an exception when the comparator returns null

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 65261721a52 [SPARK-39419][SQL] Fix ArraySort to throw an exception when the comparator returns null
65261721a52 is described below

commit 65261721a5205e2ee0e370cc451eb6bcc2d015d1
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Fri Jun 10 16:57:23 2022 +0800

    [SPARK-39419][SQL] Fix ArraySort to throw an exception when the comparator returns null
    
    ### What changes were proposed in this pull request?
    
    Fixes `ArraySort` to throw an exception when the comparator returns `null`.
    
    Also updates the doc to follow the corrected behavior.
    
    ### Why are the changes needed?
    
    When the comparator of `ArraySort` returns `null`, currently it handles it as `0` (equal).
    
    According to the doc,
    
    ```
    It returns -1, 0, or 1 as the first element is less than, equal to, or greater than
    the second element. If the comparator function returns other
    values (including null), the function will fail and raise an error.
    ```
    
    It's fine to return non -1, 0, 1 integers to follow the Java convention (still need to update the doc, though), but it should throw an exception for `null` result.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, if a user uses a comparator that returns `null`, it will throw an error after this PR.
    
    The legacy flag `spark.sql.legacy.allowNullComparisonResultInArraySort` can be used to restore the legacy behavior that handles `null` as `0` (equal).
    
    ### How was this patch tested?
    
    Added some tests.
    
    Closes #36812 from ueshin/issues/SPARK-39419/array_sort.
    
    Authored-by: Takuya UESHIN <ue...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 core/src/main/resources/error/error-classes.json   |  5 +++++
 .../expressions/higherOrderFunctions.scala         | 26 +++++++++++++++++-----
 .../spark/sql/errors/QueryExecutionErrors.scala    |  5 +++++
 .../org/apache/spark/sql/internal/SQLConf.scala    | 10 +++++++++
 .../expressions/HigherOrderFunctionsSuite.scala    | 22 +++++++++++++++++-
 5 files changed, 62 insertions(+), 6 deletions(-)

diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json
index 833ecc0a3c0..275566589f5 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -295,6 +295,11 @@
       "UDF class <className> doesn't implement any UDF interface"
     ]
   },
+  "NULL_COMPARISON_RESULT" : {
+    "message" : [
+      "The comparison result is null. If you want to handle null as 0 (equal), you can set \"spark.sql.legacy.allowNullComparisonResultInArraySort\" to \"true\"."
+    ]
+  },
   "PARSE_CHAR_MISSING_LENGTH" : {
     "message" : [
       "DataType <type> requires a length parameter, for example <type>(10). Please specify the length."
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index 79b76f799d9..135a423b38a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -357,9 +357,9 @@ case class ArrayTransform(
     Since 3.0.0 this function also sorts and returns the array based on the
     given comparator function. The comparator will take two arguments representing
     two elements of the array.
-    It returns -1, 0, or 1 as the first element is less than, equal to, or greater
-    than the second element. If the comparator function returns other
-    values (including null), the function will fail and raise an error.
+    It returns a negative integer, 0, or a positive integer as the first element is less than,
+    equal to, or greater than the second element. If the comparator function returns null,
+    the function will fail and raise an error.
     """,
   examples = """
     Examples:
@@ -375,9 +375,17 @@ case class ArrayTransform(
 // scalastyle:on line.size.limit
 case class ArraySort(
     argument: Expression,
-    function: Expression)
+    function: Expression,
+    allowNullComparisonResult: Boolean)
   extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
 
+  def this(argument: Expression, function: Expression) = {
+    this(
+      argument,
+      function,
+      SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT))
+  }
+
   def this(argument: Expression) = this(argument, ArraySort.defaultComparator)
 
   @transient lazy val elementType: DataType =
@@ -416,7 +424,11 @@ case class ArraySort(
     (o1: Any, o2: Any) => {
       firstElemVar.value.set(o1)
       secondElemVar.value.set(o2)
-      f.eval(inputRow).asInstanceOf[Int]
+      val cmp = f.eval(inputRow)
+      if (!allowNullComparisonResult && cmp == null) {
+        throw QueryExecutionErrors.nullComparisonResultError()
+      }
+      cmp.asInstanceOf[Int]
     }
   }
 
@@ -437,6 +449,10 @@ case class ArraySort(
 
 object ArraySort {
 
+  def apply(argument: Expression, function: Expression): ArraySort = {
+    new ArraySort(argument, function)
+  }
+
   def comparator(left: Expression, right: Expression): Expression = {
     val lit0 = Literal(0)
     val lit1 = Literal(1)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index cd258e3649a..2b573b2385c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -2021,4 +2021,9 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
     new SparkException(
       errorClass = "MULTI_VALUE_SUBQUERY_ERROR", messageParameters = Array(plan), cause = null)
   }
+
+  def nullComparisonResultError(): Throwable = {
+    new SparkException(errorClass = "NULL_COMPARISON_RESULT",
+      messageParameters = Array(), cause = null)
+  }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 8c7702efd47..5e1f3956159 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3818,6 +3818,16 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT =
+    buildConf("spark.sql.legacy.allowNullComparisonResultInArraySort")
+      .internal()
+      .doc("When set to false, `array_sort` function throws an error " +
+        "if the comparator function returns null. " +
+        "If set to true, it restores the legacy behavior that handles null as zero (equal).")
+      .version("3.2.2")
+      .booleanConf
+      .createWithDefault(false)
+
   /**
    * Holds information about keys that have been deprecated.
    *
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
index c0db6d8dc29..b1c4c441427 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -838,4 +838,24 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
       Literal.create(Seq(Double.NaN, 1d, 2d, null), ArrayType(DoubleType))),
       Seq(1d, 2d, Double.NaN, null))
   }
+
+  test("SPARK-39419: ArraySort should throw an exception when the comparator returns null") {
+    val comparator = {
+      val comp = ArraySort.comparator _
+      (left: Expression, right: Expression) =>
+        If(comp(left, right) === 0, Literal.create(null, IntegerType), comp(left, right))
+    }
+
+    withSQLConf(
+        SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT.key -> "false") {
+      checkExceptionInExpression[SparkException](
+        arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator), "The comparison result is null")
+    }
+
+    withSQLConf(
+        SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT.key -> "true") {
+      checkEvaluation(arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator),
+        Seq(1, 1, 2, 3))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org