You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/12/01 19:24:56 UTC

spark git commit: [SPARK-11856][SQL] add type cast if the real type is different but compatible with encoder schema

Repository: spark
Updated Branches:
  refs/heads/master 8ddc55f1d -> 9df24624a


[SPARK-11856][SQL] add type cast if the real type is different but compatible with encoder schema

When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff and lost the required data type, which may lead to runtime error if the real type doesn't match the encoder's schema.
For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type is `[a: int, b: long]`, then we will hit runtime error and say that we can't construct class `Data` with int and long, because we lost the information that `b` should be a string.

Author: Wenchen Fan <we...@databricks.com>

Closes #9840 from cloud-fan/err-msg.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9df24624
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9df24624
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9df24624

Branch: refs/heads/master
Commit: 9df24624afedd993a39ab46c8211ae153aedef1a
Parents: 8ddc55f
Author: Wenchen Fan <we...@databricks.com>
Authored: Tue Dec 1 10:24:53 2015 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Dec 1 10:24:53 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    |  93 ++++++++--
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  40 +++++
 .../catalyst/analysis/HiveTypeCoercion.scala    |   2 +-
 .../catalyst/encoders/ExpressionEncoder.scala   |   4 +-
 .../spark/sql/catalyst/expressions/Cast.scala   |   9 +
 .../expressions/complexTypeCreator.scala        |   2 +-
 .../apache/spark/sql/types/DecimalType.scala    |  12 ++
 .../encoders/EncoderResolutionSuite.scala       | 180 +++++++++++++++++++
 .../spark/sql/DatasetAggregatorSuite.scala      |   4 +-
 .../org/apache/spark/sql/DatasetSuite.scala     |  21 ++-
 10 files changed, 335 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9df24624/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index d133ad3..9b6b5b8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -18,9 +18,8 @@
 package org.apache.spark.sql.catalyst
 
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute}
-import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, ArrayData, DateTimeUtils}
+import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils}
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
@@ -117,31 +116,75 @@ object ScalaReflection extends ScalaReflection {
    * from ordinal 0 (since there are no names to map to).  The actual location can be moved by
    * calling resolve/bind with a new schema.
    */
-  def constructorFor[T : TypeTag]: Expression = constructorFor(localTypeOf[T], None)
+  def constructorFor[T : TypeTag]: Expression = {
+    val tpe = localTypeOf[T]
+    val clsName = getClassNameFromType(tpe)
+    val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
+    constructorFor(tpe, None, walkedTypePath)
+  }
 
   private def constructorFor(
       tpe: `Type`,
-      path: Option[Expression]): Expression = ScalaReflectionLock.synchronized {
+      path: Option[Expression],
+      walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
 
     /** Returns the current path with a sub-field extracted. */
-    def addToPath(part: String): Expression = path
-      .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
-      .getOrElse(UnresolvedAttribute(part))
+    def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = {
+      val newPath = path
+        .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
+        .getOrElse(UnresolvedAttribute(part))
+      upCastToExpectedType(newPath, dataType, walkedTypePath)
+    }
 
     /** Returns the current path with a field at ordinal extracted. */
-    def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
-      .map(p => GetStructField(p, ordinal))
-      .getOrElse(BoundReference(ordinal, dataType, false))
+    def addToPathOrdinal(
+        ordinal: Int,
+        dataType: DataType,
+        walkedTypePath: Seq[String]): Expression = {
+      val newPath = path
+        .map(p => GetStructField(p, ordinal))
+        .getOrElse(BoundReference(ordinal, dataType, false))
+      upCastToExpectedType(newPath, dataType, walkedTypePath)
+    }
 
     /** Returns the current path or `BoundReference`. */
-    def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true))
+    def getPath: Expression = {
+      val dataType = schemaFor(tpe).dataType
+      if (path.isDefined) {
+        path.get
+      } else {
+        upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath)
+      }
+    }
+
+    /**
+     * When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff
+     * and lost the required data type, which may lead to runtime error if the real type doesn't
+     * match the encoder's schema.
+     * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
+     * is [a: int, b: long], then we will hit runtime error and say that we can't construct class
+     * `Data` with int and long, because we lost the information that `b` should be a string.
+     *
+     * This method help us "remember" the required data type by adding a `UpCast`.  Note that we
+     * don't need to cast struct type because there must be `UnresolvedExtractValue` or
+     * `GetStructField` wrapping it, thus we only need to handle leaf type.
+     */
+    def upCastToExpectedType(
+        expr: Expression,
+        expected: DataType,
+        walkedTypePath: Seq[String]): Expression = expected match {
+      case _: StructType => expr
+      case _ => UpCast(expr, expected, walkedTypePath)
+    }
 
     tpe match {
       case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
 
       case t if t <:< localTypeOf[Option[_]] =>
         val TypeRef(_, _, Seq(optType)) = t
-        WrapOption(constructorFor(optType, path))
+        val className = getClassNameFromType(optType)
+        val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath
+        WrapOption(constructorFor(optType, path, newTypePath))
 
       case t if t <:< localTypeOf[java.lang.Integer] =>
         val boxedType = classOf[java.lang.Integer]
@@ -219,9 +262,11 @@ object ScalaReflection extends ScalaReflection {
         primitiveMethod.map { method =>
           Invoke(getPath, method, arrayClassFor(elementType))
         }.getOrElse {
+          val className = getClassNameFromType(elementType)
+          val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
           Invoke(
             MapObjects(
-              p => constructorFor(elementType, Some(p)),
+              p => constructorFor(elementType, Some(p), newTypePath),
               getPath,
               schemaFor(elementType).dataType),
             "array",
@@ -230,10 +275,12 @@ object ScalaReflection extends ScalaReflection {
 
       case t if t <:< localTypeOf[Seq[_]] =>
         val TypeRef(_, _, Seq(elementType)) = t
+        val className = getClassNameFromType(elementType)
+        val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
         val arrayData =
           Invoke(
             MapObjects(
-              p => constructorFor(elementType, Some(p)),
+              p => constructorFor(elementType, Some(p), newTypePath),
               getPath,
               schemaFor(elementType).dataType),
             "array",
@@ -246,12 +293,13 @@ object ScalaReflection extends ScalaReflection {
           arrayData :: Nil)
 
       case t if t <:< localTypeOf[Map[_, _]] =>
+        // TODO: add walked type path for map
         val TypeRef(_, _, Seq(keyType, valueType)) = t
 
         val keyData =
           Invoke(
             MapObjects(
-              p => constructorFor(keyType, Some(p)),
+              p => constructorFor(keyType, Some(p), walkedTypePath),
               Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
               schemaFor(keyType).dataType),
             "array",
@@ -260,7 +308,7 @@ object ScalaReflection extends ScalaReflection {
         val valueData =
           Invoke(
             MapObjects(
-              p => constructorFor(valueType, Some(p)),
+              p => constructorFor(valueType, Some(p), walkedTypePath),
               Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
               schemaFor(valueType).dataType),
             "array",
@@ -297,12 +345,19 @@ object ScalaReflection extends ScalaReflection {
           val fieldName = p.name.toString
           val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
           val dataType = schemaFor(fieldType).dataType
-
+          val clsName = getClassNameFromType(fieldType)
+          val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
           // For tuples, we based grab the inner fields by ordinal instead of name.
           if (cls.getName startsWith "scala.Tuple") {
-            constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
+            constructorFor(
+              fieldType,
+              Some(addToPathOrdinal(i, dataType, newTypePath)),
+              newTypePath)
           } else {
-            constructorFor(fieldType, Some(addToPath(fieldName)))
+            constructorFor(
+              fieldType,
+              Some(addToPath(fieldName, dataType, newTypePath)),
+              newTypePath)
           }
         }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9df24624/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b8f212f..765327c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -72,6 +72,7 @@ class Analyzer(
       ResolveReferences ::
       ResolveGroupingAnalytics ::
       ResolvePivot ::
+      ResolveUpCast ::
       ResolveSortReferences ::
       ResolveGenerate ::
       ResolveFunctions ::
@@ -1182,3 +1183,42 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
     }
   }
 }
+
+/**
+ * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
+ */
+object ResolveUpCast extends Rule[LogicalPlan] {
+  private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
+    throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " +
+      s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
+      "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
+      "You can either add an explicit cast to the input data or choose a higher precision " +
+      "type of the field in the target object")
+  }
+
+  private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
+    val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
+    val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
+    toPrecedence > 0 && fromPrecedence > toPrecedence
+  }
+
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    plan transformAllExpressions {
+      case u @ UpCast(child, _, _) if !child.resolved => u
+
+      case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
+        case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
+          fail(child, to, walkedTypePath)
+        case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
+          fail(child, to, walkedTypePath)
+        case (from, to) if illegalNumericPrecedence(from, to) =>
+          fail(child, to, walkedTypePath)
+        case (TimestampType, DateType) =>
+          fail(child, DateType, walkedTypePath)
+        case (StringType, to: NumericType) =>
+          fail(child, to, walkedTypePath)
+        case _ => Cast(child, dataType)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9df24624/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index f90fc3c..29502a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -53,7 +53,7 @@ object HiveTypeCoercion {
 
   // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
   // The conversion for integral and floating point types have a linear widening hierarchy:
-  private val numericPrecedence =
+  private[sql] val numericPrecedence =
     IndexedSeq(
       ByteType,
       ShortType,

http://git-wip-us.apache.org/repos/asf/spark/blob/9df24624/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 0c10a56..06ffe86 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtract
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
@@ -235,12 +236,13 @@ case class ExpressionEncoder[T](
 
     val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
     val analyzedPlan = SimpleAnalyzer.execute(plan)
+    val optimizedPlan = SimplifyCasts(analyzedPlan)
 
     // In order to construct instances of inner classes (for example those declared in a REPL cell),
     // we need an instance of the outer scope.  This rule substitues those outer objects into
     // expressions that are missing them by looking up the name in the SQLContexts `outerScopes`
     // registry.
-    copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform {
+    copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform {
       case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
         val outer = outerScopes.get(n.cls.getDeclaringClass.getName)
         if (outer == null) {

http://git-wip-us.apache.org/repos/asf/spark/blob/9df24624/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index a2c6c39..cb60d59 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -914,3 +914,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
       """
   }
 }
+
+/**
+ * Cast the child expression to the target data type, but will throw error if the cast might
+ * truncate, e.g. long -> int, timestamp -> data.
+ */
+case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String])
+  extends UnaryExpression with Unevaluable {
+  override lazy val resolved = false
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9df24624/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 1854dfa..72cc89c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -126,7 +126,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
 case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
 
   /**
-   * Returns Aliased [[Expressions]] that could be used to construct a flattened version of this
+   * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this
    * StructType.
    */
   def flatten: Seq[NamedExpression] = valExprs.zip(names).map {

http://git-wip-us.apache.org/repos/asf/spark/blob/9df24624/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 0cd352d..ce45245 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -91,6 +91,18 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
   }
 
   /**
+   * Returns whether this DecimalType is tighter than `other`. If yes, it means `this`
+   * can be casted into `other` safely without losing any precision or range.
+   */
+  private[sql] def isTighterThan(other: DataType): Boolean = other match {
+    case dt: DecimalType =>
+      (precision - scale) <= (dt.precision - dt.scale) && scale <= dt.scale
+    case dt: IntegralType =>
+      isTighterThan(DecimalType.forType(dt))
+    case _ => false
+  }
+
+  /**
    * The default size of a value of the DecimalType is 4096 bytes.
    */
   override def defaultSize: Int = 4096

http://git-wip-us.apache.org/repos/asf/spark/blob/9df24624/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
new file mode 100644
index 0000000..0289988
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.types._
+
+case class StringLongClass(a: String, b: Long)
+
+case class StringIntClass(a: String, b: Int)
+
+case class ComplexClass(a: Long, b: StringLongClass)
+
+class EncoderResolutionSuite extends PlanTest {
+  test("real type doesn't match encoder schema but they are compatible: product") {
+    val encoder = ExpressionEncoder[StringLongClass]
+    val cls = classOf[StringLongClass]
+
+    {
+      val attrs = Seq('a.string, 'b.int)
+      val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
+      val expected: Expression = NewInstance(
+        cls,
+        toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil,
+        false,
+        ObjectType(cls))
+      compareExpressions(fromRowExpr, expected)
+    }
+
+    {
+      val attrs = Seq('a.int, 'b.long)
+      val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
+      val expected = NewInstance(
+        cls,
+        toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil,
+        false,
+        ObjectType(cls))
+      compareExpressions(fromRowExpr, expected)
+    }
+  }
+
+  test("real type doesn't match encoder schema but they are compatible: nested product") {
+    val encoder = ExpressionEncoder[ComplexClass]
+    val innerCls = classOf[StringLongClass]
+    val cls = classOf[ComplexClass]
+
+    val structType = new StructType().add("a", IntegerType).add("b", LongType)
+    val attrs = Seq('a.int, 'b.struct(structType))
+    val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
+    val expected: Expression = NewInstance(
+      cls,
+      Seq(
+        'a.int.cast(LongType),
+        If(
+          'b.struct(structType).isNull,
+          Literal.create(null, ObjectType(innerCls)),
+          NewInstance(
+            innerCls,
+            Seq(
+              toExternalString(
+                GetStructField('b.struct(structType), 0, Some("a")).cast(StringType)),
+              GetStructField('b.struct(structType), 1, Some("b"))),
+            false,
+            ObjectType(innerCls))
+        )),
+      false,
+      ObjectType(cls))
+    compareExpressions(fromRowExpr, expected)
+  }
+
+  test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
+    val encoder = ExpressionEncoder.tuple(
+      ExpressionEncoder[StringLongClass],
+      ExpressionEncoder[Long])
+    val cls = classOf[StringLongClass]
+
+    val structType = new StructType().add("a", StringType).add("b", ByteType)
+    val attrs = Seq('a.struct(structType), 'b.int)
+    val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
+    val expected: Expression = NewInstance(
+      classOf[Tuple2[_, _]],
+      Seq(
+        NewInstance(
+          cls,
+          Seq(
+            toExternalString(GetStructField('a.struct(structType), 0, Some("a"))),
+            GetStructField('a.struct(structType), 1, Some("b")).cast(LongType)),
+          false,
+          ObjectType(cls)),
+        'b.int.cast(LongType)),
+      false,
+      ObjectType(classOf[Tuple2[_, _]]))
+    compareExpressions(fromRowExpr, expected)
+  }
+
+  private def toExternalString(e: Expression): Expression = {
+    Invoke(e, "toString", ObjectType(classOf[String]), Nil)
+  }
+
+  test("throw exception if real type is not compatible with encoder schema") {
+    val msg1 = intercept[AnalysisException] {
+      ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null)
+    }.message
+    assert(msg1 ==
+      s"""
+         |Cannot up cast `b` from bigint to int as it may truncate
+         |The type path of the target object is:
+         |- field (class: "scala.Int", name: "b")
+         |- root class: "org.apache.spark.sql.catalyst.encoders.StringIntClass"
+         |You can either add an explicit cast to the input data or choose a higher precision type
+       """.stripMargin.trim + " of the field in the target object")
+
+    val msg2 = intercept[AnalysisException] {
+      val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT)
+      ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null)
+    }.message
+    assert(msg2 ==
+      s"""
+         |Cannot up cast `b.b` from decimal(38,18) to bigint as it may truncate
+         |The type path of the target object is:
+         |- field (class: "scala.Long", name: "b")
+         |- field (class: "org.apache.spark.sql.catalyst.encoders.StringLongClass", name: "b")
+         |- root class: "org.apache.spark.sql.catalyst.encoders.ComplexClass"
+         |You can either add an explicit cast to the input data or choose a higher precision type
+       """.stripMargin.trim + " of the field in the target object")
+  }
+
+  // test for leaf types
+  castSuccess[Int, Long]
+  castSuccess[java.sql.Date, java.sql.Timestamp]
+  castSuccess[Long, String]
+  castSuccess[Int, java.math.BigDecimal]
+  castSuccess[Long, java.math.BigDecimal]
+
+  castFail[Long, Int]
+  castFail[java.sql.Timestamp, java.sql.Date]
+  castFail[java.math.BigDecimal, Double]
+  castFail[Double, java.math.BigDecimal]
+  castFail[java.math.BigDecimal, Int]
+  castFail[String, Long]
+
+
+  private def castSuccess[T: TypeTag, U: TypeTag]: Unit = {
+    val from = ExpressionEncoder[T]
+    val to = ExpressionEncoder[U]
+    val catalystType = from.schema.head.dataType.simpleString
+    test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should success") {
+      to.resolve(from.schema.toAttributes, null)
+    }
+  }
+
+  private def castFail[T: TypeTag, U: TypeTag]: Unit = {
+    val from = ExpressionEncoder[T]
+    val to = ExpressionEncoder[U]
+    val catalystType = from.schema.head.dataType.simpleString
+    test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should fail") {
+      intercept[AnalysisException](to.resolve(from.schema.toAttributes, null))
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9df24624/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 19dce5d..c6d2bf0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -131,9 +131,9 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
     checkAnswer(
       ds.groupBy(_._1).agg(
         sum(_._2),
-        expr("sum(_2)").as[Int],
+        expr("sum(_2)").as[Long],
         count("*")),
-      ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L))
+      ("a", 30, 30L, 2L), ("b", 3, 3L, 2L), ("c", 1, 1L, 1L))
   }
 
   test("typed aggregation: complex case") {

http://git-wip-us.apache.org/repos/asf/spark/blob/9df24624/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index a2c8d20..542e4d6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -335,24 +335,24 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
 
     checkAnswer(
-      ds.groupBy(_._1).agg(sum("_2").as[Int]),
-      ("a", 30), ("b", 3), ("c", 1))
+      ds.groupBy(_._1).agg(sum("_2").as[Long]),
+      ("a", 30L), ("b", 3L), ("c", 1L))
   }
 
   test("typed aggregation: expr, expr") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
 
     checkAnswer(
-      ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long]),
-      ("a", 30, 32L), ("b", 3, 5L), ("c", 1, 2L))
+      ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]),
+      ("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L))
   }
 
   test("typed aggregation: expr, expr, expr") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
 
     checkAnswer(
-      ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long], count("*").as[Long]),
-      ("a", 30, 32L, 2L), ("b", 3, 5L, 2L), ("c", 1, 2L, 1L))
+      ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")),
+      ("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L))
   }
 
   test("typed aggregation: expr, expr, expr, expr") {
@@ -360,11 +360,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
 
     checkAnswer(
       ds.groupBy(_._1).agg(
-        sum("_2").as[Int],
+        sum("_2").as[Long],
         sum($"_2" + 1).as[Long],
         count("*").as[Long],
         avg("_2").as[Double]),
-      ("a", 30, 32L, 2L, 15.0), ("b", 3, 5L, 2L, 1.5), ("c", 1, 2L, 1L, 1.0))
+      ("a", 30L, 32L, 2L, 15.0), ("b", 3L, 5L, 2L, 1.5), ("c", 1L, 2L, 1L, 1.0))
   }
 
   test("cogroup") {
@@ -476,6 +476,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       ((nullInt, "1"), (new java.lang.Integer(22), "2")),
       ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2")))
   }
+
+  test("change encoder with compatible schema") {
+    val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData]
+    assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3)))
+  }
 }
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org