You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/07/16 15:16:33 UTC

spark git commit: [SPARK-24734][SQL] Fix type coercions and nullabilities of nested data types of some functions.

Repository: spark
Updated Branches:
  refs/heads/master cf9704534 -> b045315e5


[SPARK-24734][SQL] Fix type coercions and nullabilities of nested data types of some functions.

## What changes were proposed in this pull request?

We have some functions which need to aware the nullabilities of all children, such as `CreateArray`, `CreateMap`, `Concat`, and so on. Currently we add casts to fix the nullabilities, but the casts might be removed during the optimization phase.
After the discussion, we decided to not add extra casts for just fixing the nullabilities of the nested types, but handle them by functions themselves.

## How was this patch tested?

Modified and added some tests.

Author: Takuya UESHIN <ue...@databricks.com>

Closes #21704 from ueshin/issues/SPARK-24734/concat_containsnull.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b045315e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b045315e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b045315e

Branch: refs/heads/master
Commit: b045315e5d87b7ea3588436053aaa4d5a7bd103f
Parents: cf97045
Author: Takuya UESHIN <ue...@databricks.com>
Authored: Mon Jul 16 23:16:25 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Jul 16 23:16:25 2018 +0800

----------------------------------------------------------------------
 .../sql/catalyst/analysis/TypeCoercion.scala    | 113 ++++++++++---------
 .../sql/catalyst/expressions/Expression.scala   |  12 +-
 .../sql/catalyst/expressions/arithmetic.scala   |  14 +--
 .../expressions/collectionOperations.scala      |  22 ++--
 .../expressions/complexTypeCreator.scala        |  15 ++-
 .../expressions/conditionalExpressions.scala    |   4 +-
 .../sql/catalyst/expressions/literals.scala     |   2 +-
 .../catalyst/expressions/nullExpressions.scala  |   6 +-
 .../spark/sql/catalyst/util/TypeUtils.scala     |  16 +--
 .../catalyst/analysis/TypeCoercionSuite.scala   |  43 ++++---
 .../expressions/ArithmeticExpressionSuite.scala |  12 ++
 .../CollectionExpressionsSuite.scala            |  60 +++++++---
 .../catalyst/expressions/ComplexTypeSuite.scala |  19 ++++
 .../expressions/NullExpressionsSuite.scala      |   7 ++
 .../spark/sql/DataFrameFunctionsSuite.scala     |   8 ++
 15 files changed, 211 insertions(+), 142 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b045315e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index e8331c9..316aebd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -184,6 +184,17 @@ object TypeCoercion {
     }
   }
 
+  def findCommonTypeDifferentOnlyInNullFlags(types: Seq[DataType]): Option[DataType] = {
+    if (types.isEmpty) {
+      None
+    } else {
+      types.tail.foldLeft[Option[DataType]](Some(types.head)) {
+        case (Some(t1), t2) => findCommonTypeDifferentOnlyInNullFlags(t1, t2)
+        case _ => None
+      }
+    }
+  }
+
   /**
    * Case 2 type widening (see the classdoc comment above for TypeCoercion).
    *
@@ -259,8 +270,25 @@ object TypeCoercion {
     }
   }
 
-  private def haveSameType(exprs: Seq[Expression]): Boolean =
-    exprs.map(_.dataType).distinct.length == 1
+  /**
+   * Check whether the given types are equal ignoring nullable, containsNull and valueContainsNull.
+   */
+  def haveSameType(types: Seq[DataType]): Boolean = {
+    if (types.size <= 1) {
+      true
+    } else {
+      val head = types.head
+      types.tail.forall(_.sameType(head))
+    }
+  }
+
+  private def castIfNotSameType(expr: Expression, dt: DataType): Expression = {
+    if (!expr.dataType.sameType(dt)) {
+      Cast(expr, dt)
+    } else {
+      expr
+    }
+  }
 
   /**
    * Widens numeric types and converts strings to numbers when appropriate.
@@ -525,23 +553,24 @@ object TypeCoercion {
    * This ensure that the types for various functions are as expected.
    */
   object FunctionArgumentConversion extends TypeCoercionRule {
+
     override protected def coerceTypes(
         plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
       // Skip nodes who's children have not been resolved yet.
       case e if !e.childrenResolved => e
 
-      case a @ CreateArray(children) if !haveSameType(children) =>
+      case a @ CreateArray(children) if !haveSameType(children.map(_.dataType)) =>
         val types = children.map(_.dataType)
         findWiderCommonType(types) match {
-          case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
+          case Some(finalDataType) => CreateArray(children.map(castIfNotSameType(_, finalDataType)))
           case None => a
         }
 
       case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
-        !haveSameType(children) =>
+        !haveSameType(c.inputTypesForMerging) =>
         val types = children.map(_.dataType)
         findWiderCommonType(types) match {
-          case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType)))
+          case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType)))
           case None => c
         }
 
