You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/12/12 07:45:30 UTC

spark git commit: [SPARK-4293][SQL] Make Cast be able to handle complex types.

Repository: spark
Updated Branches:
  refs/heads/master c152dde78 -> 334480362


[SPARK-4293][SQL] Make Cast be able to handle complex types.

Inserting data of type including `ArrayType.containsNull == false` or `MapType.valueContainsNull == false` or `StructType.fields.exists(_.nullable == false)` into Hive table will fail because `Cast` inserted by `HiveMetastoreCatalog.PreInsertionCasts` rule of `Analyzer` can't handle these types correctly.

Complex type cast rule proposal:

- Cast for non-complex types should be able to cast the same as before.
- Cast for `ArrayType` can evaluate if
  - Element type can cast
  - Nullability rule doesn't break
- Cast for `MapType` can evaluate if
  - Key type can cast
  - Nullability for casted key type is `false`
  - Value type can cast
  - Nullability rule for value type doesn't break
- Cast for `StructType` can evaluate if
  - The field size is the same
  - Each field can cast
  - Nullability rule for each field doesn't break
- The nested structure should be the same.

Nullability rule:

- If the casted type is `nullable == true`, the target nullability should be `true`

Author: Takuya UESHIN <ue...@happy-camper.st>

Closes #3150 from ueshin/issues/SPARK-4293 and squashes the following commits:

e935939 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4293
ba14003 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4293
8999868 [Takuya UESHIN] Fix a test title.
f677c30 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4293
287f410 [Takuya UESHIN] Add tests to insert data of types ArrayType / MapType / StructType with nullability is false into Hive table.
4f71bb8 [Takuya UESHIN] Make Cast be able to handle complex types.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/33448036
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/33448036
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/33448036

Branch: refs/heads/master
Commit: 334480362b3a133c2fb1e9af898930fe76d7a163
Parents: c152dde
Author: Takuya UESHIN <ue...@happy-camper.st>
Authored: Thu Dec 11 22:45:25 2014 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Thu Dec 11 22:45:25 2014 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Cast.scala   | 161 +++++++++----
 .../expressions/ExpressionEvaluationSuite.scala | 236 +++++++++++++++++++
 2 files changed, 353 insertions(+), 44 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/33448036/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index b47865f..4ede0b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -27,9 +27,14 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal
 
 /** Cast the child expression to the target data type. */
 case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
+
+  override lazy val resolved = childrenResolved && resolve(child.dataType, dataType)
+
   override def foldable = child.foldable
 
-  override def nullable = (child.dataType, dataType) match {
+  override def nullable = forceNullable(child.dataType, dataType) || child.nullable
+
+  private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match {
     case (StringType, _: NumericType) => true
     case (StringType, TimestampType)  => true
     case (DoubleType, TimestampType)  => true
@@ -41,8 +46,62 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
     case (DateType, BooleanType)      => true
     case (DoubleType, _: DecimalType) => true
     case (FloatType, _: DecimalType)  => true
-    case (_, DecimalType.Fixed(_, _)) => true  // TODO: not all upcasts here can really give null
-    case _                            => child.nullable
+    case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null
+    case _                            => false
+  }
+
+  private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from || to
+
+  private[this] def resolve(from: DataType, to: DataType): Boolean = {
+    (from, to) match {
+      case (from, to) if from == to         => true
+
+      case (NullType, _)                    => true
+
+      case (_, StringType)                  => true
+
+      case (StringType, BinaryType)         => true
+
+      case (StringType, BooleanType)        => true
+      case (DateType, BooleanType)          => true
+      case (TimestampType, BooleanType)     => true
+      case (_: NumericType, BooleanType)    => true
+
+      case (StringType, TimestampType)      => true
+      case (BooleanType, TimestampType)     => true
+      case (DateType, TimestampType)        => true
+      case (_: NumericType, TimestampType)  => true
+
+      case (_, DateType)                    => true
+
+      case (StringType, _: NumericType)     => true
+      case (BooleanType, _: NumericType)    => true
+      case (DateType, _: NumericType)       => true
+      case (TimestampType, _: NumericType)  => true
+      case (_: NumericType, _: NumericType) => true
+
+      case (ArrayType(from, fn), ArrayType(to, tn)) =>
+        resolve(from, to) &&
+          resolvableNullability(fn || forceNullable(from, to), tn)
+
+      case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
+        resolve(fromKey, toKey) &&
+          (!forceNullable(fromKey, toKey)) &&
+          resolve(fromValue, toValue) &&
+          resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
+
+      case (StructType(fromFields), StructType(toFields)) =>
+        fromFields.size == toFields.size &&
+          fromFields.zip(toFields).forall {
+            case (fromField, toField) =>
+              resolve(fromField.dataType, toField.dataType) &&
+                resolvableNullability(
+                  fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
+                  toField.nullable)
+          }
+
+      case _ => false
+    }
   }
 
   override def toString = s"CAST($child, $dataType)"
@@ -53,7 +112,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
   @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T])
 
   // UDFToString
