You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/12/23 11:24:29 UTC

[GitHub] [spark] cloud-fan commented on a diff in pull request #39186: [SPARK-41690][SQL][CONNECT] Agnostic Encoders

cloud-fan commented on code in PR #39186:
URL: https://github.com/apache/spark/pull/39186#discussion_r1056263221


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala:
##########
@@ -186,237 +167,129 @@ object ScalaReflection extends ScalaReflection {
    * @param walkedTypePath The paths from top to bottom to access current field when deserializing.
    */
   private def deserializerFor(
-      tpe: `Type`,
-      walkedTypePath: WalkedTypePath): Expression => Expression = cleanUpReflectionObjects {
-    baseType(tpe) match {
-      case t if !dataTypeFor(t).isInstanceOf[ObjectType] => identity
-
-      case t if isSubtype(t, localTypeOf[Option[_]]) =>
-        val TypeRef(_, _, Seq(optType)) = t
-        val className = getClassNameFromType(optType)
-        val newTypePath = walkedTypePath.recordOption(className)
-        val dataType = dataTypeFor(optType)
-        val deserializerFunc = deserializerFor(optType, newTypePath)
-        path => WrapOption(deserializerFunc(path), dataType)
-
-      case t if isSubtype(t, localTypeOf[java.lang.Integer]) =>
-        createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Integer])
-
-      case t if isSubtype(t, localTypeOf[java.lang.Long]) =>
-        createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Long])
-
-      case t if isSubtype(t, localTypeOf[java.lang.Double]) =>
-        createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Double])
-
-      case t if isSubtype(t, localTypeOf[java.lang.Float]) =>
-        createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Float])
-
-      case t if isSubtype(t, localTypeOf[java.lang.Short]) =>
-        createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Short])
-
-      case t if isSubtype(t, localTypeOf[java.lang.Byte]) =>
-        createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Byte])
-
-      case t if isSubtype(t, localTypeOf[java.lang.Boolean]) =>
-        createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Boolean])
-
-      case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
-        createDeserializerForLocalDate
-
-      case t if isSubtype(t, localTypeOf[java.sql.Date]) =>
-        createDeserializerForSqlDate
-
-      case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
-        createDeserializerForInstant
-
-      case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) =>
-        // Code touching Scala Reflection should be called outside the returned function to allow
-        // caching the Scala Reflection result
-        val cls = getClassFromType(t)
-        path => createDeserializerForTypesSupportValueOf(
-          Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false), cls)
-
-      case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
-        createDeserializerForSqlTimestamp
-
-      case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) =>
-        createDeserializerForLocalDateTime
-
-      case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
-        createDeserializerForDuration
-
-      case t if isSubtype(t, localTypeOf[java.time.Period]) =>
-        createDeserializerForPeriod
-
-      case t if isSubtype(t, localTypeOf[java.lang.String]) =>
-        createDeserializerForString(_, returnNullable = false)
-
-      case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
-        createDeserializerForJavaBigDecimal(_, returnNullable = false)
-
-      case t if isSubtype(t, localTypeOf[BigDecimal]) =>
-        createDeserializerForScalaBigDecimal(_, returnNullable = false)
-
-      case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
-        createDeserializerForJavaBigInteger(_, returnNullable = false)
-
-      case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
-        createDeserializerForScalaBigInt
-
-      case t if isSubtype(t, localTypeOf[Array[_]]) =>
-        val TypeRef(_, _, Seq(elementType)) = t
-        val Schema(dataType, elementNullable) = schemaFor(elementType)
-        val className = getClassNameFromType(elementType)
-        val newTypePath = walkedTypePath.recordArray(className)
-        val deserializerFunc = deserializerFor(elementType, newTypePath)
-        val mapFunction: Expression => Expression = element => {
-          // upcast the array element to the data type the encoder expected.
-          deserializerForWithNullSafetyAndUpcast(
-            element,
-            dataType,
-            nullable = elementNullable,
-            newTypePath,
-            deserializerFunc)
-        }
-
-        val arrayCls = arrayClassFor(elementType)
-
-        val methodName = elementType match {
-          case t if isSubtype(t, definitions.IntTpe) => "toIntArray"
-          case t if isSubtype(t, definitions.LongTpe) => "toLongArray"
-          case t if isSubtype(t, definitions.DoubleTpe) => "toDoubleArray"
-          case t if isSubtype(t, definitions.FloatTpe) => "toFloatArray"
-          case t if isSubtype(t, definitions.ShortTpe) => "toShortArray"
-          case t if isSubtype(t, definitions.ByteTpe) => "toByteArray"
-          case t if isSubtype(t, definitions.BooleanTpe) => "toBooleanArray"
-          // non-primitive
-          case _ => "array"
-        }
-        path => {
-          val arrayData = UnresolvedMapObjects(mapFunction, path)
-          Invoke(arrayData, methodName, arrayCls, returnNullable = false)
-        }
-
-      // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array
-      // to a `Set`, if there are duplicated elements, the elements will be de-duplicated.
-      case t if isSubtype(t, localTypeOf[scala.collection.Seq[_]]) ||
-          isSubtype(t, localTypeOf[scala.collection.Set[_]]) =>
-        val TypeRef(_, _, Seq(elementType)) = t
-        val Schema(dataType, elementNullable) = schemaFor(elementType)
-        val className = getClassNameFromType(elementType)
-        val newTypePath = walkedTypePath.recordArray(className)
-        val deserializerFunc = deserializerFor(elementType, newTypePath)
-        val mapFunction: Expression => Expression = element => {
-          deserializerForWithNullSafetyAndUpcast(
-            element,
-            dataType,
-            nullable = elementNullable,
-            newTypePath,
-            deserializerFunc)
-        }
-
-        val companion = t.dealias.typeSymbol.companion.typeSignature
-        val cls = companion.member(TermName("newBuilder")) match {
-          case NoSymbol if isSubtype(t, localTypeOf[Seq[_]]) => classOf[Seq[_]]
-          case NoSymbol if isSubtype(t, localTypeOf[scala.collection.Set[_]]) =>
-            classOf[scala.collection.Set[_]]
-          case _ => mirror.runtimeClass(t.typeSymbol.asClass)
-        }
-        UnresolvedMapObjects(mapFunction, _, Some(cls))
-
-      case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
-        val TypeRef(_, _, Seq(keyType, valueType)) = t
-
-        val classNameForKey = getClassNameFromType(keyType)
-        val classNameForValue = getClassNameFromType(valueType)
-
-        val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue)
-
-        // Code touching Scala Reflection should be called outside the returned function to allow
-        // caching the Scala Reflection result
-        val keyDeserializerFunc = deserializerFor(keyType, newTypePath)
-        val valueDeserializerFunc = deserializerFor(valueType, newTypePath)
-        val cls = mirror.runtimeClass(t.typeSymbol.asClass)
-        UnresolvedCatalystToExternalMap(_, keyDeserializerFunc, valueDeserializerFunc, cls)
-
-      case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
-        val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
-          getConstructor().newInstance()
-        val obj = NewInstance(
-          udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
-          Nil,
-          dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
-        // Code touching Scala Reflection should be called outside the returned function to allow
-        // caching the Scala Reflection result
-        val cls = udt.userClass
-        path => Invoke(obj, "deserialize", ObjectType(cls), Seq(path))
-
-      case t if UDTRegistration.exists(getClassNameFromType(t)) =>
-        val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
-          newInstance().asInstanceOf[UserDefinedType[_]]
-        val obj = NewInstance(
-          udt.getClass,
-          Nil,
-          dataType = ObjectType(udt.getClass))
-        // Code touching Scala Reflection should be called outside the returned function to allow
-        // caching the Scala Reflection result
-        val cls = udt.userClass
-        path => Invoke(obj, "deserialize", ObjectType(cls), Seq(path))
-
-      case t if definedByConstructorParams(t) =>
-        val params = getConstructorParameters(t)
-
-        val cls = getClassFromType(tpe)
-
-        val arguDeserializerFuncs = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
-          val Schema(dataType, nullable) = schemaFor(fieldType)
-          val clsName = getClassNameFromType(fieldType)
-          val newTypePath = walkedTypePath.recordField(clsName, fieldName)
-
-          // For tuples, we based grab the inner fields by ordinal instead of name.
-          val newPathFunc = if (cls.getName startsWith "scala.Tuple") {
-            addToPathOrdinal(_, i, dataType, newTypePath)
+      enc: AgnosticEncoder[_],
+      input: Expression,
+      typePath: WalkedTypePath): Expression = enc match {
+    case _ if isNativeEncoder(enc) =>
+      input
+    case BooleanEncoder =>

Review Comment:
   aren't these handled by `case _ if isNativeEncoder(enc) =>` already?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org