You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/02/02 02:53:51 UTC

[spark] branch branch-3.4 updated: [SPARK-42093][SQL] Move JavaTypeInference to AgnosticEncoders

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 18672003513 [SPARK-42093][SQL] Move JavaTypeInference to AgnosticEncoders
18672003513 is described below

commit 18672003513d5a4aa610b6b94dbbc15c33185d3a
Author: Herman van Hovell <he...@databricks.com>
AuthorDate: Thu Feb 2 10:53:11 2023 +0800

    [SPARK-42093][SQL] Move JavaTypeInference to AgnosticEncoders
    
    ### What changes were proposed in this pull request?
    This PR makes `JavaTypeInference` produce an `AgnosticEncoder`. The expression generation for these encoders is moved to `ScalaReflection`.
    
    ### Why are the changes needed?
    For the Spark Connect Scala Client we also want to be able to use Java Bean based results.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    I have added a lot of tests to `JavaTypeInferenceSuite`.
    
    Closes #39615 from hvanhovell/SPARK-42093.
    
    Authored-by: Herman van Hovell <he...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 0d93bb2c0a47f652727accfc36b652bdac33f894)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/JavaTypeInference.scala     | 565 ++++++---------------
 .../spark/sql/catalyst/ScalaReflection.scala       |  64 ++-
 .../sql/catalyst/encoders/AgnosticEncoder.scala    |  13 +-
 .../sql/catalyst/encoders/ExpressionEncoder.scala  |  11 +-
 .../sql/catalyst/expressions/objects/objects.scala |   8 +-
 .../sql/catalyst/JavaTypeInferenceSuite.scala      | 203 +++++++-
 6 files changed, 418 insertions(+), 446 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 81f363dda36..105bed38704 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -14,25 +14,18 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.spark.sql.catalyst
 
 import java.beans.{Introspector, PropertyDescriptor}
-import java.lang.{Iterable => JIterable}
-import java.lang.reflect.Type
-import java.util.{Iterator => JIterator, List => JList, Map => JMap}
+import java.lang.reflect.{ParameterizedType, Type, TypeVariable}
+import java.util.{ArrayDeque, List => JList, Map => JMap}
 import javax.annotation.Nonnull
 
-import scala.language.existentials
-
-import com.google.common.reflect.TypeToken
+import scala.annotation.tailrec
+import scala.reflect.ClassTag
 
-import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
-import org.apache.spark.sql.catalyst.SerializerBuildHelper._
-import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, P [...]
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.types._
 
@@ -40,123 +33,112 @@ import org.apache.spark.sql.types._
  * Type-inference utilities for POJOs and Java collections.
  */
 object JavaTypeInference {
-
-  private val iterableType = TypeToken.of(classOf[JIterable[_]])
-  private val mapType = TypeToken.of(classOf[JMap[_, _]])
-  private val listType = TypeToken.of(classOf[JList[_]])
-  private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType
-  private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType
-  private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType
-  private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType
-
-  // Guava changed the name of this method; this tries to stay compatible with both
-  // TODO replace with isSupertypeOf when Guava 14 support no longer needed for Hadoop
-  private val ttIsAssignableFrom: (TypeToken[_], TypeToken[_]) => Boolean = {
-    val ttMethods = classOf[TypeToken[_]].getMethods.
-      filter(_.getParameterCount == 1).
-      filter(_.getParameterTypes.head == classOf[TypeToken[_]])
-    val isAssignableFromMethod = ttMethods.find(_.getName == "isSupertypeOf").getOrElse(
-      ttMethods.find(_.getName == "isAssignableFrom").get)
-    (a: TypeToken[_], b: TypeToken[_]) => isAssignableFromMethod.invoke(a, b).asInstanceOf[Boolean]
-  }
-
   /**
-   * Infers the corresponding SQL data type of a JavaBean class.
-   * @param beanClass Java type
+   * Infers the corresponding SQL data type of a Java type.
+   * @param beanType Java type
    * @return (SQL data type, nullable)
    */
-  def inferDataType(beanClass: Class[_]): (DataType, Boolean) = {
-    inferDataType(TypeToken.of(beanClass))
+  def inferDataType(beanType: Type): (DataType, Boolean) = {
+    val encoder = encoderFor(beanType)
+    (encoder.dataType, encoder.nullable)
   }
 
   /**
-   * Infers the corresponding SQL data type of a Java type.
-   * @param beanType Java type
-   * @return (SQL data type, nullable)
+   * Infer an [[AgnosticEncoder]] for the [[Class]] `cls`.
    */
-  private[sql] def inferDataType(beanType: Type): (DataType, Boolean) = {
-    inferDataType(TypeToken.of(beanType))
+  def encoderFor[T](cls: Class[T]): AgnosticEncoder[T] = {
+    encoderFor(cls.asInstanceOf[Type])
   }
 
   /**
-   * Infers the corresponding SQL data type of a Java type.
-   * @param typeToken Java type
-   * @return (SQL data type, nullable)
+   * Infer an [[AgnosticEncoder]] for the `beanType`.
    */
-  private def inferDataType(typeToken: TypeToken[_], seenTypeSet: Set[Class[_]] = Set.empty)
-    : (DataType, Boolean) = {
-    typeToken.getRawType match {
-      case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
-        (c.getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance(), true)
-
-      case c: Class[_] if UDTRegistration.exists(c.getName) =>
-        val udt = UDTRegistration.getUDTFor(c.getName).get.getConstructor().newInstance()
-          .asInstanceOf[UserDefinedType[_ >: Null]]
-        (udt, true)
-
-      case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
-      case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true)
-
-      case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
-      case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
-      case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
-      case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
-      case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
-      case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
-      case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
-
-      case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
-      case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
-      case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
-      case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
-      case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
-      case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
-      case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
-
-      case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true)
-      case c: Class[_] if c == classOf[java.math.BigInteger] => (DecimalType.BigIntDecimal, true)
-      case c: Class[_] if c == classOf[java.time.LocalDate] => (DateType, true)
-      case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
-      case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true)
-      case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
-      case c: Class[_] if c == classOf[java.time.LocalDateTime] => (TimestampNTZType, true)
-      case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType(), true)
-      case c: Class[_] if c == classOf[java.time.Period] => (YearMonthIntervalType(), true)
-
-      case _ if typeToken.isArray =>
-        val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet)
-        (ArrayType(dataType, nullable), true)
-
-      case _ if ttIsAssignableFrom(iterableType, typeToken) =>
-        val (dataType, nullable) = inferDataType(elementType(typeToken), seenTypeSet)
-        (ArrayType(dataType, nullable), true)
-
-      case _ if ttIsAssignableFrom(mapType, typeToken) =>
-        val (keyType, valueType) = mapKeyValueType(typeToken)
-        val (keyDataType, _) = inferDataType(keyType, seenTypeSet)
-        val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet)
-        (MapType(keyDataType, valueDataType, nullable), true)
+  def encoderFor[T](beanType: Type): AgnosticEncoder[T] = {
+    encoderFor(beanType, Set.empty).asInstanceOf[AgnosticEncoder[T]]
+  }
 
