You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by zs...@apache.org on 2022/11/08 16:19:07 UTC

[spark] branch master updated: [SPARK-41045][SQL] Pre-compute to eliminate ScalaReflection calls after deserializer is created

This is an automated email from the ASF dual-hosted git repository.

zsxwing pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ef402edff91 [SPARK-41045][SQL] Pre-compute to eliminate ScalaReflection calls after deserializer is created
ef402edff91 is described below

commit ef402edff91377d37c0c1b8d40921ed7bd9f7160
Author: Shixiong Zhu <zs...@gmail.com>
AuthorDate: Tue Nov 8 08:18:50 2022 -0800

    [SPARK-41045][SQL] Pre-compute to eliminate ScalaReflection calls after deserializer is created
    
    ### What changes were proposed in this pull request?
    
    Currently when `ScalaReflection` returns a deserializer, for a few complex types, such as array, map, udt, etc, it creates functions that may still touch `ScalaReflection` after the deserializer is created.
    
    `ScalaReflection` is a performance bottleneck for multiple threads as it holds multiple global locks. We can refactor `ScalaReflection.deserializerFor` to pre-compute everything that needs to touch `ScalaReflection` before creating the deserializer. After this, once the deserializer is created, it can be reused by multiple threads without touching `ScalaReflection.deserializerFor` any more and it will be much faster.
    
    ### Why are the changes needed?
    
    Optimize `ScalaReflection.deserializerFor` to make deserializers faster under multiple threads.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    This is refactoring `deserializerFor` to optimize the code. Existing tests should already cover the correctness.
    
    Closes #38556 from zsxwing/scala-ref.
    
    Authored-by: Shixiong Zhu <zs...@gmail.com>
    Signed-off-by: Shixiong Zhu <zs...@gmail.com>
---
 .../sql/catalyst/DeserializerBuildHelper.scala     |   5 +-
 .../spark/sql/catalyst/JavaTypeInference.scala     |   8 +-
 .../spark/sql/catalyst/ScalaReflection.scala       | 157 +++++++++++----------
 3 files changed, 85 insertions(+), 85 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
index 0d3b9977e4f..7051c2d2264 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
@@ -49,10 +49,9 @@ object DeserializerBuildHelper {
       dataType: DataType,
       nullable: Boolean,
       walkedTypePath: WalkedTypePath,
-      funcForCreatingDeserializer: (Expression, WalkedTypePath) => Expression): Expression = {
+      funcForCreatingDeserializer: Expression => Expression): Expression = {
     val casted = upCastToExpectedType(expr, dataType, walkedTypePath)
-    expressionWithNullSafety(funcForCreatingDeserializer(casted, walkedTypePath),
-      nullable, walkedTypePath)
+    expressionWithNullSafety(funcForCreatingDeserializer(casted), nullable, walkedTypePath)
   }
 
   def expressionWithNullSafety(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index dccaf1c4835..827807055ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -218,9 +218,7 @@ object JavaTypeInference {
 
     // Assumes we are deserializing the first column of a row.
     deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType,
-      nullable = nullable, walkedTypePath, (casted, walkedTypePath) => {
-        deserializerFor(typeToken, casted, walkedTypePath)
-      })
+      nullable = nullable, walkedTypePath, deserializerFor(typeToken, _, walkedTypePath))
   }
 
   private def deserializerFor(
@@ -280,7 +278,7 @@ object JavaTypeInference {
             dataType,
             nullable = elementNullable,
             newTypePath,
-            (casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath))
+            deserializerFor(typeToken.getComponentType, _, newTypePath))
         }
 
         val arrayData = UnresolvedMapObjects(mapFunction, path)
