You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by zs...@apache.org on 2022/11/08 16:19:07 UTC
[spark] branch master updated: [SPARK-41045][SQL] Pre-compute to eliminate ScalaReflection calls after deserializer is created
This is an automated email from the ASF dual-hosted git repository.
zsxwing pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ef402edff91 [SPARK-41045][SQL] Pre-compute to eliminate ScalaReflection calls after deserializer is created
ef402edff91 is described below
commit ef402edff91377d37c0c1b8d40921ed7bd9f7160
Author: Shixiong Zhu <zs...@gmail.com>
AuthorDate: Tue Nov 8 08:18:50 2022 -0800
[SPARK-41045][SQL] Pre-compute to eliminate ScalaReflection calls after deserializer is created
### What changes were proposed in this pull request?
Currently when `ScalaReflection` returns a deserializer, for a few complex types, such as array, map, udt, etc, it creates functions that may still touch `ScalaReflection` after the deserializer is created.
`ScalaReflection` is a performance bottleneck for multiple threads as it holds multiple global locks. We can refactor `ScalaReflection.deserializerFor` to pre-compute everything that needs to touch `ScalaReflection` before creating the deserializer. After this, once the deserializer is created, it can be reused by multiple threads without touching `ScalaReflection.deserializerFor` any more and it will be much faster.
### Why are the changes needed?
Optimize `ScalaReflection.deserializerFor` to make deserializers faster under multiple threads.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
This is refactoring `deserializerFor` to optimize the code. Existing tests should already cover the correctness.
Closes #38556 from zsxwing/scala-ref.
Authored-by: Shixiong Zhu <zs...@gmail.com>
Signed-off-by: Shixiong Zhu <zs...@gmail.com>
---
.../sql/catalyst/DeserializerBuildHelper.scala | 5 +-
.../spark/sql/catalyst/JavaTypeInference.scala | 8 +-
.../spark/sql/catalyst/ScalaReflection.scala | 157 +++++++++++----------
3 files changed, 85 insertions(+), 85 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
index 0d3b9977e4f..7051c2d2264 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
@@ -49,10 +49,9 @@ object DeserializerBuildHelper {
dataType: DataType,
nullable: Boolean,
walkedTypePath: WalkedTypePath,
- funcForCreatingDeserializer: (Expression, WalkedTypePath) => Expression): Expression = {
+ funcForCreatingDeserializer: Expression => Expression): Expression = {
val casted = upCastToExpectedType(expr, dataType, walkedTypePath)
- expressionWithNullSafety(funcForCreatingDeserializer(casted, walkedTypePath),
- nullable, walkedTypePath)
+ expressionWithNullSafety(funcForCreatingDeserializer(casted), nullable, walkedTypePath)
}
def expressionWithNullSafety(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index dccaf1c4835..827807055ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -218,9 +218,7 @@ object JavaTypeInference {
// Assumes we are deserializing the first column of a row.
deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType,
- nullable = nullable, walkedTypePath, (casted, walkedTypePath) => {
- deserializerFor(typeToken, casted, walkedTypePath)
- })
+ nullable = nullable, walkedTypePath, deserializerFor(typeToken, _, walkedTypePath))
}
private def deserializerFor(
@@ -280,7 +278,7 @@ object JavaTypeInference {
dataType,
nullable = elementNullable,
newTypePath,
- (casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath))
+ deserializerFor(typeToken.getComponentType, _, newTypePath))
}
val arrayData = UnresolvedMapObjects(mapFunction, path)
@@ -309,7 +307,7 @@ object JavaTypeInference {
dataType,
nullable = elementNullable,
newTypePath,
- (casted, typePath) => deserializerFor(et, casted, typePath))
+ deserializerFor(et, _, newTypePath))
}
UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 12093b9f4b2..d895a0fbe19 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -172,109 +172,103 @@ object ScalaReflection extends ScalaReflection {
val clsName = getClassNameFromType(tpe)
val walkedTypePath = new WalkedTypePath().recordRoot(clsName)
val Schema(dataType, nullable) = schemaFor(tpe)
-
+ val deserializerFunc = deserializerFor(tpe, walkedTypePath)
// Assumes we are deserializing the first column of a row.
deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType,
- nullable = nullable, walkedTypePath,
- (casted, typePath) => deserializerFor(tpe, casted, typePath))
+ nullable = nullable, walkedTypePath, deserializerFunc)
}
/**
- * Returns an expression that can be used to deserialize an input expression to an object of type
- * `T` with a compatible schema.
+ * Returns a function that receives an input expression and turns it to an expression that can be
+ * used to deserialize the input expression to an object of type `T` with a compatible schema.
*
* @param tpe The `Type` of deserialized object.
- * @param path The expression which can be used to extract serialized value.
* @param walkedTypePath The paths from top to bottom to access current field when deserializing.
*/
private def deserializerFor(
tpe: `Type`,
- path: Expression,
- walkedTypePath: WalkedTypePath): Expression = cleanUpReflectionObjects {
+ walkedTypePath: WalkedTypePath): Expression => Expression = cleanUpReflectionObjects {
baseType(tpe) match {
- case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path
+ case t if !dataTypeFor(t).isInstanceOf[ObjectType] => identity
case t if isSubtype(t, localTypeOf[Option[_]]) =>
val TypeRef(_, _, Seq(optType)) = t
val className = getClassNameFromType(optType)
val newTypePath = walkedTypePath.recordOption(className)
- WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType))
+ val dataType = dataTypeFor(optType)
+ val deserializerFunc = deserializerFor(optType, newTypePath)
+ path => WrapOption(deserializerFunc(path), dataType)
case t if isSubtype(t, localTypeOf[java.lang.Integer]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Integer])
+ createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Integer])
case t if isSubtype(t, localTypeOf[java.lang.Long]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Long])
+ createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Long])
case t if isSubtype(t, localTypeOf[java.lang.Double]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Double])
+ createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Double])
case t if isSubtype(t, localTypeOf[java.lang.Float]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Float])
+ createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Float])
case t if isSubtype(t, localTypeOf[java.lang.Short]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Short])
+ createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Short])
case t if isSubtype(t, localTypeOf[java.lang.Byte]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Byte])
+ createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Byte])
case t if isSubtype(t, localTypeOf[java.lang.Boolean]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Boolean])
+ createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Boolean])
case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
- createDeserializerForLocalDate(path)
+ createDeserializerForLocalDate
case t if isSubtype(t, localTypeOf[java.sql.Date]) =>
- createDeserializerForSqlDate(path)
+ createDeserializerForSqlDate
case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
- createDeserializerForInstant(path)
+ createDeserializerForInstant
case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) =>
- createDeserializerForTypesSupportValueOf(
- Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false),
- getClassFromType(t))
+ // Code touching Scala Reflection should be called outside the returned function to allow
+ // caching the Scala Reflection result
+ val cls = getClassFromType(t)
+ path => createDeserializerForTypesSupportValueOf(
+ Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false), cls)
case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
- createDeserializerForSqlTimestamp(path)
+ createDeserializerForSqlTimestamp
case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) =>
- createDeserializerForLocalDateTime(path)
+ createDeserializerForLocalDateTime
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
- createDeserializerForDuration(path)
+ createDeserializerForDuration
case t if isSubtype(t, localTypeOf[java.time.Period]) =>
- createDeserializerForPeriod(path)
+ createDeserializerForPeriod
case t if isSubtype(t, localTypeOf[java.lang.String]) =>
- createDeserializerForString(path, returnNullable = false)
+ createDeserializerForString(_, returnNullable = false)
case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
- createDeserializerForJavaBigDecimal(path, returnNullable = false)
+ createDeserializerForJavaBigDecimal(_, returnNullable = false)
case t if isSubtype(t, localTypeOf[BigDecimal]) =>
- createDeserializerForScalaBigDecimal(path, returnNullable = false)
+ createDeserializerForScalaBigDecimal(_, returnNullable = false)
case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
- createDeserializerForJavaBigInteger(path, returnNullable = false)
+ createDeserializerForJavaBigInteger(_, returnNullable = false)
case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
- createDeserializerForScalaBigInt(path)
+ createDeserializerForScalaBigInt
case t if isSubtype(t, localTypeOf[Array[_]]) =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, elementNullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
val newTypePath = walkedTypePath.recordArray(className)
-
+ val deserializerFunc = deserializerFor(elementType, newTypePath)
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
deserializerForWithNullSafetyAndUpcast(
@@ -282,10 +276,9 @@ object ScalaReflection extends ScalaReflection {
dataType,
nullable = elementNullable,
newTypePath,
- (casted, typePath) => deserializerFor(elementType, casted, typePath))
+ deserializerFunc)
}
- val arrayData = UnresolvedMapObjects(mapFunction, path)
val arrayCls = arrayClassFor(elementType)
val methodName = elementType match {
@@ -299,7 +292,10 @@ object ScalaReflection extends ScalaReflection {
// non-primitive
case _ => "array"
}
- Invoke(arrayData, methodName, arrayCls, returnNullable = false)
+ path => {
+ val arrayData = UnresolvedMapObjects(mapFunction, path)
+ Invoke(arrayData, methodName, arrayCls, returnNullable = false)
+ }
// We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array
// to a `Set`, if there are duplicated elements, the elements will be de-duplicated.
@@ -309,14 +305,14 @@ object ScalaReflection extends ScalaReflection {
val Schema(dataType, elementNullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
val newTypePath = walkedTypePath.recordArray(className)
-
+ val deserializerFunc = deserializerFor(elementType, newTypePath)
val mapFunction: Expression => Expression = element => {
deserializerForWithNullSafetyAndUpcast(
element,
dataType,
nullable = elementNullable,
newTypePath,
- (casted, typePath) => deserializerFor(elementType, casted, typePath))
+ deserializerFunc)
}
val companion = t.dealias.typeSymbol.companion.typeSignature
@@ -326,7 +322,7 @@ object ScalaReflection extends ScalaReflection {
classOf[scala.collection.Set[_]]
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
}
- UnresolvedMapObjects(mapFunction, path, Some(cls))
+ UnresolvedMapObjects(mapFunction, _, Some(cls))
case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
@@ -336,12 +332,12 @@ object ScalaReflection extends ScalaReflection {
val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue)
- UnresolvedCatalystToExternalMap(
- path,
- p => deserializerFor(keyType, p, newTypePath),
- p => deserializerFor(valueType, p, newTypePath),
- mirror.runtimeClass(t.typeSymbol.asClass)
- )
+ // Code touching Scala Reflection should be called outside the returned function to allow
+ // caching the Scala Reflection result
+ val keyDeserializerFunc = deserializerFor(keyType, newTypePath)
+ val valueDeserializerFunc = deserializerFor(valueType, newTypePath)
+ val cls = mirror.runtimeClass(t.typeSymbol.asClass)
+ UnresolvedCatalystToExternalMap(_, keyDeserializerFunc, valueDeserializerFunc, cls)
case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
@@ -350,7 +346,10 @@ object ScalaReflection extends ScalaReflection {
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
- Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
+ // Code touching Scala Reflection should be called outside the returned function to allow
+ // caching the Scala Reflection result
+ val cls = udt.userClass
+ path => Invoke(obj, "deserialize", ObjectType(cls), Seq(path))
case t if UDTRegistration.exists(getClassNameFromType(t)) =>
val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
@@ -359,43 +358,44 @@ object ScalaReflection extends ScalaReflection {
udt.getClass,
Nil,
dataType = ObjectType(udt.getClass))
- Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
+ // Code touching Scala Reflection should be called outside the returned function to allow
+ // caching the Scala Reflection result
+ val cls = udt.userClass
+ path => Invoke(obj, "deserialize", ObjectType(cls), Seq(path))
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
val cls = getClassFromType(tpe)
- val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
+ val arguDeserializerFuncs = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
val Schema(dataType, nullable) = schemaFor(fieldType)
val clsName = getClassNameFromType(fieldType)
val newTypePath = walkedTypePath.recordField(clsName, fieldName)
// For tuples, we based grab the inner fields by ordinal instead of name.
- val newPath = if (cls.getName startsWith "scala.Tuple") {
- deserializerFor(
- fieldType,
- addToPathOrdinal(path, i, dataType, newTypePath),
- newTypePath)
+ val newPathFunc = if (cls.getName startsWith "scala.Tuple") {
+ addToPathOrdinal(_, i, dataType, newTypePath)
} else {
- deserializerFor(
- fieldType,
- addToPath(path, fieldName, dataType, newTypePath),
- newTypePath)
+ addToPath(_, fieldName, dataType, newTypePath)
}
- expressionWithNullSafety(
- newPath,
+ val deserializerFunc = deserializerFor(fieldType, newTypePath)
+ (path: Expression) => expressionWithNullSafety(
+ deserializerFunc(newPathFunc(path)),
nullable = nullable,
newTypePath)
}
- val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
-
- expressions.If(
- IsNull(path),
- expressions.Literal.create(null, ObjectType(cls)),
- newInstance
- )
+ val nullLit = expressions.Literal.create(null, ObjectType(cls))
+ path => {
+ val arguments = arguDeserializerFuncs.map(_(path))
+ val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
+ expressions.If(
+ IsNull(path),
+ nullLit,
+ newInstance
+ )
+ }
case t if isSubtype(t, localTypeOf[Enumeration#Value]) =>
// package example
@@ -406,10 +406,13 @@ object ScalaReflection extends ScalaReflection {
// the fullName of tpe is example.Foo.Foo, but we need example.Foo so that
// we can call example.Foo.withName to deserialize string to enumeration.
val parent = t.asInstanceOf[TypeRef].pre.typeSymbol.asClass
- val cls = mirror.runtimeClass(parent)
- StaticInvoke(
- cls,
- ObjectType(getClassFromType(t)),
+ // Code touching Scala Reflection should be called outside the returned function to allow
+ // caching the Scala Reflection result
+ val parentCls = mirror.runtimeClass(parent)
+ val cls = getClassFromType(t)
+ path => StaticInvoke(
+ parentCls,
+ ObjectType(cls),
"withName",
createDeserializerForString(path, false) :: Nil,
returnNullable = false)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org