-      case other if other.isEnum =>
-        (StringType, true)
+  private def encoderFor(t: Type, seenTypeSet: Set[Class[_]]): AgnosticEncoder[_] = t match {
+
+    case c: Class[_] if c == java.lang.Boolean.TYPE => PrimitiveBooleanEncoder
+    case c: Class[_] if c == java.lang.Byte.TYPE => PrimitiveByteEncoder
+    case c: Class[_] if c == java.lang.Short.TYPE => PrimitiveShortEncoder
+    case c: Class[_] if c == java.lang.Integer.TYPE => PrimitiveIntEncoder
+    case c: Class[_] if c == java.lang.Long.TYPE => PrimitiveLongEncoder
+    case c: Class[_] if c == java.lang.Float.TYPE => PrimitiveFloatEncoder
+    case c: Class[_] if c == java.lang.Double.TYPE => PrimitiveDoubleEncoder
+
+    case c: Class[_] if c == classOf[java.lang.Boolean] => BoxedBooleanEncoder
+    case c: Class[_] if c == classOf[java.lang.Byte] => BoxedByteEncoder
+    case c: Class[_] if c == classOf[java.lang.Short] => BoxedShortEncoder
+    case c: Class[_] if c == classOf[java.lang.Integer] => BoxedIntEncoder
+    case c: Class[_] if c == classOf[java.lang.Long] => BoxedLongEncoder
+    case c: Class[_] if c == classOf[java.lang.Float] => BoxedFloatEncoder
+    case c: Class[_] if c == classOf[java.lang.Double] => BoxedDoubleEncoder
+
+    case c: Class[_] if c == classOf[java.lang.String] => StringEncoder
+    case c: Class[_] if c == classOf[Array[Byte]] => BinaryEncoder
+    case c: Class[_] if c == classOf[java.math.BigDecimal] => DEFAULT_JAVA_DECIMAL_ENCODER
+    case c: Class[_] if c == classOf[java.math.BigInteger] => JavaBigIntEncoder
+    case c: Class[_] if c == classOf[java.time.LocalDate] => STRICT_LOCAL_DATE_ENCODER
+    case c: Class[_] if c == classOf[java.sql.Date] => STRICT_DATE_ENCODER
+    case c: Class[_] if c == classOf[java.time.Instant] => STRICT_INSTANT_ENCODER
+    case c: Class[_] if c == classOf[java.sql.Timestamp] => STRICT_TIMESTAMP_ENCODER
+    case c: Class[_] if c == classOf[java.time.LocalDateTime] => LocalDateTimeEncoder
+    case c: Class[_] if c == classOf[java.time.Duration] => DayTimeIntervalEncoder
+    case c: Class[_] if c == classOf[java.time.Period] => YearMonthIntervalEncoder
+
+    case c: Class[_] if c.isEnum => JavaEnumEncoder(ClassTag(c))
+
+    case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+      val udt = c.getAnnotation(classOf[SQLUserDefinedType]).udt()
+        .getConstructor().newInstance().asInstanceOf[UserDefinedType[Any]]
+      val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()
+      UDTEncoder(udt, udtClass)
+
+    case c: Class[_] if UDTRegistration.exists(c.getName) =>
+      val udt = UDTRegistration.getUDTFor(c.getName).get.getConstructor().
+        newInstance().asInstanceOf[UserDefinedType[Any]]
+      UDTEncoder(udt, udt.getClass)
+
+    case c: Class[_] if c.isArray =>
+      val elementEncoder = encoderFor(c.getComponentType, seenTypeSet)
+      ArrayEncoder(elementEncoder, elementEncoder.nullable)
+
+    case ImplementsList(c, Array(elementCls)) =>
+      val element = encoderFor(elementCls, seenTypeSet)
+      IterableEncoder(ClassTag(c), element, element.nullable, lenientSerialization = false)
+
+    case ImplementsMap(c, Array(keyCls, valueCls)) =>
+      val keyEncoder = encoderFor(keyCls, seenTypeSet)
+      val valueEncoder = encoderFor(valueCls, seenTypeSet)
+      MapEncoder(ClassTag(c), keyEncoder, valueEncoder, valueEncoder.nullable)
+
+    case c: Class[_] =>
+      if (seenTypeSet.contains(c)) {
+        throw QueryExecutionErrors.cannotHaveCircularReferencesInBeanClassError(c)
+      }
 
-      case other =>
-        if (seenTypeSet.contains(other)) {
-          throw QueryExecutionErrors.cannotHaveCircularReferencesInBeanClassError(other)
-        }
+      // TODO: we should only collect properties that have getter and setter. However, some tests
+      //   pass in scala case class as java bean class which doesn't have getter and setter.
+      val properties = getJavaBeanReadableProperties(c)
+      // Note that the fields are ordered by name.
+      val fields = properties.map { property =>
+        val readMethod = property.getReadMethod
+        val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet + c)
+        // The existence of `javax.annotation.Nonnull`, means this field is not nullable.
+        val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull])
+        EncoderField(
+          property.getName,
+          encoder,
+          encoder.nullable && !hasNonNull,
+          Metadata.empty,
+          Option(readMethod.getName),
+          Option(property.getWriteMethod).map(_.getName))
+      }
+      JavaBeanEncoder(ClassTag(c), fields)
 