@@ -309,7 +307,7 @@ object JavaTypeInference {
             dataType,
             nullable = elementNullable,
             newTypePath,
-            (casted, typePath) => deserializerFor(et, casted, typePath))
+            deserializerFor(et, _, newTypePath))
         }
 
         UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 12093b9f4b2..d895a0fbe19 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -172,109 +172,103 @@ object ScalaReflection extends ScalaReflection {
     val clsName = getClassNameFromType(tpe)
     val walkedTypePath = new WalkedTypePath().recordRoot(clsName)
     val Schema(dataType, nullable) = schemaFor(tpe)
-
+    val deserializerFunc = deserializerFor(tpe, walkedTypePath)
     // Assumes we are deserializing the first column of a row.
     deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType,
-      nullable = nullable, walkedTypePath,
-      (casted, typePath) => deserializerFor(tpe, casted, typePath))
+      nullable = nullable, walkedTypePath, deserializerFunc)
   }
 
   /**
-   * Returns an expression that can be used to deserialize an input expression to an object of type
-   * `T` with a compatible schema.
+   * Returns a function that receives an input expression and turns it to an expression that can be
+   * used to deserialize the input expression to an object of type `T` with a compatible schema.
    *
    * @param tpe The `Type` of deserialized object.
-   * @param path The expression which can be used to extract serialized value.
    * @param walkedTypePath The paths from top to bottom to access current field when deserializing.
    */
   private def deserializerFor(
       tpe: `Type`,
-      path: Expression,
-      walkedTypePath: WalkedTypePath): Expression = cleanUpReflectionObjects {
+      walkedTypePath: WalkedTypePath): Expression => Expression = cleanUpReflectionObjects {
     baseType(tpe) match {
-      case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path
+      case t if !dataTypeFor(t).isInstanceOf[ObjectType] => identity
 
       case t if isSubtype(t, localTypeOf[Option[_]]) =>
         val TypeRef(_, _, Seq(optType)) = t
         val className = getClassNameFromType(optType)
         val newTypePath = walkedTypePath.recordOption(className)
-        WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType))
+        val dataType = dataTypeFor(optType)
+        val deserializerFunc = deserializerFor(optType, newTypePath)
+        path => WrapOption(deserializerFunc(path), dataType)
 
       case t if isSubtype(t, localTypeOf[java.lang.Integer]) =>
-        createDeserializerForTypesSupportValueOf(path,
-          classOf[java.lang.Integer])
+        createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Integer])
 
       case t if isSubtype(t, localTypeOf[java.lang.Long]) =>
-        createDeserializerForTypesSupportValueOf(path,
-          classOf[java.lang.Long])
+        createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Long])
 
       case t if isSubtype(t, localTypeOf[java.lang.Double]) =>
-        createDeserializerForTypesSupportValueOf(path,
-          classOf[java.lang.Double])
+        createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Double])
 
       case t if isSubtype(t, localTypeOf[java.lang.Float]) =>
-        createDeserializerForTypesSupportValueOf(path,
-          classOf[java.lang.Float])
+        createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Float])
 
       case t if isSubtype(t, localTypeOf[java.lang.Short]) =>
-        createDeserializerForTypesSupportValueOf(path,
-          classOf[java.lang.Short])
+        createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Short])
 
       case t if isSubtype(t, localTypeOf[java.lang.Byte]) =>
-        createDeserializerForTypesSupportValueOf(path,
-          classOf[java.lang.Byte])
+        createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Byte])
 
       case t if isSubtype(t, localTypeOf[java.lang.Boolean]) =>
-        createDeserializerForTypesSupportValueOf(path,
-          classOf[java.lang.Boolean])
+        createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Boolean])
 
       case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
-        createDeserializerForLocalDate(path)
+        createDeserializerForLocalDate
 
       case t if isSubtype(t, localTypeOf[java.sql.Date]) =>
-        createDeserializerForSqlDate(path)
+        createDeserializerForSqlDate
 
       case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
-        createDeserializerForInstant(path)
+        createDeserializerForInstant
 
       case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) =>
-        createDeserializerForTypesSupportValueOf(
-          Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false),
-          getClassFromType(t))
+        // Code touching Scala Reflection should be called outside the returned function to allow
+        // caching the Scala Reflection result
+        val cls = getClassFromType(t)
+        path => createDeserializerForTypesSupportValueOf(
+          Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false), cls)
 
       case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
-        createDeserializerForSqlTimestamp(path)
+        createDeserializerForSqlTimestamp
 
       case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) =>
