You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/03/04 13:27:46 UTC

[spark] branch master updated: [SPARK-26965][SQL] Makes ElementAt nullability more precise for array cases

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 68fbbbe  [SPARK-26965][SQL] Makes ElementAt nullability more precise for array cases
68fbbbe is described below

commit 68fbbbea4e13b53ae8304f21ae727a159bd12559
Author: Takeshi Yamamuro <ya...@apache.org>
AuthorDate: Mon Mar 4 21:27:18 2019 +0800

    [SPARK-26965][SQL] Makes ElementAt nullability more precise for array cases
    
    ## What changes were proposed in this pull request?
    In master, `ElementAt` nullable is always true;
    https://github.com/apache/spark/blob/be1cadf16dc70e22eae144b3dfce9e269ef95acc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L1977
    
    But, If input is an array and foldable, we could make its nullability more precise.
    This fix is based on  SPARK-26637(#23566).
    
    ## How was this patch tested?
    Added tests in `CollectionExpressionsSuite`.
    
    Closes #23867 from maropu/SPARK-26965.
    
    Authored-by: Takeshi Yamamuro <ya...@apache.org>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../expressions/collectionOperations.scala         |  8 ++-
 .../expressions/complexTypeExtractors.scala        | 74 +++++++++++-----------
 .../expressions/CollectionExpressionsSuite.scala   | 33 ++++++++++
 .../catalyst/expressions/ComplexTypeSuite.scala    | 19 ------
 4 files changed, 76 insertions(+), 58 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 67f6739..018b6b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1929,7 +1929,8 @@ case class ArrayPosition(left: Expression, right: Expression)
        b
   """,
   since = "2.4.0")
-case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
+case class ElementAt(left: Expression, right: Expression)
+  extends GetMapValueUtil with GetArrayItemUtil {
 
   @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType
 
@@ -1974,7 +1975,10 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
     }
   }
 
-  override def nullable: Boolean = true
+  override def nullable: Boolean = left.dataType match {
+    case _: ArrayType => computeNullabilityFromArray(left, right)
+    case _: MapType => true
+  }
 
   override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal)
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 55ed617..e9d60ed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -221,7 +221,8 @@ case class GetArrayStructFields(
  * We need to do type checking here as `ordinal` expression maybe unresolved.
  */
 case class GetArrayItem(child: Expression, ordinal: Expression)
-  extends BinaryExpression with ExpectsInputTypes with ExtractValue with NullIntolerant {
+  extends BinaryExpression with GetArrayItemUtil with ExpectsInputTypes with ExtractValue
+  with NullIntolerant {
 
   // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
@@ -231,23 +232,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
 
   override def left: Expression = child
   override def right: Expression = ordinal
-
-  /** `Null` is returned for invalid ordinals. */
-  override def nullable: Boolean = if (ordinal.foldable && !ordinal.nullable) {
-    val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
-    child match {
-      case CreateArray(ar) if intOrdinal < ar.length =>
-        ar(intOrdinal).nullable
-      case GetArrayStructFields(CreateArray(elements), field, _, _, _)
-          if intOrdinal < elements.length =>
-        elements(intOrdinal).nullable || field.nullable
-      case _ =>
-        true
-    }
-  } else {
-    true
-  }
-
+  override def nullable: Boolean = computeNullabilityFromArray(left, right)
   override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
 
   protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
@@ -281,10 +266,34 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
 }
 
 /**
- * Common base class for [[GetMapValue]] and [[ElementAt]].
+ * Common trait for [[GetArrayItem]] and [[ElementAt]].
  */
+trait GetArrayItemUtil {
+
+  /** `Null` is returned for invalid ordinals. */
+  protected def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
+    if (ordinal.foldable && !ordinal.nullable) {
+      val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
+      child match {
+        case CreateArray(ar) if intOrdinal < ar.length =>
+          ar(intOrdinal).nullable
+        case GetArrayStructFields(CreateArray(elements), field, _, _, _)
+          if intOrdinal < elements.length =>
+          elements(intOrdinal).nullable || field.nullable
+        case _ =>
+          true
+      }
+    } else {
+      true
+    }
+  }
+}
+
+/**
+ * Common trait for [[GetMapValue]] and [[ElementAt]].
+ */
+trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
 
-abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
   // todo: current search is O(n), improve it.
   def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = {
     val map = value.asInstanceOf[MapData]
@@ -380,23 +389,14 @@ case class GetMapValue(child: Expression, key: Expression)
   override def left: Expression = child
   override def right: Expression = key
 
-  /** `Null` is returned for invalid ordinals. */
-  override def nullable: Boolean = if (key.foldable && !key.nullable) {
-    val keyObj = key.eval()
-    child match {
-      case m: CreateMap if m.resolved =>
-        m.keys.zip(m.values).filter { case (k, _) => k.foldable && !k.nullable }.find {
-          case (k, _) if k.eval() == keyObj => true
-          case _ => false
-        }.map(_._2.nullable).getOrElse(true)
-      case _ =>
-        true
-    }
-  } else {
-    true
-  }
-
-
+  /**
+   * `Null` is returned for invalid ordinals.
+   *
+   * TODO: We could make nullability more precise in foldable cases (e.g., literal input).
+   * But, since the key search is O(n), it takes much time to compute nullability.
+   * If we find efficient key searches, revisit this.
+   */
+  override def nullable: Boolean = true
   override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
 
   // todo: current search is O(n), improve it.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 910e6c8..c1c0459 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -1092,6 +1092,39 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null)
   }
 
+  test("correctly handles ElementAt nullability for arrays") {
+    // CreateArray case
+    val a = AttributeReference("a", IntegerType, nullable = false)()
+    val b = AttributeReference("b", IntegerType, nullable = true)()
+    val array = CreateArray(a :: b :: Nil)
+    assert(!ElementAt(array, Literal(0)).nullable)
+    assert(ElementAt(array, Literal(1)).nullable)
+    assert(!ElementAt(array, Subtract(Literal(2), Literal(2))).nullable)
+    assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable)
+
+    // GetArrayStructFields case
+    val f1 = StructField("a", IntegerType, nullable = false)
+    val f2 = StructField("b", IntegerType, nullable = true)
+    val structType = StructType(f1 :: f2 :: Nil)
+    val c = AttributeReference("c", structType, nullable = false)()
+    val inputArray1 = CreateArray(c :: Nil)
+    val inputArray1ContainsNull = c.nullable
+    val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
+    assert(!ElementAt(stArray1, Literal(0)).nullable)
+    val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
+    assert(ElementAt(stArray2, Literal(0)).nullable)
+
+    val d = AttributeReference("d", structType, nullable = true)()
+    val inputArray2 = CreateArray(c :: d :: Nil)
+    val inputArray2ContainsNull = c.nullable || d.nullable
+    val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
+    assert(!ElementAt(stArray3, Literal(0)).nullable)
+    assert(ElementAt(stArray3, Literal(1)).nullable)
+    val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
+    assert(ElementAt(stArray4, Literal(0)).nullable)
+    assert(ElementAt(stArray4, Literal(1)).nullable)
+  }
+
   test("Concat") {
     // Primitive-type elements
     val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 67f748c..0c44389 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -110,25 +110,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c"))
   }
 
-  test("SPARK-26747 handles GetMapValue nullability correctly when input key is foldable") {
-    // String key test
-    val k1 = Literal("k1")
-    val v1 = AttributeReference("v1", StringType, nullable = true)()
-    val k2 = Literal("k2")
-    val v2 = AttributeReference("v2", StringType, nullable = false)()
-    val map1 = CreateMap(k1 :: v1 :: k2 :: v2 :: Nil)
-    assert(GetMapValue(map1, Literal("k1")).nullable)
-    assert(!GetMapValue(map1, Literal("k2")).nullable)
-    assert(GetMapValue(map1, Literal("non-existent-key")).nullable)
-
-    // Complex type key test
-    val k3 = Literal.create((1, "a"))
-    val k4 = Literal.create((2, "b"))
-    val map2 = CreateMap(k3 :: v1 :: k4 :: v2 :: Nil)
-    assert(GetMapValue(map2, Literal.create((1, "a"))).nullable)
-    assert(!GetMapValue(map2, Literal.create((2, "b"))).nullable)
-  }
-
   test("GetStructField") {
     val typeS = StructType(StructField("a", IntegerType) :: Nil)
     val struct = Literal.create(create_row(1), typeS)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org