-        // TODO: we should only collect properties that have getter and setter. However, some tests
-        // pass in scala case class as java bean class which doesn't have getter and setter.
-        val properties = getJavaBeanReadableProperties(other)
-        val fields = properties.map { property =>
-          val returnType = typeToken.method(property.getReadMethod).getReturnType
-          val (dataType, nullable) = inferDataType(returnType, seenTypeSet + other)
-          // The existence of `javax.annotation.Nonnull`, means this field is not nullable.
-          val hasNonNull = property.getReadMethod.isAnnotationPresent(classOf[Nonnull])
-          new StructField(property.getName, dataType, nullable && !hasNonNull)
-        }
-        (new StructType(fields), true)
-    }
+    case _ =>
+      throw QueryExecutionErrors.cannotFindEncoderForTypeError(t.toString)
   }
 
   def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
@@ -166,317 +148,58 @@ object JavaTypeInference {
       .filter(_.getReadMethod != null)
   }
 
-  private def getJavaBeanReadableAndWritableProperties(
-      beanClass: Class[_]): Array[PropertyDescriptor] = {
-    getJavaBeanReadableProperties(beanClass).filter(_.getWriteMethod != null)
-  }
-
-  private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
-    val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
-    val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]])
-    val iteratorType = iterableSuperType.resolveType(iteratorReturnType)
-    iteratorType.resolveType(nextReturnType)
-  }
-
-  private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = {
-    val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
-    val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]])
-    val keyType = elementType(mapSuperType.resolveType(keySetReturnType))
-    val valueType = elementType(mapSuperType.resolveType(valuesReturnType))
-    keyType -> valueType
-  }
-
-  /**
-   * Returns the Spark SQL DataType for a given java class.  Where this is not an exact mapping
-   * to a native type, an ObjectType is returned.
-   *
-   * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type
-   * system.  As a result, ObjectType will be returned for things like boxed Integers.
-   */
-  private def inferExternalType(cls: Class[_]): DataType = cls match {
-    case c if c == java.lang.Boolean.TYPE => BooleanType
-    case c if c == java.lang.Byte.TYPE => ByteType
-    case c if c == java.lang.Short.TYPE => ShortType
-    case c if c == java.lang.Integer.TYPE => IntegerType
-    case c if c == java.lang.Long.TYPE => LongType
-    case c if c == java.lang.Float.TYPE => FloatType
-    case c if c == java.lang.Double.TYPE => DoubleType
-    case c if c == classOf[Array[Byte]] => BinaryType
-    case _ => ObjectType(cls)
-  }
-
-  /**
-   * Returns an expression that can be used to deserialize a Spark SQL representation to an object
-   * of java bean `T` with a compatible schema.  The Spark SQL representation is located at ordinal
-   * 0 of a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed
-   * using `UnresolvedExtractValue`.
-   */
-  def deserializerFor(beanClass: Class[_]): Expression = {
-    val typeToken = TypeToken.of(beanClass)
-    val walkedTypePath = new WalkedTypePath().recordRoot(beanClass.getCanonicalName)
-    val (dataType, nullable) = inferDataType(typeToken)
-
-    // Assumes we are deserializing the first column of a row.
-    deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType,
-      nullable = nullable, walkedTypePath, deserializerFor(typeToken, _, walkedTypePath))
-  }
-
-  private def deserializerFor(
-      typeToken: TypeToken[_],
-      path: Expression,
-      walkedTypePath: WalkedTypePath): Expression = {
-    typeToken.getRawType match {
-      case c if !inferExternalType(c).isInstanceOf[ObjectType] => path
-
-      case c if c == classOf[java.lang.Short] ||
-                c == classOf[java.lang.Integer] ||
-                c == classOf[java.lang.Long] ||
-                c == classOf[java.lang.Double] ||
-                c == classOf[java.lang.Float] ||
-                c == classOf[java.lang.Byte] ||
-                c == classOf[java.lang.Boolean] =>
-        createDeserializerForTypesSupportValueOf(path, c)
-
-      case c if c == classOf[java.time.LocalDate] =>
-        createDeserializerForLocalDate(path)
-
-      case c if c == classOf[java.sql.Date] =>
-        createDeserializerForSqlDate(path)
-
-      case c if c == classOf[java.time.Instant] =>
-        createDeserializerForInstant(path)
-
-      case c if c == classOf[java.sql.Timestamp] =>
-        createDeserializerForSqlTimestamp(path)
+  private class ImplementsGenericInterface(interface: Class[_]) {
+    assert(interface.isInterface)
+    assert(interface.getTypeParameters.nonEmpty)
 
-      case c if c == classOf[java.time.LocalDateTime] =>
-        createDeserializerForLocalDateTime(path)
-
-      case c if c == classOf[java.time.Duration] =>
-        createDeserializerForDuration(path)
-
-      case c if c == classOf[java.time.Period] =>
-        createDeserializerForPeriod(path)
-
-      case c if c == classOf[java.lang.String] =>
-        createDeserializerForString(path, returnNullable = true)
-
-      case c if c == classOf[java.math.BigDecimal] =>
-        createDeserializerForJavaBigDecimal(path, returnNullable = true)
-
-      case c if c == classOf[java.math.BigInteger] =>
-        createDeserializerForJavaBigInteger(path, returnNullable = true)
-
-      case c if c.isArray =>
-        val elementType = c.getComponentType
-        val newTypePath = walkedTypePath.recordArray(elementType.getCanonicalName)
-        val (dataType, elementNullable) = inferDataType(elementType)
-        val mapFunction: Expression => Expression = element => {
-          // upcast the array element to the data type the encoder expected.
-          deserializerForWithNullSafetyAndUpcast(
-            element,
-            dataType,
-            nullable = elementNullable,
-            newTypePath,
-            deserializerFor(typeToken.getComponentType, _, newTypePath))
-        }
-
-        val arrayData = UnresolvedMapObjects(mapFunction, path)
-
-        val methodName = elementType match {
-          case c if c == java.lang.Integer.TYPE => "toIntArray"
-          case c if c == java.lang.Long.TYPE => "toLongArray"
-          case c if c == java.lang.Double.TYPE => "toDoubleArray"
-          case c if c == java.lang.Float.TYPE => "toFloatArray"
-          case c if c == java.lang.Short.TYPE => "toShortArray"
-          case c if c == java.lang.Byte.TYPE => "toByteArray"
-          case c if c == java.lang.Boolean.TYPE => "toBooleanArray"
-          // non-primitive
-          case _ => "array"
-        }
-        Invoke(arrayData, methodName, ObjectType(c))
-
-      case c if ttIsAssignableFrom(listType, typeToken) =>
-        val et = elementType(typeToken)
-        val newTypePath = walkedTypePath.recordArray(et.getType.getTypeName)
-        val (dataType, elementNullable) = inferDataType(et)
-        val mapFunction: Expression => Expression = element => {
-          // upcast the array element to the data type the encoder expected.
-          deserializerForWithNullSafetyAndUpcast(
-            element,
-            dataType,
-            nullable = elementNullable,
-            newTypePath,
-            deserializerFor(et, _, newTypePath))
-        }
-
-        UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c))
-
-      case _ if ttIsAssignableFrom(mapType, typeToken) =>
-        val (keyType, valueType) = mapKeyValueType(typeToken)
-        val newTypePath = walkedTypePath.recordMap(keyType.getType.getTypeName,
-          valueType.getType.getTypeName)
-
-        val keyData =
-          Invoke(
-            UnresolvedMapObjects(
-              p => deserializerFor(keyType, p, newTypePath),
-              MapKeys(path)),
-            "array",
-            ObjectType(classOf[Array[Any]]))
-
-        val valueData =
-          Invoke(
-            UnresolvedMapObjects(
-              p => deserializerFor(valueType, p, newTypePath),
-              MapValues(path)),
-            "array",
-            ObjectType(classOf[Array[Any]]))
-
-        StaticInvoke(
-          ArrayBasedMapData.getClass,
-          ObjectType(classOf[JMap[_, _]]),
-          "toJavaMap",
-          keyData :: valueData :: Nil,
-          returnNullable = false)
-
-      case other if other.isEnum =>
-        createDeserializerForTypesSupportValueOf(
-          createDeserializerForString(path, returnNullable = false),
-          other)
-
-      case other =>
-        val properties = getJavaBeanReadableAndWritableProperties(other)
-        val setters = properties.map { p =>
-          val fieldName = p.getName
-          val fieldType = typeToken.method(p.getReadMethod).getReturnType
-          val (dataType, nullable) = inferDataType(fieldType)
-          val newTypePath = walkedTypePath.recordField(fieldType.getType.getTypeName, fieldName)
-          // The existence of `javax.annotation.Nonnull`, means this field is not nullable.
-          val hasNonNull = p.getReadMethod.isAnnotationPresent(classOf[Nonnull])
-          val setter = expressionWithNullSafety(
-            deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePath),
-              newTypePath),
-            nullable = nullable && !hasNonNull,
-            newTypePath)
-          p.getWriteMethod.getName -> setter
-        }.toMap
-
-        val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false)
-        val result = InitializeJavaBean(newInstance, setters)
-
-        expressions.If(
-          IsNull(path),
-          expressions.Literal.create(null, ObjectType(other)),
-          result
-        )
+    def unapply(t: Type): Option[(Class[_], Array[Type])] = implementsInterface(t).map { cls =>
+      cls -> findTypeArgumentsForInterface(t)
     }
