You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/08/15 13:29:45 UTC

[spark] branch master updated: [SPARK-40019][SQL] Refactor comment of ArrayType's containsNull and refactor the misunderstanding logics in collectionOperator's expression about `containsNull`

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f66bcd4fba8 [SPARK-40019][SQL] Refactor comment of ArrayType's containsNull and refactor the misunderstanding logics in collectionOperator's expression about `containsNull`
f66bcd4fba8 is described below

commit f66bcd4fba8d0947fd3c7a9c2f9621e78c1fbc0f
Author: Angerszhuuuu <an...@gmail.com>
AuthorDate: Mon Aug 15 21:29:20 2022 +0800

    [SPARK-40019][SQL] Refactor comment of ArrayType's containsNull and refactor the misunderstanding logics in collectionOperator's expression about `containsNull`
    
    ### What changes were proposed in this pull request?
    ArrayType's parameter `containsNull` means this array can contains null, related to nullable, this is easy to misunderstand in reading logic. In this pr, we  refactor the comment about `containsNull` and refactor the code in ArrayType related expression to make the code path have a certain meaning.
    
    ### Why are the changes needed?
    Refactor code
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Not need
    
    Closes #37453 from AngersZhuuuu/SPARK-40019.
    
    Lead-authored-by: Angerszhuuuu <an...@gmail.com>
    Co-authored-by: AngersZhuuuu <an...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../expressions/collectionOperations.scala         | 79 ++++++++++++----------
 .../expressions/complexTypeExtractors.scala        |  7 +-
 .../org/apache/spark/sql/types/ArrayType.scala     |  7 +-
 3 files changed, 52 insertions(+), 41 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index d6a9601f884..f40f5a98232 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -71,6 +71,9 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression
         s"[${left.dataType.catalogString}, ${right.dataType.catalogString}]")
     }
   }
+
+  protected def leftArrayElementNullable = left.dataType.asInstanceOf[ArrayType].containsNull
+  protected def rightArrayElementNullable = right.dataType.asInstanceOf[ArrayType].containsNull
 }
 
 
@@ -895,7 +898,8 @@ trait ArraySortLike extends ExpectsInputTypes {
   @transient lazy val elementType: DataType =
     arrayExpression.dataType.asInstanceOf[ArrayType].elementType
 
-  def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull
+  private def resultArrayElementNullable: Boolean =
+    arrayExpression.dataType.asInstanceOf[ArrayType].containsNull
 
   def sortEval(array: Any, ascending: Boolean): Any = {
     val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
@@ -930,8 +934,8 @@ trait ArraySortLike extends ExpectsInputTypes {
       } else {
         s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};"
       }
-      val canPerformFastSort =
-        CodeGenerator.isPrimitiveType(elementType) && elementType != BooleanType && !containsNull
+      val canPerformFastSort = CodeGenerator.isPrimitiveType(elementType) &&
+        elementType != BooleanType && !resultArrayElementNullable
       val nonNullPrimitiveAscendingSort = if (canPerformFastSort) {
           val javaType = CodeGenerator.javaType(elementType)
           val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType)
@@ -1079,6 +1083,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None)
 
   override def dataType: DataType = child.dataType
 
+  private def resultArrayElementNullable = dataType.asInstanceOf[ArrayType].containsNull
+
   @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
 
   @transient private[this] var random: RandomIndicesGenerator = _