@@ -553,7 +582,8 @@ object TypeCoercion {
           case None => aj
         }
 
-      case s @ Sequence(_, _, _, timeZoneId) if !haveSameType(s.coercibleChildren) =>
+      case s @ Sequence(_, _, _, timeZoneId)
+          if !haveSameType(s.coercibleChildren.map(_.dataType)) =>
         val types = s.coercibleChildren.map(_.dataType)
         findWiderCommonType(types) match {
           case Some(widerDataType) => s.castChildrenTo(widerDataType)
@@ -561,33 +591,25 @@ object TypeCoercion {
         }
 
       case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) &&
-        !haveSameType(children) =>
+          !haveSameType(m.inputTypesForMerging) =>
         val types = children.map(_.dataType)
         findWiderCommonType(types) match {
-          case Some(finalDataType) => MapConcat(children.map(Cast(_, finalDataType)))
+          case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType)))
           case None => m
         }
 
       case m @ CreateMap(children) if m.keys.length == m.values.length &&
-        (!haveSameType(m.keys) || !haveSameType(m.values)) =>
-        val newKeys = if (haveSameType(m.keys)) {
-          m.keys
-        } else {
-          val types = m.keys.map(_.dataType)
-          findWiderCommonType(types) match {
-            case Some(finalDataType) => m.keys.map(Cast(_, finalDataType))
-            case None => m.keys
-          }
+          (!haveSameType(m.keys.map(_.dataType)) || !haveSameType(m.values.map(_.dataType))) =>
+        val keyTypes = m.keys.map(_.dataType)
+        val newKeys = findWiderCommonType(keyTypes) match {
+          case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType))
+          case None => m.keys
         }
 
-        val newValues = if (haveSameType(m.values)) {
-          m.values
-        } else {
-          val types = m.values.map(_.dataType)
-          findWiderCommonType(types) match {
-            case Some(finalDataType) => m.values.map(Cast(_, finalDataType))
-            case None => m.values
-          }
+        val valueTypes = m.values.map(_.dataType)
+        val newValues = findWiderCommonType(valueTypes) match {
+          case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType))
+          case None => m.values
         }
 
         CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
