You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/02/27 05:47:47 UTC

[spark] branch master updated: [SPARK-22000][SQL] Address missing Upcast in JavaTypeInference.deserializerFor

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dea18ee  [SPARK-22000][SQL] Address missing Upcast in JavaTypeInference.deserializerFor
dea18ee is described below

commit dea18ee85b53511a012f3d0ca95626776a5241ba
Author: Jungtaek Lim (HeartSaVioR) <ka...@gmail.com>
AuthorDate: Wed Feb 27 13:47:20 2019 +0800

    [SPARK-22000][SQL] Address missing Upcast in JavaTypeInference.deserializerFor
    
    ## What changes were proposed in this pull request?
    
    Spark expects the type of column and the type of matching field is same when deserializing to Object, but Spark hasn't actually restrict it (at least for Java bean encoder) and some users just do it and experience undefined behavior (in SPARK-22000, Spark throws compilation failure on generated code because it calls `.toString()` against primitive type.
    
    It doesn't produce error in Scala side because `ScalaReflection.deserializerFor` properly inject Upcast if necessary. This patch proposes applying same thing to `JavaTypeInference.deserializerFor` as well.
    
    Credit to srowen, maropu, and cloud-fan since they provided various approaches to solve this.
    
    ## How was this patch tested?
    
    Added UT which query is slightly modified based on sample code in attachment on JIRA issue.
    
    Closes #23854 from HeartSaVioR/SPARK-22000.
    
    Authored-by: Jungtaek Lim (HeartSaVioR) <ka...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/catalyst/DeserializerBuildHelper.scala     | 158 +++++++++++++
 .../spark/sql/catalyst/JavaTypeInference.scala     | 142 ++++++------
 .../spark/sql/catalyst/ScalaReflection.scala       | 221 +++++++-----------
 .../spark/sql/JavaBeanDeserializationSuite.java    | 248 ++++++++++++++++++++-
 4 files changed, 559 insertions(+), 210 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
new file mode 100644
index 0000000..e71955a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
+import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, UpCast}
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, StaticInvoke}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.types._
+
+object DeserializerBuildHelper {
+  /** Returns the current path with a sub-field extracted. */
+  def addToPath(
+      path: Expression,
+      part: String,
+      dataType: DataType,
+      walkedTypePath: Seq[String]): Expression = {
+    val newPath = UnresolvedExtractValue(path, expressions.Literal(part))
+    upCastToExpectedType(newPath, dataType, walkedTypePath)
+  }
+
+  /** Returns the current path with a field at ordinal extracted. */
+  def addToPathOrdinal(
+      path: Expression,
+      ordinal: Int,
+      dataType: DataType,
+      walkedTypePath: Seq[String]): Expression = {
+    val newPath = GetStructField(path, ordinal)
+    upCastToExpectedType(newPath, dataType, walkedTypePath)
+  }
+
+  def deserializerForWithNullSafety(
+      expr: Expression,
+      dataType: DataType,
+      nullable: Boolean,
+      walkedTypePath: Seq[String],
+      funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = {
+    val newExpr = funcForCreatingNewExpr(expr, walkedTypePath)
+    expressionWithNullSafety(newExpr, nullable, walkedTypePath)
+  }
+
+  def deserializerForWithNullSafetyAndUpcast(
+      expr: Expression,
+      dataType: DataType,
+      nullable: Boolean,
+      walkedTypePath: Seq[String],
+      funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = {
+    val casted = upCastToExpectedType(expr, dataType, walkedTypePath)
+    deserializerForWithNullSafety(casted, dataType, nullable, walkedTypePath,
+      funcForCreatingNewExpr)
+  }
+
+  private def expressionWithNullSafety(
+      expr: Expression,
+      nullable: Boolean,
+      walkedTypePath: Seq[String]): Expression = {
+    if (nullable) {
+      expr
+    } else {
+      AssertNotNull(expr, walkedTypePath)
+    }
+  }
+
+  def createDeserializerForTypesSupportValueOf(
+      path: Expression,
+      clazz: Class[_]): Expression = {
+    StaticInvoke(
+      clazz,
+      ObjectType(clazz),
+      "valueOf",
+      path :: Nil,
+      returnNullable = false)
+  }
+
+  def createDeserializerForString(path: Expression, returnNullable: Boolean): Expression = {
+    Invoke(path, "toString", ObjectType(classOf[java.lang.String]),
+      returnNullable = returnNullable)
+  }
+
+  def createDeserializerForSqlDate(path: Expression): Expression = {
+    StaticInvoke(
+      DateTimeUtils.getClass,
+      ObjectType(classOf[java.sql.Date]),
+      "toJavaDate",
+      path :: Nil,
+      returnNullable = false)
+  }
+
+  def createDeserializerForSqlTimestamp(path: Expression): Expression = {
+    StaticInvoke(
+      DateTimeUtils.getClass,
+      ObjectType(classOf[java.sql.Timestamp]),
+      "toJavaTimestamp",
+      path :: Nil,
+      returnNullable = false)
+  }
+
+  def createDeserializerForJavaBigDecimal(
+      path: Expression,
+      returnNullable: Boolean): Expression = {
+    Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
+      returnNullable = returnNullable)
+  }
+
+  def createDeserializerForScalaBigDecimal(
+      path: Expression,
+      returnNullable: Boolean): Expression = {
+    Invoke(path, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = returnNullable)
+  }
+
+  def createDeserializerForJavaBigInteger(
+      path: Expression,
+      returnNullable: Boolean): Expression = {
+    Invoke(path, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]),
+      returnNullable = returnNullable)
+  }
+
+  def createDeserializerForScalaBigInt(path: Expression): Expression = {
+    Invoke(path, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]),
+      returnNullable = false)
+  }
+
+  /**
+   * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
+   * and lost the required data type, which may lead to runtime error if the real type doesn't
+   * match the encoder's schema.
+   * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
+   * is [a: int, b: long], then we will hit runtime error and say that we can't construct class
+   * `Data` with int and long, because we lost the information that `b` should be a string.
+   *
+   * This method help us "remember" the required data type by adding a `UpCast`. Note that we
+   * only need to do this for leaf nodes.
+   */
+  private def upCastToExpectedType(
+      expr: Expression,
+      expected: DataType,
+      walkedTypePath: Seq[String]): Expression = expected match {
+    case _: StructType => expr
+    case _: ArrayType => expr
+    case _: MapType => expr
+    case _ => UpCast(expr, expected, walkedTypePath)
+  }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 311060e..dafa878 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -26,7 +26,8 @@ import scala.language.existentials
 
 import com.google.common.reflect.TypeToken
 