-  private[this] def castToString: Any => Any = child.dataType match {
+  private[this] def castToString(from: DataType): Any => Any = from match {
     case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
     case DateType => buildCast[Date](_, dateToString)
     case TimestampType => buildCast[Timestamp](_, timestampToString)
@@ -61,12 +120,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
   }
 
   // BinaryConverter
-  private[this] def castToBinary: Any => Any = child.dataType match {
+  private[this] def castToBinary(from: DataType): Any => Any = from match {
     case StringType => buildCast[String](_, _.getBytes("UTF-8"))
   }
 
   // UDFToBoolean
-  private[this] def castToBoolean: Any => Any = child.dataType match {
+  private[this] def castToBoolean(from: DataType): Any => Any = from match {
     case StringType =>
       buildCast[String](_, _.length() != 0)
     case TimestampType =>
@@ -91,7 +150,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
   }
 
   // TimestampConverter
-  private[this] def castToTimestamp: Any => Any = child.dataType match {
+  private[this] def castToTimestamp(from: DataType): Any => Any = from match {
     case StringType =>
       buildCast[String](_, s => {
         // Throw away extra if more than 9 decimal places
@@ -133,7 +192,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       })
   }
 
-  private[this]  def decimalToTimestamp(d: Decimal) = {
+  private[this] def decimalToTimestamp(d: Decimal) = {
     val seconds = Math.floor(d.toDouble).toLong
     val bd = (d.toBigDecimal - seconds) * 1000000000
     val nanos = bd.intValue()
@@ -172,11 +231,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
   }
 
   // DateConverter
-  private[this] def castToDate: Any => Any = child.dataType match {
+  private[this] def castToDate(from: DataType): Any => Any = from match {
     case StringType =>
       buildCast[String](_, s =>
-        try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null }
-      )
+        try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null })
     case TimestampType =>
       // throw valid precision more than seconds, according to Hive.
       // Timestamp.nanos is in 0 to 999,999,999, no more than a second.
@@ -199,7 +257,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
   }
 
   // LongConverter
-  private[this] def castToLong: Any => Any = child.dataType match {
+  private[this] def castToLong(from: DataType): Any => Any = from match {
     case StringType =>
       buildCast[String](_, s => try s.toLong catch {
         case _: NumberFormatException => null
@@ -210,14 +268,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       buildCast[Date](_, d => dateToLong(d))
     case TimestampType =>
       buildCast[Timestamp](_, t => timestampToLong(t))
-    case DecimalType() =>
-      buildCast[Decimal](_, _.toLong)
     case x: NumericType =>
       b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
   }
 
   // IntConverter
-  private[this] def castToInt: Any => Any = child.dataType match {
+  private[this] def castToInt(from: DataType): Any => Any = from match {
     case StringType =>
       buildCast[String](_, s => try s.toInt catch {
         case _: NumberFormatException => null
@@ -228,14 +284,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       buildCast[Date](_, d => dateToLong(d))
     case TimestampType =>
       buildCast[Timestamp](_, t => timestampToLong(t).toInt)
-    case DecimalType() =>
-      buildCast[Decimal](_, _.toInt)
     case x: NumericType =>
       b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
   }
 
   // ShortConverter
-  private[this] def castToShort: Any => Any = child.dataType match {
+  private[this] def castToShort(from: DataType): Any => Any = from match {
     case StringType =>
       buildCast[String](_, s => try s.toShort catch {
         case _: NumberFormatException => null
@@ -246,14 +300,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       buildCast[Date](_, d => dateToLong(d))
     case TimestampType =>
       buildCast[Timestamp](_, t => timestampToLong(t).toShort)
-    case DecimalType() =>
-      buildCast[Decimal](_, _.toShort)
     case x: NumericType =>
       b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
   }
 
   // ByteConverter
-  private[this] def castToByte: Any => Any = child.dataType match {
+  private[this] def castToByte(from: DataType): Any => Any = from match {
     case StringType =>
       buildCast[String](_, s => try s.toByte catch {
         case _: NumberFormatException => null
@@ -264,8 +316,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       buildCast[Date](_, d => dateToLong(d))
     case TimestampType =>
       buildCast[Timestamp](_, t => timestampToLong(t).toByte)
-    case DecimalType() =>
-      buildCast[Decimal](_, _.toByte)
     case x: NumericType =>
       b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
   }
@@ -285,7 +335,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
     }
   }
 
-  private[this] def castToDecimal(target: DecimalType): Any => Any = child.dataType match {
+  private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
     case StringType =>
       buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch {
         case _: NumberFormatException => null
@@ -301,7 +351,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       b => changePrecision(b.asInstanceOf[Decimal].clone(), target)
     case LongType =>
       b => changePrecision(Decimal(b.asInstanceOf[Long]), target)
-    case x: NumericType =>  // All other numeric types can be represented precisely as Doubles
+    case x: NumericType => // All other numeric types can be represented precisely as Doubles
       b => try {
         changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target)
       } catch {
@@ -310,7 +360,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
   }
 
   // DoubleConverter
-  private[this] def castToDouble: Any => Any = child.dataType match {
+  private[this] def castToDouble(from: DataType): Any => Any = from match {
     case StringType =>
       buildCast[String](_, s => try s.toDouble catch {
         case _: NumberFormatException => null
@@ -321,14 +371,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       buildCast[Date](_, d => dateToDouble(d))
     case TimestampType =>
       buildCast[Timestamp](_, t => timestampToDouble(t))
-    case DecimalType() =>
-      buildCast[Decimal](_, _.toDouble)
     case x: NumericType =>
       b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
   }
 
   // FloatConverter
-  private[this] def castToFloat: Any => Any = child.dataType match {
+  private[this] def castToFloat(from: DataType): Any => Any = from match {
     case StringType =>
       buildCast[String](_, s => try s.toFloat catch {
         case _: NumberFormatException => null
@@ -339,28 +387,53 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       buildCast[Date](_, d => dateToDouble(d))
     case TimestampType =>
       buildCast[Timestamp](_, t => timestampToDouble(t).toFloat)
-    case DecimalType() =>
-      buildCast[Decimal](_, _.toFloat)
     case x: NumericType =>
       b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
   }
 
-  private[this] lazy val cast: Any => Any = dataType match {
+  private[this] def castArray(from: ArrayType, to: ArrayType): Any => Any = {
+    val elementCast = cast(from.elementType, to.elementType)
+    buildCast[Seq[Any]](_, _.map(v => if (v == null) null else elementCast(v)))
+  }
+
+  private[this] def castMap(from: MapType, to: MapType): Any => Any = {
+    val keyCast = cast(from.keyType, to.keyType)
+    val valueCast = cast(from.valueType, to.valueType)
+    buildCast[Map[Any, Any]](_, _.map {
+      case (key, value) => (keyCast(key), if (value == null) null else valueCast(value))
+    })
+  }
+
+  private[this] def castStruct(from: StructType, to: StructType): Any => Any = {
+    val casts = from.fields.zip(to.fields).map {
+      case (fromField, toField) => cast(fromField.dataType, toField.dataType)
+    }
+    buildCast[Row](_, row => Row(row.zip(casts).map {
+      case (v, cast) => if (v == null) null else cast(v)
+    }: _*))
+  }
+
+  private[this] def cast(from: DataType, to: DataType): Any => Any = to match {
     case dt if dt == child.dataType => identity[Any]
-    case StringType    => castToString
-    case BinaryType    => castToBinary
-    case DateType      => castToDate
-    case decimal: DecimalType => castToDecimal(decimal)
-    case TimestampType => castToTimestamp
-    case BooleanType   => castToBoolean
-    case ByteType      => castToByte
-    case ShortType     => castToShort
-    case IntegerType   => castToInt
-    case FloatType     => castToFloat
-    case LongType      => castToLong
-    case DoubleType    => castToDouble
+    case StringType                 => castToString(from)
+    case BinaryType                 => castToBinary(from)
+    case DateType                   => castToDate(from)
+    case decimal: DecimalType       => castToDecimal(from, decimal)
+    case TimestampType              => castToTimestamp(from)
+    case BooleanType                => castToBoolean(from)
+    case ByteType                   => castToByte(from)
+    case ShortType                  => castToShort(from)
+    case IntegerType                => castToInt(from)
+    case FloatType                  => castToFloat(from)
+    case LongType                   => castToLong(from)
+    case DoubleType                 => castToDouble(from)
+    case array: ArrayType           => castArray(from.asInstanceOf[ArrayType], array)
+    case map: MapType               => castMap(from.asInstanceOf[MapType], map)
+    case struct: StructType         => castStruct(from.asInstanceOf[StructType], struct)
   }
 
+  private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
+
   override def eval(input: Row): Any = {
     val evaluated = child.eval(input)
     if (evaluated == null) null else cast(evaluated)

http://git-wip-us.apache.org/repos/asf/spark/blob/33448036/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index cd2f67f..b030483 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -487,6 +487,242 @@ class ExpressionEvaluationSuite extends FunSuite {
     checkEvaluation(Cast(Literal(1.0f / 0.0f), TimestampType), null)
   }
 
+  test("array casting") {
+    val array = Literal(Seq("123", "abc", "", null), ArrayType(StringType, containsNull = true))
+    val array_notNull = Literal(Seq("123", "abc", ""), ArrayType(StringType, containsNull = false))
+
+    {
+      val cast = Cast(array, ArrayType(IntegerType, containsNull = true))
+      assert(cast.resolved === true)
+      checkEvaluation(cast, Seq(123, null, null, null))
+    }
+    {
+      val cast = Cast(array, ArrayType(IntegerType, containsNull = false))
+      assert(cast.resolved === false)
+    }
+    {
+      val cast = Cast(array, ArrayType(BooleanType, containsNull = true))
+      assert(cast.resolved === true)
+      checkEvaluation(cast, Seq(true, true, false, null))
+    }
+    {
+      val cast = Cast(array, ArrayType(BooleanType, containsNull = false))
+      assert(cast.resolved === false)
+    }
+
+    {
+      val cast = Cast(array_notNull, ArrayType(IntegerType, containsNull = true))
+      assert(cast.resolved === true)
+      checkEvaluation(cast, Seq(123, null, null))
+    }
+    {
+      val cast = Cast(array_notNull, ArrayType(IntegerType, containsNull = false))
+      assert(cast.resolved === false)
+    }
+    {
+      val cast = Cast(array_notNull, ArrayType(BooleanType, containsNull = true))
+      assert(cast.resolved === true)
+      checkEvaluation(cast, Seq(true, true, false))
+    }
+    {
+      val cast = Cast(array_notNull, ArrayType(BooleanType, containsNull = false))
+      assert(cast.resolved === true)
+      checkEvaluation(cast, Seq(true, true, false))
+    }
+
+    {
+      val cast = Cast(array, IntegerType)
+      assert(cast.resolved === false)
+    }
+  }
+
+  test("map casting") {
+    val map = Literal(
+      Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null),
+      MapType(StringType, StringType, valueContainsNull = true))
+    val map_notNull = Literal(
+      Map("a" -> "123", "b" -> "abc", "c" -> ""),
+      MapType(StringType, StringType, valueContainsNull = false))
+
+    {
+      val cast = Cast(map, MapType(StringType, IntegerType, valueContainsNull = true))
+      assert(cast.resolved === true)
+      checkEvaluation(cast, Map("a" -> 123, "b" -> null, "c" -> null, "d" -> null))
+    }
+    {
+      val cast = Cast(map, MapType(StringType, IntegerType, valueContainsNull = false))
+      assert(cast.resolved === false)
+    }
+    {
+      val cast = Cast(map, MapType(StringType, BooleanType, valueContainsNull = true))
+      assert(cast.resolved === true)
+      checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null))
+    }
+    {
+      val cast = Cast(map, MapType(StringType, BooleanType, valueContainsNull = false))
+      assert(cast.resolved === false)
+    }
+    {
+      val cast = Cast(map, MapType(IntegerType, StringType, valueContainsNull = true))
+      assert(cast.resolved === false)
+    }
+
+    {
+      val cast = Cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = true))
+      assert(cast.resolved === true)
+      checkEvaluation(cast, Map("a" -> 123, "b" -> null, "c" -> null))
+    }
+    {
+      val cast = Cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = false))
+      assert(cast.resolved === false)
+    }
+    {
+      val cast = Cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true))
+      assert(cast.resolved === true)
+      checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false))
+    }
+    {
+      val cast = Cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false))
+      assert(cast.resolved === true)
+      checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false))
+    }
+    {
+      val cast = Cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true))
+      assert(cast.resolved === false)
+    }
+
+    {
+      val cast = Cast(map, IntegerType)
+      assert(cast.resolved === false)
+    }
+  }
+
+  test("struct casting") {
+    val struct = Literal(
+      Row("123", "abc", "", null),
+      StructType(Seq(
+        StructField("a", StringType, nullable = true),
+        StructField("b", StringType, nullable = true),
+        StructField("c", StringType, nullable = true),
+        StructField("d", StringType, nullable = true))))
+    val struct_notNull = Literal(
+      Row("123", "abc", ""),
+      StructType(Seq(
+        StructField("a", StringType, nullable = false),
+        StructField("b", StringType, nullable = false),
+        StructField("c", StringType, nullable = false))))
+
+    {
+      val cast = Cast(struct, StructType(Seq(
+        StructField("a", IntegerType, nullable = true),
+        StructField("b", IntegerType, nullable = true),
+        StructField("c", IntegerType, nullable = true),
+        StructField("d", IntegerType, nullable = true))))
+      assert(cast.resolved === true)
+      checkEvaluation(cast, Row(123, null, null, null))
+    }
+    {
+      val cast = Cast(struct, StructType(Seq(
+        StructField("a", IntegerType, nullable = true),
+        StructField("b", IntegerType, nullable = true),
+        StructField("c", IntegerType, nullable = false),
+        StructField("d", IntegerType, nullable = true))))
+      assert(cast.resolved === false)
+    }
+    {
+      val cast = Cast(struct, StructType(Seq(
+        StructField("a", BooleanType, nullable = true),
+        StructField("b", BooleanType, nullable = true),
+        StructField("c", BooleanType, nullable = true),
+        StructField("d", BooleanType, nullable = true))))
+      assert(cast.resolved === true)
+      checkEvaluation(cast, Row(true, true, false, null))
+    }
+    {
+      val cast = Cast(struct, StructType(Seq(
+        StructField("a", BooleanType, nullable = true),
+        StructField("b", BooleanType, nullable = true),
+        StructField("c", BooleanType, nullable = false),
+        StructField("d", BooleanType, nullable = true))))
+      assert(cast.resolved === false)
+    }
+
+    {
+      val cast = Cast(struct_notNull, StructType(Seq(
+        StructField("a", IntegerType, nullable = true),
+        StructField("b", IntegerType, nullable = true),
+        StructField("c", IntegerType, nullable = true))))
+      assert(cast.resolved === true)
+      checkEvaluation(cast, Row(123, null, null))
+    }
+    {
+      val cast = Cast(struct_notNull, StructType(Seq(
+        StructField("a", IntegerType, nullable = true),
+        StructField("b", IntegerType, nullable = true),
+        StructField("c", IntegerType, nullable = false))))
+      assert(cast.resolved === false)
+    }
+    {
+      val cast = Cast(struct_notNull, StructType(Seq(
+        StructField("a", BooleanType, nullable = true),
+        StructField("b", BooleanType, nullable = true),
+        StructField("c", BooleanType, nullable = true))))
+      assert(cast.resolved === true)
+      checkEvaluation(cast, Row(true, true, false))
+    }
+    {
+      val cast = Cast(struct_notNull, StructType(Seq(
+        StructField("a", BooleanType, nullable = true),
+        StructField("b", BooleanType, nullable = true),
+        StructField("c", BooleanType, nullable = false))))
+      assert(cast.resolved === true)
+      checkEvaluation(cast, Row(true, true, false))
+    }
+
+    {
+      val cast = Cast(struct, StructType(Seq(
+        StructField("a", StringType, nullable = true),
+        StructField("b", StringType, nullable = true),
+        StructField("c", StringType, nullable = true))))
+      assert(cast.resolved === false)
+    }
+    {
+      val cast = Cast(struct, IntegerType)
+      assert(cast.resolved === false)
+    }
+  }
+
+  test("complex casting") {
+    val complex = Literal(
+      Row(
+        Seq("123", "abc", ""),
+        Map("a" -> "123", "b" -> "abc", "c" -> ""),
+        Row(0)),
+      StructType(Seq(
+        StructField("a",
+          ArrayType(StringType, containsNull = false), nullable = true),
+        StructField("m",
+          MapType(StringType, StringType, valueContainsNull = false), nullable = true),
+        StructField("s",
+          StructType(Seq(
+            StructField("i", IntegerType, nullable = true)))))))
+
+    val cast = Cast(complex, StructType(Seq(
+      StructField("a",
+        ArrayType(IntegerType, containsNull = true), nullable = true),
+      StructField("m",
+        MapType(StringType, BooleanType, valueContainsNull = false), nullable = true),
+      StructField("s",
+        StructType(Seq(
+          StructField("l", LongType, nullable = true)))))))
+
+    assert(cast.resolved === true)
+    checkEvaluation(cast, Row(
+      Seq(123, null, null),
+      Map("a" -> true, "b" -> true, "c" -> false),
+      Row(0L)))
+  }
+
   test("null checking") {
     val row = new GenericRow(Array[Any]("^Ba*n", null, true, null))
     val c1 = 'a.string.at(0)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org