You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/02/02 02:53:51 UTC
[spark] branch branch-3.4 updated: [SPARK-42093][SQL] Move JavaTypeInference to AgnosticEncoders
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 18672003513 [SPARK-42093][SQL] Move JavaTypeInference to AgnosticEncoders
18672003513 is described below
commit 18672003513d5a4aa610b6b94dbbc15c33185d3a
Author: Herman van Hovell <he...@databricks.com>
AuthorDate: Thu Feb 2 10:53:11 2023 +0800
[SPARK-42093][SQL] Move JavaTypeInference to AgnosticEncoders
### What changes were proposed in this pull request?
This PR makes `JavaTypeInference` produce an `AgnosticEncoder`. The expression generation for these encoders is moved to `ScalaReflection`.
### Why are the changes needed?
For the Spark Connect Scala Client we also want to be able to use Java Bean based results.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
I have added a lot of tests to `JavaTypeInferenceSuite`.
Closes #39615 from hvanhovell/SPARK-42093.
Authored-by: Herman van Hovell <he...@databricks.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
(cherry picked from commit 0d93bb2c0a47f652727accfc36b652bdac33f894)
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../spark/sql/catalyst/JavaTypeInference.scala | 565 ++++++---------------
.../spark/sql/catalyst/ScalaReflection.scala | 64 ++-
.../sql/catalyst/encoders/AgnosticEncoder.scala | 13 +-
.../sql/catalyst/encoders/ExpressionEncoder.scala | 11 +-
.../sql/catalyst/expressions/objects/objects.scala | 8 +-
.../sql/catalyst/JavaTypeInferenceSuite.scala | 203 +++++++-
6 files changed, 418 insertions(+), 446 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 81f363dda36..105bed38704 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -14,25 +14,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.spark.sql.catalyst
import java.beans.{Introspector, PropertyDescriptor}
-import java.lang.{Iterable => JIterable}
-import java.lang.reflect.Type
-import java.util.{Iterator => JIterator, List => JList, Map => JMap}
+import java.lang.reflect.{ParameterizedType, Type, TypeVariable}
+import java.util.{ArrayDeque, List => JList, Map => JMap}
import javax.annotation.Nonnull
-import scala.language.existentials
-
-import com.google.common.reflect.TypeToken
+import scala.annotation.tailrec
+import scala.reflect.ClassTag
-import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
-import org.apache.spark.sql.catalyst.SerializerBuildHelper._
-import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, P [...]
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
@@ -40,123 +33,112 @@ import org.apache.spark.sql.types._
* Type-inference utilities for POJOs and Java collections.
*/
object JavaTypeInference {
-
- private val iterableType = TypeToken.of(classOf[JIterable[_]])
- private val mapType = TypeToken.of(classOf[JMap[_, _]])
- private val listType = TypeToken.of(classOf[JList[_]])
- private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType
- private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType
- private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType
- private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType
-
- // Guava changed the name of this method; this tries to stay compatible with both
- // TODO replace with isSupertypeOf when Guava 14 support no longer needed for Hadoop
- private val ttIsAssignableFrom: (TypeToken[_], TypeToken[_]) => Boolean = {
- val ttMethods = classOf[TypeToken[_]].getMethods.
- filter(_.getParameterCount == 1).
- filter(_.getParameterTypes.head == classOf[TypeToken[_]])
- val isAssignableFromMethod = ttMethods.find(_.getName == "isSupertypeOf").getOrElse(
- ttMethods.find(_.getName == "isAssignableFrom").get)
- (a: TypeToken[_], b: TypeToken[_]) => isAssignableFromMethod.invoke(a, b).asInstanceOf[Boolean]
- }
-
/**
- * Infers the corresponding SQL data type of a JavaBean class.
- * @param beanClass Java type
+ * Infers the corresponding SQL data type of a Java type.
+ * @param beanType Java type
* @return (SQL data type, nullable)
*/
- def inferDataType(beanClass: Class[_]): (DataType, Boolean) = {
- inferDataType(TypeToken.of(beanClass))
+ def inferDataType(beanType: Type): (DataType, Boolean) = {
+ val encoder = encoderFor(beanType)
+ (encoder.dataType, encoder.nullable)
}
/**
- * Infers the corresponding SQL data type of a Java type.
- * @param beanType Java type
- * @return (SQL data type, nullable)
+ * Infer an [[AgnosticEncoder]] for the [[Class]] `cls`.
*/
- private[sql] def inferDataType(beanType: Type): (DataType, Boolean) = {
- inferDataType(TypeToken.of(beanType))
+ def encoderFor[T](cls: Class[T]): AgnosticEncoder[T] = {
+ encoderFor(cls.asInstanceOf[Type])
}
/**
- * Infers the corresponding SQL data type of a Java type.
- * @param typeToken Java type
- * @return (SQL data type, nullable)
+ * Infer an [[AgnosticEncoder]] for the `beanType`.
*/
- private def inferDataType(typeToken: TypeToken[_], seenTypeSet: Set[Class[_]] = Set.empty)
- : (DataType, Boolean) = {
- typeToken.getRawType match {
- case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
- (c.getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance(), true)
-
- case c: Class[_] if UDTRegistration.exists(c.getName) =>
- val udt = UDTRegistration.getUDTFor(c.getName).get.getConstructor().newInstance()
- .asInstanceOf[UserDefinedType[_ >: Null]]
- (udt, true)
-
- case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
- case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true)
-
- case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
- case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
- case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
- case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
- case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
- case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
- case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
-
- case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
- case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
- case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
- case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
- case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
- case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
- case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
-
- case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true)
- case c: Class[_] if c == classOf[java.math.BigInteger] => (DecimalType.BigIntDecimal, true)
- case c: Class[_] if c == classOf[java.time.LocalDate] => (DateType, true)
- case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
- case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true)
- case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
- case c: Class[_] if c == classOf[java.time.LocalDateTime] => (TimestampNTZType, true)
- case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType(), true)
- case c: Class[_] if c == classOf[java.time.Period] => (YearMonthIntervalType(), true)
-
- case _ if typeToken.isArray =>
- val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet)
- (ArrayType(dataType, nullable), true)
-
- case _ if ttIsAssignableFrom(iterableType, typeToken) =>
- val (dataType, nullable) = inferDataType(elementType(typeToken), seenTypeSet)
- (ArrayType(dataType, nullable), true)
-
- case _ if ttIsAssignableFrom(mapType, typeToken) =>
- val (keyType, valueType) = mapKeyValueType(typeToken)
- val (keyDataType, _) = inferDataType(keyType, seenTypeSet)
- val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet)
- (MapType(keyDataType, valueDataType, nullable), true)
+ def encoderFor[T](beanType: Type): AgnosticEncoder[T] = {
+ encoderFor(beanType, Set.empty).asInstanceOf[AgnosticEncoder[T]]
+ }
- case other if other.isEnum =>
- (StringType, true)
+ private def encoderFor(t: Type, seenTypeSet: Set[Class[_]]): AgnosticEncoder[_] = t match {
+
+ case c: Class[_] if c == java.lang.Boolean.TYPE => PrimitiveBooleanEncoder
+ case c: Class[_] if c == java.lang.Byte.TYPE => PrimitiveByteEncoder
+ case c: Class[_] if c == java.lang.Short.TYPE => PrimitiveShortEncoder
+ case c: Class[_] if c == java.lang.Integer.TYPE => PrimitiveIntEncoder
+ case c: Class[_] if c == java.lang.Long.TYPE => PrimitiveLongEncoder
+ case c: Class[_] if c == java.lang.Float.TYPE => PrimitiveFloatEncoder
+ case c: Class[_] if c == java.lang.Double.TYPE => PrimitiveDoubleEncoder
+
+ case c: Class[_] if c == classOf[java.lang.Boolean] => BoxedBooleanEncoder
+ case c: Class[_] if c == classOf[java.lang.Byte] => BoxedByteEncoder
+ case c: Class[_] if c == classOf[java.lang.Short] => BoxedShortEncoder
+ case c: Class[_] if c == classOf[java.lang.Integer] => BoxedIntEncoder
+ case c: Class[_] if c == classOf[java.lang.Long] => BoxedLongEncoder
+ case c: Class[_] if c == classOf[java.lang.Float] => BoxedFloatEncoder
+ case c: Class[_] if c == classOf[java.lang.Double] => BoxedDoubleEncoder
+
+ case c: Class[_] if c == classOf[java.lang.String] => StringEncoder
+ case c: Class[_] if c == classOf[Array[Byte]] => BinaryEncoder
+ case c: Class[_] if c == classOf[java.math.BigDecimal] => DEFAULT_JAVA_DECIMAL_ENCODER
+ case c: Class[_] if c == classOf[java.math.BigInteger] => JavaBigIntEncoder
+ case c: Class[_] if c == classOf[java.time.LocalDate] => STRICT_LOCAL_DATE_ENCODER
+ case c: Class[_] if c == classOf[java.sql.Date] => STRICT_DATE_ENCODER
+ case c: Class[_] if c == classOf[java.time.Instant] => STRICT_INSTANT_ENCODER
+ case c: Class[_] if c == classOf[java.sql.Timestamp] => STRICT_TIMESTAMP_ENCODER
+ case c: Class[_] if c == classOf[java.time.LocalDateTime] => LocalDateTimeEncoder
+ case c: Class[_] if c == classOf[java.time.Duration] => DayTimeIntervalEncoder
+ case c: Class[_] if c == classOf[java.time.Period] => YearMonthIntervalEncoder
+
+ case c: Class[_] if c.isEnum => JavaEnumEncoder(ClassTag(c))
+
+ case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+ val udt = c.getAnnotation(classOf[SQLUserDefinedType]).udt()
+ .getConstructor().newInstance().asInstanceOf[UserDefinedType[Any]]
+ val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()
+ UDTEncoder(udt, udtClass)
+
+ case c: Class[_] if UDTRegistration.exists(c.getName) =>
+ val udt = UDTRegistration.getUDTFor(c.getName).get.getConstructor().
+ newInstance().asInstanceOf[UserDefinedType[Any]]
+ UDTEncoder(udt, udt.getClass)
+
+ case c: Class[_] if c.isArray =>
+ val elementEncoder = encoderFor(c.getComponentType, seenTypeSet)
+ ArrayEncoder(elementEncoder, elementEncoder.nullable)
+
+ case ImplementsList(c, Array(elementCls)) =>
+ val element = encoderFor(elementCls, seenTypeSet)
+ IterableEncoder(ClassTag(c), element, element.nullable, lenientSerialization = false)
+
+ case ImplementsMap(c, Array(keyCls, valueCls)) =>
+ val keyEncoder = encoderFor(keyCls, seenTypeSet)
+ val valueEncoder = encoderFor(valueCls, seenTypeSet)
+ MapEncoder(ClassTag(c), keyEncoder, valueEncoder, valueEncoder.nullable)
+
+ case c: Class[_] =>
+ if (seenTypeSet.contains(c)) {
+ throw QueryExecutionErrors.cannotHaveCircularReferencesInBeanClassError(c)
+ }
- case other =>
- if (seenTypeSet.contains(other)) {
- throw QueryExecutionErrors.cannotHaveCircularReferencesInBeanClassError(other)
- }
+ // TODO: we should only collect properties that have getter and setter. However, some tests
+ // pass in scala case class as java bean class which doesn't have getter and setter.
+ val properties = getJavaBeanReadableProperties(c)
+ // Note that the fields are ordered by name.
+ val fields = properties.map { property =>
+ val readMethod = property.getReadMethod
+ val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet + c)
+ // The existence of `javax.annotation.Nonnull`, means this field is not nullable.
+ val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull])
+ EncoderField(
+ property.getName,
+ encoder,
+ encoder.nullable && !hasNonNull,
+ Metadata.empty,
+ Option(readMethod.getName),
+ Option(property.getWriteMethod).map(_.getName))
+ }
+ JavaBeanEncoder(ClassTag(c), fields)
- // TODO: we should only collect properties that have getter and setter. However, some tests
- // pass in scala case class as java bean class which doesn't have getter and setter.
- val properties = getJavaBeanReadableProperties(other)
- val fields = properties.map { property =>
- val returnType = typeToken.method(property.getReadMethod).getReturnType
- val (dataType, nullable) = inferDataType(returnType, seenTypeSet + other)
- // The existence of `javax.annotation.Nonnull`, means this field is not nullable.
- val hasNonNull = property.getReadMethod.isAnnotationPresent(classOf[Nonnull])
- new StructField(property.getName, dataType, nullable && !hasNonNull)
- }
- (new StructType(fields), true)
- }
+ case _ =>
+ throw QueryExecutionErrors.cannotFindEncoderForTypeError(t.toString)
}
def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
@@ -166,317 +148,58 @@ object JavaTypeInference {
.filter(_.getReadMethod != null)
}
- private def getJavaBeanReadableAndWritableProperties(
- beanClass: Class[_]): Array[PropertyDescriptor] = {
- getJavaBeanReadableProperties(beanClass).filter(_.getWriteMethod != null)
- }
-
- private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
- val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
- val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]])
- val iteratorType = iterableSuperType.resolveType(iteratorReturnType)
- iteratorType.resolveType(nextReturnType)
- }
-
- private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = {
- val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
- val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]])
- val keyType = elementType(mapSuperType.resolveType(keySetReturnType))
- val valueType = elementType(mapSuperType.resolveType(valuesReturnType))
- keyType -> valueType
- }
-
- /**
- * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping
- * to a native type, an ObjectType is returned.
- *
- * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type
- * system. As a result, ObjectType will be returned for things like boxed Integers.
- */
- private def inferExternalType(cls: Class[_]): DataType = cls match {
- case c if c == java.lang.Boolean.TYPE => BooleanType
- case c if c == java.lang.Byte.TYPE => ByteType
- case c if c == java.lang.Short.TYPE => ShortType
- case c if c == java.lang.Integer.TYPE => IntegerType
- case c if c == java.lang.Long.TYPE => LongType
- case c if c == java.lang.Float.TYPE => FloatType
- case c if c == java.lang.Double.TYPE => DoubleType
- case c if c == classOf[Array[Byte]] => BinaryType
- case _ => ObjectType(cls)
- }
-
- /**
- * Returns an expression that can be used to deserialize a Spark SQL representation to an object
- * of java bean `T` with a compatible schema. The Spark SQL representation is located at ordinal
- * 0 of a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed
- * using `UnresolvedExtractValue`.
- */
- def deserializerFor(beanClass: Class[_]): Expression = {
- val typeToken = TypeToken.of(beanClass)
- val walkedTypePath = new WalkedTypePath().recordRoot(beanClass.getCanonicalName)
- val (dataType, nullable) = inferDataType(typeToken)
-
- // Assumes we are deserializing the first column of a row.
- deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType,
- nullable = nullable, walkedTypePath, deserializerFor(typeToken, _, walkedTypePath))
- }
-
- private def deserializerFor(
- typeToken: TypeToken[_],
- path: Expression,
- walkedTypePath: WalkedTypePath): Expression = {
- typeToken.getRawType match {
- case c if !inferExternalType(c).isInstanceOf[ObjectType] => path
-
- case c if c == classOf[java.lang.Short] ||
- c == classOf[java.lang.Integer] ||
- c == classOf[java.lang.Long] ||
- c == classOf[java.lang.Double] ||
- c == classOf[java.lang.Float] ||
- c == classOf[java.lang.Byte] ||
- c == classOf[java.lang.Boolean] =>
- createDeserializerForTypesSupportValueOf(path, c)
-
- case c if c == classOf[java.time.LocalDate] =>
- createDeserializerForLocalDate(path)
-
- case c if c == classOf[java.sql.Date] =>
- createDeserializerForSqlDate(path)
-
- case c if c == classOf[java.time.Instant] =>
- createDeserializerForInstant(path)
-
- case c if c == classOf[java.sql.Timestamp] =>
- createDeserializerForSqlTimestamp(path)
+ private class ImplementsGenericInterface(interface: Class[_]) {
+ assert(interface.isInterface)
+ assert(interface.getTypeParameters.nonEmpty)
- case c if c == classOf[java.time.LocalDateTime] =>
- createDeserializerForLocalDateTime(path)
-
- case c if c == classOf[java.time.Duration] =>
- createDeserializerForDuration(path)
-
- case c if c == classOf[java.time.Period] =>
- createDeserializerForPeriod(path)
-
- case c if c == classOf[java.lang.String] =>
- createDeserializerForString(path, returnNullable = true)
-
- case c if c == classOf[java.math.BigDecimal] =>
- createDeserializerForJavaBigDecimal(path, returnNullable = true)
-
- case c if c == classOf[java.math.BigInteger] =>
- createDeserializerForJavaBigInteger(path, returnNullable = true)
-
- case c if c.isArray =>
- val elementType = c.getComponentType
- val newTypePath = walkedTypePath.recordArray(elementType.getCanonicalName)
- val (dataType, elementNullable) = inferDataType(elementType)
- val mapFunction: Expression => Expression = element => {
- // upcast the array element to the data type the encoder expected.
- deserializerForWithNullSafetyAndUpcast(
- element,
- dataType,
- nullable = elementNullable,
- newTypePath,
- deserializerFor(typeToken.getComponentType, _, newTypePath))
- }
-
- val arrayData = UnresolvedMapObjects(mapFunction, path)
-
- val methodName = elementType match {
- case c if c == java.lang.Integer.TYPE => "toIntArray"
- case c if c == java.lang.Long.TYPE => "toLongArray"
- case c if c == java.lang.Double.TYPE => "toDoubleArray"
- case c if c == java.lang.Float.TYPE => "toFloatArray"
- case c if c == java.lang.Short.TYPE => "toShortArray"
- case c if c == java.lang.Byte.TYPE => "toByteArray"
- case c if c == java.lang.Boolean.TYPE => "toBooleanArray"
- // non-primitive
- case _ => "array"
- }
- Invoke(arrayData, methodName, ObjectType(c))
-
- case c if ttIsAssignableFrom(listType, typeToken) =>
- val et = elementType(typeToken)
- val newTypePath = walkedTypePath.recordArray(et.getType.getTypeName)
- val (dataType, elementNullable) = inferDataType(et)
- val mapFunction: Expression => Expression = element => {
- // upcast the array element to the data type the encoder expected.
- deserializerForWithNullSafetyAndUpcast(
- element,
- dataType,
- nullable = elementNullable,
- newTypePath,
- deserializerFor(et, _, newTypePath))
- }
-
- UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c))
-
- case _ if ttIsAssignableFrom(mapType, typeToken) =>
- val (keyType, valueType) = mapKeyValueType(typeToken)
- val newTypePath = walkedTypePath.recordMap(keyType.getType.getTypeName,
- valueType.getType.getTypeName)
-
- val keyData =
- Invoke(
- UnresolvedMapObjects(
- p => deserializerFor(keyType, p, newTypePath),
- MapKeys(path)),
- "array",
- ObjectType(classOf[Array[Any]]))
-
- val valueData =
- Invoke(
- UnresolvedMapObjects(
- p => deserializerFor(valueType, p, newTypePath),
- MapValues(path)),
- "array",
- ObjectType(classOf[Array[Any]]))
-
- StaticInvoke(
- ArrayBasedMapData.getClass,
- ObjectType(classOf[JMap[_, _]]),
- "toJavaMap",
- keyData :: valueData :: Nil,
- returnNullable = false)
-
- case other if other.isEnum =>
- createDeserializerForTypesSupportValueOf(
- createDeserializerForString(path, returnNullable = false),
- other)
-
- case other =>
- val properties = getJavaBeanReadableAndWritableProperties(other)
- val setters = properties.map { p =>
- val fieldName = p.getName
- val fieldType = typeToken.method(p.getReadMethod).getReturnType
- val (dataType, nullable) = inferDataType(fieldType)
- val newTypePath = walkedTypePath.recordField(fieldType.getType.getTypeName, fieldName)
- // The existence of `javax.annotation.Nonnull`, means this field is not nullable.
- val hasNonNull = p.getReadMethod.isAnnotationPresent(classOf[Nonnull])
- val setter = expressionWithNullSafety(
- deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePath),
- newTypePath),
- nullable = nullable && !hasNonNull,
- newTypePath)
- p.getWriteMethod.getName -> setter
- }.toMap
-
- val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false)
- val result = InitializeJavaBean(newInstance, setters)
-
- expressions.If(
- IsNull(path),
- expressions.Literal.create(null, ObjectType(other)),
- result
- )
+ def unapply(t: Type): Option[(Class[_], Array[Type])] = implementsInterface(t).map { cls =>
+ cls -> findTypeArgumentsForInterface(t)
}
- }
- /**
- * Returns an expression for serializing an object of the given type to a Spark SQL
- * representation. The input object is located at ordinal 0 of a row, i.e.,
- * `BoundReference(0, _)`.
- */
- def serializerFor(beanClass: Class[_]): Expression = {
- val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
- val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean"))
- serializerFor(nullSafeInput, TypeToken.of(beanClass))
- }
-
- private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
-
- def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
- val (dataType, nullable) = inferDataType(elementType)
- if (ScalaReflection.isNativeType(dataType)) {
- val cls = input.dataType.asInstanceOf[ObjectType].cls
- if (cls.isArray && cls.getComponentType.isPrimitive) {
- createSerializerForPrimitiveArray(input, dataType)
- } else {
- createSerializerForGenericArray(input, dataType, nullable = nullable)
- }
- } else {
- createSerializerForMapObjects(input, ObjectType(elementType.getRawType),
- serializerFor(_, elementType))
- }
+ @tailrec
+ private def implementsInterface(t: Type): Option[Class[_]] = t match {
+ case pt: ParameterizedType => implementsInterface(pt.getRawType)
+ case c: Class[_] if interface.isAssignableFrom(c) => Option(c)
+ case _ => None
}
- if (!inputObject.dataType.isInstanceOf[ObjectType]) {
- inputObject
- } else {
- typeToken.getRawType match {
- case c if c == classOf[String] => createSerializerForString(inputObject)
-
- case c if c == classOf[java.time.Instant] => createSerializerForJavaInstant(inputObject)
-
- case c if c == classOf[java.sql.Timestamp] => createSerializerForSqlTimestamp(inputObject)
-
- case c if c == classOf[java.time.LocalDateTime] =>
- createSerializerForLocalDateTime(inputObject)
-
- case c if c == classOf[java.time.LocalDate] => createSerializerForJavaLocalDate(inputObject)
-
- case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject)
-
- case c if c == classOf[java.time.Duration] => createSerializerForJavaDuration(inputObject)
-
- case c if c == classOf[java.time.Period] => createSerializerForJavaPeriod(inputObject)
-
- case c if c == classOf[java.math.BigInteger] =>
- createSerializerForBigInteger(inputObject)
-
- case c if c == classOf[java.math.BigDecimal] =>
- createSerializerForBigDecimal(inputObject)
-
- case c if c == classOf[java.lang.Boolean] => createSerializerForBoolean(inputObject)
- case c if c == classOf[java.lang.Byte] => createSerializerForByte(inputObject)
- case c if c == classOf[java.lang.Short] => createSerializerForShort(inputObject)
- case c if c == classOf[java.lang.Integer] => createSerializerForInteger(inputObject)
- case c if c == classOf[java.lang.Long] => createSerializerForLong(inputObject)
- case c if c == classOf[java.lang.Float] => createSerializerForFloat(inputObject)
- case c if c == classOf[java.lang.Double] => createSerializerForDouble(inputObject)
-
- case _ if typeToken.isArray =>
- toCatalystArray(inputObject, typeToken.getComponentType)
-
- case _ if ttIsAssignableFrom(listType, typeToken) =>
- toCatalystArray(inputObject, elementType(typeToken))
-
- case _ if ttIsAssignableFrom(mapType, typeToken) =>
- val (keyType, valueType) = mapKeyValueType(typeToken)
-
- createSerializerForMap(
- inputObject,
- MapElementInformation(
- ObjectType(keyType.getRawType),
- nullable = true,
- serializerFor(_, keyType)),
- MapElementInformation(
- ObjectType(valueType.getRawType),
- nullable = true,
- serializerFor(_, valueType))
- )
-
- case other if other.isEnum =>
- createSerializerForString(
- Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false))
-
- case other =>
- val properties = getJavaBeanReadableAndWritableProperties(other)
- val fields = properties.map { p =>
- val fieldName = p.getName
- val fieldType = typeToken.method(p.getReadMethod).getReturnType
- val hasNonNull = p.getReadMethod.isAnnotationPresent(classOf[Nonnull])
- val fieldValue = Invoke(
- inputObject,
- p.getReadMethod.getName,
- inferExternalType(fieldType.getRawType),
- propagateNull = !hasNonNull,
- returnNullable = !hasNonNull)
- (fieldName, serializerFor(fieldValue, fieldType))
- }
- createSerializerForObject(inputObject, fields)
+ private def findTypeArgumentsForInterface(t: Type): Array[Type] = {
+ val queue = new ArrayDeque[(Type, Map[Any, Type])]
+ queue.add(t -> Map.empty)
+ while (!queue.isEmpty) {
+ queue.poll() match {
+ case (pt: ParameterizedType, bindings) =>
+ // translate mappings...
+ val mappedTypeArguments = pt.getActualTypeArguments.map {
+ case v: TypeVariable[_] => bindings(v.getName)
+ case v => v
+ }
+ if (pt.getRawType == interface) {
+ return mappedTypeArguments
+ } else {
+ val mappedTypeArgumentMap = mappedTypeArguments
+ .zipWithIndex.map(_.swap)
+ .toMap[Any, Type]
+ queue.add(pt.getRawType -> mappedTypeArgumentMap)
+ }
+ case (c: Class[_], indexedBindings) =>
+ val namedBindings = c.getTypeParameters.zipWithIndex.map {
+ case (parameter, index) =>
+ parameter.getName -> indexedBindings(index)
+ }.toMap[Any, Type]
+ val superClass = c.getGenericSuperclass
+ if (superClass != null) {
+ queue.add(superClass -> namedBindings)
+ }
+ c.getGenericInterfaces.foreach { iface =>
+ queue.add(iface -> namedBindings)
+ }
+ }
}
+ throw QueryExecutionErrors.unreachableError()
}
}
+
+ private object ImplementsList extends ImplementsGenericInterface(classOf[JList[_]])
+ private object ImplementsMap extends ImplementsGenericInterface(classOf[JMap[_, _]])
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 42208cd1098..4680a2aec2b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -264,6 +264,36 @@ object ScalaReflection extends ScalaReflection {
Option(clsTag.runtimeClass),
walkedTypePath)
+ case MapEncoder(tag, keyEncoder, valueEncoder, _)
+ if classOf[java.util.Map[_, _]].isAssignableFrom(tag.runtimeClass) =>
+ // TODO (hvanhovell) this is can be improved.
+ val newTypePath = walkedTypePath.recordMap(
+ keyEncoder.clsTag.runtimeClass.getName,
+ valueEncoder.clsTag.runtimeClass.getName)
+
+ val keyData =
+ Invoke(
+ UnresolvedMapObjects(
+ p => deserializerFor(keyEncoder, p, newTypePath),
+ MapKeys(path)),
+ "array",
+ ObjectType(classOf[Array[Any]]))
+
+ val valueData =
+ Invoke(
+ UnresolvedMapObjects(
+ p => deserializerFor(valueEncoder, p, newTypePath),
+ MapValues(path)),
+ "array",
+ ObjectType(classOf[Array[Any]]))
+
+ StaticInvoke(
+ ArrayBasedMapData.getClass,
+ ObjectType(classOf[java.util.Map[_, _]]),
+ "toJavaMap",
+ keyData :: valueData :: Nil,
+ returnNullable = false)
+
case MapEncoder(tag, keyEncoder, valueEncoder, _) =>
val newTypePath = walkedTypePath.recordMap(
keyEncoder.clsTag.runtimeClass.getName,
@@ -312,6 +342,26 @@ object ScalaReflection extends ScalaReflection {
exprs.If(IsNull(path),
exprs.Literal.create(null, externalDataTypeFor(enc)),
CreateExternalRow(convertedFields, enc.schema))
+
+ case JavaBeanEncoder(tag, fields) =>
+ val setters = fields.map { f =>
+ val newTypePath = walkedTypePath.recordField(
+ f.enc.clsTag.runtimeClass.getName,
+ f.name)
+ val setter = expressionWithNullSafety(
+ deserializerFor(
+ f.enc,
+ addToPath(path, f.name, f.enc.dataType, newTypePath),
+ newTypePath),
+ nullable = f.nullable,
+ newTypePath)
+ f.writeMethod.get -> setter
+ }
+
+ val cls = tag.runtimeClass
+ val newInstance = NewInstance(cls, Nil, ObjectType(cls), propagateNull = false)
+ val result = InitializeJavaBean(newInstance, setters.toMap)
+ exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result)
}
private def deserializeArray(
@@ -446,6 +496,18 @@ object ScalaReflection extends ScalaReflection {
field.name -> convertedField
}
createSerializerForObject(input, serializedFields)
+
+ case JavaBeanEncoder(_, fields) =>
+ val serializedFields = fields.map { f =>
+ val fieldValue = Invoke(
+ KnownNotNull(input),
+ f.readMethod.get,
+ externalDataTypeFor(f.enc),
+ propagateNull = f.nullable,
+ returnNullable = f.nullable)
+ f.name -> serializerFor(f.enc, fieldValue)
+ }
+ createSerializerForObject(input, serializedFields)
}
private def serializerForArray(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index cdc64f2ddb5..1a3c1089649 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -91,7 +91,9 @@ object AgnosticEncoders {
name: String,
enc: AgnosticEncoder[_],
nullable: Boolean,
- metadata: Metadata) {
+ metadata: Metadata,
+ readMethod: Option[String] = None,
+ writeMethod: Option[String] = None) {
def structField: StructField = StructField(name, enc.dataType, nullable, metadata)
}
@@ -112,6 +114,15 @@ object AgnosticEncoders {
override def clsTag: ClassTag[Row] = classTag[Row]
}
+ case class JavaBeanEncoder[K](
+ override val clsTag: ClassTag[K],
+ fields: Seq[EncoderField])
+ extends AgnosticEncoder[K] {
+ override def isPrimitive: Boolean = false
+ override val schema: StructType = StructType(fields.map(_.structField))
+ override def dataType: DataType = schema
+ }
+
// This will only work for encoding from/to Sparks' InternalRow format.
// It is here for compatibility.
case class UDTEncoder[E >: Null](
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 9ca2fc72ad9..faa165c298d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -59,16 +59,7 @@ object ExpressionEncoder {
// TODO: improve error message for java bean encoder.
def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
- val schema = JavaTypeInference.inferDataType(beanClass)._1
- assert(schema.isInstanceOf[StructType])
-
- val objSerializer = JavaTypeInference.serializerFor(beanClass)
- val objDeserializer = JavaTypeInference.deserializerFor(beanClass)
-
- new ExpressionEncoder[T](
- objSerializer,
- objDeserializer,
- ClassTag[T](beanClass))
+ apply(JavaTypeInference.encoderFor(beanClass))
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 299a928f267..929beb660ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -1927,7 +1927,8 @@ case class ValidateExternalType(child: Expression, expected: DataType, externalD
(value: Any) => {
value.getClass.isArray ||
value.isInstanceOf[scala.collection.Seq[_]] ||
- value.isInstanceOf[Set[_]]
+ value.isInstanceOf[Set[_]] ||
+ value.isInstanceOf[java.util.List[_]]
}
case _: DateType =>
(value: Any) => {
@@ -1968,7 +1969,10 @@ case class ValidateExternalType(child: Expression, expected: DataType, externalD
classOf[scala.math.BigDecimal],
classOf[Decimal]))
case _: ArrayType =>
- val check = genCheckTypes(Seq(classOf[scala.collection.Seq[_]], classOf[Set[_]]))
+ val check = genCheckTypes(Seq(
+ classOf[scala.collection.Seq[_]],
+ classOf[Set[_]],
+ classOf[java.util.List[_]]))
s"$obj.getClass().isArray() || $check"
case _: DateType =>
genCheckTypes(Seq(classOf[java.sql.Date], classOf[java.time.LocalDate]))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala
index 9c1d0c17777..35f5bf739bf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala
@@ -18,25 +18,206 @@
package org.apache.spark.sql.catalyst
import java.math.BigInteger
+import java.util.{LinkedList, List => JList, Map => JMap}
-import scala.beans.BeanProperty
+import scala.beans.{BeanProperty, BooleanBeanProperty}
+import scala.reflect.{classTag, ClassTag}
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.{CheckOverflow, Expression, Literal}
-import org.apache.spark.sql.types.DecimalType
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, UDTCaseClass, UDTForCaseClass}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
+import org.apache.spark.sql.types.{DecimalType, MapType, Metadata, StringType, StructField, StructType}
-class DummyBean() {
- @BeanProperty var bigInteger = null: BigInteger
+class DummyBean {
+ @BeanProperty var bigInteger: BigInteger = _
}
+class GenericCollectionBean {
+ @BeanProperty var listOfListOfStrings: JList[JList[String]] = _
+ @BeanProperty var mapOfDummyBeans: JMap[String, DummyBean] = _
+ @BeanProperty var linkedListOfStrings: LinkedList[String] = _
+}
+
+class LeafBean {
+ @BooleanBeanProperty var primitiveBoolean: Boolean = false
+ @BeanProperty var primitiveByte: Byte = 0
+ @BeanProperty var primitiveShort: Short = 0
+ @BeanProperty var primitiveInt: Int = 0
+ @BeanProperty var primitiveLong: Long = 0
+ @BeanProperty var primitiveFloat: Float = 0
+ @BeanProperty var primitiveDouble: Double = 0
+ @BeanProperty var boxedBoolean: java.lang.Boolean = false
+ @BeanProperty var boxedByte: java.lang.Byte = 0.toByte
+ @BeanProperty var boxedShort: java.lang.Short = 0.toShort
+ @BeanProperty var boxedInt: java.lang.Integer = 0
+ @BeanProperty var boxedLong: java.lang.Long = 0
+ @BeanProperty var boxedFloat: java.lang.Float = 0
+ @BeanProperty var boxedDouble: java.lang.Double = 0
+ @BeanProperty var string: String = _
+ @BeanProperty var binary: Array[Byte] = _
+ @BeanProperty var bigDecimal: java.math.BigDecimal = _
+ @BeanProperty var bigInteger: java.math.BigInteger = _
+ @BeanProperty var localDate: java.time.LocalDate = _
+ @BeanProperty var date: java.sql.Date = _
+ @BeanProperty var instant: java.time.Instant = _
+ @BeanProperty var timestamp: java.sql.Timestamp = _
+ @BeanProperty var localDateTime: java.time.LocalDateTime = _
+ @BeanProperty var duration: java.time.Duration = _
+ @BeanProperty var period: java.time.Period = _
+ @BeanProperty var enum: java.time.Month = _
+ @BeanProperty val readOnlyString = "read-only"
+
+ var nonNullString: String = "value"
+ @javax.annotation.Nonnull
+ def getNonNullString: String = nonNullString
+ def setNonNullString(v: String): Unit = nonNullString = {
+ java.util.Objects.nonNull(v)
+ v
+ }
+}
+
+class ArrayBean {
+ @BeanProperty var dummyBeanArray: Array[DummyBean] = _
+ @BeanProperty var primitiveIntArray: Array[Int] = _
+ @BeanProperty var stringArray: Array[String] = _
+}
+
+class UDTBean {
+ @BeanProperty var udt: UDTCaseClass = _
+}
+
+/**
+ * Test suite for Encoders produced by [[JavaTypeInference]].
+ */
class JavaTypeInferenceSuite extends SparkFunSuite {
+ private def encoderField(
+ name: String,
+ encoder: AgnosticEncoder[_],
+ overrideNullable: Option[Boolean] = None,
+ readOnly: Boolean = false): EncoderField = {
+ val readPrefix = if (encoder == PrimitiveBooleanEncoder) "is" else "get"
+ EncoderField(
+ name,
+ encoder,
+ overrideNullable.getOrElse(encoder.nullable),
+ Metadata.empty,
+ Option(readPrefix + name.capitalize),
+ Option("set" + name.capitalize).filterNot(_ => readOnly))
+ }
+
+ private val expectedDummyBeanEncoder =
+ JavaBeanEncoder[DummyBean](
+ ClassTag(classOf[DummyBean]),
+ Seq(encoderField("bigInteger", JavaBigIntEncoder)))
+
+ private val expectedDummyBeanSchema =
+ StructType(StructField("bigInteger", DecimalType(38, 0)) :: Nil)
+
test("SPARK-41007: JavaTypeInference returns the correct serializer for BigInteger") {
- var serializer = JavaTypeInference.serializerFor(classOf[DummyBean])
- var bigIntegerFieldName: Expression = serializer.children(0)
- assert(bigIntegerFieldName.asInstanceOf[Literal].value.toString == "bigInteger")
- var bigIntegerFieldExpression: Expression = serializer.children(1)
- assert(bigIntegerFieldExpression.asInstanceOf[CheckOverflow].dataType ==
- DecimalType.BigIntDecimal)
+ val encoder = JavaTypeInference.encoderFor(classOf[DummyBean])
+ assert(encoder === expectedDummyBeanEncoder)
+ assert(encoder.schema === expectedDummyBeanSchema)
+ }
+
+ test("resolve schema for class") {
+ val (schema, nullable) = JavaTypeInference.inferDataType(classOf[DummyBean])
+ assert(nullable)
+ assert(schema === expectedDummyBeanSchema)
+ }
+
+ test("resolve schema for type") {
+ val getter = classOf[GenericCollectionBean].getDeclaredMethods
+ .find(_.getName == "getMapOfDummyBeans")
+ .get
+ val (schema, nullable) = JavaTypeInference.inferDataType(getter.getGenericReturnType)
+ val expected = MapType(StringType, expectedDummyBeanSchema, valueContainsNull = true)
+ assert(nullable)
+ assert(schema === expected)
+ }
+
+ test("resolve type parameters for map and list") {
+ val encoder = JavaTypeInference.encoderFor(classOf[GenericCollectionBean])
+ val expected = JavaBeanEncoder(ClassTag(classOf[GenericCollectionBean]), Seq(
+ encoderField(
+ "linkedListOfStrings",
+ IterableEncoder(
+ ClassTag(classOf[LinkedList[_]]),
+ StringEncoder,
+ containsNull = true,
+ lenientSerialization = false)),
+ encoderField(
+ "listOfListOfStrings",
+ IterableEncoder(
+ ClassTag(classOf[JList[_]]),
+ IterableEncoder(
+ ClassTag(classOf[JList[_]]),
+ StringEncoder,
+ containsNull = true,
+ lenientSerialization = false),
+ containsNull = true,
+ lenientSerialization = false)),
+ encoderField(
+ "mapOfDummyBeans",
+ MapEncoder(
+ ClassTag(classOf[JMap[_, _]]),
+ StringEncoder,
+ expectedDummyBeanEncoder,
+ valueContainsNull = true))))
+ assert(encoder === expected)
+ }
+
+ test("resolve leaf encoders") {
+ val encoder = JavaTypeInference.encoderFor(classOf[LeafBean])
+ val expected = JavaBeanEncoder(ClassTag(classOf[LeafBean]), Seq(
+ // The order is different from the definition because fields are ordered by name.
+ encoderField("bigDecimal", DEFAULT_JAVA_DECIMAL_ENCODER),
+ encoderField("bigInteger", JavaBigIntEncoder),
+ encoderField("binary", BinaryEncoder),
+ encoderField("boxedBoolean", BoxedBooleanEncoder),
+ encoderField("boxedByte", BoxedByteEncoder),
+ encoderField("boxedDouble", BoxedDoubleEncoder),
+ encoderField("boxedFloat", BoxedFloatEncoder),
+ encoderField("boxedInt", BoxedIntEncoder),
+ encoderField("boxedLong", BoxedLongEncoder),
+ encoderField("boxedShort", BoxedShortEncoder),
+ encoderField("date", STRICT_DATE_ENCODER),
+ encoderField("duration", DayTimeIntervalEncoder),
+ encoderField("enum", JavaEnumEncoder(classTag[java.time.Month])),
+ encoderField("instant", STRICT_INSTANT_ENCODER),
+ encoderField("localDate", STRICT_LOCAL_DATE_ENCODER),
+ encoderField("localDateTime", LocalDateTimeEncoder),
+ encoderField("nonNullString", StringEncoder, overrideNullable = Option(false)),
+ encoderField("period", YearMonthIntervalEncoder),
+ encoderField("primitiveBoolean", PrimitiveBooleanEncoder),
+ encoderField("primitiveByte", PrimitiveByteEncoder),
+ encoderField("primitiveDouble", PrimitiveDoubleEncoder),
+ encoderField("primitiveFloat", PrimitiveFloatEncoder),
+ encoderField("primitiveInt", PrimitiveIntEncoder),
+ encoderField("primitiveLong", PrimitiveLongEncoder),
+ encoderField("primitiveShort", PrimitiveShortEncoder),
+ encoderField("readOnlyString", StringEncoder, readOnly = true),
+ encoderField("string", StringEncoder),
+ encoderField("timestamp", STRICT_TIMESTAMP_ENCODER)
+ ))
+ assert(encoder === expected)
+ }
+
+ test("resolve array encoders") {
+ val encoder = JavaTypeInference.encoderFor(classOf[ArrayBean])
+ val expected = JavaBeanEncoder(ClassTag(classOf[ArrayBean]), Seq(
+ encoderField("dummyBeanArray", ArrayEncoder(expectedDummyBeanEncoder, containsNull = true)),
+ encoderField("primitiveIntArray", ArrayEncoder(PrimitiveIntEncoder, containsNull = false)),
+ encoderField("stringArray", ArrayEncoder(StringEncoder, containsNull = true))
+ ))
+ assert(encoder === expected)
+ }
+
+ test("resolve UDT encoders") {
+ val encoder = JavaTypeInference.encoderFor(classOf[UDTBean])
+ val expected = JavaBeanEncoder(ClassTag(classOf[UDTBean]), Seq(
+ encoderField("udt", UDTEncoder(new UDTForCaseClass, classOf[UDTForCaseClass]))
+ ))
+ assert(encoder === expected)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org