-import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
+import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
@@ -194,14 +195,20 @@ object JavaTypeInference {
    */
   def deserializerFor(beanClass: Class[_]): Expression = {
     val typeToken = TypeToken.of(beanClass)
-    deserializerFor(typeToken, GetColumnByOrdinal(0, inferDataType(typeToken)._1))
+    val walkedTypePath = s"""- root class: "${beanClass.getCanonicalName}"""" :: Nil
+    val (dataType, nullable) = inferDataType(typeToken)
+
+    // Assumes we are deserializing the first column of a row.
+    deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType,
+      nullable = nullable, walkedTypePath, (casted, walkedTypePath) => {
+        deserializerFor(typeToken, casted, walkedTypePath)
+      })
   }
 
-  private def deserializerFor(typeToken: TypeToken[_], path: Expression): Expression = {
-    /** Returns the current path with a sub-field extracted. */
-    def addToPath(part: String): Expression = UnresolvedExtractValue(path,
-      expressions.Literal(part))
-
+  private def deserializerFor(
+      typeToken: TypeToken[_],
+      path: Expression,
+      walkedTypePath: Seq[String]): Expression = {
     typeToken.getRawType match {
       case c if !inferExternalType(c).isInstanceOf[ObjectType] => path
 
@@ -212,74 +219,79 @@ object JavaTypeInference {
                 c == classOf[java.lang.Float] ||
                 c == classOf[java.lang.Byte] ||
                 c == classOf[java.lang.Boolean] =>
-        StaticInvoke(
-          c,
-          ObjectType(c),
-          "valueOf",
-          path :: Nil,
-          returnNullable = false)
+        createDeserializerForTypesSupportValueOf(path, c)
 
       case c if c == classOf[java.sql.Date] =>
-        StaticInvoke(
-          DateTimeUtils.getClass,
-          ObjectType(c),
-          "toJavaDate",
-          path :: Nil,
-          returnNullable = false)
+        createDeserializerForSqlDate(path)
 
       case c if c == classOf[java.sql.Timestamp] =>
-        StaticInvoke(
-          DateTimeUtils.getClass,
-          ObjectType(c),
-          "toJavaTimestamp",
-          path :: Nil,
-          returnNullable = false)
+        createDeserializerForSqlTimestamp(path)
 
       case c if c == classOf[java.lang.String] =>
-        Invoke(path, "toString", ObjectType(classOf[String]))
+        createDeserializerForString(path, returnNullable = true)
 
       case c if c == classOf[java.math.BigDecimal] =>
-        Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+        createDeserializerForJavaBigDecimal(path, returnNullable = true)
+
+      case c if c == classOf[java.math.BigInteger] =>
+        createDeserializerForJavaBigInteger(path, returnNullable = true)
 
       case c if c.isArray =>
         val elementType = c.getComponentType
-        val primitiveMethod = elementType match {
-          case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray")
-          case c if c == java.lang.Byte.TYPE => Some("toByteArray")
-          case c if c == java.lang.Short.TYPE => Some("toShortArray")
-          case c if c == java.lang.Integer.TYPE => Some("toIntArray")
-          case c if c == java.lang.Long.TYPE => Some("toLongArray")
-          case c if c == java.lang.Float.TYPE => Some("toFloatArray")
-          case c if c == java.lang.Double.TYPE => Some("toDoubleArray")
-          case _ => None
+        val newTypePath = s"""- array element class: "${elementType.getCanonicalName}"""" +:
+          walkedTypePath
+        val (dataType, elementNullable) = inferDataType(elementType)
+        val mapFunction: Expression => Expression = element => {
+          // upcast the array element to the data type the encoder expected.
+          deserializerForWithNullSafetyAndUpcast(
+            element,
+            dataType,
+            nullable = elementNullable,
+            newTypePath,
+            (casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath))
         }
 
-        primitiveMethod.map { method =>
-          Invoke(path, method, ObjectType(c))
-        }.getOrElse {
-          Invoke(
-            MapObjects(
-              p => deserializerFor(typeToken.getComponentType, p),
-              path,
-              inferDataType(elementType)._1),
-            "array",
-            ObjectType(c))
+        val arrayData = UnresolvedMapObjects(mapFunction, path)
+
+        val methodName = elementType match {
+          case c if c == java.lang.Integer.TYPE => "toIntArray"
+          case c if c == java.lang.Long.TYPE => "toLongArray"
+          case c if c == java.lang.Double.TYPE => "toDoubleArray"
+          case c if c == java.lang.Float.TYPE => "toFloatArray"
+          case c if c == java.lang.Short.TYPE => "toShortArray"
+          case c if c == java.lang.Byte.TYPE => "toByteArray"
+          case c if c == java.lang.Boolean.TYPE => "toBooleanArray"
+          // non-primitive
+          case _ => "array"
         }
+        Invoke(arrayData, methodName, ObjectType(c))
 
       case c if listType.isAssignableFrom(typeToken) =>
         val et = elementType(typeToken)
-        UnresolvedMapObjects(
-          p => deserializerFor(et, p),
-          path,
-          customCollectionCls = Some(c))
+        val newTypePath = s"""- array element class: "${et.getType.getTypeName}"""" +:
+          walkedTypePath
+        val (dataType, elementNullable) = inferDataType(et)
+        val mapFunction: Expression => Expression = element => {
+          // upcast the array element to the data type the encoder expected.
+          deserializerForWithNullSafetyAndUpcast(
+            element,
+            dataType,
+            nullable = elementNullable,
+            newTypePath,
+            (casted, typePath) => deserializerFor(et, casted, typePath))
+        }
+
+        UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c))
 
       case _ if mapType.isAssignableFrom(typeToken) =>
         val (keyType, valueType) = mapKeyValueType(typeToken)
+        val newTypePath = (s"""- map key class: "${keyType.getType.getTypeName}"""" +
+          s""", value class: "${valueType.getType.getTypeName}"""") +: walkedTypePath
 
         val keyData =
           Invoke(
             UnresolvedMapObjects(
-              p => deserializerFor(keyType, p),
+              p => deserializerFor(keyType, p, newTypePath),
               MapKeys(path)),
             "array",
             ObjectType(classOf[Array[Any]]))
@@ -287,7 +299,7 @@ object JavaTypeInference {
         val valueData =
           Invoke(
             UnresolvedMapObjects(
-              p => deserializerFor(valueType, p),
+              p => deserializerFor(valueType, p, newTypePath),
               MapValues(path)),
             "array",
             ObjectType(classOf[Array[Any]]))
@@ -300,25 +312,25 @@ object JavaTypeInference {
           returnNullable = false)
 
       case other if other.isEnum =>
-        StaticInvoke(
-          other,
-          ObjectType(other),
-          "valueOf",
-          Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil,
-          returnNullable = false)
+        createDeserializerForTypesSupportValueOf(
+          createDeserializerForString(path, returnNullable = false),
+          other)
 
       case other =>
         val properties = getJavaBeanReadableAndWritableProperties(other)
         val setters = properties.map { p =>
           val fieldName = p.getName
           val fieldType = typeToken.method(p.getReadMethod).getReturnType
-          val (_, nullable) = inferDataType(fieldType)
-          val constructor = deserializerFor(fieldType, addToPath(fieldName))
-          val setter = if (nullable) {
-            constructor
-          } else {
-            AssertNotNull(constructor, Seq("currently no type path record in java"))
-          }
+          val (dataType, nullable) = inferDataType(fieldType)
+          val newTypePath = (s"""- field (class: "${fieldType.getType.getTypeName}"""" +
+            s""", name: "$fieldName")""") +: walkedTypePath
+          val setter = deserializerForWithNullSafety(
+            path,
+            dataType,
+            nullable = nullable,
+            newTypePath,
+            (expr, typePath) => deserializerFor(fieldType,
+              addToPath(expr, fieldName, dataType, typePath), typePath))
           p.getWriteMethod.getName -> setter
         }.toMap
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index d5af91a..741cba8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -17,15 +17,12 @@
 
 package org.apache.spark.sql.catalyst
 
-import java.lang.reflect.Constructor
-
-import scala.util.Properties
-
 import org.apache.commons.lang3.reflect.ConstructorUtils
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
+import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
+import org.apache.spark.sql.catalyst.expressions.{Expression, _}
 import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData}
 import org.apache.spark.sql.types._
@@ -129,25 +126,6 @@ object ScalaReflection extends ScalaReflection {
   }
 
   /**
-   * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
-   * and lost the required data type, which may lead to runtime error if the real type doesn't
-   * match the encoder's schema.
-   * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
-   * is [a: int, b: long], then we will hit runtime error and say that we can't construct class
-   * `Data` with int and long, because we lost the information that `b` should be a string.
-   *
-   * This method help us "remember" the required data type by adding a `UpCast`. Note that we
-   * only need to do this for leaf nodes.
-   */
-  private def upCastToExpectedType(expr: Expression, expected: DataType,
-      walkedTypePath: Seq[String]): Expression = expected match {
-    case _: StructType => expr
-    case _: ArrayType => expr
-    case _: MapType => expr
-    case _ => UpCast(expr, expected, walkedTypePath)
-  }
-
-  /**
    * Returns an expression that can be used to deserialize a Spark SQL representation to an object
    * of type `T` with a compatible schema. The Spark SQL representation is located at ordinal 0 of
    * a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed using
@@ -162,15 +140,9 @@ object ScalaReflection extends ScalaReflection {
     val Schema(dataType, nullable) = schemaFor(tpe)
 
     // Assumes we are deserializing the first column of a row.
-    val input = upCastToExpectedType(
-      GetColumnByOrdinal(0, dataType), dataType, walkedTypePath)
-
-    val expr = deserializerFor(tpe, input, walkedTypePath)
-    if (nullable) {
-      expr
-    } else {
-      AssertNotNull(expr, walkedTypePath)
-    }
+    deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType,
+      nullable = nullable, walkedTypePath,
+      (casted, typePath) => deserializerFor(tpe, casted, typePath))
   }
 
   /**
@@ -185,22 +157,6 @@ object ScalaReflection extends ScalaReflection {
       tpe: `Type`,
       path: Expression,
       walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects {
-
-    /** Returns the current path with a sub-field extracted. */
-    def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = {
-      val newPath = UnresolvedExtractValue(path, expressions.Literal(part))
-      upCastToExpectedType(newPath, dataType, walkedTypePath)
-    }
-
-    /** Returns the current path with a field at ordinal extracted. */
-    def addToPathOrdinal(
-        ordinal: Int,
-        dataType: DataType,
-        walkedTypePath: Seq[String]): Expression = {
-      val newPath = GetStructField(path, ordinal)
-      upCastToExpectedType(newPath, dataType, walkedTypePath)
-    }
-
     tpe.dealias match {
       case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path
 
@@ -211,73 +167,53 @@ object ScalaReflection extends ScalaReflection {
         WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType))
 
       case t if t <:< localTypeOf[java.lang.Integer] =>
-        val boxedType = classOf[java.lang.Integer]
-        val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
+        createDeserializerForTypesSupportValueOf(path,
+          classOf[java.lang.Integer])
 
       case t if t <:< localTypeOf[java.lang.Long] =>
-        val boxedType = classOf[java.lang.Long]
-        val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
+        createDeserializerForTypesSupportValueOf(path,
+          classOf[java.lang.Long])
 
       case t if t <:< localTypeOf[java.lang.Double] =>
-        val boxedType = classOf[java.lang.Double]
-        val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
+        createDeserializerForTypesSupportValueOf(path,
+          classOf[java.lang.Double])
 
       case t if t <:< localTypeOf[java.lang.Float] =>
-        val boxedType = classOf[java.lang.Float]
-        val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
+        createDeserializerForTypesSupportValueOf(path,
+          classOf[java.lang.Float])
 
       case t if t <:< localTypeOf[java.lang.Short] =>
-        val boxedType = classOf[java.lang.Short]
-        val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
+        createDeserializerForTypesSupportValueOf(path,
+          classOf[java.lang.Short])
 
       case t if t <:< localTypeOf[java.lang.Byte] =>
-        val boxedType = classOf[java.lang.Byte]
-        val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
+        createDeserializerForTypesSupportValueOf(path,
+          classOf[java.lang.Byte])
 
       case t if t <:< localTypeOf[java.lang.Boolean] =>
-        val boxedType = classOf[java.lang.Boolean]
-        val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
+        createDeserializerForTypesSupportValueOf(path,
+          classOf[java.lang.Boolean])
 
       case t if t <:< localTypeOf[java.sql.Date] =>
-        StaticInvoke(
-          DateTimeUtils.getClass,
-          ObjectType(classOf[java.sql.Date]),
-          "toJavaDate",
-          path :: Nil,
-          returnNullable = false)
+        createDeserializerForSqlDate(path)
 
       case t if t <:< localTypeOf[java.sql.Timestamp] =>
-        StaticInvoke(
-          DateTimeUtils.getClass,
-          ObjectType(classOf[java.sql.Timestamp]),
-          "toJavaTimestamp",
-          path :: Nil,
-          returnNullable = false)
+        createDeserializerForSqlTimestamp(path)
 
       case t if t <:< localTypeOf[java.lang.String] =>
-        Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false)
+        createDeserializerForString(path, returnNullable = false)
 
       case t if t <:< localTypeOf[java.math.BigDecimal] =>
-        Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
-          returnNullable = false)
+        createDeserializerForJavaBigDecimal(path, returnNullable = false)
 
       case t if t <:< localTypeOf[BigDecimal] =>
-        Invoke(path, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false)
+        createDeserializerForScalaBigDecimal(path, returnNullable = false)
 
       case t if t <:< localTypeOf[java.math.BigInteger] =>
-        Invoke(path, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]),
-          returnNullable = false)
+        createDeserializerForJavaBigInteger(path, returnNullable = false)
 
       case t if t <:< localTypeOf[scala.math.BigInt] =>
-        Invoke(path, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]),
-          returnNullable = false)
+        createDeserializerForScalaBigInt(path)
 
       case t if t <:< localTypeOf[Array[_]] =>
         val TypeRef(_, _, Seq(elementType)) = t