-  }
 
-  /**
-   * Returns an expression for serializing an object of the given type to a Spark SQL
-   * representation. The input object is located at ordinal 0 of a row, i.e.,
-   * `BoundReference(0, _)`.
-   */
-  def serializerFor(beanClass: Class[_]): Expression = {
-    val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
-    val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean"))
-    serializerFor(nullSafeInput, TypeToken.of(beanClass))
-  }
-
-  private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
-
-    def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
-      val (dataType, nullable) = inferDataType(elementType)
-      if (ScalaReflection.isNativeType(dataType)) {
-        val cls = input.dataType.asInstanceOf[ObjectType].cls
-        if (cls.isArray && cls.getComponentType.isPrimitive) {
-          createSerializerForPrimitiveArray(input, dataType)
-        } else {
-          createSerializerForGenericArray(input, dataType, nullable = nullable)
-        }
-      } else {
-        createSerializerForMapObjects(input, ObjectType(elementType.getRawType),
-          serializerFor(_, elementType))
-      }
+    @tailrec
+    private def implementsInterface(t: Type): Option[Class[_]] = t match {
+      case pt: ParameterizedType => implementsInterface(pt.getRawType)
+      case c: Class[_] if interface.isAssignableFrom(c) => Option(c)
+      case _ => None
     }
 