-        createDeserializerForLocalDateTime(path)
+        createDeserializerForLocalDateTime
 
       case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
-        createDeserializerForDuration(path)
+        createDeserializerForDuration
 
       case t if isSubtype(t, localTypeOf[java.time.Period]) =>
-        createDeserializerForPeriod(path)
+        createDeserializerForPeriod
 
       case t if isSubtype(t, localTypeOf[java.lang.String]) =>
-        createDeserializerForString(path, returnNullable = false)
+        createDeserializerForString(_, returnNullable = false)
 
       case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
-        createDeserializerForJavaBigDecimal(path, returnNullable = false)
+        createDeserializerForJavaBigDecimal(_, returnNullable = false)
 
       case t if isSubtype(t, localTypeOf[BigDecimal]) =>
-        createDeserializerForScalaBigDecimal(path, returnNullable = false)
+        createDeserializerForScalaBigDecimal(_, returnNullable = false)
 
       case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
-        createDeserializerForJavaBigInteger(path, returnNullable = false)
+        createDeserializerForJavaBigInteger(_, returnNullable = false)
 
       case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
-        createDeserializerForScalaBigInt(path)
+        createDeserializerForScalaBigInt
 
       case t if isSubtype(t, localTypeOf[Array[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
         val Schema(dataType, elementNullable) = schemaFor(elementType)
         val className = getClassNameFromType(elementType)
         val newTypePath = walkedTypePath.recordArray(className)
-
+        val deserializerFunc = deserializerFor(elementType, newTypePath)
         val mapFunction: Expression => Expression = element => {
           // upcast the array element to the data type the encoder expected.
           deserializerForWithNullSafetyAndUpcast(
@@ -282,10 +276,9 @@ object ScalaReflection extends ScalaReflection {
             dataType,
             nullable = elementNullable,
             newTypePath,
-            (casted, typePath) => deserializerFor(elementType, casted, typePath))
+            deserializerFunc)
         }
 
-        val arrayData = UnresolvedMapObjects(mapFunction, path)
         val arrayCls = arrayClassFor(elementType)
 
         val methodName = elementType match {
@@ -299,7 +292,10 @@ object ScalaReflection extends ScalaReflection {
           // non-primitive
           case _ => "array"
         }
-        Invoke(arrayData, methodName, arrayCls, returnNullable = false)
+        path => {
+          val arrayData = UnresolvedMapObjects(mapFunction, path)
+          Invoke(arrayData, methodName, arrayCls, returnNullable = false)
+        }
 
       // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array
       // to a `Set`, if there are duplicated elements, the elements will be de-duplicated.
@@ -309,14 +305,14 @@ object ScalaReflection extends ScalaReflection {
         val Schema(dataType, elementNullable) = schemaFor(elementType)
         val className = getClassNameFromType(elementType)
         val newTypePath = walkedTypePath.recordArray(className)
-
+        val deserializerFunc = deserializerFor(elementType, newTypePath)
         val mapFunction: Expression => Expression = element => {
           deserializerForWithNullSafetyAndUpcast(
             element,
             dataType,
             nullable = elementNullable,
             newTypePath,
-            (casted, typePath) => deserializerFor(elementType, casted, typePath))
+            deserializerFunc)
         }
 
         val companion = t.dealias.typeSymbol.companion.typeSignature
@@ -326,7 +322,7 @@ object ScalaReflection extends ScalaReflection {
             classOf[scala.collection.Set[_]]
           case _ => mirror.runtimeClass(t.typeSymbol.asClass)
         }
-        UnresolvedMapObjects(mapFunction, path, Some(cls))
+        UnresolvedMapObjects(mapFunction, _, Some(cls))
 
       case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
         val TypeRef(_, _, Seq(keyType, valueType)) = t
@@ -336,12 +332,12 @@ object ScalaReflection extends ScalaReflection {
 
         val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue)
 
-        UnresolvedCatalystToExternalMap(
-          path,
-          p => deserializerFor(keyType, p, newTypePath),
-          p => deserializerFor(valueType, p, newTypePath),
-          mirror.runtimeClass(t.typeSymbol.asClass)
-        )
+        // Code touching Scala Reflection should be called outside the returned function to allow
+        // caching the Scala Reflection result
+        val keyDeserializerFunc = deserializerFor(keyType, newTypePath)
+        val valueDeserializerFunc = deserializerFor(valueType, newTypePath)
+        val cls = mirror.runtimeClass(t.typeSymbol.asClass)
+        UnresolvedCatalystToExternalMap(_, keyDeserializerFunc, valueDeserializerFunc, cls)
 
       case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
         val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
@@ -350,7 +346,10 @@ object ScalaReflection extends ScalaReflection {
           udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
           Nil,
           dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
-        Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
+        // Code touching Scala Reflection should be called outside the returned function to allow
+        // caching the Scala Reflection result
+        val cls = udt.userClass
+        path => Invoke(obj, "deserialize", ObjectType(cls), Seq(path))
 
       case t if UDTRegistration.exists(getClassNameFromType(t)) =>
         val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
@@ -359,43 +358,44 @@ object ScalaReflection extends ScalaReflection {
           udt.getClass,
           Nil,
           dataType = ObjectType(udt.getClass))
-        Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
+        // Code touching Scala Reflection should be called outside the returned function to allow
+        // caching the Scala Reflection result
+        val cls = udt.userClass
+        path => Invoke(obj, "deserialize", ObjectType(cls), Seq(path))
 
       case t if definedByConstructorParams(t) =>
         val params = getConstructorParameters(t)
 
         val cls = getClassFromType(tpe)
 
-        val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
+        val arguDeserializerFuncs = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
           val Schema(dataType, nullable) = schemaFor(fieldType)
           val clsName = getClassNameFromType(fieldType)
           val newTypePath = walkedTypePath.recordField(clsName, fieldName)
 
           // For tuples, we based grab the inner fields by ordinal instead of name.
-          val newPath = if (cls.getName startsWith "scala.Tuple") {
-            deserializerFor(
-              fieldType,
-              addToPathOrdinal(path, i, dataType, newTypePath),
-              newTypePath)
+          val newPathFunc = if (cls.getName startsWith "scala.Tuple") {
+            addToPathOrdinal(_, i, dataType, newTypePath)
           } else {
-            deserializerFor(
-              fieldType,
-              addToPath(path, fieldName, dataType, newTypePath),
-              newTypePath)
+            addToPath(_, fieldName, dataType, newTypePath)
           }
-          expressionWithNullSafety(
-            newPath,
+          val deserializerFunc = deserializerFor(fieldType, newTypePath)
+          (path: Expression) => expressionWithNullSafety(
+            deserializerFunc(newPathFunc(path)),
             nullable = nullable,
             newTypePath)
         }
 
-        val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
-
-        expressions.If(
-          IsNull(path),
-          expressions.Literal.create(null, ObjectType(cls)),
-          newInstance
-        )
+        val nullLit = expressions.Literal.create(null, ObjectType(cls))
+        path => {
+          val arguments = arguDeserializerFuncs.map(_(path))
+          val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
+          expressions.If(
+            IsNull(path),
+            nullLit,
+            newInstance
+          )
+        }
 
       case t if isSubtype(t, localTypeOf[Enumeration#Value]) =>
         // package example
@@ -406,10 +406,13 @@ object ScalaReflection extends ScalaReflection {
         // the fullName of tpe is example.Foo.Foo, but we need example.Foo so that
         // we can call example.Foo.withName to deserialize string to enumeration.
         val parent = t.asInstanceOf[TypeRef].pre.typeSymbol.asClass
-        val cls = mirror.runtimeClass(parent)
-        StaticInvoke(
-          cls,
-          ObjectType(getClassFromType(t)),
+        // Code touching Scala Reflection should be called outside the returned function to allow
+        // caching the Scala Reflection result
+        val parentCls = mirror.runtimeClass(parent)
+        val cls = getClassFromType(t)
+        path => StaticInvoke(
+          parentCls,
+          ObjectType(cls),
           "withName",
           createDeserializerForString(path, false) :: Nil,
           returnNullable = false)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org