You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/02/27 04:45:50 UTC

[spark] branch master updated: [SPARK-26837][SQL] Pruning nested fields from object serializers

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0f2c0b5  [SPARK-26837][SQL] Pruning nested fields from object serializers
0f2c0b5 is described below

commit 0f2c0b53e8fb18c86c67b5dd679c006db93f94a5
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Wed Feb 27 12:45:24 2019 +0800

    [SPARK-26837][SQL] Pruning nested fields from object serializers
    
    ## What changes were proposed in this pull request?
    
    In SPARK-26619, we make change to prune unnecessary individual serializers when serializing objects. This is extension to SPARK-26619. We can further prune nested fields from object serializers if they are not used.
    
    For example, in following query, we only use one field in a struct column:
    
    ```scala
    val data = Seq((("a", 1), 1), (("b", 2), 2), (("c", 3), 3))
    val df = data.toDS().map(t => (t._1, t._2 + 1)).select("_1._1")
    ```
    
    So, instead of having a serializer to create a two fields struct, we can prune unnecessary field from it. This is what this PR proposes to do.
    
    In order to make this change conservative and safer, a SQL config is added to control it. It is disabled by default.
    
    TODO: Support to prune nested fields inside MapType's key and value.
    
    ## How was this patch tested?
    
    Added tests.
    
    Closes #23740 from viirya/nested-pruning-serializer-2.
    
    Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../expressions}/GetStructFieldObject.scala        |   5 +-
 .../expressions}/ProjectionOverSchema.scala        |   5 +-
 .../sql/catalyst/expressions/SchemaPruning.scala   | 147 +++++++++++++++++++++
 .../sql/catalyst/expressions}/SelectedField.scala  |   5 +-
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |   8 +-
 .../spark/sql/catalyst/optimizer/objects.scala     | 102 ++++++++++++++
 .../org/apache/spark/sql/internal/SQLConf.scala    |  12 ++
 .../catalyst/expressions/SchemaPruningSuite.scala  |  62 +++++++++
 .../catalyst/expressions}/SelectedFieldSuite.scala |   3 +-
 .../catalyst/optimizer/ColumnPruningSuite.scala    |  10 --
 .../optimizer/ObjectSerializerPruningSuite.scala   | 103 +++++++++++++++
 .../datasources/parquet/ParquetSchemaPruning.scala | 125 +-----------------
 .../spark/sql/DatasetOptimizationSuite.scala       | 107 +++++++++++++++
 .../scala/org/apache/spark/sql/DatasetSuite.scala  |  11 --
 14 files changed, 545 insertions(+), 160 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GetStructFieldObject.scala
similarity index 88%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GetStructFieldObject.scala
index c88b2f8..0bea0cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GetStructFieldObject.scala
@@ -15,9 +15,8 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.execution
+package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField}
 import org.apache.spark.sql.types.StructField
 
 /**
@@ -25,7 +24,7 @@ import org.apache.spark.sql.types.StructField
  * This is in contrast to the [[GetStructField]] case class extractor which returns the field
  * ordinal instead of the field itself.
  */