-    if (!inputObject.dataType.isInstanceOf[ObjectType]) {
-      inputObject
-    } else {
-      typeToken.getRawType match {
-        case c if c == classOf[String] => createSerializerForString(inputObject)
-
-        case c if c == classOf[java.time.Instant] => createSerializerForJavaInstant(inputObject)
-
-        case c if c == classOf[java.sql.Timestamp] => createSerializerForSqlTimestamp(inputObject)
-
-        case c if c == classOf[java.time.LocalDateTime] =>
-          createSerializerForLocalDateTime(inputObject)
-
-        case c if c == classOf[java.time.LocalDate] => createSerializerForJavaLocalDate(inputObject)
-
-        case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject)
-
-        case c if c == classOf[java.time.Duration] => createSerializerForJavaDuration(inputObject)
-
-        case c if c == classOf[java.time.Period] => createSerializerForJavaPeriod(inputObject)
-
-        case c if c == classOf[java.math.BigInteger] =>
-          createSerializerForBigInteger(inputObject)
-
-        case c if c == classOf[java.math.BigDecimal] =>
-          createSerializerForBigDecimal(inputObject)
-
-        case c if c == classOf[java.lang.Boolean] => createSerializerForBoolean(inputObject)
-        case c if c == classOf[java.lang.Byte] => createSerializerForByte(inputObject)
-        case c if c == classOf[java.lang.Short] => createSerializerForShort(inputObject)
-        case c if c == classOf[java.lang.Integer] => createSerializerForInteger(inputObject)
-        case c if c == classOf[java.lang.Long] => createSerializerForLong(inputObject)
-        case c if c == classOf[java.lang.Float] => createSerializerForFloat(inputObject)
-        case c if c == classOf[java.lang.Double] => createSerializerForDouble(inputObject)
-
-        case _ if typeToken.isArray =>
-          toCatalystArray(inputObject, typeToken.getComponentType)
-
-        case _ if ttIsAssignableFrom(listType, typeToken) =>
-          toCatalystArray(inputObject, elementType(typeToken))
-
-        case _ if ttIsAssignableFrom(mapType, typeToken) =>
-          val (keyType, valueType) = mapKeyValueType(typeToken)
-
-          createSerializerForMap(
-            inputObject,
-            MapElementInformation(
-              ObjectType(keyType.getRawType),
-              nullable = true,
-              serializerFor(_, keyType)),
-            MapElementInformation(
-              ObjectType(valueType.getRawType),
-              nullable = true,
-              serializerFor(_, valueType))
-          )
-
-        case other if other.isEnum =>
-          createSerializerForString(
-            Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false))
-
-        case other =>
-          val properties = getJavaBeanReadableAndWritableProperties(other)
-          val fields = properties.map { p =>
-            val fieldName = p.getName
-            val fieldType = typeToken.method(p.getReadMethod).getReturnType
-            val hasNonNull = p.getReadMethod.isAnnotationPresent(classOf[Nonnull])
-            val fieldValue = Invoke(
-              inputObject,
-              p.getReadMethod.getName,
-              inferExternalType(fieldType.getRawType),
-              propagateNull = !hasNonNull,
-              returnNullable = !hasNonNull)
-            (fieldName, serializerFor(fieldValue, fieldType))
-          }
-          createSerializerForObject(inputObject, fields)
+    private def findTypeArgumentsForInterface(t: Type): Array[Type] = {
+      val queue = new ArrayDeque[(Type, Map[Any, Type])]
+      queue.add(t -> Map.empty)
+      while (!queue.isEmpty) {
+        queue.poll() match {
+          case (pt: ParameterizedType, bindings) =>
+            // translate mappings...
+            val mappedTypeArguments = pt.getActualTypeArguments.map {
+              case v: TypeVariable[_] => bindings(v.getName)
+              case v => v
+            }
+            if (pt.getRawType == interface) {
+              return mappedTypeArguments
+            } else {
+              val mappedTypeArgumentMap = mappedTypeArguments
+                .zipWithIndex.map(_.swap)
+                .toMap[Any, Type]
+              queue.add(pt.getRawType -> mappedTypeArgumentMap)
+            }
+          case (c: Class[_], indexedBindings) =>
+            val namedBindings = c.getTypeParameters.zipWithIndex.map {
+              case (parameter, index) =>
+                parameter.getName -> indexedBindings(index)
+            }.toMap[Any, Type]
+            val superClass = c.getGenericSuperclass
+            if (superClass != null) {
+              queue.add(superClass -> namedBindings)
+            }
+            c.getGenericInterfaces.foreach { iface =>
+              queue.add(iface -> namedBindings)
+            }
+        }
       }
+      throw QueryExecutionErrors.unreachableError()
     }
   }