@@ -287,34 +223,29 @@ object ScalaReflection extends ScalaReflection {
 
         val mapFunction: Expression => Expression = element => {
           // upcast the array element to the data type the encoder expected.
-          val casted = upCastToExpectedType(element, dataType, newTypePath)
-          val converter = deserializerFor(elementType, casted, newTypePath)
-          if (elementNullable) {
-            converter
-          } else {
-            AssertNotNull(converter, newTypePath)
-          }
+          deserializerForWithNullSafetyAndUpcast(
+            element,
+            dataType,
+            nullable = elementNullable,
+            newTypePath,
+            (casted, typePath) => deserializerFor(elementType, casted, typePath))
         }
 
         val arrayData = UnresolvedMapObjects(mapFunction, path)
         val arrayCls = arrayClassFor(elementType)
 
-        if (elementNullable) {
-          Invoke(arrayData, "array", arrayCls, returnNullable = false)
-        } else {
-          val primitiveMethod = elementType match {
-            case t if t <:< definitions.IntTpe => "toIntArray"
-            case t if t <:< definitions.LongTpe => "toLongArray"
-            case t if t <:< definitions.DoubleTpe => "toDoubleArray"
-            case t if t <:< definitions.FloatTpe => "toFloatArray"
-            case t if t <:< definitions.ShortTpe => "toShortArray"
-            case t if t <:< definitions.ByteTpe => "toByteArray"
-            case t if t <:< definitions.BooleanTpe => "toBooleanArray"
-            case other => throw new IllegalStateException("expect primitive array element type " +
-              "but got " + other)
-          }
-          Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false)
+        val methodName = elementType match {
+          case t if t <:< definitions.IntTpe => "toIntArray"
+          case t if t <:< definitions.LongTpe => "toLongArray"
+          case t if t <:< definitions.DoubleTpe => "toDoubleArray"
+          case t if t <:< definitions.FloatTpe => "toFloatArray"
+          case t if t <:< definitions.ShortTpe => "toShortArray"
+          case t if t <:< definitions.ByteTpe => "toByteArray"
+          case t if t <:< definitions.BooleanTpe => "toBooleanArray"
+          // non-primitive
+          case _ => "array"
         }
+        Invoke(arrayData, methodName, arrayCls, returnNullable = false)
 
       // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array
       // to a `Set`, if there are duplicated elements, the elements will be de-duplicated.
@@ -326,14 +257,12 @@ object ScalaReflection extends ScalaReflection {
         val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
 
         val mapFunction: Expression => Expression = element => {
-          // upcast the array element to the data type the encoder expected.
-          val casted = upCastToExpectedType(element, dataType, newTypePath)
-          val converter = deserializerFor(elementType, casted, newTypePath)
-          if (elementNullable) {
-            converter
-          } else {
-            AssertNotNull(converter, newTypePath)
-          }
+          deserializerForWithNullSafetyAndUpcast(
+            element,
+            dataType,
+            nullable = elementNullable,
+            newTypePath,
+            (casted, typePath) => deserializerFor(elementType, casted, typePath))
         }
 
         val companion = t.dealias.typeSymbol.companion.typeSignature
@@ -346,13 +275,18 @@ object ScalaReflection extends ScalaReflection {
         UnresolvedMapObjects(mapFunction, path, Some(cls))
 
       case t if t <:< localTypeOf[Map[_, _]] =>
-        // TODO: add walked type path for map
         val TypeRef(_, _, Seq(keyType, valueType)) = t
 
+        val classNameForKey = getClassNameFromType(keyType)
+        val classNameForValue = getClassNameFromType(valueType)
+
+        val newTypePath = (s"""- map key class: "${classNameForKey}"""" +
+          s""", value class: "${classNameForValue}"""") +: walkedTypePath
+
         UnresolvedCatalystToExternalMap(
           path,
-          p => deserializerFor(keyType, p, walkedTypePath),
-          p => deserializerFor(valueType, p, walkedTypePath),
+          p => deserializerFor(keyType, p, newTypePath),
+          p => deserializerFor(valueType, p, newTypePath),
           mirror.runtimeClass(t.typeSymbol.asClass)
         )
 
@@ -382,25 +316,28 @@ object ScalaReflection extends ScalaReflection {
         val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
           val Schema(dataType, nullable) = schemaFor(fieldType)
           val clsName = getClassNameFromType(fieldType)
-          val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
-          // For tuples, we based grab the inner fields by ordinal instead of name.
-          val constructor = if (cls.getName startsWith "scala.Tuple") {
-            deserializerFor(
-              fieldType,
-              addToPathOrdinal(i, dataType, newTypePath),
-              newTypePath)
-          } else {
-            deserializerFor(
-              fieldType,
-              addToPath(fieldName, dataType, newTypePath),
-              newTypePath)
-          }
+          val newTypePath = (s"""- field (class: "$clsName", """ +
+              s"""name: "$fieldName")""") +: walkedTypePath
 
-          if (!nullable) {
-            AssertNotNull(constructor, newTypePath)
-          } else {
-            constructor
-          }
+          // For tuples, we based grab the inner fields by ordinal instead of name.
+          deserializerForWithNullSafety(
+            path,
+            dataType,
+            nullable = nullable,
+            newTypePath,
+            (expr, typePath) => {
+              if (cls.getName startsWith "scala.Tuple") {
+                deserializerFor(
+                  fieldType,
+                  addToPathOrdinal(expr, i, dataType, typePath),
+                  newTypePath)
+              } else {
+                deserializerFor(
+                  fieldType,
+                  addToPath(expr, fieldName, dataType, typePath),
+                  newTypePath)
+              }
+            })
         }
 
         val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
index 8f35abe..49ff522 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
@@ -20,11 +20,12 @@ package test.org.apache.spark.sql;
 import java.io.Serializable;
 import java.util.*;
 
+import org.apache.spark.sql.*;
+import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructType;
 import org.junit.*;
 
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Encoder;
-import org.apache.spark.sql.Encoders;
 import org.apache.spark.sql.test.TestSparkSession;
 
 public class JavaBeanDeserializationSuite implements Serializable {
@@ -115,6 +116,109 @@ public class JavaBeanDeserializationSuite implements Serializable {
     Assert.assertEquals(records, MAP_RECORDS);
   }
 
+  @Test
+  public void testSpark22000() {
+    List<Row> inputRows = new ArrayList<>();
+    List<RecordSpark22000> expectedRecords = new ArrayList<>();
+
+    for (long idx = 0 ; idx < 5 ; idx++) {
+      Row row = createRecordSpark22000Row(idx);
+      inputRows.add(row);
+      expectedRecords.add(createRecordSpark22000(row));
+    }
+
+    // Here we try to convert the fields, from any types to string.
+    // Before applying SPARK-22000, Spark called toString() against variable which type might
+    // be primitive.
+    // SPARK-22000 it calls String.valueOf() which finally calls toString() but handles boxing
+    // if the type is primitive.
+    Encoder<RecordSpark22000> encoder = Encoders.bean(RecordSpark22000.class);
+
+    StructType schema = new StructType()
+      .add("shortField", DataTypes.ShortType)
+      .add("intField", DataTypes.IntegerType)
+      .add("longField", DataTypes.LongType)
+      .add("floatField", DataTypes.FloatType)
+      .add("doubleField", DataTypes.DoubleType)
+      .add("stringField", DataTypes.StringType)
+      .add("booleanField", DataTypes.BooleanType)
+      .add("timestampField", DataTypes.TimestampType)
+      // explicitly setting nullable = true to make clear the intention
+      .add("nullIntField", DataTypes.IntegerType, true);
+
+    Dataset<Row> dataFrame = spark.createDataFrame(inputRows, schema);
+    Dataset<RecordSpark22000> dataset = dataFrame.as(encoder);
+
+    List<RecordSpark22000> records = dataset.collectAsList();
+
+    Assert.assertEquals(records, records);
+  }
+
+  @Test
+  public void testSpark22000FailToUpcast() {
+    List<Row> inputRows = new ArrayList<>();
+    for (long idx = 0 ; idx < 5 ; idx++) {
+      Row row = createRecordSpark22000FailToUpcastRow(idx);
+      inputRows.add(row);
+    }
+
+    // Here we try to convert the fields, from string type to int, which upcast doesn't help.
+    Encoder<RecordSpark22000FailToUpcast> encoder =
+            Encoders.bean(RecordSpark22000FailToUpcast.class);
+
+    StructType schema = new StructType().add("id", DataTypes.StringType);
+
+    Dataset<Row> dataFrame = spark.createDataFrame(inputRows, schema);
+
+    try {
+      dataFrame.as(encoder).collect();
+      Assert.fail("Expected AnalysisException, but passed.");
+    } catch (Throwable e) {
+      // Here we need to handle weird case: compiler complains AnalysisException never be thrown
+      // in try statement, but it can be thrown actually. Maybe Scala-Java interop issue?
+      if (e instanceof AnalysisException) {
+        Assert.assertTrue(e.getMessage().contains("Cannot up cast "));
+      } else {
+        throw e;
+      }
+    }
+  }
+
+  private static Row createRecordSpark22000Row(Long index) {
+    Object[] values = new Object[] {
+            index.shortValue(),
+            index.intValue(),
+            index,
+            index.floatValue(),
+            index.doubleValue(),
+            String.valueOf(index),
+            index % 2 == 0,
+            new java.sql.Timestamp(System.currentTimeMillis()),
+            null
+    };
+    return new GenericRow(values);
+  }
+
+  private static RecordSpark22000 createRecordSpark22000(Row recordRow) {
+    RecordSpark22000 record = new RecordSpark22000();
+    record.setShortField(String.valueOf(recordRow.getShort(0)));
+    record.setIntField(String.valueOf(recordRow.getInt(1)));
+    record.setLongField(String.valueOf(recordRow.getLong(2)));
+    record.setFloatField(String.valueOf(recordRow.getFloat(3)));
+    record.setDoubleField(String.valueOf(recordRow.getDouble(4)));
+    record.setStringField(recordRow.getString(5));
+    record.setBooleanField(String.valueOf(recordRow.getBoolean(6)));
+    record.setTimestampField(String.valueOf(recordRow.getTimestamp(7).getTime() * 1000));
+    // This would figure out that null value will not become "null".
+    record.setNullIntField(null);
+    return record;
+  }
+
+  private static Row createRecordSpark22000FailToUpcastRow(Long index) {
+    Object[] values = new Object[] { String.valueOf(index) };
+    return new GenericRow(values);
+  }
+
   public static class ArrayRecord {
 
     private int id;
@@ -252,4 +356,142 @@ public class JavaBeanDeserializationSuite implements Serializable {
       return String.format("[%d,%d]", startTime, endTime);
     }
   }
+
+  public static final class RecordSpark22000 {
+    private String shortField;
+    private String intField;
+    private String longField;
+    private String floatField;
+    private String doubleField;
+    private String stringField;
+    private String booleanField;
+    private String timestampField;
+    private String nullIntField;
+
+    public RecordSpark22000() { }
+
+    public String getShortField() {
+      return shortField;
+    }
+
+    public void setShortField(String shortField) {
+      this.shortField = shortField;
+    }
+
+    public String getIntField() {
+      return intField;
+    }
+
+    public void setIntField(String intField) {
+      this.intField = intField;
+    }
+
+    public String getLongField() {
+      return longField;
+    }
+
+    public void setLongField(String longField) {
+      this.longField = longField;
+    }
+
+    public String getFloatField() {
+      return floatField;
+    }
+
+    public void setFloatField(String floatField) {
+      this.floatField = floatField;
+    }
+
+    public String getDoubleField() {
+      return doubleField;
+    }
+
+    public void setDoubleField(String doubleField) {
+      this.doubleField = doubleField;
+    }
+
+    public String getStringField() {
+      return stringField;
+    }
+
+    public void setStringField(String stringField) {
+      this.stringField = stringField;
+    }
+
+    public String getBooleanField() {
+      return booleanField;
+    }
+
+    public void setBooleanField(String booleanField) {
+      this.booleanField = booleanField;
+    }
+
+    public String getTimestampField() {
+      return timestampField;
+    }
+
+    public void setTimestampField(String timestampField) {
+      this.timestampField = timestampField;
+    }
+
+    public String getNullIntField() {
+      return nullIntField;
+    }
+
+    public void setNullIntField(String nullIntField) {
+      this.nullIntField = nullIntField;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) return true;
+      if (o == null || getClass() != o.getClass()) return false;
+      RecordSpark22000 that = (RecordSpark22000) o;
+      return Objects.equals(shortField, that.shortField) &&
+              Objects.equals(intField, that.intField) &&
+              Objects.equals(longField, that.longField) &&
+              Objects.equals(floatField, that.floatField) &&
+              Objects.equals(doubleField, that.doubleField) &&
+              Objects.equals(stringField, that.stringField) &&
+              Objects.equals(booleanField, that.booleanField) &&
+              Objects.equals(timestampField, that.timestampField) &&
+              Objects.equals(nullIntField, that.nullIntField);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(shortField, intField, longField, floatField, doubleField, stringField,
+              booleanField, timestampField, nullIntField);
+    }
+
+    @Override
+    public String toString() {
+      return com.google.common.base.Objects.toStringHelper(this)
+              .add("shortField", shortField)
+              .add("intField", intField)
+              .add("longField", longField)
+              .add("floatField", floatField)
+              .add("doubleField", doubleField)
+              .add("stringField", stringField)
+              .add("booleanField", booleanField)
+              .add("timestampField", timestampField)
+              .add("nullIntField", nullIntField)
+              .toString();
+    }
+  }
+
+  public static final class RecordSpark22000FailToUpcast {
+    private Integer id;
+
+    public RecordSpark22000FailToUpcast() {
+    }
+
+    public Integer getId() {
+      return id;
+    }
+
+    public void setId(Integer id) {
+      this.id = id;
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org