You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/08/12 02:53:09 UTC

[spark] branch master updated: [SPARK-39976][SQL] ArrayIntersect should handle null in left expression correctly

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dff5c2f2e9c [SPARK-39976][SQL] ArrayIntersect should handle null in left expression correctly
dff5c2f2e9c is described below

commit dff5c2f2e9ce233e270e0e5cde0a40f682ba9534
Author: Angerszhuuuu <an...@gmail.com>
AuthorDate: Fri Aug 12 10:52:33 2022 +0800

    [SPARK-39976][SQL] ArrayIntersect should handle null in left expression correctly
    
    ### What changes were proposed in this pull request?
    `ArrayInterscet` miss judge if null contains in right expression's hash set.
    
    ```
    >>> a = [1, 2, 3]
    >>> b = [3, None, 5]
    >>> df = spark.sparkContext.parallelize(data).toDF(["a","b"])
    >>> df.show()
    +---------+------------+
    |        a|           b|
    +---------+------------+
    |[1, 2, 3]|[3, null, 5]|
    +---------+------------+
    
    >>> df.selectExpr("array_intersect(a,b)").show()
    +---------------------+
    |array_intersect(a, b)|
    +---------------------+
    |                  [3]|
    +---------------------+
    
    >>> df.selectExpr("array_intersect(b,a)").show()
    +---------------------+
    |array_intersect(b, a)|
    +---------------------+
    |            [3, null]|
    +---------------------+
    ```
    
    In origin code gen's code path, when handle `ArrayIntersect`'s array1, it use the below code
    ```
            def withArray1NullAssignment(body: String) =
              if (left.dataType.asInstanceOf[ArrayType].containsNull) {
                if (right.dataType.asInstanceOf[ArrayType].containsNull) {
                  s"""
                     |if ($array1.isNullAt($i)) {
                     |  if ($foundNullElement) {
                     |    $nullElementIndex = $size;
                     |    $foundNullElement = false;
                     |    $size++;
                     |    $builder.$$plus$$eq($nullValueHolder);
                     |  }
                     |} else {
                     |  $body
                     |}
                   """.stripMargin
                } else {
                  s"""
                     |if (!$array1.isNullAt($i)) {
                     |  $body
                     |}
                   """.stripMargin
                }
              } else {
                body
              }
    ```
    We have a flag `foundNullElement` to indicate if array2 really contains a null value. But when implement https://issues.apache.org/jira/browse/SPARK-36829, misunderstand the meaning of `ArrayType.containsNull`,
    so when implement  `SQLOpenHashSet.withNullCheckCode()`
    ```
      def withNullCheckCode(
          arrayContainsNull: Boolean,
          setContainsNull: Boolean,
          array: String,
          index: String,
          hashSet: String,
          handleNotNull: (String, String) => String,
          handleNull: String): String = {
        if (arrayContainsNull) {
          if (setContainsNull) {
            s"""
               |if ($array.isNullAt($index)) {
               |  if (!$hashSet.containsNull()) {
               |    $hashSet.addNull();
               |    $handleNull
               |  }
               |} else {
               |  ${handleNotNull(array, index)}
               |}
             """.stripMargin
          } else {
            s"""
               |if (!$array.isNullAt($index)) {
               | ${handleNotNull(array, index)}
               |}
             """.stripMargin
          }
        } else {
          handleNotNull(array, index)
        }
      }
    ```
    The code path of `  if (arrayContainsNull && setContainsNull) `  is misinterpreted that array's openHashSet really have a null value.
    
    In this pr we add a new parameter `additionalCondition ` to complements the previous implementation of `foundNullElement`. Also refactor the method's parameter name.
    
    ### Why are the changes needed?
    Fix data correct issue
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added UT
    
    Closes #37436 from AngersZhuuuu/SPARK-39776-FOLLOW_UP.
    
    Lead-authored-by: Angerszhuuuu <an...@gmail.com>
    Co-authored-by: AngersZhuuuu <an...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../expressions/collectionOperations.scala         |  8 +++--
 .../org/apache/spark/sql/util/SQLOpenHashSet.scala |  8 ++---
 .../expressions/CollectionExpressionsSuite.scala   | 34 ++++++++++++++++++++++
 3 files changed, 43 insertions(+), 7 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index ae23775b62d..d6a9601f884 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -4155,9 +4155,11 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
           right.dataType.asInstanceOf[ArrayType].containsNull,
           array1, i, hashSetResult, withArray1NaNCheckCodeGenerator,
           s"""
-             |$nullElementIndex = $size;
-             |$size++;
-             |$builder.$$plus$$eq($nullValueHolder);
+             |if ($hashSet.containsNull()) {
+             |  $nullElementIndex = $size;
+             |  $size++;
+             |  $builder.$$plus$$eq($nullValueHolder);
+             |}
            """.stripMargin)
 
         // Only need to track null element index when result array's element is nullable.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala
index 5f0366941de..ee4dd54f28e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala
@@ -79,15 +79,15 @@ object SQLOpenHashSet {
   }
 
   def withNullCheckCode(
-      arrayContainsNull: Boolean,
-      setContainsNull: Boolean,
+      array1ElementNullable: Boolean,
+      array2ElementNullable: Boolean,
       array: String,
       index: String,
       hashSet: String,
       handleNotNull: (String, String) => String,
       handleNull: String): String = {
-    if (arrayContainsNull) {
-      if (setContainsNull) {
+    if (array1ElementNullable) {
+      if (array2ElementNullable) {
         s"""
            |if ($array.isNullAt($index)) {
            |  if (!$hashSet.containsNull()) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index c14a0839b1a..1e466469973 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -2163,6 +2163,23 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(ArrayExcept(empty, oneNull), Seq.empty)
     checkEvaluation(ArrayExcept(oneNull, empty), Seq(null))
     checkEvaluation(ArrayExcept(twoNulls, empty), Seq(null))
+
+    checkEvaluation(ArrayExcept(
+      Literal.create(Seq(1d, 2d, null), ArrayType(DoubleType)),
+      Literal.create(Seq(1d), ArrayType(DoubleType))),
+      Seq(2d, null))
+    checkEvaluation(ArrayExcept(
+      Literal.create(Seq(1d, 2d, null), ArrayType(DoubleType)),
+      Literal.create(Seq(1d), ArrayType(DoubleType, false))),
+      Seq(2d, null))
+    checkEvaluation(ArrayExcept(
+      Literal.create(Seq(1d, 2d), ArrayType(DoubleType)),
+      Literal.create(Seq(1d, null), ArrayType(DoubleType))),
+      Seq(2d))
+    checkEvaluation(ArrayExcept(
+      Literal.create(Seq(1d, 2d), ArrayType(DoubleType, false)),
+      Literal.create(Seq(1d, null), ArrayType(DoubleType))),
+      Seq(2d))
   }
 
   test("Array Intersect") {
@@ -2288,6 +2305,23 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(ArrayIntersect(oneNull, twoNulls), Seq(null))
     checkEvaluation(ArrayIntersect(empty, oneNull), Seq.empty)
     checkEvaluation(ArrayIntersect(oneNull, empty), Seq.empty)
+
+    checkEvaluation(ArrayIntersect(
+      Literal.create(Seq(1d, 2d, null), ArrayType(DoubleType)),
+      Literal.create(Seq(1d), ArrayType(DoubleType))),
+      Seq(1d))
+    checkEvaluation(ArrayIntersect(
+      Literal.create(Seq(1d, 2d, null), ArrayType(DoubleType)),
+      Literal.create(Seq(1d), ArrayType(DoubleType, false))),
+      Seq(1d))
+    checkEvaluation(ArrayIntersect(
+      Literal.create(Seq(1d, 2d), ArrayType(DoubleType)),
+      Literal.create(Seq(1d, null), ArrayType(DoubleType))),
+      Seq(1d))
+    checkEvaluation(ArrayIntersect(
+      Literal.create(Seq(1d, 2d), ArrayType(DoubleType, false)),
+      Literal.create(Seq(1d, null), ArrayType(DoubleType))),
+      Seq(1d))
   }
 
   test("SPARK-31980: Start and end equal in month range") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org