You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/03/04 13:27:46 UTC
[spark] branch master updated: [SPARK-26965][SQL] Makes ElementAt
nullability more precise for array cases
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 68fbbbe [SPARK-26965][SQL] Makes ElementAt nullability more precise for array cases
68fbbbe is described below
commit 68fbbbea4e13b53ae8304f21ae727a159bd12559
Author: Takeshi Yamamuro <ya...@apache.org>
AuthorDate: Mon Mar 4 21:27:18 2019 +0800
[SPARK-26965][SQL] Makes ElementAt nullability more precise for array cases
## What changes were proposed in this pull request?
In master, `ElementAt` nullable is always true;
https://github.com/apache/spark/blob/be1cadf16dc70e22eae144b3dfce9e269ef95acc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L1977
But, If input is an array and foldable, we could make its nullability more precise.
This fix is based on SPARK-26637(#23566).
## How was this patch tested?
Added tests in `CollectionExpressionsSuite`.
Closes #23867 from maropu/SPARK-26965.
Authored-by: Takeshi Yamamuro <ya...@apache.org>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../expressions/collectionOperations.scala | 8 ++-
.../expressions/complexTypeExtractors.scala | 74 +++++++++++-----------
.../expressions/CollectionExpressionsSuite.scala | 33 ++++++++++
.../catalyst/expressions/ComplexTypeSuite.scala | 19 ------
4 files changed, 76 insertions(+), 58 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 67f6739..018b6b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1929,7 +1929,8 @@ case class ArrayPosition(left: Expression, right: Expression)
b
""",
since = "2.4.0")
-case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
+case class ElementAt(left: Expression, right: Expression)
+ extends GetMapValueUtil with GetArrayItemUtil {
@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType
@@ -1974,7 +1975,10 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
}
}
- override def nullable: Boolean = true
+ override def nullable: Boolean = left.dataType match {
+ case _: ArrayType => computeNullabilityFromArray(left, right)
+ case _: MapType => true
+ }
override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 55ed617..e9d60ed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -221,7 +221,8 @@ case class GetArrayStructFields(
* We need to do type checking here as `ordinal` expression maybe unresolved.
*/
case class GetArrayItem(child: Expression, ordinal: Expression)
- extends BinaryExpression with ExpectsInputTypes with ExtractValue with NullIntolerant {
+ extends BinaryExpression with GetArrayItemUtil with ExpectsInputTypes with ExtractValue
+ with NullIntolerant {
// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
@@ -231,23 +232,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
override def left: Expression = child
override def right: Expression = ordinal
-
- /** `Null` is returned for invalid ordinals. */
- override def nullable: Boolean = if (ordinal.foldable && !ordinal.nullable) {
- val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
- child match {
- case CreateArray(ar) if intOrdinal < ar.length =>
- ar(intOrdinal).nullable
- case GetArrayStructFields(CreateArray(elements), field, _, _, _)
- if intOrdinal < elements.length =>
- elements(intOrdinal).nullable || field.nullable
- case _ =>
- true
- }
- } else {
- true
- }
-
+ override def nullable: Boolean = computeNullabilityFromArray(left, right)
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
@@ -281,10 +266,34 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
}
/**
- * Common base class for [[GetMapValue]] and [[ElementAt]].
+ * Common trait for [[GetArrayItem]] and [[ElementAt]].
*/
+trait GetArrayItemUtil {
+
+ /** `Null` is returned for invalid ordinals. */
+ protected def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
+ if (ordinal.foldable && !ordinal.nullable) {
+ val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
+ child match {
+ case CreateArray(ar) if intOrdinal < ar.length =>
+ ar(intOrdinal).nullable
+ case GetArrayStructFields(CreateArray(elements), field, _, _, _)
+ if intOrdinal < elements.length =>
+ elements(intOrdinal).nullable || field.nullable
+ case _ =>
+ true
+ }
+ } else {
+ true
+ }
+ }
+}
+
+/**
+ * Common trait for [[GetMapValue]] and [[ElementAt]].
+ */
+trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
-abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
// todo: current search is O(n), improve it.
def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = {
val map = value.asInstanceOf[MapData]
@@ -380,23 +389,14 @@ case class GetMapValue(child: Expression, key: Expression)
override def left: Expression = child
override def right: Expression = key
- /** `Null` is returned for invalid ordinals. */
- override def nullable: Boolean = if (key.foldable && !key.nullable) {
- val keyObj = key.eval()
- child match {
- case m: CreateMap if m.resolved =>
- m.keys.zip(m.values).filter { case (k, _) => k.foldable && !k.nullable }.find {
- case (k, _) if k.eval() == keyObj => true
- case _ => false
- }.map(_._2.nullable).getOrElse(true)
- case _ =>
- true
- }
- } else {
- true
- }
-
-
+ /**
+ * `Null` is returned for invalid ordinals.
+ *
+ * TODO: We could make nullability more precise in foldable cases (e.g., literal input).
+ * But, since the key search is O(n), it takes much time to compute nullability.
+ * If we find efficient key searches, revisit this.
+ */
+ override def nullable: Boolean = true
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
// todo: current search is O(n), improve it.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 910e6c8..c1c0459 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -1092,6 +1092,39 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null)
}
+ test("correctly handles ElementAt nullability for arrays") {
+ // CreateArray case
+ val a = AttributeReference("a", IntegerType, nullable = false)()
+ val b = AttributeReference("b", IntegerType, nullable = true)()
+ val array = CreateArray(a :: b :: Nil)
+ assert(!ElementAt(array, Literal(0)).nullable)
+ assert(ElementAt(array, Literal(1)).nullable)
+ assert(!ElementAt(array, Subtract(Literal(2), Literal(2))).nullable)
+ assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable)
+
+ // GetArrayStructFields case
+ val f1 = StructField("a", IntegerType, nullable = false)
+ val f2 = StructField("b", IntegerType, nullable = true)
+ val structType = StructType(f1 :: f2 :: Nil)
+ val c = AttributeReference("c", structType, nullable = false)()
+ val inputArray1 = CreateArray(c :: Nil)
+ val inputArray1ContainsNull = c.nullable
+ val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
+ assert(!ElementAt(stArray1, Literal(0)).nullable)
+ val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
+ assert(ElementAt(stArray2, Literal(0)).nullable)
+
+ val d = AttributeReference("d", structType, nullable = true)()
+ val inputArray2 = CreateArray(c :: d :: Nil)
+ val inputArray2ContainsNull = c.nullable || d.nullable
+ val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
+ assert(!ElementAt(stArray3, Literal(0)).nullable)
+ assert(ElementAt(stArray3, Literal(1)).nullable)
+ val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
+ assert(ElementAt(stArray4, Literal(0)).nullable)
+ assert(ElementAt(stArray4, Literal(1)).nullable)
+ }
+
test("Concat") {
// Primitive-type elements
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 67f748c..0c44389 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -110,25 +110,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c"))
}
- test("SPARK-26747 handles GetMapValue nullability correctly when input key is foldable") {
- // String key test
- val k1 = Literal("k1")
- val v1 = AttributeReference("v1", StringType, nullable = true)()
- val k2 = Literal("k2")
- val v2 = AttributeReference("v2", StringType, nullable = false)()
- val map1 = CreateMap(k1 :: v1 :: k2 :: v2 :: Nil)
- assert(GetMapValue(map1, Literal("k1")).nullable)
- assert(!GetMapValue(map1, Literal("k2")).nullable)
- assert(GetMapValue(map1, Literal("non-existent-key")).nullable)
-
- // Complex type key test
- val k3 = Literal.create((1, "a"))
- val k4 = Literal.create((2, "b"))
- val map2 = CreateMap(k3 :: v1 :: k4 :: v2 :: Nil)
- assert(GetMapValue(map2, Literal.create((1, "a"))).nullable)
- assert(!GetMapValue(map2, Literal.create((2, "b"))).nullable)
- }
-
test("GetStructField") {
val typeS = StructType(StructField("a", IntegerType) :: Nil)
val struct = Literal.create(create_row(1), typeS)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org