You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/08/15 13:29:45 UTC
[spark] branch master updated: [SPARK-40019][SQL] Refactor comment of ArrayType's containsNull and refactor the misunderstanding logics in collectionOperator's expression about `containsNull`
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f66bcd4fba8 [SPARK-40019][SQL] Refactor comment of ArrayType's containsNull and refactor the misunderstanding logics in collectionOperator's expression about `containsNull`
f66bcd4fba8 is described below
commit f66bcd4fba8d0947fd3c7a9c2f9621e78c1fbc0f
Author: Angerszhuuuu <an...@gmail.com>
AuthorDate: Mon Aug 15 21:29:20 2022 +0800
[SPARK-40019][SQL] Refactor comment of ArrayType's containsNull and refactor the misunderstanding logics in collectionOperator's expression about `containsNull`
### What changes were proposed in this pull request?
ArrayType's parameter `containsNull` means this array can contains null, related to nullable, this is easy to misunderstand in reading logic. In this pr, we refactor the comment about `containsNull` and refactor the code in ArrayType related expression to make the code path have a certain meaning.
### Why are the changes needed?
Refactor code
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Not need
Closes #37453 from AngersZhuuuu/SPARK-40019.
Lead-authored-by: Angerszhuuuu <an...@gmail.com>
Co-authored-by: AngersZhuuuu <an...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../expressions/collectionOperations.scala | 79 ++++++++++++----------
.../expressions/complexTypeExtractors.scala | 7 +-
.../org/apache/spark/sql/types/ArrayType.scala | 7 +-
3 files changed, 52 insertions(+), 41 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index d6a9601f884..f40f5a98232 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -71,6 +71,9 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression
s"[${left.dataType.catalogString}, ${right.dataType.catalogString}]")
}
}
+
+ protected def leftArrayElementNullable = left.dataType.asInstanceOf[ArrayType].containsNull
+ protected def rightArrayElementNullable = right.dataType.asInstanceOf[ArrayType].containsNull
}
@@ -895,7 +898,8 @@ trait ArraySortLike extends ExpectsInputTypes {
@transient lazy val elementType: DataType =
arrayExpression.dataType.asInstanceOf[ArrayType].elementType
- def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull
+ private def resultArrayElementNullable: Boolean =
+ arrayExpression.dataType.asInstanceOf[ArrayType].containsNull
def sortEval(array: Any, ascending: Boolean): Any = {
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
@@ -930,8 +934,8 @@ trait ArraySortLike extends ExpectsInputTypes {
} else {
s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};"
}
- val canPerformFastSort =
- CodeGenerator.isPrimitiveType(elementType) && elementType != BooleanType && !containsNull
+ val canPerformFastSort = CodeGenerator.isPrimitiveType(elementType) &&
+ elementType != BooleanType && !resultArrayElementNullable
val nonNullPrimitiveAscendingSort = if (canPerformFastSort) {
val javaType = CodeGenerator.javaType(elementType)
val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType)
@@ -1079,6 +1083,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None)
override def dataType: DataType = child.dataType
+ private def resultArrayElementNullable = dataType.asInstanceOf[ArrayType].containsNull
+
@transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
@transient private[this] var random: RandomIndicesGenerator = _
@@ -1118,7 +1124,7 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None)
val initialization = CodeGenerator.createArrayData(
arrayData, elementType, numElements, s" $prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(arrayData, elementType, childName,
- i, s"$indices[$i]", dataType.asInstanceOf[ArrayType].containsNull)
+ i, s"$indices[$i]", resultArrayElementNullable)
s"""
|int $numElements = $childName.numElements();
@@ -1162,6 +1168,8 @@ case class Reverse(child: Expression)
override def dataType: DataType = child.dataType
+ private def resultArrayElementNullable = dataType.asInstanceOf[ArrayType].containsNull
+
override def nullSafeEval(input: Any): Any = doReverse(input)
@transient private lazy val doReverse: Any => Any = dataType match {
@@ -1196,7 +1204,7 @@ case class Reverse(child: Expression)
val initialization = CodeGenerator.createArrayData(
arrayData, elementType, numElements, s" $prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(
- arrayData, elementType, childName, i, j, dataType.asInstanceOf[ArrayType].containsNull)
+ arrayData, elementType, childName, i, j, resultArrayElementNullable)
s"""
|final int $numElements = $childName.numElements();
@@ -1347,8 +1355,7 @@ case class ArraysOverlap(left: Expression, right: Expression)
}
override def nullable: Boolean = {
- left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull ||
- right.dataType.asInstanceOf[ArrayType].containsNull
+ left.nullable || right.nullable || leftArrayElementNullable || rightArrayElementNullable
}
override def nullSafeEval(a1: Any, a2: Any): Any = {
@@ -1560,6 +1567,8 @@ case class Slice(x: Expression, start: Expression, length: Expression)
override def dataType: DataType = x.dataType
+ private def resultArrayElementNullable = dataType.asInstanceOf[ArrayType].containsNull
+
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType)
override def first: Expression = x
@@ -1632,7 +1641,7 @@ case class Slice(x: Expression, start: Expression, length: Expression)
val allocation = CodeGenerator.createArrayData(
values, elementType, resLength, s" $prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(values, elementType, inputArray,
- i, s"$i + $startIdx", dataType.asInstanceOf[ArrayType].containsNull)
+ i, s"$i + $startIdx", resultArrayElementNullable)
s"""
|if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
@@ -2103,7 +2112,8 @@ case class ElementAt(
@transient private lazy val mapValueContainsNull =
left.dataType.asInstanceOf[MapType].valueContainsNull
- @transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull
+ @transient private lazy val arrayElementNullable =
+ left.dataType.asInstanceOf[ArrayType].containsNull
@transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType)
@@ -2189,7 +2199,7 @@ case class ElementAt(
} else {
array.numElements() + index
}
- if (arrayContainsNull && array.isNullAt(idx)) {
+ if (arrayElementNullable && array.isNullAt(idx)) {
null
} else {
array.get(idx, dataType)
@@ -2205,7 +2215,7 @@ case class ElementAt(
case _: ArrayType =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("elementAtIndex")
- val nullCheck = if (arrayContainsNull) {
+ val nullCheck = if (arrayElementNullable) {
s"""
|if ($eval1.isNullAt($index)) {
| ${ev.isNull} = true;
@@ -2353,6 +2363,8 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
}
}
+ private def resultArrayElementNullable = dataType.asInstanceOf[ArrayType].containsNull
+
private def javaType: String = CodeGenerator.javaType(dataType)
override def nullable: Boolean = children.exists(_.nullable)
@@ -2484,8 +2496,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
val initialization = CodeGenerator.createArrayData(
arrayData, elementType, numElemName, s" $prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(
- arrayData, elementType, s"args[$y]", counter, z,
- dataType.asInstanceOf[ArrayType].containsNull)
+ arrayData, elementType, s"args[$y]", counter, z, resultArrayElementNullable)
val concat = ctx.freshName("concat")
val concatDef =
@@ -2535,6 +2546,8 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran
@transient override lazy val dataType: DataType = childDataType.elementType
+ private def resultArrayElementNullable = dataType.asInstanceOf[ArrayType].containsNull
+
@transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
@@ -2604,8 +2617,7 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran
val allocation = CodeGenerator.createArrayData(
tempArrayDataName, elementType, numElemName, s" $prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(
- tempArrayDataName, elementType, arr, counter, l,
- dataType.asInstanceOf[ArrayType].containsNull)
+ tempArrayDataName, elementType, arr, counter, l, resultArrayElementNullable)
s"""
|$numElemCode
@@ -3486,6 +3498,8 @@ trait ArraySetLike {
@transient protected lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(et)
+ protected def resultArrayElementNullable = dt.asInstanceOf[ArrayType].containsNull
+
protected def genGetValue(array: String, i: String): String =
CodeGenerator.getValue(array, et, i)
@@ -3521,7 +3535,7 @@ trait ArraySetLike {
body: String,
value: String,
nullElementIndex: String): String = {
- if (dt.asInstanceOf[ArrayType].containsNull) {
+ if (resultArrayElementNullable) {
s"""
|$body
|if ($nullElementIndex >= 0) {
@@ -3662,7 +3676,7 @@ case class ArrayDistinct(child: Expression)
val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
// Only need to track null element index when array's element is nullable.
- val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
+ val declareNullTrackVariables = if (resultArrayElementNullable) {
s"""
|int $nullElementIndex = -1;
""".stripMargin
@@ -3692,8 +3706,8 @@ case class ArrayDistinct(child: Expression)
""".stripMargin)
val processArray = SQLOpenHashSet.withNullCheckCode(
- dataType.asInstanceOf[ArrayType].containsNull,
- dataType.asInstanceOf[ArrayType].containsNull,
+ resultArrayElementNullable,
+ resultArrayElementNullable,
array, i, hashSet, withNaNCheckCodeGenerator,
s"""
|$nullElementIndex = $size;
@@ -3880,8 +3894,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
""".stripMargin)
val processArray = SQLOpenHashSet.withNullCheckCode(
- dataType.asInstanceOf[ArrayType].containsNull,
- dataType.asInstanceOf[ArrayType].containsNull,
+ resultArrayElementNullable,
+ resultArrayElementNullable,
array, i, hashSet, withNaNCheckCodeGenerator,
s"""
|$nullElementIndex = $size;
@@ -3890,7 +3904,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
""".stripMargin)
// Only need to track null element index when result array's element is nullable.
- val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
+ val declareNullTrackVariables = if (resultArrayElementNullable) {
s"""
|int $nullElementIndex = -1;
""".stripMargin
@@ -3985,9 +3999,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
private lazy val internalDataType: DataType = {
dataTypeCheck
- ArrayType(elementType,
- left.dataType.asInstanceOf[ArrayType].containsNull &&
- right.dataType.asInstanceOf[ArrayType].containsNull)
+ ArrayType(elementType, leftArrayElementNullable && rightArrayElementNullable)
}
override def dataType: DataType = internalDataType
@@ -4122,8 +4134,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
(valueNaN: String) => "")
val writeArray2ToHashSet = SQLOpenHashSet.withNullCheckCode(
- right.dataType.asInstanceOf[ArrayType].containsNull,
- left.dataType.asInstanceOf[ArrayType].containsNull,
+ rightArrayElementNullable, leftArrayElementNullable,
array2, i, hashSet, withArray2NaNCheckCodeGenerator, "")
val body =
@@ -4151,8 +4162,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
""".stripMargin)
val processArray1 = SQLOpenHashSet.withNullCheckCode(
- left.dataType.asInstanceOf[ArrayType].containsNull,
- right.dataType.asInstanceOf[ArrayType].containsNull,
+ leftArrayElementNullable, rightArrayElementNullable,
array1, i, hashSetResult, withArray1NaNCheckCodeGenerator,
s"""
|if ($hashSet.containsNull()) {
@@ -4163,7 +4173,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
""".stripMargin)
// Only need to track null element index when result array's element is nullable.
- val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
+ val declareNullTrackVariables = if (resultArrayElementNullable) {
s"""
|int $nullElementIndex = -1;
""".stripMargin
@@ -4340,8 +4350,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
(valueNaN: Any) => "")
val writeArray2ToHashSet = SQLOpenHashSet.withNullCheckCode(
- right.dataType.asInstanceOf[ArrayType].containsNull,
- left.dataType.asInstanceOf[ArrayType].containsNull,
+ rightArrayElementNullable, leftArrayElementNullable,
array2, i, hashSet, withArray2NaNCheckCodeGenerator, "")
val body =
@@ -4366,8 +4375,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
""".stripMargin)
val processArray1 = SQLOpenHashSet.withNullCheckCode(
- left.dataType.asInstanceOf[ArrayType].containsNull,
- left.dataType.asInstanceOf[ArrayType].containsNull,
+ leftArrayElementNullable,
+ leftArrayElementNullable,
array1, i, hashSet, withArray1NaNCheckCodeGenerator,
s"""
|$nullElementIndex = $size;
@@ -4376,7 +4385,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
""".stripMargin)
// Only need to track null element index when array1's element is nullable.
- val declareNullTrackVariables = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
+ val declareNullTrackVariables = if (leftArrayElementNullable) {
s"""
|int $nullElementIndex = -1;
""".stripMargin
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index b6cbb1d0005..7b99b9d1082 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -282,7 +282,8 @@ case class GetArrayItem(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("index")
- val nullCheck = if (child.dataType.asInstanceOf[ArrayType].containsNull) {
+ val childArrayElementNullable = child.dataType.asInstanceOf[ArrayType].containsNull
+ val nullCheck = if (childArrayElementNullable) {
s"""else if ($eval1.isNullAt($index)) {
${ev.isNull} = true;
}
@@ -333,7 +334,7 @@ trait GetArrayItemUtil {
ordinal: Expression,
failOnError: Boolean,
nullability: (Seq[Expression], Int) => Boolean): Boolean = {
- val arrayContainsNull = child.dataType.asInstanceOf[ArrayType].containsNull
+ val arrayElementNullable = child.dataType.asInstanceOf[ArrayType].containsNull
if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
@@ -345,7 +346,7 @@ trait GetArrayItemUtil {
true
}
} else {
- if (failOnError) arrayContainsNull else true
+ if (failOnError) arrayElementNullable else true
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index b5708bae923..e139823b2bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -53,11 +53,12 @@ object ArrayType extends AbstractDataType {
* Please use `DataTypes.createArrayType()` to create a specific instance.
*
* An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and
- * `containsNull: Boolean`. The field of `elementType` is used to specify the type of
- * array elements. The field of `containsNull` is used to specify if the array has `null` values.
+ * `containsNull: Boolean`.
+ * The field of `elementType` is used to specify the type of array elements.
+ * The field of `containsNull` is used to specify if the array can have `null` values.
*
* @param elementType The data type of values.
- * @param containsNull Indicates if values have `null` values
+ * @param containsNull Indicates if the array can have `null` values
*
* @since 1.3.0
*/
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org