@@ -610,27 +632,27 @@ object TypeCoercion {
       // Coalesce should return the first non-null value, which could be any column
       // from the list. So we need to make sure the return type is deterministic and
       // compatible with every child column.
-      case c @ Coalesce(es) if !haveSameType(es) =>
+      case c @ Coalesce(es) if !haveSameType(c.inputTypesForMerging) =>
         val types = es.map(_.dataType)
         findWiderCommonType(types) match {
-          case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
+          case Some(finalDataType) => Coalesce(es.map(castIfNotSameType(_, finalDataType)))
           case None => c
         }
 
       // When finding wider type for `Greatest` and `Least`, we should handle decimal types even if
       // we need to truncate, but we should not promote one side to string if the other side is
       // string.g
-      case g @ Greatest(children) if !haveSameType(children) =>
+      case g @ Greatest(children) if !haveSameType(g.inputTypesForMerging) =>
         val types = children.map(_.dataType)
         findWiderTypeWithoutStringPromotion(types) match {
-          case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
+          case Some(finalDataType) => Greatest(children.map(castIfNotSameType(_, finalDataType)))
           case None => g
         }
 
-      case l @ Least(children) if !haveSameType(children) =>
+      case l @ Least(children) if !haveSameType(l.inputTypesForMerging) =>
         val types = children.map(_.dataType)
         findWiderTypeWithoutStringPromotion(types) match {
-          case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
+          case Some(finalDataType) => Least(children.map(castIfNotSameType(_, finalDataType)))
           case None => l
         }
 
@@ -672,27 +694,14 @@ object TypeCoercion {
   object CaseWhenCoercion extends TypeCoercionRule {
     override protected def coerceTypes(
         plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
-      case c: CaseWhen if c.childrenResolved && !c.areInputTypesForMergingEqual =>
+      case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) =>
         val maybeCommonType = findWiderCommonType(c.inputTypesForMerging)
         maybeCommonType.map { commonType =>
-          var changed = false
           val newBranches = c.branches.map { case (condition, value) =>
-            if (value.dataType.sameType(commonType)) {
-              (condition, value)
-            } else {
-              changed = true
-              (condition, Cast(value, commonType))
-            }
-          }
-          val newElseValue = c.elseValue.map { value =>
-            if (value.dataType.sameType(commonType)) {
-              value
-            } else {
-              changed = true
-              Cast(value, commonType)
-            }
+            (condition, castIfNotSameType(value, commonType))
           }
-          if (changed) CaseWhen(newBranches, newElseValue) else c
+          val newElseValue = c.elseValue.map(castIfNotSameType(_, commonType))
+          CaseWhen(newBranches, newElseValue)
         }.getOrElse(c)
     }
   }
@@ -705,10 +714,10 @@ object TypeCoercion {
         plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
       case e if !e.childrenResolved => e
       // Find tightest common type for If, if the true value and false value have different types.
-      case i @ If(pred, left, right) if !i.areInputTypesForMergingEqual =>
+      case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) =>
         findWiderTypeForTwo(left.dataType, right.dataType).map { widestType =>
-          val newLeft = if (left.dataType.sameType(widestType)) left else Cast(left, widestType)
-          val newRight = if (right.dataType.sameType(widestType)) right else Cast(right, widestType)
+          val newLeft = castIfNotSameType(left, widestType)
+          val newRight = castIfNotSameType(right, widestType)
           If(pred, newLeft, newRight)
         }.getOrElse(i)  // If there is no applicable conversion, leave expression unchanged.
       case If(Literal(null, NullType), left, right) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/b045315e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 44c5556..f7d1b10 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -709,22 +709,12 @@ trait ComplexTypeMergingExpression extends Expression {
   @transient
   lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType)
 
-  /**
-   * A method determining whether the input types are equal ignoring nullable, containsNull and
-   * valueContainsNull flags and thus convenient for resolution of the final data type.
-   */
-  def areInputTypesForMergingEqual: Boolean = {
-    inputTypesForMerging.length <= 1 || inputTypesForMerging.sliding(2, 1).forall {
-      case Seq(dt1, dt2) => dt1.sameType(dt2)
-    }
-  }
-
   override def dataType: DataType = {
     require(
       inputTypesForMerging.nonEmpty,
       "The collection of input data types must not be empty.")
     require(
-      areInputTypesForMergingEqual,
+      TypeCoercion.haveSameType(inputTypesForMerging),
       "All input types must be the same except nullable, containsNull, valueContainsNull flags.")
     inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/b045315e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index fe91e52..5594041 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.util.TypeUtils
@@ -514,7 +514,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
       > SELECT _FUNC_(10, 9, 2, 4, 3);
        2
   """)
-case class Least(children: Seq[Expression]) extends Expression {
+case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression {
 
   override def nullable: Boolean = children.forall(_.nullable)
   override def foldable: Boolean = children.forall(_.foldable)
@@ -525,7 +525,7 @@ case class Least(children: Seq[Expression]) extends Expression {
     if (children.length <= 1) {
       TypeCheckResult.TypeCheckFailure(
         s"input to function $prettyName requires at least two arguments")
-    } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
+    } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) {
       TypeCheckResult.TypeCheckFailure(
         s"The expressions should all have the same type," +
           s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).")
@@ -534,8 +534,6 @@ case class Least(children: Seq[Expression]) extends Expression {
     }
   }
 
-  override def dataType: DataType = children.head.dataType
-
   override def eval(input: InternalRow): Any = {
     children.foldLeft[Any](null)((r, c) => {
       val evalc = c.eval(input)
@@ -589,7 +587,7 @@ case class Least(children: Seq[Expression]) extends Expression {
       > SELECT _FUNC_(10, 9, 2, 4, 3);
        10
   """)
-case class Greatest(children: Seq[Expression]) extends Expression {
+case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpression {
 
   override def nullable: Boolean = children.forall(_.nullable)
   override def foldable: Boolean = children.forall(_.foldable)
@@ -600,7 +598,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
     if (children.length <= 1) {
       TypeCheckResult.TypeCheckFailure(
         s"input to function $prettyName requires at least two arguments")
-    } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
+    } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) {
       TypeCheckResult.TypeCheckFailure(
         s"The expressions should all have the same type," +
           s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).")
@@ -609,8 +607,6 @@ case class Greatest(children: Seq[Expression]) extends Expression {
     }
   }
 
-  override def dataType: DataType = children.head.dataType
-
   override def eval(input: InternalRow): Any = {
     children.foldLeft[Any](null)((r, c) => {
       val evalc = c.eval(input)

http://git-wip-us.apache.org/repos/asf/spark/blob/b045315e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 0f4f4f1..972bc6e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -507,7 +507,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
       > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd'));
        [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]]
   """, since = "2.4.0")
-case class MapConcat(children: Seq[Expression]) extends Expression {
+case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression {
 
   override def checkInputDataTypes(): TypeCheckResult = {
     var funcName = s"function $prettyName"
@@ -521,14 +521,10 @@ case class MapConcat(children: Seq[Expression]) extends Expression {
   }
 
   override def dataType: MapType = {
-    val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption
-      .getOrElse(MapType(StringType, StringType))
-    val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType])
-      .exists(_.valueContainsNull)
-    if (dt.valueContainsNull != valueContainsNull) {
-      dt.copy(valueContainsNull = valueContainsNull)
+    if (children.isEmpty) {
+      MapType(StringType, StringType)
     } else {
-      dt
+      super.dataType.asInstanceOf[MapType]
     }
   }
 
@@ -2211,7 +2207,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
       > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
  |     [1,2,3,4,5,6]
   """)
-case class Concat(children: Seq[Expression]) extends Expression {
+case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression {
 
   private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
 
@@ -2232,7 +2228,13 @@ case class Concat(children: Seq[Expression]) extends Expression {
     }
   }
 
-  override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)
+  override def dataType: DataType = {
+    if (children.isEmpty) {
+      StringType
+    } else {
+      super.dataType
+    }
+  }
 
   lazy val javaType: String = CodeGenerator.javaType(dataType)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b045315e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 0a5f8a9..a43de02 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -18,8 +18,8 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.util._
@@ -48,7 +48,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
 
   override def dataType: ArrayType = {
     ArrayType(
-      children.headOption.map(_.dataType).getOrElse(StringType),
+      TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(children.map(_.dataType))
+        .getOrElse(StringType),
       containsNull = children.exists(_.nullable))
   }
 
@@ -179,11 +180,11 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
     if (children.size % 2 != 0) {
       TypeCheckResult.TypeCheckFailure(
         s"$prettyName expects a positive even number of arguments.")
-    } else if (keys.map(_.dataType).distinct.length > 1) {
+    } else if (!TypeCoercion.haveSameType(keys.map(_.dataType))) {
       TypeCheckResult.TypeCheckFailure(
         "The given keys of function map should all be the same type, but they are " +
           keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
-    } else if (values.map(_.dataType).distinct.length > 1) {
+    } else if (!TypeCoercion.haveSameType(values.map(_.dataType))) {
       TypeCheckResult.TypeCheckFailure(
         "The given values of function map should all be the same type, but they are " +
           values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
@@ -194,8 +195,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
 
   override def dataType: DataType = {
     MapType(
-      keyType = keys.headOption.map(_.dataType).getOrElse(StringType),
-      valueType = values.headOption.map(_.dataType).getOrElse(StringType),
+      keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType))
+        .getOrElse(StringType),
+      valueType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(values.map(_.dataType))
+        .getOrElse(StringType),
       valueContainsNull = values.exists(_.nullable))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b045315e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 30ce9e4..3b597e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -48,7 +48,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
       TypeCheckResult.TypeCheckFailure(
         "type of predicate expression in If should be boolean, " +
           s"not ${predicate.dataType.simpleString}")
-    } else if (!areInputTypesForMergingEqual) {
+    } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) {
       TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
         s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
     } else {
@@ -137,7 +137,7 @@ case class CaseWhen(
   }
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    if (areInputTypesForMergingEqual) {
+    if (TypeCoercion.haveSameType(inputTypesForMerging)) {
       // Make sure all branch conditions are boolean types.
       if (branches.forall(_._1.dataType == BooleanType)) {
         TypeCheckResult.TypeCheckSuccess

http://git-wip-us.apache.org/repos/asf/spark/blob/b045315e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 0cc2a33..0efd122 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -186,7 +186,7 @@ object Literal {
     case map: MapType => create(Map(), map)
     case struct: StructType =>
       create(InternalRow.fromSeq(struct.fields.map(f => default(f.dataType).value)), struct)
-    case udt: UserDefinedType[_] => default(udt.sqlType)
+    case udt: UserDefinedType[_] => Literal(default(udt.sqlType).value, udt)
     case other =>
       throw new RuntimeException(s"no default for type $dataType")
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/b045315e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 2eeed3b..b683d2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.util.TypeUtils
@@ -44,7 +44,7 @@ import org.apache.spark.sql.types._
        1
   """)
 // scalastyle:on line.size.limit
-case class Coalesce(children: Seq[Expression]) extends Expression {
+case class Coalesce(children: Seq[Expression]) extends ComplexTypeMergingExpression {
 
   /** Coalesce is nullable if all of its children are nullable, or if it has no children. */
   override def nullable: Boolean = children.forall(_.nullable)
@@ -61,8 +61,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
     }
   }
 
-  override def dataType: DataType = children.head.dataType
-
   override def eval(input: InternalRow): Any = {
     var result: Any = null
     val childIterator = children.iterator

http://git-wip-us.apache.org/repos/asf/spark/blob/b045315e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 1dcda49..b795abe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.util
 
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
 import org.apache.spark.sql.catalyst.expressions.RowOrdering
 import org.apache.spark.sql.types._
 
@@ -42,18 +42,12 @@ object TypeUtils {
   }
 
   def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = {
-    if (types.size <= 1) {
+    if (TypeCoercion.haveSameType(types)) {
       TypeCheckResult.TypeCheckSuccess
     } else {
-      val firstType = types.head
-      types.foreach { t =>
-        if (!t.sameType(firstType)) {
-          return TypeCheckResult.TypeCheckFailure(
-            s"input to $caller should all be the same type, but it's " +
-              types.map(_.simpleString).mkString("[", ", ", "]"))
-        }
-      }
-      TypeCheckResult.TypeCheckSuccess
+      return TypeCheckResult.TypeCheckFailure(
+        s"input to $caller should all be the same type, but it's " +
+          types.map(_.simpleString).mkString("[", ", ", "]"))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b045315e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 8cc5a23..4161f09 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -687,46 +687,43 @@ class TypeCoercionSuite extends AnalysisTest {
 
     ruleTest(rule,
       Coalesce(Seq(doubleLit, intLit, floatLit)),
-      Coalesce(Seq(Cast(doubleLit, DoubleType),
-        Cast(intLit, DoubleType), Cast(floatLit, DoubleType))))
+      Coalesce(Seq(doubleLit, Cast(intLit, DoubleType), Cast(floatLit, DoubleType))))
 
     ruleTest(rule,
       Coalesce(Seq(longLit, intLit, decimalLit)),
       Coalesce(Seq(Cast(longLit, DecimalType(22, 0)),
-        Cast(intLit, DecimalType(22, 0)), Cast(decimalLit, DecimalType(22, 0)))))
+        Cast(intLit, DecimalType(22, 0)), decimalLit)))
 
     ruleTest(rule,
       Coalesce(Seq(nullLit, intLit)),
-      Coalesce(Seq(Cast(nullLit, IntegerType), Cast(intLit, IntegerType))))
+      Coalesce(Seq(Cast(nullLit, IntegerType), intLit)))
 
     ruleTest(rule,
       Coalesce(Seq(timestampLit, stringLit)),
-      Coalesce(Seq(Cast(timestampLit, StringType), Cast(stringLit, StringType))))
+      Coalesce(Seq(Cast(timestampLit, StringType), stringLit)))
 
     ruleTest(rule,
       Coalesce(Seq(nullLit, floatNullLit, intLit)),
-      Coalesce(Seq(Cast(nullLit, FloatType), Cast(floatNullLit, FloatType),
-        Cast(intLit, FloatType))))
+      Coalesce(Seq(Cast(nullLit, FloatType), floatNullLit, Cast(intLit, FloatType))))
 
     ruleTest(rule,
       Coalesce(Seq(nullLit, intLit, decimalLit, doubleLit)),
       Coalesce(Seq(Cast(nullLit, DoubleType), Cast(intLit, DoubleType),
-        Cast(decimalLit, DoubleType), Cast(doubleLit, DoubleType))))
+        Cast(decimalLit, DoubleType), doubleLit)))
 
     ruleTest(rule,
       Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)),
       Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType),
-        Cast(doubleLit, StringType), Cast(stringLit, StringType))))
+        Cast(doubleLit, StringType), stringLit)))
 
     ruleTest(rule,
       Coalesce(Seq(timestampLit, intLit, stringLit)),
-      Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType),
-        Cast(stringLit, StringType))))
+      Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), stringLit)))
 
     ruleTest(rule,
       Coalesce(Seq(tsArrayLit, intArrayLit, strArrayLit)),
       Coalesce(Seq(Cast(tsArrayLit, ArrayType(StringType)),
-        Cast(intArrayLit, ArrayType(StringType)), Cast(strArrayLit, ArrayType(StringType)))))
+        Cast(intArrayLit, ArrayType(StringType)), strArrayLit)))
   }
 
   test("CreateArray casts") {
@@ -735,7 +732,7 @@ class TypeCoercionSuite extends AnalysisTest {
         :: Literal(1)
         :: Literal.create(1.0, FloatType)
         :: Nil),
-      CreateArray(Cast(Literal(1.0), DoubleType)
+      CreateArray(Literal(1.0)
         :: Cast(Literal(1), DoubleType)
         :: Cast(Literal.create(1.0, FloatType), DoubleType)
         :: Nil))
@@ -747,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest {
         :: Nil),
       CreateArray(Cast(Literal(1.0), StringType)
         :: Cast(Literal(1), StringType)
-        :: Cast(Literal("a"), StringType)
+        :: Literal("a")
         :: Nil))
 
     ruleTest(TypeCoercion.FunctionArgumentConversion,
@@ -765,7 +762,7 @@ class TypeCoercionSuite extends AnalysisTest {
         :: Nil),
       CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 38))
         :: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 38))
-        :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38))
+        :: Literal.create(null, DecimalType(38, 38))
         :: Nil))
   }
 
@@ -779,7 +776,7 @@ class TypeCoercionSuite extends AnalysisTest {
         :: Nil),
       CreateMap(Cast(Literal(1), FloatType)
         :: Literal("a")
-        :: Cast(Literal.create(2.0, FloatType), FloatType)
+        :: Literal.create(2.0, FloatType)
         :: Literal("b")
         :: Nil))
     ruleTest(TypeCoercion.FunctionArgumentConversion,
@@ -801,7 +798,7 @@ class TypeCoercionSuite extends AnalysisTest {
         :: Literal(3.0)
         :: Nil),
       CreateMap(Literal(1)
-        :: Cast(Literal("a"), StringType)
+        :: Literal("a")
         :: Literal(2)
         :: Cast(Literal(3.0), StringType)
         :: Nil))
@@ -814,7 +811,7 @@ class TypeCoercionSuite extends AnalysisTest {
       CreateMap(Literal(1)
         :: Literal.create(null, DecimalType(38, 0)).cast(DecimalType(38, 38))
         :: Literal(2)
-        :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38))
+        :: Literal.create(null, DecimalType(38, 38))
         :: Nil))
     // type coercion for both map keys and values
     ruleTest(TypeCoercion.FunctionArgumentConversion,
@@ -824,8 +821,8 @@ class TypeCoercionSuite extends AnalysisTest {
         :: Literal(3.0)
         :: Nil),
       CreateMap(Cast(Literal(1), DoubleType)
-        :: Cast(Literal("a"), StringType)
-        :: Cast(Literal(2.0), DoubleType)
+        :: Literal("a")
+        :: Literal(2.0)
         :: Cast(Literal(3.0), StringType)
         :: Nil))
   }
@@ -837,7 +834,7 @@ class TypeCoercionSuite extends AnalysisTest {
           :: Literal(1)
           :: Literal.create(1.0, FloatType)
           :: Nil),
-        operator(Cast(Literal(1.0), DoubleType)
+        operator(Literal(1.0)
           :: Cast(Literal(1), DoubleType)
           :: Cast(Literal.create(1.0, FloatType), DoubleType)
           :: Nil))
@@ -848,14 +845,14 @@ class TypeCoercionSuite extends AnalysisTest {
           :: Nil),
         operator(Cast(Literal(1L), DecimalType(22, 0))
           :: Cast(Literal(1), DecimalType(22, 0))
-          :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
+          :: Literal(new java.math.BigDecimal("1000000000000000000000"))
           :: Nil))
       ruleTest(TypeCoercion.FunctionArgumentConversion,
         operator(Literal(1.0)
           :: Literal.create(null, DecimalType(10, 5))
           :: Literal(1)
           :: Nil),
-        operator(Literal(1.0).cast(DoubleType)
+        operator(Literal(1.0)
           :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType)
           :: Literal(1).cast(DoubleType)
           :: Nil))

http://git-wip-us.apache.org/repos/asf/spark/blob/b045315e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 6edb434..0212176 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -282,6 +282,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
     DataTypeTestUtils.ordered.foreach { dt =>
       checkConsistencyBetweenInterpretedAndCodegen(Least, dt, 2)
     }
+
+    val least = Least(Seq(
+      Literal.create(Seq(1, 2), ArrayType(IntegerType, containsNull = false)),
+      Literal.create(Seq(1, 3, null), ArrayType(IntegerType, containsNull = true))))
+    assert(least.dataType === ArrayType(IntegerType, containsNull = true))
+    checkEvaluation(least, Seq(1, 2))
   }
 
   test("function greatest") {
@@ -334,6 +340,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
     DataTypeTestUtils.ordered.foreach { dt =>
       checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2)
     }
+
+    val greatest = Greatest(Seq(
+      Literal.create(Seq(1, 2), ArrayType(IntegerType, containsNull = false)),
+      Literal.create(Seq(1, 3, null), ArrayType(IntegerType, containsNull = true))))
+    assert(greatest.dataType === ArrayType(IntegerType, containsNull = true))
+    checkEvaluation(greatest, Seq(1, 3, null))
   }
 
   test("SPARK-22499: Least and greatest should not generate codes beyond 64KB") {

http://git-wip-us.apache.org/repos/asf/spark/blob/b045315e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 85d6a1b..f1e3bd0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -227,6 +227,27 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull)
     assert(!MapConcat(Seq(m1, m2)).nullable)
     assert(MapConcat(Seq(m1, mNull)).nullable)
+
+    val mapConcat = MapConcat(Seq(
+      Literal.create(Map(Seq(1, 2) -> Seq("a", "b")),
+        MapType(
+          ArrayType(IntegerType, containsNull = false),
+          ArrayType(StringType, containsNull = false),
+          valueContainsNull = false)),
+      Literal.create(Map(Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null),
+        MapType(
+          ArrayType(IntegerType, containsNull = true),
+          ArrayType(StringType, containsNull = true),
+          valueContainsNull = true))))
+    assert(mapConcat.dataType ===
+      MapType(
+        ArrayType(IntegerType, containsNull = true),
+        ArrayType(StringType, containsNull = true),
+        valueContainsNull = true))
+    checkEvaluation(mapConcat, Map(
+      Seq(1, 2) -> Seq("a", "b"),
+      Seq(3, 4, null) -> Seq("c", "d", null),
+      Seq(6) -> null))
   }
 
   test("MapFromEntries") {
@@ -1050,11 +1071,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
 
   test("Concat") {
     // Primitive-type elements
-    val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
-    val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
-    val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType))
-    val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType))
-    val ai4 = Literal.create(null, ArrayType(IntegerType))
+    val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
+    val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType, containsNull = false))
+    val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType, containsNull = true))
+    val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType, containsNull = true))
+    val ai4 = Literal.create(null, ArrayType(IntegerType, containsNull = false))
 
     checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3))
     checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3))
@@ -1067,14 +1088,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(Concat(Seq(ai4, ai0)), null)
 
     // Non-primitive-type elements
-    val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType))
-    val as1 = Literal.create(Seq.empty[String], ArrayType(StringType))
-    val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType))
-    val as3 = Literal.create(Seq(null, null), ArrayType(StringType))
-    val as4 = Literal.create(null, ArrayType(StringType))
-
-    val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType)))
-    val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType)))
+    val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false))
+    val as1 = Literal.create(Seq.empty[String], ArrayType(StringType, containsNull = false))
+    val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType, containsNull = true))
+    val as3 = Literal.create(Seq(null, null), ArrayType(StringType, containsNull = true))
+    val as4 = Literal.create(null, ArrayType(StringType, containsNull = false))
+
+    val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")),
+      ArrayType(ArrayType(StringType, containsNull = false), containsNull = false))
+    val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")),
+      ArrayType(ArrayType(StringType, containsNull = false), containsNull = false))
+    val aa2 = Literal.create(Seq(Seq("g", null), null),
+      ArrayType(ArrayType(StringType, containsNull = true), containsNull = true))
 
     checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c"))
     checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c"))
@@ -1087,6 +1112,15 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(Concat(Seq(as4, as0)), null)
 
     checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f")))
+
+    assert(Concat(Seq(ai0, ai1)).dataType.asInstanceOf[ArrayType].containsNull === false)
+    assert(Concat(Seq(ai0, ai2)).dataType.asInstanceOf[ArrayType].containsNull === true)
+    assert(Concat(Seq(as0, as1)).dataType.asInstanceOf[ArrayType].containsNull === false)
+    assert(Concat(Seq(as0, as2)).dataType.asInstanceOf[ArrayType].containsNull === true)
+    assert(Concat(Seq(aa0, aa1)).dataType ===
+      ArrayType(ArrayType(StringType, containsNull = false), containsNull = false))
+    assert(Concat(Seq(aa0, aa2)).dataType ===
+      ArrayType(ArrayType(StringType, containsNull = true), containsNull = true))
   }
 
   test("Flatten") {

http://git-wip-us.apache.org/repos/asf/spark/blob/b045315e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 726193b..77aaf55 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -144,6 +144,13 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(CreateArray(byteWithNull), byteSeq :+ null, EmptyRow)
     checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow)
     checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil)
+
+    val array = CreateArray(Seq(
+      Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)),
+      Literal.create(intSeq :+ null, ArrayType(IntegerType, containsNull = true))))
+    assert(array.dataType ===
+      ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = false))
+    checkEvaluation(array, Seq(intSeq, intSeq :+ null))
   }
 
   test("CreateMap") {
@@ -184,6 +191,18 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
         CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
         null, null)
     }
+
+    val map = CreateMap(Seq(
+      Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)),
+      Literal.create(strSeq, ArrayType(StringType, containsNull = false)),
+      Literal.create(intSeq :+ null, ArrayType(IntegerType, containsNull = true)),
+      Literal.create(strSeq :+ null, ArrayType(StringType, containsNull = true))))
+    assert(map.dataType ===
+      MapType(
+        ArrayType(IntegerType, containsNull = true),
+        ArrayType(StringType, containsNull = true),
+        valueContainsNull = false))
+    checkEvaluation(map, createMap(Seq(intSeq, intSeq :+ null), Seq(strSeq, strSeq :+ null)))
   }
 
   test("MapFromArrays") {

http://git-wip-us.apache.org/repos/asf/spark/blob/b045315e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
index 424c3a4..6e07f7a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
@@ -86,6 +86,13 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
       checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value)
       checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value)
     }
+
+    val coalesce = Coalesce(Seq(
+      Literal.create(null, ArrayType(IntegerType, containsNull = false)),
+      Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)),
+      Literal.create(Seq(1, 2, 3, null), ArrayType(IntegerType, containsNull = true))))
+    assert(coalesce.dataType === ArrayType(IntegerType, containsNull = true))
+    checkEvaluation(coalesce, Seq(1, 2, 3))
   }
 
   test("SPARK-16602 Nvl should support numeric-string cases") {

http://git-wip-us.apache.org/repos/asf/spark/blob/b045315e/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index d461571..3f6f455 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1622,6 +1622,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
       assert(errMsg.contains(s"input to function $name requires at least two arguments"))
     }
   }
+
+  test("SPARK-24734: Fix containsNull of Concat for array type") {
+    val df = Seq((Seq(1), Seq[Integer](null), Seq("a", "b"))).toDF("k1", "k2", "v")
+    val ex = intercept[RuntimeException] {
+      df.select(map_from_arrays(concat($"k1", $"k2"), $"v")).show()
+    }
+    assert(ex.getMessage.contains("Cannot use null as map key"))
+  }
 }
 
 object DataFrameFunctionsSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org