-private[execution] object GetStructFieldObject {
+object GetStructFieldObject {
   def unapply(getStructField: GetStructField): Option[(Expression, StructField)] =
     Some((
       getStructField.child,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
similarity index 94%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
index 612a7b8..f4956a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
@@ -15,9 +15,8 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.execution
+package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
 
 /**
@@ -26,7 +25,7 @@ import org.apache.spark.sql.types._
  * are adjusted to fit the schema. All other expressions are left as-is. This
  * class is motivated by columnar nested schema pruning.
  */
-private[execution] case class ProjectionOverSchema(schema: StructType) {
+case class ProjectionOverSchema(schema: StructType) {
   private val fieldNames = schema.fieldNames.toSet
 
   def unapply(expr: Expression): Option[Expression] = getProjection(expr)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
new file mode 100644
index 0000000..6213267
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.types._
+
+object SchemaPruning {
+  /**
+   * Filters the schema by the requested fields. For example, if the schema is struct<a:int, b:int>,
+   * and given requested field are "a", the field "b" is pruned in the returned schema.
+   * Note that schema field ordering at original schema is still preserved in pruned schema.
+   */
+  def pruneDataSchema(
+      dataSchema: StructType,
+      requestedRootFields: Seq[RootField]): StructType = {
+    // Merge the requested root fields into a single schema. Note the ordering of the fields
+    // in the resulting schema may differ from their ordering in the logical relation's
+    // original schema
+    val mergedSchema = requestedRootFields
+      .map { case root: RootField => StructType(Array(root.field)) }
+      .reduceLeft(_ merge _)
+    val dataSchemaFieldNames = dataSchema.fieldNames.toSet
+    val mergedDataSchema =
+      StructType(mergedSchema.filter(f => dataSchemaFieldNames.contains(f.name)))
+    // Sort the fields of mergedDataSchema according to their order in dataSchema,
+    // recursively. This makes mergedDataSchema a pruned schema of dataSchema
+    sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType]
+  }
+
+  /**
+   * Sorts the fields and descendant fields of structs in left according to their order in
+   * right. This function assumes that the fields of left are a subset of the fields of
+   * right, recursively. That is, left is a "subschema" of right, ignoring order of
+   * fields.
+   */
+  private def sortLeftFieldsByRight(left: DataType, right: DataType): DataType =
+    (left, right) match {
+      case (ArrayType(leftElementType, containsNull), ArrayType(rightElementType, _)) =>
+        ArrayType(
+          sortLeftFieldsByRight(leftElementType, rightElementType),
+          containsNull)
+      case (MapType(leftKeyType, leftValueType, containsNull),
+          MapType(rightKeyType, rightValueType, _)) =>
+        MapType(
+          sortLeftFieldsByRight(leftKeyType, rightKeyType),
+          sortLeftFieldsByRight(leftValueType, rightValueType),
+          containsNull)
+      case (leftStruct: StructType, rightStruct: StructType) =>
+        val filteredRightFieldNames = rightStruct.fieldNames.filter(leftStruct.fieldNames.contains)
+        val sortedLeftFields = filteredRightFieldNames.map { fieldName =>
+          val leftFieldType = leftStruct(fieldName).dataType
+          val rightFieldType = rightStruct(fieldName).dataType
+          val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType)
+          StructField(fieldName, sortedLeftFieldType, nullable = leftStruct(fieldName).nullable)
+        }
+        StructType(sortedLeftFields)
+      case _ => left
+    }
+
+  /**
+   * Returns the set of fields from projection and filtering predicates that the query plan needs.
+   */
+  def identifyRootFields(
+      projects: Seq[NamedExpression],
+      filters: Seq[Expression]): Seq[RootField] = {
+    val projectionRootFields = projects.flatMap(getRootFields)
+    val filterRootFields = filters.flatMap(getRootFields)
+
+    // Kind of expressions don't need to access any fields of a root fields, e.g., `IsNotNull`.
+    // For them, if there are any nested fields accessed in the query, we don't need to add root
+    // field access of above expressions.
+    // For example, for a query `SELECT name.first FROM contacts WHERE name IS NOT NULL`,
+    // we don't need to read nested fields of `name` struct other than `first` field.
+    val (rootFields, optRootFields) = (projectionRootFields ++ filterRootFields)
+      .distinct.partition(!_.prunedIfAnyChildAccessed)
+
+    optRootFields.filter { opt =>
+      !rootFields.exists { root =>
+        root.field.name == opt.field.name && {
+          // Checking if current optional root field can be pruned.
+          // For each required root field, we merge it with the optional root field:
+          // 1. If this optional root field has nested fields and any nested field of it is used
+          //    in the query, the merged field type must equal to the optional root field type.
+          //    We can prune this optional root field. For example, for optional root field
+          //    `struct<name:struct<middle:string,last:string>>`, if its field
+          //    `struct<name:struct<last:string>>` is used, we don't need to add this optional
+          //    root field.
+          // 2. If this optional root field has no nested fields, the merged field type equals
+          //    to the optional root field only if they are the same. If they are, we can prune
+          //    this optional root field too.
+          val rootFieldType = StructType(Array(root.field))
+          val optFieldType = StructType(Array(opt.field))
+          val merged = optFieldType.merge(rootFieldType)
+          merged.sameType(optFieldType)
+        }
+      }
+    } ++ rootFields
+  }
+
+  /**
+   * Gets the root (aka top-level, no-parent) [[StructField]]s for the given [[Expression]].
+   * When expr is an [[Attribute]], construct a field around it and indicate that that
+   * field was derived from an attribute.
+   */
+  private def getRootFields(expr: Expression): Seq[RootField] = {
+    expr match {
+      case att: Attribute =>
+        RootField(StructField(att.name, att.dataType, att.nullable), derivedFromAtt = true) :: Nil
+      case SelectedField(field) => RootField(field, derivedFromAtt = false) :: Nil
+      // Root field accesses by `IsNotNull` and `IsNull` are special cases as the expressions
+      // don't actually use any nested fields. These root field accesses might be excluded later
+      // if there are any nested fields accesses in the query plan.
+      case IsNotNull(SelectedField(field)) =>
+        RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed = true) :: Nil
+      case IsNull(SelectedField(field)) =>
+        RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed = true) :: Nil
+      case IsNotNull(_: Attribute) | IsNull(_: Attribute) =>
+        expr.children.flatMap(getRootFields).map(_.copy(prunedIfAnyChildAccessed = true))
+      case _ =>
+        expr.children.flatMap(getRootFields)
+    }
+  }
+
+  /**
+   * This represents a "root" schema field (aka top-level, no-parent). `field` is the
+   * `StructField` for field name and datatype. `derivedFromAtt` indicates whether it
+   * was derived from an attribute or had a proper child. `prunedIfAnyChildAccessed` means
+   * whether this root field can be pruned if any of child field is used in the query.
+   */
+  case class RootField(field: StructField, derivedFromAtt: Boolean,
+    prunedIfAnyChildAccessed: Boolean = false)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
similarity index 97%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
index 68f797a..38a0481 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
@@ -15,10 +15,9 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.execution
+package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
 
 /**
@@ -53,7 +52,7 @@ import org.apache.spark.sql.types._
  * is "name" and its data type is a [[org.apache.spark.sql.types.StructType]] with a single string
  * field named "first".
  */
-private[execution] object SelectedField {
+object SelectedField {
   def unapply(expr: Expression): Option[StructField] = {
     // If this expression is an alias, work on its child instead
     val unaliased = expr match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 38a051c..ad25898 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -197,7 +197,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
       DecimalAggregates) :+
     Batch("Object Expressions Optimization", fixedPoint,
       EliminateMapObjects,
-      CombineTypedFilters) :+
+      CombineTypedFilters,
+      ObjectSerializerPruning) :+
     Batch("LocalRelation", fixedPoint,
       ConvertToLocalRelation,
       PropagateEmptyRelation) :+
@@ -594,11 +595,6 @@ object ColumnPruning extends Rule[LogicalPlan] {
     case d @ DeserializeToObject(_, _, child) if !child.outputSet.subsetOf(d.references) =>
       d.copy(child = prunedChild(child, d.references))
 
-    case p @ Project(_, s: SerializeFromObject) if p.references != s.outputSet =>
-      val usedRefs = p.references
-      val prunedSerializer = s.serializer.filter(usedRefs.contains)
-      p.copy(child = SerializeFromObject(prunedSerializer, s.child))
-
     // Prunes the unused columns from child of Aggregate/Expand/Generate/ScriptTransformation
     case a @ Aggregate(_, _, child) if !child.outputSet.subsetOf(a.references) =>
       a.copy(child = prunedChild(child, a.references))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
index 8cdc642..96a172c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
@@ -17,11 +17,15 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import scala.collection.mutable.ArrayBuffer
+
 import org.apache.spark.api.java.function.FilterFunction
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{ArrayType, DataType, StructType}
 
 /*
  * This file defines optimization rules related to object manipulation (for the Dataset API).
@@ -109,3 +113,101 @@ object EliminateMapObjects extends Rule[LogicalPlan] {
      case MapObjects(_, _, _, LambdaVariable(_, _, _, false), inputData, None) => inputData
   }
 }
+
+/**
+ * Prunes unnecessary object serializers from query plan. This rule prunes both individual
+ * serializer and nested fields in serializers.
+ */
+object ObjectSerializerPruning extends Rule[LogicalPlan] {
+
+  /**
+   * Collects all struct types from given data type object, recursively. Supports struct and array
+   * types for now.
+   * TODO(SPARK-26847): support map type.
+   */
+  def collectStructType(dt: DataType, structs: ArrayBuffer[StructType]): ArrayBuffer[StructType] = {
+    dt match {
+      case s @ StructType(fields) =>
+        structs += s
+        fields.map(f => collectStructType(f.dataType, structs))
+      case ArrayType(elementType, _) =>
+        collectStructType(elementType, structs)
+      case _ =>
+    }
+    structs
+  }
+
+  /**
+   * This method prunes given serializer expression by given pruned data type. For example,
+   * given a serializer creating struct(a int, b int) and pruned data type struct(a int),
+   * this method returns pruned serializer creating struct(a int). For now it supports to
+   * prune nested fields in struct and array of struct.
+   * TODO(SPARK-26847): support to prune nested fields in key and value of map type.
+   */
+  def pruneSerializer(
+      serializer: NamedExpression,
+      prunedDataType: DataType): NamedExpression = {
+    val prunedStructTypes = collectStructType(prunedDataType, ArrayBuffer.empty[StructType])
+    var structTypeIndex = 0
+
+    val prunedSerializer = serializer.transformDown {
+      case s: CreateNamedStruct if structTypeIndex < prunedStructTypes.size =>
+        val prunedType = prunedStructTypes(structTypeIndex)
+
+        // Filters out the pruned fields.
+        val prunedFields = s.nameExprs.zip(s.valExprs).filter { case (nameExpr, _) =>
+          val name = nameExpr.eval(EmptyRow).toString
+          prunedType.fieldNames.exists { fieldName =>
+            if (SQLConf.get.caseSensitiveAnalysis) {
+              fieldName.equals(name)
+            } else {
+              fieldName.equalsIgnoreCase(name)
+            }
+          }
+        }.flatMap(pair => Seq(pair._1, pair._2))
+
+        structTypeIndex += 1
+        CreateNamedStruct(prunedFields)
+    }.transformUp {
+      // When we change nested serializer data type, `If` expression will be unresolved because
+      // literal null's data type doesn't match now. We need to align it with new data type.
+      // Note: we should do `transformUp` explicitly to change data types.
+      case i @ If(_: IsNull, Literal(null, dt), ser) if !dt.sameType(ser.dataType) =>
+        i.copy(trueValue = Literal(null, ser.dataType))
+    }.asInstanceOf[NamedExpression]
+
+    if (prunedSerializer.dataType.sameType(prunedDataType)) {
+      prunedSerializer
+    } else {
+      serializer
+    }
+  }
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case p @ Project(_, s: SerializeFromObject) =>
+      // Prunes individual serializer if it is not used at all by above projection.
+      val usedRefs = p.references
+      val prunedSerializer = s.serializer.filter(usedRefs.contains)
+
+      val rootFields = SchemaPruning.identifyRootFields(p.projectList, Seq.empty)
+
+      if (SQLConf.get.serializerNestedSchemaPruningEnabled && rootFields.nonEmpty) {
+        // Prunes nested fields in serializers.
+        val prunedSchema = SchemaPruning.pruneDataSchema(
+          StructType.fromAttributes(prunedSerializer.map(_.toAttribute)), rootFields)
+        val nestedPrunedSerializer = prunedSerializer.zipWithIndex.map { case (serializer, idx) =>
+          pruneSerializer(serializer, prunedSchema(idx).dataType)
+        }
+
+        // Builds new projection.
+        val projectionOverSchema = ProjectionOverSchema(prunedSchema)
+        val newProjects = p.projectList.map(_.transformDown {
+          case projectionOverSchema(expr) => expr
+        }).map { case expr: NamedExpression => expr }
+        p.copy(projectList = newProjects,
+          child = SerializeFromObject(nestedPrunedSerializer, s.child))
+      } else {
+        p.copy(child = SerializeFromObject(prunedSerializer, s.child))
+      }
+  }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index bbb79cd..e74c2af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1528,6 +1528,15 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED =
+    buildConf("spark.sql.optimizer.serializer.nestedSchemaPruning.enabled")
+      .internal()
+      .doc("Prune nested fields from object serialization operator which are unnecessary in " +
+        "satisfying a query. This optimization allows object serializers to avoid " +
+        "executing unnecessary nested expressions.")
+      .booleanConf
+      .createWithDefault(false)
+
   val TOP_K_SORT_FALLBACK_THRESHOLD =
     buildConf("spark.sql.execution.topKSortFallbackThreshold")
       .internal()
@@ -2077,6 +2086,9 @@ class SQLConf extends Serializable with Logging {
 
   def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED)
 
+  def serializerNestedSchemaPruningEnabled: Boolean =
+    getConf(SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED)
+
   def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING)
 
   def legacySizeOfNull: Boolean = getConf(SQLConf.LEGACY_SIZE_OF_NULL)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
new file mode 100644
index 0000000..c04f59e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types._
+
+class SchemaPruningSuite extends SparkFunSuite {
+  test("prune schema by the requested fields") {
+    def testPrunedSchema(
+        schema: StructType,
+        requestedFields: StructField*): Unit = {
+      val requestedRootFields = requestedFields.map { f =>
+        // `derivedFromAtt` doesn't affect the result of pruned schema.
+        SchemaPruning.RootField(field = f, derivedFromAtt = true)
+      }
+      val expectedSchema = SchemaPruning.pruneDataSchema(schema, requestedRootFields)
+      assert(expectedSchema == StructType(requestedFields))
+    }
+
+    testPrunedSchema(StructType.fromDDL("a int, b int"), StructField("a", IntegerType))
+    testPrunedSchema(StructType.fromDDL("a int, b int"), StructField("b", IntegerType))
+
+    val structOfStruct = StructType.fromDDL("a struct<a:int, b:int>, b int")
+    testPrunedSchema(structOfStruct, StructField("a", StructType.fromDDL("a int, b int")))
+    testPrunedSchema(structOfStruct, StructField("b", IntegerType))
+    testPrunedSchema(structOfStruct, StructField("a", StructType.fromDDL("b int")))
+
+    val arrayOfStruct = StructField("a", ArrayType(StructType.fromDDL("a int, b int, c string")))
+    val mapOfStruct = StructField("d", MapType(StructType.fromDDL("a int, b int, c string"),
+      StructType.fromDDL("d int, e int, f string")))
+
+    val complexStruct = StructType(
+      arrayOfStruct :: StructField("b", structOfStruct) :: StructField("c", IntegerType) ::
+        mapOfStruct :: Nil)
+
+    testPrunedSchema(complexStruct, StructField("a", ArrayType(StructType.fromDDL("b int"))),
+      StructField("b", StructType.fromDDL("a int")))
+    testPrunedSchema(complexStruct,
+      StructField("a", ArrayType(StructType.fromDDL("b int, c string"))),
+      StructField("b", StructType.fromDDL("b int")))
+
+    val selectFieldInMap = StructField("d", MapType(StructType.fromDDL("a int, b int"),
+      StructType.fromDDL("e int, f string")))
+    testPrunedSchema(complexStruct, StructField("c", IntegerType), selectFieldInMap)
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala
similarity index 99%
rename from sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala
index 05f7e3c..7cfe4bf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala
@@ -15,14 +15,13 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.execution
+package org.apache.spark.sql.catalyst.expressions
 
 import org.scalatest.BeforeAndAfterAll
 import org.scalatest.exceptions.TestFailedException
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.NamedExpression
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.types._
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 73112e3..41bc4d8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -399,15 +399,5 @@ class ColumnPruningSuite extends PlanTest {
     val expected = input.where(rand(0L) > 0.5).where('key < 10).select('key).analyze
     comparePlans(optimized, expected)
   }
-
-  test("SPARK-26619: Prune the unused serializers from SerializeFromObject") {
-    val testRelation = LocalRelation('_1.int, '_2.int)
-    val serializerObject = CatalystSerde.serialize[(Int, Int)](
-      CatalystSerde.deserialize[(Int, Int)](testRelation))
-    val query = serializerObject.select('_1)
-    val optimized = Optimize.execute(query.analyze)
-    val expected = serializerObject.copy(serializer = Seq(serializerObject.serializer.head)).analyze
-    comparePlans(optimized, expected)
-  }
   // todo: add more tests for column pruning
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala
new file mode 100644
index 0000000..dee685a
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+class ObjectSerializerPruningSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches = Batch("Object serializer pruning", FixedPoint(100),
+      ObjectSerializerPruning,
+      RemoveNoopOperators) :: Nil
+  }
+
+  implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
+
+  test("collect struct types") {
+    val dataTypes = Seq(
+      IntegerType,
+      ArrayType(IntegerType),
+      StructType.fromDDL("a int, b int"),
+      ArrayType(StructType.fromDDL("a int, b int, c string")),
+      StructType.fromDDL("a struct<a:int, b:int>, b int")
+    )
+
+    val expectedTypes = Seq(
+      Seq.empty[StructType],
+      Seq.empty[StructType],
+      Seq(StructType.fromDDL("a int, b int")),
+      Seq(StructType.fromDDL("a int, b int, c string")),
+      Seq(StructType.fromDDL("a struct<a:int, b:int>, b int"),
+        StructType.fromDDL("a int, b int"))
+    )
+
+    dataTypes.zipWithIndex.foreach { case (dt, idx) =>
+      val structs = ObjectSerializerPruning.collectStructType(dt, ArrayBuffer.empty[StructType])
+      assert(structs === expectedTypes(idx))
+    }
+  }
+
+  test("SPARK-26619: Prune the unused serializers from SerializeFromObject") {
+    val testRelation = LocalRelation('_1.int, '_2.int)
+    val serializerObject = CatalystSerde.serialize[(Int, Int)](
+      CatalystSerde.deserialize[(Int, Int)](testRelation))
+    val query = serializerObject.select('_1)
+    val optimized = Optimize.execute(query.analyze)
+    val expected = serializerObject.copy(serializer = Seq(serializerObject.serializer.head)).analyze
+    comparePlans(optimized, expected)
+  }
+
+  test("Prune nested serializers") {
+    withSQLConf(SQLConf.SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") {
+      val testRelation = LocalRelation('_1.struct(StructType.fromDDL("_1 int, _2 string")), '_2.int)
+      val serializerObject = CatalystSerde.serialize[((Int, String), Int)](
+        CatalystSerde.deserialize[((Int, String), Int)](testRelation))
+      val query = serializerObject.select($"_1._1")
+      val optimized = Optimize.execute(query.analyze)
+
+      val prunedSerializer = serializerObject.serializer.head.transformDown {
+        case CreateNamedStruct(children) =>
+          CreateNamedStruct(children.take(2))
+      }.transformUp {
+        // Aligns null literal in `If` expression to make it resolvable.
+        case i @ If(_: IsNull, Literal(null, dt), ser) if !dt.sameType(ser.dataType) =>
+          i.copy(trueValue = Literal(null, ser.dataType))
+      }.asInstanceOf[NamedExpression]
+
+      // `name` in `GetStructField` affects `comparePlans`. Maybe we can ignore
+      // `name` in `GetStructField.equals`?
+      val expected = serializerObject.copy(serializer = Seq(prunedSerializer))
+        .select($"_1._1").analyze.transformAllExpressions {
+        case g: GetStructField => g.copy(name = None)
+      }
+      comparePlans(optimized, expected)
+    }
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala
index 840fcae..cc33db9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala
@@ -21,7 +21,6 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{ProjectionOverSchema, SelectedField}
 import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
@@ -32,7 +31,9 @@ import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, St
  * Parquet format. In Spark SQL, a root-level Parquet column corresponds to a
  * SQL column, and a nested Parquet column corresponds to a [[StructField]].
  */
-private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] {
+object ParquetSchemaPruning extends Rule[LogicalPlan] {
+  import org.apache.spark.sql.catalyst.expressions.SchemaPruning._
+
   override def apply(plan: LogicalPlan): LogicalPlan =
     if (SQLConf.get.nestedSchemaPruningEnabled) {
       apply0(plan)
@@ -104,44 +105,6 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] {
   }
 
   /**
-   * Returns the set of fields from the Parquet file that the query plan needs.
-   */
-  private def identifyRootFields(projects: Seq[NamedExpression], filters: Seq[Expression]) = {
-    val projectionRootFields = projects.flatMap(getRootFields)
-    val filterRootFields = filters.flatMap(getRootFields)
-
-    // Kind of expressions don't need to access any fields of a root fields, e.g., `IsNotNull`.
-    // For them, if there are any nested fields accessed in the query, we don't need to add root
-    // field access of above expressions.
-    // For example, for a query `SELECT name.first FROM contacts WHERE name IS NOT NULL`,
-    // we don't need to read nested fields of `name` struct other than `first` field.
-    val (rootFields, optRootFields) = (projectionRootFields ++ filterRootFields)
-      .distinct.partition(!_.prunedIfAnyChildAccessed)
-
-    optRootFields.filter { opt =>
-      !rootFields.exists { root =>
-        root.field.name == opt.field.name && {
-          // Checking if current optional root field can be pruned.
-          // For each required root field, we merge it with the optional root field:
-          // 1. If this optional root field has nested fields and any nested field of it is used
-          //    in the query, the merged field type must equal to the optional root field type.
-          //    We can prune this optional root field. For example, for optional root field
-          //    `struct<name:struct<middle:string,last:string>>`, if its field
-          //    `struct<name:struct<last:string>>` is used, we don't need to add this optional
-          //    root field.
-          // 2. If this optional root field has no nested fields, the merged field type equals
-          //    to the optional root field only if they are the same. If they are, we can prune
-          //    this optional root field too.
-          val rootFieldType = StructType(Array(root.field))
-          val optFieldType = StructType(Array(opt.field))
-          val merged = optFieldType.merge(rootFieldType)
-          merged.sameType(optFieldType)
-        }
-      }
-    } ++ rootFields
-  }
-
-  /**
    * Builds the new output [[Project]] Spark SQL operator that has the pruned output relation.
    */
   private def buildNewProjection(
@@ -174,27 +137,6 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] {
   }
 
   /**
-   * Filters the schema from the given file by the requested fields.
-   * Schema field ordering from the file is preserved.
-   */
-  private def pruneDataSchema(
-      fileDataSchema: StructType,
-      requestedRootFields: Seq[RootField]) = {
-    // Merge the requested root fields into a single schema. Note the ordering of the fields
-    // in the resulting schema may differ from their ordering in the logical relation's
-    // original schema
-    val mergedSchema = requestedRootFields
-      .map { case root: RootField => StructType(Array(root.field)) }
-      .reduceLeft(_ merge _)
-    val dataSchemaFieldNames = fileDataSchema.fieldNames.toSet
-    val mergedDataSchema =
-      StructType(mergedSchema.filter(f => dataSchemaFieldNames.contains(f.name)))
-    // Sort the fields of mergedDataSchema according to their order in dataSchema,
-    // recursively. This makes mergedDataSchema a pruned schema of dataSchema
-    sortLeftFieldsByRight(mergedDataSchema, fileDataSchema).asInstanceOf[StructType]
-  }
-
-  /**
    * Builds a pruned logical relation from the output of the output relation and the schema of the
    * pruned base relation.
    */
@@ -218,30 +160,6 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] {
   }
 
   /**
-   * Gets the root (aka top-level, no-parent) [[StructField]]s for the given [[Expression]].
-   * When expr is an [[Attribute]], construct a field around it and indicate that that
-   * field was derived from an attribute.
-   */
-  private def getRootFields(expr: Expression): Seq[RootField] = {
-    expr match {
-      case att: Attribute =>
-        RootField(StructField(att.name, att.dataType, att.nullable), derivedFromAtt = true) :: Nil
-      case SelectedField(field) => RootField(field, derivedFromAtt = false) :: Nil
-      // Root field accesses by `IsNotNull` and `IsNull` are special cases as the expressions
-      // don't actually use any nested fields. These root field accesses might be excluded later
-      // if there are any nested fields accesses in the query plan.
-      case IsNotNull(SelectedField(field)) =>
-        RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed = true) :: Nil
-      case IsNull(SelectedField(field)) =>
-        RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed = true) :: Nil
-      case IsNotNull(_: Attribute) | IsNull(_: Attribute) =>
-        expr.children.flatMap(getRootFields).map(_.copy(prunedIfAnyChildAccessed = true))
-      case _ =>
-        expr.children.flatMap(getRootFields)
-    }
-  }
-
-  /**
    * Counts the "leaf" fields of the given dataType. Informally, this is the
    * number of fields of non-complex data type in the tree representation of
    * [[DataType]].
@@ -256,42 +174,5 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] {
     }
   }
 
-  /**
-  * Sorts the fields and descendant fields of structs in left according to their order in
-  * right. This function assumes that the fields of left are a subset of the fields of
-  * right, recursively. That is, left is a "subschema" of right, ignoring order of
-  * fields.
-  */
-  private def sortLeftFieldsByRight(left: DataType, right: DataType): DataType =
-    (left, right) match {
-      case (ArrayType(leftElementType, containsNull), ArrayType(rightElementType, _)) =>
-        ArrayType(
-          sortLeftFieldsByRight(leftElementType, rightElementType),
-          containsNull)
-      case (MapType(leftKeyType, leftValueType, containsNull),
-          MapType(rightKeyType, rightValueType, _)) =>
-        MapType(
-          sortLeftFieldsByRight(leftKeyType, rightKeyType),
-          sortLeftFieldsByRight(leftValueType, rightValueType),
-          containsNull)
-      case (leftStruct: StructType, rightStruct: StructType) =>
-        val filteredRightFieldNames = rightStruct.fieldNames.filter(leftStruct.fieldNames.contains)
-        val sortedLeftFields = filteredRightFieldNames.map { fieldName =>
-          val leftFieldType = leftStruct(fieldName).dataType
-          val rightFieldType = rightStruct(fieldName).dataType
-          val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType)
-          StructField(fieldName, sortedLeftFieldType)
-        }
-        StructType(sortedLeftFields)
-      case _ => left
-    }
 
-  /**
-   * This represents a "root" schema field (aka top-level, no-parent). `field` is the
-   * `StructField` for field name and datatype. `derivedFromAtt` indicates whether it
-   * was derived from an attribute or had a proper child. `prunedIfAnyChildAccessed` means
-   * whether this root field can be pruned if any of child field is used in the query.
-   */
-  private case class RootField(field: StructField, derivedFromAtt: Boolean,
-    prunedIfAnyChildAccessed: Boolean = false)
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala
new file mode 100644
index 0000000..2b1dbf0
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct
+import org.apache.spark.sql.catalyst.plans.logical.SerializeFromObject
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DatasetOptimizationSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
+
+  test("SPARK-26619: Prune the unused serializers from SerializeFromObject") {
+    val data = Seq(("a", 1), ("b", 2), ("c", 3))
+    val ds = data.toDS().map(t => (t._1, t._2 + 1)).select("_1")
+    val serializer = ds.queryExecution.optimizedPlan.collect {
+      case s: SerializeFromObject => s
+    }.head
+    assert(serializer.serializer.size == 1)
+    checkAnswer(ds, Seq(Row("a"), Row("b"), Row("c")))
+  }
+
+
+  // This methods checks if the given DataFrame has specified struct fields in object
+  // serializer. The varargs parameter `structFields` is the struct fields for object
+  // serializers. The first `structFields` is aligned with first serializer and ditto
+  // for other `structFields`.
+  private def testSerializer(df: DataFrame, structFields: Seq[Seq[String]]*): Unit = {
+    val serializer = df.queryExecution.optimizedPlan.collect {
+      case s: SerializeFromObject => s
+    }.head
+
+    serializer.serializer.zip(structFields).foreach { case (serializer, fields) =>
+      val structs = serializer.collect {
+        case c: CreateNamedStruct => c
+      }
+      assert(structs.size == fields.size)
+      structs.zip(fields).foreach { case (struct, fieldNames) =>
+        assert(struct.names.map(_.toString) == fieldNames)
+      }
+    }
+  }
+
+  test("Prune nested serializers: struct") {
+    withSQLConf(SQLConf.SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") {
+      val data = Seq((("a", 1, ("aa", 1.0)), 1), (("b", 2, ("bb", 2.0)), 2),
+        (("c", 3, ("cc", 3.0)), 3))
+      val ds = data.toDS().map(t => (t._1, t._2 + 1))
+
+      val df1 = ds.select("_1._1")
+      testSerializer(df1, Seq(Seq("_1")))
+      checkAnswer(df1, Seq(Row("a"), Row("b"), Row("c")))
+
+      val df2 = ds.select("_1._2")
+      testSerializer(df2, Seq(Seq("_2")))
+      checkAnswer(df2, Seq(Row(1), Row(2), Row(3)))
+
+      val df3 = ds.select("_1._3._1")
+      testSerializer(df3, Seq(Seq("_3"), Seq("_1")))
+      checkAnswer(df3, Seq(Row("aa"), Row("bb"), Row("cc")))
+
+      val df4 = ds.select("_1._3._1", "_1._2")
+      testSerializer(df4, Seq(Seq("_2", "_3"), Seq("_1")))
+      checkAnswer(df4, Seq(Row("aa", 1), Row("bb", 2), Row("cc", 3)))
+    }
+  }
+
+  test("Prune nested serializers: array of struct") {
+    withSQLConf(SQLConf.SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") {
+      val arrayData = Seq((Seq(("a", 1, ("a_1", 11)), ("b", 2, ("b_1", 22))), 1, ("aa", 1.0)),
+        (Seq(("c", 3, ("c_1", 33)), ("d", 4, ("d_1", 44))), 2, ("bb", 2.0)))
+      val arrayDs = arrayData.toDS().map(t => (t._1, t._2 + 1, t._3))
+      val df1 = arrayDs.select("_1._1")
+      // The serializer creates array of struct of one field "_1".
+      testSerializer(df1, Seq(Seq("_1")))
+      checkAnswer(df1, Seq(Row(Seq("a", "b")), Row(Seq("c", "d"))))
+
+      val df2 = arrayDs.select("_3._2")
+      testSerializer(df2, Seq(Seq("_2")))
+      checkAnswer(df2, Seq(Row(1.0), Row(2.0)))
+
+      // This is a more complex case. We select two root fields "_1" and "_3".
+      // The first serializer creates array of struct of two fields ("_1", "_3") and
+      // the field "_3" is a struct of one field "_2".
+      // The second serializer creates a struct of just one field "_1".
+      val df3 = arrayDs.select("_1._1", "_1._3._2", "_3._1")
+      testSerializer(df3, Seq(Seq("_1", "_3"), Seq("_2")), Seq(Seq("_1")))
+      checkAnswer(df3, Seq(Row(Seq("a", "b"), Seq(11, 22), "aa"),
+        Row(Seq("c", "d"), Seq(33, 44), "bb")))
+    }
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 64c4aab..a4ca9e6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -26,7 +26,6 @@ import org.apache.spark.{SparkException, TaskContext}
 import org.apache.spark.sql.catalyst.ScroogeLikeExample
 import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
 import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
-import org.apache.spark.sql.catalyst.plans.logical.SerializeFromObject
 import org.apache.spark.sql.catalyst.util.sideBySide
 import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution}
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
@@ -1707,16 +1706,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null)))
   }
 
-  test("SPARK-26619: Prune the unused serializers from SerializeFromObjec") {
-    val data = Seq(("a", 1), ("b", 2), ("c", 3))
-    val ds = data.toDS().map(t => (t._1, t._2 + 1)).select("_1")
-    val serializer = ds.queryExecution.optimizedPlan.collect {
-      case s: SerializeFromObject => s
-    }.head
-    assert(serializer.serializer.size == 1)
-    checkAnswer(ds, Seq(Row("a"), Row("b"), Row("c")))
-  }
-
   test("SPARK-26706: Fix Cast.mayTruncate for bytes") {
     val thrownException = intercept[AnalysisException] {
       spark.range(Long.MaxValue - 10, Long.MaxValue).as[Byte]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org