+
+  private object ImplementsList extends ImplementsGenericInterface(classOf[JList[_]])
+  private object ImplementsMap extends ImplementsGenericInterface(classOf[JMap[_, _]])
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 42208cd1098..4680a2aec2b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -264,6 +264,36 @@ object ScalaReflection extends ScalaReflection {
         Option(clsTag.runtimeClass),
         walkedTypePath)
 
+    case MapEncoder(tag, keyEncoder, valueEncoder, _)
+        if classOf[java.util.Map[_, _]].isAssignableFrom(tag.runtimeClass) =>
+      // TODO (hvanhovell) this is can be improved.
+      val newTypePath = walkedTypePath.recordMap(
+        keyEncoder.clsTag.runtimeClass.getName,
+        valueEncoder.clsTag.runtimeClass.getName)
+
+      val keyData =
+        Invoke(
+          UnresolvedMapObjects(
+            p => deserializerFor(keyEncoder, p, newTypePath),
+            MapKeys(path)),
+          "array",
+          ObjectType(classOf[Array[Any]]))
+
+      val valueData =
+        Invoke(
+          UnresolvedMapObjects(
+            p => deserializerFor(valueEncoder, p, newTypePath),
+            MapValues(path)),
+          "array",
+          ObjectType(classOf[Array[Any]]))
+
+      StaticInvoke(
+        ArrayBasedMapData.getClass,
+        ObjectType(classOf[java.util.Map[_, _]]),
+        "toJavaMap",
+        keyData :: valueData :: Nil,
+        returnNullable = false)
+
     case MapEncoder(tag, keyEncoder, valueEncoder, _) =>
       val newTypePath = walkedTypePath.recordMap(
         keyEncoder.clsTag.runtimeClass.getName,
@@ -312,6 +342,26 @@ object ScalaReflection extends ScalaReflection {
       exprs.If(IsNull(path),
         exprs.Literal.create(null, externalDataTypeFor(enc)),
         CreateExternalRow(convertedFields, enc.schema))
+
+    case JavaBeanEncoder(tag, fields) =>
+      val setters = fields.map { f =>
+        val newTypePath = walkedTypePath.recordField(
+          f.enc.clsTag.runtimeClass.getName,
+          f.name)
+        val setter = expressionWithNullSafety(
+          deserializerFor(
+            f.enc,
+            addToPath(path, f.name, f.enc.dataType, newTypePath),
+            newTypePath),
+          nullable = f.nullable,
+          newTypePath)
+        f.writeMethod.get -> setter
+      }
+
+      val cls = tag.runtimeClass
+      val newInstance = NewInstance(cls, Nil, ObjectType(cls), propagateNull = false)
+      val result = InitializeJavaBean(newInstance, setters.toMap)
+      exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result)
   }
 
   private def deserializeArray(
@@ -446,6 +496,18 @@ object ScalaReflection extends ScalaReflection {
         field.name -> convertedField
       }
       createSerializerForObject(input, serializedFields)
+
+    case JavaBeanEncoder(_, fields) =>
+      val serializedFields = fields.map { f =>
+        val fieldValue = Invoke(
+          KnownNotNull(input),
+          f.readMethod.get,
+          externalDataTypeFor(f.enc),
+          propagateNull = f.nullable,
+          returnNullable = f.nullable)
+        f.name -> serializerFor(f.enc, fieldValue)
+      }
+      createSerializerForObject(input, serializedFields)
   }
 
   private def serializerForArray(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index cdc64f2ddb5..1a3c1089649 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -91,7 +91,9 @@ object AgnosticEncoders {
       name: String,
       enc: AgnosticEncoder[_],
       nullable: Boolean,
-      metadata: Metadata) {
+      metadata: Metadata,
+      readMethod: Option[String] = None,
+      writeMethod: Option[String] = None) {
     def structField: StructField = StructField(name, enc.dataType, nullable, metadata)
   }
 
@@ -112,6 +114,15 @@ object AgnosticEncoders {
     override def clsTag: ClassTag[Row] = classTag[Row]
   }
 
+  case class JavaBeanEncoder[K](
+      override val clsTag: ClassTag[K],
+      fields: Seq[EncoderField])
+    extends AgnosticEncoder[K] {
+    override def isPrimitive: Boolean = false
+    override val schema: StructType = StructType(fields.map(_.structField))
+    override def dataType: DataType = schema
+  }
+
   // This will only work for encoding from/to Sparks' InternalRow format.
   // It is here for compatibility.
   case class UDTEncoder[E >: Null](
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 9ca2fc72ad9..faa165c298d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -59,16 +59,7 @@ object ExpressionEncoder {
 
   // TODO: improve error message for java bean encoder.
   def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
-    val schema = JavaTypeInference.inferDataType(beanClass)._1
-    assert(schema.isInstanceOf[StructType])
-
-    val objSerializer = JavaTypeInference.serializerFor(beanClass)
-    val objDeserializer = JavaTypeInference.deserializerFor(beanClass)
-
-    new ExpressionEncoder[T](
-      objSerializer,
-      objDeserializer,
-      ClassTag[T](beanClass))
+     apply(JavaTypeInference.encoderFor(beanClass))
   }
 
   /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 299a928f267..929beb660ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -1927,7 +1927,8 @@ case class ValidateExternalType(child: Expression, expected: DataType, externalD
       (value: Any) => {
         value.getClass.isArray ||
           value.isInstanceOf[scala.collection.Seq[_]] ||
-          value.isInstanceOf[Set[_]]
+          value.isInstanceOf[Set[_]] ||
+          value.isInstanceOf[java.util.List[_]]
       }
     case _: DateType =>
       (value: Any) => {
@@ -1968,7 +1969,10 @@ case class ValidateExternalType(child: Expression, expected: DataType, externalD
           classOf[scala.math.BigDecimal],
           classOf[Decimal]))
       case _: ArrayType =>
-        val check = genCheckTypes(Seq(classOf[scala.collection.Seq[_]], classOf[Set[_]]))
+        val check = genCheckTypes(Seq(
+          classOf[scala.collection.Seq[_]],
+          classOf[Set[_]],
+          classOf[java.util.List[_]]))
         s"$obj.getClass().isArray() || $check"
       case _: DateType =>
         genCheckTypes(Seq(classOf[java.sql.Date], classOf[java.time.LocalDate]))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala
index 9c1d0c17777..35f5bf739bf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala
@@ -18,25 +18,206 @@
 package org.apache.spark.sql.catalyst
 
 import java.math.BigInteger
+import java.util.{LinkedList, List => JList, Map => JMap}
 
-import scala.beans.BeanProperty
+import scala.beans.{BeanProperty, BooleanBeanProperty}
+import scala.reflect.{classTag, ClassTag}
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.{CheckOverflow, Expression, Literal}
-import org.apache.spark.sql.types.DecimalType
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, UDTCaseClass, UDTForCaseClass}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
+import org.apache.spark.sql.types.{DecimalType, MapType, Metadata, StringType, StructField, StructType}
 
-class DummyBean() {
-  @BeanProperty var bigInteger = null: BigInteger
+class DummyBean {
+  @BeanProperty var bigInteger: BigInteger = _
 }
 
+class GenericCollectionBean {
+  @BeanProperty var listOfListOfStrings: JList[JList[String]] = _
+  @BeanProperty var mapOfDummyBeans: JMap[String, DummyBean] = _
+  @BeanProperty var linkedListOfStrings: LinkedList[String] = _
+}
+
+class LeafBean {
+  @BooleanBeanProperty var primitiveBoolean: Boolean = false
+  @BeanProperty var primitiveByte: Byte = 0
+  @BeanProperty var primitiveShort: Short = 0
+  @BeanProperty var primitiveInt: Int = 0
+  @BeanProperty var primitiveLong: Long = 0
+  @BeanProperty var primitiveFloat: Float = 0
+  @BeanProperty var primitiveDouble: Double = 0
+  @BeanProperty var boxedBoolean: java.lang.Boolean = false
+  @BeanProperty var boxedByte: java.lang.Byte = 0.toByte
+  @BeanProperty var boxedShort: java.lang.Short = 0.toShort
+  @BeanProperty var boxedInt: java.lang.Integer = 0
+  @BeanProperty var boxedLong: java.lang.Long = 0
+  @BeanProperty var boxedFloat: java.lang.Float = 0
+  @BeanProperty var boxedDouble: java.lang.Double = 0
+  @BeanProperty var string: String = _
+  @BeanProperty var binary: Array[Byte] = _
+  @BeanProperty var bigDecimal: java.math.BigDecimal = _
+  @BeanProperty var bigInteger: java.math.BigInteger = _
+  @BeanProperty var localDate: java.time.LocalDate = _
+  @BeanProperty var date: java.sql.Date = _
+  @BeanProperty var instant: java.time.Instant = _
+  @BeanProperty var timestamp: java.sql.Timestamp = _
+  @BeanProperty var localDateTime: java.time.LocalDateTime = _
+  @BeanProperty var duration: java.time.Duration = _
+  @BeanProperty var period: java.time.Period = _
+  @BeanProperty var enum: java.time.Month = _
+  @BeanProperty val readOnlyString = "read-only"
+
+  var nonNullString: String = "value"
+  @javax.annotation.Nonnull
+  def getNonNullString: String = nonNullString
+  def setNonNullString(v: String): Unit = nonNullString = {
+    java.util.Objects.nonNull(v)
+    v
+  }
+}
+
+class ArrayBean {
+  @BeanProperty var dummyBeanArray: Array[DummyBean] = _
+  @BeanProperty var primitiveIntArray: Array[Int] = _
+  @BeanProperty var stringArray: Array[String] = _
+}
+
+class UDTBean {
+  @BeanProperty var udt: UDTCaseClass = _
+}
+
+/**
+ * Test suite for Encoders produced by [[JavaTypeInference]].
+ */
 class JavaTypeInferenceSuite extends SparkFunSuite {
 
+  private def encoderField(
+      name: String,
+      encoder: AgnosticEncoder[_],
+      overrideNullable: Option[Boolean] = None,
+      readOnly: Boolean = false): EncoderField = {
+    val readPrefix = if (encoder == PrimitiveBooleanEncoder) "is" else "get"
+    EncoderField(
+      name,
+      encoder,
+      overrideNullable.getOrElse(encoder.nullable),
+      Metadata.empty,
+      Option(readPrefix + name.capitalize),
+      Option("set" + name.capitalize).filterNot(_ => readOnly))
+  }
+
+  private val expectedDummyBeanEncoder =
+    JavaBeanEncoder[DummyBean](
+      ClassTag(classOf[DummyBean]),
+      Seq(encoderField("bigInteger", JavaBigIntEncoder)))
+
+  private val expectedDummyBeanSchema =
+    StructType(StructField("bigInteger", DecimalType(38, 0)) :: Nil)
+
   test("SPARK-41007: JavaTypeInference returns the correct serializer for BigInteger") {
-    var serializer = JavaTypeInference.serializerFor(classOf[DummyBean])
-    var bigIntegerFieldName: Expression = serializer.children(0)
-    assert(bigIntegerFieldName.asInstanceOf[Literal].value.toString == "bigInteger")
-    var bigIntegerFieldExpression: Expression = serializer.children(1)
-    assert(bigIntegerFieldExpression.asInstanceOf[CheckOverflow].dataType ==
-      DecimalType.BigIntDecimal)
+    val encoder = JavaTypeInference.encoderFor(classOf[DummyBean])
+    assert(encoder === expectedDummyBeanEncoder)
+    assert(encoder.schema === expectedDummyBeanSchema)
+  }
+
+  test("resolve schema for class") {
+    val (schema, nullable) = JavaTypeInference.inferDataType(classOf[DummyBean])
+    assert(nullable)
+    assert(schema === expectedDummyBeanSchema)
+  }
+
+  test("resolve schema for type") {
+    val getter = classOf[GenericCollectionBean].getDeclaredMethods
+      .find(_.getName == "getMapOfDummyBeans")
+      .get
+    val (schema, nullable) = JavaTypeInference.inferDataType(getter.getGenericReturnType)
+    val expected = MapType(StringType, expectedDummyBeanSchema, valueContainsNull = true)
+    assert(nullable)
+    assert(schema === expected)
+  }
+
+  test("resolve type parameters for map and list") {
+    val encoder = JavaTypeInference.encoderFor(classOf[GenericCollectionBean])
+    val expected = JavaBeanEncoder(ClassTag(classOf[GenericCollectionBean]), Seq(
+      encoderField(
+        "linkedListOfStrings",
+        IterableEncoder(
+          ClassTag(classOf[LinkedList[_]]),
+          StringEncoder,
+          containsNull = true,
+          lenientSerialization = false)),
+      encoderField(
+        "listOfListOfStrings",
+        IterableEncoder(
+          ClassTag(classOf[JList[_]]),
+          IterableEncoder(
+            ClassTag(classOf[JList[_]]),
+            StringEncoder,
+            containsNull = true,
+            lenientSerialization = false),
+          containsNull = true,
+          lenientSerialization = false)),
+      encoderField(
+        "mapOfDummyBeans",
+        MapEncoder(
+          ClassTag(classOf[JMap[_, _]]),
+          StringEncoder,
+          expectedDummyBeanEncoder,
+          valueContainsNull = true))))
+    assert(encoder === expected)
+  }
+
+  test("resolve leaf encoders") {
+    val encoder = JavaTypeInference.encoderFor(classOf[LeafBean])
+    val expected = JavaBeanEncoder(ClassTag(classOf[LeafBean]), Seq(
+      // The order is different from the definition because fields are ordered by name.
+      encoderField("bigDecimal", DEFAULT_JAVA_DECIMAL_ENCODER),
+      encoderField("bigInteger", JavaBigIntEncoder),
+      encoderField("binary", BinaryEncoder),
+      encoderField("boxedBoolean", BoxedBooleanEncoder),
+      encoderField("boxedByte", BoxedByteEncoder),
+      encoderField("boxedDouble", BoxedDoubleEncoder),
+      encoderField("boxedFloat", BoxedFloatEncoder),
+      encoderField("boxedInt", BoxedIntEncoder),
+      encoderField("boxedLong", BoxedLongEncoder),
+      encoderField("boxedShort", BoxedShortEncoder),
+      encoderField("date", STRICT_DATE_ENCODER),
+      encoderField("duration", DayTimeIntervalEncoder),
+      encoderField("enum", JavaEnumEncoder(classTag[java.time.Month])),
+      encoderField("instant", STRICT_INSTANT_ENCODER),
+      encoderField("localDate", STRICT_LOCAL_DATE_ENCODER),
+      encoderField("localDateTime", LocalDateTimeEncoder),
+      encoderField("nonNullString", StringEncoder, overrideNullable = Option(false)),
+      encoderField("period", YearMonthIntervalEncoder),
+      encoderField("primitiveBoolean", PrimitiveBooleanEncoder),
+      encoderField("primitiveByte", PrimitiveByteEncoder),
+      encoderField("primitiveDouble", PrimitiveDoubleEncoder),
+      encoderField("primitiveFloat", PrimitiveFloatEncoder),
+      encoderField("primitiveInt", PrimitiveIntEncoder),
+      encoderField("primitiveLong", PrimitiveLongEncoder),
+      encoderField("primitiveShort", PrimitiveShortEncoder),
+      encoderField("readOnlyString", StringEncoder, readOnly = true),
+      encoderField("string", StringEncoder),
+      encoderField("timestamp", STRICT_TIMESTAMP_ENCODER)
+    ))
+    assert(encoder === expected)
+  }
+
+  test("resolve array encoders") {
+    val encoder = JavaTypeInference.encoderFor(classOf[ArrayBean])
+    val expected = JavaBeanEncoder(ClassTag(classOf[ArrayBean]), Seq(
+      encoderField("dummyBeanArray", ArrayEncoder(expectedDummyBeanEncoder, containsNull = true)),
+      encoderField("primitiveIntArray", ArrayEncoder(PrimitiveIntEncoder, containsNull = false)),
+      encoderField("stringArray", ArrayEncoder(StringEncoder, containsNull = true))
+    ))
+    assert(encoder === expected)
+  }
+
+  test("resolve UDT encoders") {
+    val encoder = JavaTypeInference.encoderFor(classOf[UDTBean])
+    val expected = JavaBeanEncoder(ClassTag(classOf[UDTBean]), Seq(
+      encoderField("udt", UDTEncoder(new UDTForCaseClass, classOf[UDTForCaseClass]))
+    ))
+    assert(encoder === expected)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org