@@ -1118,7 +1124,7 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None)
     val initialization = CodeGenerator.createArrayData(
       arrayData, elementType, numElements, s" $prettyName failed.")
     val assignment = CodeGenerator.createArrayAssignment(arrayData, elementType, childName,
-      i, s"$indices[$i]", dataType.asInstanceOf[ArrayType].containsNull)
+      i, s"$indices[$i]", resultArrayElementNullable)
 
     s"""
        |int $numElements = $childName.numElements();
@@ -1162,6 +1168,8 @@ case class Reverse(child: Expression)
 
   override def dataType: DataType = child.dataType
 
+  private def resultArrayElementNullable = dataType.asInstanceOf[ArrayType].containsNull
+
   override def nullSafeEval(input: Any): Any = doReverse(input)
 
   @transient private lazy val doReverse: Any => Any = dataType match {
@@ -1196,7 +1204,7 @@ case class Reverse(child: Expression)
     val initialization = CodeGenerator.createArrayData(
       arrayData, elementType, numElements, s" $prettyName failed.")
     val assignment = CodeGenerator.createArrayAssignment(
-      arrayData, elementType, childName, i, j, dataType.asInstanceOf[ArrayType].containsNull)
+      arrayData, elementType, childName, i, j, resultArrayElementNullable)
 
     s"""
        |final int $numElements = $childName.numElements();
@@ -1347,8 +1355,7 @@ case class ArraysOverlap(left: Expression, right: Expression)
   }
 
   override def nullable: Boolean = {
-    left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull ||
-      right.dataType.asInstanceOf[ArrayType].containsNull
+    left.nullable || right.nullable || leftArrayElementNullable || rightArrayElementNullable
   }
 
   override def nullSafeEval(a1: Any, a2: Any): Any = {
@@ -1560,6 +1567,8 @@ case class Slice(x: Expression, start: Expression, length: Expression)
 
   override def dataType: DataType = x.dataType
 
+  private def resultArrayElementNullable = dataType.asInstanceOf[ArrayType].containsNull
+
   override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType)
 
   override def first: Expression = x
@@ -1632,7 +1641,7 @@ case class Slice(x: Expression, start: Expression, length: Expression)
     val allocation = CodeGenerator.createArrayData(
       values, elementType, resLength, s" $prettyName failed.")
     val assignment = CodeGenerator.createArrayAssignment(values, elementType, inputArray,
-      i, s"$i + $startIdx", dataType.asInstanceOf[ArrayType].containsNull)
+      i, s"$i + $startIdx", resultArrayElementNullable)
 
     s"""
        |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
@@ -2103,7 +2112,8 @@ case class ElementAt(
   @transient private lazy val mapValueContainsNull =
     left.dataType.asInstanceOf[MapType].valueContainsNull
 
-  @transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull
+  @transient private lazy val arrayElementNullable =
+    left.dataType.asInstanceOf[ArrayType].containsNull
 
   @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType)
 
@@ -2189,7 +2199,7 @@ case class ElementAt(
           } else {
             array.numElements() + index
           }
-          if (arrayContainsNull && array.isNullAt(idx)) {
+          if (arrayElementNullable && array.isNullAt(idx)) {
             null
           } else {
             array.get(idx, dataType)
@@ -2205,7 +2215,7 @@ case class ElementAt(
       case _: ArrayType =>
         nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
           val index = ctx.freshName("elementAtIndex")
-          val nullCheck = if (arrayContainsNull) {
+          val nullCheck = if (arrayElementNullable) {
             s"""
                |if ($eval1.isNullAt($index)) {
                |  ${ev.isNull} = true;
@@ -2353,6 +2363,8 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
     }
   }
 
+  private def resultArrayElementNullable = dataType.asInstanceOf[ArrayType].containsNull
+
   private def javaType: String = CodeGenerator.javaType(dataType)
 
   override def nullable: Boolean = children.exists(_.nullable)
@@ -2484,8 +2496,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
     val initialization = CodeGenerator.createArrayData(
       arrayData, elementType, numElemName, s" $prettyName failed.")
     val assignment = CodeGenerator.createArrayAssignment(
-      arrayData, elementType, s"args[$y]", counter, z,
-      dataType.asInstanceOf[ArrayType].containsNull)
+      arrayData, elementType, s"args[$y]", counter, z, resultArrayElementNullable)
 
     val concat = ctx.freshName("concat")
     val concatDef =
@@ -2535,6 +2546,8 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran
 
   @transient override lazy val dataType: DataType = childDataType.elementType
 
+  private def resultArrayElementNullable = dataType.asInstanceOf[ArrayType].containsNull
+
   @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
 
   override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
@@ -2604,8 +2617,7 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran
     val allocation = CodeGenerator.createArrayData(
       tempArrayDataName, elementType, numElemName, s" $prettyName failed.")
     val assignment = CodeGenerator.createArrayAssignment(
-      tempArrayDataName, elementType, arr, counter, l,
-      dataType.asInstanceOf[ArrayType].containsNull)
+      tempArrayDataName, elementType, arr, counter, l, resultArrayElementNullable)
 
     s"""
     |$numElemCode
@@ -3486,6 +3498,8 @@ trait ArraySetLike {
   @transient protected lazy val ordering: Ordering[Any] =
     TypeUtils.getInterpretedOrdering(et)
 
+  protected def resultArrayElementNullable = dt.asInstanceOf[ArrayType].containsNull
+
   protected def genGetValue(array: String, i: String): String =
     CodeGenerator.getValue(array, et, i)
 
@@ -3521,7 +3535,7 @@ trait ArraySetLike {
       body: String,
       value: String,
       nullElementIndex: String): String = {
-    if (dt.asInstanceOf[ArrayType].containsNull) {
+    if (resultArrayElementNullable) {
       s"""
          |$body
          |if ($nullElementIndex >= 0) {
@@ -3662,7 +3676,7 @@ case class ArrayDistinct(child: Expression)
         val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
 
         // Only need to track null element index when array's element is nullable.
-        val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
+        val declareNullTrackVariables = if (resultArrayElementNullable) {
           s"""
              |int $nullElementIndex = -1;
            """.stripMargin
@@ -3692,8 +3706,8 @@ case class ArrayDistinct(child: Expression)
                      """.stripMargin)
 
         val processArray = SQLOpenHashSet.withNullCheckCode(
-          dataType.asInstanceOf[ArrayType].containsNull,
-          dataType.asInstanceOf[ArrayType].containsNull,
+          resultArrayElementNullable,
+          resultArrayElementNullable,
           array, i, hashSet, withNaNCheckCodeGenerator,
           s"""
              |$nullElementIndex = $size;
@@ -3880,8 +3894,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
                      """.stripMargin)
 
         val processArray = SQLOpenHashSet.withNullCheckCode(
-          dataType.asInstanceOf[ArrayType].containsNull,
-          dataType.asInstanceOf[ArrayType].containsNull,
+          resultArrayElementNullable,
+          resultArrayElementNullable,
           array, i, hashSet, withNaNCheckCodeGenerator,
           s"""
              |$nullElementIndex = $size;
@@ -3890,7 +3904,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
            """.stripMargin)
 
         // Only need to track null element index when result array's element is nullable.
-        val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
+        val declareNullTrackVariables = if (resultArrayElementNullable) {
           s"""
              |int $nullElementIndex = -1;
            """.stripMargin
@@ -3985,9 +3999,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
 
   private lazy val internalDataType: DataType = {
     dataTypeCheck
-    ArrayType(elementType,
-      left.dataType.asInstanceOf[ArrayType].containsNull &&
-        right.dataType.asInstanceOf[ArrayType].containsNull)
+    ArrayType(elementType, leftArrayElementNullable && rightArrayElementNullable)
   }
 
   override def dataType: DataType = internalDataType
@@ -4122,8 +4134,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
                 (valueNaN: String) => "")
 
         val writeArray2ToHashSet = SQLOpenHashSet.withNullCheckCode(
-          right.dataType.asInstanceOf[ArrayType].containsNull,
-          left.dataType.asInstanceOf[ArrayType].containsNull,
+          rightArrayElementNullable, leftArrayElementNullable,
           array2, i, hashSet, withArray2NaNCheckCodeGenerator, "")
 
         val body =
@@ -4151,8 +4162,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
                  """.stripMargin)
 
         val processArray1 = SQLOpenHashSet.withNullCheckCode(
-          left.dataType.asInstanceOf[ArrayType].containsNull,
-          right.dataType.asInstanceOf[ArrayType].containsNull,
+          leftArrayElementNullable, rightArrayElementNullable,
           array1, i, hashSetResult, withArray1NaNCheckCodeGenerator,
           s"""
              |if ($hashSet.containsNull()) {
@@ -4163,7 +4173,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
            """.stripMargin)
 
         // Only need to track null element index when result array's element is nullable.
-        val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
+        val declareNullTrackVariables = if (resultArrayElementNullable) {
           s"""
              |int $nullElementIndex = -1;
            """.stripMargin
@@ -4340,8 +4350,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
                 (valueNaN: Any) => "")
 
         val writeArray2ToHashSet = SQLOpenHashSet.withNullCheckCode(
-          right.dataType.asInstanceOf[ArrayType].containsNull,
-          left.dataType.asInstanceOf[ArrayType].containsNull,
+          rightArrayElementNullable, leftArrayElementNullable,
           array2, i, hashSet, withArray2NaNCheckCodeGenerator, "")
 
         val body =
@@ -4366,8 +4375,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
                  """.stripMargin)
 
         val processArray1 = SQLOpenHashSet.withNullCheckCode(
-          left.dataType.asInstanceOf[ArrayType].containsNull,
-          left.dataType.asInstanceOf[ArrayType].containsNull,
+          leftArrayElementNullable,
+          leftArrayElementNullable,
           array1, i, hashSet, withArray1NaNCheckCodeGenerator,
           s"""
              |$nullElementIndex = $size;
@@ -4376,7 +4385,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
            """.stripMargin)
 
         // Only need to track null element index when array1's element is nullable.
-        val declareNullTrackVariables = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
+        val declareNullTrackVariables = if (leftArrayElementNullable) {
           s"""
              |int $nullElementIndex = -1;
            """.stripMargin
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index b6cbb1d0005..7b99b9d1082 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -282,7 +282,8 @@ case class GetArrayItem(
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
       val index = ctx.freshName("index")
-      val nullCheck = if (child.dataType.asInstanceOf[ArrayType].containsNull) {
+      val childArrayElementNullable = child.dataType.asInstanceOf[ArrayType].containsNull
+      val nullCheck = if (childArrayElementNullable) {
         s"""else if ($eval1.isNullAt($index)) {
                ${ev.isNull} = true;
             }
@@ -333,7 +334,7 @@ trait GetArrayItemUtil {
       ordinal: Expression,
       failOnError: Boolean,
       nullability: (Seq[Expression], Int) => Boolean): Boolean = {
-    val arrayContainsNull = child.dataType.asInstanceOf[ArrayType].containsNull
+    val arrayElementNullable = child.dataType.asInstanceOf[ArrayType].containsNull
     if (ordinal.foldable && !ordinal.nullable) {
       val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
       child match {
@@ -345,7 +346,7 @@ trait GetArrayItemUtil {
           true
       }
     } else {
-      if (failOnError) arrayContainsNull else true
+      if (failOnError) arrayElementNullable else true
     }
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index b5708bae923..e139823b2bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -53,11 +53,12 @@ object ArrayType extends AbstractDataType {
  * Please use `DataTypes.createArrayType()` to create a specific instance.
  *
  * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and
- * `containsNull: Boolean`. The field of `elementType` is used to specify the type of
- * array elements. The field of `containsNull` is used to specify if the array has `null` values.
+ * `containsNull: Boolean`.
+ * The field of `elementType` is used to specify the type of array elements.
+ * The field of `containsNull` is used to specify if the array can have `null` values.
  *
  * @param elementType The data type of values.
- * @param containsNull Indicates if values have `null` values
+ * @param containsNull Indicates if the array can have `null` values
  *
  * @since 1.3.0
  */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org